gabriel / muse public
_midi_query.py python
572 lines 18.7 KB
bda49bdb feat: redesign .museignore as TOML with domain-scoped sections (#100) Gabriel Cardona <cgcardona@gmail.com> 5d ago
1 """Music-domain query DSL for the Muse VCS.
2
3 Allows agents and humans to query the commit history for musical content::
4
5 muse music-query "note.pitch_class == 'Eb' and bar == 12"
6 muse music-query "harmony.quality == 'dim' and bar == 8"
7 muse music-query "author == 'agent-x' and track == 'piano.mid'"
8 muse music-query "note.velocity > 80 and not bar == 4"
9 muse music-query "agent_id == 'counterpoint-bot'"
10
11 Grammar (EBNF)
12 --------------
13
14 query = or_expr
15 or_expr = and_expr ( 'or' and_expr )*
16 and_expr = not_expr ( 'and' not_expr )*
17 not_expr = 'not' not_expr | atom
18 atom = '(' query ')' | comparison
19 comparison = FIELD OP VALUE
20 FIELD = 'note.pitch' | 'note.pitch_class' | 'note.velocity'
21 | 'note.channel' | 'note.duration'
22 | 'bar' | 'track' | 'harmony.chord' | 'harmony.quality'
23 | 'author' | 'agent_id' | 'model_id' | 'toolchain_id'
24 OP = '==' | '!=' | '>' | '<' | '>=' | '<='
25 VALUE = QUOTED_STRING | NUMBER
26
27 Supported field paths
28 ---------------------
29
30 +---------------------+---------------------------------------------+
31 | Field | Resolves to |
32 +=====================+=============================================+
33 | note.pitch | any note's MIDI pitch (integer 0–127) |
34 +---------------------+---------------------------------------------+
35 | note.pitch_class | pitch class name ("C", "C#", …, "B") |
36 +---------------------+---------------------------------------------+
37 | note.velocity | MIDI velocity (0–127) |
38 +---------------------+---------------------------------------------+
39 | note.channel | MIDI channel (0–15) |
40 +---------------------+---------------------------------------------+
41 | note.duration | note duration in beats (float) |
42 +---------------------+---------------------------------------------+
43 | bar | 1-indexed bar number (assumes 4/4) |
44 +---------------------+---------------------------------------------+
45 | track | workspace-relative MIDI file path |
46 +---------------------+---------------------------------------------+
47 | harmony.chord | detected chord name ("Cmaj", "Fdim7", …) |
48 +---------------------+---------------------------------------------+
49 | harmony.quality | chord quality suffix ("maj", "min", "dim"…) |
50 +---------------------+---------------------------------------------+
51 | author | commit author string |
52 +---------------------+---------------------------------------------+
53 | agent_id | agent_id from commit provenance |
54 +---------------------+---------------------------------------------+
55 | model_id | model_id from commit provenance |
56 +---------------------+---------------------------------------------+
57 | toolchain_id | toolchain_id from commit provenance |
58 +---------------------+---------------------------------------------+
59 """
60
61 from __future__ import annotations
62
63 import logging
64 import pathlib
65 import re
66 from dataclasses import dataclass, field
67 from typing import Literal, TypedDict
68
69 from muse.core.object_store import read_object
70 from muse.core.store import CommitRecord, get_commit_snapshot_manifest, read_commit
71 from muse.plugins.midi._query import (
72 NoteInfo,
73 detect_chord,
74 notes_by_bar,
75 walk_commits_for_track,
76 )
77 from muse.plugins.midi.midi_diff import extract_notes
78
79 logger = logging.getLogger(__name__)
80
81
82 # ---------------------------------------------------------------------------
83 # AST node dataclasses
84 # ---------------------------------------------------------------------------
85
86
87 @dataclass
88 class EqNode:
89 """Leaf comparison: ``field OP value``."""
90
91 field: str
92 op: Literal["==", "!=", ">", "<", ">=", "<="]
93 value: str | int | float
94
95
96 @dataclass
97 class AndNode:
98 """Logical AND of two sub-expressions."""
99
100 left: QueryNode
101 right: QueryNode
102
103
104 @dataclass
105 class OrNode:
106 """Logical OR of two sub-expressions."""
107
108 left: QueryNode
109 right: QueryNode
110
111
112 @dataclass
113 class NotNode:
114 """Logical NOT of a sub-expression."""
115
116 inner: QueryNode
117
118
119 QueryNode = EqNode | AndNode | OrNode | NotNode
120
121
122 # ---------------------------------------------------------------------------
123 # Query context and result types
124 # ---------------------------------------------------------------------------
125
126
127 @dataclass
128 class QueryContext:
129 """Data available to the evaluator for one bar in one track at one commit."""
130
131 commit: CommitRecord
132 track: str
133 bar: int
134 notes: list[NoteInfo]
135 chord: str
136 ticks_per_beat: int
137
138
139 class NoteDict(TypedDict):
140 """Serialisable representation of a note for query results."""
141
142 pitch: int
143 pitch_class: str
144 velocity: int
145 channel: int
146 beat: float
147 duration_beats: float
148
149
150 class QueryMatch(TypedDict):
151 """A single match returned by :func:`run_query`."""
152
153 commit_id: str
154 commit_short: str
155 commit_message: str
156 author: str
157 agent_id: str
158 committed_at: str
159 track: str
160 bar: int
161 notes: list[NoteDict]
162 chord: str
163 matched_on: str
164
165
166 # ---------------------------------------------------------------------------
167 # Tokenizer
168 # ---------------------------------------------------------------------------
169
170 _TOKEN_RE = re.compile(
171 r"""
172 (?P<LPAREN> \( ) |
173 (?P<RPAREN> \) ) |
174 (?P<OP> ==|!=|>=|<=|>|< ) |
175 (?P<KW> (?:and|or|not)(?![A-Za-z0-9_.]) ) |
176 (?P<STR> "(?:[^"\\]|\\.)*"|'(?:[^'\\]|\\.)*' ) |
177 (?P<NUM> -?\d+(?:\.\d+)? ) |
178 (?P<NAME> [A-Za-z_][A-Za-z0-9_.]* ) |
179 (?P<WS> \s+ )
180 """,
181 re.VERBOSE,
182 )
183
184 _OpLiteral = Literal["==", "!=", ">", "<", ">=", "<="]
185 _VALID_OPS: frozenset[str] = frozenset({"==", "!=", ">", "<", ">=", "<="})
186
187
188 @dataclass
189 class Token:
190 """A single lexed token."""
191
192 kind: str
193 value: str
194
195
196 def _tokenize(query: str) -> list[Token]:
197 """Convert *query* string to a flat list of :class:`Token` objects.
198
199 Whitespace tokens are discarded. Raises ``ValueError`` on unrecognised input.
200 """
201 tokens: list[Token] = []
202 for m in _TOKEN_RE.finditer(query):
203 kind = m.lastgroup
204 if kind is None or kind == "WS":
205 continue
206 tokens.append(Token(kind=kind, value=m.group()))
207 # Verify full coverage.
208 covered = sum(len(m.group()) for m in _TOKEN_RE.finditer(query) if m.lastgroup != "WS")
209 no_ws = re.sub(r"\s+", "", query)
210 if covered != len(no_ws):
211 raise ValueError(f"Unrecognised characters in query: {query!r}")
212 return tokens
213
214
215 # ---------------------------------------------------------------------------
216 # Recursive descent parser
217 # ---------------------------------------------------------------------------
218
219
220 class _Parser:
221 """Recursive descent parser for the music query DSL."""
222
223 def __init__(self, tokens: list[Token]) -> None:
224 self._tokens = tokens
225 self._pos = 0
226
227 def _peek(self) -> Token | None:
228 if self._pos < len(self._tokens):
229 return self._tokens[self._pos]
230 return None
231
232 def _consume(self, kind: str | None = None, value: str | None = None) -> Token:
233 tok = self._peek()
234 if tok is None:
235 raise ValueError("Unexpected end of query")
236 if kind is not None and tok.kind != kind:
237 raise ValueError(f"Expected {kind!r}, got {tok.kind!r} ({tok.value!r})")
238 if value is not None and tok.value != value:
239 raise ValueError(f"Expected {value!r}, got {tok.value!r}")
240 self._pos += 1
241 return tok
242
243 def parse(self) -> QueryNode:
244 node = self._or_expr()
245 if self._peek() is not None:
246 raise ValueError(f"Unexpected token: {self._peek()!r}")
247 return node
248
249 def _or_expr(self) -> QueryNode:
250 node = self._and_expr()
251 while (tok := self._peek()) is not None and tok.kind == "KW" and tok.value == "or":
252 self._consume()
253 right = self._and_expr()
254 node = OrNode(left=node, right=right)
255 return node
256
257 def _and_expr(self) -> QueryNode:
258 node = self._not_expr()
259 while (tok := self._peek()) is not None and tok.kind == "KW" and tok.value == "and":
260 self._consume()
261 right = self._not_expr()
262 node = AndNode(left=node, right=right)
263 return node
264
265 def _not_expr(self) -> QueryNode:
266 tok = self._peek()
267 if tok is not None and tok.kind == "KW" and tok.value == "not":
268 self._consume()
269 return NotNode(inner=self._not_expr())
270 return self._atom()
271
272 def _atom(self) -> QueryNode:
273 tok = self._peek()
274 if tok is None:
275 raise ValueError("Unexpected end of query in atom")
276 if tok.kind == "LPAREN":
277 self._consume("LPAREN")
278 node = self._or_expr()
279 self._consume("RPAREN")
280 return node
281 return self._comparison()
282
283 def _comparison(self) -> QueryNode:
284 field_tok = self._consume("NAME")
285 op_tok = self._consume("OP")
286 if op_tok.value not in _VALID_OPS:
287 raise ValueError(f"Invalid operator: {op_tok.value!r}")
288
289 val_tok = self._consume()
290 if val_tok.kind == "STR":
291 raw = val_tok.value[1:-1]
292 raw = raw.replace('\\"', '"').replace("\\'", "'")
293 value: str | int | float = raw
294 elif val_tok.kind == "NUM":
295 value = float(val_tok.value) if "." in val_tok.value else int(val_tok.value)
296 elif val_tok.kind == "NAME":
297 value = val_tok.value
298 else:
299 raise ValueError(f"Expected value, got {val_tok.kind!r}")
300
301 _op_map: dict[str, _OpLiteral] = {
302 "==": "==", "!=": "!=", ">": ">", "<": "<", ">=": ">=", "<=": "<=",
303 }
304 op_val = _op_map.get(op_tok.value)
305 if op_val is None:
306 raise ValueError(f"Invalid operator: {op_tok.value!r}")
307 return EqNode(
308 field=field_tok.value,
309 op=op_val,
310 value=value,
311 )
312
313
314 def parse_query(query_str: str) -> QueryNode:
315 """Parse a query string into an AST.
316
317 Args:
318 query_str: Music query expression.
319
320 Returns:
321 Root :data:`QueryNode` of the AST.
322
323 Raises:
324 ValueError: On parse error.
325 """
326 tokens = _tokenize(query_str)
327 return _Parser(tokens).parse()
328
329
330 # ---------------------------------------------------------------------------
331 # Evaluator
332 # ---------------------------------------------------------------------------
333
334
335 def _compare(actual: str | int | float, op: str, expected: str | int | float) -> bool:
336 """Apply a comparison operator to two values.
337
338 String comparisons use ``==`` / ``!=`` only (other operators raise).
339 Numeric comparisons support all six operators.
340 """
341 if isinstance(actual, str):
342 if op == "==":
343 return actual.lower() == str(expected).lower()
344 if op == "!=":
345 return actual.lower() != str(expected).lower()
346 raise ValueError(f"Operator {op!r} not supported for string values")
347
348 exp_num: float
349 if isinstance(expected, str):
350 try:
351 exp_num = float(expected)
352 except ValueError:
353 raise ValueError(f"Cannot compare numeric field to {expected!r}")
354 else:
355 exp_num = float(expected)
356
357 act_num = float(actual)
358 if op == "==":
359 return act_num == exp_num
360 if op == "!=":
361 return act_num != exp_num
362 if op == ">":
363 return act_num > exp_num
364 if op == "<":
365 return act_num < exp_num
366 if op == ">=":
367 return act_num >= exp_num
368 if op == "<=":
369 return act_num <= exp_num
370 raise ValueError(f"Unknown operator {op!r}")
371
372
373 def _resolve_field(field: str, ctx: QueryContext) -> list[str | int | float]:
374 """Resolve a field path to a list of candidate values from *ctx*.
375
376 Note fields return one value per note in the bar; all other fields
377 return a single-element list. The evaluator matches if *any* candidate
378 satisfies the predicate.
379 """
380 # --- Note fields ---
381 if field == "note.pitch":
382 return [n.pitch for n in ctx.notes]
383 if field == "note.pitch_class":
384 return [n.pitch_class_name for n in ctx.notes]
385 if field == "note.velocity":
386 return [n.velocity for n in ctx.notes]
387 if field == "note.channel":
388 return [n.channel for n in ctx.notes]
389 if field == "note.duration":
390 return [n.beat_duration for n in ctx.notes]
391 # --- Bar / track ---
392 if field == "bar":
393 return [ctx.bar]
394 if field == "track":
395 return [ctx.track]
396 # --- Harmony ---
397 if field == "harmony.chord":
398 return [ctx.chord]
399 if field == "harmony.quality":
400 for suffix in ("dim7", "maj7", "min7", "dom7", "sus2", "sus4", "aug", "dim", "maj", "min", "5"):
401 if ctx.chord.endswith(suffix):
402 return [suffix]
403 return [""]
404 # --- Commit provenance ---
405 if field == "author":
406 return [ctx.commit.author]
407 if field == "agent_id":
408 return [ctx.commit.agent_id]
409 if field == "model_id":
410 return [ctx.commit.model_id]
411 if field == "toolchain_id":
412 return [ctx.commit.toolchain_id]
413
414 raise ValueError(f"Unknown field: {field!r}")
415
416
417 def evaluate_node(node: QueryNode, ctx: QueryContext) -> bool:
418 """Recursively evaluate a query AST node against *ctx*.
419
420 Args:
421 node: The root (or sub) AST node.
422 ctx: Query context for the bar/track/commit being tested.
423
424 Returns:
425 ``True`` when the predicate matches, ``False`` otherwise.
426 """
427 if isinstance(node, EqNode):
428 try:
429 candidates = _resolve_field(node.field, ctx)
430 except ValueError:
431 return False
432 return any(_compare(c, node.op, node.value) for c in candidates)
433
434 if isinstance(node, AndNode):
435 return evaluate_node(node.left, ctx) and evaluate_node(node.right, ctx)
436
437 if isinstance(node, OrNode):
438 return evaluate_node(node.left, ctx) or evaluate_node(node.right, ctx)
439
440 if isinstance(node, NotNode):
441 return not evaluate_node(node.inner, ctx)
442
443 raise TypeError(f"Unknown AST node type: {type(node)}")
444
445
446 # ---------------------------------------------------------------------------
447 # Query runner
448 # ---------------------------------------------------------------------------
449
450
451 def run_query(
452 query_str: str,
453 root: pathlib.Path,
454 start_commit_id: str,
455 *,
456 track_filter: str | None = None,
457 from_commit_id: str | None = None,
458 max_commits: int = 10_000,
459 max_results: int = 1_000,
460 ) -> list[QueryMatch]:
461 """Evaluate a music query DSL expression over the commit history.
462
463 Walks the parent chain from *start_commit_id*, loading each MIDI track,
464 grouping notes by bar, and evaluating the query predicate against each
465 (commit, track, bar) triple.
466
467 Args:
468 query_str: Music query expression string.
469 root: Repository root.
470 start_commit_id: Start of the walk (inclusive; usually HEAD).
471 track_filter: Restrict search to a single MIDI file path.
472 from_commit_id: Stop before this commit ID (exclusive).
473 max_commits: Safety cap on commits walked.
474 max_results: Safety cap on results returned.
475
476 Returns:
477 List of :class:`QueryMatch` dicts, chronologically ordered
478 (oldest first).
479
480 Raises:
481 ValueError: When the query string cannot be parsed.
482 """
483 ast = parse_query(query_str)
484 results: list[QueryMatch] = []
485
486 commit_id: str | None = start_commit_id
487 seen: set[str] = set()
488 commits_walked = 0
489
490 while commit_id and commit_id not in seen and commits_walked < max_commits:
491 seen.add(commit_id)
492 commits_walked += 1
493
494 if commit_id == from_commit_id:
495 break
496
497 commit = read_commit(root, commit_id)
498 if commit is None:
499 break
500
501 manifest = get_commit_snapshot_manifest(root, commit_id) or {}
502
503 midi_paths = [
504 p for p in manifest
505 if p.lower().endswith(".mid")
506 and (track_filter is None or p == track_filter)
507 ]
508
509 for track_path in sorted(midi_paths):
510 obj_hash = manifest.get(track_path)
511 if obj_hash is None:
512 continue
513 raw = read_object(root, obj_hash)
514 if raw is None:
515 continue
516 try:
517 keys, tpb = extract_notes(raw)
518 except ValueError:
519 continue
520
521 notes = [NoteInfo.from_note_key(k, tpb) for k in keys]
522 bar_map = notes_by_bar(notes)
523
524 for bar_num, bar_notes in sorted(bar_map.items()):
525 pcs = frozenset(n.pitch_class for n in bar_notes)
526 chord = detect_chord(pcs)
527 ctx = QueryContext(
528 commit=commit,
529 track=track_path,
530 bar=bar_num,
531 notes=bar_notes,
532 chord=chord,
533 ticks_per_beat=tpb,
534 )
535 if evaluate_node(ast, ctx):
536 results.append(
537 QueryMatch(
538 commit_id=commit.commit_id,
539 commit_short=commit.commit_id[:8],
540 commit_message=commit.message,
541 author=commit.author,
542 agent_id=commit.agent_id,
543 committed_at=commit.committed_at.isoformat(),
544 track=track_path,
545 bar=bar_num,
546 notes=[
547 NoteDict(
548 pitch=n.pitch,
549 pitch_class=n.pitch_class_name,
550 velocity=n.velocity,
551 channel=n.channel,
552 beat=round(n.beat, 4),
553 duration_beats=round(n.beat_duration, 4),
554 )
555 for n in bar_notes
556 ],
557 chord=chord,
558 matched_on=query_str,
559 )
560 )
561 if len(results) >= max_results:
562 logger.warning(
563 "⚠️ music-query hit max_results=%d — truncating",
564 max_results,
565 )
566 results.reverse()
567 return results
568
569 commit_id = commit.parent_commit_id
570
571 results.reverse() # oldest-first
572 return results