gabriel / musehub public
test_musehub_objects.py python
335 lines 11.1 KB
c0f0b481 release: merge dev → main (#5) Gabriel Cardona <cgcardona@gmail.com> 5d ago
1 """Tests for MuseHub object endpoints — (piano roll / MIDI parsing).
2
3 Covers:
4 - test_parse_midi_bytes_basic — parser returns MidiParseResult shape
5 - test_parse_midi_bytes_note_data — notes have correct fields
6 - test_parse_midi_bytes_empty_track — empty MIDI file returns zero beats
7 - test_parse_midi_bytes_invalid_data — bad bytes raise ValueError
8 - test_parse_midi_object_endpoint_404 — unknown object returns 404
9 - test_parse_midi_object_non_midi_404 — non-MIDI object returns 404
10 - test_piano_roll_pitch_to_name — pitch_to_name helper correctness
11 """
12 from __future__ import annotations
13
14 import io
15 import os
16 import struct
17 import tempfile
18
19 import mido
20 import pytest
21 from httpx import AsyncClient
22 from sqlalchemy.ext.asyncio import AsyncSession
23
24 from musehub.db.musehub_models import MusehubObject, MusehubRepo
25 from musehub.services.musehub_midi_parser import (
26 MidiNote,
27 MidiParseResult,
28 MidiTrack,
29 parse_midi_bytes,
30 pitch_to_name,
31 )
32
33
34 # ---------------------------------------------------------------------------
35 # MIDI file builder helpers
36 # ---------------------------------------------------------------------------
37
38
39 def _make_simple_midi() -> bytes:
40 """Return a minimal but valid SMF Type-0 MIDI file with one note-on/off."""
41 mid = mido.MidiFile(type=0, ticks_per_beat=480)
42 track = mido.MidiTrack()
43 mid.tracks.append(track)
44 track.append(mido.MetaMessage("set_tempo", tempo=500000, time=0))
45 track.append(mido.Message("note_on", channel=0, note=60, velocity=80, time=0))
46 track.append(mido.Message("note_off", channel=0, note=60, velocity=0, time=480))
47 track.append(mido.MetaMessage("end_of_track", time=0))
48 buf = io.BytesIO()
49 mid.save(file=buf)
50 return buf.getvalue()
51
52
53 def _make_multi_track_midi() -> bytes:
54 """Return an SMF Type-1 file with two tracks."""
55 mid = mido.MidiFile(type=1, ticks_per_beat=480)
56
57 track0 = mido.MidiTrack()
58 mid.tracks.append(track0)
59 track0.append(mido.MetaMessage("set_tempo", tempo=500000, time=0))
60 track0.append(mido.MetaMessage("time_signature", numerator=3, denominator=4, time=0))
61 track0.append(mido.MetaMessage("track_name", name="Piano", time=0))
62 track0.append(mido.MetaMessage("end_of_track", time=0))
63
64 track1 = mido.MidiTrack()
65 mid.tracks.append(track1)
66 track1.append(mido.MetaMessage("track_name", name="Bass", time=0))
67 track1.append(mido.Message("note_on", channel=1, note=36, velocity=100, time=0))
68 track1.append(mido.Message("note_off", channel=1, note=36, velocity=0, time=960))
69 track1.append(mido.MetaMessage("end_of_track", time=0))
70
71 buf = io.BytesIO()
72 mid.save(file=buf)
73 return buf.getvalue()
74
75
76 # ---------------------------------------------------------------------------
77 # Unit tests — musehub_midi_parser
78 # ---------------------------------------------------------------------------
79
80
81 def test_parse_midi_bytes_basic_shape() -> None:
82 """parse_midi_bytes returns all required MidiParseResult keys."""
83 result = parse_midi_bytes(_make_simple_midi())
84 assert "tracks" in result
85 assert "tempo_bpm" in result
86 assert "time_signature" in result
87 assert "total_beats" in result
88
89
90 def test_parse_midi_bytes_note_data() -> None:
91 """Parsed note has correct pitch, velocity, and positive duration."""
92 result = parse_midi_bytes(_make_simple_midi())
93 tracks = result["tracks"]
94 assert len(tracks) >= 1
95 notes = tracks[0]["notes"]
96 assert len(notes) == 1
97 note = notes[0]
98 assert note["pitch"] == 60
99 assert note["velocity"] == 80
100 assert note["duration_beats"] > 0
101 assert note["start_beat"] == 0.0
102 assert note["track_id"] == 0
103 assert note["channel"] == 0
104
105
106 def test_parse_midi_bytes_tempo() -> None:
107 """Default 500000 µs/beat = 120 BPM is parsed correctly."""
108 result = parse_midi_bytes(_make_simple_midi())
109 assert abs(result["tempo_bpm"] - 120.0) < 0.1
110
111
112 def test_parse_midi_bytes_time_signature() -> None:
113 """Time signature from meta message is returned as 'N/D' string."""
114 midi_bytes = _make_multi_track_midi()
115 result = parse_midi_bytes(midi_bytes)
116 assert result["time_signature"] == "3/4"
117
118
119 def test_parse_midi_bytes_multi_track() -> None:
120 """Multi-track MIDI produces one MidiTrack entry per SMF track."""
121 result = parse_midi_bytes(_make_multi_track_midi())
122 assert len(result["tracks"]) == 2
123 # Bass track should have its note
124 bass_track = result["tracks"][1]
125 assert bass_track["name"] == "Bass"
126 assert len(bass_track["notes"]) == 1
127 assert bass_track["notes"][0]["pitch"] == 36
128
129
130 def test_parse_midi_bytes_total_beats_positive() -> None:
131 """total_beats is greater than zero when notes are present."""
132 result = parse_midi_bytes(_make_simple_midi())
133 assert result["total_beats"] > 0
134
135
136 def test_parse_midi_bytes_empty_track() -> None:
137 """An SMF file with no notes returns zero total_beats."""
138 mid = mido.MidiFile(type=0, ticks_per_beat=480)
139 track = mido.MidiTrack()
140 mid.tracks.append(track)
141 track.append(mido.MetaMessage("end_of_track", time=0))
142 buf = io.BytesIO()
143 mid.save(file=buf)
144 result = parse_midi_bytes(buf.getvalue())
145 assert result["total_beats"] == 0.0
146
147
148 def test_parse_midi_bytes_invalid_data_raises() -> None:
149 """Garbage bytes raise ValueError with a descriptive message."""
150 with pytest.raises(ValueError, match="Could not parse MIDI"):
151 parse_midi_bytes(b"\x00\x01\x02\x03garbage")
152
153
154 def test_parse_midi_bytes_notes_sorted_by_start_beat() -> None:
155 """Notes within each track are sorted by start_beat ascending."""
156 mid = mido.MidiFile(type=0, ticks_per_beat=480)
157 track = mido.MidiTrack()
158 mid.tracks.append(track)
159 # Two notes at different beat positions
160 track.append(mido.Message("note_on", channel=0, note=64, velocity=70, time=0))
161 track.append(mido.Message("note_off", channel=0, note=64, velocity=0, time=240))
162 track.append(mido.Message("note_on", channel=0, note=60, velocity=80, time=0))
163 track.append(mido.Message("note_off", channel=0, note=60, velocity=0, time=480))
164 track.append(mido.MetaMessage("end_of_track", time=0))
165 buf = io.BytesIO()
166 mid.save(file=buf)
167 result = parse_midi_bytes(buf.getvalue())
168 notes = result["tracks"][0]["notes"]
169 beats = [n["start_beat"] for n in notes]
170 assert beats == sorted(beats)
171
172
173 def test_pitch_to_name_middle_c() -> None:
174 """MIDI pitch 60 is middle C (C4)."""
175 assert pitch_to_name(60) == "C4"
176
177
178 def test_pitch_to_name_a4() -> None:
179 """MIDI pitch 69 is A4 (concert A)."""
180 assert pitch_to_name(69) == "A4"
181
182
183 def test_pitch_to_name_a0() -> None:
184 """MIDI pitch 21 is A0 (lowest piano key)."""
185 assert pitch_to_name(21) == "A0"
186
187
188 # ---------------------------------------------------------------------------
189 # HTTP endpoint tests — parse-midi route
190 # ---------------------------------------------------------------------------
191
192
193 async def _seed_repo_and_obj(
194 db_session: AsyncSession,
195 disk_path: str = "/nonexistent/track.mid",
196 path: str = "tracks/bass.mid",
197 ) -> tuple[str, str]:
198 """Seed a repo and object; return (repo_id, object_id).
199
200 Uses a uuid4-derived suffix so each call produces unique slugs even when
201 tests are run in parallel across multiple xdist workers.
202 """
203 import uuid
204 suffix = uuid.uuid4().hex[:8]
205 object_id = f"sha256:test{suffix}"
206
207 repo = MusehubRepo(
208 name=f"midi-test-{suffix}",
209 owner="testuser",
210 slug=f"midi-test-{suffix}",
211 visibility="public",
212 owner_user_id="test-owner",
213 )
214 db_session.add(repo)
215 await db_session.commit()
216 await db_session.refresh(repo)
217
218 obj = MusehubObject(
219 object_id=object_id,
220 repo_id=str(repo.repo_id),
221 path=path,
222 size_bytes=0,
223 disk_path=disk_path,
224 )
225 db_session.add(obj)
226 await db_session.commit()
227 await db_session.refresh(obj)
228 return str(repo.repo_id), str(obj.object_id)
229
230
231 @pytest.mark.anyio
232 async def test_parse_midi_object_endpoint_unknown_repo_404(
233 client: AsyncClient,
234 auth_headers: dict[str, str],
235 ) -> None:
236 """GET /parse-midi for an unknown repo_id returns 404."""
237 response = await client.get(
238 "/api/v1/repos/unknown-repo/objects/unknown-obj/parse-midi",
239 headers=auth_headers,
240 )
241 assert response.status_code == 404
242
243
244 @pytest.mark.anyio
245 async def test_parse_midi_object_endpoint_unknown_object_404(
246 client: AsyncClient,
247 db_session: AsyncSession,
248 auth_headers: dict[str, str],
249 ) -> None:
250 """GET /parse-midi for a missing object_id returns 404."""
251 repo_id, _ = await _seed_repo_and_obj(db_session)
252 response = await client.get(
253 f"/api/v1/repos/{repo_id}/objects/missing-object-id/parse-midi",
254 headers=auth_headers,
255 )
256 assert response.status_code == 404
257
258
259 @pytest.mark.anyio
260 async def test_parse_midi_object_non_midi_404(
261 client: AsyncClient,
262 db_session: AsyncSession,
263 auth_headers: dict[str, str],
264 ) -> None:
265 """GET /parse-midi for a non-MIDI object (e.g. .mp3) returns 404."""
266 repo_id, obj_id = await _seed_repo_and_obj(
267 db_session, path="tracks/audio.mp3"
268 )
269 response = await client.get(
270 f"/api/v1/repos/{repo_id}/objects/{obj_id}/parse-midi",
271 headers=auth_headers,
272 )
273 assert response.status_code == 404
274 assert "MIDI" in response.json()["detail"]
275
276
277 @pytest.mark.anyio
278 async def test_parse_midi_object_missing_disk_file_410(
279 client: AsyncClient,
280 db_session: AsyncSession,
281 auth_headers: dict[str, str],
282 ) -> None:
283 """GET /parse-midi when disk file is gone returns 410."""
284 repo_id, obj_id = await _seed_repo_and_obj(
285 db_session, disk_path="/nonexistent/missing.mid", path="missing.mid"
286 )
287 response = await client.get(
288 f"/api/v1/repos/{repo_id}/objects/{obj_id}/parse-midi",
289 headers=auth_headers,
290 )
291 assert response.status_code == 410
292
293
294 @pytest.mark.anyio
295 async def test_parse_midi_object_returns_valid_result(
296 client: AsyncClient,
297 db_session: AsyncSession,
298 auth_headers: dict[str, str],
299 ) -> None:
300 """GET /parse-midi for a valid MIDI file returns MidiParseResult JSON."""
301 midi_data = _make_simple_midi()
302 with tempfile.NamedTemporaryFile(suffix=".mid", delete=False) as fh:
303 fh.write(midi_data)
304 tmp_path = fh.name
305
306 try:
307 repo_id, obj_id = await _seed_repo_and_obj(
308 db_session, disk_path=tmp_path, path="track.mid"
309 )
310 response = await client.get(
311 f"/api/v1/repos/{repo_id}/objects/{obj_id}/parse-midi",
312 headers=auth_headers,
313 )
314 assert response.status_code == 200
315 body = response.json()
316 assert "tracks" in body
317 assert "tempo_bpm" in body
318 assert "time_signature" in body
319 assert "total_beats" in body
320 assert body["total_beats"] > 0
321 tracks = body["tracks"]
322 assert isinstance(tracks, list)
323 assert len(tracks) >= 1
324 notes = tracks[0]["notes"]
325 assert isinstance(notes, list)
326 assert len(notes) >= 1
327 note = notes[0]
328 assert "pitch" in note
329 assert "start_beat" in note
330 assert "duration_beats" in note
331 assert "velocity" in note
332 assert "track_id" in note
333 assert "channel" in note
334 finally:
335 os.unlink(tmp_path)