gabriel / muse public
midi_diff.py python
521 lines 18.1 KB
04004b82 Rename MusicRGA → MidiRGA and purge all 'music plugin' terminology Gabriel Cardona <gabriel@tellurstori.com> 5d ago
1 """MIDI note-level diff for the Muse MIDI plugin.
2
3 Produces a ``StructuredDelta`` with note-level ``InsertOp`` and ``DeleteOp``
4 entries from two MIDI byte strings. This is what lets ``muse show`` display
5 "C4 added at beat 3.5" rather than "tracks/drums.mid modified".
6
7 Algorithm
8 ---------
9 1. Parse MIDI bytes and extract paired note events (note_on + note_off)
10 sorted by start tick.
11 2. Represent each note as a ``NoteKey`` TypedDict with five fields.
12 3. Convert each ``NoteKey`` to its deterministic content ID (SHA-256 of the
13 five fields).
14 4. Delegate to :func:`~muse.core.diff_algorithms.lcs.myers_ses` — the shared
15 LCS implementation from the diff algorithm library — for the SES.
16 5. Map edit steps to typed ``DomainOp`` instances using the note's content
17 ID and a human-readable summary string.
18 6. Wrap the ops in a ``StructuredDelta``.
19
20 Additional features
21 -----------------
22 :func:`reconstruct_midi` — the inverse of :func:`extract_notes`. Given a list
23 of :class:`NoteKey` objects and a ticks_per_beat value, produces raw MIDI bytes
24 for a Type 0 single-track file. Used by ``MidiPlugin.merge_ops()`` to
25 materialise a merged MIDI file after the OT engine has determined that
26 two branches' note-level operations commute.
27
28 Public API
29 ----------
30 - :class:`NoteKey` — typed MIDI note identity.
31 - :func:`extract_notes` — MIDI bytes → sorted ``list[NoteKey]``.
32 - :func:`reconstruct_midi` — ``list[NoteKey]`` → MIDI bytes.
33 - :func:`diff_midi_notes` — top-level: MIDI bytes × 2 → ``StructuredDelta``.
34 """
35
36 import hashlib
37 import io
38 import logging
39 from dataclasses import dataclass
40 from typing import TYPE_CHECKING, Literal, TypedDict
41
42 if TYPE_CHECKING:
43 from muse.plugins.midi.entity import EntityIndex
44
45 import mido
46
47 from muse.core.diff_algorithms.lcs import myers_ses
48 from muse.domain import (
49 DeleteOp,
50 DomainOp,
51 InsertOp,
52 StructuredDelta,
53 )
54
55 logger = logging.getLogger(__name__)
56
57 #: Identifies the sub-domain for note-level operations inside a PatchOp.
58 _CHILD_DOMAIN = "midi_notes"
59
60 _PITCH_NAMES = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
61
62
63 # ---------------------------------------------------------------------------
64 # NoteKey — the unit of LCS comparison
65 # ---------------------------------------------------------------------------
66
67
68 class NoteKey(TypedDict):
69 """Fully-specified MIDI note used as the LCS comparison unit.
70
71 Two notes are considered identical in LCS iff all five fields match.
72 A pitch change, velocity change, timing shift, or channel change
73 counts as a delete of the old note and an insert of the new one.
74 This is conservative but correct — it means the LCS finds true
75 structural matches and surfaces real musical changes.
76 """
77
78 pitch: int
79 velocity: int
80 start_tick: int
81 duration_ticks: int
82 channel: int
83
84
85 # ---------------------------------------------------------------------------
86 # Helpers
87 # ---------------------------------------------------------------------------
88
89
90 def _pitch_name(midi_pitch: int) -> str:
91 """Return a human-readable pitch string, e.g. ``"C4"``, ``"F#5"``."""
92 octave = midi_pitch // 12 - 1
93 name = _PITCH_NAMES[midi_pitch % 12]
94 return f"{name}{octave}"
95
96
97 def _note_content_id(note: NoteKey) -> str:
98 """Return a deterministic SHA-256 for a note's five identity fields.
99
100 This gives a stable ``content_id`` for use in ``InsertOp`` / ``DeleteOp``
101 without requiring the note to be stored as a separate blob in the object
102 store. The hash uniquely identifies "this specific note event".
103 """
104 payload = (
105 f"{note['pitch']}:{note['velocity']}:"
106 f"{note['start_tick']}:{note['duration_ticks']}:{note['channel']}"
107 )
108 return hashlib.sha256(payload.encode()).hexdigest()
109
110
111 def _note_summary(note: NoteKey, ticks_per_beat: int) -> str:
112 """Return a human-readable one-liner for a note, e.g. ``"C4 vel=80 @beat=1.00"``."""
113 beat = note["start_tick"] / max(ticks_per_beat, 1)
114 dur = note["duration_ticks"] / max(ticks_per_beat, 1)
115 return (
116 f"{_pitch_name(note['pitch'])} "
117 f"vel={note['velocity']} "
118 f"@beat={beat:.2f} "
119 f"dur={dur:.2f}"
120 )
121
122
123 # ---------------------------------------------------------------------------
124 # Note extraction
125 # ---------------------------------------------------------------------------
126
127
128 def extract_notes(midi_bytes: bytes) -> tuple[list[NoteKey], int]:
129 """Parse *midi_bytes* and return ``(notes, ticks_per_beat)``.
130
131 Notes are paired note_on / note_off events. A note_on with velocity=0
132 is treated as note_off. Notes are sorted by start_tick then pitch for
133 deterministic ordering.
134
135 Args:
136 midi_bytes: Raw bytes of a ``.mid`` file.
137
138 Returns:
139 A tuple of (sorted NoteKey list, ticks_per_beat integer).
140
141 Raises:
142 ValueError: When *midi_bytes* cannot be parsed as a MIDI file.
143 """
144 try:
145 mid = mido.MidiFile(file=io.BytesIO(midi_bytes))
146 except Exception as exc:
147 raise ValueError(f"Cannot parse MIDI bytes: {exc}") from exc
148
149 ticks_per_beat: int = int(mid.ticks_per_beat)
150 # (channel, pitch) → (start_tick, velocity)
151 active: dict[tuple[int, int], tuple[int, int]] = {}
152 notes: list[NoteKey] = []
153
154 for track in mid.tracks:
155 abs_tick = 0
156 for msg in track:
157 abs_tick += msg.time
158 if msg.type == "note_on" and msg.velocity > 0:
159 active[(msg.channel, msg.note)] = (abs_tick, msg.velocity)
160 elif msg.type == "note_off" or (
161 msg.type == "note_on" and msg.velocity == 0
162 ):
163 key = (msg.channel, msg.note)
164 if key in active:
165 start, vel = active.pop(key)
166 notes.append(
167 NoteKey(
168 pitch=msg.note,
169 velocity=vel,
170 start_tick=start,
171 duration_ticks=max(abs_tick - start, 1),
172 channel=msg.channel,
173 )
174 )
175
176 # Close any notes still open at end of file with duration 1.
177 for (ch, pitch), (start, vel) in active.items():
178 notes.append(
179 NoteKey(
180 pitch=pitch,
181 velocity=vel,
182 start_tick=start,
183 duration_ticks=1,
184 channel=ch,
185 )
186 )
187
188 notes.sort(key=lambda n: (n["start_tick"], n["pitch"], n["channel"]))
189 return notes, ticks_per_beat
190
191
192 # ---------------------------------------------------------------------------
193 # NoteKey-level edit script — adapter over the core LCS
194 # ---------------------------------------------------------------------------
195
196 EditKind = Literal["keep", "insert", "delete"]
197
198
199 @dataclass(frozen=True)
200 class EditStep:
201 """One step in the note-level edit script produced by :func:`lcs_edit_script`."""
202
203 kind: EditKind
204 base_index: int
205 target_index: int
206 note: NoteKey
207
208
209 def lcs_edit_script(
210 base: list[NoteKey],
211 target: list[NoteKey],
212 ) -> list[EditStep]:
213 """Compute the shortest edit script transforming *base* into *target*.
214
215 Converts each ``NoteKey`` to its content ID, delegates to
216 :func:`~muse.core.diff_algorithms.lcs.myers_ses` for the SES, then maps
217 the result back to :class:`EditStep` entries carrying the original
218 ``NoteKey`` values.
219
220 Two notes are matched iff all five ``NoteKey`` fields are equal. This is
221 correct: a pitch change, velocity change, or timing shift is a delete of
222 the old note and an insert of the new one.
223
224 Args:
225 base: The base (ancestor) note sequence.
226 target: The target (newer) note sequence.
227
228 Returns:
229 A list of :class:`EditStep` entries (keep / insert / delete).
230 """
231 base_ids = [_note_content_id(n) for n in base]
232 target_ids = [_note_content_id(n) for n in target]
233 raw_steps = myers_ses(base_ids, target_ids)
234
235 result: list[EditStep] = []
236 for step in raw_steps:
237 if step.kind == "keep":
238 result.append(EditStep("keep", step.base_index, step.target_index, base[step.base_index]))
239 elif step.kind == "insert":
240 result.append(EditStep("insert", step.base_index, step.target_index, target[step.target_index]))
241 else:
242 result.append(EditStep("delete", step.base_index, step.target_index, base[step.base_index]))
243 return result
244
245
246 # ---------------------------------------------------------------------------
247 # Public diff entry point
248 # ---------------------------------------------------------------------------
249
250
251 def diff_midi_notes(
252 base_bytes: bytes,
253 target_bytes: bytes,
254 *,
255 file_path: str = "",
256 ) -> StructuredDelta:
257 """Compute a note-level ``StructuredDelta`` between two MIDI files.
258
259 Parses both files, converts each note to its content ID, delegates to the
260 core :func:`~muse.core.diff_algorithms.lcs.myers_ses` for the SES, then
261 maps the edit steps to typed ``DomainOp`` instances.
262
263 Args:
264 base_bytes: Raw bytes of the base (ancestor) MIDI file.
265 target_bytes: Raw bytes of the target (newer) MIDI file.
266 file_path: Workspace-relative path of the file being diffed (used
267 only in log messages and ``content_summary`` strings).
268
269 Returns:
270 A ``StructuredDelta`` with ``InsertOp`` and ``DeleteOp`` entries for
271 each note added or removed. The ``summary`` field is human-readable,
272 e.g. ``"3 notes added, 1 note removed"``.
273
274 Raises:
275 ValueError: When either byte string cannot be parsed as MIDI.
276 """
277 base_notes, base_tpb = extract_notes(base_bytes)
278 target_notes, target_tpb = extract_notes(target_bytes)
279 tpb = base_tpb # use base ticks_per_beat for human-readable summaries
280
281 # Convert NoteKey → content ID, then delegate LCS to the core algorithm.
282 base_ids = [_note_content_id(n) for n in base_notes]
283 target_ids = [_note_content_id(n) for n in target_notes]
284 steps = myers_ses(base_ids, target_ids)
285
286 # Build a content-ID → NoteKey lookup so we can produce rich summaries.
287 base_by_id = {_note_content_id(n): n for n in base_notes}
288 target_by_id = {_note_content_id(n): n for n in target_notes}
289
290 child_ops: list[DomainOp] = []
291 inserts = 0
292 deletes = 0
293
294 for step in steps:
295 if step.kind == "insert":
296 note = target_by_id.get(step.item)
297 summary = _note_summary(note, tpb) if note else step.item[:12]
298 child_ops.append(
299 InsertOp(
300 op="insert",
301 address=f"note:{step.target_index}",
302 position=step.target_index,
303 content_id=step.item,
304 content_summary=summary,
305 )
306 )
307 inserts += 1
308 elif step.kind == "delete":
309 note = base_by_id.get(step.item)
310 summary = _note_summary(note, tpb) if note else step.item[:12]
311 child_ops.append(
312 DeleteOp(
313 op="delete",
314 address=f"note:{step.base_index}",
315 position=step.base_index,
316 content_id=step.item,
317 content_summary=summary,
318 )
319 )
320 deletes += 1
321 # "keep" steps produce no ops — the note is unchanged.
322
323 parts: list[str] = []
324 if inserts:
325 parts.append(f"{inserts} note{'s' if inserts != 1 else ''} added")
326 if deletes:
327 parts.append(f"{deletes} note{'s' if deletes != 1 else ''} removed")
328 child_summary = ", ".join(parts) if parts else "no note changes"
329
330 logger.debug(
331 "✅ MIDI diff %r: +%d -%d notes (%d SES steps)",
332 file_path,
333 inserts,
334 deletes,
335 len(steps),
336 )
337
338 return StructuredDelta(
339 domain=_CHILD_DOMAIN,
340 ops=child_ops,
341 summary=child_summary,
342 )
343
344
345 # ---------------------------------------------------------------------------
346 # Entity-aware diff — wrapper that produces MutateOp for field-level mutations
347 # ---------------------------------------------------------------------------
348
349
350 def diff_midi_notes_with_entities(
351 base_bytes: bytes,
352 target_bytes: bytes,
353 *,
354 prior_index: "EntityIndex | None" = None,
355 commit_id: str = "",
356 op_id: str = "",
357 file_path: str = "",
358 mutation_threshold_ticks: int = 10,
359 mutation_threshold_velocity: int = 20,
360 ) -> StructuredDelta:
361 """Compute a note-level ``StructuredDelta`` with stable entity identity.
362
363 Unlike :func:`diff_midi_notes` which maps every field-level change to a
364 ``DeleteOp + InsertOp`` pair, this function uses the entity index from the
365 parent commit to detect *mutations* — notes that are logically the same
366 entity with changed properties — and emits ``MutateOp`` entries for them.
367
368 When ``prior_index`` is ``None`` or entity tracking is unavailable for a
369 note, this function falls back to the content-hash-only diff for that note
370 (same semantics as :func:`diff_midi_notes`).
371
372 The returned ``StructuredDelta`` also includes updated entity tracking
373 metadata in the ``domain`` field tag so consumers know which delta type
374 they are receiving.
375
376 Args:
377 base_bytes: Raw bytes of the base (ancestor) MIDI file.
378 target_bytes: Raw bytes of the target (newer) MIDI file.
379 prior_index: Entity index from the parent commit for *file_path*.
380 ``None`` for first-commit or untracked tracks.
381 commit_id: Current commit ID for provenance metadata.
382 op_id: Op log entry ID that produced this diff.
383 file_path: Workspace-relative path for log messages.
384 mutation_threshold_ticks: Max |Δtick| for fuzzy entity matching.
385 mutation_threshold_velocity: Max |Δvelocity| for fuzzy entity matching.
386
387 Returns:
388 A ``StructuredDelta`` with ``InsertOp``, ``DeleteOp``, and ``MutateOp``
389 entries. Domain tag is ``"midi_notes_tracked"`` to distinguish from
390 the plain content-hash diff.
391
392 Raises:
393 ValueError: When either byte string cannot be parsed as MIDI.
394 """
395 from muse.plugins.midi.entity import assign_entity_ids, diff_with_entity_ids
396
397 base_notes, base_tpb = extract_notes(base_bytes)
398 target_notes, _ = extract_notes(target_bytes)
399 tpb = base_tpb
400
401 base_entities = assign_entity_ids(
402 base_notes,
403 prior_index,
404 commit_id=commit_id or "base",
405 op_id=op_id or "",
406 mutation_threshold_ticks=mutation_threshold_ticks,
407 mutation_threshold_velocity=mutation_threshold_velocity,
408 )
409 target_entities = assign_entity_ids(
410 target_notes,
411 prior_index,
412 commit_id=commit_id,
413 op_id=op_id,
414 mutation_threshold_ticks=mutation_threshold_ticks,
415 mutation_threshold_velocity=mutation_threshold_velocity,
416 )
417
418 ops = diff_with_entity_ids(base_entities, target_entities, tpb)
419
420 inserts = sum(1 for op in ops if op["op"] == "insert")
421 deletes = sum(1 for op in ops if op["op"] == "delete")
422 mutates = sum(1 for op in ops if op["op"] == "mutate")
423
424 parts: list[str] = []
425 if inserts:
426 parts.append(f"{inserts} note{'s' if inserts != 1 else ''} added")
427 if deletes:
428 parts.append(f"{deletes} note{'s' if deletes != 1 else ''} removed")
429 if mutates:
430 parts.append(f"{mutates} note{'s' if mutates != 1 else ''} mutated")
431 summary = ", ".join(parts) if parts else "no note changes"
432
433 logger.debug(
434 "✅ Entity-aware MIDI diff %r: +%d -%d ~%d (%d ops)",
435 file_path,
436 inserts,
437 deletes,
438 mutates,
439 len(ops),
440 )
441
442 return StructuredDelta(
443 domain="midi_notes_tracked",
444 ops=ops,
445 summary=summary,
446 )
447
448
449 # ---------------------------------------------------------------------------
450 # MIDI reconstruction — inverse of extract_notes
451 # ---------------------------------------------------------------------------
452
453
454 def reconstruct_midi(
455 notes: list[NoteKey],
456 *,
457 ticks_per_beat: int = 480,
458 ) -> bytes:
459 """Produce raw MIDI bytes from a list of :class:`NoteKey` objects.
460
461 Creates a Type 0 (single-track) MIDI file. One ``note_on`` and one
462 ``note_off`` event are emitted per note. Events are sorted by absolute
463 tick time so the output is a valid MIDI stream regardless of the input
464 order.
465
466 This is the inverse of :func:`extract_notes`. Used by
467 :func:`~muse.plugins.midi.plugin._merge_patch_ops` after the OT
468 engine has confirmed that two branches' note sequences commute, allowing
469 the merged note list to be materialised as actual MIDI bytes.
470
471 Args:
472 notes: Note events to write. May be in any order; the
473 function sorts by ``start_tick`` before writing.
474 ticks_per_beat: Timing resolution. Preserve the base file's value so
475 that beat positions remain meaningful.
476
477 Returns:
478 Raw MIDI bytes ready to be written to the object store.
479 """
480 mid = mido.MidiFile(ticks_per_beat=ticks_per_beat, type=0)
481 track = mido.MidiTrack()
482 mid.tracks.append(track)
483
484 # Build flat (abs_tick, note_on, channel, pitch, velocity) event tuples.
485 raw_events: list[tuple[int, bool, int, int, int]] = []
486 for note in notes:
487 raw_events.append(
488 (note["start_tick"], True, note["channel"], note["pitch"], note["velocity"])
489 )
490 raw_events.append(
491 (
492 note["start_tick"] + note["duration_ticks"],
493 False,
494 note["channel"],
495 note["pitch"],
496 0,
497 )
498 )
499
500 # Sort: by tick, with note_off (False) before note_on (True) at the same
501 # tick so that retriggered notes are handled correctly.
502 raw_events.sort(key=lambda e: (e[0], e[1]))
503
504 prev_tick = 0
505 for abs_tick, is_on, channel, pitch, velocity in raw_events:
506 delta = abs_tick - prev_tick
507 if is_on:
508 track.append(
509 mido.Message("note_on", channel=channel, note=pitch, velocity=velocity, time=delta)
510 )
511 else:
512 track.append(
513 mido.Message("note_off", channel=channel, note=pitch, velocity=0, time=delta)
514 )
515 prev_tick = abs_tick
516
517 buf = io.BytesIO()
518 mid.save(file=buf)
519 return buf.getvalue()
520
521