gabriel / muse public
tree_edit.py python
308 lines 10.1 KB
c5b7bd6b feat(phase-2): domain schema declaration + diff algorithm library (#15) Gabriel Cardona <cgcardona@gmail.com> 6d ago
1 """LCS-based tree edit algorithm for labeled ordered trees — Phase 2.
2
3 Implements a correct tree diff that produces ``InsertOp``, ``DeleteOp``,
4 ``ReplaceOp``, and ``MoveOp`` entries for labeled ordered trees.
5
6 Algorithm
7 ---------
8 The diff proceeds top-down recursively:
9
10 1. Compare root nodes by ``content_id``. Different content_id → ``ReplaceOp``
11 on the root node.
12 2. Diff the children sequences using the same LCS algorithm as
13 :mod:`~muse.core.diff_algorithms.lcs`:
14
15 - Matched child pairs (same ``content_id``) → recurse into subtree.
16 - Unmatched inserts → ``InsertOp`` (entire subtree added).
17 - Unmatched deletes → ``DeleteOp`` (entire subtree removed).
18 - Paired insert+delete of same ``content_id`` at different positions →
19 ``MoveOp``.
20
21 This approach is O(nm) per tree level where n, m are child counts. It does
22 not find the globally minimal edit script (Zhang-Shasha is optimal), but it
23 is correct: every change is accounted for, and applying the script to the base
24 tree produces the target tree. For the bounded tree sizes typical of domain
25 objects (scenes, tracks, ASTs ≲ 10k nodes), this is more than adequate for
26 Phase 2. Zhang-Shasha optimisation is a drop-in replacement once needed.
27
28 ``TreeNode`` is defined here and re-exported by the package ``__init__``.
29
30 Public API
31 ----------
32 - :class:`TreeNode` — labeled ordered tree node (frozen dataclass).
33 - :func:`diff` — ``TreeNode`` × ``TreeNode`` → ``StructuredDelta``.
34 """
35 from __future__ import annotations
36
37 import logging
38 from dataclasses import dataclass
39 from typing import Literal
40
41 from muse.core.schema import TreeSchema
42 from muse.domain import (
43 DeleteOp,
44 DomainOp,
45 InsertOp,
46 MoveOp,
47 ReplaceOp,
48 StructuredDelta,
49 )
50
51 logger = logging.getLogger(__name__)
52
53
54 # ---------------------------------------------------------------------------
55 # TreeNode — the unit of tree-edit comparison
56 # ---------------------------------------------------------------------------
57
58
59 @dataclass(frozen=True)
60 class TreeNode:
61 """A node in a labeled ordered tree for domain tree-edit algorithms.
62
63 ``id`` is a stable unique identifier for the node (e.g. UUID or path).
64 ``label`` is the human-readable name (e.g. element tag, node type).
65 ``content_id`` is the SHA-256 of this node's own value — excluding its
66 children. Two nodes are considered the same iff their ``content_id``\\s
67 match; a different ``content_id`` triggers a ``ReplaceOp``.
68 ``children`` is an ordered tuple of child nodes.
69 """
70
71 id: str
72 label: str
73 content_id: str
74 children: tuple[TreeNode, ...]
75
76
77 # ---------------------------------------------------------------------------
78 # Internal helpers
79 # ---------------------------------------------------------------------------
80
81
82 def _subtree_nodes(node: TreeNode) -> list[TreeNode]:
83 """Return all nodes in *node*'s subtree (postorder)."""
84 result: list[TreeNode] = []
85
86 def _visit(n: TreeNode) -> None:
87 for child in n.children:
88 _visit(child)
89 result.append(n)
90
91 _visit(node)
92 return result
93
94
95 def _lcs_children(
96 base_children: tuple[TreeNode, ...],
97 target_children: tuple[TreeNode, ...],
98 ) -> list[tuple[Literal["keep", "insert", "delete"], int, int]]:
99 """LCS shortest-edit script on two sequences of child nodes.
100
101 Comparison is by ``id`` — children with the same id are matched (a "keep"),
102 even if their ``content_id`` differs. A kept pair that has a different
103 ``content_id`` will produce a ``ReplaceOp`` when recursed into by
104 :func:`_diff_nodes`.
105
106 Unmatched children produce insert / delete ops.
107
108 Returns a list of ``(kind, base_idx, target_idx)`` triples.
109 """
110 n, m = len(base_children), len(target_children)
111 base_ids = [c.id for c in base_children]
112 target_ids = [c.id for c in target_children]
113
114 dp: list[list[int]] = [[0] * (m + 1) for _ in range(n + 1)]
115 for i in range(n - 1, -1, -1):
116 for j in range(m - 1, -1, -1):
117 if base_ids[i] == target_ids[j]:
118 dp[i][j] = dp[i + 1][j + 1] + 1
119 else:
120 dp[i][j] = max(dp[i + 1][j], dp[i][j + 1])
121
122 result: list[tuple[Literal["keep", "insert", "delete"], int, int]] = []
123 i, j = 0, 0
124 while i < n or j < m:
125 if i < n and j < m and base_ids[i] == target_ids[j]:
126 result.append(("keep", i, j))
127 i += 1
128 j += 1
129 elif j < m and (i >= n or dp[i][j + 1] >= dp[i + 1][j]):
130 result.append(("insert", i, j))
131 j += 1
132 else:
133 result.append(("delete", i, j))
134 i += 1
135
136 return result
137
138
139 def _diff_nodes(
140 base: TreeNode,
141 target: TreeNode,
142 *,
143 domain: str,
144 address: str,
145 ) -> list[DomainOp]:
146 """Recursively diff two tree nodes, returning a flat op list."""
147 ops: list[DomainOp] = []
148 node_addr = f"{address}/{base.id}" if address else base.id
149
150 # Root node comparison
151 if base.content_id != target.content_id:
152 ops.append(
153 ReplaceOp(
154 op="replace",
155 address=node_addr,
156 position=None,
157 old_content_id=base.content_id,
158 new_content_id=target.content_id,
159 old_summary=f"{base.label} (prev)",
160 new_summary=f"{target.label} (new)",
161 )
162 )
163
164 if not base.children and not target.children:
165 return ops
166
167 # Diff children via LCS
168 script = _lcs_children(base.children, target.children)
169
170 raw_inserts: list[tuple[int, TreeNode]] = [] # (target_idx, node)
171 raw_deletes: list[tuple[int, TreeNode]] = [] # (base_idx, node)
172
173 for kind, bi, ti in script:
174 if kind == "keep":
175 # Recurse into the matched child pair
176 ops.extend(
177 _diff_nodes(
178 base.children[bi],
179 target.children[ti],
180 domain=domain,
181 address=node_addr,
182 )
183 )
184 elif kind == "insert":
185 raw_inserts.append((ti, target.children[ti]))
186 else:
187 raw_deletes.append((bi, base.children[bi]))
188
189 # Move detection: paired insert+delete of the same node id at different positions.
190 # Node identity is tracked by id, not content_id, so a repositioned node
191 # is detected as a move even if its content also changed.
192 delete_by_id: dict[str, tuple[int, TreeNode]] = {}
193 for bi, node in raw_deletes:
194 if node.id not in delete_by_id:
195 delete_by_id[node.id] = (bi, node)
196
197 consumed_ids: set[str] = set()
198 for ti, node in raw_inserts:
199 nid = node.id
200 if nid in delete_by_id and nid not in consumed_ids:
201 from_idx, _ = delete_by_id[nid]
202 if from_idx != ti:
203 ops.append(
204 MoveOp(
205 op="move",
206 address=node_addr,
207 from_position=from_idx,
208 to_position=ti,
209 content_id=node.content_id,
210 )
211 )
212 consumed_ids.add(nid)
213 continue
214 # True insert — recursively add the entire subtree's nodes
215 for sub_node in _subtree_nodes(node):
216 ops.append(
217 InsertOp(
218 op="insert",
219 address=node_addr,
220 position=ti,
221 content_id=sub_node.content_id,
222 content_summary=f"{sub_node.label} added",
223 )
224 )
225
226 for bi, node in raw_deletes:
227 if node.id in consumed_ids:
228 continue
229 # True delete — recursively remove the entire subtree's nodes
230 for sub_node in _subtree_nodes(node):
231 ops.append(
232 DeleteOp(
233 op="delete",
234 address=node_addr,
235 position=bi,
236 content_id=sub_node.content_id,
237 content_summary=f"{sub_node.label} removed",
238 )
239 )
240
241 return ops
242
243
244 # ---------------------------------------------------------------------------
245 # Top-level diff entry point
246 # ---------------------------------------------------------------------------
247
248
249 def diff(
250 schema: TreeSchema,
251 base: TreeNode,
252 target: TreeNode,
253 *,
254 domain: str,
255 address: str = "",
256 ) -> StructuredDelta:
257 """Diff two labeled ordered trees, returning a ``StructuredDelta``.
258
259 Produces ``ReplaceOp`` for node relabels, ``InsertOp`` / ``DeleteOp``
260 for subtree insertions and deletions, and ``MoveOp`` for repositioned
261 subtrees (detected as paired delete+insert of the same content).
262
263 Args:
264 schema: The ``TreeSchema`` declaring node type and diff algorithm.
265 base: Root of the base (ancestor) tree.
266 target: Root of the target (newer) tree.
267 domain: Domain tag for the returned ``StructuredDelta``.
268 address: Address prefix for generated op entries.
269
270 Returns:
271 A ``StructuredDelta`` with typed ops and a human-readable summary.
272 """
273 # Fast path: identical trees
274 if base.content_id == target.content_id and base.children == target.children:
275 return StructuredDelta(
276 domain=domain,
277 ops=[],
278 summary=f"no {schema['node_type']} changes",
279 )
280
281 ops = _diff_nodes(base, target, domain=domain, address=address)
282
283 n_replace = sum(1 for op in ops if op["op"] == "replace")
284 n_insert = sum(1 for op in ops if op["op"] == "insert")
285 n_delete = sum(1 for op in ops if op["op"] == "delete")
286 n_move = sum(1 for op in ops if op["op"] == "move")
287
288 parts: list[str] = []
289 if n_replace:
290 parts.append(f"{n_replace} relabelled")
291 if n_insert:
292 parts.append(f"{n_insert} added")
293 if n_delete:
294 parts.append(f"{n_delete} removed")
295 if n_move:
296 parts.append(f"{n_move} moved")
297 summary = ", ".join(parts) if parts else f"no {schema['node_type']} changes"
298
299 logger.debug(
300 "tree_edit.diff: +%d -%d ~%d r%d ops on %r",
301 n_insert,
302 n_delete,
303 n_move,
304 n_replace,
305 address,
306 )
307
308 return StructuredDelta(domain=domain, ops=ops, summary=summary)