gabriel / muse public
_query.py python
305 lines 10.1 KB
26a36470 feat(music): 9 new semantic commands — version control that understands music Gabriel Cardona <gabriel@tellurstori.com> 5d ago
1 """Shared music-domain query helpers for the Muse CLI.
2
3 Provides the low-level primitives that music-domain commands share:
4 note extraction from the object store, bar-level grouping, chord detection,
5 and commit-graph walking specific to MIDI tracks.
6
7 Nothing here belongs in the public ``MusicPlugin`` API. These are CLI-layer
8 helpers — thin adapters over ``midi_diff.extract_notes`` and the core store.
9 """
10 from __future__ import annotations
11
12 import logging
13 import pathlib
14 from typing import NamedTuple
15
16 from muse.core.object_store import read_object
17 from muse.core.store import CommitRecord, read_commit, get_commit_snapshot_manifest
18 from muse.plugins.music.midi_diff import NoteKey, _pitch_name, extract_notes # noqa: PLC2701
19
20 logger = logging.getLogger(__name__)
21
22 # ---------------------------------------------------------------------------
23 # Pitch / music-theory constants
24 # ---------------------------------------------------------------------------
25
26 _PITCH_CLASSES = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
27
28 # Chord templates: frozenset of pitch-class offsets (root = 0).
29 _CHORD_TEMPLATES: list[tuple[str, frozenset[int]]] = [
30 ("maj", frozenset({0, 4, 7})),
31 ("min", frozenset({0, 3, 7})),
32 ("dim", frozenset({0, 3, 6})),
33 ("aug", frozenset({0, 4, 8})),
34 ("sus2", frozenset({0, 2, 7})),
35 ("sus4", frozenset({0, 5, 7})),
36 ("dom7", frozenset({0, 4, 7, 10})),
37 ("maj7", frozenset({0, 4, 7, 11})),
38 ("min7", frozenset({0, 3, 7, 10})),
39 ("dim7", frozenset({0, 3, 6, 9})),
40 ("5", frozenset({0, 7})), # power chord
41 ]
42
43 # ---------------------------------------------------------------------------
44 # NoteInfo — enriched note for display
45 # ---------------------------------------------------------------------------
46
47
48 class NoteInfo(NamedTuple):
49 """A ``NoteKey`` with derived musical fields for display."""
50
51 pitch: int
52 velocity: int
53 start_tick: int
54 duration_ticks: int
55 channel: int
56 ticks_per_beat: int
57
58 @property
59 def pitch_name(self) -> str:
60 return _pitch_name(self.pitch)
61
62 @property
63 def beat(self) -> float:
64 return self.start_tick / max(self.ticks_per_beat, 1)
65
66 @property
67 def beat_duration(self) -> float:
68 return self.duration_ticks / max(self.ticks_per_beat, 1)
69
70 @property
71 def bar(self) -> int:
72 """1-indexed bar number (assumes 4/4 time)."""
73 return int(self.start_tick // (4 * max(self.ticks_per_beat, 1))) + 1
74
75 @property
76 def beat_in_bar(self) -> float:
77 """Beat position within the bar (1-indexed)."""
78 tpb = max(self.ticks_per_beat, 1)
79 bar_tick = (self.bar - 1) * 4 * tpb
80 return (self.start_tick - bar_tick) / tpb + 1
81
82 @property
83 def pitch_class(self) -> int:
84 return self.pitch % 12
85
86 @property
87 def pitch_class_name(self) -> str:
88 return _PITCH_CLASSES[self.pitch_class]
89
90 @classmethod
91 def from_note_key(cls, note: NoteKey, ticks_per_beat: int) -> "NoteInfo":
92 return cls(
93 pitch=note["pitch"],
94 velocity=note["velocity"],
95 start_tick=note["start_tick"],
96 duration_ticks=note["duration_ticks"],
97 channel=note["channel"],
98 ticks_per_beat=ticks_per_beat,
99 )
100
101
102 # ---------------------------------------------------------------------------
103 # Track loading from the object store
104 # ---------------------------------------------------------------------------
105
106
107 def load_track(
108 root: pathlib.Path,
109 commit_id: str,
110 track_path: str,
111 ) -> tuple[list[NoteInfo], int] | None:
112 """Load notes for *track_path* from the snapshot at *commit_id*.
113
114 Args:
115 root: Repository root.
116 commit_id: SHA-256 commit ID.
117 track_path: Workspace-relative path to the ``.mid`` file.
118
119 Returns:
120 ``(notes, ticks_per_beat)`` on success, ``None`` when the track is
121 not in the snapshot or the object is missing / unparseable.
122 """
123 manifest: dict[str, str] = get_commit_snapshot_manifest(root, commit_id) or {}
124 object_id = manifest.get(track_path)
125 if object_id is None:
126 return None
127 raw = read_object(root, object_id)
128 if raw is None:
129 return None
130 try:
131 keys, tpb = extract_notes(raw)
132 except ValueError as exc:
133 logger.debug("Cannot parse MIDI %r from commit %s: %s", track_path, commit_id[:8], exc)
134 return None
135 notes = [NoteInfo.from_note_key(k, tpb) for k in keys]
136 return notes, tpb
137
138
139 def load_track_from_workdir(
140 root: pathlib.Path,
141 track_path: str,
142 ) -> tuple[list[NoteInfo], int] | None:
143 """Load notes for *track_path* from ``muse-work/`` (live working tree).
144
145 Args:
146 root: Repository root.
147 track_path: Workspace-relative path to the ``.mid`` file.
148
149 Returns:
150 ``(notes, ticks_per_beat)`` on success, ``None`` when unreadable.
151 """
152 work_path = root / "muse-work" / track_path
153 if not work_path.exists():
154 work_path = root / track_path
155 if not work_path.exists():
156 return None
157 raw = work_path.read_bytes()
158 try:
159 keys, tpb = extract_notes(raw)
160 except ValueError as exc:
161 logger.debug("Cannot parse MIDI %r from workdir: %s", track_path, exc)
162 return None
163 notes = [NoteInfo.from_note_key(k, tpb) for k in keys]
164 return notes, tpb
165
166
167 # ---------------------------------------------------------------------------
168 # Musical analysis helpers
169 # ---------------------------------------------------------------------------
170
171
172 def notes_by_bar(notes: list[NoteInfo]) -> dict[int, list[NoteInfo]]:
173 """Group *notes* by 1-indexed bar number (assumes 4/4 time)."""
174 bars: dict[int, list[NoteInfo]] = {}
175 for note in sorted(notes, key=lambda n: (n.start_tick, n.pitch)):
176 bars.setdefault(note.bar, []).append(note)
177 return bars
178
179
180 def detect_chord(pitch_classes: frozenset[int]) -> str:
181 """Return the best chord name for a set of pitch classes.
182
183 Tries every chromatic root and every chord template. Returns the
184 name of the best match (most pitch classes covered) as ``"RootQuality"``
185 e.g. ``"Cmaj"``, ``"Fmin7"``. Returns ``"??"`` when fewer than two
186 distinct pitch classes are present.
187 """
188 if len(pitch_classes) < 2:
189 return "??"
190 best_name = "??"
191 best_score = 0
192 for root in range(12):
193 normalized = frozenset((pc - root) % 12 for pc in pitch_classes)
194 for quality, template in _CHORD_TEMPLATES:
195 overlap = len(normalized & template)
196 if overlap > best_score or (
197 overlap == best_score and overlap == len(template)
198 ):
199 best_score = overlap
200 root_name = _PITCH_CLASSES[root]
201 best_name = f"{root_name}{quality}"
202 return best_name
203
204
205 def key_signature_guess(notes: list[NoteInfo]) -> str:
206 """Guess the key signature from pitch class frequencies.
207
208 Uses the Krumhansl-Schmuckler key-finding algorithm with simplified
209 major and minor profiles. Returns a string like ``"G major"`` or
210 ``"D minor"``.
211 """
212 if not notes:
213 return "unknown"
214
215 # Build pitch class histogram.
216 histogram = [0] * 12
217 for note in notes:
218 histogram[note.pitch_class] += 1
219
220 # Krumhansl-Schmuckler major and minor profiles (normalized).
221 major_profile = [
222 6.35, 2.23, 3.48, 2.33, 4.38, 4.09,
223 2.52, 5.19, 2.39, 3.66, 2.29, 2.88,
224 ]
225 minor_profile = [
226 6.33, 2.68, 3.52, 5.38, 2.60, 3.53,
227 2.54, 4.75, 3.98, 2.69, 3.34, 3.17,
228 ]
229
230 total = max(sum(histogram), 1)
231 h_norm = [v / total for v in histogram]
232
233 best_key = ""
234 best_score = -999.0
235
236 for root in range(12):
237 for mode, profile in [("major", major_profile), ("minor", minor_profile)]:
238 # Rotate profile to this root.
239 score = sum(
240 h_norm[(root + i) % 12] * profile[i] for i in range(12)
241 )
242 if score > best_score:
243 best_score = score
244 best_key = f"{_PITCH_CLASSES[root]} {mode}"
245
246 return best_key
247
248
249 # ---------------------------------------------------------------------------
250 # Commit-graph walking (music-domain specific)
251 # ---------------------------------------------------------------------------
252
253
254 def walk_commits_for_track(
255 root: pathlib.Path,
256 start_commit_id: str,
257 track_path: str,
258 max_commits: int = 10_000,
259 ) -> list[tuple[CommitRecord, dict[str, str] | None]]:
260 """Walk the parent chain from *start_commit_id*, collecting snapshot manifests.
261
262 Returns ``(commit, manifest)`` pairs where ``manifest`` may be ``None``
263 when the commit has no snapshot. Only commits where the track appears
264 in the manifest (or in its parent's manifest) are useful for note-level
265 queries, but we return all so callers can filter.
266 """
267 result: list[tuple[CommitRecord, dict[str, str] | None]] = []
268 seen: set[str] = set()
269 current_id: str | None = start_commit_id
270 while current_id and current_id not in seen and len(result) < max_commits:
271 seen.add(current_id)
272 commit = read_commit(root, current_id)
273 if commit is None:
274 break
275 manifest = get_commit_snapshot_manifest(root, commit.commit_id) or None
276 result.append((commit, manifest))
277 current_id = commit.parent_commit_id
278 return result
279
280
281 # ---------------------------------------------------------------------------
282 # MIDI reconstruction helper (for transpose / mix)
283 # ---------------------------------------------------------------------------
284
285
286 def notes_to_midi_bytes(notes: list[NoteInfo], ticks_per_beat: int) -> bytes:
287 """Reconstruct a MIDI file from a list of ``NoteInfo`` objects.
288
289 Produces a Type-0 single-track MIDI file with one note_on / note_off
290 pair per note. Delegates to
291 :func:`~muse.plugins.music.midi_diff.reconstruct_midi`.
292 """
293 from muse.plugins.music.midi_diff import NoteKey, reconstruct_midi
294
295 keys = [
296 NoteKey(
297 pitch=n.pitch,
298 velocity=n.velocity,
299 start_tick=n.start_tick,
300 duration_ticks=n.duration_ticks,
301 channel=n.channel,
302 )
303 for n in notes
304 ]
305 return reconstruct_midi(keys, ticks_per_beat=ticks_per_beat)