gabriel / muse public
midi_check.py python
180 lines 5.0 KB
9ee9c39c refactor: rename music→midi domain, strip all 5-dim backward compat Gabriel Cardona <gabriel@tellurstori.com> 5d ago
1 """``muse midi-check`` — MIDI invariant enforcement.
2
3 Evaluates the invariant rules declared in ``.muse/midi_invariants.toml``
4 against every MIDI track in the specified commit and reports violations with
5 severity, location, and description.
6
7 Built-in rule types (declared in TOML)::
8
9 [[rule]]
10 name = "max_polyphony"
11 severity = "error"
12 rule_type = "max_polyphony"
13 [rule.params]
14 max_simultaneous = 6
15
16 [[rule]]
17 name = "pitch_range"
18 severity = "warning"
19 rule_type = "pitch_range"
20 [rule.params]
21 min_pitch = 24
22 max_pitch = 108
23
24 [[rule]]
25 name = "key_consistency"
26 severity = "info"
27 rule_type = "key_consistency"
28 [rule.params]
29 threshold = 0.15
30
31 [[rule]]
32 name = "no_parallel_fifths"
33 severity = "warning"
34 rule_type = "no_parallel_fifths"
35
36 Usage::
37
38 muse midi-check # check HEAD
39 muse midi-check abc1234 # check specific commit
40 muse midi-check --track piano.mid # check one track
41 muse midi-check --strict # exit 1 on any error-severity violation
42 muse midi-check --json # machine-readable output
43 """
44 from __future__ import annotations
45
46 import json
47 import logging
48 import pathlib
49 import sys
50
51 import typer
52
53 from muse.core.repo import require_repo
54 from muse.core.store import get_head_commit_id
55 from muse.plugins.midi._invariants import (
56 InvariantReport,
57 load_invariant_rules,
58 run_invariants,
59 )
60
61 logger = logging.getLogger(__name__)
62
63 app = typer.Typer(no_args_is_help=False)
64
65
66 def _read_branch(root: pathlib.Path) -> str:
67 head_ref = (root / ".muse" / "HEAD").read_text().strip()
68 return head_ref.removeprefix("refs/heads/").strip()
69
70
71 @app.command(name="midi-check")
72 def midi_check_cmd(
73 commit: str | None = typer.Argument(
74 None,
75 metavar="COMMIT",
76 help="Commit ID to check (default: HEAD).",
77 ),
78 track: str | None = typer.Option(
79 None,
80 "--track",
81 "-t",
82 metavar="PATH",
83 help="Restrict check to a single MIDI file path.",
84 ),
85 rules_file: str | None = typer.Option(
86 None,
87 "--rules",
88 "-r",
89 metavar="FILE",
90 help="Path to a TOML invariant rules file (default: .muse/midi_invariants.toml).",
91 ),
92 strict: bool = typer.Option(
93 False,
94 "--strict",
95 help="Exit with code 1 when any error-severity violations are found.",
96 ),
97 as_json: bool = typer.Option(
98 False,
99 "--json",
100 help="Output machine-readable JSON instead of formatted text.",
101 ),
102 ) -> None:
103 """Enforce MIDI invariant rules against a commit's MIDI tracks."""
104 root = require_repo()
105
106 commit_id = commit
107 if commit_id is None:
108 branch = _read_branch(root)
109 commit_id = get_head_commit_id(root, branch)
110 if commit_id is None:
111 typer.echo("❌ No commits in this repository.", err=True)
112 raise typer.Exit(1)
113
114 # Load rules.
115 rules_path: pathlib.Path | None = None
116 if rules_file:
117 rules_path = pathlib.Path(rules_file)
118 else:
119 default = root / ".muse" / "midi_invariants.toml"
120 if default.exists():
121 rules_path = default
122
123 rules = load_invariant_rules(rules_path)
124 report = run_invariants(root, commit_id, rules, track_filter=track)
125
126 if as_json:
127 sys.stdout.write(json.dumps(report, indent=2) + "\n")
128 else:
129 _print_report(report)
130
131 if strict and report["has_errors"]:
132 raise typer.Exit(1)
133
134
135 _SEVERITY_ICON = {
136 "error": "❌",
137 "warning": "⚠️",
138 "info": "ℹ️",
139 }
140
141
142 def _print_report(report: InvariantReport) -> None:
143 """Format and print an invariant report to stdout."""
144 violations = report["violations"]
145
146 if not violations:
147 typer.echo(
148 f"✅ No violations found ({report['rules_checked']} rule-track checks)"
149 )
150 return
151
152 current_track: str | None = None
153 for v in violations:
154 if v["track"] != current_track:
155 current_track = v["track"]
156 typer.echo(f"\n {current_track}")
157 icon = _SEVERITY_ICON.get(v["severity"], "•")
158 bar_label = f"bar {v['bar']}" if v["bar"] > 0 else "track"
159 typer.echo(
160 f" {icon} [{v['rule_name']}] {bar_label}: {v['description']}"
161 )
162
163 error_count = sum(1 for v in violations if v["severity"] == "error")
164 warn_count = sum(1 for v in violations if v["severity"] == "warning")
165 info_count = sum(1 for v in violations if v["severity"] == "info")
166
167 parts: list[str] = []
168 if error_count:
169 parts.append(f"{error_count} error{'s' if error_count != 1 else ''}")
170 if warn_count:
171 parts.append(f"{warn_count} warning{'s' if warn_count != 1 else ''}")
172 if info_count:
173 parts.append(f"{info_count} info")
174
175 summary = ", ".join(parts)
176 icon = "❌" if error_count else "⚠️" if warn_count else "ℹ️"
177 typer.echo(
178 f"\n{icon} {summary} in commit {report['commit_id'][:8]} "
179 f"({report['rules_checked']} rule-track checks)"
180 )