gabriel / muse public
invariants.py python
356 lines 12.1 KB
bda49bdb feat: redesign .museignore as TOML with domain-scoped sections (#100) Gabriel Cardona <cgcardona@gmail.com> 5d ago
1 """muse invariants — enforce architectural rules from .muse/invariants.toml.
2
3 Loads invariant rules from ``.muse/invariants.toml`` and checks them against
4 the committed snapshot at HEAD (or a given commit). Rules are declarative
5 architectural constraints — enforced by analysis, not by the runtime.
6
7 Supported rule types
8 --------------------
9 ``no_cycles``
10 The import graph must have no cycles. Detects import cycle violations.
11
12 ``forbidden_dependency``
13 A file (or file pattern) must not import from another file (or pattern).
14 Enforces layer boundaries (e.g. "core must not import from cli").
15
16 ``required_test``
17 Every public function in ``source_pattern`` must have a corresponding test
18 in ``test_pattern`` (matched by function name).
19
20 ``layer_boundary``
21 Enforce import direction between layers: ``lower`` may not import from
22 ``upper``.
23
24 Rule file format (``.muse/invariants.toml``)
25 --------------------------------------------
26
27 .. code-block:: toml
28
29 [[rules]]
30 type = "no_cycles"
31 name = "no import cycles"
32
33 [[rules]]
34 type = "forbidden_dependency"
35 name = "core must not import cli"
36 source_pattern = "muse/core/"
37 forbidden_pattern = "muse/cli/"
38
39 [[rules]]
40 type = "layer_boundary"
41 name = "plugins must not import from cli"
42 lower = "muse/plugins/"
43 upper = "muse/cli/"
44
45 [[rules]]
46 type = "required_test"
47 name = "all billing functions must have tests"
48 source_pattern = "src/billing.py"
49 test_pattern = "tests/test_billing.py"
50
51 Usage::
52
53 muse invariants
54 muse invariants --commit HEAD~5
55 muse invariants --json
56
57 Output::
58
59 Invariant check — commit a1b2c3d4
60 ──────────────────────────────────────────────────────────────
61
62 ✅ no import cycles passed
63 🔴 core must not import cli VIOLATED
64 muse/core/snapshot.py imports muse/cli/app (1 violation)
65 ✅ plugins must not import from cli passed
66
67 1 rule passed · 1 rule violated
68
69 Flags:
70
71 ``--commit, -c REF``
72 Check a historical snapshot instead of HEAD.
73
74 ``--json``
75 Emit results as JSON.
76 """
77
78 from __future__ import annotations
79
80 import json
81 import logging
82 import pathlib
83 import re
84
85 import typer
86
87 from muse.core.errors import ExitCode
88 from muse.core.object_store import read_object
89 from muse.core.repo import require_repo
90 from muse.core.store import get_commit_snapshot_manifest, resolve_commit_ref
91 from muse.plugins.code._query import is_semantic, symbols_for_snapshot
92 from muse.plugins.code.ast_parser import parse_symbols
93
94 logger = logging.getLogger(__name__)
95
96 app = typer.Typer()
97
98 _INVARIANTS_FILE = pathlib.PurePosixPath(".muse") / "invariants.toml"
99
100
101 def _read_repo_id(root: pathlib.Path) -> str:
102 return str(json.loads((root / ".muse" / "repo.json").read_text())["repo_id"])
103
104
105 def _read_branch(root: pathlib.Path) -> str:
106 head_ref = (root / ".muse" / "HEAD").read_text().strip()
107 return head_ref.removeprefix("refs/heads/").strip()
108
109
110 class _RuleResult:
111 def __init__(self, name: str, rule_type: str, passed: bool, violations: list[str]) -> None:
112 self.name = name
113 self.rule_type = rule_type
114 self.passed = passed
115 self.violations = violations
116
117 def to_dict(self) -> dict[str, str | bool | list[str]]:
118 return {
119 "name": self.name,
120 "rule_type": self.rule_type,
121 "passed": self.passed,
122 "violations": self.violations,
123 }
124
125
126 def _parse_toml_rules(text: str) -> list[dict[str, str]]:
127 """Minimal TOML parser for [[rules]] sections (no external dependencies).
128
129 Parses key = "value" lines within [[rules]] blocks. Does not support
130 nested tables, arrays, or multi-line strings — the invariants format is
131 intentionally simple.
132 """
133 rules: list[dict[str, str]] = []
134 current: dict[str, str] | None = None
135 for line in text.splitlines():
136 line = line.strip()
137 if line == "[[rules]]":
138 if current is not None:
139 rules.append(current)
140 current = {}
141 continue
142 if current is not None and "=" in line and not line.startswith("#"):
143 key, _, val = line.partition("=")
144 key = key.strip()
145 val = val.strip().strip('"').strip("'")
146 current[key] = val
147 if current is not None:
148 rules.append(current)
149 return rules
150
151
152 def _build_import_map(root: pathlib.Path, manifest: dict[str, str]) -> dict[str, list[str]]:
153 """Return {file_path: [imported_file_paths]} from snapshot."""
154 stem_to_file: dict[str, str] = {
155 pathlib.PurePosixPath(fp).stem: fp for fp in manifest
156 }
157 imports: dict[str, list[str]] = {fp: [] for fp in manifest}
158 for file_path, obj_id in sorted(manifest.items()):
159 raw = read_object(root, obj_id)
160 if raw is None:
161 continue
162 tree = parse_symbols(raw, file_path)
163 for rec in tree.values():
164 if rec["kind"] != "import":
165 continue
166 imported = rec["qualified_name"].split(".")[-1].replace("import::", "")
167 target = stem_to_file.get(imported)
168 if target and target != file_path:
169 imports[file_path].append(target)
170 return imports
171
172
173 def _find_cycles(imports: dict[str, list[str]]) -> list[list[str]]:
174 """DFS cycle detection."""
175 cycles: list[list[str]] = []
176 visited: set[str] = set()
177 in_stack: set[str] = set()
178
179 def dfs(node: str, path: list[str]) -> None:
180 if node in in_stack:
181 start = path.index(node)
182 cycles.append(path[start:] + [node])
183 return
184 if node in visited:
185 return
186 visited.add(node)
187 in_stack.add(node)
188 for n in imports.get(node, []):
189 dfs(n, path + [node])
190 in_stack.discard(node)
191
192 for node in imports:
193 if node not in visited:
194 dfs(node, [])
195 return cycles
196
197
198 def _check_rule(
199 rule: dict[str, str],
200 manifest: dict[str, str],
201 import_map: dict[str, list[str]],
202 root: pathlib.Path,
203 ) -> _RuleResult:
204 name = rule.get("name", "unnamed")
205 rule_type = rule.get("type", "")
206 violations: list[str] = []
207
208 if rule_type == "no_cycles":
209 cycles = _find_cycles(import_map)
210 for cycle in cycles:
211 violations.append(" → ".join(cycle))
212 return _RuleResult(name, rule_type, not violations, violations)
213
214 if rule_type == "forbidden_dependency":
215 src_pat = rule.get("source_pattern", "")
216 forb_pat = rule.get("forbidden_pattern", "")
217 for fp, deps in sorted(import_map.items()):
218 if src_pat and src_pat not in fp:
219 continue
220 for dep in deps:
221 if forb_pat and forb_pat in dep:
222 violations.append(f"{fp} imports {dep}")
223 return _RuleResult(name, rule_type, not violations, violations)
224
225 if rule_type == "layer_boundary":
226 lower = rule.get("lower", "")
227 upper = rule.get("upper", "")
228 for fp, deps in sorted(import_map.items()):
229 if lower and lower not in fp:
230 continue
231 for dep in deps:
232 if upper and upper in dep:
233 violations.append(f"{fp} (lower layer) imports {dep} (upper layer)")
234 return _RuleResult(name, rule_type, not violations, violations)
235
236 if rule_type == "required_test":
237 src_pat = rule.get("source_pattern", "")
238 test_pat = rule.get("test_pattern", "")
239 # Collect public function names from source files.
240 src_funcs: set[str] = set()
241 for fp in manifest:
242 if src_pat and src_pat not in fp:
243 continue
244 raw = read_object(root, manifest[fp])
245 if raw is None:
246 continue
247 tree = parse_symbols(raw, fp)
248 for addr, rec in tree.items():
249 if rec["kind"] in ("function", "async_function") and not rec["name"].startswith("_"):
250 src_funcs.add(rec["name"])
251 # Collect test function names from test files.
252 test_funcs: set[str] = set()
253 for fp in manifest:
254 if test_pat and test_pat not in fp:
255 continue
256 raw = read_object(root, manifest[fp])
257 if raw is None:
258 continue
259 tree = parse_symbols(raw, fp)
260 for rec in tree.values():
261 if rec["kind"] in ("function", "async_function"):
262 test_funcs.add(rec["name"])
263 # Check that every src_func has a corresponding test_<name> or <name> in test.
264 for func in sorted(src_funcs):
265 has_test = f"test_{func}" in test_funcs or func in test_funcs
266 if not has_test:
267 violations.append(f"no test found for function '{func}'")
268 return _RuleResult(name, rule_type, not violations, violations)
269
270 # Unknown rule type.
271 return _RuleResult(name, rule_type, False, [f"unknown rule type: {rule_type!r}"])
272
273
274 @app.callback(invoke_without_command=True)
275 def invariants(
276 ctx: typer.Context,
277 ref: str | None = typer.Option(
278 None, "--commit", "-c", metavar="REF",
279 help="Check a historical snapshot instead of HEAD.",
280 ),
281 as_json: bool = typer.Option(False, "--json", help="Emit results as JSON."),
282 ) -> None:
283 """Check architectural invariants from .muse/invariants.toml.
284
285 Loads declarative rules and verifies them against the committed snapshot:
286
287 * **no_cycles** — the import graph must be acyclic
288 * **forbidden_dependency** — enforces layer boundaries
289 * **layer_boundary** — lower layers must not import from upper layers
290 * **required_test** — public functions must have corresponding tests
291
292 Create ``.muse/invariants.toml`` with ``[[rules]]`` blocks to define your
293 architectural constraints. All rules run against the committed snapshot;
294 no working-tree parsing or code execution required.
295 """
296 root = require_repo()
297 repo_id = _read_repo_id(root)
298 branch = _read_branch(root)
299
300 invariants_path = root / ".muse" / "invariants.toml"
301 if not invariants_path.exists():
302 typer.echo(
303 "⚠️ .muse/invariants.toml not found.\n"
304 "Create it with [[rules]] blocks to define architectural constraints.\n"
305 "See: muse invariants --help for the rule format."
306 )
307 return
308
309 rules = _parse_toml_rules(invariants_path.read_text())
310 if not rules:
311 typer.echo(" (no rules defined in .muse/invariants.toml)")
312 return
313
314 commit = resolve_commit_ref(root, repo_id, branch, ref)
315 if commit is None:
316 typer.echo(f"❌ Commit '{ref or 'HEAD'}' not found.", err=True)
317 raise typer.Exit(code=ExitCode.USER_ERROR)
318
319 manifest = get_commit_snapshot_manifest(root, commit.commit_id) or {}
320 import_map = _build_import_map(root, manifest)
321
322 results: list[_RuleResult] = []
323 for rule in rules:
324 result = _check_rule(rule, manifest, import_map, root)
325 results.append(result)
326
327 if as_json:
328 typer.echo(json.dumps(
329 {
330 "schema_version": 1,
331 "commit": commit.commit_id[:8],
332 "rules_checked": len(results),
333 "passed": sum(1 for r in results if r.passed),
334 "violated": sum(1 for r in results if not r.passed),
335 "results": [r.to_dict() for r in results],
336 },
337 indent=2,
338 ))
339 return
340
341 typer.echo(f"\nInvariant check — commit {commit.commit_id[:8]}")
342 typer.echo("─" * 62)
343
344 for result in results:
345 icon = "✅" if result.passed else "🔴"
346 status = "passed" if result.passed else "VIOLATED"
347 typer.echo(f"\n{icon} {result.name:<40} {status}")
348 if not result.passed:
349 for v in result.violations[:5]:
350 typer.echo(f" {v}")
351 if len(result.violations) > 5:
352 typer.echo(f" … and {len(result.violations) - 5} more")
353
354 passed = sum(1 for r in results if r.passed)
355 violated = sum(1 for r in results if not r.passed)
356 typer.echo(f"\n {passed} rule(s) passed · {violated} rule(s) violated")