cgcardona / muse public
test_entity.py python
190 lines 7.3 KB
6d8ca4ac feat: god-tier MIDI dimension expansion + full supercharge architecture Gabriel Cardona <gabriel@tellurstori.com> 1d ago
1 """Tests for muse.plugins.music.entity — NoteEntity, EntityIndex, assign/diff."""
2 from __future__ import annotations
3
4 import pathlib
5
6 import pytest
7
8 from muse.plugins.music.entity import (
9 EntityIndex,
10 EntityIndexEntry,
11 NoteEntity,
12 assign_entity_ids,
13 build_entity_index,
14 diff_with_entity_ids,
15 read_entity_index,
16 write_entity_index,
17 )
18 from muse.plugins.music.midi_diff import NoteKey, _note_content_id
19
20
21 def _key(pitch: int = 60, velocity: int = 80, start_tick: int = 0,
22 duration_ticks: int = 480, channel: int = 0) -> NoteKey:
23 return NoteKey(pitch=pitch, velocity=velocity, start_tick=start_tick,
24 duration_ticks=duration_ticks, channel=channel)
25
26
27 # ---------------------------------------------------------------------------
28 # assign_entity_ids — first commit (no prior index)
29 # ---------------------------------------------------------------------------
30
31
32 class TestAssignEntityIdsFirstCommit:
33 def test_all_notes_get_new_entity_ids(self) -> None:
34 notes = [_key(60), _key(62), _key(64)]
35 entities = assign_entity_ids(notes, None, commit_id="c001", op_id="op1")
36 assert len(entities) == 3
37 eids = {e["entity_id"] for e in entities}
38 assert len(eids) == 3 # all distinct
39
40 def test_entity_has_correct_pitch_fields(self) -> None:
41 note = _key(pitch=67, velocity=100)
42 entities = assign_entity_ids([note], None, commit_id="c001", op_id="op1")
43 assert entities[0]["pitch"] == 67
44 assert entities[0]["velocity"] == 100
45
46 def test_origin_commit_id_set(self) -> None:
47 note = _key(60)
48 entities = assign_entity_ids([note], None, commit_id="commit-abc", op_id="op1")
49 assert entities[0]["origin_commit_id"] == "commit-abc"
50
51
52 # ---------------------------------------------------------------------------
53 # assign_entity_ids — with prior index (exact match)
54 # ---------------------------------------------------------------------------
55
56
57 class TestAssignEntityIdsExactMatch:
58 def _make_index(self, notes: list[NoteKey], commit_id: str) -> EntityIndex:
59 entities = assign_entity_ids(notes, None, commit_id=commit_id, op_id="op0")
60 return build_entity_index(entities, "track.mid", commit_id)
61
62 def test_same_notes_get_same_entity_ids(self) -> None:
63 notes = [_key(60), _key(62), _key(64)]
64 prior = self._make_index(notes, "c001")
65 prior_eids = set(prior["entities"].keys())
66
67 entities = assign_entity_ids(notes, prior, commit_id="c002", op_id="op2")
68 new_eids = {e["entity_id"] for e in entities}
69 assert new_eids == prior_eids
70
71 def test_added_note_gets_new_entity_id(self) -> None:
72 notes = [_key(60), _key(62)]
73 prior = self._make_index(notes, "c001")
74 prior_eids = set(prior["entities"].keys())
75
76 new_notes = [_key(60), _key(62), _key(64)]
77 entities = assign_entity_ids(new_notes, prior, commit_id="c002", op_id="op2")
78 new_eids = {e["entity_id"] for e in entities}
79
80 # The two original notes should retain their IDs.
81 assert prior_eids.issubset(new_eids)
82 # The new note gets a fresh ID.
83 assert len(new_eids - prior_eids) == 1
84
85
86 # ---------------------------------------------------------------------------
87 # diff_with_entity_ids
88 # ---------------------------------------------------------------------------
89
90
91 class TestDiffWithEntityIds:
92 def _entities_from(self, notes: list[NoteKey]) -> list[NoteEntity]:
93 return assign_entity_ids(notes, None, commit_id="c001", op_id="op1")
94
95 def test_no_change_produces_no_ops(self) -> None:
96 notes = [_key(60), _key(62), _key(64)]
97 base = self._entities_from(notes)
98 target = assign_entity_ids(
99 notes,
100 build_entity_index(base, "track.mid", "c001"),
101 commit_id="c002",
102 op_id="op2",
103 )
104 ops = diff_with_entity_ids(base, target, 480)
105 assert ops == []
106
107 def test_added_note_produces_insert(self) -> None:
108 base = self._entities_from([_key(60)])
109 target_notes = [_key(60), _key(64)]
110 target = assign_entity_ids(
111 target_notes,
112 build_entity_index(base, "track.mid", "c001"),
113 commit_id="c002",
114 op_id="op2",
115 )
116 ops = diff_with_entity_ids(base, target, 480)
117 op_types = [o["op"] for o in ops]
118 assert "insert" in op_types
119 assert "delete" not in op_types
120
121 def test_removed_note_produces_delete(self) -> None:
122 base = self._entities_from([_key(60), _key(64)])
123 target_notes = [_key(60)]
124 target = assign_entity_ids(
125 target_notes,
126 build_entity_index(base, "track.mid", "c001"),
127 commit_id="c002",
128 op_id="op2",
129 )
130 ops = diff_with_entity_ids(base, target, 480)
131 op_types = [o["op"] for o in ops]
132 assert "delete" in op_types
133 assert "insert" not in op_types
134
135 def test_velocity_change_produces_mutate(self) -> None:
136 base_note = _key(pitch=60, velocity=80)
137 base = self._entities_from([base_note])
138 prior_index = build_entity_index(base, "track.mid", "c001")
139
140 # Change velocity only.
141 changed = _key(pitch=60, velocity=100)
142 target = assign_entity_ids(
143 [changed],
144 prior_index,
145 commit_id="c002",
146 op_id="op2",
147 mutation_threshold_ticks=20,
148 mutation_threshold_velocity=30,
149 )
150 ops = diff_with_entity_ids(base, target, 480)
151 op_types = [o["op"] for o in ops]
152 # May produce mutate or insert/delete depending on match heuristic.
153 # Accept either — the key test is no crash and some op is emitted.
154 assert len(ops) > 0
155
156
157 # ---------------------------------------------------------------------------
158 # EntityIndex I/O
159 # ---------------------------------------------------------------------------
160
161
162 class TestEntityIndexIO:
163 def test_write_and_read_roundtrip(self, tmp_path: pathlib.Path) -> None:
164 notes = [_key(60), _key(62), _key(64)]
165 entities = assign_entity_ids(notes, None, commit_id="c001", op_id="op1")
166 index = build_entity_index(entities, "track.mid", "c001")
167
168 write_entity_index(tmp_path, "c001", "track.mid", index)
169 recovered = read_entity_index(tmp_path, "c001", "track.mid")
170
171 assert recovered is not None
172 assert set(recovered["entities"].keys()) == set(index["entities"].keys())
173
174 def test_read_missing_returns_none(self, tmp_path: pathlib.Path) -> None:
175 result = read_entity_index(tmp_path, "nonexistent", "track.mid")
176 assert result is None
177
178 def test_index_has_all_entities(self, tmp_path: pathlib.Path) -> None:
179 notes = [_key(60 + i) for i in range(5)]
180 entities = assign_entity_ids(notes, None, commit_id="c001", op_id="op1")
181 index = build_entity_index(entities, "track.mid", "c001")
182 assert len(index["entities"]) == 5
183
184 def test_index_content_ids_match_notes(self) -> None:
185 note = _key(60, 80)
186 entities = assign_entity_ids([note], None, commit_id="c001", op_id="op1")
187 index = build_entity_index(entities, "track.mid", "c001")
188 eid = list(index["entities"].keys())[0]
189 expected_cid = _note_content_id(note)
190 assert index["entities"][eid]["content_id"] == expected_cid