gabriel / muse public
tree_edit.py python
307 lines 10.0 KB
e6786943 feat: upgrade to Python 3.14, drop from __future__ import annotations Gabriel Cardona <cgcardona@gmail.com> 5d ago
1 """LCS-based tree edit algorithm for labeled ordered trees.
2
3 Implements a correct tree diff that produces ``InsertOp``, ``DeleteOp``,
4 ``ReplaceOp``, and ``MoveOp`` entries for labeled ordered trees.
5
6 Algorithm
7 ---------
8 The diff proceeds top-down recursively:
9
10 1. Compare root nodes by ``content_id``. Different content_id → ``ReplaceOp``
11 on the root node.
12 2. Diff the children sequences using the same LCS algorithm as
13 :mod:`~muse.core.diff_algorithms.lcs`:
14
15 - Matched child pairs (same ``content_id``) → recurse into subtree.
16 - Unmatched inserts → ``InsertOp`` (entire subtree added).
17 - Unmatched deletes → ``DeleteOp`` (entire subtree removed).
18 - Paired insert+delete of same ``content_id`` at different positions →
19 ``MoveOp``.
20
21 This approach is O(nm) per tree level where n, m are child counts. It does
22 not find the globally minimal edit script (Zhang-Shasha is optimal), but it
23 is correct: every change is accounted for, and applying the script to the base
24 tree produces the target tree. For the bounded tree sizes typical of domain
25 objects (scenes, tracks, ASTs ≲ 10k nodes), this is more than adequate for
26 Zhang-Shasha optimisation is a drop-in replacement once needed.
27
28 ``TreeNode`` is defined here and re-exported by the package ``__init__``.
29
30 Public API
31 ----------
32 - :class:`TreeNode` — labeled ordered tree node (frozen dataclass).
33 - :func:`diff` — ``TreeNode`` × ``TreeNode`` → ``StructuredDelta``.
34 """
35
36 import logging
37 from dataclasses import dataclass
38 from typing import Literal
39
40 from muse.core.schema import TreeSchema
41 from muse.domain import (
42 DeleteOp,
43 DomainOp,
44 InsertOp,
45 MoveOp,
46 ReplaceOp,
47 StructuredDelta,
48 )
49
50 logger = logging.getLogger(__name__)
51
52
53 # ---------------------------------------------------------------------------
54 # TreeNode — the unit of tree-edit comparison
55 # ---------------------------------------------------------------------------
56
57
58 @dataclass(frozen=True)
59 class TreeNode:
60 """A node in a labeled ordered tree for domain tree-edit algorithms.
61
62 ``id`` is a stable unique identifier for the node (e.g. UUID or path).
63 ``label`` is the human-readable name (e.g. element tag, node type).
64 ``content_id`` is the SHA-256 of this node's own value — excluding its
65 children. Two nodes are considered the same iff their ``content_id``\\s
66 match; a different ``content_id`` triggers a ``ReplaceOp``.
67 ``children`` is an ordered tuple of child nodes.
68 """
69
70 id: str
71 label: str
72 content_id: str
73 children: tuple[TreeNode, ...]
74
75
76 # ---------------------------------------------------------------------------
77 # Internal helpers
78 # ---------------------------------------------------------------------------
79
80
81 def _subtree_nodes(node: TreeNode) -> list[TreeNode]:
82 """Return all nodes in *node*'s subtree (postorder)."""
83 result: list[TreeNode] = []
84
85 def _visit(n: TreeNode) -> None:
86 for child in n.children:
87 _visit(child)
88 result.append(n)
89
90 _visit(node)
91 return result
92
93
94 def _lcs_children(
95 base_children: tuple[TreeNode, ...],
96 target_children: tuple[TreeNode, ...],
97 ) -> list[tuple[Literal["keep", "insert", "delete"], int, int]]:
98 """LCS shortest-edit script on two sequences of child nodes.
99
100 Comparison is by ``id`` — children with the same id are matched (a "keep"),
101 even if their ``content_id`` differs. A kept pair that has a different
102 ``content_id`` will produce a ``ReplaceOp`` when recursed into by
103 :func:`_diff_nodes`.
104
105 Unmatched children produce insert / delete ops.
106
107 Returns a list of ``(kind, base_idx, target_idx)`` triples.
108 """
109 n, m = len(base_children), len(target_children)
110 base_ids = [c.id for c in base_children]
111 target_ids = [c.id for c in target_children]
112
113 dp: list[list[int]] = [[0] * (m + 1) for _ in range(n + 1)]
114 for i in range(n - 1, -1, -1):
115 for j in range(m - 1, -1, -1):
116 if base_ids[i] == target_ids[j]:
117 dp[i][j] = dp[i + 1][j + 1] + 1
118 else:
119 dp[i][j] = max(dp[i + 1][j], dp[i][j + 1])
120
121 result: list[tuple[Literal["keep", "insert", "delete"], int, int]] = []
122 i, j = 0, 0
123 while i < n or j < m:
124 if i < n and j < m and base_ids[i] == target_ids[j]:
125 result.append(("keep", i, j))
126 i += 1
127 j += 1
128 elif j < m and (i >= n or dp[i][j + 1] >= dp[i + 1][j]):
129 result.append(("insert", i, j))
130 j += 1
131 else:
132 result.append(("delete", i, j))
133 i += 1
134
135 return result
136
137
138 def _diff_nodes(
139 base: TreeNode,
140 target: TreeNode,
141 *,
142 domain: str,
143 address: str,
144 ) -> list[DomainOp]:
145 """Recursively diff two tree nodes, returning a flat op list."""
146 ops: list[DomainOp] = []
147 node_addr = f"{address}/{base.id}" if address else base.id
148
149 # Root node comparison
150 if base.content_id != target.content_id:
151 ops.append(
152 ReplaceOp(
153 op="replace",
154 address=node_addr,
155 position=None,
156 old_content_id=base.content_id,
157 new_content_id=target.content_id,
158 old_summary=f"{base.label} (prev)",
159 new_summary=f"{target.label} (new)",
160 )
161 )
162
163 if not base.children and not target.children:
164 return ops
165
166 # Diff children via LCS
167 script = _lcs_children(base.children, target.children)
168
169 raw_inserts: list[tuple[int, TreeNode]] = [] # (target_idx, node)
170 raw_deletes: list[tuple[int, TreeNode]] = [] # (base_idx, node)
171
172 for kind, bi, ti in script:
173 if kind == "keep":
174 # Recurse into the matched child pair
175 ops.extend(
176 _diff_nodes(
177 base.children[bi],
178 target.children[ti],
179 domain=domain,
180 address=node_addr,
181 )
182 )
183 elif kind == "insert":
184 raw_inserts.append((ti, target.children[ti]))
185 else:
186 raw_deletes.append((bi, base.children[bi]))
187
188 # Move detection: paired insert+delete of the same node id at different positions.
189 # Node identity is tracked by id, not content_id, so a repositioned node
190 # is detected as a move even if its content also changed.
191 delete_by_id: dict[str, tuple[int, TreeNode]] = {}
192 for bi, node in raw_deletes:
193 if node.id not in delete_by_id:
194 delete_by_id[node.id] = (bi, node)
195
196 consumed_ids: set[str] = set()
197 for ti, node in raw_inserts:
198 nid = node.id
199 if nid in delete_by_id and nid not in consumed_ids:
200 from_idx, _ = delete_by_id[nid]
201 if from_idx != ti:
202 ops.append(
203 MoveOp(
204 op="move",
205 address=node_addr,
206 from_position=from_idx,
207 to_position=ti,
208 content_id=node.content_id,
209 )
210 )
211 consumed_ids.add(nid)
212 continue
213 # True insert — recursively add the entire subtree's nodes
214 for sub_node in _subtree_nodes(node):
215 ops.append(
216 InsertOp(
217 op="insert",
218 address=node_addr,
219 position=ti,
220 content_id=sub_node.content_id,
221 content_summary=f"{sub_node.label} added",
222 )
223 )
224
225 for bi, node in raw_deletes:
226 if node.id in consumed_ids:
227 continue
228 # True delete — recursively remove the entire subtree's nodes
229 for sub_node in _subtree_nodes(node):
230 ops.append(
231 DeleteOp(
232 op="delete",
233 address=node_addr,
234 position=bi,
235 content_id=sub_node.content_id,
236 content_summary=f"{sub_node.label} removed",
237 )
238 )
239
240 return ops
241
242
243 # ---------------------------------------------------------------------------
244 # Top-level diff entry point
245 # ---------------------------------------------------------------------------
246
247
248 def diff(
249 schema: TreeSchema,
250 base: TreeNode,
251 target: TreeNode,
252 *,
253 domain: str,
254 address: str = "",
255 ) -> StructuredDelta:
256 """Diff two labeled ordered trees, returning a ``StructuredDelta``.
257
258 Produces ``ReplaceOp`` for node relabels, ``InsertOp`` / ``DeleteOp``
259 for subtree insertions and deletions, and ``MoveOp`` for repositioned
260 subtrees (detected as paired delete+insert of the same content).
261
262 Args:
263 schema: The ``TreeSchema`` declaring node type and diff algorithm.
264 base: Root of the base (ancestor) tree.
265 target: Root of the target (newer) tree.
266 domain: Domain tag for the returned ``StructuredDelta``.
267 address: Address prefix for generated op entries.
268
269 Returns:
270 A ``StructuredDelta`` with typed ops and a human-readable summary.
271 """
272 # Fast path: identical trees
273 if base.content_id == target.content_id and base.children == target.children:
274 return StructuredDelta(
275 domain=domain,
276 ops=[],
277 summary=f"no {schema['node_type']} changes",
278 )
279
280 ops = _diff_nodes(base, target, domain=domain, address=address)
281
282 n_replace = sum(1 for op in ops if op["op"] == "replace")
283 n_insert = sum(1 for op in ops if op["op"] == "insert")
284 n_delete = sum(1 for op in ops if op["op"] == "delete")
285 n_move = sum(1 for op in ops if op["op"] == "move")
286
287 parts: list[str] = []
288 if n_replace:
289 parts.append(f"{n_replace} relabelled")
290 if n_insert:
291 parts.append(f"{n_insert} added")
292 if n_delete:
293 parts.append(f"{n_delete} removed")
294 if n_move:
295 parts.append(f"{n_move} moved")
296 summary = ", ".join(parts) if parts else f"no {schema['node_type']} changes"
297
298 logger.debug(
299 "tree_edit.diff: +%d -%d ~%d r%d ops on %r",
300 n_insert,
301 n_delete,
302 n_move,
303 n_replace,
304 address,
305 )
306
307 return StructuredDelta(domain=domain, ops=ops, summary=summary)