gabriel / muse public
_invariants.py python
659 lines 22.4 KB
bda49bdb feat: redesign .museignore as TOML with domain-scoped sections (#100) Gabriel Cardona <cgcardona@gmail.com> 5d ago
1 """Code-domain invariants engine for Muse.
2
3 Evaluates semantic rules against code snapshots. Rules are declared in
4 ``.muse/code_invariants.toml`` and evaluated at commit time, merge time, or
5 on-demand via ``muse code-check``.
6
7 Rule types
8 ----------
9
10 ``max_complexity``
11 Detects functions / methods whose estimated cyclomatic complexity exceeds
12 *threshold*. Complexity is approximated by counting control-flow branch
13 points (``if``, ``elif``, ``for``, ``while``, ``except``, ``with``,
14 ``and``, ``or``) inside each symbol's body. This correlates well with
15 real cyclomatic complexity for Python and is language-agnostic for other
16 tree-sitter-parsed languages.
17
18 ``no_circular_imports``
19 Detects import cycles among Python files in the snapshot. Builds a
20 directed graph (file → files it imports) and runs DFS cycle detection.
21 Reports each cycle as one violation at the root file of the cycle.
22
23 ``no_dead_exports``
24 Detects top-level functions and classes that are never imported by any
25 other file in the snapshot (dead exports / unreachable public API).
26 Only applies to semantic files; test files and ``__init__.py`` are exempt.
27
28 ``test_coverage_floor``
29 Requires that at least *min_ratio* of non-test functions have a
30 corresponding test function (detected by ``test_`` prefix convention).
31 Reports the actual vs required coverage ratio when the floor is not met.
32
33 TOML example
34 ------------
35 ::
36
37 [[rule]]
38 name = "complexity_gate"
39 severity = "error"
40 scope = "function"
41 rule_type = "max_complexity"
42 [rule.params]
43 threshold = 15
44
45 [[rule]]
46 name = "no_cycles"
47 severity = "error"
48 scope = "file"
49 rule_type = "no_circular_imports"
50
51 [[rule]]
52 name = "dead_exports"
53 severity = "warning"
54 scope = "file"
55 rule_type = "no_dead_exports"
56
57 [[rule]]
58 name = "test_coverage"
59 severity = "warning"
60 scope = "repo"
61 rule_type = "test_coverage_floor"
62 [rule.params]
63 min_ratio = 0.30
64
65 Public API
66 ----------
67 - :class:`CodeInvariantRule` — code-specific rule declaration.
68 - :class:`CodeViolation` — violation with file + symbol address.
69 - :class:`CodeInvariantReport` — full report for one commit.
70 - :class:`CodeChecker` — satisfies :class:`~muse.core.invariants.InvariantChecker`.
71 - :func:`load_invariant_rules` — load from TOML with built-in defaults.
72 - :func:`run_invariants` — top-level runner.
73 """
74
75 from __future__ import annotations
76
77 import ast
78 import logging
79 import pathlib
80 from typing import Literal, TypedDict
81
82 from muse.core.invariants import (
83 BaseReport,
84 BaseViolation,
85 InvariantSeverity,
86 format_report,
87 load_rules_toml,
88 make_report,
89 )
90 from muse.core.object_store import read_object
91 from muse.core.store import get_commit_snapshot_manifest
92 from muse.plugins.code.ast_parser import (
93 SEMANTIC_EXTENSIONS,
94 SymbolTree,
95 adapter_for_path,
96 parse_symbols,
97 )
98
99 logger = logging.getLogger(__name__)
100
101 _DEFAULT_RULES_FILE = ".muse/code_invariants.toml"
102
103 # ---------------------------------------------------------------------------
104 # Types
105 # ---------------------------------------------------------------------------
106
107
108 class _RuleRequired(TypedDict):
109 name: str
110 severity: InvariantSeverity
111 scope: Literal["function", "file", "repo", "global"]
112 rule_type: str
113
114
115 class CodeInvariantRule(_RuleRequired, total=False):
116 """A single code invariant rule declaration.
117
118 ``name`` Unique human-readable identifier.
119 ``severity`` ``"info"``, ``"warning"``, or ``"error"``.
120 ``scope`` Granularity: ``"function"``, ``"file"``, ``"repo"``.
121 ``rule_type`` Built-in type: ``"max_complexity"``,
122 ``"no_circular_imports"``, ``"no_dead_exports"``,
123 ``"test_coverage_floor"``.
124 ``params`` Rule-specific numeric / string parameters.
125 """
126
127 params: dict[str, str | int | float]
128
129
130 class CodeViolation(TypedDict):
131 """A code invariant violation with precise source location.
132
133 ``rule_name`` Rule that fired.
134 ``severity`` Inherited from the rule.
135 ``address`` ``"file.py::symbol_name"`` or ``"file.py"`` for file-level.
136 ``description`` Human-readable explanation.
137 ``file`` Workspace-relative file path.
138 ``symbol`` Symbol name (empty string for file-level violations).
139 ``detail`` Additional context (e.g. complexity score, cycle path).
140 """
141
142 rule_name: str
143 severity: InvariantSeverity
144 address: str
145 description: str
146 file: str
147 symbol: str
148 detail: str
149
150
151 # ---------------------------------------------------------------------------
152 # Built-in default rules
153 # ---------------------------------------------------------------------------
154
155 _BUILTIN_DEFAULTS: list[CodeInvariantRule] = [
156 CodeInvariantRule(
157 name="complexity_gate",
158 severity="warning",
159 scope="function",
160 rule_type="max_complexity",
161 params={"threshold": 10},
162 ),
163 CodeInvariantRule(
164 name="no_cycles",
165 severity="error",
166 scope="file",
167 rule_type="no_circular_imports",
168 params={},
169 ),
170 CodeInvariantRule(
171 name="dead_exports",
172 severity="warning",
173 scope="file",
174 rule_type="no_dead_exports",
175 params={},
176 ),
177 ]
178
179
180 # ---------------------------------------------------------------------------
181 # Rule implementations
182 # ---------------------------------------------------------------------------
183
184
185 def _estimate_complexity(source: bytes, file_path: str) -> dict[str, int]:
186 """Return {symbol_address: complexity_score} for a Python source file.
187
188 Uses a simple branch-count heuristic: each ``if``, ``elif``, ``for``,
189 ``while``, ``except``, ``with``, ``and``, ``or``, ``assert``,
190 ``comprehension`` adds 1 to the enclosing function's score. Starting
191 score is 1 (the function itself).
192
193 Returns an empty dict for non-Python files or parse failures.
194 """
195 if not file_path.endswith((".py", ".pyi")):
196 return {}
197 try:
198 tree = ast.parse(source, filename=file_path)
199 except SyntaxError:
200 return {}
201
202 branch_nodes = (
203 ast.If, ast.For, ast.While, ast.ExceptHandler,
204 ast.With, ast.AsyncWith, ast.AsyncFor,
205 ast.BoolOp, ast.Assert, ast.comprehension,
206 )
207
208 scores: dict[str, int] = {}
209
210 def _score_fn(node: ast.FunctionDef | ast.AsyncFunctionDef, prefix: str = "") -> None:
211 name = node.name
212 qualified = f"{prefix}{name}" if prefix else name
213 addr = f"{file_path}::{qualified}"
214 score = 1 # base complexity
215 for child in ast.walk(node):
216 if isinstance(child, branch_nodes):
217 score += 1
218 scores[addr] = score
219
220 for node in ast.walk(tree):
221 if isinstance(node, ast.ClassDef):
222 for item in ast.walk(node):
223 if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
224 _score_fn(item, prefix=f"{node.name}.")
225 elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
226 _score_fn(node)
227
228 return scores
229
230
231 def check_max_complexity(
232 manifest: dict[str, str],
233 repo_root: pathlib.Path,
234 rule_name: str,
235 severity: InvariantSeverity,
236 *,
237 threshold: int = 10,
238 ) -> list[CodeViolation]:
239 """Detect functions whose estimated cyclomatic complexity exceeds *threshold*."""
240 violations: list[CodeViolation] = []
241 for file_path, content_hash in manifest.items():
242 if not file_path.endswith((".py", ".pyi")):
243 continue
244 source = read_object(repo_root, content_hash)
245 if source is None:
246 continue
247 scores = _estimate_complexity(source, file_path)
248 for addr, score in sorted(scores.items()):
249 if score > threshold:
250 symbol = addr.split("::", 1)[-1] if "::" in addr else ""
251 violations.append(CodeViolation(
252 rule_name=rule_name,
253 severity=severity,
254 address=addr,
255 description=(
256 f"Complexity {score} exceeds threshold {threshold}. "
257 "Consider extracting helper functions."
258 ),
259 file=file_path,
260 symbol=symbol,
261 detail=f"score={score} threshold={threshold}",
262 ))
263 return violations
264
265
266 def _build_import_graph(
267 manifest: dict[str, str],
268 repo_root: pathlib.Path,
269 ) -> dict[str, set[str]]:
270 """Build a directed import graph: {file → set of imported files}.
271
272 Only tracks intra-repo imports (files that exist in the manifest).
273 """
274 file_set = set(manifest)
275 # Build a module-name → file-path index for intra-repo resolution.
276 module_to_file: dict[str, str] = {}
277 for fp in file_set:
278 if fp.endswith((".py", ".pyi")):
279 # Convert path to module name: strip suffix, replace / with .
280 mod = fp.removesuffix(".pyi").removesuffix(".py").replace("/", ".").replace("\\", ".")
281 module_to_file[mod] = fp
282 # Also index by last segment for relative guesses.
283 last = mod.rsplit(".", 1)[-1]
284 module_to_file.setdefault(last, fp)
285
286 graph: dict[str, set[str]] = {fp: set() for fp in file_set if fp.endswith(".py")}
287
288 for file_path, content_hash in manifest.items():
289 if not file_path.endswith(".py"):
290 continue
291 source = read_object(repo_root, content_hash)
292 if source is None:
293 continue
294 try:
295 tree = ast.parse(source, filename=file_path)
296 except SyntaxError:
297 continue
298
299 for node in ast.walk(tree):
300 if isinstance(node, ast.Import):
301 for alias in node.names:
302 target = module_to_file.get(alias.name)
303 if target and target != file_path:
304 graph[file_path].add(target)
305 elif isinstance(node, ast.ImportFrom) and node.module:
306 target = module_to_file.get(node.module)
307 if target and target != file_path:
308 graph[file_path].add(target)
309
310 return graph
311
312
313 def _find_cycles(graph: dict[str, set[str]]) -> list[list[str]]:
314 """DFS cycle detection; returns list of cycles as file-path lists."""
315 WHITE, GRAY, BLACK = 0, 1, 2
316 color = {n: WHITE for n in graph}
317 stack: list[str] = []
318 cycles: list[list[str]] = []
319
320 def dfs(node: str) -> None:
321 color[node] = GRAY
322 stack.append(node)
323 for neighbor in sorted(graph.get(node, set())):
324 if color.get(neighbor, BLACK) == WHITE:
325 dfs(neighbor)
326 elif color.get(neighbor, BLACK) == GRAY:
327 # Found a cycle — extract from stack.
328 idx = stack.index(neighbor)
329 cycle = stack[idx:]
330 # Deduplicate: only add if not already seen.
331 cycle_key = frozenset(cycle)
332 if not any(frozenset(c) == cycle_key for c in cycles):
333 cycles.append(list(cycle))
334 stack.pop()
335 color[node] = BLACK
336
337 for node in sorted(graph):
338 if color[node] == WHITE:
339 dfs(node)
340
341 return cycles
342
343
344 def check_no_circular_imports(
345 manifest: dict[str, str],
346 repo_root: pathlib.Path,
347 rule_name: str,
348 severity: InvariantSeverity,
349 ) -> list[CodeViolation]:
350 """Detect import cycles among Python files in the snapshot."""
351 graph = _build_import_graph(manifest, repo_root)
352 cycles = _find_cycles(graph)
353 violations: list[CodeViolation] = []
354 for cycle in cycles:
355 root_file = cycle[0]
356 cycle_str = " → ".join([*cycle, cycle[0]])
357 violations.append(CodeViolation(
358 rule_name=rule_name,
359 severity=severity,
360 address=root_file,
361 description=f"Circular import cycle detected: {cycle_str}",
362 file=root_file,
363 symbol="",
364 detail=cycle_str,
365 ))
366 return violations
367
368
369 def check_no_dead_exports(
370 manifest: dict[str, str],
371 repo_root: pathlib.Path,
372 rule_name: str,
373 severity: InvariantSeverity,
374 ) -> list[CodeViolation]:
375 """Detect top-level functions/classes never imported by any other file.
376
377 Exempt: test files, ``__init__.py``, files with ``__all__`` declarations
378 (which signal deliberate public API), and ``main`` functions.
379 """
380 violations: list[CodeViolation] = []
381
382 # Collect all intra-repo imported names.
383 imported_names: set[str] = set()
384 for file_path, content_hash in manifest.items():
385 if not file_path.endswith(".py"):
386 continue
387 source = read_object(repo_root, content_hash)
388 if source is None:
389 continue
390 try:
391 tree = ast.parse(source, filename=file_path)
392 except SyntaxError:
393 continue
394 for node in ast.walk(tree):
395 if isinstance(node, ast.Import):
396 for alias in node.names:
397 imported_names.add(alias.asname or alias.name.split(".")[-1])
398 elif isinstance(node, ast.ImportFrom):
399 for alias in node.names:
400 imported_names.add(alias.asname or alias.name)
401
402 # Check each non-test, non-init Python file.
403 for file_path, content_hash in manifest.items():
404 if not file_path.endswith(".py"):
405 continue
406 base = pathlib.PurePosixPath(file_path).name
407 if base.startswith("test_") or base == "__init__.py":
408 continue
409 source = read_object(repo_root, content_hash)
410 if source is None:
411 continue
412 try:
413 tree = ast.parse(source, filename=file_path)
414 except SyntaxError:
415 continue
416
417 # Skip files that declare __all__ (they manage their own exports).
418 has_all = any(
419 isinstance(n, ast.Assign)
420 and any(isinstance(t, ast.Name) and t.id == "__all__" for t in n.targets)
421 for n in ast.walk(tree)
422 )
423 if has_all:
424 continue
425
426 for node in ast.iter_child_nodes(tree):
427 if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
428 if node.name.startswith("_") or node.name == "main":
429 continue
430 if node.name not in imported_names:
431 addr = f"{file_path}::{node.name}"
432 violations.append(CodeViolation(
433 rule_name=rule_name,
434 severity=severity,
435 address=addr,
436 description=(
437 f"'{node.name}' is never imported by any other file. "
438 "Consider removing, making private (prefix _), or adding to __all__."
439 ),
440 file=file_path,
441 symbol=node.name,
442 detail="no importers found",
443 ))
444 elif isinstance(node, ast.ClassDef):
445 if node.name.startswith("_"):
446 continue
447 if node.name not in imported_names:
448 addr = f"{file_path}::{node.name}"
449 violations.append(CodeViolation(
450 rule_name=rule_name,
451 severity=severity,
452 address=addr,
453 description=(
454 f"Class '{node.name}' is never imported by any other file. "
455 "Consider removing or making private."
456 ),
457 file=file_path,
458 symbol=node.name,
459 detail="no importers found",
460 ))
461
462 return violations
463
464
465 def check_test_coverage_floor(
466 manifest: dict[str, str],
467 repo_root: pathlib.Path,
468 rule_name: str,
469 severity: InvariantSeverity,
470 *,
471 min_ratio: float = 0.30,
472 ) -> list[CodeViolation]:
473 """Require that at least *min_ratio* of functions have a test counterpart.
474
475 A function ``foo`` is considered "tested" if any test file contains a
476 function named ``test_foo`` or a class method containing ``foo`` in its
477 name. This is a naming-convention heuristic, not true coverage.
478 """
479 test_fn_names: set[str] = set()
480 all_fn_names: set[str] = set()
481
482 for file_path, content_hash in manifest.items():
483 if not file_path.endswith(".py"):
484 continue
485 source = read_object(repo_root, content_hash)
486 if source is None:
487 continue
488 try:
489 tree = ast.parse(source, filename=file_path)
490 except SyntaxError:
491 continue
492
493 base = pathlib.PurePosixPath(file_path).name
494 is_test = base.startswith("test_") or base.endswith("_test.py")
495
496 for node in ast.walk(tree):
497 if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
498 if is_test and node.name.startswith("test_"):
499 test_fn_names.add(node.name.removeprefix("test_"))
500 elif not is_test and not node.name.startswith("_"):
501 all_fn_names.add(node.name)
502
503 if not all_fn_names:
504 return []
505
506 covered = all_fn_names & test_fn_names
507 ratio = len(covered) / len(all_fn_names)
508
509 if ratio < min_ratio:
510 pct_actual = round(ratio * 100, 1)
511 pct_required = round(min_ratio * 100, 1)
512 return [CodeViolation(
513 rule_name=rule_name,
514 severity=severity,
515 address="repo",
516 description=(
517 f"Test coverage floor not met: {pct_actual}% of functions have test counterparts "
518 f"(required {pct_required}%). Untested: "
519 + ", ".join(sorted(all_fn_names - covered)[:10])
520 + ("…" if len(all_fn_names - covered) > 10 else "")
521 ),
522 file="",
523 symbol="",
524 detail=f"ratio={ratio:.3f} required={min_ratio:.3f}",
525 )]
526 return []
527
528
529 # ---------------------------------------------------------------------------
530 # Rule dispatch
531 # ---------------------------------------------------------------------------
532
533
534 def _dispatch_rule(
535 rule: CodeInvariantRule,
536 manifest: dict[str, str],
537 repo_root: pathlib.Path,
538 ) -> list[CodeViolation]:
539 """Dispatch a single rule to its implementation function."""
540 params = rule.get("params", {})
541 rule_name = rule["name"]
542 severity = rule["severity"]
543 rt = rule["rule_type"]
544
545 if rt == "max_complexity":
546 threshold = int(params.get("threshold", 10))
547 return check_max_complexity(manifest, repo_root, rule_name, severity, threshold=threshold)
548
549 if rt == "no_circular_imports":
550 return check_no_circular_imports(manifest, repo_root, rule_name, severity)
551
552 if rt == "no_dead_exports":
553 return check_no_dead_exports(manifest, repo_root, rule_name, severity)
554
555 if rt == "test_coverage_floor":
556 min_ratio = float(params.get("min_ratio", 0.30))
557 return check_test_coverage_floor(manifest, repo_root, rule_name, severity, min_ratio=min_ratio)
558
559 logger.warning("Unknown code invariant rule_type: %r — skipping", rt)
560 return []
561
562
563 # ---------------------------------------------------------------------------
564 # Public entry points
565 # ---------------------------------------------------------------------------
566
567
568 def load_invariant_rules(
569 rules_file: pathlib.Path | None = None,
570 ) -> list[CodeInvariantRule]:
571 """Load code invariant rules from TOML, falling back to built-in defaults.
572
573 Args:
574 rules_file: Path to the TOML file. ``None`` → use default path.
575
576 Returns:
577 List of :class:`CodeInvariantRule` dicts.
578 """
579 path = rules_file or pathlib.Path(_DEFAULT_RULES_FILE)
580 raw = load_rules_toml(path)
581 if not raw:
582 return list(_BUILTIN_DEFAULTS)
583
584 rules: list[CodeInvariantRule] = []
585 for r in raw:
586 name = str(r.get("name", "unnamed"))
587 raw_sev = str(r.get("severity", "warning"))
588 _sev_map: dict[str, InvariantSeverity] = {"info": "info", "warning": "warning", "error": "error"}
589 severity: InvariantSeverity = _sev_map.get(raw_sev, "warning")
590 scope_raw = str(r.get("scope", "function"))
591 _scope_map: dict[str, Literal["function", "file", "repo", "global"]] = {
592 "function": "function", "file": "file", "repo": "repo", "global": "global",
593 }
594 scope: Literal["function", "file", "repo", "global"] = _scope_map.get(scope_raw, "function")
595 rule_type = str(r.get("rule_type", ""))
596 raw_params = r.get("params", {})
597 params: dict[str, str | int | float] = (
598 {k: v for k, v in raw_params.items()}
599 if isinstance(raw_params, dict)
600 else {}
601 )
602 rule = CodeInvariantRule(
603 name=name, severity=severity, scope=scope, rule_type=rule_type, params=params
604 )
605 rules.append(rule)
606 return rules
607
608
609 def run_invariants(
610 repo_root: pathlib.Path,
611 commit_id: str,
612 rules: list[CodeInvariantRule],
613 ) -> BaseReport:
614 """Evaluate all rules against the snapshot of *commit_id*.
615
616 Args:
617 repo_root: Repository root.
618 commit_id: Commit to check.
619 rules: Rules to evaluate (from :func:`load_invariant_rules`).
620
621 Returns:
622 A :class:`~muse.core.invariants.BaseReport` with all violations.
623 """
624 manifest = get_commit_snapshot_manifest(repo_root, commit_id)
625 if manifest is None:
626 logger.warning("Could not load snapshot for commit %s", commit_id)
627 return make_report(commit_id, "code", [], 0)
628
629 all_violations: list[BaseViolation] = []
630 for rule in rules:
631 try:
632 code_violations = _dispatch_rule(rule, dict(manifest), repo_root)
633 # Upcast CodeViolation → BaseViolation (CodeViolation structurally satisfies it).
634 for cv in code_violations:
635 all_violations.append(BaseViolation(
636 rule_name=cv["rule_name"],
637 severity=cv["severity"],
638 address=cv["address"],
639 description=cv["description"],
640 ))
641 except Exception:
642 logger.exception("Error evaluating rule %r on commit %s", rule["name"], commit_id)
643
644 return make_report(commit_id, "code", all_violations, len(rules))
645
646
647 class CodeChecker:
648 """Satisfies :class:`~muse.core.invariants.InvariantChecker` for the code domain."""
649
650 def check(
651 self,
652 repo_root: pathlib.Path,
653 commit_id: str,
654 *,
655 rules_file: pathlib.Path | None = None,
656 ) -> BaseReport:
657 """Run code invariant checks against *commit_id*."""
658 rules = load_invariant_rules(rules_file)
659 return run_invariants(repo_root, commit_id, rules)