_query.py
python
| 1 | """Shared music-domain query helpers for the Muse CLI. |
| 2 | |
| 3 | Provides the low-level primitives that music-domain commands share: |
| 4 | note extraction from the object store, bar-level grouping, chord detection, |
| 5 | and commit-graph walking specific to MIDI tracks. |
| 6 | |
| 7 | Nothing here belongs in the public ``MidiPlugin`` API. These are CLI-layer |
| 8 | helpers — thin adapters over ``midi_diff.extract_notes`` and the core store. |
| 9 | """ |
| 10 | |
| 11 | import logging |
| 12 | import pathlib |
| 13 | from typing import NamedTuple |
| 14 | |
| 15 | from muse.core.object_store import read_object |
| 16 | from muse.core.store import CommitRecord, read_commit, get_commit_snapshot_manifest |
| 17 | from muse.plugins.midi.midi_diff import NoteKey, _pitch_name, extract_notes # noqa: PLC2701 |
| 18 | |
| 19 | logger = logging.getLogger(__name__) |
| 20 | |
| 21 | # --------------------------------------------------------------------------- |
| 22 | # Pitch / music-theory constants |
| 23 | # --------------------------------------------------------------------------- |
| 24 | |
| 25 | _PITCH_CLASSES = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"] |
| 26 | |
| 27 | # Chord templates: frozenset of pitch-class offsets (root = 0). |
| 28 | _CHORD_TEMPLATES: list[tuple[str, frozenset[int]]] = [ |
| 29 | ("maj", frozenset({0, 4, 7})), |
| 30 | ("min", frozenset({0, 3, 7})), |
| 31 | ("dim", frozenset({0, 3, 6})), |
| 32 | ("aug", frozenset({0, 4, 8})), |
| 33 | ("sus2", frozenset({0, 2, 7})), |
| 34 | ("sus4", frozenset({0, 5, 7})), |
| 35 | ("dom7", frozenset({0, 4, 7, 10})), |
| 36 | ("maj7", frozenset({0, 4, 7, 11})), |
| 37 | ("min7", frozenset({0, 3, 7, 10})), |
| 38 | ("dim7", frozenset({0, 3, 6, 9})), |
| 39 | ("5", frozenset({0, 7})), # power chord |
| 40 | ] |
| 41 | |
| 42 | # --------------------------------------------------------------------------- |
| 43 | # NoteInfo — enriched note for display |
| 44 | # --------------------------------------------------------------------------- |
| 45 | |
| 46 | |
| 47 | class NoteInfo(NamedTuple): |
| 48 | """A ``NoteKey`` with derived musical fields for display.""" |
| 49 | |
| 50 | pitch: int |
| 51 | velocity: int |
| 52 | start_tick: int |
| 53 | duration_ticks: int |
| 54 | channel: int |
| 55 | ticks_per_beat: int |
| 56 | |
| 57 | @property |
| 58 | def pitch_name(self) -> str: |
| 59 | return _pitch_name(self.pitch) |
| 60 | |
| 61 | @property |
| 62 | def beat(self) -> float: |
| 63 | return self.start_tick / max(self.ticks_per_beat, 1) |
| 64 | |
| 65 | @property |
| 66 | def beat_duration(self) -> float: |
| 67 | return self.duration_ticks / max(self.ticks_per_beat, 1) |
| 68 | |
| 69 | @property |
| 70 | def bar(self) -> int: |
| 71 | """1-indexed bar number (assumes 4/4 time).""" |
| 72 | return int(self.start_tick // (4 * max(self.ticks_per_beat, 1))) + 1 |
| 73 | |
| 74 | @property |
| 75 | def beat_in_bar(self) -> float: |
| 76 | """Beat position within the bar (1-indexed).""" |
| 77 | tpb = max(self.ticks_per_beat, 1) |
| 78 | bar_tick = (self.bar - 1) * 4 * tpb |
| 79 | return (self.start_tick - bar_tick) / tpb + 1 |
| 80 | |
| 81 | @property |
| 82 | def pitch_class(self) -> int: |
| 83 | return self.pitch % 12 |
| 84 | |
| 85 | @property |
| 86 | def pitch_class_name(self) -> str: |
| 87 | return _PITCH_CLASSES[self.pitch_class] |
| 88 | |
| 89 | @classmethod |
| 90 | def from_note_key(cls, note: NoteKey, ticks_per_beat: int) -> "NoteInfo": |
| 91 | return cls( |
| 92 | pitch=note["pitch"], |
| 93 | velocity=note["velocity"], |
| 94 | start_tick=note["start_tick"], |
| 95 | duration_ticks=note["duration_ticks"], |
| 96 | channel=note["channel"], |
| 97 | ticks_per_beat=ticks_per_beat, |
| 98 | ) |
| 99 | |
| 100 | |
| 101 | # --------------------------------------------------------------------------- |
| 102 | # Track loading from the object store |
| 103 | # --------------------------------------------------------------------------- |
| 104 | |
| 105 | |
| 106 | def load_track( |
| 107 | root: pathlib.Path, |
| 108 | commit_id: str, |
| 109 | track_path: str, |
| 110 | ) -> tuple[list[NoteInfo], int] | None: |
| 111 | """Load notes for *track_path* from the snapshot at *commit_id*. |
| 112 | |
| 113 | Args: |
| 114 | root: Repository root. |
| 115 | commit_id: SHA-256 commit ID. |
| 116 | track_path: Workspace-relative path to the ``.mid`` file. |
| 117 | |
| 118 | Returns: |
| 119 | ``(notes, ticks_per_beat)`` on success, ``None`` when the track is |
| 120 | not in the snapshot or the object is missing / unparseable. |
| 121 | """ |
| 122 | manifest: dict[str, str] = get_commit_snapshot_manifest(root, commit_id) or {} |
| 123 | object_id = manifest.get(track_path) |
| 124 | if object_id is None: |
| 125 | return None |
| 126 | raw = read_object(root, object_id) |
| 127 | if raw is None: |
| 128 | return None |
| 129 | try: |
| 130 | keys, tpb = extract_notes(raw) |
| 131 | except ValueError as exc: |
| 132 | logger.debug("Cannot parse MIDI %r from commit %s: %s", track_path, commit_id[:8], exc) |
| 133 | return None |
| 134 | notes = [NoteInfo.from_note_key(k, tpb) for k in keys] |
| 135 | return notes, tpb |
| 136 | |
| 137 | |
| 138 | def load_track_from_workdir( |
| 139 | root: pathlib.Path, |
| 140 | track_path: str, |
| 141 | ) -> tuple[list[NoteInfo], int] | None: |
| 142 | """Load notes for *track_path* from ``muse-work/`` (live working tree). |
| 143 | |
| 144 | Args: |
| 145 | root: Repository root. |
| 146 | track_path: Workspace-relative path to the ``.mid`` file. |
| 147 | |
| 148 | Returns: |
| 149 | ``(notes, ticks_per_beat)`` on success, ``None`` when unreadable. |
| 150 | """ |
| 151 | work_path = root / "muse-work" / track_path |
| 152 | if not work_path.exists(): |
| 153 | work_path = root / track_path |
| 154 | if not work_path.exists(): |
| 155 | return None |
| 156 | raw = work_path.read_bytes() |
| 157 | try: |
| 158 | keys, tpb = extract_notes(raw) |
| 159 | except ValueError as exc: |
| 160 | logger.debug("Cannot parse MIDI %r from workdir: %s", track_path, exc) |
| 161 | return None |
| 162 | notes = [NoteInfo.from_note_key(k, tpb) for k in keys] |
| 163 | return notes, tpb |
| 164 | |
| 165 | |
| 166 | # --------------------------------------------------------------------------- |
| 167 | # Musical analysis helpers |
| 168 | # --------------------------------------------------------------------------- |
| 169 | |
| 170 | |
| 171 | def notes_by_bar(notes: list[NoteInfo]) -> dict[int, list[NoteInfo]]: |
| 172 | """Group *notes* by 1-indexed bar number (assumes 4/4 time).""" |
| 173 | bars: dict[int, list[NoteInfo]] = {} |
| 174 | for note in sorted(notes, key=lambda n: (n.start_tick, n.pitch)): |
| 175 | bars.setdefault(note.bar, []).append(note) |
| 176 | return bars |
| 177 | |
| 178 | |
| 179 | def detect_chord(pitch_classes: frozenset[int]) -> str: |
| 180 | """Return the best chord name for a set of pitch classes. |
| 181 | |
| 182 | Tries every chromatic root and every chord template. Returns the |
| 183 | name of the best match (most pitch classes covered) as ``"RootQuality"`` |
| 184 | e.g. ``"Cmaj"``, ``"Fmin7"``. Returns ``"??"`` when fewer than two |
| 185 | distinct pitch classes are present. |
| 186 | """ |
| 187 | if len(pitch_classes) < 2: |
| 188 | return "??" |
| 189 | best_name = "??" |
| 190 | best_score = 0 |
| 191 | for root in range(12): |
| 192 | normalized = frozenset((pc - root) % 12 for pc in pitch_classes) |
| 193 | for quality, template in _CHORD_TEMPLATES: |
| 194 | overlap = len(normalized & template) |
| 195 | if overlap > best_score or ( |
| 196 | overlap == best_score and overlap == len(template) |
| 197 | ): |
| 198 | best_score = overlap |
| 199 | root_name = _PITCH_CLASSES[root] |
| 200 | best_name = f"{root_name}{quality}" |
| 201 | return best_name |
| 202 | |
| 203 | |
| 204 | def key_signature_guess(notes: list[NoteInfo]) -> str: |
| 205 | """Guess the key signature from pitch class frequencies. |
| 206 | |
| 207 | Uses the Krumhansl-Schmuckler key-finding algorithm with simplified |
| 208 | major and minor profiles. Returns a string like ``"G major"`` or |
| 209 | ``"D minor"``. |
| 210 | """ |
| 211 | if not notes: |
| 212 | return "unknown" |
| 213 | |
| 214 | # Build pitch class histogram. |
| 215 | histogram = [0] * 12 |
| 216 | for note in notes: |
| 217 | histogram[note.pitch_class] += 1 |
| 218 | |
| 219 | # Krumhansl-Schmuckler major and minor profiles (normalized). |
| 220 | major_profile = [ |
| 221 | 6.35, 2.23, 3.48, 2.33, 4.38, 4.09, |
| 222 | 2.52, 5.19, 2.39, 3.66, 2.29, 2.88, |
| 223 | ] |
| 224 | minor_profile = [ |
| 225 | 6.33, 2.68, 3.52, 5.38, 2.60, 3.53, |
| 226 | 2.54, 4.75, 3.98, 2.69, 3.34, 3.17, |
| 227 | ] |
| 228 | |
| 229 | total = max(sum(histogram), 1) |
| 230 | h_norm = [v / total for v in histogram] |
| 231 | |
| 232 | best_key = "" |
| 233 | best_score = -999.0 |
| 234 | |
| 235 | for root in range(12): |
| 236 | for mode, profile in [("major", major_profile), ("minor", minor_profile)]: |
| 237 | # Rotate profile to this root. |
| 238 | score = sum( |
| 239 | h_norm[(root + i) % 12] * profile[i] for i in range(12) |
| 240 | ) |
| 241 | if score > best_score: |
| 242 | best_score = score |
| 243 | best_key = f"{_PITCH_CLASSES[root]} {mode}" |
| 244 | |
| 245 | return best_key |
| 246 | |
| 247 | |
| 248 | # --------------------------------------------------------------------------- |
| 249 | # Commit-graph walking (music-domain specific) |
| 250 | # --------------------------------------------------------------------------- |
| 251 | |
| 252 | |
| 253 | def walk_commits_for_track( |
| 254 | root: pathlib.Path, |
| 255 | start_commit_id: str, |
| 256 | track_path: str, |
| 257 | max_commits: int = 10_000, |
| 258 | ) -> list[tuple[CommitRecord, dict[str, str] | None]]: |
| 259 | """Walk the parent chain from *start_commit_id*, collecting snapshot manifests. |
| 260 | |
| 261 | Returns ``(commit, manifest)`` pairs where ``manifest`` may be ``None`` |
| 262 | when the commit has no snapshot. Only commits where the track appears |
| 263 | in the manifest (or in its parent's manifest) are useful for note-level |
| 264 | queries, but we return all so callers can filter. |
| 265 | """ |
| 266 | result: list[tuple[CommitRecord, dict[str, str] | None]] = [] |
| 267 | seen: set[str] = set() |
| 268 | current_id: str | None = start_commit_id |
| 269 | while current_id and current_id not in seen and len(result) < max_commits: |
| 270 | seen.add(current_id) |
| 271 | commit = read_commit(root, current_id) |
| 272 | if commit is None: |
| 273 | break |
| 274 | manifest = get_commit_snapshot_manifest(root, commit.commit_id) or None |
| 275 | result.append((commit, manifest)) |
| 276 | current_id = commit.parent_commit_id |
| 277 | return result |
| 278 | |
| 279 | |
| 280 | # --------------------------------------------------------------------------- |
| 281 | # MIDI reconstruction helper (for transpose / mix) |
| 282 | # --------------------------------------------------------------------------- |
| 283 | |
| 284 | |
| 285 | def notes_to_midi_bytes(notes: list[NoteInfo], ticks_per_beat: int) -> bytes: |
| 286 | """Reconstruct a MIDI file from a list of ``NoteInfo`` objects. |
| 287 | |
| 288 | Produces a Type-0 single-track MIDI file with one note_on / note_off |
| 289 | pair per note. Delegates to |
| 290 | :func:`~muse.plugins.midi.midi_diff.reconstruct_midi`. |
| 291 | """ |
| 292 | from muse.plugins.midi.midi_diff import NoteKey, reconstruct_midi |
| 293 | |
| 294 | keys = [ |
| 295 | NoteKey( |
| 296 | pitch=n.pitch, |
| 297 | velocity=n.velocity, |
| 298 | start_tick=n.start_tick, |
| 299 | duration_ticks=n.duration_ticks, |
| 300 | channel=n.channel, |
| 301 | ) |
| 302 | for n in notes |
| 303 | ] |
| 304 | return reconstruct_midi(keys, ticks_per_beat=ticks_per_beat) |