cgcardona / muse public
test_midi_semantic.py python
380 lines 12.9 KB
630bfa59 feat(midi): add 20 new semantic porcelain commands (#120) Gabriel Cardona <cgcardona@gmail.com> 15h ago
1 """Tests for the new MIDI semantic porcelain — analysis helpers and CLI commands.
2
3 Coverage:
4 - muse/plugins/midi/_analysis.py: all eight analysis functions
5 - CLI commands (via CliRunner): rhythm, scale, contour, density, tension,
6 cadence, motif, voice_leading, instrumentation, tempo, quantize, humanize,
7 invert, retrograde, arpeggiate, velocity_normalize, midi_compare
8 """
9
10 from __future__ import annotations
11
12 import pathlib
13
14 import pytest
15 from typer.testing import CliRunner
16
17 from muse.cli.app import cli
18 from muse.plugins.midi._analysis import (
19 analyze_contour,
20 analyze_density,
21 analyze_rhythm,
22 check_voice_leading,
23 compute_tension,
24 detect_cadences,
25 detect_scale,
26 estimate_tempo,
27 find_motifs,
28 phrase_similarity,
29 )
30 from muse.plugins.midi._query import NoteInfo
31
32 runner = CliRunner()
33
34 # ---------------------------------------------------------------------------
35 # Helpers
36 # ---------------------------------------------------------------------------
37
38 _TPB = 480 # ticks per beat
39
40
41 def _note(pitch: int, start_beat: float, dur_beats: float = 0.5, vel: int = 80, ch: int = 0) -> NoteInfo:
42 return NoteInfo(
43 pitch=pitch,
44 velocity=vel,
45 start_tick=round(start_beat * _TPB),
46 duration_ticks=round(dur_beats * _TPB),
47 channel=ch,
48 ticks_per_beat=_TPB,
49 )
50
51
52 def _make_scale_run() -> list[NoteInfo]:
53 """C major scale ascending, two octaves."""
54 pitches = [60, 62, 64, 65, 67, 69, 71, 72, 74, 76, 77, 79, 81, 83]
55 return [_note(p, i * 0.5) for i, p in enumerate(pitches)]
56
57
58 def _make_chord_sequence() -> list[NoteInfo]:
59 """Simple C–Am–F–G progression, one chord per bar (4 beats)."""
60 chords = [
61 [60, 64, 67], # bar 1: C major
62 [57, 60, 64], # bar 2: A minor
63 [53, 57, 60], # bar 3: F major
64 [55, 59, 62], # bar 4: G major
65 ]
66 notes: list[NoteInfo] = []
67 for bar_idx, pitches in enumerate(chords):
68 start = bar_idx * 4.0 # beat offset
69 for p in pitches:
70 notes.append(_note(p, start, dur_beats=3.5))
71 return notes
72
73
74 def _make_motif_track() -> list[NoteInfo]:
75 """A track where the interval pattern [+2, -1, +3] repeats three times."""
76 base_pitches = [60, 62, 61, 64]
77 notes: list[NoteInfo] = []
78 for rep in range(3):
79 offset = rep * 4.0
80 for i, p in enumerate(base_pitches):
81 notes.append(_note(p, offset + i * 0.5))
82 return notes
83
84
85 # ---------------------------------------------------------------------------
86 # _analysis.py unit tests
87 # ---------------------------------------------------------------------------
88
89
90 class TestDetectScale:
91 def test_major_scale_detected(self) -> None:
92 notes = _make_scale_run()
93 matches = detect_scale(notes)
94 assert matches, "should return at least one match"
95 tops = [m["name"] for m in matches[:3]]
96 assert "major" in tops
97
98 def test_returns_up_to_five(self) -> None:
99 notes = _make_scale_run()
100 matches = detect_scale(notes)
101 assert 1 <= len(matches) <= 5
102
103 def test_empty_notes(self) -> None:
104 assert detect_scale([]) == []
105
106 def test_confidence_bounded(self) -> None:
107 notes = _make_scale_run()
108 for m in detect_scale(notes):
109 assert 0.0 <= m["confidence"] <= 1.0
110
111 def test_single_note_still_works(self) -> None:
112 notes = [_note(60, 0.0)]
113 matches = detect_scale(notes)
114 assert len(matches) >= 1
115
116
117 class TestAnalyzeRhythm:
118 def test_on_beat_notes_high_quantization(self) -> None:
119 # Notes exactly on quarter-note grid
120 notes = [_note(60, float(i)) for i in range(8)]
121 analysis = analyze_rhythm(notes)
122 assert analysis["quantization_score"] >= 0.95
123
124 def test_empty_notes(self) -> None:
125 a = analyze_rhythm([])
126 assert a["total_notes"] == 0
127 assert a["quantization_score"] == 1.0
128
129 def test_syncopation_score_range(self) -> None:
130 notes = _make_scale_run()
131 a = analyze_rhythm(notes)
132 assert 0.0 <= a["syncopation_score"] <= 1.0
133
134 def test_swing_ratio_positive(self) -> None:
135 notes = _make_scale_run()
136 a = analyze_rhythm(notes)
137 assert a["swing_ratio"] >= 0.0
138
139 def test_dominant_subdivision_is_string(self) -> None:
140 notes = _make_scale_run()
141 a = analyze_rhythm(notes)
142 assert isinstance(a["dominant_subdivision"], str)
143
144
145 class TestAnalyzeContour:
146 def test_ascending_scale(self) -> None:
147 notes = [_note(60 + i, float(i)) for i in range(8)]
148 analysis = analyze_contour(notes)
149 assert analysis["shape"] == "ascending"
150
151 def test_descending_scale(self) -> None:
152 notes = [_note(72 - i, float(i)) for i in range(8)]
153 analysis = analyze_contour(notes)
154 assert analysis["shape"] == "descending"
155
156 def test_arch_shape(self) -> None:
157 pitches = [60, 62, 65, 67, 69, 67, 65, 62, 60]
158 notes = [_note(p, float(i)) for i, p in enumerate(pitches)]
159 analysis = analyze_contour(notes)
160 assert analysis["shape"] in ("arch", "wave")
161
162 def test_intervals_list_is_bounded(self) -> None:
163 notes = _make_scale_run()
164 analysis = analyze_contour(notes)
165 assert len(analysis["intervals"]) <= 32
166
167 def test_single_note(self) -> None:
168 notes = [_note(60, 0.0)]
169 analysis = analyze_contour(notes)
170 assert analysis["shape"] == "flat"
171 assert analysis["intervals"] == []
172
173
174 class TestAnalyzeDensity:
175 def test_bar_count_matches(self) -> None:
176 notes = _make_chord_sequence()
177 bars = analyze_density(notes)
178 assert len(bars) == 4
179
180 def test_notes_per_beat_positive(self) -> None:
181 notes = _make_chord_sequence()
182 for b in analyze_density(notes):
183 assert b["notes_per_beat"] > 0
184
185 def test_empty_notes(self) -> None:
186 assert analyze_density([]) == []
187
188
189 class TestComputeTension:
190 def test_returns_one_entry_per_bar(self) -> None:
191 notes = _make_chord_sequence()
192 bars = compute_tension(notes)
193 assert len(bars) == 4
194
195 def test_tension_in_range(self) -> None:
196 notes = _make_chord_sequence()
197 for b in compute_tension(notes):
198 assert 0.0 <= b["tension"] <= 1.0
199
200 def test_label_is_string(self) -> None:
201 notes = _make_chord_sequence()
202 for b in compute_tension(notes):
203 assert b["label"] in ("consonant", "mild", "tense")
204
205
206 class TestDetectCadences:
207 def test_short_track_no_cadences(self) -> None:
208 notes = _make_scale_run()
209 assert detect_cadences(notes) == []
210
211 def test_four_bar_chord_sequence_may_find_cadence(self) -> None:
212 notes = _make_chord_sequence()
213 cadences = detect_cadences(notes)
214 assert isinstance(cadences, list)
215
216
217 class TestFindMotifs:
218 def test_finds_repeated_pattern(self) -> None:
219 notes = _make_motif_track()
220 motifs = find_motifs(notes, min_length=3, min_occurrences=2)
221 assert len(motifs) >= 1
222
223 def test_motif_occurrences_gte_min(self) -> None:
224 notes = _make_motif_track()
225 for m in find_motifs(notes, min_occurrences=2):
226 assert m["occurrences"] >= 2
227
228 def test_too_short_track(self) -> None:
229 notes = [_note(60, 0.0), _note(62, 0.5)]
230 assert find_motifs(notes) == []
231
232 def test_interval_pattern_is_list_of_int(self) -> None:
233 notes = _make_motif_track()
234 for m in find_motifs(notes):
235 for iv in m["interval_pattern"]:
236 assert isinstance(iv, int)
237
238
239 class TestCheckVoiceLeading:
240 def test_no_parallel_motion_on_single_voice(self) -> None:
241 # A monophonic scale has no simultaneous voices, so parallel fifths/octaves
242 # cannot occur. Large leaps between bars may still be reported.
243 notes = _make_scale_run()
244 issues = check_voice_leading(notes)
245 parallel = [i for i in issues if i["issue_type"] in ("parallel_fifths", "parallel_octaves")]
246 assert parallel == []
247
248 def test_returns_list(self) -> None:
249 notes = _make_chord_sequence()
250 issues = check_voice_leading(notes)
251 assert isinstance(issues, list)
252
253 def test_issue_types_are_valid(self) -> None:
254 notes = _make_chord_sequence()
255 valid = {"parallel_fifths", "parallel_octaves", "large_leap"}
256 for issue in check_voice_leading(notes):
257 assert issue["issue_type"] in valid
258
259
260 class TestEstimateTempo:
261 def test_regular_quarter_notes_approx_120(self) -> None:
262 # Quarter notes at 480 tpb, 120 BPM ≈ one note per beat
263 notes = [_note(60, float(i)) for i in range(8)]
264 est = estimate_tempo(notes)
265 assert 60.0 <= est["estimated_bpm"] <= 300.0
266
267 def test_empty_notes(self) -> None:
268 est = estimate_tempo([])
269 assert est["estimated_bpm"] == 120.0
270 assert est["confidence"] == "none"
271
272 def test_confidence_is_valid(self) -> None:
273 notes = _make_scale_run()
274 est = estimate_tempo(notes)
275 assert est["confidence"] in ("high", "medium", "low", "none")
276
277
278 class TestPhraseSimilarity:
279 def test_identical_phrases_score_high(self) -> None:
280 notes = _make_scale_run()
281 score = phrase_similarity(notes, notes)
282 assert score >= 0.9
283
284 def test_empty_query_returns_zero(self) -> None:
285 notes = _make_scale_run()
286 assert phrase_similarity([], notes) == 0.0
287
288 def test_score_in_range(self) -> None:
289 a = _make_scale_run()
290 b = _make_chord_sequence()
291 score = phrase_similarity(a, b)
292 assert 0.0 <= score <= 1.0
293
294
295 # ---------------------------------------------------------------------------
296 # CLI command integration tests (no real .muse repo needed for help/validation)
297 # ---------------------------------------------------------------------------
298
299
300 class TestCliHelpPages:
301 """Verify all new commands are registered and have help text."""
302
303 @pytest.mark.parametrize("cmd", [
304 ["midi", "rhythm", "--help"],
305 ["midi", "scale", "--help"],
306 ["midi", "contour", "--help"],
307 ["midi", "density", "--help"],
308 ["midi", "tension", "--help"],
309 ["midi", "cadence", "--help"],
310 ["midi", "motif", "--help"],
311 ["midi", "voice-leading", "--help"],
312 ["midi", "instrumentation", "--help"],
313 ["midi", "tempo", "--help"],
314 ["midi", "compare", "--help"],
315 ["midi", "quantize", "--help"],
316 ["midi", "humanize", "--help"],
317 ["midi", "invert", "--help"],
318 ["midi", "retrograde", "--help"],
319 ["midi", "arpeggiate", "--help"],
320 ["midi", "normalize", "--help"],
321 ["midi", "shard", "--help"],
322 ["midi", "agent-map", "--help"],
323 ["midi", "find-phrase", "--help"],
324 ])
325 def test_help_exits_zero(self, cmd: list[str]) -> None:
326 result = runner.invoke(cli, cmd)
327 assert result.exit_code == 0, f"Help failed for {cmd}: {result.output}"
328
329 def test_midi_namespace_lists_all_commands(self) -> None:
330 result = runner.invoke(cli, ["midi", "--help"])
331 assert result.exit_code == 0
332 output = result.output
333 for expected in [
334 "rhythm", "scale", "contour", "density", "tension",
335 "cadence", "motif", "voice-leading", "instrumentation",
336 "tempo", "compare", "quantize", "humanize",
337 "invert", "retrograde", "arpeggiate", "normalize",
338 "shard", "agent-map", "find-phrase",
339 ]:
340 assert expected in output, f"'{expected}' not found in midi help"
341
342
343 class TestQuantizeValidation:
344 """Validate --grid and --strength option guards."""
345
346 def test_unknown_grid_exits_error(self, tmp_path: pathlib.Path) -> None:
347 result = runner.invoke(cli, ["midi", "quantize", "fake.mid", "--grid", "99th"])
348 assert result.exit_code != 0
349
350 def test_invalid_strength_exits_error(self, tmp_path: pathlib.Path) -> None:
351 result = runner.invoke(cli, ["midi", "quantize", "fake.mid", "--strength", "2.5"])
352 assert result.exit_code != 0
353
354
355 class TestArpeggiateValidation:
356 def test_unknown_rate_exits_error(self) -> None:
357 result = runner.invoke(cli, ["midi", "arpeggiate", "fake.mid", "--rate", "64th"])
358 assert result.exit_code != 0
359
360 def test_unknown_order_exits_error(self) -> None:
361 result = runner.invoke(cli, ["midi", "arpeggiate", "fake.mid", "--order", "zigzag"])
362 assert result.exit_code != 0
363
364
365 class TestNormalizeValidation:
366 def test_min_gte_max_exits_error(self) -> None:
367 result = runner.invoke(cli, ["midi", "normalize", "fake.mid", "--min", "100", "--max", "50"])
368 assert result.exit_code != 0
369
370 def test_out_of_range_min_exits_error(self) -> None:
371 result = runner.invoke(cli, ["midi", "normalize", "fake.mid", "--min", "0"])
372 assert result.exit_code != 0
373
374
375 class TestMidiShardValidation:
376 def test_mutually_exclusive_flags(self) -> None:
377 result = runner.invoke(cli, [
378 "midi", "shard", "fake.mid", "--shards", "4", "--bars-per-shard", "8"
379 ])
380 assert result.exit_code != 0