tree_edit.py
python
| 1 | """LCS-based tree edit algorithm for labeled ordered trees — Phase 2. |
| 2 | |
| 3 | Implements a correct tree diff that produces ``InsertOp``, ``DeleteOp``, |
| 4 | ``ReplaceOp``, and ``MoveOp`` entries for labeled ordered trees. |
| 5 | |
| 6 | Algorithm |
| 7 | --------- |
| 8 | The diff proceeds top-down recursively: |
| 9 | |
| 10 | 1. Compare root nodes by ``content_id``. Different content_id → ``ReplaceOp`` |
| 11 | on the root node. |
| 12 | 2. Diff the children sequences using the same LCS algorithm as |
| 13 | :mod:`~muse.core.diff_algorithms.lcs`: |
| 14 | |
| 15 | - Matched child pairs (same ``content_id``) → recurse into subtree. |
| 16 | - Unmatched inserts → ``InsertOp`` (entire subtree added). |
| 17 | - Unmatched deletes → ``DeleteOp`` (entire subtree removed). |
| 18 | - Paired insert+delete of same ``content_id`` at different positions → |
| 19 | ``MoveOp``. |
| 20 | |
| 21 | This approach is O(nm) per tree level where n, m are child counts. It does |
| 22 | not find the globally minimal edit script (Zhang-Shasha is optimal), but it |
| 23 | is correct: every change is accounted for, and applying the script to the base |
| 24 | tree produces the target tree. For the bounded tree sizes typical of domain |
| 25 | objects (scenes, tracks, ASTs ≲ 10k nodes), this is more than adequate for |
| 26 | Phase 2. Zhang-Shasha optimisation is a drop-in replacement once needed. |
| 27 | |
| 28 | ``TreeNode`` is defined here and re-exported by the package ``__init__``. |
| 29 | |
| 30 | Public API |
| 31 | ---------- |
| 32 | - :class:`TreeNode` — labeled ordered tree node (frozen dataclass). |
| 33 | - :func:`diff` — ``TreeNode`` × ``TreeNode`` → ``StructuredDelta``. |
| 34 | """ |
| 35 | from __future__ import annotations |
| 36 | |
| 37 | import logging |
| 38 | from dataclasses import dataclass |
| 39 | from typing import Literal |
| 40 | |
| 41 | from muse.core.schema import TreeSchema |
| 42 | from muse.domain import ( |
| 43 | DeleteOp, |
| 44 | DomainOp, |
| 45 | InsertOp, |
| 46 | MoveOp, |
| 47 | ReplaceOp, |
| 48 | StructuredDelta, |
| 49 | ) |
| 50 | |
| 51 | logger = logging.getLogger(__name__) |
| 52 | |
| 53 | |
| 54 | # --------------------------------------------------------------------------- |
| 55 | # TreeNode — the unit of tree-edit comparison |
| 56 | # --------------------------------------------------------------------------- |
| 57 | |
| 58 | |
| 59 | @dataclass(frozen=True) |
| 60 | class TreeNode: |
| 61 | """A node in a labeled ordered tree for domain tree-edit algorithms. |
| 62 | |
| 63 | ``id`` is a stable unique identifier for the node (e.g. UUID or path). |
| 64 | ``label`` is the human-readable name (e.g. element tag, node type). |
| 65 | ``content_id`` is the SHA-256 of this node's own value — excluding its |
| 66 | children. Two nodes are considered the same iff their ``content_id``\\s |
| 67 | match; a different ``content_id`` triggers a ``ReplaceOp``. |
| 68 | ``children`` is an ordered tuple of child nodes. |
| 69 | """ |
| 70 | |
| 71 | id: str |
| 72 | label: str |
| 73 | content_id: str |
| 74 | children: tuple[TreeNode, ...] |
| 75 | |
| 76 | |
| 77 | # --------------------------------------------------------------------------- |
| 78 | # Internal helpers |
| 79 | # --------------------------------------------------------------------------- |
| 80 | |
| 81 | |
| 82 | def _subtree_nodes(node: TreeNode) -> list[TreeNode]: |
| 83 | """Return all nodes in *node*'s subtree (postorder).""" |
| 84 | result: list[TreeNode] = [] |
| 85 | |
| 86 | def _visit(n: TreeNode) -> None: |
| 87 | for child in n.children: |
| 88 | _visit(child) |
| 89 | result.append(n) |
| 90 | |
| 91 | _visit(node) |
| 92 | return result |
| 93 | |
| 94 | |
| 95 | def _lcs_children( |
| 96 | base_children: tuple[TreeNode, ...], |
| 97 | target_children: tuple[TreeNode, ...], |
| 98 | ) -> list[tuple[Literal["keep", "insert", "delete"], int, int]]: |
| 99 | """LCS shortest-edit script on two sequences of child nodes. |
| 100 | |
| 101 | Comparison is by ``id`` — children with the same id are matched (a "keep"), |
| 102 | even if their ``content_id`` differs. A kept pair that has a different |
| 103 | ``content_id`` will produce a ``ReplaceOp`` when recursed into by |
| 104 | :func:`_diff_nodes`. |
| 105 | |
| 106 | Unmatched children produce insert / delete ops. |
| 107 | |
| 108 | Returns a list of ``(kind, base_idx, target_idx)`` triples. |
| 109 | """ |
| 110 | n, m = len(base_children), len(target_children) |
| 111 | base_ids = [c.id for c in base_children] |
| 112 | target_ids = [c.id for c in target_children] |
| 113 | |
| 114 | dp: list[list[int]] = [[0] * (m + 1) for _ in range(n + 1)] |
| 115 | for i in range(n - 1, -1, -1): |
| 116 | for j in range(m - 1, -1, -1): |
| 117 | if base_ids[i] == target_ids[j]: |
| 118 | dp[i][j] = dp[i + 1][j + 1] + 1 |
| 119 | else: |
| 120 | dp[i][j] = max(dp[i + 1][j], dp[i][j + 1]) |
| 121 | |
| 122 | result: list[tuple[Literal["keep", "insert", "delete"], int, int]] = [] |
| 123 | i, j = 0, 0 |
| 124 | while i < n or j < m: |
| 125 | if i < n and j < m and base_ids[i] == target_ids[j]: |
| 126 | result.append(("keep", i, j)) |
| 127 | i += 1 |
| 128 | j += 1 |
| 129 | elif j < m and (i >= n or dp[i][j + 1] >= dp[i + 1][j]): |
| 130 | result.append(("insert", i, j)) |
| 131 | j += 1 |
| 132 | else: |
| 133 | result.append(("delete", i, j)) |
| 134 | i += 1 |
| 135 | |
| 136 | return result |
| 137 | |
| 138 | |
| 139 | def _diff_nodes( |
| 140 | base: TreeNode, |
| 141 | target: TreeNode, |
| 142 | *, |
| 143 | domain: str, |
| 144 | address: str, |
| 145 | ) -> list[DomainOp]: |
| 146 | """Recursively diff two tree nodes, returning a flat op list.""" |
| 147 | ops: list[DomainOp] = [] |
| 148 | node_addr = f"{address}/{base.id}" if address else base.id |
| 149 | |
| 150 | # Root node comparison |
| 151 | if base.content_id != target.content_id: |
| 152 | ops.append( |
| 153 | ReplaceOp( |
| 154 | op="replace", |
| 155 | address=node_addr, |
| 156 | position=None, |
| 157 | old_content_id=base.content_id, |
| 158 | new_content_id=target.content_id, |
| 159 | old_summary=f"{base.label} (prev)", |
| 160 | new_summary=f"{target.label} (new)", |
| 161 | ) |
| 162 | ) |
| 163 | |
| 164 | if not base.children and not target.children: |
| 165 | return ops |
| 166 | |
| 167 | # Diff children via LCS |
| 168 | script = _lcs_children(base.children, target.children) |
| 169 | |
| 170 | raw_inserts: list[tuple[int, TreeNode]] = [] # (target_idx, node) |
| 171 | raw_deletes: list[tuple[int, TreeNode]] = [] # (base_idx, node) |
| 172 | |
| 173 | for kind, bi, ti in script: |
| 174 | if kind == "keep": |
| 175 | # Recurse into the matched child pair |
| 176 | ops.extend( |
| 177 | _diff_nodes( |
| 178 | base.children[bi], |
| 179 | target.children[ti], |
| 180 | domain=domain, |
| 181 | address=node_addr, |
| 182 | ) |
| 183 | ) |
| 184 | elif kind == "insert": |
| 185 | raw_inserts.append((ti, target.children[ti])) |
| 186 | else: |
| 187 | raw_deletes.append((bi, base.children[bi])) |
| 188 | |
| 189 | # Move detection: paired insert+delete of the same node id at different positions. |
| 190 | # Node identity is tracked by id, not content_id, so a repositioned node |
| 191 | # is detected as a move even if its content also changed. |
| 192 | delete_by_id: dict[str, tuple[int, TreeNode]] = {} |
| 193 | for bi, node in raw_deletes: |
| 194 | if node.id not in delete_by_id: |
| 195 | delete_by_id[node.id] = (bi, node) |
| 196 | |
| 197 | consumed_ids: set[str] = set() |
| 198 | for ti, node in raw_inserts: |
| 199 | nid = node.id |
| 200 | if nid in delete_by_id and nid not in consumed_ids: |
| 201 | from_idx, _ = delete_by_id[nid] |
| 202 | if from_idx != ti: |
| 203 | ops.append( |
| 204 | MoveOp( |
| 205 | op="move", |
| 206 | address=node_addr, |
| 207 | from_position=from_idx, |
| 208 | to_position=ti, |
| 209 | content_id=node.content_id, |
| 210 | ) |
| 211 | ) |
| 212 | consumed_ids.add(nid) |
| 213 | continue |
| 214 | # True insert — recursively add the entire subtree's nodes |
| 215 | for sub_node in _subtree_nodes(node): |
| 216 | ops.append( |
| 217 | InsertOp( |
| 218 | op="insert", |
| 219 | address=node_addr, |
| 220 | position=ti, |
| 221 | content_id=sub_node.content_id, |
| 222 | content_summary=f"{sub_node.label} added", |
| 223 | ) |
| 224 | ) |
| 225 | |
| 226 | for bi, node in raw_deletes: |
| 227 | if node.id in consumed_ids: |
| 228 | continue |
| 229 | # True delete — recursively remove the entire subtree's nodes |
| 230 | for sub_node in _subtree_nodes(node): |
| 231 | ops.append( |
| 232 | DeleteOp( |
| 233 | op="delete", |
| 234 | address=node_addr, |
| 235 | position=bi, |
| 236 | content_id=sub_node.content_id, |
| 237 | content_summary=f"{sub_node.label} removed", |
| 238 | ) |
| 239 | ) |
| 240 | |
| 241 | return ops |
| 242 | |
| 243 | |
| 244 | # --------------------------------------------------------------------------- |
| 245 | # Top-level diff entry point |
| 246 | # --------------------------------------------------------------------------- |
| 247 | |
| 248 | |
| 249 | def diff( |
| 250 | schema: TreeSchema, |
| 251 | base: TreeNode, |
| 252 | target: TreeNode, |
| 253 | *, |
| 254 | domain: str, |
| 255 | address: str = "", |
| 256 | ) -> StructuredDelta: |
| 257 | """Diff two labeled ordered trees, returning a ``StructuredDelta``. |
| 258 | |
| 259 | Produces ``ReplaceOp`` for node relabels, ``InsertOp`` / ``DeleteOp`` |
| 260 | for subtree insertions and deletions, and ``MoveOp`` for repositioned |
| 261 | subtrees (detected as paired delete+insert of the same content). |
| 262 | |
| 263 | Args: |
| 264 | schema: The ``TreeSchema`` declaring node type and diff algorithm. |
| 265 | base: Root of the base (ancestor) tree. |
| 266 | target: Root of the target (newer) tree. |
| 267 | domain: Domain tag for the returned ``StructuredDelta``. |
| 268 | address: Address prefix for generated op entries. |
| 269 | |
| 270 | Returns: |
| 271 | A ``StructuredDelta`` with typed ops and a human-readable summary. |
| 272 | """ |
| 273 | # Fast path: identical trees |
| 274 | if base.content_id == target.content_id and base.children == target.children: |
| 275 | return StructuredDelta( |
| 276 | domain=domain, |
| 277 | ops=[], |
| 278 | summary=f"no {schema['node_type']} changes", |
| 279 | ) |
| 280 | |
| 281 | ops = _diff_nodes(base, target, domain=domain, address=address) |
| 282 | |
| 283 | n_replace = sum(1 for op in ops if op["op"] == "replace") |
| 284 | n_insert = sum(1 for op in ops if op["op"] == "insert") |
| 285 | n_delete = sum(1 for op in ops if op["op"] == "delete") |
| 286 | n_move = sum(1 for op in ops if op["op"] == "move") |
| 287 | |
| 288 | parts: list[str] = [] |
| 289 | if n_replace: |
| 290 | parts.append(f"{n_replace} relabelled") |
| 291 | if n_insert: |
| 292 | parts.append(f"{n_insert} added") |
| 293 | if n_delete: |
| 294 | parts.append(f"{n_delete} removed") |
| 295 | if n_move: |
| 296 | parts.append(f"{n_move} moved") |
| 297 | summary = ", ".join(parts) if parts else f"no {schema['node_type']} changes" |
| 298 | |
| 299 | logger.debug( |
| 300 | "tree_edit.diff: +%d -%d ~%d r%d ops on %r", |
| 301 | n_insert, |
| 302 | n_delete, |
| 303 | n_move, |
| 304 | n_replace, |
| 305 | address, |
| 306 | ) |
| 307 | |
| 308 | return StructuredDelta(domain=domain, ops=ops, summary=summary) |