invariants.py
python
| 1 | """Domain-agnostic invariants engine for Muse. |
| 2 | |
| 3 | An *invariant* is a semantic rule that a domain's state must satisfy. Rules |
| 4 | are declared in TOML, evaluated against commit snapshots, and reported with |
| 5 | structured violations. Any domain plugin can implement invariant checking |
| 6 | by satisfying the :class:`InvariantChecker` protocol and wiring a CLI command. |
| 7 | |
| 8 | This module defines the **shared vocabulary** — TypedDicts and protocols that |
| 9 | are domain-agnostic. Domain-specific implementations (MIDI, code, genomics…) |
| 10 | import these types and add their own rule types and evaluators. |
| 11 | |
| 12 | Architecture |
| 13 | ------------ |
| 14 | :: |
| 15 | |
| 16 | muse/core/invariants.py ← this file: shared protocol |
| 17 | muse/plugins/midi/_invariants.py ← MIDI-specific rules + evaluator |
| 18 | muse/plugins/code/_invariants.py ← code-specific rules + evaluator |
| 19 | muse/cli/commands/midi_check.py ← CLI wiring for MIDI |
| 20 | muse/cli/commands/code_check.py ← CLI wiring for code |
| 21 | |
| 22 | TOML rule file format (shared across all domains):: |
| 23 | |
| 24 | [[rule]] |
| 25 | name = "my_rule" # unique human-readable identifier |
| 26 | severity = "error" # "info" | "warning" | "error" |
| 27 | scope = "file" # domain-specific scope tag |
| 28 | rule_type = "max_complexity" # domain-specific rule type string |
| 29 | |
| 30 | [rule.params] |
| 31 | threshold = 10 # rule-specific numeric / string params |
| 32 | |
| 33 | Severity levels |
| 34 | --------------- |
| 35 | - ``"error"`` — must be resolved before committing (when ``--strict`` is set). |
| 36 | - ``"warning"`` — reported but does not block commits. |
| 37 | - ``"info"`` — informational; surfaced in ``muse check`` output only. |
| 38 | |
| 39 | Public API |
| 40 | ---------- |
| 41 | - :data:`InvariantSeverity` — severity literal type alias. |
| 42 | - :class:`BaseViolation` — domain-agnostic violation record. |
| 43 | - :class:`BaseReport` — full check report for one commit. |
| 44 | - :class:`InvariantChecker` — Protocol every domain checker must satisfy. |
| 45 | - :func:`make_report` — build a ``BaseReport`` from a violation list. |
| 46 | - :func:`load_rules_toml` — parse any ``[[rule]]`` TOML file. |
| 47 | - :func:`format_report` — human-readable report text. |
| 48 | """ |
| 49 | |
| 50 | import logging |
| 51 | import pathlib |
| 52 | from typing import Literal, Protocol, TypedDict, runtime_checkable |
| 53 | |
| 54 | logger = logging.getLogger(__name__) |
| 55 | |
| 56 | # --------------------------------------------------------------------------- |
| 57 | # Shared severity literal |
| 58 | # --------------------------------------------------------------------------- |
| 59 | |
| 60 | InvariantSeverity = Literal["info", "warning", "error"] |
| 61 | |
| 62 | |
| 63 | # --------------------------------------------------------------------------- |
| 64 | # Domain-agnostic violation + report TypedDicts |
| 65 | # --------------------------------------------------------------------------- |
| 66 | |
| 67 | |
| 68 | class BaseViolation(TypedDict): |
| 69 | """A single invariant violation, domain-agnostic. |
| 70 | |
| 71 | Domain implementations extend this with additional fields (e.g. ``track`` |
| 72 | for MIDI, ``file`` and ``symbol`` for code). |
| 73 | |
| 74 | ``rule_name`` The name of the rule that fired. |
| 75 | ``severity`` Violation severity inherited from the rule declaration. |
| 76 | ``address`` Dotted path to the violating element |
| 77 | (e.g. ``"src/utils.py::my_fn"`` or ``"piano.mid/bar:4"``). |
| 78 | ``description`` Human-readable explanation of the violation. |
| 79 | """ |
| 80 | |
| 81 | rule_name: str |
| 82 | severity: InvariantSeverity |
| 83 | address: str |
| 84 | description: str |
| 85 | |
| 86 | |
| 87 | class BaseReport(TypedDict): |
| 88 | """Full invariant check report for one commit, domain-agnostic. |
| 89 | |
| 90 | ``commit_id`` The commit that was checked. |
| 91 | ``domain`` Domain tag (e.g. ``"midi"``, ``"code"``). |
| 92 | ``violations`` All violations found, sorted by address. |
| 93 | ``rules_checked`` Number of rules evaluated. |
| 94 | ``has_errors`` ``True`` when any violation has severity ``"error"``. |
| 95 | ``has_warnings`` ``True`` when any violation has severity ``"warning"``. |
| 96 | """ |
| 97 | |
| 98 | commit_id: str |
| 99 | domain: str |
| 100 | violations: list[BaseViolation] |
| 101 | rules_checked: int |
| 102 | has_errors: bool |
| 103 | has_warnings: bool |
| 104 | |
| 105 | |
| 106 | # --------------------------------------------------------------------------- |
| 107 | # InvariantChecker protocol |
| 108 | # --------------------------------------------------------------------------- |
| 109 | |
| 110 | |
| 111 | @runtime_checkable |
| 112 | class InvariantChecker(Protocol): |
| 113 | """Protocol every domain invariant checker must satisfy. |
| 114 | |
| 115 | Domain plugins implement this by providing :meth:`check` — a function that |
| 116 | loads and evaluates the domain's invariant rules against a commit, returning |
| 117 | a :class:`BaseReport`. The CLI ``muse check`` command dispatches to the |
| 118 | domain's registered checker via this protocol. |
| 119 | |
| 120 | Example implementation:: |
| 121 | |
| 122 | class MyDomainChecker: |
| 123 | def check( |
| 124 | self, |
| 125 | repo_root: pathlib.Path, |
| 126 | commit_id: str, |
| 127 | *, |
| 128 | rules_file: pathlib.Path | None = None, |
| 129 | ) -> BaseReport: |
| 130 | rules = load_rules_toml(rules_file or default_path) |
| 131 | violations = _evaluate(repo_root, commit_id, rules) |
| 132 | return make_report(commit_id, "mydomain", violations, len(rules)) |
| 133 | """ |
| 134 | |
| 135 | def check( |
| 136 | self, |
| 137 | repo_root: pathlib.Path, |
| 138 | commit_id: str, |
| 139 | *, |
| 140 | rules_file: pathlib.Path | None = None, |
| 141 | ) -> BaseReport: |
| 142 | """Evaluate invariant rules and return a structured report. |
| 143 | |
| 144 | Args: |
| 145 | repo_root: Repository root (contains ``.muse/``). |
| 146 | commit_id: Commit to check. |
| 147 | rules_file: Path to a TOML rule file. ``None`` → use the |
| 148 | domain's default location. |
| 149 | |
| 150 | Returns: |
| 151 | A :class:`BaseReport` with all violations and summary flags. |
| 152 | """ |
| 153 | ... |
| 154 | |
| 155 | |
| 156 | # --------------------------------------------------------------------------- |
| 157 | # Helpers |
| 158 | # --------------------------------------------------------------------------- |
| 159 | |
| 160 | |
| 161 | def make_report( |
| 162 | commit_id: str, |
| 163 | domain: str, |
| 164 | violations: list[BaseViolation], |
| 165 | rules_checked: int, |
| 166 | ) -> BaseReport: |
| 167 | """Build a :class:`BaseReport` from a flat violation list. |
| 168 | |
| 169 | Sorts violations by address then rule name for deterministic output. |
| 170 | |
| 171 | Args: |
| 172 | commit_id: Commit that was checked. |
| 173 | domain: Domain tag. |
| 174 | violations: All violations found. |
| 175 | rules_checked: Number of rules that were evaluated. |
| 176 | |
| 177 | Returns: |
| 178 | A fully populated :class:`BaseReport`. |
| 179 | """ |
| 180 | sorted_violations = sorted(violations, key=lambda v: (v["address"], v["rule_name"])) |
| 181 | return BaseReport( |
| 182 | commit_id=commit_id, |
| 183 | domain=domain, |
| 184 | violations=sorted_violations, |
| 185 | rules_checked=rules_checked, |
| 186 | has_errors=any(v["severity"] == "error" for v in violations), |
| 187 | has_warnings=any(v["severity"] == "warning" for v in violations), |
| 188 | ) |
| 189 | |
| 190 | |
| 191 | def load_rules_toml(path: pathlib.Path) -> list[dict[str, str | int | float | dict[str, str | int | float]]]: |
| 192 | """Parse a ``[[rule]]`` TOML file and return the raw rule dicts. |
| 193 | |
| 194 | Returns an empty list when the file does not exist (domain then uses |
| 195 | built-in defaults). |
| 196 | |
| 197 | Args: |
| 198 | path: Path to the TOML file. |
| 199 | |
| 200 | Returns: |
| 201 | List of raw rule dicts (``{"name": ..., "severity": ..., ...}``). |
| 202 | """ |
| 203 | if not path.exists(): |
| 204 | logger.debug("Invariants rules file not found at %s — using defaults", path) |
| 205 | return [] |
| 206 | import tomllib # stdlib on Python ≥ 3.11; Muse requires 3.12 |
| 207 | try: |
| 208 | data = tomllib.loads(path.read_text()) |
| 209 | rules: list[dict[str, str | int | float | dict[str, str | int | float]]] = data.get("rule", []) |
| 210 | return rules |
| 211 | except Exception as exc: |
| 212 | logger.warning("Failed to parse invariants file %s: %s", path, exc) |
| 213 | return [] |
| 214 | |
| 215 | |
| 216 | def format_report(report: BaseReport, *, color: bool = True) -> str: |
| 217 | """Return a human-readable multi-line report string. |
| 218 | |
| 219 | Args: |
| 220 | report: The report to format. |
| 221 | color: If ``True``, prefix error/warning/info lines with emoji. |
| 222 | |
| 223 | Returns: |
| 224 | Formatted string ready for ``typer.echo()``. |
| 225 | """ |
| 226 | lines: list[str] = [] |
| 227 | prefix = { |
| 228 | "error": "❌" if color else "[error]", |
| 229 | "warning": "⚠️ " if color else "[warn] ", |
| 230 | "info": "ℹ️ " if color else "[info] ", |
| 231 | } |
| 232 | for v in report["violations"]: |
| 233 | p = prefix.get(v["severity"], " ") |
| 234 | lines.append(f" {p} [{v['rule_name']}] {v['address']}: {v['description']}") |
| 235 | |
| 236 | checked = report["rules_checked"] |
| 237 | total = len(report["violations"]) |
| 238 | errors = sum(1 for v in report["violations"] if v["severity"] == "error") |
| 239 | warnings = sum(1 for v in report["violations"] if v["severity"] == "warning") |
| 240 | |
| 241 | summary = f"\n{checked} rules checked — {total} violation(s)" |
| 242 | if errors: |
| 243 | summary += f", {errors} error(s)" |
| 244 | if warnings: |
| 245 | summary += f", {warnings} warning(s)" |
| 246 | if not total: |
| 247 | summary = f"\n✅ {checked} rules checked — no violations" |
| 248 | |
| 249 | return "\n".join(lines) + summary |