rebase.py
python
| 1 | """``muse rebase`` — replay commits from one branch onto another. |
| 2 | |
| 3 | Muse rebase is cherry-pick of a range: it takes the commits unique to the |
| 4 | current branch (those not reachable from the upstream) and replays them |
| 5 | one-by-one on top of the upstream. Because commits are content-addressed, |
| 6 | each replayed commit gets a new ID — the originals are untouched in the store. |
| 7 | |
| 8 | Usage:: |
| 9 | |
| 10 | muse rebase <upstream> # replay HEAD's unique commits onto upstream |
| 11 | muse rebase --onto <newbase> <upstream> # replay onto a different base |
| 12 | muse rebase --squash [<upstream>] # collapse all commits into one |
| 13 | muse rebase --abort # restore original HEAD |
| 14 | muse rebase --continue # resume after resolving a conflict |
| 15 | |
| 16 | Exit codes:: |
| 17 | |
| 18 | 0 — rebase completed or aborted successfully |
| 19 | 1 — conflict encountered, or bad arguments |
| 20 | 3 — internal error |
| 21 | """ |
| 22 | |
| 23 | from __future__ import annotations |
| 24 | |
| 25 | import datetime |
| 26 | import json |
| 27 | import logging |
| 28 | import pathlib |
| 29 | from typing import Annotated |
| 30 | |
| 31 | import typer |
| 32 | |
| 33 | from muse.core.errors import ExitCode |
| 34 | from muse.core.merge_engine import find_merge_base, write_merge_state |
| 35 | from muse.core.validation import validate_branch_name |
| 36 | from muse.domain import SnapshotManifest as _SnapshotManifest |
| 37 | from muse.core.rebase import ( |
| 38 | RebaseState, |
| 39 | _write_branch_ref, |
| 40 | clear_rebase_state, |
| 41 | collect_commits_to_replay, |
| 42 | load_rebase_state, |
| 43 | replay_one, |
| 44 | save_rebase_state, |
| 45 | ) |
| 46 | from muse.core.reflog import append_reflog |
| 47 | from muse.core.repo import require_repo |
| 48 | from muse.domain import MuseDomainPlugin, SnapshotManifest |
| 49 | from muse.core.snapshot import compute_commit_id, compute_snapshot_id |
| 50 | from muse.core.store import ( |
| 51 | CommitRecord, |
| 52 | SnapshotRecord, |
| 53 | get_head_commit_id, |
| 54 | read_commit, |
| 55 | read_current_branch, |
| 56 | read_snapshot, |
| 57 | resolve_commit_ref, |
| 58 | write_commit, |
| 59 | write_snapshot, |
| 60 | ) |
| 61 | from muse.core.validation import sanitize_display |
| 62 | from muse.core.workdir import apply_manifest |
| 63 | from muse.plugins.registry import read_domain, resolve_plugin |
| 64 | |
| 65 | logger = logging.getLogger(__name__) |
| 66 | |
| 67 | app = typer.Typer(help="Replay commits from the current branch onto a new base.") |
| 68 | |
| 69 | |
| 70 | def _read_repo_id(root: pathlib.Path) -> str: |
| 71 | return str(json.loads((root / ".muse" / "repo.json").read_text(encoding="utf-8"))["repo_id"]) |
| 72 | |
| 73 | |
| 74 | def _resolve_ref_to_id( |
| 75 | root: pathlib.Path, |
| 76 | repo_id: str, |
| 77 | branch: str, |
| 78 | ref: str, |
| 79 | ) -> str | None: |
| 80 | """Resolve a ref string (branch name, commit SHA, or HEAD) to a commit ID. |
| 81 | |
| 82 | Branch names are validated with ``validate_branch_name`` before being used |
| 83 | as path components to prevent directory traversal attacks. |
| 84 | """ |
| 85 | if ref.upper() == "HEAD": |
| 86 | return get_head_commit_id(root, branch) |
| 87 | |
| 88 | # Try as a branch ref — validate before using as a path component. |
| 89 | try: |
| 90 | from muse.core.validation import validate_branch_name |
| 91 | validate_branch_name(ref) |
| 92 | except (ValueError, ImportError): |
| 93 | # Not a valid branch name — try as a commit SHA instead. |
| 94 | rec = resolve_commit_ref(root, repo_id, branch, ref) |
| 95 | return rec.commit_id if rec else None |
| 96 | |
| 97 | ref_path = root / ".muse" / "refs" / "heads" / ref |
| 98 | if ref_path.exists(): |
| 99 | raw = ref_path.read_text(encoding="utf-8").strip() |
| 100 | if raw and len(raw) == 64 and all(c in "0123456789abcdef" for c in raw): |
| 101 | return raw |
| 102 | |
| 103 | # Fall back to commit SHA prefix resolution. |
| 104 | rec = resolve_commit_ref(root, repo_id, branch, ref) |
| 105 | return rec.commit_id if rec else None |
| 106 | |
| 107 | |
| 108 | def _run_replay_loop( |
| 109 | root: pathlib.Path, |
| 110 | state: RebaseState, |
| 111 | repo_id: str, |
| 112 | branch: str, |
| 113 | plugin: "MuseDomainPlugin", |
| 114 | domain: str, |
| 115 | ) -> bool: |
| 116 | """Run the replay loop. Returns True if completed cleanly, False on conflict.""" |
| 117 | current_parent = state["completed"][-1] if state["completed"] else state["onto"] |
| 118 | |
| 119 | while state["remaining"]: |
| 120 | orig_commit_id = state["remaining"][0] |
| 121 | commit = read_commit(root, orig_commit_id) |
| 122 | if commit is None: |
| 123 | typer.echo(f"⚠️ Commit {orig_commit_id[:12]} not found — skipping.") |
| 124 | state["remaining"].pop(0) |
| 125 | save_rebase_state(root, state) |
| 126 | continue |
| 127 | |
| 128 | typer.echo(f" Replaying {orig_commit_id[:12]}: {sanitize_display(commit.message)}") |
| 129 | |
| 130 | result = replay_one( |
| 131 | root, commit, current_parent, plugin, domain, repo_id, branch |
| 132 | ) |
| 133 | |
| 134 | if isinstance(result, list): |
| 135 | # Conflict — write state and pause. |
| 136 | state["remaining"].pop(0) # will retry via --continue |
| 137 | state["remaining"].insert(0, orig_commit_id) |
| 138 | save_rebase_state(root, state) |
| 139 | |
| 140 | write_merge_state( |
| 141 | root, |
| 142 | base_commit=commit.parent_commit_id or "", |
| 143 | ours_commit=current_parent, |
| 144 | theirs_commit=orig_commit_id, |
| 145 | conflict_paths=result, |
| 146 | ) |
| 147 | typer.echo(f"\n❌ Rebase stopped at {orig_commit_id[:12]} due to conflict(s):") |
| 148 | for p in sorted(result): |
| 149 | typer.echo(f" CONFLICT: {p}") |
| 150 | typer.echo( |
| 151 | "\nResolve conflicts then run:\n" |
| 152 | " muse rebase --continue to resume\n" |
| 153 | " muse rebase --abort to restore original HEAD" |
| 154 | ) |
| 155 | return False |
| 156 | |
| 157 | # Clean replay — advance. |
| 158 | current_parent = result.commit_id |
| 159 | state["remaining"].pop(0) |
| 160 | state["completed"].append(result.commit_id) |
| 161 | save_rebase_state(root, state) |
| 162 | |
| 163 | append_reflog( |
| 164 | root, branch, |
| 165 | old_id=state["completed"][-2] if len(state["completed"]) >= 2 else state["onto"], |
| 166 | new_id=result.commit_id, |
| 167 | author="user", |
| 168 | operation=f"rebase: replayed {orig_commit_id[:12]} onto {state['onto'][:12]}", |
| 169 | ) |
| 170 | |
| 171 | return True |
| 172 | |
| 173 | |
| 174 | @app.callback(invoke_without_command=True) |
| 175 | def rebase( |
| 176 | ctx: typer.Context, |
| 177 | upstream: Annotated[ |
| 178 | str | None, |
| 179 | typer.Argument(help="Branch or commit to rebase onto."), |
| 180 | ] = None, |
| 181 | onto: Annotated[ |
| 182 | str | None, |
| 183 | typer.Option("--onto", "-o", help="New base commit (replay commits between <upstream> and HEAD onto this)."), |
| 184 | ] = None, |
| 185 | squash: Annotated[ |
| 186 | bool, |
| 187 | typer.Option("--squash", "-s", help="Collapse all replayed commits into one."), |
| 188 | ] = False, |
| 189 | squash_message: Annotated[ |
| 190 | str | None, |
| 191 | typer.Option("--message", "-m", help="Commit message for --squash (default: last commit's message)."), |
| 192 | ] = None, |
| 193 | abort: Annotated[ |
| 194 | bool, |
| 195 | typer.Option("--abort", "-a", help="Abort an in-progress rebase and restore the original HEAD."), |
| 196 | ] = False, |
| 197 | continue_: Annotated[ |
| 198 | bool, |
| 199 | typer.Option("--continue", "-c", help="Resume after resolving a conflict."), |
| 200 | ] = False, |
| 201 | fmt: Annotated[ |
| 202 | str, |
| 203 | typer.Option("--format", "-f", help="Output format: text or json."), |
| 204 | ] = "text", |
| 205 | ) -> None: |
| 206 | """Replay commits from the current branch onto a new base. |
| 207 | |
| 208 | The most common invocation replays all commits unique to the current |
| 209 | branch on top of *upstream*'s HEAD:: |
| 210 | |
| 211 | muse rebase main # rebase current branch onto main |
| 212 | |
| 213 | Use ``--onto`` when you need to replay onto a commit that is not the |
| 214 | tip of *upstream*:: |
| 215 | |
| 216 | muse rebase --onto newbase upstream |
| 217 | |
| 218 | Use ``--squash`` to collapse all replayed commits into a single commit |
| 219 | for a clean merge workflow:: |
| 220 | |
| 221 | muse rebase --squash main |
| 222 | |
| 223 | When a conflict is encountered, the rebase pauses. Resolve the conflict, |
| 224 | then:: |
| 225 | |
| 226 | muse rebase --continue |
| 227 | |
| 228 | Or discard the entire rebase:: |
| 229 | |
| 230 | muse rebase --abort |
| 231 | |
| 232 | Examples:: |
| 233 | |
| 234 | muse rebase main |
| 235 | muse rebase --onto main feat/base |
| 236 | muse rebase --squash --message "feat: combined" main |
| 237 | muse rebase --abort |
| 238 | muse rebase --continue |
| 239 | muse rebase --format json main # machine-readable result |
| 240 | |
| 241 | Agents should pass ``--format json`` to receive a structured result. |
| 242 | Squash rebase payload:: |
| 243 | |
| 244 | { |
| 245 | "status": "merged|conflict", |
| 246 | "branch": "<name>", |
| 247 | "new_head": "<sha256>", |
| 248 | "onto": "<sha256>", |
| 249 | "squash": true, |
| 250 | "conflicts": [] |
| 251 | } |
| 252 | |
| 253 | Normal rebase payload adds ``"replayed": <n>`` (number of commits applied). |
| 254 | """ |
| 255 | import json as _json |
| 256 | if fmt not in ("text", "json"): |
| 257 | from muse.core.validation import sanitize_display as _sd |
| 258 | typer.echo(f"❌ Unknown --format '{_sd(fmt)}'. Choose text or json.", err=True) |
| 259 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 260 | |
| 261 | root = require_repo() |
| 262 | repo_id = _read_repo_id(root) |
| 263 | branch = read_current_branch(root) |
| 264 | plugin = resolve_plugin(root) |
| 265 | domain = read_domain(root) |
| 266 | |
| 267 | active_state = load_rebase_state(root) |
| 268 | |
| 269 | # --abort |
| 270 | if abort: |
| 271 | if active_state is None: |
| 272 | typer.echo("❌ No rebase in progress.") |
| 273 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 274 | |
| 275 | original_head = active_state["original_head"] |
| 276 | original_branch = active_state["original_branch"] |
| 277 | _write_branch_ref(root, original_branch, original_head) |
| 278 | |
| 279 | # Restore working tree to original HEAD. |
| 280 | orig_commit = read_commit(root, original_head) |
| 281 | if orig_commit: |
| 282 | snap = read_snapshot(root, orig_commit.snapshot_id) |
| 283 | if snap: |
| 284 | apply_manifest(root, snap.manifest) |
| 285 | |
| 286 | append_reflog( |
| 287 | root, original_branch, |
| 288 | old_id=active_state["completed"][-1] if active_state["completed"] else active_state["onto"], |
| 289 | new_id=original_head, |
| 290 | author="user", |
| 291 | operation="rebase: abort", |
| 292 | ) |
| 293 | clear_rebase_state(root) |
| 294 | typer.echo(f"✅ Rebase aborted. HEAD restored to {original_head[:12]}.") |
| 295 | return |
| 296 | |
| 297 | # --continue |
| 298 | if continue_: |
| 299 | if active_state is None: |
| 300 | typer.echo("❌ No rebase in progress. Nothing to continue.") |
| 301 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 302 | |
| 303 | # The user has resolved the conflict manually. Snapshot the current |
| 304 | # working tree and create the commit for the paused step. |
| 305 | current_parent = ( |
| 306 | active_state["completed"][-1] |
| 307 | if active_state["completed"] |
| 308 | else active_state["onto"] |
| 309 | ) |
| 310 | orig_commit_id = active_state["remaining"][0] if active_state["remaining"] else "" |
| 311 | orig_commit = read_commit(root, orig_commit_id) if orig_commit_id else None |
| 312 | |
| 313 | snap_result = plugin.snapshot(root) |
| 314 | manifest: dict[str, str] = snap_result["files"] |
| 315 | snapshot_id = compute_snapshot_id(manifest) |
| 316 | committed_at = datetime.datetime.now(datetime.timezone.utc) |
| 317 | message = orig_commit.message if orig_commit else "rebase: continued" |
| 318 | new_commit_id = compute_commit_id( |
| 319 | parent_ids=[current_parent] if current_parent else [], |
| 320 | snapshot_id=snapshot_id, |
| 321 | message=message, |
| 322 | committed_at_iso=committed_at.isoformat(), |
| 323 | ) |
| 324 | write_snapshot(root, SnapshotRecord(snapshot_id=snapshot_id, manifest=manifest)) |
| 325 | new_commit = CommitRecord( |
| 326 | commit_id=new_commit_id, |
| 327 | repo_id=repo_id, |
| 328 | branch=branch, |
| 329 | snapshot_id=snapshot_id, |
| 330 | message=message, |
| 331 | committed_at=committed_at, |
| 332 | parent_commit_id=current_parent if current_parent else None, |
| 333 | author=orig_commit.author if orig_commit else "", |
| 334 | ) |
| 335 | write_commit(root, new_commit) |
| 336 | active_state["completed"].append(new_commit_id) |
| 337 | if active_state["remaining"]: |
| 338 | active_state["remaining"].pop(0) |
| 339 | save_rebase_state(root, active_state) |
| 340 | |
| 341 | append_reflog( |
| 342 | root, branch, |
| 343 | old_id=current_parent, |
| 344 | new_id=new_commit_id, |
| 345 | author="user", |
| 346 | operation=f"rebase: continue — replayed {orig_commit_id[:12] if orig_commit_id else '?'}", |
| 347 | ) |
| 348 | |
| 349 | if not active_state["remaining"]: |
| 350 | _write_branch_ref(root, branch, new_commit_id) |
| 351 | clear_rebase_state(root) |
| 352 | typer.echo(f"✅ Rebase complete. HEAD is now {new_commit_id[:12]}.") |
| 353 | return |
| 354 | |
| 355 | # More commits to replay. |
| 356 | clean = _run_replay_loop(root, active_state, repo_id, branch, plugin, domain) |
| 357 | if clean: |
| 358 | final_id = active_state["completed"][-1] |
| 359 | _write_branch_ref(root, branch, final_id) |
| 360 | clear_rebase_state(root) |
| 361 | typer.echo(f"✅ Rebase complete. HEAD is now {final_id[:12]}.") |
| 362 | return |
| 363 | |
| 364 | # New rebase — check not already in progress. |
| 365 | if active_state is not None: |
| 366 | typer.echo( |
| 367 | "❌ Rebase in progress. Use --continue or --abort." |
| 368 | ) |
| 369 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 370 | |
| 371 | if upstream is None: |
| 372 | typer.echo("❌ Provide an upstream branch or commit to rebase onto.", err=True) |
| 373 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 374 | |
| 375 | # Resolve HEAD and upstream. |
| 376 | head_commit_id = get_head_commit_id(root, branch) |
| 377 | if head_commit_id is None: |
| 378 | typer.echo("❌ Current branch has no commits.", err=True) |
| 379 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 380 | |
| 381 | upstream_id = _resolve_ref_to_id(root, repo_id, branch, upstream) |
| 382 | if upstream_id is None: |
| 383 | typer.echo(f"❌ Upstream '{sanitize_display(upstream)}' not found.", err=True) |
| 384 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 385 | |
| 386 | # Determine the new base. |
| 387 | if onto is not None: |
| 388 | onto_id = _resolve_ref_to_id(root, repo_id, branch, onto) |
| 389 | if onto_id is None: |
| 390 | typer.echo(f"❌ --onto '{sanitize_display(onto)}' not found.", err=True) |
| 391 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 392 | else: |
| 393 | onto_id = upstream_id |
| 394 | |
| 395 | # Find merge base to determine which commits to replay. |
| 396 | merge_base_id = find_merge_base(root, head_commit_id, upstream_id) |
| 397 | stop_at = merge_base_id or "" |
| 398 | |
| 399 | if head_commit_id == upstream_id or head_commit_id == onto_id: |
| 400 | typer.echo("Already up to date.") |
| 401 | return |
| 402 | |
| 403 | commits_to_replay = collect_commits_to_replay(root, stop_at, head_commit_id) |
| 404 | if not commits_to_replay: |
| 405 | typer.echo("Already up to date.") |
| 406 | return |
| 407 | |
| 408 | typer.echo( |
| 409 | f"Rebasing {len(commits_to_replay)} commit(s) " |
| 410 | f"onto {onto_id[:12]} (from {branch})" |
| 411 | ) |
| 412 | |
| 413 | if squash: |
| 414 | # Replay all commits and produce one final squashed commit. |
| 415 | current_parent = onto_id |
| 416 | squash_manifest: dict[str, str] = {} |
| 417 | |
| 418 | # Get onto base snapshot. |
| 419 | onto_commit = read_commit(root, onto_id) |
| 420 | if onto_commit: |
| 421 | onto_snap = read_snapshot(root, onto_commit.snapshot_id) |
| 422 | if onto_snap: |
| 423 | squash_manifest = dict(onto_snap.manifest) |
| 424 | |
| 425 | conflict_occurred = False |
| 426 | for commit in commits_to_replay: |
| 427 | base_manifest: dict[str, str] = {} |
| 428 | if commit.parent_commit_id: |
| 429 | pc = read_commit(root, commit.parent_commit_id) |
| 430 | if pc: |
| 431 | ps = read_snapshot(root, pc.snapshot_id) |
| 432 | if ps: |
| 433 | base_manifest = ps.manifest |
| 434 | |
| 435 | theirs_snap = read_snapshot(root, commit.snapshot_id) |
| 436 | theirs_manifest = theirs_snap.manifest if theirs_snap else {} |
| 437 | |
| 438 | result = plugin.merge( |
| 439 | _SnapshotManifest(files=base_manifest, domain=domain), |
| 440 | _SnapshotManifest(files=squash_manifest, domain=domain), |
| 441 | _SnapshotManifest(files=theirs_manifest, domain=domain), |
| 442 | repo_root=root, |
| 443 | ) |
| 444 | if not result.is_clean: |
| 445 | typer.echo(f"❌ Conflict during squash at {commit.commit_id[:12]}:") |
| 446 | for p in sorted(result.conflicts): |
| 447 | typer.echo(f" CONFLICT: {p}") |
| 448 | typer.echo("Resolve conflicts and try again. Squash does not support --continue.") |
| 449 | conflict_occurred = True |
| 450 | break |
| 451 | squash_manifest = result.merged["files"] |
| 452 | |
| 453 | if conflict_occurred: |
| 454 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 455 | |
| 456 | apply_manifest(root, squash_manifest) |
| 457 | snapshot_id = compute_snapshot_id(squash_manifest) |
| 458 | committed_at = datetime.datetime.now(datetime.timezone.utc) |
| 459 | final_message = squash_message or commits_to_replay[-1].message |
| 460 | new_commit_id = compute_commit_id( |
| 461 | parent_ids=[onto_id], |
| 462 | snapshot_id=snapshot_id, |
| 463 | message=final_message, |
| 464 | committed_at_iso=committed_at.isoformat(), |
| 465 | ) |
| 466 | write_snapshot(root, SnapshotRecord(snapshot_id=snapshot_id, manifest=squash_manifest)) |
| 467 | write_commit(root, CommitRecord( |
| 468 | commit_id=new_commit_id, |
| 469 | repo_id=repo_id, |
| 470 | branch=branch, |
| 471 | snapshot_id=snapshot_id, |
| 472 | message=final_message, |
| 473 | committed_at=committed_at, |
| 474 | parent_commit_id=onto_id, |
| 475 | )) |
| 476 | _write_branch_ref(root, branch, new_commit_id) |
| 477 | append_reflog( |
| 478 | root, branch, |
| 479 | old_id=head_commit_id, |
| 480 | new_id=new_commit_id, |
| 481 | author="user", |
| 482 | operation=f"rebase --squash onto {onto_id[:12]}", |
| 483 | ) |
| 484 | if fmt == "json": |
| 485 | typer.echo(_json.dumps({ |
| 486 | "status": "completed", |
| 487 | "branch": branch, |
| 488 | "new_head": new_commit_id, |
| 489 | "onto": onto_id, |
| 490 | "squash": True, |
| 491 | "conflicts": [], |
| 492 | })) |
| 493 | else: |
| 494 | typer.echo(f"✅ Squash-rebase complete. HEAD is now {new_commit_id[:12]}.") |
| 495 | return |
| 496 | |
| 497 | # Normal replay loop. |
| 498 | state = RebaseState( |
| 499 | original_branch=branch, |
| 500 | original_head=head_commit_id, |
| 501 | onto=onto_id, |
| 502 | remaining=[c.commit_id for c in commits_to_replay], |
| 503 | completed=[], |
| 504 | squash=False, |
| 505 | ) |
| 506 | save_rebase_state(root, state) |
| 507 | |
| 508 | clean = _run_replay_loop(root, state, repo_id, branch, plugin, domain) |
| 509 | |
| 510 | if clean: |
| 511 | final_id = state["completed"][-1] if state["completed"] else onto_id |
| 512 | _write_branch_ref(root, branch, final_id) |
| 513 | clear_rebase_state(root) |
| 514 | append_reflog( |
| 515 | root, branch, |
| 516 | old_id=head_commit_id, |
| 517 | new_id=final_id, |
| 518 | author="user", |
| 519 | operation=f"rebase: finished onto {onto_id[:12]}", |
| 520 | ) |
| 521 | if fmt == "json": |
| 522 | typer.echo(_json.dumps({ |
| 523 | "status": "completed", |
| 524 | "branch": branch, |
| 525 | "new_head": final_id, |
| 526 | "onto": onto_id, |
| 527 | "squash": False, |
| 528 | "replayed": len(state["completed"]), |
| 529 | "conflicts": [], |
| 530 | })) |
| 531 | else: |
| 532 | typer.echo(f"✅ Rebase complete. HEAD is now {final_id[:12]}.") |