midi_diff.py
python
| 1 | """MIDI note-level diff for the Muse music plugin. |
| 2 | |
| 3 | Implements the Myers / LCS shortest-edit-script algorithm on MIDI note |
| 4 | sequences, producing a ``StructuredDelta`` with note-level ``InsertOp``, |
| 5 | ``DeleteOp``, and ``ReplaceOp`` entries inside a ``PatchOp``. |
| 6 | |
| 7 | This is what lets ``muse show`` display "C4 added at beat 3.5" rather than |
| 8 | "tracks/drums.mid modified". |
| 9 | |
| 10 | Algorithm |
| 11 | --------- |
| 12 | 1. Parse MIDI bytes and extract paired note events (note_on + note_off) |
| 13 | sorted by start tick. |
| 14 | 2. Represent each note as a ``NoteKey`` TypedDict with five fields. |
| 15 | 3. Run the O(nm) LCS dynamic-programming algorithm on the two note sequences. |
| 16 | 4. Traceback to produce a shortest edit script of keep / insert / delete steps. |
| 17 | 5. Map edit steps to typed ``DomainOp`` instances. |
| 18 | 6. Wrap the ops in a ``StructuredDelta`` with a human-readable summary. |
| 19 | |
| 20 | Public API |
| 21 | ---------- |
| 22 | - :class:`NoteKey` — hashable note identity. |
| 23 | - :func:`extract_notes` — MIDI bytes → sorted ``list[NoteKey]``. |
| 24 | - :func:`lcs_edit_script` — LCS shortest edit script on two note lists. |
| 25 | - :func:`diff_midi_notes` — top-level: MIDI bytes × 2 → ``StructuredDelta``. |
| 26 | """ |
| 27 | from __future__ import annotations |
| 28 | |
| 29 | import hashlib |
| 30 | import io |
| 31 | import logging |
| 32 | from dataclasses import dataclass |
| 33 | from typing import Literal, TypedDict |
| 34 | |
| 35 | import mido |
| 36 | |
| 37 | from muse.domain import ( |
| 38 | DeleteOp, |
| 39 | DomainOp, |
| 40 | InsertOp, |
| 41 | StructuredDelta, |
| 42 | ) |
| 43 | |
| 44 | logger = logging.getLogger(__name__) |
| 45 | |
| 46 | #: Identifies the sub-domain for note-level operations inside a PatchOp. |
| 47 | _CHILD_DOMAIN = "midi_notes" |
| 48 | |
| 49 | _PITCH_NAMES = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"] |
| 50 | |
| 51 | |
| 52 | # --------------------------------------------------------------------------- |
| 53 | # NoteKey — the unit of LCS comparison |
| 54 | # --------------------------------------------------------------------------- |
| 55 | |
| 56 | |
| 57 | class NoteKey(TypedDict): |
| 58 | """Fully-specified MIDI note used as the LCS comparison unit. |
| 59 | |
| 60 | Two notes are considered identical in LCS iff all five fields match. |
| 61 | A pitch change, velocity change, timing shift, or channel change |
| 62 | counts as a delete of the old note and an insert of the new one. |
| 63 | This is conservative but correct — it means the LCS finds true |
| 64 | structural matches and surfaces real musical changes. |
| 65 | """ |
| 66 | |
| 67 | pitch: int |
| 68 | velocity: int |
| 69 | start_tick: int |
| 70 | duration_ticks: int |
| 71 | channel: int |
| 72 | |
| 73 | |
| 74 | # --------------------------------------------------------------------------- |
| 75 | # Edit step — output of the LCS traceback |
| 76 | # --------------------------------------------------------------------------- |
| 77 | |
| 78 | EditKind = Literal["keep", "insert", "delete"] |
| 79 | |
| 80 | |
| 81 | @dataclass(frozen=True) |
| 82 | class EditStep: |
| 83 | """One step in the shortest edit script.""" |
| 84 | |
| 85 | kind: EditKind |
| 86 | base_index: int # index in the base note sequence |
| 87 | target_index: int # index in the target note sequence |
| 88 | note: NoteKey |
| 89 | |
| 90 | |
| 91 | # --------------------------------------------------------------------------- |
| 92 | # Helpers |
| 93 | # --------------------------------------------------------------------------- |
| 94 | |
| 95 | |
| 96 | def _pitch_name(midi_pitch: int) -> str: |
| 97 | """Return a human-readable pitch string, e.g. ``"C4"``, ``"F#5"``.""" |
| 98 | octave = midi_pitch // 12 - 1 |
| 99 | name = _PITCH_NAMES[midi_pitch % 12] |
| 100 | return f"{name}{octave}" |
| 101 | |
| 102 | |
| 103 | def _note_content_id(note: NoteKey) -> str: |
| 104 | """Return a deterministic SHA-256 for a note's five identity fields. |
| 105 | |
| 106 | This gives a stable ``content_id`` for use in ``InsertOp`` / ``DeleteOp`` |
| 107 | without requiring the note to be stored as a separate blob in the object |
| 108 | store. The hash uniquely identifies "this specific note event". |
| 109 | """ |
| 110 | payload = ( |
| 111 | f"{note['pitch']}:{note['velocity']}:" |
| 112 | f"{note['start_tick']}:{note['duration_ticks']}:{note['channel']}" |
| 113 | ) |
| 114 | return hashlib.sha256(payload.encode()).hexdigest() |
| 115 | |
| 116 | |
| 117 | def _note_summary(note: NoteKey, ticks_per_beat: int) -> str: |
| 118 | """Return a human-readable one-liner for a note, e.g. ``"C4 vel=80 @beat=1.00"``.""" |
| 119 | beat = note["start_tick"] / max(ticks_per_beat, 1) |
| 120 | dur = note["duration_ticks"] / max(ticks_per_beat, 1) |
| 121 | return ( |
| 122 | f"{_pitch_name(note['pitch'])} " |
| 123 | f"vel={note['velocity']} " |
| 124 | f"@beat={beat:.2f} " |
| 125 | f"dur={dur:.2f}" |
| 126 | ) |
| 127 | |
| 128 | |
| 129 | # --------------------------------------------------------------------------- |
| 130 | # Note extraction |
| 131 | # --------------------------------------------------------------------------- |
| 132 | |
| 133 | |
| 134 | def extract_notes(midi_bytes: bytes) -> tuple[list[NoteKey], int]: |
| 135 | """Parse *midi_bytes* and return ``(notes, ticks_per_beat)``. |
| 136 | |
| 137 | Notes are paired note_on / note_off events. A note_on with velocity=0 |
| 138 | is treated as note_off. Notes are sorted by start_tick then pitch for |
| 139 | deterministic ordering. |
| 140 | |
| 141 | Args: |
| 142 | midi_bytes: Raw bytes of a ``.mid`` file. |
| 143 | |
| 144 | Returns: |
| 145 | A tuple of (sorted NoteKey list, ticks_per_beat integer). |
| 146 | |
| 147 | Raises: |
| 148 | ValueError: When *midi_bytes* cannot be parsed as a MIDI file. |
| 149 | """ |
| 150 | try: |
| 151 | mid = mido.MidiFile(file=io.BytesIO(midi_bytes)) |
| 152 | except Exception as exc: |
| 153 | raise ValueError(f"Cannot parse MIDI bytes: {exc}") from exc |
| 154 | |
| 155 | ticks_per_beat: int = int(mid.ticks_per_beat) |
| 156 | # (channel, pitch) → (start_tick, velocity) |
| 157 | active: dict[tuple[int, int], tuple[int, int]] = {} |
| 158 | notes: list[NoteKey] = [] |
| 159 | |
| 160 | for track in mid.tracks: |
| 161 | abs_tick = 0 |
| 162 | for msg in track: |
| 163 | abs_tick += msg.time |
| 164 | if msg.type == "note_on" and msg.velocity > 0: |
| 165 | active[(msg.channel, msg.note)] = (abs_tick, msg.velocity) |
| 166 | elif msg.type == "note_off" or ( |
| 167 | msg.type == "note_on" and msg.velocity == 0 |
| 168 | ): |
| 169 | key = (msg.channel, msg.note) |
| 170 | if key in active: |
| 171 | start, vel = active.pop(key) |
| 172 | notes.append( |
| 173 | NoteKey( |
| 174 | pitch=msg.note, |
| 175 | velocity=vel, |
| 176 | start_tick=start, |
| 177 | duration_ticks=max(abs_tick - start, 1), |
| 178 | channel=msg.channel, |
| 179 | ) |
| 180 | ) |
| 181 | |
| 182 | # Close any notes still open at end of file with duration 1. |
| 183 | for (ch, pitch), (start, vel) in active.items(): |
| 184 | notes.append( |
| 185 | NoteKey( |
| 186 | pitch=pitch, |
| 187 | velocity=vel, |
| 188 | start_tick=start, |
| 189 | duration_ticks=1, |
| 190 | channel=ch, |
| 191 | ) |
| 192 | ) |
| 193 | |
| 194 | notes.sort(key=lambda n: (n["start_tick"], n["pitch"], n["channel"])) |
| 195 | return notes, ticks_per_beat |
| 196 | |
| 197 | |
| 198 | # --------------------------------------------------------------------------- |
| 199 | # LCS / Myers algorithm |
| 200 | # --------------------------------------------------------------------------- |
| 201 | |
| 202 | |
| 203 | def lcs_edit_script( |
| 204 | base: list[NoteKey], |
| 205 | target: list[NoteKey], |
| 206 | ) -> list[EditStep]: |
| 207 | """Compute the shortest edit script transforming *base* into *target*. |
| 208 | |
| 209 | Uses the standard O(n·m) LCS dynamic-programming algorithm followed by |
| 210 | linear-time traceback. Two notes are matched iff all five ``NoteKey`` |
| 211 | fields are equal. |
| 212 | |
| 213 | Args: |
| 214 | base: The base (ancestor) note sequence. |
| 215 | target: The target (newer) note sequence. |
| 216 | |
| 217 | Returns: |
| 218 | A list of ``EditStep`` with kind ``"keep"``, ``"insert"``, or |
| 219 | ``"delete"`` that transforms *base* into *target* in order. |
| 220 | The list is minimal: ``len(keep steps) == LCS length``. |
| 221 | """ |
| 222 | n, m = len(base), len(target) |
| 223 | |
| 224 | # dp[i][j] = length of LCS of base[i:] and target[j:] |
| 225 | dp: list[list[int]] = [[0] * (m + 1) for _ in range(n + 1)] |
| 226 | for i in range(n - 1, -1, -1): |
| 227 | for j in range(m - 1, -1, -1): |
| 228 | if base[i] == target[j]: |
| 229 | dp[i][j] = dp[i + 1][j + 1] + 1 |
| 230 | else: |
| 231 | dp[i][j] = max(dp[i + 1][j], dp[i][j + 1]) |
| 232 | |
| 233 | # Traceback: reconstruct the edit script. |
| 234 | steps: list[EditStep] = [] |
| 235 | i, j = 0, 0 |
| 236 | while i < n or j < m: |
| 237 | if i < n and j < m and base[i] == target[j]: |
| 238 | steps.append(EditStep("keep", i, j, base[i])) |
| 239 | i += 1 |
| 240 | j += 1 |
| 241 | elif j < m and (i >= n or dp[i][j + 1] >= dp[i + 1][j]): |
| 242 | steps.append(EditStep("insert", i, j, target[j])) |
| 243 | j += 1 |
| 244 | else: |
| 245 | steps.append(EditStep("delete", i, j, base[i])) |
| 246 | i += 1 |
| 247 | |
| 248 | return steps |
| 249 | |
| 250 | |
| 251 | # --------------------------------------------------------------------------- |
| 252 | # Public diff entry point |
| 253 | # --------------------------------------------------------------------------- |
| 254 | |
| 255 | |
| 256 | def diff_midi_notes( |
| 257 | base_bytes: bytes, |
| 258 | target_bytes: bytes, |
| 259 | *, |
| 260 | file_path: str = "", |
| 261 | ) -> StructuredDelta: |
| 262 | """Compute a note-level ``StructuredDelta`` between two MIDI files. |
| 263 | |
| 264 | Parses both files, runs LCS on their note sequences, and returns a |
| 265 | ``StructuredDelta`` suitable for embedding in a ``PatchOp.child_ops`` |
| 266 | list or storing directly as a commit's ``structured_delta``. |
| 267 | |
| 268 | Args: |
| 269 | base_bytes: Raw bytes of the base (ancestor) MIDI file. |
| 270 | target_bytes: Raw bytes of the target (newer) MIDI file. |
| 271 | file_path: Workspace-relative path of the file being diffed. |
| 272 | Used only in log messages and ``content_summary`` strings. |
| 273 | |
| 274 | Returns: |
| 275 | A ``StructuredDelta`` with ``InsertOp`` and ``DeleteOp`` entries for |
| 276 | each note added or removed. The ``summary`` field is human-readable, |
| 277 | e.g. ``"3 notes added, 1 note removed"``. |
| 278 | |
| 279 | Raises: |
| 280 | ValueError: When either byte string cannot be parsed as MIDI. |
| 281 | """ |
| 282 | base_notes, base_tpb = extract_notes(base_bytes) |
| 283 | target_notes, target_tpb = extract_notes(target_bytes) |
| 284 | tpb = base_tpb # use base ticks_per_beat for summary formatting |
| 285 | |
| 286 | steps = lcs_edit_script(base_notes, target_notes) |
| 287 | |
| 288 | child_ops: list[DomainOp] = [] |
| 289 | inserts = 0 |
| 290 | deletes = 0 |
| 291 | |
| 292 | for step in steps: |
| 293 | if step.kind == "insert": |
| 294 | child_ops.append( |
| 295 | InsertOp( |
| 296 | op="insert", |
| 297 | address=f"note:{step.target_index}", |
| 298 | position=step.target_index, |
| 299 | content_id=_note_content_id(step.note), |
| 300 | content_summary=_note_summary(step.note, tpb), |
| 301 | ) |
| 302 | ) |
| 303 | inserts += 1 |
| 304 | elif step.kind == "delete": |
| 305 | child_ops.append( |
| 306 | DeleteOp( |
| 307 | op="delete", |
| 308 | address=f"note:{step.base_index}", |
| 309 | position=step.base_index, |
| 310 | content_id=_note_content_id(step.note), |
| 311 | content_summary=_note_summary(step.note, tpb), |
| 312 | ) |
| 313 | ) |
| 314 | deletes += 1 |
| 315 | # "keep" steps produce no ops — the note is unchanged. |
| 316 | |
| 317 | parts: list[str] = [] |
| 318 | if inserts: |
| 319 | parts.append(f"{inserts} note{'s' if inserts != 1 else ''} added") |
| 320 | if deletes: |
| 321 | parts.append(f"{deletes} note{'s' if deletes != 1 else ''} removed") |
| 322 | child_summary = ", ".join(parts) if parts else "no note changes" |
| 323 | |
| 324 | logger.debug( |
| 325 | "✅ MIDI diff %r: +%d -%d notes (%d LCS steps)", |
| 326 | file_path, |
| 327 | inserts, |
| 328 | deletes, |
| 329 | len(steps), |
| 330 | ) |
| 331 | |
| 332 | return StructuredDelta( |
| 333 | domain=_CHILD_DOMAIN, |
| 334 | ops=child_ops, |
| 335 | summary=child_summary, |
| 336 | ) |
| 337 | |
| 338 |