gabriel / muse public
invariants.py python
251 lines 8.8 KB
bda49bdb feat: redesign .museignore as TOML with domain-scoped sections (#100) Gabriel Cardona <cgcardona@gmail.com> 5d ago
1 """Domain-agnostic invariants engine for Muse.
2
3 An *invariant* is a semantic rule that a domain's state must satisfy. Rules
4 are declared in TOML, evaluated against commit snapshots, and reported with
5 structured violations. Any domain plugin can implement invariant checking
6 by satisfying the :class:`InvariantChecker` protocol and wiring a CLI command.
7
8 This module defines the **shared vocabulary** — TypedDicts and protocols that
9 are domain-agnostic. Domain-specific implementations (MIDI, code, genomics…)
10 from __future__ import annotations
11
12 import these types and add their own rule types and evaluators.
13
14 Architecture
15 ------------
16 ::
17
18 muse/core/invariants.py ← this file: shared protocol
19 muse/plugins/midi/_invariants.py ← MIDI-specific rules + evaluator
20 muse/plugins/code/_invariants.py ← code-specific rules + evaluator
21 muse/cli/commands/midi_check.py ← CLI wiring for MIDI
22 muse/cli/commands/code_check.py ← CLI wiring for code
23
24 TOML rule file format (shared across all domains)::
25
26 [[rule]]
27 name = "my_rule" # unique human-readable identifier
28 severity = "error" # "info" | "warning" | "error"
29 scope = "file" # domain-specific scope tag
30 rule_type = "max_complexity" # domain-specific rule type string
31
32 [rule.params]
33 threshold = 10 # rule-specific numeric / string params
34
35 Severity levels
36 ---------------
37 - ``"error"`` — must be resolved before committing (when ``--strict`` is set).
38 - ``"warning"`` — reported but does not block commits.
39 - ``"info"`` — informational; surfaced in ``muse check`` output only.
40
41 Public API
42 ----------
43 - :data:`InvariantSeverity` — severity literal type alias.
44 - :class:`BaseViolation` — domain-agnostic violation record.
45 - :class:`BaseReport` — full check report for one commit.
46 - :class:`InvariantChecker` — Protocol every domain checker must satisfy.
47 - :func:`make_report` — build a ``BaseReport`` from a violation list.
48 - :func:`load_rules_toml` — parse any ``[[rule]]`` TOML file.
49 - :func:`format_report` — human-readable report text.
50 """
51
52 import logging
53 import pathlib
54 from typing import Literal, Protocol, TypedDict, runtime_checkable
55
56 logger = logging.getLogger(__name__)
57
58 # ---------------------------------------------------------------------------
59 # Shared severity literal
60 # ---------------------------------------------------------------------------
61
62 InvariantSeverity = Literal["info", "warning", "error"]
63
64
65 # ---------------------------------------------------------------------------
66 # Domain-agnostic violation + report TypedDicts
67 # ---------------------------------------------------------------------------
68
69
70 class BaseViolation(TypedDict):
71 """A single invariant violation, domain-agnostic.
72
73 Domain implementations extend this with additional fields (e.g. ``track``
74 for MIDI, ``file`` and ``symbol`` for code).
75
76 ``rule_name`` The name of the rule that fired.
77 ``severity`` Violation severity inherited from the rule declaration.
78 ``address`` Dotted path to the violating element
79 (e.g. ``"src/utils.py::my_fn"`` or ``"piano.mid/bar:4"``).
80 ``description`` Human-readable explanation of the violation.
81 """
82
83 rule_name: str
84 severity: InvariantSeverity
85 address: str
86 description: str
87
88
89 class BaseReport(TypedDict):
90 """Full invariant check report for one commit, domain-agnostic.
91
92 ``commit_id`` The commit that was checked.
93 ``domain`` Domain tag (e.g. ``"midi"``, ``"code"``).
94 ``violations`` All violations found, sorted by address.
95 ``rules_checked`` Number of rules evaluated.
96 ``has_errors`` ``True`` when any violation has severity ``"error"``.
97 ``has_warnings`` ``True`` when any violation has severity ``"warning"``.
98 """
99
100 commit_id: str
101 domain: str
102 violations: list[BaseViolation]
103 rules_checked: int
104 has_errors: bool
105 has_warnings: bool
106
107
108 # ---------------------------------------------------------------------------
109 # InvariantChecker protocol
110 # ---------------------------------------------------------------------------
111
112
113 @runtime_checkable
114 class InvariantChecker(Protocol):
115 """Protocol every domain invariant checker must satisfy.
116
117 Domain plugins implement this by providing :meth:`check` — a function that
118 loads and evaluates the domain's invariant rules against a commit, returning
119 a :class:`BaseReport`. The CLI ``muse check`` command dispatches to the
120 domain's registered checker via this protocol.
121
122 Example implementation::
123
124 class MyDomainChecker:
125 def check(
126 self,
127 repo_root: pathlib.Path,
128 commit_id: str,
129 *,
130 rules_file: pathlib.Path | None = None,
131 ) -> BaseReport:
132 rules = load_rules_toml(rules_file or default_path)
133 violations = _evaluate(repo_root, commit_id, rules)
134 return make_report(commit_id, "mydomain", violations, len(rules))
135 """
136
137 def check(
138 self,
139 repo_root: pathlib.Path,
140 commit_id: str,
141 *,
142 rules_file: pathlib.Path | None = None,
143 ) -> BaseReport:
144 """Evaluate invariant rules and return a structured report.
145
146 Args:
147 repo_root: Repository root (contains ``.muse/``).
148 commit_id: Commit to check.
149 rules_file: Path to a TOML rule file. ``None`` → use the
150 domain's default location.
151
152 Returns:
153 A :class:`BaseReport` with all violations and summary flags.
154 """
155 ...
156
157
158 # ---------------------------------------------------------------------------
159 # Helpers
160 # ---------------------------------------------------------------------------
161
162
163 def make_report(
164 commit_id: str,
165 domain: str,
166 violations: list[BaseViolation],
167 rules_checked: int,
168 ) -> BaseReport:
169 """Build a :class:`BaseReport` from a flat violation list.
170
171 Sorts violations by address then rule name for deterministic output.
172
173 Args:
174 commit_id: Commit that was checked.
175 domain: Domain tag.
176 violations: All violations found.
177 rules_checked: Number of rules that were evaluated.
178
179 Returns:
180 A fully populated :class:`BaseReport`.
181 """
182 sorted_violations = sorted(violations, key=lambda v: (v["address"], v["rule_name"]))
183 return BaseReport(
184 commit_id=commit_id,
185 domain=domain,
186 violations=sorted_violations,
187 rules_checked=rules_checked,
188 has_errors=any(v["severity"] == "error" for v in violations),
189 has_warnings=any(v["severity"] == "warning" for v in violations),
190 )
191
192
193 def load_rules_toml(path: pathlib.Path) -> list[dict[str, str | int | float | dict[str, str | int | float]]]:
194 """Parse a ``[[rule]]`` TOML file and return the raw rule dicts.
195
196 Returns an empty list when the file does not exist (domain then uses
197 built-in defaults).
198
199 Args:
200 path: Path to the TOML file.
201
202 Returns:
203 List of raw rule dicts (``{"name": ..., "severity": ..., ...}``).
204 """
205 if not path.exists():
206 logger.debug("Invariants rules file not found at %s — using defaults", path)
207 return []
208 import tomllib # stdlib on Python ≥ 3.11; Muse requires 3.12
209 try:
210 data = tomllib.loads(path.read_text())
211 rules: list[dict[str, str | int | float | dict[str, str | int | float]]] = data.get("rule", [])
212 return rules
213 except Exception as exc:
214 logger.warning("Failed to parse invariants file %s: %s", path, exc)
215 return []
216
217
218 def format_report(report: BaseReport, *, color: bool = True) -> str:
219 """Return a human-readable multi-line report string.
220
221 Args:
222 report: The report to format.
223 color: If ``True``, prefix error/warning/info lines with emoji.
224
225 Returns:
226 Formatted string ready for ``typer.echo()``.
227 """
228 lines: list[str] = []
229 prefix = {
230 "error": "❌" if color else "[error]",
231 "warning": "⚠️ " if color else "[warn] ",
232 "info": "ℹ️ " if color else "[info] ",
233 }
234 for v in report["violations"]:
235 p = prefix.get(v["severity"], " ")
236 lines.append(f" {p} [{v['rule_name']}] {v['address']}: {v['description']}")
237
238 checked = report["rules_checked"]
239 total = len(report["violations"])
240 errors = sum(1 for v in report["violations"] if v["severity"] == "error")
241 warnings = sum(1 for v in report["violations"] if v["severity"] == "warning")
242
243 summary = f"\n{checked} rules checked — {total} violation(s)"
244 if errors:
245 summary += f", {errors} error(s)"
246 if warnings:
247 summary += f", {warnings} warning(s)"
248 if not total:
249 summary = f"\n✅ {checked} rules checked — no violations"
250
251 return "\n".join(lines) + summary