cgcardona / muse public
_invariants.py python
560 lines 18.5 KB
6d8ca4ac feat: god-tier MIDI dimension expansion + full supercharge architecture Gabriel Cardona <gabriel@tellurstori.com> 1d ago
1 """Musical invariants engine for the Muse music plugin.
2
3 Invariants are semantic rules that a MIDI track must satisfy. They are
4 evaluated at commit time, merge time, or on-demand via ``muse music-check``.
5 Violations are reported with human-readable descriptions, severity levels,
6 and structured addresses for programmatic consumers.
7
8 Rule file format (TOML)
9 -----------------------
10 Rules are declared in ``.muse/music_invariants.toml`` (default path).
11 Example::
12
13 [[rule]]
14 name = "max_polyphony"
15 severity = "error"
16 scope = "track"
17 rule_type = "max_polyphony"
18
19 [rule.params]
20 max_simultaneous = 6
21
22 [[rule]]
23 name = "keep_in_range"
24 severity = "warning"
25 scope = "track"
26 rule_type = "pitch_range"
27
28 [rule.params]
29 min_pitch = 24
30 max_pitch = 108
31
32 [[rule]]
33 name = "no_fifths"
34 severity = "warning"
35 scope = "voice_pair"
36 rule_type = "no_parallel_fifths"
37
38 [[rule]]
39 name = "consistent_key"
40 severity = "info"
41 scope = "track"
42 rule_type = "key_consistency"
43
44 [rule.params]
45 threshold = 0.15
46
47 Built-in rule types
48 -------------------
49
50 ``max_polyphony``
51 Detects bars where more than *max_simultaneous* notes overlap at any
52 tick position. Uses a sweep-line algorithm over start/end tick events.
53
54 ``pitch_range``
55 Detects any note with ``pitch < min_pitch`` or ``pitch > max_pitch``.
56
57 ``key_consistency``
58 Detects notes whose pitch class is highly inconsistent with the key
59 estimated by the Krumhansl-Schmuckler algorithm. Fires when the ratio
60 of "foreign" pitch classes exceeds *threshold*.
61
62 ``no_parallel_fifths``
63 Detects consecutive bars where the lowest voice and the second-lowest
64 voice both move by a perfect fifth in parallel (a classical counterpoint
65 violation). Best-effort heuristic — voice assignment is implicit.
66
67 Severity levels
68 ---------------
69 - ``"error"`` — must be resolved before committing (when ``--strict`` is set).
70 - ``"warning"`` — reported but does not block commits.
71 - ``"info"`` — informational; surfaced in ``muse music-check`` output only.
72
73 Public API
74 ----------
75 - :class:`InvariantRule` — rule declaration TypedDict.
76 - :class:`InvariantViolation` — single violation record TypedDict.
77 - :class:`InvariantReport` — full report for one commit / track.
78 - :func:`load_invariant_rules` — load from TOML file with defaults fallback.
79 - :func:`run_invariants` — evaluate all rules against a commit.
80 """
81 from __future__ import annotations
82
83 import logging
84 import pathlib
85 from typing import Literal, TypedDict
86
87 from muse.core.object_store import read_object
88 from muse.core.store import get_commit_snapshot_manifest
89 from muse.plugins.music._query import NoteInfo, key_signature_guess, notes_by_bar
90 from muse.plugins.music.midi_diff import extract_notes
91
92 logger = logging.getLogger(__name__)
93
94 _DEFAULT_RULES_FILE = ".muse/music_invariants.toml"
95
96
97 # ---------------------------------------------------------------------------
98 # Types
99 # ---------------------------------------------------------------------------
100
101
102 class _InvariantRuleRequired(TypedDict):
103 name: str
104 severity: Literal["info", "warning", "error"]
105 scope: Literal["track", "bar", "voice_pair", "global"]
106 rule_type: str
107
108
109 class InvariantRule(_InvariantRuleRequired, total=False):
110 """Declaration of one musical invariant rule.
111
112 ``name`` Human-readable rule identifier (unique within a rule set).
113 ``severity`` Violation severity: ``"info"``, ``"warning"``, or ``"error"``.
114 ``scope`` Granularity: ``"track"``, ``"bar"``, ``"voice_pair"``, ``"global"``.
115 ``rule_type`` Built-in type string: ``"max_polyphony"``, ``"pitch_range"``,
116 ``"key_consistency"``, ``"no_parallel_fifths"``.
117 ``params`` Rule-specific parameter dict.
118 """
119
120 params: dict[str, str | int | float]
121
122
123 class InvariantViolation(TypedDict):
124 """A single invariant violation record.
125
126 ``rule_name`` The name of the rule that fired.
127 ``severity`` Severity level from the rule declaration.
128 ``track`` Workspace-relative MIDI file path.
129 ``bar`` 1-indexed bar number (0 for track-level violations).
130 ``description`` Human-readable explanation of what was violated.
131 ``addresses`` Note addresses or other domain addresses involved.
132 """
133
134 rule_name: str
135 severity: Literal["info", "warning", "error"]
136 track: str
137 bar: int
138 description: str
139 addresses: list[str]
140
141
142 class InvariantReport(TypedDict):
143 """Full invariant check report for one commit.
144
145 ``commit_id`` The commit that was checked.
146 ``violations`` All violations found, sorted by track then bar.
147 ``rules_checked`` Number of rules evaluated.
148 ``has_errors`` True when any violation has severity ``"error"``.
149 ``has_warnings`` True when any violation has severity ``"warning"``.
150 """
151
152 commit_id: str
153 violations: list[InvariantViolation]
154 rules_checked: int
155 has_errors: bool
156 has_warnings: bool
157
158
159 # ---------------------------------------------------------------------------
160 # Built-in rule implementations
161 # ---------------------------------------------------------------------------
162
163
164 def check_max_polyphony(
165 notes: list[NoteInfo],
166 track: str,
167 rule_name: str,
168 severity: Literal["info", "warning", "error"],
169 *,
170 max_simultaneous: int = 6,
171 ) -> list[InvariantViolation]:
172 """Find bars where simultaneous note count exceeds *max_simultaneous*.
173
174 Uses a tick-based sweep-line over (start_tick, end_tick) intervals.
175 Reports one violation per offending bar.
176
177 Args:
178 notes: All notes in the track.
179 track: Track file path for violation records.
180 rule_name: Rule identifier string.
181 severity: Violation severity.
182 max_simultaneous: Maximum allowed simultaneous notes.
183
184 Returns:
185 List of :class:`InvariantViolation` records.
186 """
187 violations: list[InvariantViolation] = []
188 bars = notes_by_bar(notes)
189
190 for bar_num, bar_notes in sorted(bars.items()):
191 # Collect all tick events: +1 for note_on, -1 for note_off.
192 events: list[tuple[int, int]] = []
193 for n in bar_notes:
194 events.append((n.start_tick, 1))
195 events.append((n.start_tick + n.duration_ticks, -1))
196 events.sort(key=lambda e: (e[0], e[1])) # off before on at same tick
197
198 current = 0
199 peak = 0
200 peak_tick = 0
201 for tick, delta in events:
202 current += delta
203 if current > peak:
204 peak = current
205 peak_tick = tick
206
207 if peak > max_simultaneous:
208 violations.append(
209 InvariantViolation(
210 rule_name=rule_name,
211 severity=severity,
212 track=track,
213 bar=bar_num,
214 description=(
215 f"Polyphony reached {peak} simultaneous notes at tick {peak_tick} "
216 f"(max allowed: {max_simultaneous})"
217 ),
218 addresses=[f"bar:{bar_num}:tick:{peak_tick}"],
219 )
220 )
221
222 return violations
223
224
225 def check_pitch_range(
226 notes: list[NoteInfo],
227 track: str,
228 rule_name: str,
229 severity: Literal["info", "warning", "error"],
230 *,
231 min_pitch: int = 0,
232 max_pitch: int = 127,
233 ) -> list[InvariantViolation]:
234 """Find notes outside the allowed MIDI pitch range.
235
236 Args:
237 notes: All notes in the track.
238 track: Track file path.
239 rule_name: Rule identifier.
240 severity: Violation severity.
241 min_pitch: Lowest allowed MIDI pitch (inclusive).
242 max_pitch: Highest allowed MIDI pitch (inclusive).
243
244 Returns:
245 One :class:`InvariantViolation` per out-of-range note.
246 """
247 violations: list[InvariantViolation] = []
248 for note in notes:
249 if note.pitch < min_pitch or note.pitch > max_pitch:
250 violations.append(
251 InvariantViolation(
252 rule_name=rule_name,
253 severity=severity,
254 track=track,
255 bar=note.bar,
256 description=(
257 f"Note {note.pitch_name} (MIDI {note.pitch}) is outside "
258 f"allowed range [{min_pitch}, {max_pitch}]"
259 ),
260 addresses=[f"bar:{note.bar}:pitch:{note.pitch}"],
261 )
262 )
263 return violations
264
265
266 def check_key_consistency(
267 notes: list[NoteInfo],
268 track: str,
269 rule_name: str,
270 severity: Literal["info", "warning", "error"],
271 *,
272 threshold: float = 0.15,
273 ) -> list[InvariantViolation]:
274 """Detect notes whose pitch class is inconsistent with the guessed key.
275
276 Estimates the key using the Krumhansl-Schmuckler algorithm, then counts
277 the fraction of notes that use a pitch class not diatonic to that key.
278 Fires when the foreign-note ratio exceeds *threshold*.
279
280 Args:
281 notes: All notes in the track.
282 track: Track file path.
283 rule_name: Rule identifier.
284 severity: Violation severity.
285 threshold: Maximum allowed ratio of foreign pitch classes (0.0–1.0).
286
287 Returns:
288 Zero or one :class:`InvariantViolation` for the track.
289 """
290 if not notes:
291 return []
292
293 key_guess = key_signature_guess(notes)
294 # Parse key guess string e.g. "G major" or "D minor".
295 parts = key_guess.split()
296 if len(parts) < 2:
297 return []
298
299 root_name = parts[0]
300 mode = parts[1]
301
302 pitch_classes = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
303 root_idx = pitch_classes.index(root_name) if root_name in pitch_classes else -1
304 if root_idx < 0:
305 return []
306
307 # Diatonic pitch classes for major and natural minor scales.
308 major_steps = [0, 2, 4, 5, 7, 9, 11]
309 minor_steps = [0, 2, 3, 5, 7, 8, 10]
310 steps = major_steps if mode == "major" else minor_steps
311 diatonic_pcs = frozenset((root_idx + s) % 12 for s in steps)
312
313 foreign = sum(1 for n in notes if n.pitch_class not in diatonic_pcs)
314 ratio = foreign / len(notes)
315
316 if ratio > threshold:
317 return [
318 InvariantViolation(
319 rule_name=rule_name,
320 severity=severity,
321 track=track,
322 bar=0,
323 description=(
324 f"{foreign}/{len(notes)} notes ({ratio:.0%}) use pitch classes "
325 f"foreign to estimated key {key_guess} "
326 f"(threshold: {threshold:.0%})"
327 ),
328 addresses=[track],
329 )
330 ]
331 return []
332
333
334 def check_no_parallel_fifths(
335 notes: list[NoteInfo],
336 track: str,
337 rule_name: str,
338 severity: Literal["info", "warning", "error"],
339 ) -> list[InvariantViolation]:
340 """Detect consecutive bars with parallel perfect fifth motion.
341
342 Heuristic: for each pair of consecutive bars, find the two lowest-pitched
343 notes (approximating bass and tenor voices) and check whether both voices
344 move by a perfect fifth (7 semitones) in the same direction.
345
346 This is a best-effort approximation — accurate voice separation would
347 require dedicated voice-leading analysis beyond this scope.
348
349 Args:
350 notes: All notes in the track.
351 track: Track file path.
352 rule_name: Rule identifier.
353 severity: Violation severity.
354
355 Returns:
356 One :class:`InvariantViolation` per detected parallel-fifth bar pair.
357 """
358 violations: list[InvariantViolation] = []
359 bars = notes_by_bar(notes)
360 sorted_bars = sorted(bars.keys())
361
362 for i in range(len(sorted_bars) - 1):
363 bar_a = sorted_bars[i]
364 bar_b = sorted_bars[i + 1]
365 notes_a = sorted(bars[bar_a], key=lambda n: n.pitch)
366 notes_b = sorted(bars[bar_b], key=lambda n: n.pitch)
367
368 if len(notes_a) < 2 or len(notes_b) < 2:
369 continue
370
371 # Take two lowest pitches as approximated bass + tenor voices.
372 v1_a, v2_a = notes_a[0].pitch, notes_a[1].pitch
373 v1_b, v2_b = notes_b[0].pitch, notes_b[1].pitch
374
375 # Interval between voices in each bar.
376 interval_a = abs(v2_a - v1_a) % 12
377 interval_b = abs(v2_b - v1_b) % 12
378
379 # Both form a perfect fifth (7 semitones modulo octave)?
380 if interval_a == 7 and interval_b == 7:
381 # Both voices moved in the same direction?
382 motion_v1 = v1_b - v1_a
383 motion_v2 = v2_b - v2_a
384 if (motion_v1 > 0 and motion_v2 > 0) or (motion_v1 < 0 and motion_v2 < 0):
385 violations.append(
386 InvariantViolation(
387 rule_name=rule_name,
388 severity=severity,
389 track=track,
390 bar=bar_b,
391 description=(
392 f"Parallel fifths between bars {bar_a} and {bar_b}: "
393 f"lower voice {notes_a[0].pitch_name}→{notes_b[0].pitch_name}, "
394 f"upper voice {notes_a[1].pitch_name}→{notes_b[1].pitch_name}"
395 ),
396 addresses=[f"bar:{bar_a}", f"bar:{bar_b}"],
397 )
398 )
399
400 return violations
401
402
403 # ---------------------------------------------------------------------------
404 # Rule loading
405 # ---------------------------------------------------------------------------
406
407 _DEFAULT_RULE_SET: list[InvariantRule] = [
408 InvariantRule(
409 name="max_polyphony",
410 severity="warning",
411 scope="track",
412 rule_type="max_polyphony",
413 params={"max_simultaneous": 8},
414 ),
415 InvariantRule(
416 name="pitch_range",
417 severity="warning",
418 scope="track",
419 rule_type="pitch_range",
420 params={"min_pitch": 0, "max_pitch": 127},
421 ),
422 ]
423
424
425 def load_invariant_rules(rules_file: pathlib.Path | None = None) -> list[InvariantRule]:
426 """Load invariant rules from a TOML file, falling back to defaults.
427
428 Requires ``tomllib`` (Python 3.11+) for TOML parsing. If the file does
429 not exist or cannot be parsed, the default rule set is returned.
430
431 Args:
432 rules_file: Path to the TOML rule file. ``None`` means use defaults.
433
434 Returns:
435 List of :class:`InvariantRule` dicts.
436 """
437 if rules_file is None or not rules_file.exists():
438 return list(_DEFAULT_RULE_SET)
439
440 try:
441 import tomllib
442
443 with rules_file.open("rb") as fh:
444 data = tomllib.load(fh)
445
446 rules: list[InvariantRule] = []
447 for raw in data.get("rule", []):
448 _valid_severities: dict[str, Literal["info", "warning", "error"]] = {
449 "info": "info", "warning": "warning", "error": "error",
450 }
451 _valid_scopes: dict[str, Literal["track", "bar", "voice_pair", "global"]] = {
452 "track": "track", "bar": "bar", "voice_pair": "voice_pair", "global": "global",
453 }
454 sev = _valid_severities.get(str(raw.get("severity", "")), "warning")
455 scope = _valid_scopes.get(str(raw.get("scope", "")), "track")
456 rule = InvariantRule(
457 name=str(raw.get("name", "unnamed")),
458 severity=sev,
459 scope=scope,
460 rule_type=str(raw.get("rule_type", "")),
461 )
462 if "params" in raw:
463 rule["params"] = raw["params"]
464 rules.append(rule)
465 return rules if rules else list(_DEFAULT_RULE_SET)
466
467 except Exception as exc:
468 logger.warning("⚠️ Could not load invariant rules from %s: %s", rules_file, exc)
469 return list(_DEFAULT_RULE_SET)
470
471
472 # ---------------------------------------------------------------------------
473 # Main runner
474 # ---------------------------------------------------------------------------
475
476
477 def run_invariants(
478 root: "pathlib.Path",
479 commit_id: str,
480 rules: list[InvariantRule],
481 *,
482 track_filter: str | None = None,
483 ) -> InvariantReport:
484 """Evaluate all *rules* against every MIDI track in *commit_id*.
485
486 Args:
487 root: Repository root.
488 commit_id: Commit to check.
489 rules: List of :class:`InvariantRule` declarations.
490 track_filter: Restrict check to a single MIDI file path.
491
492 Returns:
493 An :class:`InvariantReport` with all violations found.
494 """
495 import pathlib as _pathlib
496
497 all_violations: list[InvariantViolation] = []
498 manifest = get_commit_snapshot_manifest(root, commit_id) or {}
499
500 midi_paths = [
501 p for p in manifest
502 if p.lower().endswith(".mid")
503 and (track_filter is None or p == track_filter)
504 ]
505
506 for track_path in sorted(midi_paths):
507 obj_hash = manifest.get(track_path)
508 if obj_hash is None:
509 continue
510 raw = read_object(root, obj_hash)
511 if raw is None:
512 continue
513 try:
514 keys, tpb = extract_notes(raw)
515 except ValueError as exc:
516 logger.debug("Cannot parse MIDI %r: %s", track_path, exc)
517 continue
518
519 notes = [NoteInfo.from_note_key(k, tpb) for k in keys]
520
521 for rule in rules:
522 rt = rule["rule_type"]
523 sev = rule["severity"]
524 params = rule.get("params", {})
525 name = rule["name"]
526
527 if rt == "max_polyphony":
528 max_sim = int(params.get("max_simultaneous", 8))
529 all_violations.extend(
530 check_max_polyphony(notes, track_path, name, sev, max_simultaneous=max_sim)
531 )
532 elif rt == "pitch_range":
533 min_p = int(params.get("min_pitch", 0))
534 max_p = int(params.get("max_pitch", 127))
535 all_violations.extend(
536 check_pitch_range(notes, track_path, name, sev, min_pitch=min_p, max_pitch=max_p)
537 )
538 elif rt == "key_consistency":
539 thresh = float(params.get("threshold", 0.15))
540 all_violations.extend(
541 check_key_consistency(notes, track_path, name, sev, threshold=thresh)
542 )
543 elif rt == "no_parallel_fifths":
544 all_violations.extend(
545 check_no_parallel_fifths(notes, track_path, name, sev)
546 )
547 else:
548 logger.debug("Unknown rule_type %r in rule %r — skipped", rt, name)
549
550 all_violations.sort(key=lambda v: (v["track"], v["bar"]))
551 has_errors = any(v["severity"] == "error" for v in all_violations)
552 has_warnings = any(v["severity"] == "warning" for v in all_violations)
553
554 return InvariantReport(
555 commit_id=commit_id,
556 violations=all_violations,
557 rules_checked=len(rules) * len(midi_paths),
558 has_errors=has_errors,
559 has_warnings=has_warnings,
560 )