test_musehub_objects.py
python
| 1 | """Tests for Muse Hub object endpoints — (piano roll / MIDI parsing). |
| 2 | |
| 3 | Covers: |
| 4 | - test_parse_midi_bytes_basic — parser returns MidiParseResult shape |
| 5 | - test_parse_midi_bytes_note_data — notes have correct fields |
| 6 | - test_parse_midi_bytes_empty_track — empty MIDI file returns zero beats |
| 7 | - test_parse_midi_bytes_invalid_data — bad bytes raise ValueError |
| 8 | - test_parse_midi_object_endpoint_404 — unknown object returns 404 |
| 9 | - test_parse_midi_object_non_midi_404 — non-MIDI object returns 404 |
| 10 | - test_piano_roll_pitch_to_name — pitch_to_name helper correctness |
| 11 | """ |
| 12 | from __future__ import annotations |
| 13 | |
| 14 | import io |
| 15 | import os |
| 16 | import struct |
| 17 | import tempfile |
| 18 | |
| 19 | import mido |
| 20 | import pytest |
| 21 | from httpx import AsyncClient |
| 22 | from sqlalchemy.ext.asyncio import AsyncSession |
| 23 | |
| 24 | from musehub.db.musehub_models import MusehubObject, MusehubRepo |
| 25 | from musehub.services.musehub_midi_parser import ( |
| 26 | MidiNote, |
| 27 | MidiParseResult, |
| 28 | MidiTrack, |
| 29 | parse_midi_bytes, |
| 30 | pitch_to_name, |
| 31 | ) |
| 32 | |
| 33 | |
| 34 | # --------------------------------------------------------------------------- |
| 35 | # MIDI file builder helpers |
| 36 | # --------------------------------------------------------------------------- |
| 37 | |
| 38 | |
| 39 | def _make_simple_midi() -> bytes: |
| 40 | """Return a minimal but valid SMF Type-0 MIDI file with one note-on/off.""" |
| 41 | mid = mido.MidiFile(type=0, ticks_per_beat=480) |
| 42 | track = mido.MidiTrack() |
| 43 | mid.tracks.append(track) |
| 44 | track.append(mido.MetaMessage("set_tempo", tempo=500000, time=0)) |
| 45 | track.append(mido.Message("note_on", channel=0, note=60, velocity=80, time=0)) |
| 46 | track.append(mido.Message("note_off", channel=0, note=60, velocity=0, time=480)) |
| 47 | track.append(mido.MetaMessage("end_of_track", time=0)) |
| 48 | buf = io.BytesIO() |
| 49 | mid.save(file=buf) |
| 50 | return buf.getvalue() |
| 51 | |
| 52 | |
| 53 | def _make_multi_track_midi() -> bytes: |
| 54 | """Return an SMF Type-1 file with two tracks.""" |
| 55 | mid = mido.MidiFile(type=1, ticks_per_beat=480) |
| 56 | |
| 57 | track0 = mido.MidiTrack() |
| 58 | mid.tracks.append(track0) |
| 59 | track0.append(mido.MetaMessage("set_tempo", tempo=500000, time=0)) |
| 60 | track0.append(mido.MetaMessage("time_signature", numerator=3, denominator=4, time=0)) |
| 61 | track0.append(mido.MetaMessage("track_name", name="Piano", time=0)) |
| 62 | track0.append(mido.MetaMessage("end_of_track", time=0)) |
| 63 | |
| 64 | track1 = mido.MidiTrack() |
| 65 | mid.tracks.append(track1) |
| 66 | track1.append(mido.MetaMessage("track_name", name="Bass", time=0)) |
| 67 | track1.append(mido.Message("note_on", channel=1, note=36, velocity=100, time=0)) |
| 68 | track1.append(mido.Message("note_off", channel=1, note=36, velocity=0, time=960)) |
| 69 | track1.append(mido.MetaMessage("end_of_track", time=0)) |
| 70 | |
| 71 | buf = io.BytesIO() |
| 72 | mid.save(file=buf) |
| 73 | return buf.getvalue() |
| 74 | |
| 75 | |
| 76 | # --------------------------------------------------------------------------- |
| 77 | # Unit tests — musehub_midi_parser |
| 78 | # --------------------------------------------------------------------------- |
| 79 | |
| 80 | |
| 81 | def test_parse_midi_bytes_basic_shape() -> None: |
| 82 | """parse_midi_bytes returns all required MidiParseResult keys.""" |
| 83 | result = parse_midi_bytes(_make_simple_midi()) |
| 84 | assert "tracks" in result |
| 85 | assert "tempo_bpm" in result |
| 86 | assert "time_signature" in result |
| 87 | assert "total_beats" in result |
| 88 | |
| 89 | |
| 90 | def test_parse_midi_bytes_note_data() -> None: |
| 91 | """Parsed note has correct pitch, velocity, and positive duration.""" |
| 92 | result = parse_midi_bytes(_make_simple_midi()) |
| 93 | tracks = result["tracks"] |
| 94 | assert len(tracks) >= 1 |
| 95 | notes = tracks[0]["notes"] |
| 96 | assert len(notes) == 1 |
| 97 | note = notes[0] |
| 98 | assert note["pitch"] == 60 |
| 99 | assert note["velocity"] == 80 |
| 100 | assert note["duration_beats"] > 0 |
| 101 | assert note["start_beat"] == 0.0 |
| 102 | assert note["track_id"] == 0 |
| 103 | assert note["channel"] == 0 |
| 104 | |
| 105 | |
| 106 | def test_parse_midi_bytes_tempo() -> None: |
| 107 | """Default 500000 µs/beat = 120 BPM is parsed correctly.""" |
| 108 | result = parse_midi_bytes(_make_simple_midi()) |
| 109 | assert abs(result["tempo_bpm"] - 120.0) < 0.1 |
| 110 | |
| 111 | |
| 112 | def test_parse_midi_bytes_time_signature() -> None: |
| 113 | """Time signature from meta message is returned as 'N/D' string.""" |
| 114 | midi_bytes = _make_multi_track_midi() |
| 115 | result = parse_midi_bytes(midi_bytes) |
| 116 | assert result["time_signature"] == "3/4" |
| 117 | |
| 118 | |
| 119 | def test_parse_midi_bytes_multi_track() -> None: |
| 120 | """Multi-track MIDI produces one MidiTrack entry per SMF track.""" |
| 121 | result = parse_midi_bytes(_make_multi_track_midi()) |
| 122 | assert len(result["tracks"]) == 2 |
| 123 | # Bass track should have its note |
| 124 | bass_track = result["tracks"][1] |
| 125 | assert bass_track["name"] == "Bass" |
| 126 | assert len(bass_track["notes"]) == 1 |
| 127 | assert bass_track["notes"][0]["pitch"] == 36 |
| 128 | |
| 129 | |
| 130 | def test_parse_midi_bytes_total_beats_positive() -> None: |
| 131 | """total_beats is greater than zero when notes are present.""" |
| 132 | result = parse_midi_bytes(_make_simple_midi()) |
| 133 | assert result["total_beats"] > 0 |
| 134 | |
| 135 | |
| 136 | def test_parse_midi_bytes_empty_track() -> None: |
| 137 | """An SMF file with no notes returns zero total_beats.""" |
| 138 | mid = mido.MidiFile(type=0, ticks_per_beat=480) |
| 139 | track = mido.MidiTrack() |
| 140 | mid.tracks.append(track) |
| 141 | track.append(mido.MetaMessage("end_of_track", time=0)) |
| 142 | buf = io.BytesIO() |
| 143 | mid.save(file=buf) |
| 144 | result = parse_midi_bytes(buf.getvalue()) |
| 145 | assert result["total_beats"] == 0.0 |
| 146 | |
| 147 | |
| 148 | def test_parse_midi_bytes_invalid_data_raises() -> None: |
| 149 | """Garbage bytes raise ValueError with a descriptive message.""" |
| 150 | with pytest.raises(ValueError, match="Could not parse MIDI"): |
| 151 | parse_midi_bytes(b"\x00\x01\x02\x03garbage") |
| 152 | |
| 153 | |
| 154 | def test_parse_midi_bytes_notes_sorted_by_start_beat() -> None: |
| 155 | """Notes within each track are sorted by start_beat ascending.""" |
| 156 | mid = mido.MidiFile(type=0, ticks_per_beat=480) |
| 157 | track = mido.MidiTrack() |
| 158 | mid.tracks.append(track) |
| 159 | # Two notes at different beat positions |
| 160 | track.append(mido.Message("note_on", channel=0, note=64, velocity=70, time=0)) |
| 161 | track.append(mido.Message("note_off", channel=0, note=64, velocity=0, time=240)) |
| 162 | track.append(mido.Message("note_on", channel=0, note=60, velocity=80, time=0)) |
| 163 | track.append(mido.Message("note_off", channel=0, note=60, velocity=0, time=480)) |
| 164 | track.append(mido.MetaMessage("end_of_track", time=0)) |
| 165 | buf = io.BytesIO() |
| 166 | mid.save(file=buf) |
| 167 | result = parse_midi_bytes(buf.getvalue()) |
| 168 | notes = result["tracks"][0]["notes"] |
| 169 | beats = [n["start_beat"] for n in notes] |
| 170 | assert beats == sorted(beats) |
| 171 | |
| 172 | |
| 173 | def test_pitch_to_name_middle_c() -> None: |
| 174 | """MIDI pitch 60 is middle C (C4).""" |
| 175 | assert pitch_to_name(60) == "C4" |
| 176 | |
| 177 | |
| 178 | def test_pitch_to_name_a4() -> None: |
| 179 | """MIDI pitch 69 is A4 (concert A).""" |
| 180 | assert pitch_to_name(69) == "A4" |
| 181 | |
| 182 | |
| 183 | def test_pitch_to_name_a0() -> None: |
| 184 | """MIDI pitch 21 is A0 (lowest piano key).""" |
| 185 | assert pitch_to_name(21) == "A0" |
| 186 | |
| 187 | |
| 188 | # --------------------------------------------------------------------------- |
| 189 | # HTTP endpoint tests — parse-midi route |
| 190 | # --------------------------------------------------------------------------- |
| 191 | |
| 192 | |
| 193 | _OBJ_COUNTER = 0 |
| 194 | |
| 195 | |
| 196 | async def _seed_repo_and_obj( |
| 197 | db_session: AsyncSession, |
| 198 | disk_path: str = "/nonexistent/track.mid", |
| 199 | path: str = "tracks/bass.mid", |
| 200 | ) -> tuple[str, str]: |
| 201 | """Seed a repo and object; return (repo_id, object_id).""" |
| 202 | global _OBJ_COUNTER |
| 203 | _OBJ_COUNTER += 1 |
| 204 | object_id = f"sha256:test{_OBJ_COUNTER:04d}" |
| 205 | |
| 206 | repo = MusehubRepo( |
| 207 | name=f"midi-test-{_OBJ_COUNTER}", |
| 208 | owner="testuser", |
| 209 | slug=f"midi-test-{_OBJ_COUNTER}", |
| 210 | visibility="public", |
| 211 | owner_user_id="test-owner", |
| 212 | ) |
| 213 | db_session.add(repo) |
| 214 | await db_session.commit() |
| 215 | await db_session.refresh(repo) |
| 216 | |
| 217 | obj = MusehubObject( |
| 218 | object_id=object_id, |
| 219 | repo_id=str(repo.repo_id), |
| 220 | path=path, |
| 221 | size_bytes=0, |
| 222 | disk_path=disk_path, |
| 223 | ) |
| 224 | db_session.add(obj) |
| 225 | await db_session.commit() |
| 226 | await db_session.refresh(obj) |
| 227 | return str(repo.repo_id), str(obj.object_id) |
| 228 | |
| 229 | |
| 230 | @pytest.mark.anyio |
| 231 | async def test_parse_midi_object_endpoint_unknown_repo_404( |
| 232 | client: AsyncClient, |
| 233 | auth_headers: dict[str, str], |
| 234 | ) -> None: |
| 235 | """GET /parse-midi for an unknown repo_id returns 404.""" |
| 236 | response = await client.get( |
| 237 | "/api/v1/musehub/repos/unknown-repo/objects/unknown-obj/parse-midi", |
| 238 | headers=auth_headers, |
| 239 | ) |
| 240 | assert response.status_code == 404 |
| 241 | |
| 242 | |
| 243 | @pytest.mark.anyio |
| 244 | async def test_parse_midi_object_endpoint_unknown_object_404( |
| 245 | client: AsyncClient, |
| 246 | db_session: AsyncSession, |
| 247 | auth_headers: dict[str, str], |
| 248 | ) -> None: |
| 249 | """GET /parse-midi for a missing object_id returns 404.""" |
| 250 | repo_id, _ = await _seed_repo_and_obj(db_session) |
| 251 | response = await client.get( |
| 252 | f"/api/v1/musehub/repos/{repo_id}/objects/missing-object-id/parse-midi", |
| 253 | headers=auth_headers, |
| 254 | ) |
| 255 | assert response.status_code == 404 |
| 256 | |
| 257 | |
| 258 | @pytest.mark.anyio |
| 259 | async def test_parse_midi_object_non_midi_404( |
| 260 | client: AsyncClient, |
| 261 | db_session: AsyncSession, |
| 262 | auth_headers: dict[str, str], |
| 263 | ) -> None: |
| 264 | """GET /parse-midi for a non-MIDI object (e.g. .mp3) returns 404.""" |
| 265 | repo_id, obj_id = await _seed_repo_and_obj( |
| 266 | db_session, path="tracks/audio.mp3" |
| 267 | ) |
| 268 | response = await client.get( |
| 269 | f"/api/v1/musehub/repos/{repo_id}/objects/{obj_id}/parse-midi", |
| 270 | headers=auth_headers, |
| 271 | ) |
| 272 | assert response.status_code == 404 |
| 273 | assert "MIDI" in response.json()["detail"] |
| 274 | |
| 275 | |
| 276 | @pytest.mark.anyio |
| 277 | async def test_parse_midi_object_missing_disk_file_410( |
| 278 | client: AsyncClient, |
| 279 | db_session: AsyncSession, |
| 280 | auth_headers: dict[str, str], |
| 281 | ) -> None: |
| 282 | """GET /parse-midi when disk file is gone returns 410.""" |
| 283 | repo_id, obj_id = await _seed_repo_and_obj( |
| 284 | db_session, disk_path="/nonexistent/missing.mid", path="missing.mid" |
| 285 | ) |
| 286 | response = await client.get( |
| 287 | f"/api/v1/musehub/repos/{repo_id}/objects/{obj_id}/parse-midi", |
| 288 | headers=auth_headers, |
| 289 | ) |
| 290 | assert response.status_code == 410 |
| 291 | |
| 292 | |
| 293 | @pytest.mark.anyio |
| 294 | async def test_parse_midi_object_returns_valid_result( |
| 295 | client: AsyncClient, |
| 296 | db_session: AsyncSession, |
| 297 | auth_headers: dict[str, str], |
| 298 | ) -> None: |
| 299 | """GET /parse-midi for a valid MIDI file returns MidiParseResult JSON.""" |
| 300 | midi_data = _make_simple_midi() |
| 301 | with tempfile.NamedTemporaryFile(suffix=".mid", delete=False) as fh: |
| 302 | fh.write(midi_data) |
| 303 | tmp_path = fh.name |
| 304 | |
| 305 | try: |
| 306 | repo_id, obj_id = await _seed_repo_and_obj( |
| 307 | db_session, disk_path=tmp_path, path="track.mid" |
| 308 | ) |
| 309 | response = await client.get( |
| 310 | f"/api/v1/musehub/repos/{repo_id}/objects/{obj_id}/parse-midi", |
| 311 | headers=auth_headers, |
| 312 | ) |
| 313 | assert response.status_code == 200 |
| 314 | body = response.json() |
| 315 | assert "tracks" in body |
| 316 | assert "tempo_bpm" in body |
| 317 | assert "time_signature" in body |
| 318 | assert "total_beats" in body |
| 319 | assert body["total_beats"] > 0 |
| 320 | tracks = body["tracks"] |
| 321 | assert isinstance(tracks, list) |
| 322 | assert len(tracks) >= 1 |
| 323 | notes = tracks[0]["notes"] |
| 324 | assert isinstance(notes, list) |
| 325 | assert len(notes) >= 1 |
| 326 | note = notes[0] |
| 327 | assert "pitch" in note |
| 328 | assert "start_beat" in note |
| 329 | assert "duration_beats" in note |
| 330 | assert "velocity" in note |
| 331 | assert "track_id" in note |
| 332 | assert "channel" in note |
| 333 | finally: |
| 334 | os.unlink(tmp_path) |