checkout_symbol.py
python
| 1 | """muse checkout-symbol — restore a historical version of a specific symbol. |
| 2 | |
| 3 | Extracts a single named symbol from a historical committed snapshot and writes |
| 4 | it back into the current working-tree file, replacing the current version of |
| 5 | that symbol. |
| 6 | |
| 7 | This is a **surgical** operation: only the target symbol's lines change. |
| 8 | All surrounding code — other symbols, comments, imports, blank lines outside |
| 9 | the symbol boundary — is left untouched. |
| 10 | |
| 11 | Why this matters |
| 12 | ---------------- |
| 13 | Git's ``checkout`` restores entire files. If you need to roll back a single |
| 14 | function while keeping everything else current, you need to manually cherry- |
| 15 | pick lines. ``muse checkout-symbol`` does this atomically against Muse's |
| 16 | content-addressed symbol index. |
| 17 | |
| 18 | Usage:: |
| 19 | |
| 20 | muse checkout-symbol "src/billing.py::compute_invoice_total" --commit HEAD~3 |
| 21 | muse checkout-symbol "src/auth.py::validate_token" --commit abc12345 --dry-run |
| 22 | |
| 23 | Output (without --dry-run):: |
| 24 | |
| 25 | Restoring: src/billing.py::compute_invoice_total |
| 26 | from commit: abc12345 (2026-02-15) |
| 27 | lines 42–67 → replaced with 31 historical lines |
| 28 | ✅ Written to src/billing.py |
| 29 | |
| 30 | Output (with --dry-run):: |
| 31 | |
| 32 | Dry run — no files will be written. |
| 33 | |
| 34 | Restoring: src/billing.py::compute_invoice_total |
| 35 | from commit: abc12345 (2026-02-15) |
| 36 | |
| 37 | --- current |
| 38 | +++ historical |
| 39 | @@ -42,26 +42,20 @@ |
| 40 | def compute_invoice_total(...): |
| 41 | - ...current body... |
| 42 | + ...historical body... |
| 43 | |
| 44 | Flags: |
| 45 | |
| 46 | ``--commit, -c REF`` |
| 47 | Required. Commit to restore from. |
| 48 | |
| 49 | ``--dry-run`` |
| 50 | Print the diff without writing anything. |
| 51 | """ |
| 52 | from __future__ import annotations |
| 53 | |
| 54 | import difflib |
| 55 | import json |
| 56 | import logging |
| 57 | import pathlib |
| 58 | |
| 59 | import typer |
| 60 | |
| 61 | from muse.core.errors import ExitCode |
| 62 | from muse.core.object_store import read_object |
| 63 | from muse.core.repo import require_repo |
| 64 | from muse.core.store import get_commit_snapshot_manifest, resolve_commit_ref |
| 65 | from muse.plugins.code.ast_parser import parse_symbols |
| 66 | |
| 67 | logger = logging.getLogger(__name__) |
| 68 | |
| 69 | app = typer.Typer() |
| 70 | |
| 71 | |
| 72 | def _read_repo_id(root: pathlib.Path) -> str: |
| 73 | return str(json.loads((root / ".muse" / "repo.json").read_text())["repo_id"]) |
| 74 | |
| 75 | |
| 76 | def _read_branch(root: pathlib.Path) -> str: |
| 77 | head_ref = (root / ".muse" / "HEAD").read_text().strip() |
| 78 | return head_ref.removeprefix("refs/heads/").strip() |
| 79 | |
| 80 | |
| 81 | def _extract_lines(source: bytes, lineno: int, end_lineno: int) -> list[str]: |
| 82 | """Extract lines *lineno*..*end_lineno* (1-indexed, inclusive) as a list.""" |
| 83 | all_lines = source.decode("utf-8", errors="replace").splitlines(keepends=True) |
| 84 | return all_lines[lineno - 1:end_lineno] |
| 85 | |
| 86 | |
| 87 | def _find_current_symbol_lines( |
| 88 | working_tree_file: pathlib.Path, |
| 89 | address: str, |
| 90 | ) -> tuple[int, int] | None: |
| 91 | """Return (lineno, end_lineno) for *address* in the current working-tree file. |
| 92 | |
| 93 | Returns ``None`` if the symbol is not found. |
| 94 | """ |
| 95 | if not working_tree_file.exists(): |
| 96 | return None |
| 97 | raw = working_tree_file.read_bytes() |
| 98 | tree = parse_symbols(raw, str(working_tree_file)) |
| 99 | rec = tree.get(address) |
| 100 | if rec is None: |
| 101 | return None |
| 102 | return rec["lineno"], rec["end_lineno"] |
| 103 | |
| 104 | |
| 105 | @app.callback(invoke_without_command=True) |
| 106 | def checkout_symbol( |
| 107 | ctx: typer.Context, |
| 108 | address: str = typer.Argument( |
| 109 | ..., metavar="ADDRESS", |
| 110 | help='Symbol address, e.g. "src/billing.py::compute_invoice_total".', |
| 111 | ), |
| 112 | ref: str = typer.Option( |
| 113 | ..., "--commit", "-c", metavar="REF", |
| 114 | help="Commit to restore the symbol from (required).", |
| 115 | ), |
| 116 | dry_run: bool = typer.Option( |
| 117 | False, "--dry-run", |
| 118 | help="Print the diff without writing anything.", |
| 119 | ), |
| 120 | ) -> None: |
| 121 | """Restore a historical version of a specific symbol into the working tree. |
| 122 | |
| 123 | Extracts the symbol body from the given historical commit and splices it |
| 124 | into the current working-tree file at the symbol's current location. |
| 125 | Only the target symbol's lines change; everything else is left untouched. |
| 126 | |
| 127 | If the symbol does not exist at ``--commit``, the command exits with an |
| 128 | error. If the symbol does not exist in the current working tree (perhaps |
| 129 | it was deleted), the historical version is appended to the end of the file. |
| 130 | """ |
| 131 | root = require_repo() |
| 132 | repo_id = _read_repo_id(root) |
| 133 | branch = _read_branch(root) |
| 134 | |
| 135 | if "::" not in address: |
| 136 | typer.echo("❌ ADDRESS must be a symbol address like 'src/billing.py::func'.", err=True) |
| 137 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 138 | |
| 139 | file_rel, sym_qualified = address.split("::", 1) |
| 140 | |
| 141 | commit = resolve_commit_ref(root, repo_id, branch, ref) |
| 142 | if commit is None: |
| 143 | typer.echo(f"❌ Commit '{ref}' not found.", err=True) |
| 144 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 145 | |
| 146 | # Read the historical blob. |
| 147 | manifest = get_commit_snapshot_manifest(root, commit.commit_id) or {} |
| 148 | obj_id = manifest.get(file_rel) |
| 149 | if obj_id is None: |
| 150 | typer.echo( |
| 151 | f"❌ '{file_rel}' is not in snapshot {commit.commit_id[:8]}.", err=True |
| 152 | ) |
| 153 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 154 | |
| 155 | historical_raw = read_object(root, obj_id) |
| 156 | if historical_raw is None: |
| 157 | typer.echo(f"❌ Blob {obj_id[:8]} missing from object store.", err=True) |
| 158 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 159 | |
| 160 | # Find the symbol in the historical blob. |
| 161 | hist_tree = parse_symbols(historical_raw, file_rel) |
| 162 | hist_rec = hist_tree.get(address) |
| 163 | if hist_rec is None: |
| 164 | typer.echo( |
| 165 | f"❌ Symbol '{address}' not found in commit {commit.commit_id[:8]}.", err=True |
| 166 | ) |
| 167 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 168 | |
| 169 | historical_lines = _extract_lines( |
| 170 | historical_raw, hist_rec["lineno"], hist_rec["end_lineno"] |
| 171 | ) |
| 172 | |
| 173 | # Find the symbol in the current working tree. |
| 174 | working_file = root / file_rel |
| 175 | current_lines = working_file.read_bytes().decode("utf-8", errors="replace").splitlines( |
| 176 | keepends=True |
| 177 | ) if working_file.exists() else [] |
| 178 | |
| 179 | current_sym_range = _find_current_symbol_lines(working_file, address) |
| 180 | |
| 181 | if dry_run: |
| 182 | typer.echo("Dry run — no files will be written.\n") |
| 183 | |
| 184 | typer.echo(f"Restoring: {address}") |
| 185 | typer.echo(f" from commit: {commit.commit_id[:8]} ({commit.committed_at.date()})") |
| 186 | |
| 187 | if current_sym_range is not None: |
| 188 | cur_start, cur_end = current_sym_range |
| 189 | typer.echo( |
| 190 | f" lines {cur_start}–{cur_end} → replaced with " |
| 191 | f"{len(historical_lines)} historical line(s)" |
| 192 | ) |
| 193 | new_lines = current_lines[:cur_start - 1] + historical_lines + current_lines[cur_end:] |
| 194 | else: |
| 195 | typer.echo(f" symbol not found in working tree — appending at end of file") |
| 196 | new_lines = current_lines + ["\n"] + historical_lines |
| 197 | |
| 198 | if dry_run: |
| 199 | # Show unified diff. |
| 200 | diff = difflib.unified_diff( |
| 201 | current_lines, |
| 202 | new_lines, |
| 203 | fromfile="current", |
| 204 | tofile="historical", |
| 205 | lineterm="", |
| 206 | ) |
| 207 | typer.echo("\n" + "".join(diff)) |
| 208 | return |
| 209 | |
| 210 | # Write the patched file. |
| 211 | working_file.write_text("".join(new_lines), encoding="utf-8") |
| 212 | typer.echo(f"✅ Written to {file_rel}") |