gabriel / musehub public
test_musehub_objects.py python
334 lines 11.1 KB
cd448303 Initial extraction of MuseHub from maestro monorepo. Gabriel Cardona <gabriel@tellurstori.com> 7d ago
1 """Tests for Muse Hub object endpoints — (piano roll / MIDI parsing).
2
3 Covers:
4 - test_parse_midi_bytes_basic — parser returns MidiParseResult shape
5 - test_parse_midi_bytes_note_data — notes have correct fields
6 - test_parse_midi_bytes_empty_track — empty MIDI file returns zero beats
7 - test_parse_midi_bytes_invalid_data — bad bytes raise ValueError
8 - test_parse_midi_object_endpoint_404 — unknown object returns 404
9 - test_parse_midi_object_non_midi_404 — non-MIDI object returns 404
10 - test_piano_roll_pitch_to_name — pitch_to_name helper correctness
11 """
12 from __future__ import annotations
13
14 import io
15 import os
16 import struct
17 import tempfile
18
19 import mido
20 import pytest
21 from httpx import AsyncClient
22 from sqlalchemy.ext.asyncio import AsyncSession
23
24 from musehub.db.musehub_models import MusehubObject, MusehubRepo
25 from musehub.services.musehub_midi_parser import (
26 MidiNote,
27 MidiParseResult,
28 MidiTrack,
29 parse_midi_bytes,
30 pitch_to_name,
31 )
32
33
34 # ---------------------------------------------------------------------------
35 # MIDI file builder helpers
36 # ---------------------------------------------------------------------------
37
38
39 def _make_simple_midi() -> bytes:
40 """Return a minimal but valid SMF Type-0 MIDI file with one note-on/off."""
41 mid = mido.MidiFile(type=0, ticks_per_beat=480)
42 track = mido.MidiTrack()
43 mid.tracks.append(track)
44 track.append(mido.MetaMessage("set_tempo", tempo=500000, time=0))
45 track.append(mido.Message("note_on", channel=0, note=60, velocity=80, time=0))
46 track.append(mido.Message("note_off", channel=0, note=60, velocity=0, time=480))
47 track.append(mido.MetaMessage("end_of_track", time=0))
48 buf = io.BytesIO()
49 mid.save(file=buf)
50 return buf.getvalue()
51
52
53 def _make_multi_track_midi() -> bytes:
54 """Return an SMF Type-1 file with two tracks."""
55 mid = mido.MidiFile(type=1, ticks_per_beat=480)
56
57 track0 = mido.MidiTrack()
58 mid.tracks.append(track0)
59 track0.append(mido.MetaMessage("set_tempo", tempo=500000, time=0))
60 track0.append(mido.MetaMessage("time_signature", numerator=3, denominator=4, time=0))
61 track0.append(mido.MetaMessage("track_name", name="Piano", time=0))
62 track0.append(mido.MetaMessage("end_of_track", time=0))
63
64 track1 = mido.MidiTrack()
65 mid.tracks.append(track1)
66 track1.append(mido.MetaMessage("track_name", name="Bass", time=0))
67 track1.append(mido.Message("note_on", channel=1, note=36, velocity=100, time=0))
68 track1.append(mido.Message("note_off", channel=1, note=36, velocity=0, time=960))
69 track1.append(mido.MetaMessage("end_of_track", time=0))
70
71 buf = io.BytesIO()
72 mid.save(file=buf)
73 return buf.getvalue()
74
75
76 # ---------------------------------------------------------------------------
77 # Unit tests — musehub_midi_parser
78 # ---------------------------------------------------------------------------
79
80
81 def test_parse_midi_bytes_basic_shape() -> None:
82 """parse_midi_bytes returns all required MidiParseResult keys."""
83 result = parse_midi_bytes(_make_simple_midi())
84 assert "tracks" in result
85 assert "tempo_bpm" in result
86 assert "time_signature" in result
87 assert "total_beats" in result
88
89
90 def test_parse_midi_bytes_note_data() -> None:
91 """Parsed note has correct pitch, velocity, and positive duration."""
92 result = parse_midi_bytes(_make_simple_midi())
93 tracks = result["tracks"]
94 assert len(tracks) >= 1
95 notes = tracks[0]["notes"]
96 assert len(notes) == 1
97 note = notes[0]
98 assert note["pitch"] == 60
99 assert note["velocity"] == 80
100 assert note["duration_beats"] > 0
101 assert note["start_beat"] == 0.0
102 assert note["track_id"] == 0
103 assert note["channel"] == 0
104
105
106 def test_parse_midi_bytes_tempo() -> None:
107 """Default 500000 µs/beat = 120 BPM is parsed correctly."""
108 result = parse_midi_bytes(_make_simple_midi())
109 assert abs(result["tempo_bpm"] - 120.0) < 0.1
110
111
112 def test_parse_midi_bytes_time_signature() -> None:
113 """Time signature from meta message is returned as 'N/D' string."""
114 midi_bytes = _make_multi_track_midi()
115 result = parse_midi_bytes(midi_bytes)
116 assert result["time_signature"] == "3/4"
117
118
119 def test_parse_midi_bytes_multi_track() -> None:
120 """Multi-track MIDI produces one MidiTrack entry per SMF track."""
121 result = parse_midi_bytes(_make_multi_track_midi())
122 assert len(result["tracks"]) == 2
123 # Bass track should have its note
124 bass_track = result["tracks"][1]
125 assert bass_track["name"] == "Bass"
126 assert len(bass_track["notes"]) == 1
127 assert bass_track["notes"][0]["pitch"] == 36
128
129
130 def test_parse_midi_bytes_total_beats_positive() -> None:
131 """total_beats is greater than zero when notes are present."""
132 result = parse_midi_bytes(_make_simple_midi())
133 assert result["total_beats"] > 0
134
135
136 def test_parse_midi_bytes_empty_track() -> None:
137 """An SMF file with no notes returns zero total_beats."""
138 mid = mido.MidiFile(type=0, ticks_per_beat=480)
139 track = mido.MidiTrack()
140 mid.tracks.append(track)
141 track.append(mido.MetaMessage("end_of_track", time=0))
142 buf = io.BytesIO()
143 mid.save(file=buf)
144 result = parse_midi_bytes(buf.getvalue())
145 assert result["total_beats"] == 0.0
146
147
148 def test_parse_midi_bytes_invalid_data_raises() -> None:
149 """Garbage bytes raise ValueError with a descriptive message."""
150 with pytest.raises(ValueError, match="Could not parse MIDI"):
151 parse_midi_bytes(b"\x00\x01\x02\x03garbage")
152
153
154 def test_parse_midi_bytes_notes_sorted_by_start_beat() -> None:
155 """Notes within each track are sorted by start_beat ascending."""
156 mid = mido.MidiFile(type=0, ticks_per_beat=480)
157 track = mido.MidiTrack()
158 mid.tracks.append(track)
159 # Two notes at different beat positions
160 track.append(mido.Message("note_on", channel=0, note=64, velocity=70, time=0))
161 track.append(mido.Message("note_off", channel=0, note=64, velocity=0, time=240))
162 track.append(mido.Message("note_on", channel=0, note=60, velocity=80, time=0))
163 track.append(mido.Message("note_off", channel=0, note=60, velocity=0, time=480))
164 track.append(mido.MetaMessage("end_of_track", time=0))
165 buf = io.BytesIO()
166 mid.save(file=buf)
167 result = parse_midi_bytes(buf.getvalue())
168 notes = result["tracks"][0]["notes"]
169 beats = [n["start_beat"] for n in notes]
170 assert beats == sorted(beats)
171
172
173 def test_pitch_to_name_middle_c() -> None:
174 """MIDI pitch 60 is middle C (C4)."""
175 assert pitch_to_name(60) == "C4"
176
177
178 def test_pitch_to_name_a4() -> None:
179 """MIDI pitch 69 is A4 (concert A)."""
180 assert pitch_to_name(69) == "A4"
181
182
183 def test_pitch_to_name_a0() -> None:
184 """MIDI pitch 21 is A0 (lowest piano key)."""
185 assert pitch_to_name(21) == "A0"
186
187
188 # ---------------------------------------------------------------------------
189 # HTTP endpoint tests — parse-midi route
190 # ---------------------------------------------------------------------------
191
192
193 _OBJ_COUNTER = 0
194
195
196 async def _seed_repo_and_obj(
197 db_session: AsyncSession,
198 disk_path: str = "/nonexistent/track.mid",
199 path: str = "tracks/bass.mid",
200 ) -> tuple[str, str]:
201 """Seed a repo and object; return (repo_id, object_id)."""
202 global _OBJ_COUNTER
203 _OBJ_COUNTER += 1
204 object_id = f"sha256:test{_OBJ_COUNTER:04d}"
205
206 repo = MusehubRepo(
207 name=f"midi-test-{_OBJ_COUNTER}",
208 owner="testuser",
209 slug=f"midi-test-{_OBJ_COUNTER}",
210 visibility="public",
211 owner_user_id="test-owner",
212 )
213 db_session.add(repo)
214 await db_session.commit()
215 await db_session.refresh(repo)
216
217 obj = MusehubObject(
218 object_id=object_id,
219 repo_id=str(repo.repo_id),
220 path=path,
221 size_bytes=0,
222 disk_path=disk_path,
223 )
224 db_session.add(obj)
225 await db_session.commit()
226 await db_session.refresh(obj)
227 return str(repo.repo_id), str(obj.object_id)
228
229
230 @pytest.mark.anyio
231 async def test_parse_midi_object_endpoint_unknown_repo_404(
232 client: AsyncClient,
233 auth_headers: dict[str, str],
234 ) -> None:
235 """GET /parse-midi for an unknown repo_id returns 404."""
236 response = await client.get(
237 "/api/v1/musehub/repos/unknown-repo/objects/unknown-obj/parse-midi",
238 headers=auth_headers,
239 )
240 assert response.status_code == 404
241
242
243 @pytest.mark.anyio
244 async def test_parse_midi_object_endpoint_unknown_object_404(
245 client: AsyncClient,
246 db_session: AsyncSession,
247 auth_headers: dict[str, str],
248 ) -> None:
249 """GET /parse-midi for a missing object_id returns 404."""
250 repo_id, _ = await _seed_repo_and_obj(db_session)
251 response = await client.get(
252 f"/api/v1/musehub/repos/{repo_id}/objects/missing-object-id/parse-midi",
253 headers=auth_headers,
254 )
255 assert response.status_code == 404
256
257
258 @pytest.mark.anyio
259 async def test_parse_midi_object_non_midi_404(
260 client: AsyncClient,
261 db_session: AsyncSession,
262 auth_headers: dict[str, str],
263 ) -> None:
264 """GET /parse-midi for a non-MIDI object (e.g. .mp3) returns 404."""
265 repo_id, obj_id = await _seed_repo_and_obj(
266 db_session, path="tracks/audio.mp3"
267 )
268 response = await client.get(
269 f"/api/v1/musehub/repos/{repo_id}/objects/{obj_id}/parse-midi",
270 headers=auth_headers,
271 )
272 assert response.status_code == 404
273 assert "MIDI" in response.json()["detail"]
274
275
276 @pytest.mark.anyio
277 async def test_parse_midi_object_missing_disk_file_410(
278 client: AsyncClient,
279 db_session: AsyncSession,
280 auth_headers: dict[str, str],
281 ) -> None:
282 """GET /parse-midi when disk file is gone returns 410."""
283 repo_id, obj_id = await _seed_repo_and_obj(
284 db_session, disk_path="/nonexistent/missing.mid", path="missing.mid"
285 )
286 response = await client.get(
287 f"/api/v1/musehub/repos/{repo_id}/objects/{obj_id}/parse-midi",
288 headers=auth_headers,
289 )
290 assert response.status_code == 410
291
292
293 @pytest.mark.anyio
294 async def test_parse_midi_object_returns_valid_result(
295 client: AsyncClient,
296 db_session: AsyncSession,
297 auth_headers: dict[str, str],
298 ) -> None:
299 """GET /parse-midi for a valid MIDI file returns MidiParseResult JSON."""
300 midi_data = _make_simple_midi()
301 with tempfile.NamedTemporaryFile(suffix=".mid", delete=False) as fh:
302 fh.write(midi_data)
303 tmp_path = fh.name
304
305 try:
306 repo_id, obj_id = await _seed_repo_and_obj(
307 db_session, disk_path=tmp_path, path="track.mid"
308 )
309 response = await client.get(
310 f"/api/v1/musehub/repos/{repo_id}/objects/{obj_id}/parse-midi",
311 headers=auth_headers,
312 )
313 assert response.status_code == 200
314 body = response.json()
315 assert "tracks" in body
316 assert "tempo_bpm" in body
317 assert "time_signature" in body
318 assert "total_beats" in body
319 assert body["total_beats"] > 0
320 tracks = body["tracks"]
321 assert isinstance(tracks, list)
322 assert len(tracks) >= 1
323 notes = tracks[0]["notes"]
324 assert isinstance(notes, list)
325 assert len(notes) >= 1
326 note = notes[0]
327 assert "pitch" in note
328 assert "start_beat" in note
329 assert "duration_beats" in note
330 assert "velocity" in note
331 assert "track_id" in note
332 assert "channel" in note
333 finally:
334 os.unlink(tmp_path)