gabriel / muse public
diff.py python
336 lines 11.8 KB
6acddccb fix: restore structured --help output for all CLI commands Gabriel Cardona <gabriel@tellurstori.com> 1d ago
1 """muse diff — compare working tree against HEAD, or compare two commits."""
2
3 from __future__ import annotations
4
5 import argparse
6 import difflib
7 import json
8 import logging
9 import pathlib
10 import sys
11
12 from muse.core.errors import ExitCode
13 from muse.core.object_store import read_object
14 from muse.core.repo import require_repo
15 from muse.core.store import get_commit_snapshot_manifest, get_head_snapshot_manifest, read_current_branch, resolve_commit_ref
16 from muse.core.validation import sanitize_display
17 from muse.domain import DomainOp, SnapshotManifest
18 from muse.plugins.registry import read_domain, resolve_plugin
19
20 logger = logging.getLogger(__name__)
21
22
23 def _read_branch(root: pathlib.Path) -> str:
24 return read_current_branch(root)
25
26
27 def _read_repo_id(root: pathlib.Path) -> str:
28 return str(json.loads((root / ".muse" / "repo.json").read_text())["repo_id"])
29
30
31 _MAX_INLINE_CHILDREN = 12
32
33
34 def _green(text: str) -> str:
35 return f"\033[32m{text}\033[0m"
36
37
38 def _red(text: str) -> str:
39 return f"\033[31m{text}\033[0m"
40
41
42 def _yellow(text: str) -> str:
43 return f"\033[33m{text}\033[0m"
44
45
46 def _cyan(text: str) -> str:
47 return f"\033[36m{text}\033[0m"
48
49
50 _LOC_SEP = " L"
51
52
53 def _split_loc(summary: str) -> tuple[str, str]:
54 """Split 'added function foo L4–8' into ('added function foo', 'L4–8').
55
56 Returns the original string and an empty loc when no location suffix is
57 present (e.g. cross-file move annotations that carry no line data).
58 """
59 if _LOC_SEP in summary:
60 label, _, loc = summary.rpartition(_LOC_SEP)
61 return label, f"L{loc}"
62 return summary, ""
63
64
65 def _print_child_ops(child_ops: list[DomainOp]) -> None:
66 """Render symbol-level child ops with aligned columns and colours.
67
68 Labels are left-padded to a uniform width within the group so the
69 line-range column (``L{start}–{end}``) lines up vertically. Shows up
70 to ``_MAX_INLINE_CHILDREN`` entries inline; summarises the rest on a
71 single trailing line.
72 """
73 visible = child_ops[:_MAX_INLINE_CHILDREN]
74 overflow = len(child_ops) - len(visible)
75
76 # First pass: gather (op_type, unstyled_label, loc) for each visible op.
77 # We need unstyled widths before applying ANSI colour codes.
78 rows: list[tuple[str, str, str]] = []
79 for cop in visible:
80 if cop["op"] == "insert":
81 label, loc = _split_loc(cop["content_summary"])
82 rows.append(("insert", label, loc))
83 elif cop["op"] == "delete":
84 label, loc = _split_loc(cop["content_summary"])
85 rows.append(("delete", label, loc))
86 elif cop["op"] == "replace":
87 label, loc = _split_loc(cop["new_summary"])
88 rows.append(("replace", label, loc))
89 elif cop["op"] == "move":
90 label = f"{cop['address']} ({cop['from_position']} → {cop['to_position']})"
91 rows.append(("move", label, ""))
92 else:
93 rows.append(("unknown", "", ""))
94
95 for i, (op_type, label, loc) in enumerate(rows):
96 is_last = (i == len(rows) - 1) and overflow == 0
97 connector = "└─" if is_last else "├─"
98 if op_type == "insert":
99 styled = _green(label)
100 elif op_type == "delete":
101 styled = _red(label)
102 elif op_type == "replace":
103 styled = _yellow(label)
104 elif op_type == "move":
105 styled = _cyan(label)
106 else:
107 styled = label
108 suffix = f" {loc}" if loc else ""
109 print(f" {connector} {styled}{suffix}")
110
111 if overflow > 0:
112 print(f" └─ … and {overflow} more")
113
114
115 def _print_structured_delta(ops: list[DomainOp]) -> int:
116 """Print a colour-coded delta op-by-op. Returns the number of ops printed.
117
118 Colour scheme mirrors standard diff conventions:
119 - Green → added (A)
120 - Red → deleted (D)
121 - Yellow → modified (M)
122 - Cyan → moved / renamed (R)
123
124 Each branch checks ``op["op"]`` directly so mypy can narrow the
125 TypedDict union to the specific subtype before accessing its fields.
126 """
127 for op in ops:
128 if op["op"] == "insert":
129 print(_green(f"A {op['address']}"))
130 elif op["op"] == "delete":
131 print(_red(f"D {op['address']}"))
132 elif op["op"] == "replace":
133 print(_yellow(f"M {op['address']}"))
134 elif op["op"] == "move":
135 print(
136 _cyan(f"R {op['address']} ({op['from_position']} → {op['to_position']})")
137 )
138 elif op["op"] == "patch":
139 child_ops = op["child_ops"]
140 from_address = op.get("from_address")
141 if from_address:
142 # File was renamed AND edited simultaneously.
143 print(_cyan(f"R {from_address} → {op['address']}"))
144 else:
145 # Classify the patch: all-inserts = new file, all-deletes =
146 # removed file, mixed = modification. Use the right status
147 # prefix so the output reads like `git diff --name-status`.
148 all_insert = all(c["op"] == "insert" for c in child_ops)
149 all_delete = all(c["op"] == "delete" for c in child_ops)
150 if all_insert:
151 print(_green(f"A {op['address']}"))
152 elif all_delete:
153 print(_red(f"D {op['address']}"))
154 else:
155 print(_yellow(f"M {op['address']}"))
156 _print_child_ops(child_ops)
157 return len(ops)
158
159
160 def _print_text_diff(
161 base_files: dict[str, str],
162 target_files: dict[str, str],
163 root: pathlib.Path,
164 workdir: pathlib.Path | None,
165 ) -> int:
166 """Print a coloured unified diff for every changed file. Returns change count."""
167 base_paths = set(base_files)
168 target_paths = set(target_files)
169 changed = (
170 sorted(target_paths - base_paths) # added
171 + sorted(base_paths - target_paths) # removed
172 + sorted( # modified
173 p for p in base_paths & target_paths
174 if base_files[p] != target_files[p]
175 )
176 )
177
178 for path in changed:
179 # Read base content.
180 if path in base_files:
181 raw_base = read_object(root, base_files[path])
182 base_lines = raw_base.decode("utf-8", errors="replace").splitlines(keepends=True) if raw_base else []
183 base_label = f"a/{path}"
184 else:
185 base_lines = []
186 base_label = "/dev/null"
187
188 # Read target content (object store first, then disk for working tree).
189 if path in target_files:
190 raw_target = read_object(root, target_files[path])
191 if raw_target is None and workdir is not None:
192 disk = workdir / path
193 if disk.is_file():
194 raw_target = disk.read_bytes()
195 target_lines = raw_target.decode("utf-8", errors="replace").splitlines(keepends=True) if raw_target else []
196 target_label = f"b/{path}"
197 else:
198 target_lines = []
199 target_label = "/dev/null"
200
201 hunks = list(difflib.unified_diff(
202 base_lines, target_lines,
203 fromfile=base_label, tofile=target_label,
204 lineterm="",
205 ))
206 if not hunks:
207 continue
208
209 for line in hunks:
210 if line.startswith("---") or line.startswith("+++"):
211 print(f"\033[1m{line}\033[0m")
212 elif line.startswith("@@"):
213 print(_cyan(line))
214 elif line.startswith("+"):
215 print(_green(line))
216 elif line.startswith("-"):
217 print(_red(line))
218 else:
219 print(line)
220
221 return len(changed)
222
223
224 def register(subparsers: "argparse._SubParsersAction[argparse.ArgumentParser]") -> None:
225 """Register the diff subcommand."""
226 parser = subparsers.add_parser(
227 "diff",
228 help="Compare working tree against HEAD, or compare two commits.",
229 description=__doc__,
230 formatter_class=argparse.RawDescriptionHelpFormatter,
231 )
232 parser.add_argument("commit_a", nargs="?", default=None, help="Base commit ID (default: HEAD).")
233 parser.add_argument("commit_b", nargs="?", default=None, help="Target commit ID (default: working tree).")
234 parser.add_argument("--stat", action="store_true", help="Show summary statistics only.")
235 parser.add_argument("--text", action="store_true", help="Show line-level unified diff instead of semantic symbols.")
236 parser.add_argument("--format", "-f", default="text", dest="fmt", help="Output format: text or json.")
237 parser.set_defaults(func=run)
238
239
240 def run(args: argparse.Namespace) -> None:
241 """Compare working tree against HEAD, or compare two commits.
242
243 Agents should pass ``--format json`` to receive a structured result::
244
245 {
246 "summary": "3 changes",
247 "added": ["path/to/new_file"],
248 "deleted": ["path/to/removed_file"],
249 "modified": ["path/to/changed_file"],
250 "total_changes": 3
251 }
252 """
253 commit_a: str | None = args.commit_a
254 commit_b: str | None = args.commit_b
255 stat: bool = args.stat
256 text: bool = args.text
257 fmt: str = args.fmt
258
259 if fmt not in ("text", "json"):
260 print(f"❌ Unknown --format '{sanitize_display(fmt)}'. Choose text or json.", file=sys.stderr)
261 raise SystemExit(ExitCode.USER_ERROR)
262 root = require_repo()
263 repo_id = _read_repo_id(root)
264 branch = _read_branch(root)
265 domain = read_domain(root)
266 plugin = resolve_plugin(root)
267
268 def _resolve_manifest(ref: str) -> dict[str, str]:
269 """Resolve a ref (branch, short SHA, full SHA) to its snapshot manifest."""
270 resolved = resolve_commit_ref(root, repo_id, branch, ref)
271 if resolved is None:
272 print(f"⚠️ Commit '{sanitize_display(ref)}' not found.")
273 raise SystemExit(ExitCode.USER_ERROR)
274 return get_commit_snapshot_manifest(root, resolved.commit_id) or {}
275
276 if commit_a is None:
277 base_snap = SnapshotManifest(
278 files=get_head_snapshot_manifest(root, repo_id, branch) or {},
279 domain=domain,
280 )
281 target_snap = plugin.snapshot(root)
282 elif commit_b is None:
283 # Single ref provided: diff HEAD vs that ref's snapshot.
284 base_snap = SnapshotManifest(
285 files=get_head_snapshot_manifest(root, repo_id, branch) or {},
286 domain=domain,
287 )
288 target_snap = SnapshotManifest(
289 files=_resolve_manifest(commit_a),
290 domain=domain,
291 )
292 else:
293 base_snap = SnapshotManifest(
294 files=_resolve_manifest(commit_a),
295 domain=domain,
296 )
297 target_snap = SnapshotManifest(
298 files=_resolve_manifest(commit_b),
299 domain=domain,
300 )
301
302 if text and fmt != "json":
303 workdir = root if commit_a is None else None
304 changed = _print_text_diff(
305 base_snap["files"], target_snap["files"], root, workdir
306 )
307 if changed == 0:
308 print("No differences.")
309 return
310
311 delta = plugin.diff(base_snap, target_snap, repo_root=root)
312
313 if fmt == "json":
314 added = [op["address"] for op in delta["ops"] if op["op"] == "insert"]
315 deleted = [op["address"] for op in delta["ops"] if op["op"] == "delete"]
316 modified = [op["address"] for op in delta["ops"]
317 if op["op"] in ("replace", "patch", "mutate", "move")]
318 print(json.dumps({
319 "summary": delta["summary"],
320 "added": sorted(added),
321 "deleted": sorted(deleted),
322 "modified": sorted(modified),
323 "total_changes": len(delta["ops"]),
324 }))
325 return
326
327 if stat:
328 print(delta["summary"] if delta["ops"] else "No differences.")
329 return
330
331 changed = _print_structured_delta(delta["ops"])
332
333 if changed == 0:
334 print("No differences.")
335 else:
336 print(f"\n{delta['summary']}")