midi_check.py
python
| 1 | """``muse midi-check`` — MIDI invariant enforcement. |
| 2 | |
| 3 | Evaluates the invariant rules declared in ``.muse/midi_invariants.toml`` |
| 4 | against every MIDI track in the specified commit and reports violations with |
| 5 | severity, location, and description. |
| 6 | |
| 7 | Built-in rule types (declared in TOML):: |
| 8 | |
| 9 | [[rule]] |
| 10 | name = "max_polyphony" |
| 11 | severity = "error" |
| 12 | rule_type = "max_polyphony" |
| 13 | [rule.params] |
| 14 | max_simultaneous = 6 |
| 15 | |
| 16 | [[rule]] |
| 17 | name = "pitch_range" |
| 18 | severity = "warning" |
| 19 | rule_type = "pitch_range" |
| 20 | [rule.params] |
| 21 | min_pitch = 24 |
| 22 | max_pitch = 108 |
| 23 | |
| 24 | [[rule]] |
| 25 | name = "key_consistency" |
| 26 | severity = "info" |
| 27 | rule_type = "key_consistency" |
| 28 | [rule.params] |
| 29 | threshold = 0.15 |
| 30 | |
| 31 | [[rule]] |
| 32 | name = "no_parallel_fifths" |
| 33 | severity = "warning" |
| 34 | rule_type = "no_parallel_fifths" |
| 35 | |
| 36 | Usage:: |
| 37 | |
| 38 | muse midi-check # check HEAD |
| 39 | muse midi-check abc1234 # check specific commit |
| 40 | muse midi-check --track piano.mid # check one track |
| 41 | muse midi-check --strict # exit 1 on any error-severity violation |
| 42 | muse midi-check --json # machine-readable output |
| 43 | """ |
| 44 | from __future__ import annotations |
| 45 | |
| 46 | import json |
| 47 | import logging |
| 48 | import pathlib |
| 49 | import sys |
| 50 | |
| 51 | import typer |
| 52 | |
| 53 | from muse.core.repo import require_repo |
| 54 | from muse.core.store import get_head_commit_id |
| 55 | from muse.plugins.midi._invariants import ( |
| 56 | InvariantReport, |
| 57 | load_invariant_rules, |
| 58 | run_invariants, |
| 59 | ) |
| 60 | |
| 61 | logger = logging.getLogger(__name__) |
| 62 | |
| 63 | app = typer.Typer(no_args_is_help=False) |
| 64 | |
| 65 | |
| 66 | def _read_branch(root: pathlib.Path) -> str: |
| 67 | head_ref = (root / ".muse" / "HEAD").read_text().strip() |
| 68 | return head_ref.removeprefix("refs/heads/").strip() |
| 69 | |
| 70 | |
| 71 | @app.command(name="midi-check") |
| 72 | def midi_check_cmd( |
| 73 | commit: str | None = typer.Argument( |
| 74 | None, |
| 75 | metavar="COMMIT", |
| 76 | help="Commit ID to check (default: HEAD).", |
| 77 | ), |
| 78 | track: str | None = typer.Option( |
| 79 | None, |
| 80 | "--track", |
| 81 | "-t", |
| 82 | metavar="PATH", |
| 83 | help="Restrict check to a single MIDI file path.", |
| 84 | ), |
| 85 | rules_file: str | None = typer.Option( |
| 86 | None, |
| 87 | "--rules", |
| 88 | "-r", |
| 89 | metavar="FILE", |
| 90 | help="Path to a TOML invariant rules file (default: .muse/midi_invariants.toml).", |
| 91 | ), |
| 92 | strict: bool = typer.Option( |
| 93 | False, |
| 94 | "--strict", |
| 95 | help="Exit with code 1 when any error-severity violations are found.", |
| 96 | ), |
| 97 | as_json: bool = typer.Option( |
| 98 | False, |
| 99 | "--json", |
| 100 | help="Output machine-readable JSON instead of formatted text.", |
| 101 | ), |
| 102 | ) -> None: |
| 103 | """Enforce MIDI invariant rules against a commit's MIDI tracks.""" |
| 104 | root = require_repo() |
| 105 | |
| 106 | commit_id = commit |
| 107 | if commit_id is None: |
| 108 | branch = _read_branch(root) |
| 109 | commit_id = get_head_commit_id(root, branch) |
| 110 | if commit_id is None: |
| 111 | typer.echo("❌ No commits in this repository.", err=True) |
| 112 | raise typer.Exit(1) |
| 113 | |
| 114 | # Load rules. |
| 115 | rules_path: pathlib.Path | None = None |
| 116 | if rules_file: |
| 117 | rules_path = pathlib.Path(rules_file) |
| 118 | else: |
| 119 | default = root / ".muse" / "midi_invariants.toml" |
| 120 | if default.exists(): |
| 121 | rules_path = default |
| 122 | |
| 123 | rules = load_invariant_rules(rules_path) |
| 124 | report = run_invariants(root, commit_id, rules, track_filter=track) |
| 125 | |
| 126 | if as_json: |
| 127 | sys.stdout.write(json.dumps(report, indent=2) + "\n") |
| 128 | else: |
| 129 | _print_report(report) |
| 130 | |
| 131 | if strict and report["has_errors"]: |
| 132 | raise typer.Exit(1) |
| 133 | |
| 134 | |
| 135 | _SEVERITY_ICON = { |
| 136 | "error": "❌", |
| 137 | "warning": "⚠️", |
| 138 | "info": "ℹ️", |
| 139 | } |
| 140 | |
| 141 | |
| 142 | def _print_report(report: InvariantReport) -> None: |
| 143 | """Format and print an invariant report to stdout.""" |
| 144 | violations = report["violations"] |
| 145 | |
| 146 | if not violations: |
| 147 | typer.echo( |
| 148 | f"✅ No violations found ({report['rules_checked']} rule-track checks)" |
| 149 | ) |
| 150 | return |
| 151 | |
| 152 | current_track: str | None = None |
| 153 | for v in violations: |
| 154 | if v["track"] != current_track: |
| 155 | current_track = v["track"] |
| 156 | typer.echo(f"\n {current_track}") |
| 157 | icon = _SEVERITY_ICON.get(v["severity"], "•") |
| 158 | bar_label = f"bar {v['bar']}" if v["bar"] > 0 else "track" |
| 159 | typer.echo( |
| 160 | f" {icon} [{v['rule_name']}] {bar_label}: {v['description']}" |
| 161 | ) |
| 162 | |
| 163 | error_count = sum(1 for v in violations if v["severity"] == "error") |
| 164 | warn_count = sum(1 for v in violations if v["severity"] == "warning") |
| 165 | info_count = sum(1 for v in violations if v["severity"] == "info") |
| 166 | |
| 167 | parts: list[str] = [] |
| 168 | if error_count: |
| 169 | parts.append(f"{error_count} error{'s' if error_count != 1 else ''}") |
| 170 | if warn_count: |
| 171 | parts.append(f"{warn_count} warning{'s' if warn_count != 1 else ''}") |
| 172 | if info_count: |
| 173 | parts.append(f"{info_count} info") |
| 174 | |
| 175 | summary = ", ".join(parts) |
| 176 | icon = "❌" if error_count else "⚠️" if warn_count else "ℹ️" |
| 177 | typer.echo( |
| 178 | f"\n{icon} {summary} in commit {report['commit_id'][:8]} " |
| 179 | f"({report['rules_checked']} rule-track checks)" |
| 180 | ) |