gabriel / muse public
merge_engine.py python
480 lines 16.8 KB
8aa515d5 refactor: consolidate schema_version to single source of truth Gabriel Cardona <gabriel@tellurstori.com> 3d ago
1 """Muse VCS merge engine — fast-forward, 3-way, op-level, and CRDT merge.
2
3 Public API
4 ----------
5 Pure functions (no I/O):
6
7 - :func:`diff_snapshots` — paths that changed between two snapshot manifests.
8 - :func:`detect_conflicts` — paths changed on *both* branches since the base.
9 - :func:`apply_merge` — build merged manifest for a conflict-free 3-way merge.
10 - :func:`crdt_join_snapshots` — convergent CRDT join; always succeeds.
11
12 Operational Transformation (operation-level) merge:
13
14 - :mod:`muse.core.op_transform` — ``ops_commute``, ``transform``, ``merge_op_lists``,
15 ``merge_structured``, and :class:`~muse.core.op_transform.MergeOpsResult`.
16 Plugins that implement :class:`~muse.domain.StructuredMergePlugin` use these
17 functions to auto-merge non-conflicting ``DomainOp`` lists.
18
19 CRDT convergent merge:
20
21 - :func:`crdt_join_snapshots` — detects :class:`~muse.domain.CRDTPlugin` at
22 runtime and delegates to ``plugin.join(a, b)``. Returns a
23 :class:`~muse.domain.MergeResult` with an empty ``conflicts`` list; CRDT
24 joins never fail.
25
26 File-based helpers:
27
28 - :func:`find_merge_base` — lowest common ancestor (LCA) of two commits.
29 - :func:`read_merge_state` — detect and load an in-progress merge.
30 - :func:`write_merge_state` — persist conflict state before exiting.
31 - :func:`clear_merge_state` — remove MERGE_STATE.json after resolution.
32 - :func:`apply_resolution` — restore a specific object version to state/.
33
34 ``MERGE_STATE.json`` schema
35 ---------------------------
36
37 .. code-block:: json
38
39 {
40 "base_commit": "abc123...",
41 "ours_commit": "def456...",
42 "theirs_commit": "789abc...",
43 "conflict_paths": ["beat.mid", "lead.mp3"],
44 "other_branch": "feature/experiment"
45 }
46
47 ``other_branch`` is optional; all other fields are required when conflicts exist.
48 """
49
50 from __future__ import annotations
51
52 import json
53 import logging
54 import pathlib
55 from collections import deque
56 from dataclasses import dataclass, field
57 from typing import TYPE_CHECKING, TypedDict
58
59 from muse._version import __version__
60 from muse.core.validation import contain_path, validate_object_id, validate_ref_id
61
62 if TYPE_CHECKING:
63 from muse.domain import MergeResult, MuseDomainPlugin
64
65 logger = logging.getLogger(__name__)
66
67 _MERGE_STATE_FILENAME = "MERGE_STATE.json"
68
69
70 # ---------------------------------------------------------------------------
71 # Wire-format TypedDict
72 # ---------------------------------------------------------------------------
73
74
75 class MergeStatePayload(TypedDict, total=False):
76 """JSON-serialisable form of an in-progress merge state."""
77
78 base_commit: str
79 ours_commit: str
80 theirs_commit: str
81 conflict_paths: list[str]
82 other_branch: str
83
84
85 # ---------------------------------------------------------------------------
86 # MergeState dataclass
87 # ---------------------------------------------------------------------------
88
89
90 @dataclass(frozen=True)
91 class MergeState:
92 """Describes an in-progress merge with unresolved conflicts."""
93
94 conflict_paths: list[str] = field(default_factory=list)
95 base_commit: str | None = None
96 ours_commit: str | None = None
97 theirs_commit: str | None = None
98 other_branch: str | None = None
99
100
101 # ---------------------------------------------------------------------------
102 # Filesystem helpers
103 # ---------------------------------------------------------------------------
104
105
106 def read_merge_state(root: pathlib.Path) -> MergeState | None:
107 """Return :class:`MergeState` if a merge is in progress, otherwise ``None``."""
108 merge_state_path = root / ".muse" / _MERGE_STATE_FILENAME
109 if not merge_state_path.exists():
110 return None
111 try:
112 data = json.loads(merge_state_path.read_text())
113 except (json.JSONDecodeError, OSError) as exc:
114 logger.warning("⚠️ Failed to read %s: %s", _MERGE_STATE_FILENAME, exc)
115 return None
116
117 raw_conflicts = data.get("conflict_paths", [])
118 safe_conflict_paths: list[str] = []
119 if isinstance(raw_conflicts, list):
120 for c in raw_conflicts:
121 try:
122 contained = contain_path(root, str(c))
123 # Store as relative POSIX string for display; contain_path already validated it.
124 safe_conflict_paths.append(contained.relative_to(root.resolve()).as_posix())
125 except ValueError:
126 logger.warning(
127 "⚠️ Skipping unsafe conflict path %r from MERGE_STATE.json", c
128 )
129
130 def _validated_ref(key: str) -> str | None:
131 val = data.get(key)
132 if val is None:
133 return None
134 s = str(val)
135 try:
136 validate_ref_id(s)
137 return s
138 except ValueError:
139 logger.warning(
140 "⚠️ Invalid %s %r in MERGE_STATE.json — ignoring", key, s
141 )
142 return None
143
144 def _str_or_none(key: str) -> str | None:
145 val = data.get(key)
146 return str(val) if val is not None else None
147
148 return MergeState(
149 conflict_paths=safe_conflict_paths,
150 base_commit=_validated_ref("base_commit"),
151 ours_commit=_validated_ref("ours_commit"),
152 theirs_commit=_validated_ref("theirs_commit"),
153 other_branch=_str_or_none("other_branch"),
154 )
155
156
157 def write_merge_state(
158 root: pathlib.Path,
159 *,
160 base_commit: str,
161 ours_commit: str,
162 theirs_commit: str,
163 conflict_paths: list[str],
164 other_branch: str | None = None,
165 ) -> None:
166 """Write ``.muse/MERGE_STATE.json`` to signal an in-progress conflicted merge.
167
168 Called by the ``muse merge`` command when the merge produces at least one
169 conflict that cannot be auto-resolved. The file is read back by
170 :func:`read_merge_state` on subsequent ``muse status`` and ``muse commit``
171 invocations to surface conflict state to the user.
172
173 Args:
174 root: Repository root (parent of ``.muse/``).
175 base_commit: Commit ID of the merge base (common ancestor).
176 ours_commit: Commit ID of the current branch (HEAD) at merge time.
177 theirs_commit: Commit ID of the branch being merged in.
178 conflict_paths: Sorted list of workspace-relative POSIX paths with
179 unresolvable conflicts.
180 other_branch: Name of the branch being merged in; stored for
181 informational display but not required for resolution.
182 """
183 merge_state_path = root / ".muse" / _MERGE_STATE_FILENAME
184 payload: MergeStatePayload = {
185 "base_commit": base_commit,
186 "ours_commit": ours_commit,
187 "theirs_commit": theirs_commit,
188 "conflict_paths": sorted(conflict_paths),
189 }
190 if other_branch is not None:
191 payload["other_branch"] = other_branch
192 merge_state_path.write_text(json.dumps(payload, indent=2))
193 logger.info("✅ Wrote MERGE_STATE.json with %d conflict(s)", len(conflict_paths))
194
195
196 def clear_merge_state(root: pathlib.Path) -> None:
197 """Remove ``.muse/MERGE_STATE.json`` after a successful merge or resolution."""
198 merge_state_path = root / ".muse" / _MERGE_STATE_FILENAME
199 if merge_state_path.exists():
200 merge_state_path.unlink()
201 logger.debug("✅ Cleared MERGE_STATE.json")
202
203
204 def apply_resolution(
205 root: pathlib.Path,
206 rel_path: str,
207 object_id: str,
208 ) -> None:
209 """Restore a specific object version to the working tree at ``<rel_path>``.
210
211 Used by the ``muse merge --resolve`` workflow: after a user has chosen
212 which version of a conflicting file to keep, this function writes that
213 version into the working tree so ``muse commit`` can snapshot it.
214
215 Args:
216 root: Repository root (parent of ``.muse/``).
217 rel_path: Workspace-relative POSIX path of the conflicting file.
218 object_id: SHA-256 of the chosen resolution content in the object store.
219
220 Raises:
221 FileNotFoundError: When *object_id* is not present in the local store.
222 """
223 from muse.core.object_store import read_object
224
225 validate_object_id(object_id)
226 dest = contain_path(root, rel_path)
227
228 content = read_object(root, object_id)
229 if content is None:
230 raise FileNotFoundError(
231 f"Object {object_id[:8]} for '{rel_path}' not found in local store."
232 )
233 dest.parent.mkdir(parents=True, exist_ok=True)
234 dest.write_bytes(content)
235 logger.debug("✅ Restored '%s' from object %s", rel_path, object_id[:8])
236
237
238 def is_conflict_resolved(merge_state: MergeState, rel_path: str) -> bool:
239 """Return ``True`` if *rel_path* is NOT listed as a conflict in *merge_state*."""
240 return rel_path not in merge_state.conflict_paths
241
242
243 # ---------------------------------------------------------------------------
244 # Pure merge functions (no I/O)
245 # ---------------------------------------------------------------------------
246
247
248 def diff_snapshots(
249 base_manifest: dict[str, str],
250 other_manifest: dict[str, str],
251 ) -> set[str]:
252 """Return the set of paths that differ between *base_manifest* and *other_manifest*.
253
254 A path is "different" if it was added (in *other* but not *base*), deleted
255 (in *base* but not *other*), or modified (present in both with different
256 content hashes).
257
258 Args:
259 base_manifest: Path → content-hash map for the ancestor snapshot.
260 other_manifest: Path → content-hash map for the other snapshot.
261
262 Returns:
263 Set of workspace-relative POSIX paths that differ.
264 """
265 base_paths = set(base_manifest.keys())
266 other_paths = set(other_manifest.keys())
267 added = other_paths - base_paths
268 deleted = base_paths - other_paths
269 common = base_paths & other_paths
270 modified = {p for p in common if base_manifest[p] != other_manifest[p]}
271 return added | deleted | modified
272
273
274 def detect_conflicts(
275 ours_changed: set[str],
276 theirs_changed: set[str],
277 ) -> set[str]:
278 """Return paths changed on *both* branches since the merge base."""
279 return ours_changed & theirs_changed
280
281
282 def apply_merge(
283 base_manifest: dict[str, str],
284 ours_manifest: dict[str, str],
285 theirs_manifest: dict[str, str],
286 ours_changed: set[str],
287 theirs_changed: set[str],
288 conflict_paths: set[str],
289 ) -> dict[str, str]:
290 """Build the merged snapshot manifest for a conflict-free 3-way merge.
291
292 Starts from *base_manifest* and applies non-conflicting changes from both
293 branches:
294
295 - Ours-only changes (in *ours_changed* but not *conflict_paths*) are taken
296 from *ours_manifest*. Deletions are handled by the absence of the path
297 in *ours_manifest*.
298 - Theirs-only changes (in *theirs_changed* but not *conflict_paths*) are
299 taken from *theirs_manifest* by the same logic.
300 - Paths in *conflict_paths* are excluded — callers must resolve them
301 separately before producing a final merged snapshot.
302
303 Args:
304 base_manifest: Path → content-hash for the common ancestor.
305 ours_manifest: Path → content-hash for our branch.
306 theirs_manifest: Path → content-hash for their branch.
307 ours_changed: Paths changed by our branch (from :func:`diff_snapshots`).
308 theirs_changed: Paths changed by their branch.
309 conflict_paths: Paths with concurrent changes — excluded from output.
310
311 Returns:
312 Merged path → content-hash mapping; conflict paths are absent.
313 """
314 merged: dict[str, str] = dict(base_manifest)
315 for path in ours_changed - conflict_paths:
316 if path in ours_manifest:
317 merged[path] = ours_manifest[path]
318 else:
319 merged.pop(path, None)
320 for path in theirs_changed - conflict_paths:
321 if path in theirs_manifest:
322 merged[path] = theirs_manifest[path]
323 else:
324 merged.pop(path, None)
325 return merged
326
327
328 # ---------------------------------------------------------------------------
329 # CRDT convergent join
330 # ---------------------------------------------------------------------------
331
332
333 def crdt_join_snapshots(
334 plugin: MuseDomainPlugin,
335 a_snapshot: dict[str, str],
336 b_snapshot: dict[str, str],
337 a_vclock: dict[str, int],
338 b_vclock: dict[str, int],
339 a_crdt_state: dict[str, str],
340 b_crdt_state: dict[str, str],
341 domain: str,
342 ) -> MergeResult:
343 """Convergent CRDT merge — always succeeds, no conflicts possible.
344
345 Detects :class:`~muse.domain.CRDTPlugin` support via ``isinstance`` and
346 delegates to ``plugin.join(a, b)``. The returned :class:`~muse.domain.MergeResult`
347 always has an empty ``conflicts`` list — the defining property of CRDT joins.
348
349 This function is the CRDT entry point for the ``muse merge`` command.
350 It is only called when ``DomainSchema.merge_mode == "crdt"`` AND the plugin
351 passes the ``isinstance(plugin, CRDTPlugin)`` check.
352
353 Args:
354 plugin: The loaded domain plugin instance.
355 a_snapshot: ``files`` mapping (path → content hash) for replica A.
356 b_snapshot: ``files`` mapping (path → content hash) for replica B.
357 a_vclock: Vector clock ``{agent_id: count}`` for replica A.
358 b_vclock: Vector clock ``{agent_id: count}`` for replica B.
359 a_crdt_state: CRDT metadata hashes (path → blob hash) for replica A.
360 b_crdt_state: CRDT metadata hashes (path → blob hash) for replica B.
361 domain: Domain name string (e.g. ``"midi"``).
362
363 Returns:
364 A :class:`~muse.domain.MergeResult` with the joined snapshot and an
365 empty ``conflicts`` list.
366
367 Raises:
368 TypeError: When *plugin* does not implement the
369 :class:`~muse.domain.CRDTPlugin` protocol.
370 """
371 from muse.domain import CRDTPlugin, CRDTSnapshotManifest, MergeResult, StateSnapshot
372
373 if not isinstance(plugin, CRDTPlugin):
374 raise TypeError(
375 f"crdt_join_snapshots: plugin {type(plugin).__name__!r} does not "
376 "implement CRDTPlugin — cannot use CRDT join path."
377 )
378
379 a_crdt: CRDTSnapshotManifest = {
380 "files": a_snapshot,
381 "domain": domain,
382 "vclock": a_vclock,
383 "crdt_state": a_crdt_state,
384 "schema_version": __version__,
385 }
386 b_crdt: CRDTSnapshotManifest = {
387 "files": b_snapshot,
388 "domain": domain,
389 "vclock": b_vclock,
390 "crdt_state": b_crdt_state,
391 "schema_version": __version__,
392 }
393
394 result_crdt = plugin.join(a_crdt, b_crdt)
395 plain_snapshot: StateSnapshot = plugin.from_crdt_state(result_crdt)
396
397 return MergeResult(
398 merged=plain_snapshot,
399 conflicts=[],
400 applied_strategies={},
401 )
402
403
404 # ---------------------------------------------------------------------------
405 # File-based merge base finder
406 # ---------------------------------------------------------------------------
407
408
409 def find_merge_base(
410 repo_root: pathlib.Path,
411 commit_id_a: str,
412 commit_id_b: str,
413 ) -> str | None:
414 """Find the Lowest Common Ancestor (LCA) of two commits.
415
416 Uses BFS to collect all ancestors of *commit_id_a* (inclusive), then
417 walks *commit_id_b*'s ancestor graph (BFS) until the first node found
418 in *a*'s ancestor set is reached.
419
420 Args:
421 repo_root: The repository root directory.
422 commit_id_a: First commit ID (e.g., current branch HEAD).
423 commit_id_b: Second commit ID (e.g., target branch HEAD).
424
425 Returns:
426 The LCA commit ID, or ``None`` if the commits share no common ancestor.
427 """
428 from muse.core.errors import MuseCLIError
429 from muse.core.store import read_commit
430
431 _MAX_ANCESTORS = 50_000
432
433 def _all_ancestors(start: str) -> set[str]:
434 visited: set[str] = set()
435 queue: deque[str] = deque([start])
436 while queue:
437 if len(visited) >= _MAX_ANCESTORS:
438 raise MuseCLIError(
439 f"Ancestor graph exceeds {_MAX_ANCESTORS} commits during "
440 "merge-base search. The repository DAG may be malformed."
441 )
442 cid = queue.popleft()
443 if cid in visited:
444 continue
445 visited.add(cid)
446 commit = read_commit(repo_root, cid)
447 if commit is None:
448 continue
449 if commit.parent_commit_id:
450 queue.append(commit.parent_commit_id)
451 if commit.parent2_commit_id:
452 queue.append(commit.parent2_commit_id)
453 return visited
454
455 a_ancestors = _all_ancestors(commit_id_a)
456
457 visited_b: set[str] = set()
458 queue_b: deque[str] = deque([commit_id_b])
459 while queue_b:
460 if len(visited_b) >= _MAX_ANCESTORS:
461 logger.warning(
462 "⚠️ Ancestor graph exceeds %d commits during merge-base search — stopping",
463 _MAX_ANCESTORS,
464 )
465 return None
466 cid = queue_b.popleft()
467 if cid in visited_b:
468 continue
469 visited_b.add(cid)
470 if cid in a_ancestors:
471 return cid
472 commit = read_commit(repo_root, cid)
473 if commit is None:
474 continue
475 if commit.parent_commit_id:
476 queue_b.append(commit.parent_commit_id)
477 if commit.parent2_commit_id:
478 queue_b.append(commit.parent2_commit_id)
479
480 return None