cgcardona / muse public
test_code_invariants.py python
290 lines 11.1 KB
766ee24d feat: code domain leverages core invariants, query engine, manifests, p… Gabriel Cardona <gabriel@tellurstori.com> 1d ago
1 """Tests for the code-domain invariants engine."""
2 from __future__ import annotations
3
4 import pathlib
5 import tempfile
6
7 import pytest
8
9 from muse.core.invariants import InvariantChecker
10 from muse.plugins.code._invariants import (
11 CodeChecker,
12 CodeInvariantRule,
13 check_max_complexity,
14 check_no_circular_imports,
15 check_no_dead_exports,
16 check_test_coverage_floor,
17 load_invariant_rules,
18 run_invariants,
19 )
20
21
22 # ---------------------------------------------------------------------------
23 # Helpers
24 # ---------------------------------------------------------------------------
25
26
27 def _make_repo(tmp_path: pathlib.Path) -> pathlib.Path:
28 """Set up a minimal .muse/ structure."""
29 muse = tmp_path / ".muse"
30 muse.mkdir()
31 (muse / "repo.json").write_text('{"repo_id":"test"}')
32 (muse / "HEAD").write_text("refs/heads/main")
33 (muse / "commits").mkdir()
34 (muse / "snapshots").mkdir()
35 (muse / "refs" / "heads").mkdir(parents=True)
36 (muse / "objects").mkdir()
37 return tmp_path
38
39
40 def _write_object(root: pathlib.Path, content: bytes) -> str:
41 import hashlib
42 h = hashlib.sha256(content).hexdigest()
43 obj_path = root / ".muse" / "objects" / h[:2] / h[2:]
44 obj_path.parent.mkdir(parents=True, exist_ok=True)
45 obj_path.write_bytes(content)
46 return h
47
48
49 # ---------------------------------------------------------------------------
50 # _estimate_complexity (via check_max_complexity)
51 # ---------------------------------------------------------------------------
52
53
54 class TestMaxComplexity:
55 def test_simple_function_no_violation(self) -> None:
56 with tempfile.TemporaryDirectory() as tmp:
57 root = _make_repo(pathlib.Path(tmp))
58 src = b"def simple():\n return 1\n"
59 h = _write_object(root, src)
60 manifest = {"mod.py": h}
61 violations = check_max_complexity(manifest, root, "test", "error", threshold=10)
62 assert violations == []
63
64 def test_complex_function_triggers_violation(self) -> None:
65 # 15+ branches = definitely over threshold 5.
66 src = b"""
67 def complex():
68 if True:
69 pass
70 if True:
71 pass
72 if True:
73 pass
74 if True:
75 pass
76 if True:
77 pass
78 if True:
79 pass
80 if True:
81 pass
82 return 1
83 """
84 with tempfile.TemporaryDirectory() as tmp:
85 root = _make_repo(pathlib.Path(tmp))
86 h = _write_object(root, src)
87 manifest = {"mod.py": h}
88 violations = check_max_complexity(manifest, root, "gate", "error", threshold=5)
89 assert len(violations) >= 1
90 assert violations[0]["rule_name"] == "gate"
91 assert "complexity" in violations[0]["description"].lower()
92
93 def test_non_python_file_skipped(self) -> None:
94 with tempfile.TemporaryDirectory() as tmp:
95 root = _make_repo(pathlib.Path(tmp))
96 src = b"def hello() { return 1; }"
97 h = _write_object(root, src)
98 manifest = {"mod.js": h}
99 violations = check_max_complexity(manifest, root, "c", "error", threshold=1)
100 assert violations == []
101
102
103 # ---------------------------------------------------------------------------
104 # check_no_circular_imports
105 # ---------------------------------------------------------------------------
106
107
108 class TestNoCircularImports:
109 def test_no_cycle_returns_empty(self) -> None:
110 with tempfile.TemporaryDirectory() as tmp:
111 root = _make_repo(pathlib.Path(tmp))
112 a = b"import b\n"
113 b_src = b"x = 1\n"
114 ha = _write_object(root, a)
115 hb = _write_object(root, b_src)
116 manifest = {"a.py": ha, "b.py": hb}
117 violations = check_no_circular_imports(manifest, root, "no_cycles", "error")
118 assert violations == []
119
120 def test_cycle_detected(self) -> None:
121 with tempfile.TemporaryDirectory() as tmp:
122 root = _make_repo(pathlib.Path(tmp))
123 # a imports b, b imports a → cycle
124 a = b"import b\n"
125 b_src = b"import a\n"
126 ha = _write_object(root, a)
127 hb = _write_object(root, b_src)
128 manifest = {"a.py": ha, "b.py": hb}
129 violations = check_no_circular_imports(manifest, root, "no_cycles", "error")
130 assert len(violations) >= 1
131 assert "cycle" in violations[0]["description"].lower()
132
133 def test_three_file_cycle_detected(self) -> None:
134 with tempfile.TemporaryDirectory() as tmp:
135 root = _make_repo(pathlib.Path(tmp))
136 a = b"import b\n"
137 b_src = b"import c\n"
138 c_src = b"import a\n"
139 ha = _write_object(root, a)
140 hb = _write_object(root, b_src)
141 hc = _write_object(root, c_src)
142 manifest = {"a.py": ha, "b.py": hb, "c.py": hc}
143 violations = check_no_circular_imports(manifest, root, "cycles", "error")
144 assert len(violations) >= 1
145
146
147 # ---------------------------------------------------------------------------
148 # check_no_dead_exports
149 # ---------------------------------------------------------------------------
150
151
152 class TestNoDeadExports:
153 def test_used_function_not_reported(self) -> None:
154 with tempfile.TemporaryDirectory() as tmp:
155 root = _make_repo(pathlib.Path(tmp))
156 lib = b"def my_func():\n return 1\n"
157 main = b"from lib import my_func\n"
158 hl = _write_object(root, lib)
159 hm = _write_object(root, main)
160 manifest = {"lib.py": hl, "main.py": hm}
161 violations = check_no_dead_exports(manifest, root, "dead", "warning")
162 # lib.my_func is imported by main.py → should not be reported.
163 addresses = [v["address"] for v in violations]
164 assert "lib.py::my_func" not in addresses
165
166 def test_unused_function_reported(self) -> None:
167 with tempfile.TemporaryDirectory() as tmp:
168 root = _make_repo(pathlib.Path(tmp))
169 lib = b"def orphan_fn():\n return 1\n"
170 other = b"x = 1\n"
171 hl = _write_object(root, lib)
172 ho = _write_object(root, other)
173 manifest = {"lib.py": hl, "other.py": ho}
174 violations = check_no_dead_exports(manifest, root, "dead", "warning")
175 addresses = [v["address"] for v in violations]
176 assert "lib.py::orphan_fn" in addresses
177
178 def test_private_function_exempt(self) -> None:
179 with tempfile.TemporaryDirectory() as tmp:
180 root = _make_repo(pathlib.Path(tmp))
181 lib = b"def _private():\n return 1\n"
182 h = _write_object(root, lib)
183 manifest = {"lib.py": h}
184 violations = check_no_dead_exports(manifest, root, "dead", "warning")
185 # Private functions are exempt.
186 assert all("_private" not in v["address"] for v in violations)
187
188 def test_test_file_exempt(self) -> None:
189 with tempfile.TemporaryDirectory() as tmp:
190 root = _make_repo(pathlib.Path(tmp))
191 lib = b"def test_something():\n assert True\n"
192 h = _write_object(root, lib)
193 manifest = {"test_stuff.py": h}
194 violations = check_no_dead_exports(manifest, root, "dead", "warning")
195 assert violations == []
196
197
198 # ---------------------------------------------------------------------------
199 # check_test_coverage_floor
200 # ---------------------------------------------------------------------------
201
202
203 class TestTestCoverageFloor:
204 def test_well_covered_code_no_violation(self) -> None:
205 with tempfile.TemporaryDirectory() as tmp:
206 root = _make_repo(pathlib.Path(tmp))
207 src = b"def foo():\n return 1\n"
208 test_src = b"def test_foo():\n assert True\n"
209 hs = _write_object(root, src)
210 ht = _write_object(root, test_src)
211 manifest = {"src.py": hs, "test_src.py": ht}
212 violations = check_test_coverage_floor(manifest, root, "coverage", "warning", min_ratio=0.5)
213 assert violations == []
214
215 def test_uncovered_code_violates(self) -> None:
216 with tempfile.TemporaryDirectory() as tmp:
217 root = _make_repo(pathlib.Path(tmp))
218 src = b"def foo():\n pass\ndef bar():\n pass\ndef baz():\n pass\n"
219 h = _write_object(root, src)
220 manifest = {"src.py": h}
221 violations = check_test_coverage_floor(manifest, root, "coverage", "warning", min_ratio=0.5)
222 assert len(violations) == 1
223 assert "coverage floor" in violations[0]["description"].lower()
224
225 def test_no_functions_no_violation(self) -> None:
226 with tempfile.TemporaryDirectory() as tmp:
227 root = _make_repo(pathlib.Path(tmp))
228 src = b"X = 1\n"
229 h = _write_object(root, src)
230 manifest = {"config.py": h}
231 violations = check_test_coverage_floor(manifest, root, "coverage", "warning", min_ratio=0.5)
232 assert violations == []
233
234
235 # ---------------------------------------------------------------------------
236 # load_invariant_rules
237 # ---------------------------------------------------------------------------
238
239
240 class TestLoadInvariantRules:
241 def test_no_file_returns_defaults(self) -> None:
242 rules = load_invariant_rules(pathlib.Path("/no/such/file.toml"))
243 assert len(rules) >= 1
244 rule_types = {r["rule_type"] for r in rules}
245 assert "max_complexity" in rule_types
246
247 def test_toml_file_loaded(self) -> None:
248 import tempfile
249 toml = "[[rule]]\nname='r1'\nseverity='error'\nscope='function'\nrule_type='max_complexity'\n"
250 with tempfile.NamedTemporaryFile(suffix=".toml", mode="w", delete=False) as f:
251 f.write(toml)
252 path = pathlib.Path(f.name)
253 try:
254 rules = load_invariant_rules(path)
255 assert any(r["rule_type"] == "max_complexity" for r in rules)
256 finally:
257 path.unlink(missing_ok=True)
258
259
260 # ---------------------------------------------------------------------------
261 # CodeChecker (protocol)
262 # ---------------------------------------------------------------------------
263
264
265 class TestCodeChecker:
266 def test_satisfies_invariant_checker_protocol(self) -> None:
267 checker = CodeChecker()
268 assert isinstance(checker, InvariantChecker)
269
270 def test_check_returns_base_report(self) -> None:
271 with tempfile.TemporaryDirectory() as tmp:
272 root = _make_repo(pathlib.Path(tmp))
273 # No commits — check should return a report with 0 violations.
274 from muse.core.store import CommitRecord, SnapshotRecord, write_commit, write_snapshot
275 import datetime
276 snap = SnapshotRecord(snapshot_id="snap1", manifest={})
277 write_snapshot(root, snap)
278 commit = CommitRecord(
279 commit_id="abc123",
280 repo_id="test",
281 branch="main",
282 snapshot_id="snap1",
283 message="init",
284 committed_at=datetime.datetime.now(datetime.timezone.utc),
285 )
286 write_commit(root, commit)
287 report = CodeChecker().check(root, "abc123")
288 assert report["commit_id"] == "abc123"
289 assert report["domain"] == "code"
290 assert isinstance(report["violations"], list)