cgcardona / muse public
midi_diff.py python
522 lines 18.2 KB
9ee9c39c refactor: rename music→midi domain, strip all 5-dim backward compat Gabriel Cardona <gabriel@tellurstori.com> 1d ago
1 """MIDI note-level diff for the Muse music plugin.
2
3 Produces a ``StructuredDelta`` with note-level ``InsertOp`` and ``DeleteOp``
4 entries from two MIDI byte strings. This is what lets ``muse show`` display
5 "C4 added at beat 3.5" rather than "tracks/drums.mid modified".
6
7 Algorithm
8 ---------
9 1. Parse MIDI bytes and extract paired note events (note_on + note_off)
10 sorted by start tick.
11 2. Represent each note as a ``NoteKey`` TypedDict with five fields.
12 3. Convert each ``NoteKey`` to its deterministic content ID (SHA-256 of the
13 five fields).
14 4. Delegate to :func:`~muse.core.diff_algorithms.lcs.myers_ses` — the shared
15 LCS implementation from the diff algorithm library — for the SES.
16 5. Map edit steps to typed ``DomainOp`` instances using the note's content
17 ID and a human-readable summary string.
18 6. Wrap the ops in a ``StructuredDelta``.
19
20 Additional features
21 -----------------
22 :func:`reconstruct_midi` — the inverse of :func:`extract_notes`. Given a list
23 of :class:`NoteKey` objects and a ticks_per_beat value, produces raw MIDI bytes
24 for a Type 0 single-track file. Used by ``MidiPlugin.merge_ops()`` to
25 materialise a merged MIDI file after the OT engine has determined that
26 two branches' note-level operations commute.
27
28 Public API
29 ----------
30 - :class:`NoteKey` — typed MIDI note identity.
31 - :func:`extract_notes` — MIDI bytes → sorted ``list[NoteKey]``.
32 - :func:`reconstruct_midi` — ``list[NoteKey]`` → MIDI bytes.
33 - :func:`diff_midi_notes` — top-level: MIDI bytes × 2 → ``StructuredDelta``.
34 """
35 from __future__ import annotations
36
37 import hashlib
38 import io
39 import logging
40 from dataclasses import dataclass
41 from typing import TYPE_CHECKING, Literal, TypedDict
42
43 if TYPE_CHECKING:
44 from muse.plugins.midi.entity import EntityIndex
45
46 import mido
47
48 from muse.core.diff_algorithms.lcs import myers_ses
49 from muse.domain import (
50 DeleteOp,
51 DomainOp,
52 InsertOp,
53 StructuredDelta,
54 )
55
56 logger = logging.getLogger(__name__)
57
58 #: Identifies the sub-domain for note-level operations inside a PatchOp.
59 _CHILD_DOMAIN = "midi_notes"
60
61 _PITCH_NAMES = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
62
63
64 # ---------------------------------------------------------------------------
65 # NoteKey — the unit of LCS comparison
66 # ---------------------------------------------------------------------------
67
68
69 class NoteKey(TypedDict):
70 """Fully-specified MIDI note used as the LCS comparison unit.
71
72 Two notes are considered identical in LCS iff all five fields match.
73 A pitch change, velocity change, timing shift, or channel change
74 counts as a delete of the old note and an insert of the new one.
75 This is conservative but correct — it means the LCS finds true
76 structural matches and surfaces real musical changes.
77 """
78
79 pitch: int
80 velocity: int
81 start_tick: int
82 duration_ticks: int
83 channel: int
84
85
86 # ---------------------------------------------------------------------------
87 # Helpers
88 # ---------------------------------------------------------------------------
89
90
91 def _pitch_name(midi_pitch: int) -> str:
92 """Return a human-readable pitch string, e.g. ``"C4"``, ``"F#5"``."""
93 octave = midi_pitch // 12 - 1
94 name = _PITCH_NAMES[midi_pitch % 12]
95 return f"{name}{octave}"
96
97
98 def _note_content_id(note: NoteKey) -> str:
99 """Return a deterministic SHA-256 for a note's five identity fields.
100
101 This gives a stable ``content_id`` for use in ``InsertOp`` / ``DeleteOp``
102 without requiring the note to be stored as a separate blob in the object
103 store. The hash uniquely identifies "this specific note event".
104 """
105 payload = (
106 f"{note['pitch']}:{note['velocity']}:"
107 f"{note['start_tick']}:{note['duration_ticks']}:{note['channel']}"
108 )
109 return hashlib.sha256(payload.encode()).hexdigest()
110
111
112 def _note_summary(note: NoteKey, ticks_per_beat: int) -> str:
113 """Return a human-readable one-liner for a note, e.g. ``"C4 vel=80 @beat=1.00"``."""
114 beat = note["start_tick"] / max(ticks_per_beat, 1)
115 dur = note["duration_ticks"] / max(ticks_per_beat, 1)
116 return (
117 f"{_pitch_name(note['pitch'])} "
118 f"vel={note['velocity']} "
119 f"@beat={beat:.2f} "
120 f"dur={dur:.2f}"
121 )
122
123
124 # ---------------------------------------------------------------------------
125 # Note extraction
126 # ---------------------------------------------------------------------------
127
128
129 def extract_notes(midi_bytes: bytes) -> tuple[list[NoteKey], int]:
130 """Parse *midi_bytes* and return ``(notes, ticks_per_beat)``.
131
132 Notes are paired note_on / note_off events. A note_on with velocity=0
133 is treated as note_off. Notes are sorted by start_tick then pitch for
134 deterministic ordering.
135
136 Args:
137 midi_bytes: Raw bytes of a ``.mid`` file.
138
139 Returns:
140 A tuple of (sorted NoteKey list, ticks_per_beat integer).
141
142 Raises:
143 ValueError: When *midi_bytes* cannot be parsed as a MIDI file.
144 """
145 try:
146 mid = mido.MidiFile(file=io.BytesIO(midi_bytes))
147 except Exception as exc:
148 raise ValueError(f"Cannot parse MIDI bytes: {exc}") from exc
149
150 ticks_per_beat: int = int(mid.ticks_per_beat)
151 # (channel, pitch) → (start_tick, velocity)
152 active: dict[tuple[int, int], tuple[int, int]] = {}
153 notes: list[NoteKey] = []
154
155 for track in mid.tracks:
156 abs_tick = 0
157 for msg in track:
158 abs_tick += msg.time
159 if msg.type == "note_on" and msg.velocity > 0:
160 active[(msg.channel, msg.note)] = (abs_tick, msg.velocity)
161 elif msg.type == "note_off" or (
162 msg.type == "note_on" and msg.velocity == 0
163 ):
164 key = (msg.channel, msg.note)
165 if key in active:
166 start, vel = active.pop(key)
167 notes.append(
168 NoteKey(
169 pitch=msg.note,
170 velocity=vel,
171 start_tick=start,
172 duration_ticks=max(abs_tick - start, 1),
173 channel=msg.channel,
174 )
175 )
176
177 # Close any notes still open at end of file with duration 1.
178 for (ch, pitch), (start, vel) in active.items():
179 notes.append(
180 NoteKey(
181 pitch=pitch,
182 velocity=vel,
183 start_tick=start,
184 duration_ticks=1,
185 channel=ch,
186 )
187 )
188
189 notes.sort(key=lambda n: (n["start_tick"], n["pitch"], n["channel"]))
190 return notes, ticks_per_beat
191
192
193 # ---------------------------------------------------------------------------
194 # NoteKey-level edit script — adapter over the core LCS
195 # ---------------------------------------------------------------------------
196
197 EditKind = Literal["keep", "insert", "delete"]
198
199
200 @dataclass(frozen=True)
201 class EditStep:
202 """One step in the note-level edit script produced by :func:`lcs_edit_script`."""
203
204 kind: EditKind
205 base_index: int
206 target_index: int
207 note: NoteKey
208
209
210 def lcs_edit_script(
211 base: list[NoteKey],
212 target: list[NoteKey],
213 ) -> list[EditStep]:
214 """Compute the shortest edit script transforming *base* into *target*.
215
216 Converts each ``NoteKey`` to its content ID, delegates to
217 :func:`~muse.core.diff_algorithms.lcs.myers_ses` for the SES, then maps
218 the result back to :class:`EditStep` entries carrying the original
219 ``NoteKey`` values.
220
221 Two notes are matched iff all five ``NoteKey`` fields are equal. This is
222 correct: a pitch change, velocity change, or timing shift is a delete of
223 the old note and an insert of the new one.
224
225 Args:
226 base: The base (ancestor) note sequence.
227 target: The target (newer) note sequence.
228
229 Returns:
230 A list of :class:`EditStep` entries (keep / insert / delete).
231 """
232 base_ids = [_note_content_id(n) for n in base]
233 target_ids = [_note_content_id(n) for n in target]
234 raw_steps = myers_ses(base_ids, target_ids)
235
236 result: list[EditStep] = []
237 for step in raw_steps:
238 if step.kind == "keep":
239 result.append(EditStep("keep", step.base_index, step.target_index, base[step.base_index]))
240 elif step.kind == "insert":
241 result.append(EditStep("insert", step.base_index, step.target_index, target[step.target_index]))
242 else:
243 result.append(EditStep("delete", step.base_index, step.target_index, base[step.base_index]))
244 return result
245
246
247 # ---------------------------------------------------------------------------
248 # Public diff entry point
249 # ---------------------------------------------------------------------------
250
251
252 def diff_midi_notes(
253 base_bytes: bytes,
254 target_bytes: bytes,
255 *,
256 file_path: str = "",
257 ) -> StructuredDelta:
258 """Compute a note-level ``StructuredDelta`` between two MIDI files.
259
260 Parses both files, converts each note to its content ID, delegates to the
261 core :func:`~muse.core.diff_algorithms.lcs.myers_ses` for the SES, then
262 maps the edit steps to typed ``DomainOp`` instances.
263
264 Args:
265 base_bytes: Raw bytes of the base (ancestor) MIDI file.
266 target_bytes: Raw bytes of the target (newer) MIDI file.
267 file_path: Workspace-relative path of the file being diffed (used
268 only in log messages and ``content_summary`` strings).
269
270 Returns:
271 A ``StructuredDelta`` with ``InsertOp`` and ``DeleteOp`` entries for
272 each note added or removed. The ``summary`` field is human-readable,
273 e.g. ``"3 notes added, 1 note removed"``.
274
275 Raises:
276 ValueError: When either byte string cannot be parsed as MIDI.
277 """
278 base_notes, base_tpb = extract_notes(base_bytes)
279 target_notes, target_tpb = extract_notes(target_bytes)
280 tpb = base_tpb # use base ticks_per_beat for human-readable summaries
281
282 # Convert NoteKey → content ID, then delegate LCS to the core algorithm.
283 base_ids = [_note_content_id(n) for n in base_notes]
284 target_ids = [_note_content_id(n) for n in target_notes]
285 steps = myers_ses(base_ids, target_ids)
286
287 # Build a content-ID → NoteKey lookup so we can produce rich summaries.
288 base_by_id = {_note_content_id(n): n for n in base_notes}
289 target_by_id = {_note_content_id(n): n for n in target_notes}
290
291 child_ops: list[DomainOp] = []
292 inserts = 0
293 deletes = 0
294
295 for step in steps:
296 if step.kind == "insert":
297 note = target_by_id.get(step.item)
298 summary = _note_summary(note, tpb) if note else step.item[:12]
299 child_ops.append(
300 InsertOp(
301 op="insert",
302 address=f"note:{step.target_index}",
303 position=step.target_index,
304 content_id=step.item,
305 content_summary=summary,
306 )
307 )
308 inserts += 1
309 elif step.kind == "delete":
310 note = base_by_id.get(step.item)
311 summary = _note_summary(note, tpb) if note else step.item[:12]
312 child_ops.append(
313 DeleteOp(
314 op="delete",
315 address=f"note:{step.base_index}",
316 position=step.base_index,
317 content_id=step.item,
318 content_summary=summary,
319 )
320 )
321 deletes += 1
322 # "keep" steps produce no ops — the note is unchanged.
323
324 parts: list[str] = []
325 if inserts:
326 parts.append(f"{inserts} note{'s' if inserts != 1 else ''} added")
327 if deletes:
328 parts.append(f"{deletes} note{'s' if deletes != 1 else ''} removed")
329 child_summary = ", ".join(parts) if parts else "no note changes"
330
331 logger.debug(
332 "✅ MIDI diff %r: +%d -%d notes (%d SES steps)",
333 file_path,
334 inserts,
335 deletes,
336 len(steps),
337 )
338
339 return StructuredDelta(
340 domain=_CHILD_DOMAIN,
341 ops=child_ops,
342 summary=child_summary,
343 )
344
345
346 # ---------------------------------------------------------------------------
347 # Entity-aware diff — wrapper that produces MutateOp for field-level mutations
348 # ---------------------------------------------------------------------------
349
350
351 def diff_midi_notes_with_entities(
352 base_bytes: bytes,
353 target_bytes: bytes,
354 *,
355 prior_index: "EntityIndex | None" = None,
356 commit_id: str = "",
357 op_id: str = "",
358 file_path: str = "",
359 mutation_threshold_ticks: int = 10,
360 mutation_threshold_velocity: int = 20,
361 ) -> StructuredDelta:
362 """Compute a note-level ``StructuredDelta`` with stable entity identity.
363
364 Unlike :func:`diff_midi_notes` which maps every field-level change to a
365 ``DeleteOp + InsertOp`` pair, this function uses the entity index from the
366 parent commit to detect *mutations* — notes that are logically the same
367 entity with changed properties — and emits ``MutateOp`` entries for them.
368
369 When ``prior_index`` is ``None`` or entity tracking is unavailable for a
370 note, this function falls back to the content-hash-only diff for that note
371 (same semantics as :func:`diff_midi_notes`).
372
373 The returned ``StructuredDelta`` also includes updated entity tracking
374 metadata in the ``domain`` field tag so consumers know which delta type
375 they are receiving.
376
377 Args:
378 base_bytes: Raw bytes of the base (ancestor) MIDI file.
379 target_bytes: Raw bytes of the target (newer) MIDI file.
380 prior_index: Entity index from the parent commit for *file_path*.
381 ``None`` for first-commit or untracked tracks.
382 commit_id: Current commit ID for provenance metadata.
383 op_id: Op log entry ID that produced this diff.
384 file_path: Workspace-relative path for log messages.
385 mutation_threshold_ticks: Max |Δtick| for fuzzy entity matching.
386 mutation_threshold_velocity: Max |Δvelocity| for fuzzy entity matching.
387
388 Returns:
389 A ``StructuredDelta`` with ``InsertOp``, ``DeleteOp``, and ``MutateOp``
390 entries. Domain tag is ``"midi_notes_tracked"`` to distinguish from
391 the plain content-hash diff.
392
393 Raises:
394 ValueError: When either byte string cannot be parsed as MIDI.
395 """
396 from muse.plugins.midi.entity import assign_entity_ids, diff_with_entity_ids
397
398 base_notes, base_tpb = extract_notes(base_bytes)
399 target_notes, _ = extract_notes(target_bytes)
400 tpb = base_tpb
401
402 base_entities = assign_entity_ids(
403 base_notes,
404 prior_index,
405 commit_id=commit_id or "base",
406 op_id=op_id or "",
407 mutation_threshold_ticks=mutation_threshold_ticks,
408 mutation_threshold_velocity=mutation_threshold_velocity,
409 )
410 target_entities = assign_entity_ids(
411 target_notes,
412 prior_index,
413 commit_id=commit_id,
414 op_id=op_id,
415 mutation_threshold_ticks=mutation_threshold_ticks,
416 mutation_threshold_velocity=mutation_threshold_velocity,
417 )
418
419 ops = diff_with_entity_ids(base_entities, target_entities, tpb)
420
421 inserts = sum(1 for op in ops if op["op"] == "insert")
422 deletes = sum(1 for op in ops if op["op"] == "delete")
423 mutates = sum(1 for op in ops if op["op"] == "mutate")
424
425 parts: list[str] = []
426 if inserts:
427 parts.append(f"{inserts} note{'s' if inserts != 1 else ''} added")
428 if deletes:
429 parts.append(f"{deletes} note{'s' if deletes != 1 else ''} removed")
430 if mutates:
431 parts.append(f"{mutates} note{'s' if mutates != 1 else ''} mutated")
432 summary = ", ".join(parts) if parts else "no note changes"
433
434 logger.debug(
435 "✅ Entity-aware MIDI diff %r: +%d -%d ~%d (%d ops)",
436 file_path,
437 inserts,
438 deletes,
439 mutates,
440 len(ops),
441 )
442
443 return StructuredDelta(
444 domain="midi_notes_tracked",
445 ops=ops,
446 summary=summary,
447 )
448
449
450 # ---------------------------------------------------------------------------
451 # MIDI reconstruction — inverse of extract_notes
452 # ---------------------------------------------------------------------------
453
454
455 def reconstruct_midi(
456 notes: list[NoteKey],
457 *,
458 ticks_per_beat: int = 480,
459 ) -> bytes:
460 """Produce raw MIDI bytes from a list of :class:`NoteKey` objects.
461
462 Creates a Type 0 (single-track) MIDI file. One ``note_on`` and one
463 ``note_off`` event are emitted per note. Events are sorted by absolute
464 tick time so the output is a valid MIDI stream regardless of the input
465 order.
466
467 This is the inverse of :func:`extract_notes`. Used by
468 :func:`~muse.plugins.midi.plugin._merge_patch_ops` after the OT
469 engine has confirmed that two branches' note sequences commute, allowing
470 the merged note list to be materialised as actual MIDI bytes.
471
472 Args:
473 notes: Note events to write. May be in any order; the
474 function sorts by ``start_tick`` before writing.
475 ticks_per_beat: Timing resolution. Preserve the base file's value so
476 that beat positions remain meaningful.
477
478 Returns:
479 Raw MIDI bytes ready to be written to the object store.
480 """
481 mid = mido.MidiFile(ticks_per_beat=ticks_per_beat, type=0)
482 track = mido.MidiTrack()
483 mid.tracks.append(track)
484
485 # Build flat (abs_tick, note_on, channel, pitch, velocity) event tuples.
486 raw_events: list[tuple[int, bool, int, int, int]] = []
487 for note in notes:
488 raw_events.append(
489 (note["start_tick"], True, note["channel"], note["pitch"], note["velocity"])
490 )
491 raw_events.append(
492 (
493 note["start_tick"] + note["duration_ticks"],
494 False,
495 note["channel"],
496 note["pitch"],
497 0,
498 )
499 )
500
501 # Sort: by tick, with note_off (False) before note_on (True) at the same
502 # tick so that retriggered notes are handled correctly.
503 raw_events.sort(key=lambda e: (e[0], e[1]))
504
505 prev_tick = 0
506 for abs_tick, is_on, channel, pitch, velocity in raw_events:
507 delta = abs_tick - prev_tick
508 if is_on:
509 track.append(
510 mido.Message("note_on", channel=channel, note=pitch, velocity=velocity, time=delta)
511 )
512 else:
513 track.append(
514 mido.Message("note_off", channel=channel, note=pitch, velocity=0, time=delta)
515 )
516 prev_tick = abs_tick
517
518 buf = io.BytesIO()
519 mid.save(file=buf)
520 return buf.getvalue()
521
522