gabriel / muse public
numerical.py python
191 lines 6.5 KB
c5b7bd6b feat(phase-2): domain schema declaration + diff algorithm library (#15) Gabriel Cardona <cgcardona@gmail.com> 6d ago
1 """Sparse / block / full tensor diff for numerical arrays — Phase 2.
2
3 Diffs flat 1-D numerical arrays element-wise with an epsilon tolerance.
4 Floating-point values within ``schema.epsilon`` of each other are not
5 considered changed — this prevents noise from triggering spurious diffs in
6 simulation state, velocity curves, and weight matrices.
7
8 Three output modes (``schema.diff_mode``):
9
10 - ``"sparse"`` — one ``ReplaceOp`` per changed element. Best for data where
11 a small fraction of elements change (e.g. sparse gradient updates).
12 - ``"block"`` — groups adjacent changed elements into contiguous range ops.
13 Best for data where changes cluster (e.g. a section of a velocity curve
14 was edited).
15 - ``"full"`` — emits a single ``ReplaceOp`` for the entire array if any
16 element changed. Best for very large tensors where per-element ops are
17 prohibitively expensive, or when the domain only cares "did anything change?"
18
19 Public API
20 ----------
21 - :func:`diff` — ``list[float]`` × ``list[float]`` → ``StructuredDelta``.
22 """
23 from __future__ import annotations
24
25 import hashlib
26 import logging
27
28 from muse.core.schema import TensorSchema
29 from muse.domain import DomainOp, ReplaceOp, StructuredDelta
30
31 logger = logging.getLogger(__name__)
32
33
34 # ---------------------------------------------------------------------------
35 # Internal helpers
36 # ---------------------------------------------------------------------------
37
38
39 def _float_content_id(values: list[float]) -> str:
40 """Deterministic SHA-256 for a list of float values."""
41 payload = ",".join(f"{v:.10g}" for v in values)
42 return hashlib.sha256(payload.encode()).hexdigest()
43
44
45 def _single_content_id(value: float) -> str:
46 """Deterministic SHA-256 for a single float value."""
47 return hashlib.sha256(f"{value:.10g}".encode()).hexdigest()
48
49
50 # ---------------------------------------------------------------------------
51 # Top-level diff entry point
52 # ---------------------------------------------------------------------------
53
54
55 def diff(
56 schema: TensorSchema,
57 base: list[float],
58 target: list[float],
59 *,
60 domain: str,
61 address: str = "",
62 ) -> StructuredDelta:
63 """Diff two 1-D numerical arrays under the given ``TensorSchema``.
64
65 Length mismatches are treated as a full replacement. For equal-length
66 arrays, the ``diff_mode`` on the schema controls the output granularity.
67
68 Args:
69 schema: The ``TensorSchema`` declaring dtype, epsilon, and diff_mode.
70 base: Base (ancestor) array of float values.
71 target: Target (newer) array of float values.
72 domain: Domain tag for the returned ``StructuredDelta``.
73 address: Address prefix for generated op entries.
74
75 Returns:
76 A ``StructuredDelta`` with ``ReplaceOp`` entries for changed elements
77 and a human-readable summary.
78 """
79 eps = schema["epsilon"]
80 ops: list[DomainOp] = []
81
82 # Length mismatch → full replacement regardless of diff_mode
83 if len(base) != len(target):
84 old_cid = _float_content_id(base)
85 new_cid = _float_content_id(target)
86 ops = [
87 ReplaceOp(
88 op="replace",
89 address=address,
90 position=None,
91 old_content_id=old_cid,
92 new_content_id=new_cid,
93 old_summary=f"tensor[{len(base)}] (prev)",
94 new_summary=f"tensor[{len(target)}] (new)",
95 )
96 ]
97 return StructuredDelta(
98 domain=domain,
99 ops=ops,
100 summary=f"tensor length changed {len(base)}→{len(target)}",
101 )
102
103 # Identify changed indices. Strict `>` so that eps=0.0 means exact equality:
104 # identical values (|b-t|=0) are never flagged, while any actual difference is.
105 changed: list[int] = [
106 i for i, (b, t) in enumerate(zip(base, target)) if abs(b - t) > eps
107 ]
108
109 if not changed:
110 return StructuredDelta(
111 domain=domain, ops=[], summary="no numerical changes"
112 )
113
114 mode = schema["diff_mode"]
115
116 if mode == "full":
117 old_cid = _float_content_id(base)
118 new_cid = _float_content_id(target)
119 ops = [
120 ReplaceOp(
121 op="replace",
122 address=address,
123 position=None,
124 old_content_id=old_cid,
125 new_content_id=new_cid,
126 old_summary=f"tensor[{len(base)}] (prev)",
127 new_summary=f"tensor[{len(target)}] (new)",
128 )
129 ]
130 summary = f"{len(changed)} element{'s' if len(changed) != 1 else ''} changed"
131
132 elif mode == "sparse":
133 for i in changed:
134 ops.append(
135 ReplaceOp(
136 op="replace",
137 address=address,
138 position=i,
139 old_content_id=_single_content_id(base[i]),
140 new_content_id=_single_content_id(target[i]),
141 old_summary=f"[{i}]={base[i]:.6g}",
142 new_summary=f"[{i}]={target[i]:.6g}",
143 )
144 )
145 n = len(changed)
146 summary = f"{n} element{'s' if n != 1 else ''} changed"
147
148 else: # "block"
149 # Group adjacent changed indices into contiguous ranges
150 blocks: list[tuple[int, int]] = [] # (start, end) inclusive
151 run_start = changed[0]
152 run_end = changed[0]
153 for idx in changed[1:]:
154 if idx == run_end + 1:
155 run_end = idx
156 else:
157 blocks.append((run_start, run_end))
158 run_start = idx
159 run_end = idx
160 blocks.append((run_start, run_end))
161
162 for start, end in blocks:
163 block_base = base[start : end + 1]
164 block_target = target[start : end + 1]
165 label = f"[{start}]" if start == end else f"[{start}:{end+1}]"
166 ops.append(
167 ReplaceOp(
168 op="replace",
169 address=address,
170 position=start,
171 old_content_id=_float_content_id(block_base),
172 new_content_id=_float_content_id(block_target),
173 old_summary=f"{label} (prev)",
174 new_summary=f"{label} (new)",
175 )
176 )
177 n = len(changed)
178 summary = (
179 f"{n} element{'s' if n != 1 else ''} changed "
180 f"in {len(blocks)} block{'s' if len(blocks) != 1 else ''}"
181 )
182
183 logger.debug(
184 "numerical.diff %r mode=%r: %d changed of %d elements",
185 address,
186 mode,
187 len(changed),
188 len(base),
189 )
190
191 return StructuredDelta(domain=domain, ops=ops, summary=summary)