invariants.py
python
| 1 | """muse invariants — enforce architectural rules from .muse/invariants.toml. |
| 2 | |
| 3 | Loads invariant rules from ``.muse/invariants.toml`` and checks them against |
| 4 | the committed snapshot at HEAD (or a given commit). Rules are declarative |
| 5 | architectural constraints — enforced by analysis, not by the runtime. |
| 6 | |
| 7 | Supported rule types |
| 8 | -------------------- |
| 9 | ``no_cycles`` |
| 10 | The import graph must have no cycles. Detects import cycle violations. |
| 11 | |
| 12 | ``forbidden_dependency`` |
| 13 | A file (or file pattern) must not import from another file (or pattern). |
| 14 | Enforces layer boundaries (e.g. "core must not import from cli"). |
| 15 | |
| 16 | ``required_test`` |
| 17 | Every public function in ``source_pattern`` must have a corresponding test |
| 18 | in ``test_pattern`` (matched by function name). |
| 19 | |
| 20 | ``layer_boundary`` |
| 21 | Enforce import direction between layers: ``lower`` may not import from |
| 22 | ``upper``. |
| 23 | |
| 24 | Rule file format (``.muse/invariants.toml``) |
| 25 | -------------------------------------------- |
| 26 | |
| 27 | .. code-block:: toml |
| 28 | |
| 29 | [[rules]] |
| 30 | type = "no_cycles" |
| 31 | name = "no import cycles" |
| 32 | |
| 33 | [[rules]] |
| 34 | type = "forbidden_dependency" |
| 35 | name = "core must not import cli" |
| 36 | source_pattern = "muse/core/" |
| 37 | forbidden_pattern = "muse/cli/" |
| 38 | |
| 39 | [[rules]] |
| 40 | type = "layer_boundary" |
| 41 | name = "plugins must not import from cli" |
| 42 | lower = "muse/plugins/" |
| 43 | upper = "muse/cli/" |
| 44 | |
| 45 | [[rules]] |
| 46 | type = "required_test" |
| 47 | name = "all billing functions must have tests" |
| 48 | source_pattern = "src/billing.py" |
| 49 | test_pattern = "tests/test_billing.py" |
| 50 | |
| 51 | Usage:: |
| 52 | |
| 53 | muse invariants |
| 54 | muse invariants --commit HEAD~5 |
| 55 | muse invariants --json |
| 56 | |
| 57 | Output:: |
| 58 | |
| 59 | Invariant check — commit a1b2c3d4 |
| 60 | ────────────────────────────────────────────────────────────── |
| 61 | |
| 62 | ✅ no import cycles passed |
| 63 | 🔴 core must not import cli VIOLATED |
| 64 | muse/core/snapshot.py imports muse/cli/app (1 violation) |
| 65 | ✅ plugins must not import from cli passed |
| 66 | |
| 67 | 1 rule passed · 1 rule violated |
| 68 | |
| 69 | Flags: |
| 70 | |
| 71 | ``--commit, -c REF`` |
| 72 | Check a historical snapshot instead of HEAD. |
| 73 | |
| 74 | ``--json`` |
| 75 | Emit results as JSON. |
| 76 | """ |
| 77 | |
| 78 | from __future__ import annotations |
| 79 | |
| 80 | import json |
| 81 | import logging |
| 82 | import pathlib |
| 83 | import re |
| 84 | |
| 85 | import typer |
| 86 | |
| 87 | from muse.core.errors import ExitCode |
| 88 | from muse.core.object_store import read_object |
| 89 | from muse.core.repo import require_repo |
| 90 | from muse.core.store import get_commit_snapshot_manifest, resolve_commit_ref |
| 91 | from muse.plugins.code._query import is_semantic, symbols_for_snapshot |
| 92 | from muse.plugins.code.ast_parser import parse_symbols |
| 93 | |
| 94 | logger = logging.getLogger(__name__) |
| 95 | |
| 96 | app = typer.Typer() |
| 97 | |
| 98 | _INVARIANTS_FILE = pathlib.PurePosixPath(".muse") / "invariants.toml" |
| 99 | |
| 100 | |
| 101 | def _read_repo_id(root: pathlib.Path) -> str: |
| 102 | return str(json.loads((root / ".muse" / "repo.json").read_text())["repo_id"]) |
| 103 | |
| 104 | |
| 105 | def _read_branch(root: pathlib.Path) -> str: |
| 106 | head_ref = (root / ".muse" / "HEAD").read_text().strip() |
| 107 | return head_ref.removeprefix("refs/heads/").strip() |
| 108 | |
| 109 | |
| 110 | class _RuleResult: |
| 111 | def __init__(self, name: str, rule_type: str, passed: bool, violations: list[str]) -> None: |
| 112 | self.name = name |
| 113 | self.rule_type = rule_type |
| 114 | self.passed = passed |
| 115 | self.violations = violations |
| 116 | |
| 117 | def to_dict(self) -> dict[str, str | bool | list[str]]: |
| 118 | return { |
| 119 | "name": self.name, |
| 120 | "rule_type": self.rule_type, |
| 121 | "passed": self.passed, |
| 122 | "violations": self.violations, |
| 123 | } |
| 124 | |
| 125 | |
| 126 | def _parse_toml_rules(text: str) -> list[dict[str, str]]: |
| 127 | """Minimal TOML parser for [[rules]] sections (no external dependencies). |
| 128 | |
| 129 | Parses key = "value" lines within [[rules]] blocks. Does not support |
| 130 | nested tables, arrays, or multi-line strings — the invariants format is |
| 131 | intentionally simple. |
| 132 | """ |
| 133 | rules: list[dict[str, str]] = [] |
| 134 | current: dict[str, str] | None = None |
| 135 | for line in text.splitlines(): |
| 136 | line = line.strip() |
| 137 | if line == "[[rules]]": |
| 138 | if current is not None: |
| 139 | rules.append(current) |
| 140 | current = {} |
| 141 | continue |
| 142 | if current is not None and "=" in line and not line.startswith("#"): |
| 143 | key, _, val = line.partition("=") |
| 144 | key = key.strip() |
| 145 | val = val.strip().strip('"').strip("'") |
| 146 | current[key] = val |
| 147 | if current is not None: |
| 148 | rules.append(current) |
| 149 | return rules |
| 150 | |
| 151 | |
| 152 | def _build_import_map(root: pathlib.Path, manifest: dict[str, str]) -> dict[str, list[str]]: |
| 153 | """Return {file_path: [imported_file_paths]} from snapshot.""" |
| 154 | stem_to_file: dict[str, str] = { |
| 155 | pathlib.PurePosixPath(fp).stem: fp for fp in manifest |
| 156 | } |
| 157 | imports: dict[str, list[str]] = {fp: [] for fp in manifest} |
| 158 | for file_path, obj_id in sorted(manifest.items()): |
| 159 | raw = read_object(root, obj_id) |
| 160 | if raw is None: |
| 161 | continue |
| 162 | tree = parse_symbols(raw, file_path) |
| 163 | for rec in tree.values(): |
| 164 | if rec["kind"] != "import": |
| 165 | continue |
| 166 | imported = rec["qualified_name"].split(".")[-1].replace("import::", "") |
| 167 | target = stem_to_file.get(imported) |
| 168 | if target and target != file_path: |
| 169 | imports[file_path].append(target) |
| 170 | return imports |
| 171 | |
| 172 | |
| 173 | def _find_cycles(imports: dict[str, list[str]]) -> list[list[str]]: |
| 174 | """DFS cycle detection.""" |
| 175 | cycles: list[list[str]] = [] |
| 176 | visited: set[str] = set() |
| 177 | in_stack: set[str] = set() |
| 178 | |
| 179 | def dfs(node: str, path: list[str]) -> None: |
| 180 | if node in in_stack: |
| 181 | start = path.index(node) |
| 182 | cycles.append(path[start:] + [node]) |
| 183 | return |
| 184 | if node in visited: |
| 185 | return |
| 186 | visited.add(node) |
| 187 | in_stack.add(node) |
| 188 | for n in imports.get(node, []): |
| 189 | dfs(n, path + [node]) |
| 190 | in_stack.discard(node) |
| 191 | |
| 192 | for node in imports: |
| 193 | if node not in visited: |
| 194 | dfs(node, []) |
| 195 | return cycles |
| 196 | |
| 197 | |
| 198 | def _check_rule( |
| 199 | rule: dict[str, str], |
| 200 | manifest: dict[str, str], |
| 201 | import_map: dict[str, list[str]], |
| 202 | root: pathlib.Path, |
| 203 | ) -> _RuleResult: |
| 204 | name = rule.get("name", "unnamed") |
| 205 | rule_type = rule.get("type", "") |
| 206 | violations: list[str] = [] |
| 207 | |
| 208 | if rule_type == "no_cycles": |
| 209 | cycles = _find_cycles(import_map) |
| 210 | for cycle in cycles: |
| 211 | violations.append(" → ".join(cycle)) |
| 212 | return _RuleResult(name, rule_type, not violations, violations) |
| 213 | |
| 214 | if rule_type == "forbidden_dependency": |
| 215 | src_pat = rule.get("source_pattern", "") |
| 216 | forb_pat = rule.get("forbidden_pattern", "") |
| 217 | for fp, deps in sorted(import_map.items()): |
| 218 | if src_pat and src_pat not in fp: |
| 219 | continue |
| 220 | for dep in deps: |
| 221 | if forb_pat and forb_pat in dep: |
| 222 | violations.append(f"{fp} imports {dep}") |
| 223 | return _RuleResult(name, rule_type, not violations, violations) |
| 224 | |
| 225 | if rule_type == "layer_boundary": |
| 226 | lower = rule.get("lower", "") |
| 227 | upper = rule.get("upper", "") |
| 228 | for fp, deps in sorted(import_map.items()): |
| 229 | if lower and lower not in fp: |
| 230 | continue |
| 231 | for dep in deps: |
| 232 | if upper and upper in dep: |
| 233 | violations.append(f"{fp} (lower layer) imports {dep} (upper layer)") |
| 234 | return _RuleResult(name, rule_type, not violations, violations) |
| 235 | |
| 236 | if rule_type == "required_test": |
| 237 | src_pat = rule.get("source_pattern", "") |
| 238 | test_pat = rule.get("test_pattern", "") |
| 239 | # Collect public function names from source files. |
| 240 | src_funcs: set[str] = set() |
| 241 | for fp in manifest: |
| 242 | if src_pat and src_pat not in fp: |
| 243 | continue |
| 244 | raw = read_object(root, manifest[fp]) |
| 245 | if raw is None: |
| 246 | continue |
| 247 | tree = parse_symbols(raw, fp) |
| 248 | for addr, rec in tree.items(): |
| 249 | if rec["kind"] in ("function", "async_function") and not rec["name"].startswith("_"): |
| 250 | src_funcs.add(rec["name"]) |
| 251 | # Collect test function names from test files. |
| 252 | test_funcs: set[str] = set() |
| 253 | for fp in manifest: |
| 254 | if test_pat and test_pat not in fp: |
| 255 | continue |
| 256 | raw = read_object(root, manifest[fp]) |
| 257 | if raw is None: |
| 258 | continue |
| 259 | tree = parse_symbols(raw, fp) |
| 260 | for rec in tree.values(): |
| 261 | if rec["kind"] in ("function", "async_function"): |
| 262 | test_funcs.add(rec["name"]) |
| 263 | # Check that every src_func has a corresponding test_<name> or <name> in test. |
| 264 | for func in sorted(src_funcs): |
| 265 | has_test = f"test_{func}" in test_funcs or func in test_funcs |
| 266 | if not has_test: |
| 267 | violations.append(f"no test found for function '{func}'") |
| 268 | return _RuleResult(name, rule_type, not violations, violations) |
| 269 | |
| 270 | # Unknown rule type. |
| 271 | return _RuleResult(name, rule_type, False, [f"unknown rule type: {rule_type!r}"]) |
| 272 | |
| 273 | |
| 274 | @app.callback(invoke_without_command=True) |
| 275 | def invariants( |
| 276 | ctx: typer.Context, |
| 277 | ref: str | None = typer.Option( |
| 278 | None, "--commit", "-c", metavar="REF", |
| 279 | help="Check a historical snapshot instead of HEAD.", |
| 280 | ), |
| 281 | as_json: bool = typer.Option(False, "--json", help="Emit results as JSON."), |
| 282 | ) -> None: |
| 283 | """Check architectural invariants from .muse/invariants.toml. |
| 284 | |
| 285 | Loads declarative rules and verifies them against the committed snapshot: |
| 286 | |
| 287 | * **no_cycles** — the import graph must be acyclic |
| 288 | * **forbidden_dependency** — enforces layer boundaries |
| 289 | * **layer_boundary** — lower layers must not import from upper layers |
| 290 | * **required_test** — public functions must have corresponding tests |
| 291 | |
| 292 | Create ``.muse/invariants.toml`` with ``[[rules]]`` blocks to define your |
| 293 | architectural constraints. All rules run against the committed snapshot; |
| 294 | no working-tree parsing or code execution required. |
| 295 | """ |
| 296 | root = require_repo() |
| 297 | repo_id = _read_repo_id(root) |
| 298 | branch = _read_branch(root) |
| 299 | |
| 300 | invariants_path = root / ".muse" / "invariants.toml" |
| 301 | if not invariants_path.exists(): |
| 302 | typer.echo( |
| 303 | "⚠️ .muse/invariants.toml not found.\n" |
| 304 | "Create it with [[rules]] blocks to define architectural constraints.\n" |
| 305 | "See: muse invariants --help for the rule format." |
| 306 | ) |
| 307 | return |
| 308 | |
| 309 | rules = _parse_toml_rules(invariants_path.read_text()) |
| 310 | if not rules: |
| 311 | typer.echo(" (no rules defined in .muse/invariants.toml)") |
| 312 | return |
| 313 | |
| 314 | commit = resolve_commit_ref(root, repo_id, branch, ref) |
| 315 | if commit is None: |
| 316 | typer.echo(f"❌ Commit '{ref or 'HEAD'}' not found.", err=True) |
| 317 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 318 | |
| 319 | manifest = get_commit_snapshot_manifest(root, commit.commit_id) or {} |
| 320 | import_map = _build_import_map(root, manifest) |
| 321 | |
| 322 | results: list[_RuleResult] = [] |
| 323 | for rule in rules: |
| 324 | result = _check_rule(rule, manifest, import_map, root) |
| 325 | results.append(result) |
| 326 | |
| 327 | if as_json: |
| 328 | typer.echo(json.dumps( |
| 329 | { |
| 330 | "schema_version": 1, |
| 331 | "commit": commit.commit_id[:8], |
| 332 | "rules_checked": len(results), |
| 333 | "passed": sum(1 for r in results if r.passed), |
| 334 | "violated": sum(1 for r in results if not r.passed), |
| 335 | "results": [r.to_dict() for r in results], |
| 336 | }, |
| 337 | indent=2, |
| 338 | )) |
| 339 | return |
| 340 | |
| 341 | typer.echo(f"\nInvariant check — commit {commit.commit_id[:8]}") |
| 342 | typer.echo("─" * 62) |
| 343 | |
| 344 | for result in results: |
| 345 | icon = "✅" if result.passed else "🔴" |
| 346 | status = "passed" if result.passed else "VIOLATED" |
| 347 | typer.echo(f"\n{icon} {result.name:<40} {status}") |
| 348 | if not result.passed: |
| 349 | for v in result.violations[:5]: |
| 350 | typer.echo(f" {v}") |
| 351 | if len(result.violations) > 5: |
| 352 | typer.echo(f" … and {len(result.violations) - 5} more") |
| 353 | |
| 354 | passed = sum(1 for r in results if r.passed) |
| 355 | violated = sum(1 for r in results if not r.passed) |
| 356 | typer.echo(f"\n {passed} rule(s) passed · {violated} rule(s) violated") |