cgcardona / muse public
test_core_query_engine.py python
193 lines 6.8 KB
8d5137ed fix(security): full surface hardening — validation, path containment, p… Gabriel Cardona <cgcardona@gmail.com> 10h ago
1 """Tests for the generic query engine in muse/core/query_engine.py."""
2
3 import datetime
4 import pathlib
5 import tempfile
6
7 import pytest
8
9 from muse.core.query_engine import QueryMatch, format_matches, walk_history
10 from muse.core.store import CommitRecord, write_commit
11
12
13 # ---------------------------------------------------------------------------
14 # Helpers
15 # ---------------------------------------------------------------------------
16
17
18 def _make_repo(tmp_path: pathlib.Path) -> pathlib.Path:
19 """Set up a minimal .muse/ structure for query_engine tests."""
20 muse = tmp_path / ".muse"
21 muse.mkdir()
22 (muse / "repo.json").write_text('{"repo_id":"test-repo"}')
23 (muse / "HEAD").write_text("refs/heads/main")
24 (muse / "commits").mkdir()
25 (muse / "snapshots").mkdir()
26 (muse / "refs" / "heads").mkdir(parents=True)
27 return tmp_path
28
29
30 def _write_commit(root: pathlib.Path, commit_id: str, parent_id: str | None = None) -> CommitRecord:
31 record = CommitRecord(
32 commit_id=commit_id,
33 repo_id="test-repo",
34 branch="main",
35 snapshot_id="snap-" + commit_id,
36 message=f"commit {commit_id}",
37 committed_at=datetime.datetime.now(datetime.timezone.utc),
38 parent_commit_id=parent_id,
39 author="test-author",
40 )
41 write_commit(root, record)
42 return record
43
44
45 # ---------------------------------------------------------------------------
46 # walk_history
47 # ---------------------------------------------------------------------------
48
49
50 class TestWalkHistory:
51 def test_empty_branch_returns_empty(self) -> None:
52 with tempfile.TemporaryDirectory() as tmp:
53 root = _make_repo(pathlib.Path(tmp))
54 results = walk_history(root, "main", lambda c, m, r: [])
55 assert results == []
56
57 def test_single_commit_visited(self) -> None:
58 with tempfile.TemporaryDirectory() as tmp:
59 root = _make_repo(pathlib.Path(tmp))
60 _write_commit(root, "aaa111")
61 (root / ".muse" / "refs" / "heads" / "main").write_text("aaa111")
62
63 visited: list[str] = []
64
65 def evaluator(commit: CommitRecord, manifest: dict[str, str], r: pathlib.Path) -> list[QueryMatch]:
66 visited.append(commit.commit_id)
67 return []
68
69 walk_history(root, "main", evaluator)
70 assert visited == ["aaa111"]
71
72 def test_chain_walked_newest_first(self) -> None:
73 with tempfile.TemporaryDirectory() as tmp:
74 root = _make_repo(pathlib.Path(tmp))
75 _write_commit(root, "aaa111")
76 _write_commit(root, "bbb222", parent_id="aaa111")
77 (root / ".muse" / "refs" / "heads" / "main").write_text("bbb222")
78
79 visited: list[str] = []
80
81 def evaluator(commit: CommitRecord, manifest: dict[str, str], r: pathlib.Path) -> list[QueryMatch]:
82 visited.append(commit.commit_id)
83 return []
84
85 walk_history(root, "main", evaluator)
86 assert visited == ["bbb222", "aaa111"]
87
88 def test_matches_collected(self) -> None:
89 with tempfile.TemporaryDirectory() as tmp:
90 root = _make_repo(pathlib.Path(tmp))
91 _write_commit(root, "ccc333")
92 (root / ".muse" / "refs" / "heads" / "main").write_text("ccc333")
93
94 def evaluator(commit: CommitRecord, manifest: dict[str, str], r: pathlib.Path) -> list[QueryMatch]:
95 return [QueryMatch(
96 commit_id=commit.commit_id,
97 author=commit.author,
98 committed_at=commit.committed_at.isoformat(),
99 branch=commit.branch,
100 detail="test match",
101 extra={},
102 )]
103
104 results = walk_history(root, "main", evaluator)
105 assert len(results) == 1
106 assert results[0]["detail"] == "test match"
107
108 def test_max_commits_limits_walk(self) -> None:
109 with tempfile.TemporaryDirectory() as tmp:
110 root = _make_repo(pathlib.Path(tmp))
111 ids = [f"commit{i:03d}" for i in range(10)]
112 for i, cid in enumerate(ids):
113 parent = ids[i - 1] if i > 0 else None
114 _write_commit(root, cid, parent_id=parent)
115 (root / ".muse" / "refs" / "heads" / "main").write_text(ids[-1])
116
117 visited: list[str] = []
118
119 def evaluator(commit: CommitRecord, manifest: dict[str, str], r: pathlib.Path) -> list[QueryMatch]:
120 visited.append(commit.commit_id)
121 return []
122
123 walk_history(root, "main", evaluator, max_commits=3)
124 assert len(visited) == 3
125
126 def test_head_commit_id_override(self) -> None:
127 with tempfile.TemporaryDirectory() as tmp:
128 root = _make_repo(pathlib.Path(tmp))
129 _write_commit(root, "aaa111")
130 _write_commit(root, "bbb222", parent_id="aaa111")
131 # HEAD points to bbb222 but we override to aaa111.
132 (root / ".muse" / "refs" / "heads" / "main").write_text("bbb222")
133
134 visited: list[str] = []
135
136 def evaluator(commit: CommitRecord, manifest: dict[str, str], r: pathlib.Path) -> list[QueryMatch]:
137 visited.append(commit.commit_id)
138 return []
139
140 walk_history(root, "main", evaluator, head_commit_id="aaa111")
141 assert visited == ["aaa111"]
142
143
144 # ---------------------------------------------------------------------------
145 # format_matches
146 # ---------------------------------------------------------------------------
147
148
149 class TestFormatMatches:
150 def test_empty_returns_no_matches(self) -> None:
151 assert "No matches" in format_matches([])
152
153 def test_single_match_formatted(self) -> None:
154 m = QueryMatch(
155 commit_id="a" * 64,
156 author="gabriel",
157 committed_at="2026-03-18T12:00:00+00:00",
158 branch="main",
159 detail="my_function (added)",
160 extra={},
161 )
162 out = format_matches([m])
163 assert ("a" * 64)[:8] in out
164 assert "gabriel" in out
165 assert "my_function (added)" in out
166
167 def test_agent_id_shown_when_present(self) -> None:
168 m = QueryMatch(
169 commit_id="a" * 64,
170 author="bot",
171 committed_at="2026-03-18T12:00:00+00:00",
172 branch="main",
173 detail="something",
174 extra={},
175 agent_id="claude-v4",
176 )
177 out = format_matches([m])
178 assert "claude-v4" in out
179
180 def test_max_results_capped(self) -> None:
181 matches = [
182 QueryMatch(
183 commit_id=f"commit{i:04d}",
184 author="x",
185 committed_at="2026-01-01T00:00:00+00:00",
186 branch="main",
187 detail=f"match {i}",
188 extra={},
189 )
190 for i in range(100)
191 ]
192 out = format_matches(matches, max_results=5)
193 assert "95 more" in out