invariants.py
python
| 1 | """Domain-agnostic invariants engine for Muse. |
| 2 | |
| 3 | An *invariant* is a semantic rule that a domain's state must satisfy. Rules |
| 4 | are declared in TOML, evaluated against commit snapshots, and reported with |
| 5 | structured violations. Any domain plugin can implement invariant checking |
| 6 | by satisfying the :class:`InvariantChecker` protocol and wiring a CLI command. |
| 7 | |
| 8 | This module defines the **shared vocabulary** — TypedDicts and protocols that |
| 9 | are domain-agnostic. Domain-specific implementations (MIDI, code, genomics…) |
| 10 | from __future__ import annotations |
| 11 | |
| 12 | import these types and add their own rule types and evaluators. |
| 13 | |
| 14 | Architecture |
| 15 | ------------ |
| 16 | :: |
| 17 | |
| 18 | muse/core/invariants.py ← this file: shared protocol |
| 19 | muse/plugins/midi/_invariants.py ← MIDI-specific rules + evaluator |
| 20 | muse/plugins/code/_invariants.py ← code-specific rules + evaluator |
| 21 | muse/cli/commands/midi_check.py ← CLI wiring for MIDI |
| 22 | muse/cli/commands/code_check.py ← CLI wiring for code |
| 23 | |
| 24 | TOML rule file format (shared across all domains):: |
| 25 | |
| 26 | [[rule]] |
| 27 | name = "my_rule" # unique human-readable identifier |
| 28 | severity = "error" # "info" | "warning" | "error" |
| 29 | scope = "file" # domain-specific scope tag |
| 30 | rule_type = "max_complexity" # domain-specific rule type string |
| 31 | |
| 32 | [rule.params] |
| 33 | threshold = 10 # rule-specific numeric / string params |
| 34 | |
| 35 | Severity levels |
| 36 | --------------- |
| 37 | - ``"error"`` — must be resolved before committing (when ``--strict`` is set). |
| 38 | - ``"warning"`` — reported but does not block commits. |
| 39 | - ``"info"`` — informational; surfaced in ``muse check`` output only. |
| 40 | |
| 41 | Public API |
| 42 | ---------- |
| 43 | - :data:`InvariantSeverity` — severity literal type alias. |
| 44 | - :class:`BaseViolation` — domain-agnostic violation record. |
| 45 | - :class:`BaseReport` — full check report for one commit. |
| 46 | - :class:`InvariantChecker` — Protocol every domain checker must satisfy. |
| 47 | - :func:`make_report` — build a ``BaseReport`` from a violation list. |
| 48 | - :func:`load_rules_toml` — parse any ``[[rule]]`` TOML file. |
| 49 | - :func:`format_report` — human-readable report text. |
| 50 | """ |
| 51 | |
| 52 | import logging |
| 53 | import pathlib |
| 54 | from typing import Literal, Protocol, TypedDict, runtime_checkable |
| 55 | |
| 56 | logger = logging.getLogger(__name__) |
| 57 | |
| 58 | # --------------------------------------------------------------------------- |
| 59 | # Shared severity literal |
| 60 | # --------------------------------------------------------------------------- |
| 61 | |
| 62 | InvariantSeverity = Literal["info", "warning", "error"] |
| 63 | |
| 64 | |
| 65 | # --------------------------------------------------------------------------- |
| 66 | # Domain-agnostic violation + report TypedDicts |
| 67 | # --------------------------------------------------------------------------- |
| 68 | |
| 69 | |
| 70 | class BaseViolation(TypedDict): |
| 71 | """A single invariant violation, domain-agnostic. |
| 72 | |
| 73 | Domain implementations extend this with additional fields (e.g. ``track`` |
| 74 | for MIDI, ``file`` and ``symbol`` for code). |
| 75 | |
| 76 | ``rule_name`` The name of the rule that fired. |
| 77 | ``severity`` Violation severity inherited from the rule declaration. |
| 78 | ``address`` Dotted path to the violating element |
| 79 | (e.g. ``"src/utils.py::my_fn"`` or ``"piano.mid/bar:4"``). |
| 80 | ``description`` Human-readable explanation of the violation. |
| 81 | """ |
| 82 | |
| 83 | rule_name: str |
| 84 | severity: InvariantSeverity |
| 85 | address: str |
| 86 | description: str |
| 87 | |
| 88 | |
| 89 | class BaseReport(TypedDict): |
| 90 | """Full invariant check report for one commit, domain-agnostic. |
| 91 | |
| 92 | ``commit_id`` The commit that was checked. |
| 93 | ``domain`` Domain tag (e.g. ``"midi"``, ``"code"``). |
| 94 | ``violations`` All violations found, sorted by address. |
| 95 | ``rules_checked`` Number of rules evaluated. |
| 96 | ``has_errors`` ``True`` when any violation has severity ``"error"``. |
| 97 | ``has_warnings`` ``True`` when any violation has severity ``"warning"``. |
| 98 | """ |
| 99 | |
| 100 | commit_id: str |
| 101 | domain: str |
| 102 | violations: list[BaseViolation] |
| 103 | rules_checked: int |
| 104 | has_errors: bool |
| 105 | has_warnings: bool |
| 106 | |
| 107 | |
| 108 | # --------------------------------------------------------------------------- |
| 109 | # InvariantChecker protocol |
| 110 | # --------------------------------------------------------------------------- |
| 111 | |
| 112 | |
| 113 | @runtime_checkable |
| 114 | class InvariantChecker(Protocol): |
| 115 | """Protocol every domain invariant checker must satisfy. |
| 116 | |
| 117 | Domain plugins implement this by providing :meth:`check` — a function that |
| 118 | loads and evaluates the domain's invariant rules against a commit, returning |
| 119 | a :class:`BaseReport`. The CLI ``muse check`` command dispatches to the |
| 120 | domain's registered checker via this protocol. |
| 121 | |
| 122 | Example implementation:: |
| 123 | |
| 124 | class MyDomainChecker: |
| 125 | def check( |
| 126 | self, |
| 127 | repo_root: pathlib.Path, |
| 128 | commit_id: str, |
| 129 | *, |
| 130 | rules_file: pathlib.Path | None = None, |
| 131 | ) -> BaseReport: |
| 132 | rules = load_rules_toml(rules_file or default_path) |
| 133 | violations = _evaluate(repo_root, commit_id, rules) |
| 134 | return make_report(commit_id, "mydomain", violations, len(rules)) |
| 135 | """ |
| 136 | |
| 137 | def check( |
| 138 | self, |
| 139 | repo_root: pathlib.Path, |
| 140 | commit_id: str, |
| 141 | *, |
| 142 | rules_file: pathlib.Path | None = None, |
| 143 | ) -> BaseReport: |
| 144 | """Evaluate invariant rules and return a structured report. |
| 145 | |
| 146 | Args: |
| 147 | repo_root: Repository root (contains ``.muse/``). |
| 148 | commit_id: Commit to check. |
| 149 | rules_file: Path to a TOML rule file. ``None`` → use the |
| 150 | domain's default location. |
| 151 | |
| 152 | Returns: |
| 153 | A :class:`BaseReport` with all violations and summary flags. |
| 154 | """ |
| 155 | ... |
| 156 | |
| 157 | |
| 158 | # --------------------------------------------------------------------------- |
| 159 | # Helpers |
| 160 | # --------------------------------------------------------------------------- |
| 161 | |
| 162 | |
| 163 | def make_report( |
| 164 | commit_id: str, |
| 165 | domain: str, |
| 166 | violations: list[BaseViolation], |
| 167 | rules_checked: int, |
| 168 | ) -> BaseReport: |
| 169 | """Build a :class:`BaseReport` from a flat violation list. |
| 170 | |
| 171 | Sorts violations by address then rule name for deterministic output. |
| 172 | |
| 173 | Args: |
| 174 | commit_id: Commit that was checked. |
| 175 | domain: Domain tag. |
| 176 | violations: All violations found. |
| 177 | rules_checked: Number of rules that were evaluated. |
| 178 | |
| 179 | Returns: |
| 180 | A fully populated :class:`BaseReport`. |
| 181 | """ |
| 182 | sorted_violations = sorted(violations, key=lambda v: (v["address"], v["rule_name"])) |
| 183 | return BaseReport( |
| 184 | commit_id=commit_id, |
| 185 | domain=domain, |
| 186 | violations=sorted_violations, |
| 187 | rules_checked=rules_checked, |
| 188 | has_errors=any(v["severity"] == "error" for v in violations), |
| 189 | has_warnings=any(v["severity"] == "warning" for v in violations), |
| 190 | ) |
| 191 | |
| 192 | |
| 193 | def load_rules_toml(path: pathlib.Path) -> list[dict[str, str | int | float | dict[str, str | int | float]]]: |
| 194 | """Parse a ``[[rule]]`` TOML file and return the raw rule dicts. |
| 195 | |
| 196 | Returns an empty list when the file does not exist (domain then uses |
| 197 | built-in defaults). |
| 198 | |
| 199 | Args: |
| 200 | path: Path to the TOML file. |
| 201 | |
| 202 | Returns: |
| 203 | List of raw rule dicts (``{"name": ..., "severity": ..., ...}``). |
| 204 | """ |
| 205 | if not path.exists(): |
| 206 | logger.debug("Invariants rules file not found at %s — using defaults", path) |
| 207 | return [] |
| 208 | import tomllib # stdlib on Python ≥ 3.11; Muse requires 3.12 |
| 209 | try: |
| 210 | data = tomllib.loads(path.read_text()) |
| 211 | rules: list[dict[str, str | int | float | dict[str, str | int | float]]] = data.get("rule", []) |
| 212 | return rules |
| 213 | except Exception as exc: |
| 214 | logger.warning("Failed to parse invariants file %s: %s", path, exc) |
| 215 | return [] |
| 216 | |
| 217 | |
| 218 | def format_report(report: BaseReport, *, color: bool = True) -> str: |
| 219 | """Return a human-readable multi-line report string. |
| 220 | |
| 221 | Args: |
| 222 | report: The report to format. |
| 223 | color: If ``True``, prefix error/warning/info lines with emoji. |
| 224 | |
| 225 | Returns: |
| 226 | Formatted string ready for ``typer.echo()``. |
| 227 | """ |
| 228 | lines: list[str] = [] |
| 229 | prefix = { |
| 230 | "error": "❌" if color else "[error]", |
| 231 | "warning": "⚠️ " if color else "[warn] ", |
| 232 | "info": "ℹ️ " if color else "[info] ", |
| 233 | } |
| 234 | for v in report["violations"]: |
| 235 | p = prefix.get(v["severity"], " ") |
| 236 | lines.append(f" {p} [{v['rule_name']}] {v['address']}: {v['description']}") |
| 237 | |
| 238 | checked = report["rules_checked"] |
| 239 | total = len(report["violations"]) |
| 240 | errors = sum(1 for v in report["violations"] if v["severity"] == "error") |
| 241 | warnings = sum(1 for v in report["violations"] if v["severity"] == "warning") |
| 242 | |
| 243 | summary = f"\n{checked} rules checked — {total} violation(s)" |
| 244 | if errors: |
| 245 | summary += f", {errors} error(s)" |
| 246 | if warnings: |
| 247 | summary += f", {warnings} warning(s)" |
| 248 | if not total: |
| 249 | summary = f"\n✅ {checked} rules checked — no violations" |
| 250 | |
| 251 | return "\n".join(lines) + summary |