cgcardona / muse public
test_midi_diff.py python
343 lines 13.5 KB
e6786943 feat: upgrade to Python 3.14, drop from __future__ import annotations Gabriel Cardona <cgcardona@gmail.com> 1d ago
1 """Tests for muse.plugins.midi.midi_diff — Myers LCS on MIDI note sequences.
2
3 Covers:
4 - NoteKey extraction from MIDI bytes.
5 - LCS edit script correctness (keep/insert/delete).
6 - LCS minimality (length of keep steps == LCS length).
7 - diff_midi_notes() produces correct StructuredDelta.
8 - Content IDs are deterministic and unique per note.
9 - Human-readable summaries and content_summary strings.
10 """
11
12 import io
13 import struct
14
15 import mido
16 import pytest
17
18 from muse.plugins.midi.midi_diff import (
19 NoteKey,
20 diff_midi_notes,
21 extract_notes,
22 lcs_edit_script,
23 )
24
25
26 # ---------------------------------------------------------------------------
27 # MIDI builder helpers
28 # ---------------------------------------------------------------------------
29
30 def _build_midi(notes: list[tuple[int, int, int, int]]) -> bytes:
31 """Build a minimal type-0 MIDI file from (pitch, velocity, start, duration) tuples.
32
33 All values use ticks_per_beat=480. Produces valid mido-parseable MIDI bytes.
34 """
35 mid = mido.MidiFile(type=0, ticks_per_beat=480)
36 track = mido.MidiTrack()
37 mid.tracks.append(track)
38
39 # Collect all events sorted by tick.
40 events: list[tuple[int, str, int, int]] = [] # (tick, type, note, velocity)
41 for pitch, velocity, start, duration in notes:
42 events.append((start, "note_on", pitch, velocity))
43 events.append((start + duration, "note_off", pitch, 0))
44
45 events.sort(key=lambda e: e[0])
46
47 prev_tick = 0
48 for tick, msg_type, note, vel in events:
49 delta = tick - prev_tick
50 track.append(mido.Message(msg_type, note=note, velocity=vel, time=delta))
51 prev_tick = tick
52
53 track.append(mido.MetaMessage("end_of_track", time=0))
54
55 buf = io.BytesIO()
56 mid.save(file=buf)
57 return buf.getvalue()
58
59
60 def _note(pitch: int, velocity: int = 80, start: int = 0, duration: int = 480) -> NoteKey:
61 return NoteKey(
62 pitch=pitch, velocity=velocity, start_tick=start,
63 duration_ticks=duration, channel=0,
64 )
65
66
67 # ---------------------------------------------------------------------------
68 # extract_notes
69 # ---------------------------------------------------------------------------
70
71 class TestExtractNotes:
72 def test_empty_midi_returns_no_notes(self) -> None:
73 midi_bytes = _build_midi([])
74 notes, tpb = extract_notes(midi_bytes)
75 assert notes == []
76 assert tpb == 480
77
78 def test_single_note_extracted(self) -> None:
79 midi_bytes = _build_midi([(60, 80, 0, 480)]) # C4
80 notes, tpb = extract_notes(midi_bytes)
81 assert len(notes) == 1
82 assert notes[0]["pitch"] == 60
83 assert notes[0]["velocity"] == 80
84 assert notes[0]["start_tick"] == 0
85 assert notes[0]["duration_ticks"] == 480
86
87 def test_multiple_notes_extracted(self) -> None:
88 midi_bytes = _build_midi([
89 (60, 80, 0, 480),
90 (64, 90, 480, 480),
91 (67, 70, 960, 480),
92 ])
93 notes, _ = extract_notes(midi_bytes)
94 assert len(notes) == 3
95
96 def test_notes_sorted_by_start_tick(self) -> None:
97 midi_bytes = _build_midi([
98 (67, 70, 960, 240),
99 (60, 80, 0, 480),
100 (64, 90, 480, 480),
101 ])
102 notes, _ = extract_notes(midi_bytes)
103 ticks = [n["start_tick"] for n in notes]
104 assert ticks == sorted(ticks)
105
106 def test_invalid_bytes_raises_value_error(self) -> None:
107 with pytest.raises(ValueError):
108 extract_notes(b"not a midi file")
109
110 def test_ticks_per_beat_is_returned(self) -> None:
111 midi_bytes = _build_midi([(60, 80, 0, 480)])
112 _, tpb = extract_notes(midi_bytes)
113 assert tpb == 480
114
115
116 # ---------------------------------------------------------------------------
117 # lcs_edit_script
118 # ---------------------------------------------------------------------------
119
120 class TestLCSEditScript:
121 """LCS tests use start_tick=pitch so same-pitch notes always compare equal.
122
123 NoteKey equality requires ALL five fields to match. Using start_tick=pitch
124 ensures that notes with the same pitch in base and target are considered
125 identical by LCS, giving intuitive edit scripts.
126 """
127
128 def _nk(self, pitch: int) -> NoteKey:
129 """Make a NoteKey where start_tick equals pitch for stable matching."""
130 return NoteKey(
131 pitch=pitch, velocity=80,
132 start_tick=pitch, # deterministic: same pitch → same tick → same key
133 duration_ticks=480, channel=0,
134 )
135
136 def _seq(self, pitches: list[int]) -> list[NoteKey]:
137 return [self._nk(p) for p in pitches]
138
139 def test_identical_sequences_keeps_all(self) -> None:
140 notes = self._seq([60, 62, 64])
141 steps = lcs_edit_script(notes, notes)
142 kinds = [s.kind for s in steps]
143 assert kinds == ["keep", "keep", "keep"]
144
145 def test_empty_to_sequence_all_inserts(self) -> None:
146 target = self._seq([60, 62])
147 steps = lcs_edit_script([], target)
148 assert all(s.kind == "insert" for s in steps)
149 assert len(steps) == 2
150
151 def test_sequence_to_empty_all_deletes(self) -> None:
152 base = self._seq([60, 62])
153 steps = lcs_edit_script(base, [])
154 assert all(s.kind == "delete" for s in steps)
155 assert len(steps) == 2
156
157 def test_single_insert_at_end(self) -> None:
158 # base=[60,62], target=[60,62,64] → keep 60, keep 62, insert 64
159 base = self._seq([60, 62])
160 target = self._seq([60, 62, 64])
161 steps = lcs_edit_script(base, target)
162 keeps = [s for s in steps if s.kind == "keep"]
163 inserts = [s for s in steps if s.kind == "insert"]
164 assert len(keeps) == 2
165 assert len(inserts) == 1
166 assert inserts[0].note["pitch"] == 64
167
168 def test_single_delete_from_middle(self) -> None:
169 # base=[60,62,64], target=[60,64] → keep 60, delete 62, keep 64
170 # NoteKeys with start_tick=pitch ensure 64@64 matches 64@64.
171 base = self._seq([60, 62, 64])
172 target = self._seq([60, 64])
173 steps = lcs_edit_script(base, target)
174 deletes = [s for s in steps if s.kind == "delete"]
175 assert len(deletes) == 1
176 assert deletes[0].note["pitch"] == 62
177
178 def test_pitch_change_is_delete_plus_insert(self) -> None:
179 # A note with a different pitch → one delete + one insert.
180 base = [_note(60)]
181 target = [_note(62)]
182 steps = lcs_edit_script(base, target)
183 kinds = {s.kind for s in steps}
184 assert "delete" in kinds
185 assert "insert" in kinds
186 assert "keep" not in kinds
187
188 def test_lcs_is_minimal_keeps_equal_lcs_length(self) -> None:
189 # LCS of [60,62,64,65] and [60,64,65,67] is [60,64,65] (length 3)
190 # because 60@60, 64@64, 65@65 all have matching counterparts in target.
191 base = self._seq([60, 62, 64, 65])
192 target = self._seq([60, 64, 65, 67])
193 steps = lcs_edit_script(base, target)
194 keeps = [s for s in steps if s.kind == "keep"]
195 assert len(keeps) == 3
196
197 def test_empty_both_returns_empty(self) -> None:
198 steps = lcs_edit_script([], [])
199 assert steps == []
200
201 def test_step_indices_are_consistent(self) -> None:
202 base = self._seq([60, 62, 64])
203 target = self._seq([60, 64])
204 steps = lcs_edit_script(base, target)
205 base_indices = [s.base_index for s in steps if s.kind != "insert"]
206 target_indices = [s.target_index for s in steps if s.kind != "delete"]
207 assert base_indices == sorted(base_indices)
208 assert target_indices == sorted(target_indices)
209
210 def test_reorder_detected_as_delete_insert(self) -> None:
211 # Swapping pitches at the same positions → notes differ → no keeps.
212 # Using start_tick=0 for all to guarantee tick collision is NOT the issue;
213 # the pitch mismatch is what creates the delete+insert.
214 base = [NoteKey(pitch=60, velocity=80, start_tick=0, duration_ticks=480, channel=0),
215 NoteKey(pitch=62, velocity=80, start_tick=480, duration_ticks=480, channel=0)]
216 target = [NoteKey(pitch=62, velocity=80, start_tick=0, duration_ticks=480, channel=0),
217 NoteKey(pitch=60, velocity=80, start_tick=480, duration_ticks=480, channel=0)]
218 steps = lcs_edit_script(base, target)
219 keeps = [s for s in steps if s.kind == "keep"]
220 # No notes match exactly (same pitch at same tick is not present in both).
221 assert len(keeps) == 0
222
223
224 # ---------------------------------------------------------------------------
225 # diff_midi_notes
226 # ---------------------------------------------------------------------------
227
228 class TestDiffMidiNotes:
229 def test_no_change_returns_empty_ops(self) -> None:
230 midi_bytes = _build_midi([(60, 80, 0, 480)])
231 delta = diff_midi_notes(midi_bytes, midi_bytes)
232 assert delta["ops"] == []
233
234 def test_no_change_summary(self) -> None:
235 midi_bytes = _build_midi([(60, 80, 0, 480)])
236 delta = diff_midi_notes(midi_bytes, midi_bytes)
237 assert "no note changes" in delta["summary"]
238
239 def test_add_note_returns_insert_op(self) -> None:
240 base_bytes = _build_midi([(60, 80, 0, 480)])
241 target_bytes = _build_midi([(60, 80, 0, 480), (64, 80, 480, 480)])
242 delta = diff_midi_notes(base_bytes, target_bytes)
243 inserts = [op for op in delta["ops"] if op["op"] == "insert"]
244 assert len(inserts) == 1
245
246 def test_remove_note_returns_delete_op(self) -> None:
247 base_bytes = _build_midi([(60, 80, 0, 480), (64, 80, 480, 480)])
248 target_bytes = _build_midi([(60, 80, 0, 480)])
249 delta = diff_midi_notes(base_bytes, target_bytes)
250 deletes = [op for op in delta["ops"] if op["op"] == "delete"]
251 assert len(deletes) == 1
252
253 def test_change_pitch_produces_delete_and_insert(self) -> None:
254 base_bytes = _build_midi([(60, 80, 0, 480)])
255 target_bytes = _build_midi([(62, 80, 0, 480)])
256 delta = diff_midi_notes(base_bytes, target_bytes)
257 kinds = {op["op"] for op in delta["ops"]}
258 assert "delete" in kinds
259 assert "insert" in kinds
260
261 def test_summary_mentions_added_notes(self) -> None:
262 base_bytes = _build_midi([(60, 80, 0, 480)])
263 target_bytes = _build_midi([(60, 80, 0, 480), (64, 80, 480, 480)])
264 delta = diff_midi_notes(base_bytes, target_bytes)
265 assert "added" in delta["summary"]
266
267 def test_summary_mentions_removed_notes(self) -> None:
268 base_bytes = _build_midi([(60, 80, 0, 480), (64, 80, 480, 480)])
269 target_bytes = _build_midi([(60, 80, 0, 480)])
270 delta = diff_midi_notes(base_bytes, target_bytes)
271 assert "removed" in delta["summary"]
272
273 def test_summary_singular_for_one_note(self) -> None:
274 base_bytes = _build_midi([])
275 target_bytes = _build_midi([(60, 80, 0, 480)])
276 delta = diff_midi_notes(base_bytes, target_bytes)
277 assert "1 note added" in delta["summary"]
278
279 def test_summary_plural_for_multiple_notes(self) -> None:
280 base_bytes = _build_midi([])
281 target_bytes = _build_midi([(60, 80, 0, 480), (64, 80, 480, 480)])
282 delta = diff_midi_notes(base_bytes, target_bytes)
283 assert "2 notes added" in delta["summary"]
284
285 def test_content_id_is_deterministic(self) -> None:
286 midi_bytes = _build_midi([(60, 80, 0, 480)])
287 empty_bytes = _build_midi([])
288 delta1 = diff_midi_notes(empty_bytes, midi_bytes)
289 delta2 = diff_midi_notes(empty_bytes, midi_bytes)
290 ids1 = [op["content_id"] for op in delta1["ops"]]
291 ids2 = [op["content_id"] for op in delta2["ops"]]
292 assert ids1 == ids2
293
294 def test_content_ids_differ_for_different_notes(self) -> None:
295 empty_bytes = _build_midi([])
296 midi_c4 = _build_midi([(60, 80, 0, 480)])
297 midi_d4 = _build_midi([(62, 80, 0, 480)])
298 delta_c4 = diff_midi_notes(empty_bytes, midi_c4)
299 delta_d4 = diff_midi_notes(empty_bytes, midi_d4)
300 id_c4 = delta_c4["ops"][0]["content_id"]
301 id_d4 = delta_d4["ops"][0]["content_id"]
302 assert id_c4 != id_d4
303
304 def test_content_summary_is_human_readable(self) -> None:
305 empty_bytes = _build_midi([])
306 target_bytes = _build_midi([(60, 80, 0, 480)]) # C4
307 delta = diff_midi_notes(empty_bytes, target_bytes)
308 summary = delta["ops"][0]["content_summary"]
309 assert "C4" in summary
310 assert "vel=80" in summary
311
312 def test_domain_is_midi_notes(self) -> None:
313 midi_bytes = _build_midi([(60, 80, 0, 480)])
314 empty_bytes = _build_midi([])
315 delta = diff_midi_notes(empty_bytes, midi_bytes)
316 assert delta["domain"] == "midi_notes"
317
318 def test_invalid_base_raises_value_error(self) -> None:
319 valid = _build_midi([(60, 80, 0, 480)])
320 with pytest.raises(ValueError):
321 diff_midi_notes(b"garbage", valid)
322
323 def test_invalid_target_raises_value_error(self) -> None:
324 valid = _build_midi([(60, 80, 0, 480)])
325 with pytest.raises(ValueError):
326 diff_midi_notes(valid, b"garbage")
327
328 def test_file_path_appears_in_content_summary_context(self) -> None:
329 # file_path is used only for logging; no crash expected.
330 base_bytes = _build_midi([])
331 target_bytes = _build_midi([(60, 80, 0, 480)])
332 delta = diff_midi_notes(
333 base_bytes, target_bytes, file_path="tracks/piano.mid"
334 )
335 assert len(delta["ops"]) == 1
336
337 def test_position_reflects_sequence_index(self) -> None:
338 empty = _build_midi([])
339 two_notes = _build_midi([(60, 80, 0, 480), (64, 80, 480, 480)])
340 delta = diff_midi_notes(empty, two_notes)
341 positions = [op["position"] for op in delta["ops"]]
342 assert 0 in positions
343 assert 1 in positions