gabriel / muse public
_invariants.py python
591 lines 19.6 KB
bda49bdb feat: redesign .museignore as TOML with domain-scoped sections (#100) Gabriel Cardona <cgcardona@gmail.com> 5d ago
1 """Musical invariants engine for the Muse MIDI plugin.
2
3 Invariants are semantic rules that a MIDI track must satisfy. They are
4 evaluated at commit time, merge time, or on-demand via ``muse music-check``.
5 Violations are reported with human-readable descriptions, severity levels,
6 and structured addresses for programmatic consumers.
7
8 Rule file format (TOML)
9 -----------------------
10 Rules are declared in ``.muse/midi_invariants.toml`` (default path).
11 Example::
12
13 [[rule]]
14 name = "max_polyphony"
15 severity = "error"
16 scope = "track"
17 rule_type = "max_polyphony"
18
19 [rule.params]
20 max_simultaneous = 6
21
22 [[rule]]
23 name = "keep_in_range"
24 severity = "warning"
25 scope = "track"
26 rule_type = "pitch_range"
27
28 [rule.params]
29 min_pitch = 24
30 max_pitch = 108
31
32 [[rule]]
33 name = "no_fifths"
34 severity = "warning"
35 scope = "voice_pair"
36 rule_type = "no_parallel_fifths"
37
38 [[rule]]
39 name = "consistent_key"
40 severity = "info"
41 scope = "track"
42 rule_type = "key_consistency"
43
44 [rule.params]
45 threshold = 0.15
46
47 Built-in rule types
48 -------------------
49
50 ``max_polyphony``
51 Detects bars where more than *max_simultaneous* notes overlap at any
52 tick position. Uses a sweep-line algorithm over start/end tick events.
53
54 ``pitch_range``
55 Detects any note with ``pitch < min_pitch`` or ``pitch > max_pitch``.
56
57 ``key_consistency``
58 Detects notes whose pitch class is highly inconsistent with the key
59 estimated by the Krumhansl-Schmuckler algorithm. Fires when the ratio
60 of "foreign" pitch classes exceeds *threshold*.
61
62 ``no_parallel_fifths``
63 Detects consecutive bars where the lowest voice and the second-lowest
64 voice both move by a perfect fifth in parallel (a classical counterpoint
65 violation). Best-effort heuristic — voice assignment is implicit.
66
67 Severity levels
68 ---------------
69 - ``"error"`` — must be resolved before committing (when ``--strict`` is set).
70 - ``"warning"`` — reported but does not block commits.
71 - ``"info"`` — informational; surfaced in ``muse music-check`` output only.
72
73 Public API
74 ----------
75 - :class:`InvariantRule` — rule declaration TypedDict.
76 - :class:`InvariantViolation` — single violation record TypedDict.
77 - :class:`InvariantReport` — full report for one commit / track.
78 - :func:`load_invariant_rules` — load from TOML file with defaults fallback.
79 - :func:`run_invariants` — evaluate all rules against a commit.
80 """
81
82 from __future__ import annotations
83
84 import logging
85 import pathlib
86 from typing import Literal, TypedDict
87
88 from muse.core.invariants import BaseReport, BaseViolation, make_report
89 from muse.core.object_store import read_object
90 from muse.core.store import get_commit_snapshot_manifest
91 from muse.plugins.midi._query import NoteInfo, key_signature_guess, notes_by_bar
92 from muse.plugins.midi.midi_diff import extract_notes
93
94 logger = logging.getLogger(__name__)
95
96 _DEFAULT_RULES_FILE = ".muse/midi_invariants.toml"
97
98
99 # ---------------------------------------------------------------------------
100 # Types
101 # ---------------------------------------------------------------------------
102
103
104 class _InvariantRuleRequired(TypedDict):
105 name: str
106 severity: Literal["info", "warning", "error"]
107 scope: Literal["track", "bar", "voice_pair", "global"]
108 rule_type: str
109
110
111 class InvariantRule(_InvariantRuleRequired, total=False):
112 """Declaration of one MIDI invariant rule.
113
114 ``name`` Human-readable rule identifier (unique within a rule set).
115 ``severity`` Violation severity: ``"info"``, ``"warning"``, or ``"error"``.
116 ``scope`` Granularity: ``"track"``, ``"bar"``, ``"voice_pair"``, ``"global"``.
117 ``rule_type`` Built-in type string: ``"max_polyphony"``, ``"pitch_range"``,
118 ``"key_consistency"``, ``"no_parallel_fifths"``.
119 ``params`` Rule-specific parameter dict.
120 """
121
122 params: dict[str, str | int | float]
123
124
125 class InvariantViolation(TypedDict):
126 """A single invariant violation record.
127
128 ``rule_name`` The name of the rule that fired.
129 ``severity`` Severity level from the rule declaration.
130 ``track`` Workspace-relative MIDI file path.
131 ``bar`` 1-indexed bar number (0 for track-level violations).
132 ``description`` Human-readable explanation of what was violated.
133 ``addresses`` Note addresses or other domain addresses involved.
134 """
135
136 rule_name: str
137 severity: Literal["info", "warning", "error"]
138 track: str
139 bar: int
140 description: str
141 addresses: list[str]
142
143
144 class InvariantReport(TypedDict):
145 """Full invariant check report for one commit.
146
147 ``commit_id`` The commit that was checked.
148 ``violations`` All violations found, sorted by track then bar.
149 ``rules_checked`` Number of rules evaluated.
150 ``has_errors`` True when any violation has severity ``"error"``.
151 ``has_warnings`` True when any violation has severity ``"warning"``.
152 """
153
154 commit_id: str
155 violations: list[InvariantViolation]
156 rules_checked: int
157 has_errors: bool
158 has_warnings: bool
159
160
161 # ---------------------------------------------------------------------------
162 # Built-in rule implementations
163 # ---------------------------------------------------------------------------
164
165
166 def check_max_polyphony(
167 notes: list[NoteInfo],
168 track: str,
169 rule_name: str,
170 severity: Literal["info", "warning", "error"],
171 *,
172 max_simultaneous: int = 6,
173 ) -> list[InvariantViolation]:
174 """Find bars where simultaneous note count exceeds *max_simultaneous*.
175
176 Uses a tick-based sweep-line over (start_tick, end_tick) intervals.
177 Reports one violation per offending bar.
178
179 Args:
180 notes: All notes in the track.
181 track: Track file path for violation records.
182 rule_name: Rule identifier string.
183 severity: Violation severity.
184 max_simultaneous: Maximum allowed simultaneous notes.
185
186 Returns:
187 List of :class:`InvariantViolation` records.
188 """
189 violations: list[InvariantViolation] = []
190 bars = notes_by_bar(notes)
191
192 for bar_num, bar_notes in sorted(bars.items()):
193 # Collect all tick events: +1 for note_on, -1 for note_off.
194 events: list[tuple[int, int]] = []
195 for n in bar_notes:
196 events.append((n.start_tick, 1))
197 events.append((n.start_tick + n.duration_ticks, -1))
198 events.sort(key=lambda e: (e[0], e[1])) # off before on at same tick
199
200 current = 0
201 peak = 0
202 peak_tick = 0
203 for tick, delta in events:
204 current += delta
205 if current > peak:
206 peak = current
207 peak_tick = tick
208
209 if peak > max_simultaneous:
210 violations.append(
211 InvariantViolation(
212 rule_name=rule_name,
213 severity=severity,
214 track=track,
215 bar=bar_num,
216 description=(
217 f"Polyphony reached {peak} simultaneous notes at tick {peak_tick} "
218 f"(max allowed: {max_simultaneous})"
219 ),
220 addresses=[f"bar:{bar_num}:tick:{peak_tick}"],
221 )
222 )
223
224 return violations
225
226
227 def check_pitch_range(
228 notes: list[NoteInfo],
229 track: str,
230 rule_name: str,
231 severity: Literal["info", "warning", "error"],
232 *,
233 min_pitch: int = 0,
234 max_pitch: int = 127,
235 ) -> list[InvariantViolation]:
236 """Find notes outside the allowed MIDI pitch range.
237
238 Args:
239 notes: All notes in the track.
240 track: Track file path.
241 rule_name: Rule identifier.
242 severity: Violation severity.
243 min_pitch: Lowest allowed MIDI pitch (inclusive).
244 max_pitch: Highest allowed MIDI pitch (inclusive).
245
246 Returns:
247 One :class:`InvariantViolation` per out-of-range note.
248 """
249 violations: list[InvariantViolation] = []
250 for note in notes:
251 if note.pitch < min_pitch or note.pitch > max_pitch:
252 violations.append(
253 InvariantViolation(
254 rule_name=rule_name,
255 severity=severity,
256 track=track,
257 bar=note.bar,
258 description=(
259 f"Note {note.pitch_name} (MIDI {note.pitch}) is outside "
260 f"allowed range [{min_pitch}, {max_pitch}]"
261 ),
262 addresses=[f"bar:{note.bar}:pitch:{note.pitch}"],
263 )
264 )
265 return violations
266
267
268 def check_key_consistency(
269 notes: list[NoteInfo],
270 track: str,
271 rule_name: str,
272 severity: Literal["info", "warning", "error"],
273 *,
274 threshold: float = 0.15,
275 ) -> list[InvariantViolation]:
276 """Detect notes whose pitch class is inconsistent with the guessed key.
277
278 Estimates the key using the Krumhansl-Schmuckler algorithm, then counts
279 the fraction of notes that use a pitch class not diatonic to that key.
280 Fires when the foreign-note ratio exceeds *threshold*.
281
282 Args:
283 notes: All notes in the track.
284 track: Track file path.
285 rule_name: Rule identifier.
286 severity: Violation severity.
287 threshold: Maximum allowed ratio of foreign pitch classes (0.0–1.0).
288
289 Returns:
290 Zero or one :class:`InvariantViolation` for the track.
291 """
292 if not notes:
293 return []
294
295 key_guess = key_signature_guess(notes)
296 # Parse key guess string e.g. "G major" or "D minor".
297 parts = key_guess.split()
298 if len(parts) < 2:
299 return []
300
301 root_name = parts[0]
302 mode = parts[1]
303
304 pitch_classes = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
305 root_idx = pitch_classes.index(root_name) if root_name in pitch_classes else -1
306 if root_idx < 0:
307 return []
308
309 # Diatonic pitch classes for major and natural minor scales.
310 major_steps = [0, 2, 4, 5, 7, 9, 11]
311 minor_steps = [0, 2, 3, 5, 7, 8, 10]
312 steps = major_steps if mode == "major" else minor_steps
313 diatonic_pcs = frozenset((root_idx + s) % 12 for s in steps)
314
315 foreign = sum(1 for n in notes if n.pitch_class not in diatonic_pcs)
316 ratio = foreign / len(notes)
317
318 if ratio > threshold:
319 return [
320 InvariantViolation(
321 rule_name=rule_name,
322 severity=severity,
323 track=track,
324 bar=0,
325 description=(
326 f"{foreign}/{len(notes)} notes ({ratio:.0%}) use pitch classes "
327 f"foreign to estimated key {key_guess} "
328 f"(threshold: {threshold:.0%})"
329 ),
330 addresses=[track],
331 )
332 ]
333 return []
334
335
336 def check_no_parallel_fifths(
337 notes: list[NoteInfo],
338 track: str,
339 rule_name: str,
340 severity: Literal["info", "warning", "error"],
341 ) -> list[InvariantViolation]:
342 """Detect consecutive bars with parallel perfect fifth motion.
343
344 Heuristic: for each pair of consecutive bars, find the two lowest-pitched
345 notes (approximating bass and tenor voices) and check whether both voices
346 move by a perfect fifth (7 semitones) in the same direction.
347
348 This is a best-effort approximation — accurate voice separation would
349 require dedicated voice-leading analysis beyond this scope.
350
351 Args:
352 notes: All notes in the track.
353 track: Track file path.
354 rule_name: Rule identifier.
355 severity: Violation severity.
356
357 Returns:
358 One :class:`InvariantViolation` per detected parallel-fifth bar pair.
359 """
360 violations: list[InvariantViolation] = []
361 bars = notes_by_bar(notes)
362 sorted_bars = sorted(bars.keys())
363
364 for i in range(len(sorted_bars) - 1):
365 bar_a = sorted_bars[i]
366 bar_b = sorted_bars[i + 1]
367 notes_a = sorted(bars[bar_a], key=lambda n: n.pitch)
368 notes_b = sorted(bars[bar_b], key=lambda n: n.pitch)
369
370 if len(notes_a) < 2 or len(notes_b) < 2:
371 continue
372
373 # Take two lowest pitches as approximated bass + tenor voices.
374 v1_a, v2_a = notes_a[0].pitch, notes_a[1].pitch
375 v1_b, v2_b = notes_b[0].pitch, notes_b[1].pitch
376
377 # Interval between voices in each bar.
378 interval_a = abs(v2_a - v1_a) % 12
379 interval_b = abs(v2_b - v1_b) % 12
380
381 # Both form a perfect fifth (7 semitones modulo octave)?
382 if interval_a == 7 and interval_b == 7:
383 # Both voices moved in the same direction?
384 motion_v1 = v1_b - v1_a
385 motion_v2 = v2_b - v2_a
386 if (motion_v1 > 0 and motion_v2 > 0) or (motion_v1 < 0 and motion_v2 < 0):
387 violations.append(
388 InvariantViolation(
389 rule_name=rule_name,
390 severity=severity,
391 track=track,
392 bar=bar_b,
393 description=(
394 f"Parallel fifths between bars {bar_a} and {bar_b}: "
395 f"lower voice {notes_a[0].pitch_name}→{notes_b[0].pitch_name}, "
396 f"upper voice {notes_a[1].pitch_name}→{notes_b[1].pitch_name}"
397 ),
398 addresses=[f"bar:{bar_a}", f"bar:{bar_b}"],
399 )
400 )
401
402 return violations
403
404
405 # ---------------------------------------------------------------------------
406 # Rule loading
407 # ---------------------------------------------------------------------------
408
409 _DEFAULT_RULE_SET: list[InvariantRule] = [
410 InvariantRule(
411 name="max_polyphony",
412 severity="warning",
413 scope="track",
414 rule_type="max_polyphony",
415 params={"max_simultaneous": 8},
416 ),
417 InvariantRule(
418 name="pitch_range",
419 severity="warning",
420 scope="track",
421 rule_type="pitch_range",
422 params={"min_pitch": 0, "max_pitch": 127},
423 ),
424 ]
425
426
427 def load_invariant_rules(rules_file: pathlib.Path | None = None) -> list[InvariantRule]:
428 """Load invariant rules from a TOML file, falling back to defaults.
429
430 Requires ``tomllib`` (Python 3.11+) for TOML parsing. If the file does
431 not exist or cannot be parsed, the default rule set is returned.
432
433 Args:
434 rules_file: Path to the TOML rule file. ``None`` means use defaults.
435
436 Returns:
437 List of :class:`InvariantRule` dicts.
438 """
439 if rules_file is None or not rules_file.exists():
440 return list(_DEFAULT_RULE_SET)
441
442 try:
443 import tomllib
444
445 with rules_file.open("rb") as fh:
446 data = tomllib.load(fh)
447
448 rules: list[InvariantRule] = []
449 for raw in data.get("rule", []):
450 _valid_severities: dict[str, Literal["info", "warning", "error"]] = {
451 "info": "info", "warning": "warning", "error": "error",
452 }
453 _valid_scopes: dict[str, Literal["track", "bar", "voice_pair", "global"]] = {
454 "track": "track", "bar": "bar", "voice_pair": "voice_pair", "global": "global",
455 }
456 sev = _valid_severities.get(str(raw.get("severity", "")), "warning")
457 scope = _valid_scopes.get(str(raw.get("scope", "")), "track")
458 rule = InvariantRule(
459 name=str(raw.get("name", "unnamed")),
460 severity=sev,
461 scope=scope,
462 rule_type=str(raw.get("rule_type", "")),
463 )
464 if "params" in raw:
465 rule["params"] = raw["params"]
466 rules.append(rule)
467 return rules if rules else list(_DEFAULT_RULE_SET)
468
469 except Exception as exc:
470 logger.warning("⚠️ Could not load invariant rules from %s: %s", rules_file, exc)
471 return list(_DEFAULT_RULE_SET)
472
473
474 # ---------------------------------------------------------------------------
475 # Main runner
476 # ---------------------------------------------------------------------------
477
478
479 def run_invariants(
480 root: "pathlib.Path",
481 commit_id: str,
482 rules: list[InvariantRule],
483 *,
484 track_filter: str | None = None,
485 ) -> InvariantReport:
486 """Evaluate all *rules* against every MIDI track in *commit_id*.
487
488 Args:
489 root: Repository root.
490 commit_id: Commit to check.
491 rules: List of :class:`InvariantRule` declarations.
492 track_filter: Restrict check to a single MIDI file path.
493
494 Returns:
495 An :class:`InvariantReport` with all violations found.
496 """
497 import pathlib as _pathlib
498
499 all_violations: list[InvariantViolation] = []
500 manifest = get_commit_snapshot_manifest(root, commit_id) or {}
501
502 midi_paths = [
503 p for p in manifest
504 if p.lower().endswith(".mid")
505 and (track_filter is None or p == track_filter)
506 ]
507
508 for track_path in sorted(midi_paths):
509 obj_hash = manifest.get(track_path)
510 if obj_hash is None:
511 continue
512 raw = read_object(root, obj_hash)
513 if raw is None:
514 continue
515 try:
516 keys, tpb = extract_notes(raw)
517 except ValueError as exc:
518 logger.debug("Cannot parse MIDI %r: %s", track_path, exc)
519 continue
520
521 notes = [NoteInfo.from_note_key(k, tpb) for k in keys]
522
523 for rule in rules:
524 rt = rule["rule_type"]
525 sev = rule["severity"]
526 params = rule.get("params", {})
527 name = rule["name"]
528
529 if rt == "max_polyphony":
530 max_sim = int(params.get("max_simultaneous", 8))
531 all_violations.extend(
532 check_max_polyphony(notes, track_path, name, sev, max_simultaneous=max_sim)
533 )
534 elif rt == "pitch_range":
535 min_p = int(params.get("min_pitch", 0))
536 max_p = int(params.get("max_pitch", 127))
537 all_violations.extend(
538 check_pitch_range(notes, track_path, name, sev, min_pitch=min_p, max_pitch=max_p)
539 )
540 elif rt == "key_consistency":
541 thresh = float(params.get("threshold", 0.15))
542 all_violations.extend(
543 check_key_consistency(notes, track_path, name, sev, threshold=thresh)
544 )
545 elif rt == "no_parallel_fifths":
546 all_violations.extend(
547 check_no_parallel_fifths(notes, track_path, name, sev)
548 )
549 else:
550 logger.debug("Unknown rule_type %r in rule %r — skipped", rt, name)
551
552 all_violations.sort(key=lambda v: (v["track"], v["bar"]))
553 has_errors = any(v["severity"] == "error" for v in all_violations)
554 has_warnings = any(v["severity"] == "warning" for v in all_violations)
555
556 return InvariantReport(
557 commit_id=commit_id,
558 violations=all_violations,
559 rules_checked=len(rules) * len(midi_paths),
560 has_errors=has_errors,
561 has_warnings=has_warnings,
562 )
563
564
565 class MidiChecker:
566 """Satisfies :class:`~muse.core.invariants.InvariantChecker` for the MIDI domain.
567
568 Wraps :func:`run_invariants` so that the generic ``muse check`` command
569 can dispatch to the MIDI checker without knowing MIDI internals.
570 """
571
572 def check(
573 self,
574 repo_root: pathlib.Path,
575 commit_id: str,
576 *,
577 rules_file: pathlib.Path | None = None,
578 ) -> BaseReport:
579 """Run MIDI invariant checks against *commit_id* and return a :class:`~muse.core.invariants.BaseReport`."""
580 rules = load_invariant_rules(rules_file)
581 midi_report = run_invariants(repo_root, commit_id, rules)
582 base_violations: list[BaseViolation] = [
583 BaseViolation(
584 rule_name=v["rule_name"],
585 severity=v["severity"],
586 address=v["track"],
587 description=v["description"],
588 )
589 for v in midi_report["violations"]
590 ]
591 return make_report(commit_id, "midi", base_violations, midi_report["rules_checked"])