cgcardona / muse public
_query.py python
304 lines 10.1 KB
e6786943 feat: upgrade to Python 3.14, drop from __future__ import annotations Gabriel Cardona <cgcardona@gmail.com> 1d ago
1 """Shared music-domain query helpers for the Muse CLI.
2
3 Provides the low-level primitives that music-domain commands share:
4 note extraction from the object store, bar-level grouping, chord detection,
5 and commit-graph walking specific to MIDI tracks.
6
7 Nothing here belongs in the public ``MidiPlugin`` API. These are CLI-layer
8 helpers — thin adapters over ``midi_diff.extract_notes`` and the core store.
9 """
10
11 import logging
12 import pathlib
13 from typing import NamedTuple
14
15 from muse.core.object_store import read_object
16 from muse.core.store import CommitRecord, read_commit, get_commit_snapshot_manifest
17 from muse.plugins.midi.midi_diff import NoteKey, _pitch_name, extract_notes # noqa: PLC2701
18
19 logger = logging.getLogger(__name__)
20
21 # ---------------------------------------------------------------------------
22 # Pitch / music-theory constants
23 # ---------------------------------------------------------------------------
24
25 _PITCH_CLASSES = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
26
27 # Chord templates: frozenset of pitch-class offsets (root = 0).
28 _CHORD_TEMPLATES: list[tuple[str, frozenset[int]]] = [
29 ("maj", frozenset({0, 4, 7})),
30 ("min", frozenset({0, 3, 7})),
31 ("dim", frozenset({0, 3, 6})),
32 ("aug", frozenset({0, 4, 8})),
33 ("sus2", frozenset({0, 2, 7})),
34 ("sus4", frozenset({0, 5, 7})),
35 ("dom7", frozenset({0, 4, 7, 10})),
36 ("maj7", frozenset({0, 4, 7, 11})),
37 ("min7", frozenset({0, 3, 7, 10})),
38 ("dim7", frozenset({0, 3, 6, 9})),
39 ("5", frozenset({0, 7})), # power chord
40 ]
41
42 # ---------------------------------------------------------------------------
43 # NoteInfo — enriched note for display
44 # ---------------------------------------------------------------------------
45
46
47 class NoteInfo(NamedTuple):
48 """A ``NoteKey`` with derived musical fields for display."""
49
50 pitch: int
51 velocity: int
52 start_tick: int
53 duration_ticks: int
54 channel: int
55 ticks_per_beat: int
56
57 @property
58 def pitch_name(self) -> str:
59 return _pitch_name(self.pitch)
60
61 @property
62 def beat(self) -> float:
63 return self.start_tick / max(self.ticks_per_beat, 1)
64
65 @property
66 def beat_duration(self) -> float:
67 return self.duration_ticks / max(self.ticks_per_beat, 1)
68
69 @property
70 def bar(self) -> int:
71 """1-indexed bar number (assumes 4/4 time)."""
72 return int(self.start_tick // (4 * max(self.ticks_per_beat, 1))) + 1
73
74 @property
75 def beat_in_bar(self) -> float:
76 """Beat position within the bar (1-indexed)."""
77 tpb = max(self.ticks_per_beat, 1)
78 bar_tick = (self.bar - 1) * 4 * tpb
79 return (self.start_tick - bar_tick) / tpb + 1
80
81 @property
82 def pitch_class(self) -> int:
83 return self.pitch % 12
84
85 @property
86 def pitch_class_name(self) -> str:
87 return _PITCH_CLASSES[self.pitch_class]
88
89 @classmethod
90 def from_note_key(cls, note: NoteKey, ticks_per_beat: int) -> "NoteInfo":
91 return cls(
92 pitch=note["pitch"],
93 velocity=note["velocity"],
94 start_tick=note["start_tick"],
95 duration_ticks=note["duration_ticks"],
96 channel=note["channel"],
97 ticks_per_beat=ticks_per_beat,
98 )
99
100
101 # ---------------------------------------------------------------------------
102 # Track loading from the object store
103 # ---------------------------------------------------------------------------
104
105
106 def load_track(
107 root: pathlib.Path,
108 commit_id: str,
109 track_path: str,
110 ) -> tuple[list[NoteInfo], int] | None:
111 """Load notes for *track_path* from the snapshot at *commit_id*.
112
113 Args:
114 root: Repository root.
115 commit_id: SHA-256 commit ID.
116 track_path: Workspace-relative path to the ``.mid`` file.
117
118 Returns:
119 ``(notes, ticks_per_beat)`` on success, ``None`` when the track is
120 not in the snapshot or the object is missing / unparseable.
121 """
122 manifest: dict[str, str] = get_commit_snapshot_manifest(root, commit_id) or {}
123 object_id = manifest.get(track_path)
124 if object_id is None:
125 return None
126 raw = read_object(root, object_id)
127 if raw is None:
128 return None
129 try:
130 keys, tpb = extract_notes(raw)
131 except ValueError as exc:
132 logger.debug("Cannot parse MIDI %r from commit %s: %s", track_path, commit_id[:8], exc)
133 return None
134 notes = [NoteInfo.from_note_key(k, tpb) for k in keys]
135 return notes, tpb
136
137
138 def load_track_from_workdir(
139 root: pathlib.Path,
140 track_path: str,
141 ) -> tuple[list[NoteInfo], int] | None:
142 """Load notes for *track_path* from ``muse-work/`` (live working tree).
143
144 Args:
145 root: Repository root.
146 track_path: Workspace-relative path to the ``.mid`` file.
147
148 Returns:
149 ``(notes, ticks_per_beat)`` on success, ``None`` when unreadable.
150 """
151 work_path = root / "muse-work" / track_path
152 if not work_path.exists():
153 work_path = root / track_path
154 if not work_path.exists():
155 return None
156 raw = work_path.read_bytes()
157 try:
158 keys, tpb = extract_notes(raw)
159 except ValueError as exc:
160 logger.debug("Cannot parse MIDI %r from workdir: %s", track_path, exc)
161 return None
162 notes = [NoteInfo.from_note_key(k, tpb) for k in keys]
163 return notes, tpb
164
165
166 # ---------------------------------------------------------------------------
167 # Musical analysis helpers
168 # ---------------------------------------------------------------------------
169
170
171 def notes_by_bar(notes: list[NoteInfo]) -> dict[int, list[NoteInfo]]:
172 """Group *notes* by 1-indexed bar number (assumes 4/4 time)."""
173 bars: dict[int, list[NoteInfo]] = {}
174 for note in sorted(notes, key=lambda n: (n.start_tick, n.pitch)):
175 bars.setdefault(note.bar, []).append(note)
176 return bars
177
178
179 def detect_chord(pitch_classes: frozenset[int]) -> str:
180 """Return the best chord name for a set of pitch classes.
181
182 Tries every chromatic root and every chord template. Returns the
183 name of the best match (most pitch classes covered) as ``"RootQuality"``
184 e.g. ``"Cmaj"``, ``"Fmin7"``. Returns ``"??"`` when fewer than two
185 distinct pitch classes are present.
186 """
187 if len(pitch_classes) < 2:
188 return "??"
189 best_name = "??"
190 best_score = 0
191 for root in range(12):
192 normalized = frozenset((pc - root) % 12 for pc in pitch_classes)
193 for quality, template in _CHORD_TEMPLATES:
194 overlap = len(normalized & template)
195 if overlap > best_score or (
196 overlap == best_score and overlap == len(template)
197 ):
198 best_score = overlap
199 root_name = _PITCH_CLASSES[root]
200 best_name = f"{root_name}{quality}"
201 return best_name
202
203
204 def key_signature_guess(notes: list[NoteInfo]) -> str:
205 """Guess the key signature from pitch class frequencies.
206
207 Uses the Krumhansl-Schmuckler key-finding algorithm with simplified
208 major and minor profiles. Returns a string like ``"G major"`` or
209 ``"D minor"``.
210 """
211 if not notes:
212 return "unknown"
213
214 # Build pitch class histogram.
215 histogram = [0] * 12
216 for note in notes:
217 histogram[note.pitch_class] += 1
218
219 # Krumhansl-Schmuckler major and minor profiles (normalized).
220 major_profile = [
221 6.35, 2.23, 3.48, 2.33, 4.38, 4.09,
222 2.52, 5.19, 2.39, 3.66, 2.29, 2.88,
223 ]
224 minor_profile = [
225 6.33, 2.68, 3.52, 5.38, 2.60, 3.53,
226 2.54, 4.75, 3.98, 2.69, 3.34, 3.17,
227 ]
228
229 total = max(sum(histogram), 1)
230 h_norm = [v / total for v in histogram]
231
232 best_key = ""
233 best_score = -999.0
234
235 for root in range(12):
236 for mode, profile in [("major", major_profile), ("minor", minor_profile)]:
237 # Rotate profile to this root.
238 score = sum(
239 h_norm[(root + i) % 12] * profile[i] for i in range(12)
240 )
241 if score > best_score:
242 best_score = score
243 best_key = f"{_PITCH_CLASSES[root]} {mode}"
244
245 return best_key
246
247
248 # ---------------------------------------------------------------------------
249 # Commit-graph walking (music-domain specific)
250 # ---------------------------------------------------------------------------
251
252
253 def walk_commits_for_track(
254 root: pathlib.Path,
255 start_commit_id: str,
256 track_path: str,
257 max_commits: int = 10_000,
258 ) -> list[tuple[CommitRecord, dict[str, str] | None]]:
259 """Walk the parent chain from *start_commit_id*, collecting snapshot manifests.
260
261 Returns ``(commit, manifest)`` pairs where ``manifest`` may be ``None``
262 when the commit has no snapshot. Only commits where the track appears
263 in the manifest (or in its parent's manifest) are useful for note-level
264 queries, but we return all so callers can filter.
265 """
266 result: list[tuple[CommitRecord, dict[str, str] | None]] = []
267 seen: set[str] = set()
268 current_id: str | None = start_commit_id
269 while current_id and current_id not in seen and len(result) < max_commits:
270 seen.add(current_id)
271 commit = read_commit(root, current_id)
272 if commit is None:
273 break
274 manifest = get_commit_snapshot_manifest(root, commit.commit_id) or None
275 result.append((commit, manifest))
276 current_id = commit.parent_commit_id
277 return result
278
279
280 # ---------------------------------------------------------------------------
281 # MIDI reconstruction helper (for transpose / mix)
282 # ---------------------------------------------------------------------------
283
284
285 def notes_to_midi_bytes(notes: list[NoteInfo], ticks_per_beat: int) -> bytes:
286 """Reconstruct a MIDI file from a list of ``NoteInfo`` objects.
287
288 Produces a Type-0 single-track MIDI file with one note_on / note_off
289 pair per note. Delegates to
290 :func:`~muse.plugins.midi.midi_diff.reconstruct_midi`.
291 """
292 from muse.plugins.midi.midi_diff import NoteKey, reconstruct_midi
293
294 keys = [
295 NoteKey(
296 pitch=n.pitch,
297 velocity=n.velocity,
298 start_tick=n.start_tick,
299 duration_ticks=n.duration_ticks,
300 channel=n.channel,
301 )
302 for n in notes
303 ]
304 return reconstruct_midi(keys, ticks_per_beat=ticks_per_beat)