cgcardona / muse public
test_music_invariants.py python
193 lines 7.4 KB
e6786943 feat: upgrade to Python 3.14, drop from __future__ import annotations Gabriel Cardona <cgcardona@gmail.com> 1d ago
1 """Tests for muse.plugins.midi._invariants — rule checks and runner."""
2
3 import pathlib
4
5 import pytest
6
7 from muse.plugins.midi._invariants import (
8 InvariantRule,
9 check_key_consistency,
10 check_max_polyphony,
11 check_no_parallel_fifths,
12 check_pitch_range,
13 load_invariant_rules,
14 )
15 from muse.plugins.midi._query import NoteInfo
16 from muse.plugins.midi.midi_diff import NoteKey
17
18
19 def _note(pitch: int, start_tick: int = 0, duration_ticks: int = 480,
20 velocity: int = 80, channel: int = 0) -> NoteInfo:
21 return NoteInfo.from_note_key(
22 NoteKey(
23 pitch=pitch,
24 velocity=velocity,
25 start_tick=start_tick,
26 duration_ticks=duration_ticks,
27 channel=channel,
28 ),
29 ticks_per_beat=480,
30 )
31
32
33 # ---------------------------------------------------------------------------
34 # check_max_polyphony
35 # ---------------------------------------------------------------------------
36
37
38 class TestCheckMaxPolyphony:
39 def test_no_violation_when_polyphony_ok(self) -> None:
40 notes = [_note(60, 0), _note(64, 480), _note(67, 960)]
41 violations = check_max_polyphony(notes, "track.mid", "poly", "warning", max_simultaneous=4)
42 assert violations == []
43
44 def test_violation_when_too_many_simultaneous(self) -> None:
45 # 5 notes all starting at tick 0 with long duration.
46 notes = [_note(60 + i, 0, 480) for i in range(5)]
47 violations = check_max_polyphony(notes, "track.mid", "poly", "error", max_simultaneous=4)
48 assert len(violations) == 1
49 assert violations[0]["severity"] == "error"
50 assert violations[0]["rule_name"] == "poly"
51
52 def test_violation_mentions_peak_count(self) -> None:
53 notes = [_note(60 + i, 0, 480) for i in range(6)]
54 violations = check_max_polyphony(notes, "track.mid", "poly", "warning", max_simultaneous=4)
55 assert "6" in violations[0]["description"]
56
57 def test_empty_notes_produces_no_violation(self) -> None:
58 violations = check_max_polyphony([], "track.mid", "poly", "warning")
59 assert violations == []
60
61 def test_non_overlapping_notes_ok(self) -> None:
62 # Each note starts after the previous one ends.
63 notes = [_note(60, start_tick=i * 960, duration_ticks=480) for i in range(10)]
64 violations = check_max_polyphony(notes, "track.mid", "poly", "warning", max_simultaneous=4)
65 assert violations == []
66
67
68 # ---------------------------------------------------------------------------
69 # check_pitch_range
70 # ---------------------------------------------------------------------------
71
72
73 class TestCheckPitchRange:
74 def test_all_in_range_produces_no_violation(self) -> None:
75 notes = [_note(60), _note(72), _note(84)]
76 violations = check_pitch_range(notes, "track.mid", "range", "warning",
77 min_pitch=48, max_pitch=96)
78 assert violations == []
79
80 def test_too_low_produces_violation(self) -> None:
81 notes = [_note(36)] # below min=48
82 violations = check_pitch_range(notes, "track.mid", "range", "error",
83 min_pitch=48, max_pitch=96)
84 assert len(violations) == 1
85 assert "36" in violations[0]["description"]
86 assert violations[0]["severity"] == "error"
87
88 def test_too_high_produces_violation(self) -> None:
89 notes = [_note(100)] # above max=96
90 violations = check_pitch_range(notes, "track.mid", "range", "warning",
91 min_pitch=48, max_pitch=96)
92 assert len(violations) == 1
93
94 def test_multiple_out_of_range_produces_multiple_violations(self) -> None:
95 notes = [_note(30), _note(110), _note(60)]
96 violations = check_pitch_range(notes, "t.mid", "r", "info",
97 min_pitch=48, max_pitch=96)
98 assert len(violations) == 2
99
100
101 # ---------------------------------------------------------------------------
102 # check_key_consistency
103 # ---------------------------------------------------------------------------
104
105
106 class TestCheckKeyConsistency:
107 def test_cmajor_notes_no_violation(self) -> None:
108 # C major diatonic: C D E F G A B
109 c_major_pitches = [60, 62, 64, 65, 67, 69, 71] # C4-B4
110 notes = [_note(p) for p in c_major_pitches * 4]
111 violations = check_key_consistency(notes, "t.mid", "key", "info", threshold=0.2)
112 assert violations == []
113
114 def test_empty_notes_produces_no_violation(self) -> None:
115 violations = check_key_consistency([], "t.mid", "key", "warning")
116 assert violations == []
117
118
119 # ---------------------------------------------------------------------------
120 # check_no_parallel_fifths
121 # ---------------------------------------------------------------------------
122
123
124 class TestCheckNoParallelFifths:
125 def test_no_violation_without_parallel_fifths(self) -> None:
126 # Bar 1: C4 (60) and G4 (67) — interval of 7
127 # Bar 2: D4 (62) and E4 (64) — interval of 2 (not a fifth)
128 tpb = 480
129 bar_ticks = tpb * 4
130 notes = [
131 _note(60, start_tick=0, duration_ticks=tpb),
132 _note(67, start_tick=0, duration_ticks=tpb),
133 _note(62, start_tick=bar_ticks, duration_ticks=tpb),
134 _note(64, start_tick=bar_ticks, duration_ticks=tpb),
135 ]
136 violations = check_no_parallel_fifths(notes, "t.mid", "fifths", "warning")
137 assert violations == []
138
139 def test_parallel_fifths_detected(self) -> None:
140 # Bar 1: C4 (60) and G4 (67) — perfect fifth
141 # Bar 2: D4 (62) and A4 (69) — perfect fifth, both voices moved up
142 tpb = 480
143 bar_ticks = tpb * 4
144 notes = [
145 _note(60, start_tick=0, duration_ticks=tpb),
146 _note(67, start_tick=0, duration_ticks=tpb),
147 _note(62, start_tick=bar_ticks, duration_ticks=tpb),
148 _note(69, start_tick=bar_ticks, duration_ticks=tpb),
149 ]
150 violations = check_no_parallel_fifths(notes, "t.mid", "fifths", "warning")
151 assert len(violations) >= 1
152 assert violations[0]["rule_name"] == "fifths"
153
154 def test_not_enough_notes_produces_no_violation(self) -> None:
155 notes = [_note(60)]
156 violations = check_no_parallel_fifths(notes, "t.mid", "fifths", "warning")
157 assert violations == []
158
159
160 # ---------------------------------------------------------------------------
161 # load_invariant_rules
162 # ---------------------------------------------------------------------------
163
164
165 class TestLoadInvariantRules:
166 def test_default_rules_returned_when_no_file(self) -> None:
167 rules = load_invariant_rules(None)
168 assert len(rules) >= 1
169 rule_types = {r["rule_type"] for r in rules}
170 assert "max_polyphony" in rule_types
171
172 def test_missing_file_returns_defaults(self, tmp_path: pathlib.Path) -> None:
173 rules = load_invariant_rules(tmp_path / "nonexistent.toml")
174 assert rules
175
176 def test_toml_file_parsed_correctly(self, tmp_path: pathlib.Path) -> None:
177 toml_content = """
178 [[rule]]
179 name = "test_rule"
180 severity = "error"
181 scope = "track"
182 rule_type = "max_polyphony"
183
184 [rule.params]
185 max_simultaneous = 4
186 """
187 rules_file = tmp_path / "invariants.toml"
188 rules_file.write_text(toml_content)
189 rules = load_invariant_rules(rules_file)
190 assert len(rules) == 1
191 assert rules[0]["name"] == "test_rule"
192 assert rules[0]["severity"] == "error"
193 assert rules[0].get("params", {}).get("max_simultaneous") == 4