gabriel / muse public
rebase.py python
532 lines 18.5 KB
e8c4265e feat(hardening): final sweep — security, performance, API consistency, docs Gabriel Cardona <gabriel@tellurstori.com> 2d ago
1 """``muse rebase`` — replay commits from one branch onto another.
2
3 Muse rebase is cherry-pick of a range: it takes the commits unique to the
4 current branch (those not reachable from the upstream) and replays them
5 one-by-one on top of the upstream. Because commits are content-addressed,
6 each replayed commit gets a new ID — the originals are untouched in the store.
7
8 Usage::
9
10 muse rebase <upstream> # replay HEAD's unique commits onto upstream
11 muse rebase --onto <newbase> <upstream> # replay onto a different base
12 muse rebase --squash [<upstream>] # collapse all commits into one
13 muse rebase --abort # restore original HEAD
14 muse rebase --continue # resume after resolving a conflict
15
16 Exit codes::
17
18 0 — rebase completed or aborted successfully
19 1 — conflict encountered, or bad arguments
20 3 — internal error
21 """
22
23 from __future__ import annotations
24
25 import datetime
26 import json
27 import logging
28 import pathlib
29 from typing import Annotated
30
31 import typer
32
33 from muse.core.errors import ExitCode
34 from muse.core.merge_engine import find_merge_base, write_merge_state
35 from muse.core.validation import validate_branch_name
36 from muse.domain import SnapshotManifest as _SnapshotManifest
37 from muse.core.rebase import (
38 RebaseState,
39 _write_branch_ref,
40 clear_rebase_state,
41 collect_commits_to_replay,
42 load_rebase_state,
43 replay_one,
44 save_rebase_state,
45 )
46 from muse.core.reflog import append_reflog
47 from muse.core.repo import require_repo
48 from muse.domain import MuseDomainPlugin, SnapshotManifest
49 from muse.core.snapshot import compute_commit_id, compute_snapshot_id
50 from muse.core.store import (
51 CommitRecord,
52 SnapshotRecord,
53 get_head_commit_id,
54 read_commit,
55 read_current_branch,
56 read_snapshot,
57 resolve_commit_ref,
58 write_commit,
59 write_snapshot,
60 )
61 from muse.core.validation import sanitize_display
62 from muse.core.workdir import apply_manifest
63 from muse.plugins.registry import read_domain, resolve_plugin
64
65 logger = logging.getLogger(__name__)
66
67 app = typer.Typer(help="Replay commits from the current branch onto a new base.")
68
69
70 def _read_repo_id(root: pathlib.Path) -> str:
71 return str(json.loads((root / ".muse" / "repo.json").read_text(encoding="utf-8"))["repo_id"])
72
73
74 def _resolve_ref_to_id(
75 root: pathlib.Path,
76 repo_id: str,
77 branch: str,
78 ref: str,
79 ) -> str | None:
80 """Resolve a ref string (branch name, commit SHA, or HEAD) to a commit ID.
81
82 Branch names are validated with ``validate_branch_name`` before being used
83 as path components to prevent directory traversal attacks.
84 """
85 if ref.upper() == "HEAD":
86 return get_head_commit_id(root, branch)
87
88 # Try as a branch ref — validate before using as a path component.
89 try:
90 from muse.core.validation import validate_branch_name
91 validate_branch_name(ref)
92 except (ValueError, ImportError):
93 # Not a valid branch name — try as a commit SHA instead.
94 rec = resolve_commit_ref(root, repo_id, branch, ref)
95 return rec.commit_id if rec else None
96
97 ref_path = root / ".muse" / "refs" / "heads" / ref
98 if ref_path.exists():
99 raw = ref_path.read_text(encoding="utf-8").strip()
100 if raw and len(raw) == 64 and all(c in "0123456789abcdef" for c in raw):
101 return raw
102
103 # Fall back to commit SHA prefix resolution.
104 rec = resolve_commit_ref(root, repo_id, branch, ref)
105 return rec.commit_id if rec else None
106
107
108 def _run_replay_loop(
109 root: pathlib.Path,
110 state: RebaseState,
111 repo_id: str,
112 branch: str,
113 plugin: "MuseDomainPlugin",
114 domain: str,
115 ) -> bool:
116 """Run the replay loop. Returns True if completed cleanly, False on conflict."""
117 current_parent = state["completed"][-1] if state["completed"] else state["onto"]
118
119 while state["remaining"]:
120 orig_commit_id = state["remaining"][0]
121 commit = read_commit(root, orig_commit_id)
122 if commit is None:
123 typer.echo(f"⚠️ Commit {orig_commit_id[:12]} not found — skipping.")
124 state["remaining"].pop(0)
125 save_rebase_state(root, state)
126 continue
127
128 typer.echo(f" Replaying {orig_commit_id[:12]}: {sanitize_display(commit.message)}")
129
130 result = replay_one(
131 root, commit, current_parent, plugin, domain, repo_id, branch
132 )
133
134 if isinstance(result, list):
135 # Conflict — write state and pause.
136 state["remaining"].pop(0) # will retry via --continue
137 state["remaining"].insert(0, orig_commit_id)
138 save_rebase_state(root, state)
139
140 write_merge_state(
141 root,
142 base_commit=commit.parent_commit_id or "",
143 ours_commit=current_parent,
144 theirs_commit=orig_commit_id,
145 conflict_paths=result,
146 )
147 typer.echo(f"\n❌ Rebase stopped at {orig_commit_id[:12]} due to conflict(s):")
148 for p in sorted(result):
149 typer.echo(f" CONFLICT: {p}")
150 typer.echo(
151 "\nResolve conflicts then run:\n"
152 " muse rebase --continue to resume\n"
153 " muse rebase --abort to restore original HEAD"
154 )
155 return False
156
157 # Clean replay — advance.
158 current_parent = result.commit_id
159 state["remaining"].pop(0)
160 state["completed"].append(result.commit_id)
161 save_rebase_state(root, state)
162
163 append_reflog(
164 root, branch,
165 old_id=state["completed"][-2] if len(state["completed"]) >= 2 else state["onto"],
166 new_id=result.commit_id,
167 author="user",
168 operation=f"rebase: replayed {orig_commit_id[:12]} onto {state['onto'][:12]}",
169 )
170
171 return True
172
173
174 @app.callback(invoke_without_command=True)
175 def rebase(
176 ctx: typer.Context,
177 upstream: Annotated[
178 str | None,
179 typer.Argument(help="Branch or commit to rebase onto."),
180 ] = None,
181 onto: Annotated[
182 str | None,
183 typer.Option("--onto", "-o", help="New base commit (replay commits between <upstream> and HEAD onto this)."),
184 ] = None,
185 squash: Annotated[
186 bool,
187 typer.Option("--squash", "-s", help="Collapse all replayed commits into one."),
188 ] = False,
189 squash_message: Annotated[
190 str | None,
191 typer.Option("--message", "-m", help="Commit message for --squash (default: last commit's message)."),
192 ] = None,
193 abort: Annotated[
194 bool,
195 typer.Option("--abort", "-a", help="Abort an in-progress rebase and restore the original HEAD."),
196 ] = False,
197 continue_: Annotated[
198 bool,
199 typer.Option("--continue", "-c", help="Resume after resolving a conflict."),
200 ] = False,
201 fmt: Annotated[
202 str,
203 typer.Option("--format", "-f", help="Output format: text or json."),
204 ] = "text",
205 ) -> None:
206 """Replay commits from the current branch onto a new base.
207
208 The most common invocation replays all commits unique to the current
209 branch on top of *upstream*'s HEAD::
210
211 muse rebase main # rebase current branch onto main
212
213 Use ``--onto`` when you need to replay onto a commit that is not the
214 tip of *upstream*::
215
216 muse rebase --onto newbase upstream
217
218 Use ``--squash`` to collapse all replayed commits into a single commit
219 for a clean merge workflow::
220
221 muse rebase --squash main
222
223 When a conflict is encountered, the rebase pauses. Resolve the conflict,
224 then::
225
226 muse rebase --continue
227
228 Or discard the entire rebase::
229
230 muse rebase --abort
231
232 Examples::
233
234 muse rebase main
235 muse rebase --onto main feat/base
236 muse rebase --squash --message "feat: combined" main
237 muse rebase --abort
238 muse rebase --continue
239 muse rebase --format json main # machine-readable result
240
241 Agents should pass ``--format json`` to receive a structured result.
242 Squash rebase payload::
243
244 {
245 "status": "merged|conflict",
246 "branch": "<name>",
247 "new_head": "<sha256>",
248 "onto": "<sha256>",
249 "squash": true,
250 "conflicts": []
251 }
252
253 Normal rebase payload adds ``"replayed": <n>`` (number of commits applied).
254 """
255 import json as _json
256 if fmt not in ("text", "json"):
257 from muse.core.validation import sanitize_display as _sd
258 typer.echo(f"❌ Unknown --format '{_sd(fmt)}'. Choose text or json.", err=True)
259 raise typer.Exit(code=ExitCode.USER_ERROR)
260
261 root = require_repo()
262 repo_id = _read_repo_id(root)
263 branch = read_current_branch(root)
264 plugin = resolve_plugin(root)
265 domain = read_domain(root)
266
267 active_state = load_rebase_state(root)
268
269 # --abort
270 if abort:
271 if active_state is None:
272 typer.echo("❌ No rebase in progress.")
273 raise typer.Exit(code=ExitCode.USER_ERROR)
274
275 original_head = active_state["original_head"]
276 original_branch = active_state["original_branch"]
277 _write_branch_ref(root, original_branch, original_head)
278
279 # Restore working tree to original HEAD.
280 orig_commit = read_commit(root, original_head)
281 if orig_commit:
282 snap = read_snapshot(root, orig_commit.snapshot_id)
283 if snap:
284 apply_manifest(root, snap.manifest)
285
286 append_reflog(
287 root, original_branch,
288 old_id=active_state["completed"][-1] if active_state["completed"] else active_state["onto"],
289 new_id=original_head,
290 author="user",
291 operation="rebase: abort",
292 )
293 clear_rebase_state(root)
294 typer.echo(f"✅ Rebase aborted. HEAD restored to {original_head[:12]}.")
295 return
296
297 # --continue
298 if continue_:
299 if active_state is None:
300 typer.echo("❌ No rebase in progress. Nothing to continue.")
301 raise typer.Exit(code=ExitCode.USER_ERROR)
302
303 # The user has resolved the conflict manually. Snapshot the current
304 # working tree and create the commit for the paused step.
305 current_parent = (
306 active_state["completed"][-1]
307 if active_state["completed"]
308 else active_state["onto"]
309 )
310 orig_commit_id = active_state["remaining"][0] if active_state["remaining"] else ""
311 orig_commit = read_commit(root, orig_commit_id) if orig_commit_id else None
312
313 snap_result = plugin.snapshot(root)
314 manifest: dict[str, str] = snap_result["files"]
315 snapshot_id = compute_snapshot_id(manifest)
316 committed_at = datetime.datetime.now(datetime.timezone.utc)
317 message = orig_commit.message if orig_commit else "rebase: continued"
318 new_commit_id = compute_commit_id(
319 parent_ids=[current_parent] if current_parent else [],
320 snapshot_id=snapshot_id,
321 message=message,
322 committed_at_iso=committed_at.isoformat(),
323 )
324 write_snapshot(root, SnapshotRecord(snapshot_id=snapshot_id, manifest=manifest))
325 new_commit = CommitRecord(
326 commit_id=new_commit_id,
327 repo_id=repo_id,
328 branch=branch,
329 snapshot_id=snapshot_id,
330 message=message,
331 committed_at=committed_at,
332 parent_commit_id=current_parent if current_parent else None,
333 author=orig_commit.author if orig_commit else "",
334 )
335 write_commit(root, new_commit)
336 active_state["completed"].append(new_commit_id)
337 if active_state["remaining"]:
338 active_state["remaining"].pop(0)
339 save_rebase_state(root, active_state)
340
341 append_reflog(
342 root, branch,
343 old_id=current_parent,
344 new_id=new_commit_id,
345 author="user",
346 operation=f"rebase: continue — replayed {orig_commit_id[:12] if orig_commit_id else '?'}",
347 )
348
349 if not active_state["remaining"]:
350 _write_branch_ref(root, branch, new_commit_id)
351 clear_rebase_state(root)
352 typer.echo(f"✅ Rebase complete. HEAD is now {new_commit_id[:12]}.")
353 return
354
355 # More commits to replay.
356 clean = _run_replay_loop(root, active_state, repo_id, branch, plugin, domain)
357 if clean:
358 final_id = active_state["completed"][-1]
359 _write_branch_ref(root, branch, final_id)
360 clear_rebase_state(root)
361 typer.echo(f"✅ Rebase complete. HEAD is now {final_id[:12]}.")
362 return
363
364 # New rebase — check not already in progress.
365 if active_state is not None:
366 typer.echo(
367 "❌ Rebase in progress. Use --continue or --abort."
368 )
369 raise typer.Exit(code=ExitCode.USER_ERROR)
370
371 if upstream is None:
372 typer.echo("❌ Provide an upstream branch or commit to rebase onto.", err=True)
373 raise typer.Exit(code=ExitCode.USER_ERROR)
374
375 # Resolve HEAD and upstream.
376 head_commit_id = get_head_commit_id(root, branch)
377 if head_commit_id is None:
378 typer.echo("❌ Current branch has no commits.", err=True)
379 raise typer.Exit(code=ExitCode.USER_ERROR)
380
381 upstream_id = _resolve_ref_to_id(root, repo_id, branch, upstream)
382 if upstream_id is None:
383 typer.echo(f"❌ Upstream '{sanitize_display(upstream)}' not found.", err=True)
384 raise typer.Exit(code=ExitCode.USER_ERROR)
385
386 # Determine the new base.
387 if onto is not None:
388 onto_id = _resolve_ref_to_id(root, repo_id, branch, onto)
389 if onto_id is None:
390 typer.echo(f"❌ --onto '{sanitize_display(onto)}' not found.", err=True)
391 raise typer.Exit(code=ExitCode.USER_ERROR)
392 else:
393 onto_id = upstream_id
394
395 # Find merge base to determine which commits to replay.
396 merge_base_id = find_merge_base(root, head_commit_id, upstream_id)
397 stop_at = merge_base_id or ""
398
399 if head_commit_id == upstream_id or head_commit_id == onto_id:
400 typer.echo("Already up to date.")
401 return
402
403 commits_to_replay = collect_commits_to_replay(root, stop_at, head_commit_id)
404 if not commits_to_replay:
405 typer.echo("Already up to date.")
406 return
407
408 typer.echo(
409 f"Rebasing {len(commits_to_replay)} commit(s) "
410 f"onto {onto_id[:12]} (from {branch})"
411 )
412
413 if squash:
414 # Replay all commits and produce one final squashed commit.
415 current_parent = onto_id
416 squash_manifest: dict[str, str] = {}
417
418 # Get onto base snapshot.
419 onto_commit = read_commit(root, onto_id)
420 if onto_commit:
421 onto_snap = read_snapshot(root, onto_commit.snapshot_id)
422 if onto_snap:
423 squash_manifest = dict(onto_snap.manifest)
424
425 conflict_occurred = False
426 for commit in commits_to_replay:
427 base_manifest: dict[str, str] = {}
428 if commit.parent_commit_id:
429 pc = read_commit(root, commit.parent_commit_id)
430 if pc:
431 ps = read_snapshot(root, pc.snapshot_id)
432 if ps:
433 base_manifest = ps.manifest
434
435 theirs_snap = read_snapshot(root, commit.snapshot_id)
436 theirs_manifest = theirs_snap.manifest if theirs_snap else {}
437
438 result = plugin.merge(
439 _SnapshotManifest(files=base_manifest, domain=domain),
440 _SnapshotManifest(files=squash_manifest, domain=domain),
441 _SnapshotManifest(files=theirs_manifest, domain=domain),
442 repo_root=root,
443 )
444 if not result.is_clean:
445 typer.echo(f"❌ Conflict during squash at {commit.commit_id[:12]}:")
446 for p in sorted(result.conflicts):
447 typer.echo(f" CONFLICT: {p}")
448 typer.echo("Resolve conflicts and try again. Squash does not support --continue.")
449 conflict_occurred = True
450 break
451 squash_manifest = result.merged["files"]
452
453 if conflict_occurred:
454 raise typer.Exit(code=ExitCode.USER_ERROR)
455
456 apply_manifest(root, squash_manifest)
457 snapshot_id = compute_snapshot_id(squash_manifest)
458 committed_at = datetime.datetime.now(datetime.timezone.utc)
459 final_message = squash_message or commits_to_replay[-1].message
460 new_commit_id = compute_commit_id(
461 parent_ids=[onto_id],
462 snapshot_id=snapshot_id,
463 message=final_message,
464 committed_at_iso=committed_at.isoformat(),
465 )
466 write_snapshot(root, SnapshotRecord(snapshot_id=snapshot_id, manifest=squash_manifest))
467 write_commit(root, CommitRecord(
468 commit_id=new_commit_id,
469 repo_id=repo_id,
470 branch=branch,
471 snapshot_id=snapshot_id,
472 message=final_message,
473 committed_at=committed_at,
474 parent_commit_id=onto_id,
475 ))
476 _write_branch_ref(root, branch, new_commit_id)
477 append_reflog(
478 root, branch,
479 old_id=head_commit_id,
480 new_id=new_commit_id,
481 author="user",
482 operation=f"rebase --squash onto {onto_id[:12]}",
483 )
484 if fmt == "json":
485 typer.echo(_json.dumps({
486 "status": "completed",
487 "branch": branch,
488 "new_head": new_commit_id,
489 "onto": onto_id,
490 "squash": True,
491 "conflicts": [],
492 }))
493 else:
494 typer.echo(f"✅ Squash-rebase complete. HEAD is now {new_commit_id[:12]}.")
495 return
496
497 # Normal replay loop.
498 state = RebaseState(
499 original_branch=branch,
500 original_head=head_commit_id,
501 onto=onto_id,
502 remaining=[c.commit_id for c in commits_to_replay],
503 completed=[],
504 squash=False,
505 )
506 save_rebase_state(root, state)
507
508 clean = _run_replay_loop(root, state, repo_id, branch, plugin, domain)
509
510 if clean:
511 final_id = state["completed"][-1] if state["completed"] else onto_id
512 _write_branch_ref(root, branch, final_id)
513 clear_rebase_state(root)
514 append_reflog(
515 root, branch,
516 old_id=head_commit_id,
517 new_id=final_id,
518 author="user",
519 operation=f"rebase: finished onto {onto_id[:12]}",
520 )
521 if fmt == "json":
522 typer.echo(_json.dumps({
523 "status": "completed",
524 "branch": branch,
525 "new_head": final_id,
526 "onto": onto_id,
527 "squash": False,
528 "replayed": len(state["completed"]),
529 "conflicts": [],
530 }))
531 else:
532 typer.echo(f"✅ Rebase complete. HEAD is now {final_id[:12]}.")