gabriel / muse public
_invariants.py python
658 lines 22.4 KB
766ee24d feat: code domain leverages core invariants, query engine, manifests, p… Gabriel Cardona <gabriel@tellurstori.com> 5d ago
1 """Code-domain invariants engine for Muse.
2
3 Evaluates semantic rules against code snapshots. Rules are declared in
4 ``.muse/code_invariants.toml`` and evaluated at commit time, merge time, or
5 on-demand via ``muse code-check``.
6
7 Rule types
8 ----------
9
10 ``max_complexity``
11 Detects functions / methods whose estimated cyclomatic complexity exceeds
12 *threshold*. Complexity is approximated by counting control-flow branch
13 points (``if``, ``elif``, ``for``, ``while``, ``except``, ``with``,
14 ``and``, ``or``) inside each symbol's body. This correlates well with
15 real cyclomatic complexity for Python and is language-agnostic for other
16 tree-sitter-parsed languages.
17
18 ``no_circular_imports``
19 Detects import cycles among Python files in the snapshot. Builds a
20 directed graph (file → files it imports) and runs DFS cycle detection.
21 Reports each cycle as one violation at the root file of the cycle.
22
23 ``no_dead_exports``
24 Detects top-level functions and classes that are never imported by any
25 other file in the snapshot (dead exports / unreachable public API).
26 Only applies to semantic files; test files and ``__init__.py`` are exempt.
27
28 ``test_coverage_floor``
29 Requires that at least *min_ratio* of non-test functions have a
30 corresponding test function (detected by ``test_`` prefix convention).
31 Reports the actual vs required coverage ratio when the floor is not met.
32
33 TOML example
34 ------------
35 ::
36
37 [[rule]]
38 name = "complexity_gate"
39 severity = "error"
40 scope = "function"
41 rule_type = "max_complexity"
42 [rule.params]
43 threshold = 15
44
45 [[rule]]
46 name = "no_cycles"
47 severity = "error"
48 scope = "file"
49 rule_type = "no_circular_imports"
50
51 [[rule]]
52 name = "dead_exports"
53 severity = "warning"
54 scope = "file"
55 rule_type = "no_dead_exports"
56
57 [[rule]]
58 name = "test_coverage"
59 severity = "warning"
60 scope = "repo"
61 rule_type = "test_coverage_floor"
62 [rule.params]
63 min_ratio = 0.30
64
65 Public API
66 ----------
67 - :class:`CodeInvariantRule` — code-specific rule declaration.
68 - :class:`CodeViolation` — violation with file + symbol address.
69 - :class:`CodeInvariantReport` — full report for one commit.
70 - :class:`CodeChecker` — satisfies :class:`~muse.core.invariants.InvariantChecker`.
71 - :func:`load_invariant_rules` — load from TOML with built-in defaults.
72 - :func:`run_invariants` — top-level runner.
73 """
74 from __future__ import annotations
75
76 import ast
77 import logging
78 import pathlib
79 from typing import Literal, TypedDict
80
81 from muse.core.invariants import (
82 BaseReport,
83 BaseViolation,
84 InvariantSeverity,
85 format_report,
86 load_rules_toml,
87 make_report,
88 )
89 from muse.core.object_store import read_object
90 from muse.core.store import get_commit_snapshot_manifest
91 from muse.plugins.code.ast_parser import (
92 SEMANTIC_EXTENSIONS,
93 SymbolTree,
94 adapter_for_path,
95 parse_symbols,
96 )
97
98 logger = logging.getLogger(__name__)
99
100 _DEFAULT_RULES_FILE = ".muse/code_invariants.toml"
101
102 # ---------------------------------------------------------------------------
103 # Types
104 # ---------------------------------------------------------------------------
105
106
107 class _RuleRequired(TypedDict):
108 name: str
109 severity: InvariantSeverity
110 scope: Literal["function", "file", "repo", "global"]
111 rule_type: str
112
113
114 class CodeInvariantRule(_RuleRequired, total=False):
115 """A single code invariant rule declaration.
116
117 ``name`` Unique human-readable identifier.
118 ``severity`` ``"info"``, ``"warning"``, or ``"error"``.
119 ``scope`` Granularity: ``"function"``, ``"file"``, ``"repo"``.
120 ``rule_type`` Built-in type: ``"max_complexity"``,
121 ``"no_circular_imports"``, ``"no_dead_exports"``,
122 ``"test_coverage_floor"``.
123 ``params`` Rule-specific numeric / string parameters.
124 """
125
126 params: dict[str, str | int | float]
127
128
129 class CodeViolation(TypedDict):
130 """A code invariant violation with precise source location.
131
132 ``rule_name`` Rule that fired.
133 ``severity`` Inherited from the rule.
134 ``address`` ``"file.py::symbol_name"`` or ``"file.py"`` for file-level.
135 ``description`` Human-readable explanation.
136 ``file`` Workspace-relative file path.
137 ``symbol`` Symbol name (empty string for file-level violations).
138 ``detail`` Additional context (e.g. complexity score, cycle path).
139 """
140
141 rule_name: str
142 severity: InvariantSeverity
143 address: str
144 description: str
145 file: str
146 symbol: str
147 detail: str
148
149
150 # ---------------------------------------------------------------------------
151 # Built-in default rules
152 # ---------------------------------------------------------------------------
153
154 _BUILTIN_DEFAULTS: list[CodeInvariantRule] = [
155 CodeInvariantRule(
156 name="complexity_gate",
157 severity="warning",
158 scope="function",
159 rule_type="max_complexity",
160 params={"threshold": 10},
161 ),
162 CodeInvariantRule(
163 name="no_cycles",
164 severity="error",
165 scope="file",
166 rule_type="no_circular_imports",
167 params={},
168 ),
169 CodeInvariantRule(
170 name="dead_exports",
171 severity="warning",
172 scope="file",
173 rule_type="no_dead_exports",
174 params={},
175 ),
176 ]
177
178
179 # ---------------------------------------------------------------------------
180 # Rule implementations
181 # ---------------------------------------------------------------------------
182
183
184 def _estimate_complexity(source: bytes, file_path: str) -> dict[str, int]:
185 """Return {symbol_address: complexity_score} for a Python source file.
186
187 Uses a simple branch-count heuristic: each ``if``, ``elif``, ``for``,
188 ``while``, ``except``, ``with``, ``and``, ``or``, ``assert``,
189 ``comprehension`` adds 1 to the enclosing function's score. Starting
190 score is 1 (the function itself).
191
192 Returns an empty dict for non-Python files or parse failures.
193 """
194 if not file_path.endswith((".py", ".pyi")):
195 return {}
196 try:
197 tree = ast.parse(source, filename=file_path)
198 except SyntaxError:
199 return {}
200
201 branch_nodes = (
202 ast.If, ast.For, ast.While, ast.ExceptHandler,
203 ast.With, ast.AsyncWith, ast.AsyncFor,
204 ast.BoolOp, ast.Assert, ast.comprehension,
205 )
206
207 scores: dict[str, int] = {}
208
209 def _score_fn(node: ast.FunctionDef | ast.AsyncFunctionDef, prefix: str = "") -> None:
210 name = node.name
211 qualified = f"{prefix}{name}" if prefix else name
212 addr = f"{file_path}::{qualified}"
213 score = 1 # base complexity
214 for child in ast.walk(node):
215 if isinstance(child, branch_nodes):
216 score += 1
217 scores[addr] = score
218
219 for node in ast.walk(tree):
220 if isinstance(node, ast.ClassDef):
221 for item in ast.walk(node):
222 if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
223 _score_fn(item, prefix=f"{node.name}.")
224 elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
225 _score_fn(node)
226
227 return scores
228
229
230 def check_max_complexity(
231 manifest: dict[str, str],
232 repo_root: pathlib.Path,
233 rule_name: str,
234 severity: InvariantSeverity,
235 *,
236 threshold: int = 10,
237 ) -> list[CodeViolation]:
238 """Detect functions whose estimated cyclomatic complexity exceeds *threshold*."""
239 violations: list[CodeViolation] = []
240 for file_path, content_hash in manifest.items():
241 if not file_path.endswith((".py", ".pyi")):
242 continue
243 source = read_object(repo_root, content_hash)
244 if source is None:
245 continue
246 scores = _estimate_complexity(source, file_path)
247 for addr, score in sorted(scores.items()):
248 if score > threshold:
249 symbol = addr.split("::", 1)[-1] if "::" in addr else ""
250 violations.append(CodeViolation(
251 rule_name=rule_name,
252 severity=severity,
253 address=addr,
254 description=(
255 f"Complexity {score} exceeds threshold {threshold}. "
256 "Consider extracting helper functions."
257 ),
258 file=file_path,
259 symbol=symbol,
260 detail=f"score={score} threshold={threshold}",
261 ))
262 return violations
263
264
265 def _build_import_graph(
266 manifest: dict[str, str],
267 repo_root: pathlib.Path,
268 ) -> dict[str, set[str]]:
269 """Build a directed import graph: {file → set of imported files}.
270
271 Only tracks intra-repo imports (files that exist in the manifest).
272 """
273 file_set = set(manifest)
274 # Build a module-name → file-path index for intra-repo resolution.
275 module_to_file: dict[str, str] = {}
276 for fp in file_set:
277 if fp.endswith((".py", ".pyi")):
278 # Convert path to module name: strip suffix, replace / with .
279 mod = fp.removesuffix(".pyi").removesuffix(".py").replace("/", ".").replace("\\", ".")
280 module_to_file[mod] = fp
281 # Also index by last segment for relative guesses.
282 last = mod.rsplit(".", 1)[-1]
283 module_to_file.setdefault(last, fp)
284
285 graph: dict[str, set[str]] = {fp: set() for fp in file_set if fp.endswith(".py")}
286
287 for file_path, content_hash in manifest.items():
288 if not file_path.endswith(".py"):
289 continue
290 source = read_object(repo_root, content_hash)
291 if source is None:
292 continue
293 try:
294 tree = ast.parse(source, filename=file_path)
295 except SyntaxError:
296 continue
297
298 for node in ast.walk(tree):
299 if isinstance(node, ast.Import):
300 for alias in node.names:
301 target = module_to_file.get(alias.name)
302 if target and target != file_path:
303 graph[file_path].add(target)
304 elif isinstance(node, ast.ImportFrom) and node.module:
305 target = module_to_file.get(node.module)
306 if target and target != file_path:
307 graph[file_path].add(target)
308
309 return graph
310
311
312 def _find_cycles(graph: dict[str, set[str]]) -> list[list[str]]:
313 """DFS cycle detection; returns list of cycles as file-path lists."""
314 WHITE, GRAY, BLACK = 0, 1, 2
315 color = {n: WHITE for n in graph}
316 stack: list[str] = []
317 cycles: list[list[str]] = []
318
319 def dfs(node: str) -> None:
320 color[node] = GRAY
321 stack.append(node)
322 for neighbor in sorted(graph.get(node, set())):
323 if color.get(neighbor, BLACK) == WHITE:
324 dfs(neighbor)
325 elif color.get(neighbor, BLACK) == GRAY:
326 # Found a cycle — extract from stack.
327 idx = stack.index(neighbor)
328 cycle = stack[idx:]
329 # Deduplicate: only add if not already seen.
330 cycle_key = frozenset(cycle)
331 if not any(frozenset(c) == cycle_key for c in cycles):
332 cycles.append(list(cycle))
333 stack.pop()
334 color[node] = BLACK
335
336 for node in sorted(graph):
337 if color[node] == WHITE:
338 dfs(node)
339
340 return cycles
341
342
343 def check_no_circular_imports(
344 manifest: dict[str, str],
345 repo_root: pathlib.Path,
346 rule_name: str,
347 severity: InvariantSeverity,
348 ) -> list[CodeViolation]:
349 """Detect import cycles among Python files in the snapshot."""
350 graph = _build_import_graph(manifest, repo_root)
351 cycles = _find_cycles(graph)
352 violations: list[CodeViolation] = []
353 for cycle in cycles:
354 root_file = cycle[0]
355 cycle_str = " → ".join([*cycle, cycle[0]])
356 violations.append(CodeViolation(
357 rule_name=rule_name,
358 severity=severity,
359 address=root_file,
360 description=f"Circular import cycle detected: {cycle_str}",
361 file=root_file,
362 symbol="",
363 detail=cycle_str,
364 ))
365 return violations
366
367
368 def check_no_dead_exports(
369 manifest: dict[str, str],
370 repo_root: pathlib.Path,
371 rule_name: str,
372 severity: InvariantSeverity,
373 ) -> list[CodeViolation]:
374 """Detect top-level functions/classes never imported by any other file.
375
376 Exempt: test files, ``__init__.py``, files with ``__all__`` declarations
377 (which signal deliberate public API), and ``main`` functions.
378 """
379 violations: list[CodeViolation] = []
380
381 # Collect all intra-repo imported names.
382 imported_names: set[str] = set()
383 for file_path, content_hash in manifest.items():
384 if not file_path.endswith(".py"):
385 continue
386 source = read_object(repo_root, content_hash)
387 if source is None:
388 continue
389 try:
390 tree = ast.parse(source, filename=file_path)
391 except SyntaxError:
392 continue
393 for node in ast.walk(tree):
394 if isinstance(node, ast.Import):
395 for alias in node.names:
396 imported_names.add(alias.asname or alias.name.split(".")[-1])
397 elif isinstance(node, ast.ImportFrom):
398 for alias in node.names:
399 imported_names.add(alias.asname or alias.name)
400
401 # Check each non-test, non-init Python file.
402 for file_path, content_hash in manifest.items():
403 if not file_path.endswith(".py"):
404 continue
405 base = pathlib.PurePosixPath(file_path).name
406 if base.startswith("test_") or base == "__init__.py":
407 continue
408 source = read_object(repo_root, content_hash)
409 if source is None:
410 continue
411 try:
412 tree = ast.parse(source, filename=file_path)
413 except SyntaxError:
414 continue
415
416 # Skip files that declare __all__ (they manage their own exports).
417 has_all = any(
418 isinstance(n, ast.Assign)
419 and any(isinstance(t, ast.Name) and t.id == "__all__" for t in n.targets)
420 for n in ast.walk(tree)
421 )
422 if has_all:
423 continue
424
425 for node in ast.iter_child_nodes(tree):
426 if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
427 if node.name.startswith("_") or node.name == "main":
428 continue
429 if node.name not in imported_names:
430 addr = f"{file_path}::{node.name}"
431 violations.append(CodeViolation(
432 rule_name=rule_name,
433 severity=severity,
434 address=addr,
435 description=(
436 f"'{node.name}' is never imported by any other file. "
437 "Consider removing, making private (prefix _), or adding to __all__."
438 ),
439 file=file_path,
440 symbol=node.name,
441 detail="no importers found",
442 ))
443 elif isinstance(node, ast.ClassDef):
444 if node.name.startswith("_"):
445 continue
446 if node.name not in imported_names:
447 addr = f"{file_path}::{node.name}"
448 violations.append(CodeViolation(
449 rule_name=rule_name,
450 severity=severity,
451 address=addr,
452 description=(
453 f"Class '{node.name}' is never imported by any other file. "
454 "Consider removing or making private."
455 ),
456 file=file_path,
457 symbol=node.name,
458 detail="no importers found",
459 ))
460
461 return violations
462
463
464 def check_test_coverage_floor(
465 manifest: dict[str, str],
466 repo_root: pathlib.Path,
467 rule_name: str,
468 severity: InvariantSeverity,
469 *,
470 min_ratio: float = 0.30,
471 ) -> list[CodeViolation]:
472 """Require that at least *min_ratio* of functions have a test counterpart.
473
474 A function ``foo`` is considered "tested" if any test file contains a
475 function named ``test_foo`` or a class method containing ``foo`` in its
476 name. This is a naming-convention heuristic, not true coverage.
477 """
478 test_fn_names: set[str] = set()
479 all_fn_names: set[str] = set()
480
481 for file_path, content_hash in manifest.items():
482 if not file_path.endswith(".py"):
483 continue
484 source = read_object(repo_root, content_hash)
485 if source is None:
486 continue
487 try:
488 tree = ast.parse(source, filename=file_path)
489 except SyntaxError:
490 continue
491
492 base = pathlib.PurePosixPath(file_path).name
493 is_test = base.startswith("test_") or base.endswith("_test.py")
494
495 for node in ast.walk(tree):
496 if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
497 if is_test and node.name.startswith("test_"):
498 test_fn_names.add(node.name.removeprefix("test_"))
499 elif not is_test and not node.name.startswith("_"):
500 all_fn_names.add(node.name)
501
502 if not all_fn_names:
503 return []
504
505 covered = all_fn_names & test_fn_names
506 ratio = len(covered) / len(all_fn_names)
507
508 if ratio < min_ratio:
509 pct_actual = round(ratio * 100, 1)
510 pct_required = round(min_ratio * 100, 1)
511 return [CodeViolation(
512 rule_name=rule_name,
513 severity=severity,
514 address="repo",
515 description=(
516 f"Test coverage floor not met: {pct_actual}% of functions have test counterparts "
517 f"(required {pct_required}%). Untested: "
518 + ", ".join(sorted(all_fn_names - covered)[:10])
519 + ("…" if len(all_fn_names - covered) > 10 else "")
520 ),
521 file="",
522 symbol="",
523 detail=f"ratio={ratio:.3f} required={min_ratio:.3f}",
524 )]
525 return []
526
527
528 # ---------------------------------------------------------------------------
529 # Rule dispatch
530 # ---------------------------------------------------------------------------
531
532
533 def _dispatch_rule(
534 rule: CodeInvariantRule,
535 manifest: dict[str, str],
536 repo_root: pathlib.Path,
537 ) -> list[CodeViolation]:
538 """Dispatch a single rule to its implementation function."""
539 params = rule.get("params", {})
540 rule_name = rule["name"]
541 severity = rule["severity"]
542 rt = rule["rule_type"]
543
544 if rt == "max_complexity":
545 threshold = int(params.get("threshold", 10))
546 return check_max_complexity(manifest, repo_root, rule_name, severity, threshold=threshold)
547
548 if rt == "no_circular_imports":
549 return check_no_circular_imports(manifest, repo_root, rule_name, severity)
550
551 if rt == "no_dead_exports":
552 return check_no_dead_exports(manifest, repo_root, rule_name, severity)
553
554 if rt == "test_coverage_floor":
555 min_ratio = float(params.get("min_ratio", 0.30))
556 return check_test_coverage_floor(manifest, repo_root, rule_name, severity, min_ratio=min_ratio)
557
558 logger.warning("Unknown code invariant rule_type: %r — skipping", rt)
559 return []
560
561
562 # ---------------------------------------------------------------------------
563 # Public entry points
564 # ---------------------------------------------------------------------------
565
566
567 def load_invariant_rules(
568 rules_file: pathlib.Path | None = None,
569 ) -> list[CodeInvariantRule]:
570 """Load code invariant rules from TOML, falling back to built-in defaults.
571
572 Args:
573 rules_file: Path to the TOML file. ``None`` → use default path.
574
575 Returns:
576 List of :class:`CodeInvariantRule` dicts.
577 """
578 path = rules_file or pathlib.Path(_DEFAULT_RULES_FILE)
579 raw = load_rules_toml(path)
580 if not raw:
581 return list(_BUILTIN_DEFAULTS)
582
583 rules: list[CodeInvariantRule] = []
584 for r in raw:
585 name = str(r.get("name", "unnamed"))
586 raw_sev = str(r.get("severity", "warning"))
587 _sev_map: dict[str, InvariantSeverity] = {"info": "info", "warning": "warning", "error": "error"}
588 severity: InvariantSeverity = _sev_map.get(raw_sev, "warning")
589 scope_raw = str(r.get("scope", "function"))
590 _scope_map: dict[str, Literal["function", "file", "repo", "global"]] = {
591 "function": "function", "file": "file", "repo": "repo", "global": "global",
592 }
593 scope: Literal["function", "file", "repo", "global"] = _scope_map.get(scope_raw, "function")
594 rule_type = str(r.get("rule_type", ""))
595 raw_params = r.get("params", {})
596 params: dict[str, str | int | float] = (
597 {k: v for k, v in raw_params.items()}
598 if isinstance(raw_params, dict)
599 else {}
600 )
601 rule = CodeInvariantRule(
602 name=name, severity=severity, scope=scope, rule_type=rule_type, params=params
603 )
604 rules.append(rule)
605 return rules
606
607
608 def run_invariants(
609 repo_root: pathlib.Path,
610 commit_id: str,
611 rules: list[CodeInvariantRule],
612 ) -> BaseReport:
613 """Evaluate all rules against the snapshot of *commit_id*.
614
615 Args:
616 repo_root: Repository root.
617 commit_id: Commit to check.
618 rules: Rules to evaluate (from :func:`load_invariant_rules`).
619
620 Returns:
621 A :class:`~muse.core.invariants.BaseReport` with all violations.
622 """
623 manifest = get_commit_snapshot_manifest(repo_root, commit_id)
624 if manifest is None:
625 logger.warning("Could not load snapshot for commit %s", commit_id)
626 return make_report(commit_id, "code", [], 0)
627
628 all_violations: list[BaseViolation] = []
629 for rule in rules:
630 try:
631 code_violations = _dispatch_rule(rule, dict(manifest), repo_root)
632 # Upcast CodeViolation → BaseViolation (CodeViolation structurally satisfies it).
633 for cv in code_violations:
634 all_violations.append(BaseViolation(
635 rule_name=cv["rule_name"],
636 severity=cv["severity"],
637 address=cv["address"],
638 description=cv["description"],
639 ))
640 except Exception:
641 logger.exception("Error evaluating rule %r on commit %s", rule["name"], commit_id)
642
643 return make_report(commit_id, "code", all_violations, len(rules))
644
645
646 class CodeChecker:
647 """Satisfies :class:`~muse.core.invariants.InvariantChecker` for the code domain."""
648
649 def check(
650 self,
651 repo_root: pathlib.Path,
652 commit_id: str,
653 *,
654 rules_file: pathlib.Path | None = None,
655 ) -> BaseReport:
656 """Run code invariant checks against *commit_id*."""
657 rules = load_invariant_rules(rules_file)
658 return run_invariants(repo_root, commit_id, rules)