cgcardona / muse public
test_midi_diff.py python
344 lines 13.5 KB
d7054e63 feat(phase-1): typed delta algebra — replace DeltaManifest with Structu… Gabriel Cardona <gabriel@tellurstori.com> 2d ago
1 """Tests for muse.plugins.music.midi_diff — Myers LCS on MIDI note sequences.
2
3 Covers:
4 - NoteKey extraction from MIDI bytes.
5 - LCS edit script correctness (keep/insert/delete).
6 - LCS minimality (length of keep steps == LCS length).
7 - diff_midi_notes() produces correct StructuredDelta.
8 - Content IDs are deterministic and unique per note.
9 - Human-readable summaries and content_summary strings.
10 """
11 from __future__ import annotations
12
13 import io
14 import struct
15
16 import mido
17 import pytest
18
19 from muse.plugins.music.midi_diff import (
20 NoteKey,
21 diff_midi_notes,
22 extract_notes,
23 lcs_edit_script,
24 )
25
26
27 # ---------------------------------------------------------------------------
28 # MIDI builder helpers
29 # ---------------------------------------------------------------------------
30
31 def _build_midi(notes: list[tuple[int, int, int, int]]) -> bytes:
32 """Build a minimal type-0 MIDI file from (pitch, velocity, start, duration) tuples.
33
34 All values use ticks_per_beat=480. Produces valid mido-parseable MIDI bytes.
35 """
36 mid = mido.MidiFile(type=0, ticks_per_beat=480)
37 track = mido.MidiTrack()
38 mid.tracks.append(track)
39
40 # Collect all events sorted by tick.
41 events: list[tuple[int, str, int, int]] = [] # (tick, type, note, velocity)
42 for pitch, velocity, start, duration in notes:
43 events.append((start, "note_on", pitch, velocity))
44 events.append((start + duration, "note_off", pitch, 0))
45
46 events.sort(key=lambda e: e[0])
47
48 prev_tick = 0
49 for tick, msg_type, note, vel in events:
50 delta = tick - prev_tick
51 track.append(mido.Message(msg_type, note=note, velocity=vel, time=delta))
52 prev_tick = tick
53
54 track.append(mido.MetaMessage("end_of_track", time=0))
55
56 buf = io.BytesIO()
57 mid.save(file=buf)
58 return buf.getvalue()
59
60
61 def _note(pitch: int, velocity: int = 80, start: int = 0, duration: int = 480) -> NoteKey:
62 return NoteKey(
63 pitch=pitch, velocity=velocity, start_tick=start,
64 duration_ticks=duration, channel=0,
65 )
66
67
68 # ---------------------------------------------------------------------------
69 # extract_notes
70 # ---------------------------------------------------------------------------
71
72 class TestExtractNotes:
73 def test_empty_midi_returns_no_notes(self) -> None:
74 midi_bytes = _build_midi([])
75 notes, tpb = extract_notes(midi_bytes)
76 assert notes == []
77 assert tpb == 480
78
79 def test_single_note_extracted(self) -> None:
80 midi_bytes = _build_midi([(60, 80, 0, 480)]) # C4
81 notes, tpb = extract_notes(midi_bytes)
82 assert len(notes) == 1
83 assert notes[0]["pitch"] == 60
84 assert notes[0]["velocity"] == 80
85 assert notes[0]["start_tick"] == 0
86 assert notes[0]["duration_ticks"] == 480
87
88 def test_multiple_notes_extracted(self) -> None:
89 midi_bytes = _build_midi([
90 (60, 80, 0, 480),
91 (64, 90, 480, 480),
92 (67, 70, 960, 480),
93 ])
94 notes, _ = extract_notes(midi_bytes)
95 assert len(notes) == 3
96
97 def test_notes_sorted_by_start_tick(self) -> None:
98 midi_bytes = _build_midi([
99 (67, 70, 960, 240),
100 (60, 80, 0, 480),
101 (64, 90, 480, 480),
102 ])
103 notes, _ = extract_notes(midi_bytes)
104 ticks = [n["start_tick"] for n in notes]
105 assert ticks == sorted(ticks)
106
107 def test_invalid_bytes_raises_value_error(self) -> None:
108 with pytest.raises(ValueError):
109 extract_notes(b"not a midi file")
110
111 def test_ticks_per_beat_is_returned(self) -> None:
112 midi_bytes = _build_midi([(60, 80, 0, 480)])
113 _, tpb = extract_notes(midi_bytes)
114 assert tpb == 480
115
116
117 # ---------------------------------------------------------------------------
118 # lcs_edit_script
119 # ---------------------------------------------------------------------------
120
121 class TestLCSEditScript:
122 """LCS tests use start_tick=pitch so same-pitch notes always compare equal.
123
124 NoteKey equality requires ALL five fields to match. Using start_tick=pitch
125 ensures that notes with the same pitch in base and target are considered
126 identical by LCS, giving intuitive edit scripts.
127 """
128
129 def _nk(self, pitch: int) -> NoteKey:
130 """Make a NoteKey where start_tick equals pitch for stable matching."""
131 return NoteKey(
132 pitch=pitch, velocity=80,
133 start_tick=pitch, # deterministic: same pitch → same tick → same key
134 duration_ticks=480, channel=0,
135 )
136
137 def _seq(self, pitches: list[int]) -> list[NoteKey]:
138 return [self._nk(p) for p in pitches]
139
140 def test_identical_sequences_keeps_all(self) -> None:
141 notes = self._seq([60, 62, 64])
142 steps = lcs_edit_script(notes, notes)
143 kinds = [s.kind for s in steps]
144 assert kinds == ["keep", "keep", "keep"]
145
146 def test_empty_to_sequence_all_inserts(self) -> None:
147 target = self._seq([60, 62])
148 steps = lcs_edit_script([], target)
149 assert all(s.kind == "insert" for s in steps)
150 assert len(steps) == 2
151
152 def test_sequence_to_empty_all_deletes(self) -> None:
153 base = self._seq([60, 62])
154 steps = lcs_edit_script(base, [])
155 assert all(s.kind == "delete" for s in steps)
156 assert len(steps) == 2
157
158 def test_single_insert_at_end(self) -> None:
159 # base=[60,62], target=[60,62,64] → keep 60, keep 62, insert 64
160 base = self._seq([60, 62])
161 target = self._seq([60, 62, 64])
162 steps = lcs_edit_script(base, target)
163 keeps = [s for s in steps if s.kind == "keep"]
164 inserts = [s for s in steps if s.kind == "insert"]
165 assert len(keeps) == 2
166 assert len(inserts) == 1
167 assert inserts[0].note["pitch"] == 64
168
169 def test_single_delete_from_middle(self) -> None:
170 # base=[60,62,64], target=[60,64] → keep 60, delete 62, keep 64
171 # NoteKeys with start_tick=pitch ensure 64@64 matches 64@64.
172 base = self._seq([60, 62, 64])
173 target = self._seq([60, 64])
174 steps = lcs_edit_script(base, target)
175 deletes = [s for s in steps if s.kind == "delete"]
176 assert len(deletes) == 1
177 assert deletes[0].note["pitch"] == 62
178
179 def test_pitch_change_is_delete_plus_insert(self) -> None:
180 # A note with a different pitch → one delete + one insert.
181 base = [_note(60)]
182 target = [_note(62)]
183 steps = lcs_edit_script(base, target)
184 kinds = {s.kind for s in steps}
185 assert "delete" in kinds
186 assert "insert" in kinds
187 assert "keep" not in kinds
188
189 def test_lcs_is_minimal_keeps_equal_lcs_length(self) -> None:
190 # LCS of [60,62,64,65] and [60,64,65,67] is [60,64,65] (length 3)
191 # because 60@60, 64@64, 65@65 all have matching counterparts in target.
192 base = self._seq([60, 62, 64, 65])
193 target = self._seq([60, 64, 65, 67])
194 steps = lcs_edit_script(base, target)
195 keeps = [s for s in steps if s.kind == "keep"]
196 assert len(keeps) == 3
197
198 def test_empty_both_returns_empty(self) -> None:
199 steps = lcs_edit_script([], [])
200 assert steps == []
201
202 def test_step_indices_are_consistent(self) -> None:
203 base = self._seq([60, 62, 64])
204 target = self._seq([60, 64])
205 steps = lcs_edit_script(base, target)
206 base_indices = [s.base_index for s in steps if s.kind != "insert"]
207 target_indices = [s.target_index for s in steps if s.kind != "delete"]
208 assert base_indices == sorted(base_indices)
209 assert target_indices == sorted(target_indices)
210
211 def test_reorder_detected_as_delete_insert(self) -> None:
212 # Swapping pitches at the same positions → notes differ → no keeps.
213 # Using start_tick=0 for all to guarantee tick collision is NOT the issue;
214 # the pitch mismatch is what creates the delete+insert.
215 base = [NoteKey(pitch=60, velocity=80, start_tick=0, duration_ticks=480, channel=0),
216 NoteKey(pitch=62, velocity=80, start_tick=480, duration_ticks=480, channel=0)]
217 target = [NoteKey(pitch=62, velocity=80, start_tick=0, duration_ticks=480, channel=0),
218 NoteKey(pitch=60, velocity=80, start_tick=480, duration_ticks=480, channel=0)]
219 steps = lcs_edit_script(base, target)
220 keeps = [s for s in steps if s.kind == "keep"]
221 # No notes match exactly (same pitch at same tick is not present in both).
222 assert len(keeps) == 0
223
224
225 # ---------------------------------------------------------------------------
226 # diff_midi_notes
227 # ---------------------------------------------------------------------------
228
229 class TestDiffMidiNotes:
230 def test_no_change_returns_empty_ops(self) -> None:
231 midi_bytes = _build_midi([(60, 80, 0, 480)])
232 delta = diff_midi_notes(midi_bytes, midi_bytes)
233 assert delta["ops"] == []
234
235 def test_no_change_summary(self) -> None:
236 midi_bytes = _build_midi([(60, 80, 0, 480)])
237 delta = diff_midi_notes(midi_bytes, midi_bytes)
238 assert "no note changes" in delta["summary"]
239
240 def test_add_note_returns_insert_op(self) -> None:
241 base_bytes = _build_midi([(60, 80, 0, 480)])
242 target_bytes = _build_midi([(60, 80, 0, 480), (64, 80, 480, 480)])
243 delta = diff_midi_notes(base_bytes, target_bytes)
244 inserts = [op for op in delta["ops"] if op["op"] == "insert"]
245 assert len(inserts) == 1
246
247 def test_remove_note_returns_delete_op(self) -> None:
248 base_bytes = _build_midi([(60, 80, 0, 480), (64, 80, 480, 480)])
249 target_bytes = _build_midi([(60, 80, 0, 480)])
250 delta = diff_midi_notes(base_bytes, target_bytes)
251 deletes = [op for op in delta["ops"] if op["op"] == "delete"]
252 assert len(deletes) == 1
253
254 def test_change_pitch_produces_delete_and_insert(self) -> None:
255 base_bytes = _build_midi([(60, 80, 0, 480)])
256 target_bytes = _build_midi([(62, 80, 0, 480)])
257 delta = diff_midi_notes(base_bytes, target_bytes)
258 kinds = {op["op"] for op in delta["ops"]}
259 assert "delete" in kinds
260 assert "insert" in kinds
261
262 def test_summary_mentions_added_notes(self) -> None:
263 base_bytes = _build_midi([(60, 80, 0, 480)])
264 target_bytes = _build_midi([(60, 80, 0, 480), (64, 80, 480, 480)])
265 delta = diff_midi_notes(base_bytes, target_bytes)
266 assert "added" in delta["summary"]
267
268 def test_summary_mentions_removed_notes(self) -> None:
269 base_bytes = _build_midi([(60, 80, 0, 480), (64, 80, 480, 480)])
270 target_bytes = _build_midi([(60, 80, 0, 480)])
271 delta = diff_midi_notes(base_bytes, target_bytes)
272 assert "removed" in delta["summary"]
273
274 def test_summary_singular_for_one_note(self) -> None:
275 base_bytes = _build_midi([])
276 target_bytes = _build_midi([(60, 80, 0, 480)])
277 delta = diff_midi_notes(base_bytes, target_bytes)
278 assert "1 note added" in delta["summary"]
279
280 def test_summary_plural_for_multiple_notes(self) -> None:
281 base_bytes = _build_midi([])
282 target_bytes = _build_midi([(60, 80, 0, 480), (64, 80, 480, 480)])
283 delta = diff_midi_notes(base_bytes, target_bytes)
284 assert "2 notes added" in delta["summary"]
285
286 def test_content_id_is_deterministic(self) -> None:
287 midi_bytes = _build_midi([(60, 80, 0, 480)])
288 empty_bytes = _build_midi([])
289 delta1 = diff_midi_notes(empty_bytes, midi_bytes)
290 delta2 = diff_midi_notes(empty_bytes, midi_bytes)
291 ids1 = [op["content_id"] for op in delta1["ops"]]
292 ids2 = [op["content_id"] for op in delta2["ops"]]
293 assert ids1 == ids2
294
295 def test_content_ids_differ_for_different_notes(self) -> None:
296 empty_bytes = _build_midi([])
297 midi_c4 = _build_midi([(60, 80, 0, 480)])
298 midi_d4 = _build_midi([(62, 80, 0, 480)])
299 delta_c4 = diff_midi_notes(empty_bytes, midi_c4)
300 delta_d4 = diff_midi_notes(empty_bytes, midi_d4)
301 id_c4 = delta_c4["ops"][0]["content_id"]
302 id_d4 = delta_d4["ops"][0]["content_id"]
303 assert id_c4 != id_d4
304
305 def test_content_summary_is_human_readable(self) -> None:
306 empty_bytes = _build_midi([])
307 target_bytes = _build_midi([(60, 80, 0, 480)]) # C4
308 delta = diff_midi_notes(empty_bytes, target_bytes)
309 summary = delta["ops"][0]["content_summary"]
310 assert "C4" in summary
311 assert "vel=80" in summary
312
313 def test_domain_is_midi_notes(self) -> None:
314 midi_bytes = _build_midi([(60, 80, 0, 480)])
315 empty_bytes = _build_midi([])
316 delta = diff_midi_notes(empty_bytes, midi_bytes)
317 assert delta["domain"] == "midi_notes"
318
319 def test_invalid_base_raises_value_error(self) -> None:
320 valid = _build_midi([(60, 80, 0, 480)])
321 with pytest.raises(ValueError):
322 diff_midi_notes(b"garbage", valid)
323
324 def test_invalid_target_raises_value_error(self) -> None:
325 valid = _build_midi([(60, 80, 0, 480)])
326 with pytest.raises(ValueError):
327 diff_midi_notes(valid, b"garbage")
328
329 def test_file_path_appears_in_content_summary_context(self) -> None:
330 # file_path is used only for logging; no crash expected.
331 base_bytes = _build_midi([])
332 target_bytes = _build_midi([(60, 80, 0, 480)])
333 delta = diff_midi_notes(
334 base_bytes, target_bytes, file_path="tracks/piano.mid"
335 )
336 assert len(delta["ops"]) == 1
337
338 def test_position_reflects_sequence_index(self) -> None:
339 empty = _build_midi([])
340 two_notes = _build_midi([(60, 80, 0, 480), (64, 80, 480, 480)])
341 delta = diff_midi_notes(empty, two_notes)
342 positions = [op["position"] for op in delta["ops"]]
343 assert 0 in positions
344 assert 1 in positions