gabriel / muse public
test_domain_schema.py python
262 lines 10.4 KB
8aa515d5 refactor: consolidate schema_version to single source of truth Gabriel Cardona <gabriel@tellurstori.com> 3d ago
1 """Tests for domain schema declaration and plugin registry lookup.
2
3 Verifies that:
4 - ``MidiPlugin.schema()`` returns a fully-typed ``DomainSchema``.
5 - All 21 MIDI dimensions are declared with the correct schema types.
6 - Independence flags match the semantic MIDI merge model.
7 - The schema is JSON round-trippable (all values are JSON-serialisable).
8 - ``schema_for()`` in the plugin registry performs the correct lookup.
9 - The protocol assertion still holds after adding ``schema()``.
10 """
11
12 import json
13
14 import pytest
15
16 from muse._version import __version__
17 from muse.core.schema import DomainSchema
18 from muse.domain import MuseDomainPlugin
19 from muse.plugins.midi.plugin import MidiPlugin
20 from muse.plugins.registry import registered_domains, schema_for
21
22 # ---------------------------------------------------------------------------
23 # Expected dimension layout for the 21-dimension MIDI schema
24 # ---------------------------------------------------------------------------
25
26 # (name, independent_merge, schema_kind)
27 _EXPECTED_DIMS: list[tuple[str, bool, str]] = [
28 # Expressive note content
29 ("notes", True, "sequence"),
30 ("pitch_bend", True, "tensor"),
31 ("channel_pressure", True, "tensor"),
32 ("poly_pressure", True, "tensor"),
33 # Named CC controllers
34 ("cc_modulation", True, "tensor"),
35 ("cc_volume", True, "tensor"),
36 ("cc_pan", True, "tensor"),
37 ("cc_expression", True, "tensor"),
38 ("cc_sustain", True, "tensor"),
39 ("cc_portamento", True, "tensor"),
40 ("cc_sostenuto", True, "tensor"),
41 ("cc_soft_pedal", True, "tensor"),
42 ("cc_reverb", True, "tensor"),
43 ("cc_chorus", True, "tensor"),
44 ("cc_other", True, "tensor"),
45 # Patch selection
46 ("program_change", True, "sequence"),
47 # Non-independent timeline metadata
48 ("tempo_map", False, "sequence"),
49 ("time_signatures", False, "sequence"),
50 # Tonal / annotation metadata
51 ("key_signatures", True, "sequence"),
52 ("markers", True, "sequence"),
53 # Track structure (non-independent)
54 ("track_structure", False, "tree"),
55 ]
56
57 _NON_INDEPENDENT = frozenset(
58 name for name, independent, _ in _EXPECTED_DIMS if not independent
59 )
60
61
62 # ---------------------------------------------------------------------------
63 # Fixtures
64 # ---------------------------------------------------------------------------
65
66
67 @pytest.fixture()
68 def midi_plugin() -> MidiPlugin:
69 return MidiPlugin()
70
71
72 @pytest.fixture()
73 def midi_schema(midi_plugin: MidiPlugin) -> DomainSchema:
74 return midi_plugin.schema()
75
76
77 # ===========================================================================
78 # MidiPlugin.schema() — top-level structure
79 # ===========================================================================
80
81
82 class TestMidiPluginSchema:
83 def test_schema_returns_dict(self, midi_schema: DomainSchema) -> None:
84 assert isinstance(midi_schema, dict)
85
86 def test_domain_is_midi(self, midi_schema: DomainSchema) -> None:
87 assert midi_schema["domain"] == "midi"
88
89 def test_schema_version_matches_package(self, midi_schema: DomainSchema) -> None:
90 assert midi_schema["schema_version"] == __version__
91
92 def test_merge_mode_is_three_way(self, midi_schema: DomainSchema) -> None:
93 assert midi_schema["merge_mode"] == "three_way"
94
95 def test_description_is_non_empty(self, midi_schema: DomainSchema) -> None:
96 assert isinstance(midi_schema["description"], str)
97 assert len(midi_schema["description"]) > 0
98
99 def test_top_level_is_set_schema(self, midi_schema: DomainSchema) -> None:
100 top = midi_schema["top_level"]
101 assert top["kind"] == "set"
102
103 def test_top_level_element_type(self, midi_schema: DomainSchema) -> None:
104 top = midi_schema["top_level"]
105 assert top["kind"] == "set"
106 assert top["element_type"] == "audio_file"
107 assert top["identity"] == "by_content"
108
109
110 # ===========================================================================
111 # 21-dimension layout
112 # ===========================================================================
113
114
115 class TestMidiDimensions:
116 def test_exactly_21_dimensions(self, midi_schema: DomainSchema) -> None:
117 assert len(midi_schema["dimensions"]) == 21
118
119 def test_all_expected_dimension_names_present(self, midi_schema: DomainSchema) -> None:
120 names = {d["name"] for d in midi_schema["dimensions"]}
121 expected = {name for name, _, _ in _EXPECTED_DIMS}
122 assert names == expected
123
124 def test_dimension_order_matches_spec(self, midi_schema: DomainSchema) -> None:
125 names = [d["name"] for d in midi_schema["dimensions"]]
126 expected = [name for name, _, _ in _EXPECTED_DIMS]
127 assert names == expected
128
129 @pytest.mark.parametrize("name,independent,kind", _EXPECTED_DIMS)
130 def test_dimension_independence(
131 self, midi_schema: DomainSchema, name: str, independent: bool, kind: str
132 ) -> None:
133 dim = next(d for d in midi_schema["dimensions"] if d["name"] == name)
134 assert dim["independent_merge"] is independent, (
135 f"Dimension '{name}': expected independent_merge={independent}"
136 )
137
138 @pytest.mark.parametrize("name,independent,kind", _EXPECTED_DIMS)
139 def test_dimension_schema_kind(
140 self, midi_schema: DomainSchema, name: str, independent: bool, kind: str
141 ) -> None:
142 dim = next(d for d in midi_schema["dimensions"] if d["name"] == name)
143 assert dim["schema"]["kind"] == kind, (
144 f"Dimension '{name}': expected schema kind '{kind}', got '{dim['schema']['kind']}'"
145 )
146
147 def test_all_dimensions_have_description(self, midi_schema: DomainSchema) -> None:
148 for dim in midi_schema["dimensions"]:
149 assert isinstance(dim.get("description"), str), (
150 f"Dimension '{dim['name']}' missing description"
151 )
152 assert len(dim["description"]) > 0
153
154 def test_non_independent_set(self, midi_schema: DomainSchema) -> None:
155 non_indep = {
156 d["name"] for d in midi_schema["dimensions"] if not d["independent_merge"]
157 }
158 assert non_indep == _NON_INDEPENDENT
159
160 def test_notes_dimension_sequence_fields(self, midi_schema: DomainSchema) -> None:
161 notes = next(d for d in midi_schema["dimensions"] if d["name"] == "notes")
162 schema = notes["schema"]
163 assert schema["kind"] == "sequence"
164 assert schema["element_type"] == "note_event"
165 assert schema["diff_algorithm"] == "lcs"
166
167 def test_cc_dimensions_are_tensor_float32(self, midi_schema: DomainSchema) -> None:
168 cc_names = {name for name, _, kind in _EXPECTED_DIMS if kind == "tensor"}
169 for dim in midi_schema["dimensions"]:
170 if dim["name"] in cc_names:
171 s = dim["schema"]
172 assert s["kind"] == "tensor"
173 assert s["dtype"] == "float32"
174 assert s["diff_mode"] == "sparse"
175
176 def test_track_structure_is_tree(self, midi_schema: DomainSchema) -> None:
177 ts = next(d for d in midi_schema["dimensions"] if d["name"] == "track_structure")
178 schema = ts["schema"]
179 assert schema["kind"] == "tree"
180 assert schema["node_type"] == "track_node"
181 assert schema["diff_algorithm"] == "zhang_shasha"
182
183
184 # ===========================================================================
185 # JSON round-trip
186 # ===========================================================================
187
188
189 class TestSchemaJsonRoundtrip:
190 def test_schema_is_json_serialisable(self, midi_schema: DomainSchema) -> None:
191 serialised = json.dumps(midi_schema)
192 restored = json.loads(serialised)
193 assert restored["domain"] == midi_schema["domain"]
194 assert restored["schema_version"] == midi_schema["schema_version"]
195 assert len(restored["dimensions"]) == len(midi_schema["dimensions"])
196 assert restored["top_level"]["kind"] == midi_schema["top_level"]["kind"]
197
198 def test_all_dimension_schemas_survive_roundtrip(self, midi_schema: DomainSchema) -> None:
199 serialised = json.dumps(midi_schema)
200 restored = json.loads(serialised)
201 original_kinds = {d["name"]: d["schema"]["kind"] for d in midi_schema["dimensions"]}
202 restored_kinds = {d["name"]: d["schema"]["kind"] for d in restored["dimensions"]}
203 assert original_kinds == restored_kinds
204
205
206 # ===========================================================================
207 # Plugin registry schema lookup
208 # ===========================================================================
209
210
211 class TestPluginRegistrySchemaLookup:
212 def test_schema_for_midi_returns_domain_schema(self) -> None:
213 result = schema_for("midi")
214 assert result is not None
215 assert result["domain"] == "midi"
216
217 def test_schema_for_unknown_domain_returns_none(self) -> None:
218 result = schema_for("nonexistent_domain_xyz")
219 assert result is None
220
221 def test_schema_for_matches_direct_plugin_call(self) -> None:
222 plugin = MidiPlugin()
223 direct = plugin.schema()
224 via_registry = schema_for("midi")
225 assert via_registry is not None
226 assert via_registry["domain"] == direct["domain"]
227 assert via_registry["schema_version"] == direct["schema_version"]
228 assert len(via_registry["dimensions"]) == len(direct["dimensions"])
229
230 def test_registered_domains_contains_midi(self) -> None:
231 assert "midi" in registered_domains()
232
233 def test_music_key_not_in_registry(self) -> None:
234 """Ensure the old 'music' key was fully removed."""
235 assert "music" not in registered_domains()
236
237 def test_schema_for_all_registered_domains_returns_non_none(self) -> None:
238 for domain in registered_domains():
239 result = schema_for(domain)
240 assert result is not None, f"schema_for({domain!r}) returned None"
241
242
243 # ===========================================================================
244 # Protocol conformance
245 # ===========================================================================
246
247
248 class TestProtocolConformance:
249 def test_midi_plugin_satisfies_protocol(self) -> None:
250 plugin = MidiPlugin()
251 assert isinstance(plugin, MuseDomainPlugin)
252
253 def test_schema_method_is_callable(self) -> None:
254 plugin = MidiPlugin()
255 assert callable(plugin.schema)
256
257 def test_schema_returns_domain_schema(self) -> None:
258 plugin = MidiPlugin()
259 result = plugin.schema()
260 assert isinstance(result, dict)
261 assert "domain" in result
262 assert "dimensions" in result