cgcardona / muse public
test_entity.py python
189 lines 7.2 KB
e6786943 feat: upgrade to Python 3.14, drop from __future__ import annotations Gabriel Cardona <cgcardona@gmail.com> 1d ago
1 """Tests for muse.plugins.midi.entity — NoteEntity, EntityIndex, assign/diff."""
2
3 import pathlib
4
5 import pytest
6
7 from muse.plugins.midi.entity import (
8 EntityIndex,
9 EntityIndexEntry,
10 NoteEntity,
11 assign_entity_ids,
12 build_entity_index,
13 diff_with_entity_ids,
14 read_entity_index,
15 write_entity_index,
16 )
17 from muse.plugins.midi.midi_diff import NoteKey, _note_content_id
18
19
20 def _key(pitch: int = 60, velocity: int = 80, start_tick: int = 0,
21 duration_ticks: int = 480, channel: int = 0) -> NoteKey:
22 return NoteKey(pitch=pitch, velocity=velocity, start_tick=start_tick,
23 duration_ticks=duration_ticks, channel=channel)
24
25
26 # ---------------------------------------------------------------------------
27 # assign_entity_ids — first commit (no prior index)
28 # ---------------------------------------------------------------------------
29
30
31 class TestAssignEntityIdsFirstCommit:
32 def test_all_notes_get_new_entity_ids(self) -> None:
33 notes = [_key(60), _key(62), _key(64)]
34 entities = assign_entity_ids(notes, None, commit_id="c001", op_id="op1")
35 assert len(entities) == 3
36 eids = {e["entity_id"] for e in entities}
37 assert len(eids) == 3 # all distinct
38
39 def test_entity_has_correct_pitch_fields(self) -> None:
40 note = _key(pitch=67, velocity=100)
41 entities = assign_entity_ids([note], None, commit_id="c001", op_id="op1")
42 assert entities[0]["pitch"] == 67
43 assert entities[0]["velocity"] == 100
44
45 def test_origin_commit_id_set(self) -> None:
46 note = _key(60)
47 entities = assign_entity_ids([note], None, commit_id="commit-abc", op_id="op1")
48 assert entities[0]["origin_commit_id"] == "commit-abc"
49
50
51 # ---------------------------------------------------------------------------
52 # assign_entity_ids — with prior index (exact match)
53 # ---------------------------------------------------------------------------
54
55
56 class TestAssignEntityIdsExactMatch:
57 def _make_index(self, notes: list[NoteKey], commit_id: str) -> EntityIndex:
58 entities = assign_entity_ids(notes, None, commit_id=commit_id, op_id="op0")
59 return build_entity_index(entities, "track.mid", commit_id)
60
61 def test_same_notes_get_same_entity_ids(self) -> None:
62 notes = [_key(60), _key(62), _key(64)]
63 prior = self._make_index(notes, "c001")
64 prior_eids = set(prior["entities"].keys())
65
66 entities = assign_entity_ids(notes, prior, commit_id="c002", op_id="op2")
67 new_eids = {e["entity_id"] for e in entities}
68 assert new_eids == prior_eids
69
70 def test_added_note_gets_new_entity_id(self) -> None:
71 notes = [_key(60), _key(62)]
72 prior = self._make_index(notes, "c001")
73 prior_eids = set(prior["entities"].keys())
74
75 new_notes = [_key(60), _key(62), _key(64)]
76 entities = assign_entity_ids(new_notes, prior, commit_id="c002", op_id="op2")
77 new_eids = {e["entity_id"] for e in entities}
78
79 # The two original notes should retain their IDs.
80 assert prior_eids.issubset(new_eids)
81 # The new note gets a fresh ID.
82 assert len(new_eids - prior_eids) == 1
83
84
85 # ---------------------------------------------------------------------------
86 # diff_with_entity_ids
87 # ---------------------------------------------------------------------------
88
89
90 class TestDiffWithEntityIds:
91 def _entities_from(self, notes: list[NoteKey]) -> list[NoteEntity]:
92 return assign_entity_ids(notes, None, commit_id="c001", op_id="op1")
93
94 def test_no_change_produces_no_ops(self) -> None:
95 notes = [_key(60), _key(62), _key(64)]
96 base = self._entities_from(notes)
97 target = assign_entity_ids(
98 notes,
99 build_entity_index(base, "track.mid", "c001"),
100 commit_id="c002",
101 op_id="op2",
102 )
103 ops = diff_with_entity_ids(base, target, 480)
104 assert ops == []
105
106 def test_added_note_produces_insert(self) -> None:
107 base = self._entities_from([_key(60)])
108 target_notes = [_key(60), _key(64)]
109 target = assign_entity_ids(
110 target_notes,
111 build_entity_index(base, "track.mid", "c001"),
112 commit_id="c002",
113 op_id="op2",
114 )
115 ops = diff_with_entity_ids(base, target, 480)
116 op_types = [o["op"] for o in ops]
117 assert "insert" in op_types
118 assert "delete" not in op_types
119
120 def test_removed_note_produces_delete(self) -> None:
121 base = self._entities_from([_key(60), _key(64)])
122 target_notes = [_key(60)]
123 target = assign_entity_ids(
124 target_notes,
125 build_entity_index(base, "track.mid", "c001"),
126 commit_id="c002",
127 op_id="op2",
128 )
129 ops = diff_with_entity_ids(base, target, 480)
130 op_types = [o["op"] for o in ops]
131 assert "delete" in op_types
132 assert "insert" not in op_types
133
134 def test_velocity_change_produces_mutate(self) -> None:
135 base_note = _key(pitch=60, velocity=80)
136 base = self._entities_from([base_note])
137 prior_index = build_entity_index(base, "track.mid", "c001")
138
139 # Change velocity only.
140 changed = _key(pitch=60, velocity=100)
141 target = assign_entity_ids(
142 [changed],
143 prior_index,
144 commit_id="c002",
145 op_id="op2",
146 mutation_threshold_ticks=20,
147 mutation_threshold_velocity=30,
148 )
149 ops = diff_with_entity_ids(base, target, 480)
150 op_types = [o["op"] for o in ops]
151 # May produce mutate or insert/delete depending on match heuristic.
152 # Accept either — the key test is no crash and some op is emitted.
153 assert len(ops) > 0
154
155
156 # ---------------------------------------------------------------------------
157 # EntityIndex I/O
158 # ---------------------------------------------------------------------------
159
160
161 class TestEntityIndexIO:
162 def test_write_and_read_roundtrip(self, tmp_path: pathlib.Path) -> None:
163 notes = [_key(60), _key(62), _key(64)]
164 entities = assign_entity_ids(notes, None, commit_id="c001", op_id="op1")
165 index = build_entity_index(entities, "track.mid", "c001")
166
167 write_entity_index(tmp_path, "c001", "track.mid", index)
168 recovered = read_entity_index(tmp_path, "c001", "track.mid")
169
170 assert recovered is not None
171 assert set(recovered["entities"].keys()) == set(index["entities"].keys())
172
173 def test_read_missing_returns_none(self, tmp_path: pathlib.Path) -> None:
174 result = read_entity_index(tmp_path, "nonexistent", "track.mid")
175 assert result is None
176
177 def test_index_has_all_entities(self, tmp_path: pathlib.Path) -> None:
178 notes = [_key(60 + i) for i in range(5)]
179 entities = assign_entity_ids(notes, None, commit_id="c001", op_id="op1")
180 index = build_entity_index(entities, "track.mid", "c001")
181 assert len(index["entities"]) == 5
182
183 def test_index_content_ids_match_notes(self) -> None:
184 note = _key(60, 80)
185 entities = assign_entity_ids([note], None, commit_id="c001", op_id="op1")
186 index = build_entity_index(entities, "track.mid", "c001")
187 eid = list(index["entities"].keys())[0]
188 expected_cid = _note_content_id(note)
189 assert index["entities"][eid]["content_id"] == expected_cid