cgcardona / muse public
_code_query.py python
377 lines 11.9 KB
766ee24d feat: code domain leverages core invariants, query engine, manifests, p… Gabriel Cardona <gabriel@tellurstori.com> 1d ago
1 """Code-domain query evaluator for the Muse generic query engine.
2
3 Implements :data:`~muse.core.query_engine.CommitEvaluator` for the code domain.
4 Allows agents and humans to search the commit history for code changes::
5
6 muse code-query "symbol == 'my_function' and change == 'added'"
7 muse code-query "language == 'Python' and author == 'agent-x'"
8 muse code-query "agent_id == 'claude' and sem_ver_bump == 'major'"
9 muse code-query "file == 'src/core.py'"
10 muse code-query "change == 'added' and kind == 'class'"
11
12 Query language
13 --------------
14
15 query = and_expr ( 'or' and_expr )*
16 and_expr = atom ( 'and' atom )*
17 atom = FIELD OP VALUE
18 FIELD = 'symbol' | 'file' | 'language' | 'kind' | 'change'
19 | 'author' | 'agent_id' | 'model_id' | 'toolchain_id'
20 | 'sem_ver_bump' | 'branch'
21 OP = '==' | '!=' | 'contains' | 'startswith'
22 VALUE = QUOTED_STRING | UNQUOTED_WORD
23
24 Supported fields
25 ----------------
26
27 ``symbol`` Qualified symbol name (e.g. ``"MyClass.method"``).
28 ``file`` Workspace-relative file path.
29 ``language`` Language name (``"Python"``, ``"TypeScript"``…).
30 ``kind`` Symbol kind (``"function"``, ``"class"``, ``"method"``…).
31 ``change`` ``"added"``, ``"removed"``, or ``"modified"``.
32 ``author`` Commit author string.
33 ``agent_id`` Agent identity from commit provenance.
34 ``model_id`` Model ID from commit provenance.
35 ``toolchain_id`` Toolchain string from commit provenance.
36 ``sem_ver_bump`` Semantic version bump: ``"none"``, ``"patch"``,
37 ``"minor"``, ``"major"``.
38 ``branch`` Branch name.
39 """
40 from __future__ import annotations
41
42 import logging
43 import pathlib
44 import re
45 from dataclasses import dataclass
46 from typing import Literal, TypeGuard, get_args
47
48 from muse.core.query_engine import CommitEvaluator, QueryMatch
49 from muse.core.store import CommitRecord
50 from muse.domain import DomainOp
51 from muse.plugins.code._query import language_of, symbols_for_snapshot
52
53 logger = logging.getLogger(__name__)
54
55
56 # ---------------------------------------------------------------------------
57 # Query AST types
58 # ---------------------------------------------------------------------------
59
60 CodeField = Literal[
61 "symbol", "file", "language", "kind", "change",
62 "author", "agent_id", "model_id", "toolchain_id",
63 "sem_ver_bump", "branch",
64 ]
65
66 CodeOp = Literal["==", "!=", "contains", "startswith"]
67
68
69 @dataclass(frozen=True)
70 class Comparison:
71 """A single field OP value predicate."""
72
73 field: CodeField
74 op: CodeOp
75 value: str
76
77
78 @dataclass(frozen=True)
79 class AndExpr:
80 """Conjunction of predicates (all must match)."""
81
82 clauses: list[Comparison]
83
84
85 @dataclass(frozen=True)
86 class OrExpr:
87 """Disjunction of AND-expressions (any must match)."""
88
89 clauses: list[AndExpr]
90
91
92 # ---------------------------------------------------------------------------
93 # Tokeniser & parser
94 # ---------------------------------------------------------------------------
95
96 _TOKEN_RE = re.compile(
97 r"""
98 (?P<keyword>(?:or|and|contains|startswith)(?![A-Za-z0-9_.]))
99 |(?P<op>==|!=)
100 |(?P<quoted>"[^"]*"|'[^']*')
101 |(?P<word>[A-Za-z_][A-Za-z0-9_.]*)
102 """,
103 re.VERBOSE,
104 )
105
106 _VALID_FIELDS: frozenset[str] = frozenset(get_args(CodeField))
107 _VALID_OPS: frozenset[str] = frozenset(get_args(CodeOp))
108
109
110 def _is_code_field(tok: str) -> TypeGuard[CodeField]:
111 return tok in _VALID_FIELDS
112
113
114 def _is_code_op(tok: str) -> TypeGuard[CodeOp]:
115 return tok in _VALID_OPS
116
117
118 def _as_code_field(tok: str) -> CodeField:
119 """Validate and narrow *tok* to :data:`CodeField`; raises :exc:`ValueError` if invalid."""
120 if not _is_code_field(tok):
121 raise ValueError(f"Unknown field: {tok!r}. Valid: {sorted(_VALID_FIELDS)}")
122 return tok
123
124
125 def _as_code_op(tok: str) -> CodeOp:
126 """Validate and narrow *tok* to :data:`CodeOp`; raises :exc:`ValueError` if invalid."""
127 if not _is_code_op(tok):
128 raise ValueError(f"Unknown operator: {tok!r}. Valid: {sorted(_VALID_OPS)}")
129 return tok
130
131
132 def _tokenize(query: str) -> list[str]:
133 return [m.group() for m in _TOKEN_RE.finditer(query)]
134
135
136 def _parse_query(query: str) -> OrExpr:
137 """Parse a query string into an :class:`OrExpr` AST."""
138 tokens = _tokenize(query.strip())
139 pos = 0
140
141 def peek() -> str | None:
142 return tokens[pos] if pos < len(tokens) else None
143
144 def consume() -> str:
145 nonlocal pos
146 tok = tokens[pos]
147 pos += 1
148 return tok
149
150 def parse_atom() -> Comparison:
151 field_tok = consume()
152 validated_field = _as_code_field(field_tok)
153 op_tok = consume()
154 validated_op = _as_code_op(op_tok)
155 val_tok = consume()
156 if val_tok.startswith(("'", '"')):
157 val_tok = val_tok[1:-1]
158 return Comparison(
159 field=validated_field,
160 op=validated_op,
161 value=val_tok,
162 )
163
164 def parse_and() -> AndExpr:
165 clauses: list[Comparison] = [parse_atom()]
166 while peek() == "and":
167 consume()
168 clauses.append(parse_atom())
169 return AndExpr(clauses=clauses)
170
171 def parse_or() -> OrExpr:
172 clauses: list[AndExpr] = [parse_and()]
173 while peek() == "or":
174 consume()
175 clauses.append(parse_and())
176 return OrExpr(clauses=clauses)
177
178 return parse_or()
179
180
181 # ---------------------------------------------------------------------------
182 # Evaluator
183 # ---------------------------------------------------------------------------
184
185
186 def _match_op(actual: str, op: CodeOp, expected: str) -> bool:
187 """Apply *op* to *actual* and *expected* strings."""
188 if op == "==":
189 return actual == expected
190 if op == "!=":
191 return actual != expected
192 if op == "contains":
193 return expected.lower() in actual.lower()
194 # op == "startswith"
195 return actual.lower().startswith(expected.lower())
196
197
198 def _commit_matches_comparison(
199 comparison: Comparison,
200 commit: CommitRecord,
201 manifest: dict[str, str],
202 root: pathlib.Path,
203 symbol_matches: list[dict[str, str]],
204 ) -> bool:
205 """Return True if *commit* + its symbols satisfy *comparison*.
206
207 For symbol/file/language/kind/change fields, each (symbol, file) pair
208 that matches is appended to *symbol_matches* for result detail.
209 """
210 f = comparison.field
211 op = comparison.op
212 v = comparison.value
213
214 # Commit-level fields — no symbol iteration needed.
215 if f == "author":
216 return _match_op(commit.author, op, v)
217 if f == "agent_id":
218 return _match_op(commit.agent_id, op, v)
219 if f == "model_id":
220 return _match_op(commit.model_id, op, v)
221 if f == "toolchain_id":
222 return _match_op(commit.toolchain_id, op, v)
223 if f == "sem_ver_bump":
224 return _match_op(commit.sem_ver_bump, op, v)
225 if f == "branch":
226 return _match_op(commit.branch, op, v)
227
228 # Symbol/file-level fields — iterate the delta ops.
229 delta = commit.structured_delta
230 if delta is None:
231 return False
232
233 hit = False
234 for op_rec in delta.get("ops", []):
235 op_type = op_rec.get("op", "")
236 address: str = op_rec.get("address", "")
237
238 # Resolve file vs symbol from address.
239 if "::" in address:
240 file_path, symbol_name = address.split("::", 1)
241 else:
242 file_path = address
243 symbol_name = ""
244
245 lang = language_of(file_path)
246 change_type = (
247 "added" if op_type == "insert"
248 else "removed" if op_type == "delete"
249 else "modified"
250 )
251
252 # For PatchOps also iterate child_ops.
253 all_ops: list[DomainOp] = [op_rec]
254 if op_rec.get("op") == "patch" and op_rec["op"] == "patch":
255 all_ops = [op_rec] + op_rec["child_ops"]
256
257 for rec in all_ops:
258 rec_address: str = str(rec.get("address", address))
259 if "::" in rec_address:
260 rec_file, rec_symbol = rec_address.split("::", 1)
261 else:
262 rec_file = rec_address
263 rec_symbol = ""
264
265 rec_kind = str(rec.get("kind", ""))
266 rec_op_type = str(rec.get("op", ""))
267 rec_change = (
268 "added" if rec_op_type == "insert"
269 else "removed" if rec_op_type == "delete"
270 else "modified"
271 )
272
273 field_val = {
274 "symbol": rec_symbol or symbol_name,
275 "file": rec_file or file_path,
276 "language": lang,
277 "kind": rec_kind,
278 "change": rec_change or change_type,
279 }.get(f, "")
280
281 if field_val is not None and _match_op(field_val, op, v):
282 hit = True
283 sym = rec_symbol or symbol_name
284 symbol_matches.append({
285 "file": rec_file or file_path,
286 "symbol": sym,
287 "kind": rec_kind,
288 "change": rec_change or change_type,
289 "language": lang,
290 })
291
292 return hit
293
294
295 def build_evaluator(query: str) -> CommitEvaluator:
296 """Parse *query* and return a :data:`CommitEvaluator` for :func:`~muse.core.query_engine.walk_history`.
297
298 Args:
299 query: A query string in the code query DSL.
300
301 Returns:
302 A callable that can be passed to :func:`~muse.core.query_engine.walk_history`.
303
304 Raises:
305 ValueError: If the query cannot be parsed.
306 """
307 ast = _parse_query(query)
308
309 def evaluator(
310 commit: CommitRecord,
311 manifest: dict[str, str],
312 root: pathlib.Path,
313 ) -> list[QueryMatch]:
314 matches: list[QueryMatch] = []
315 symbol_matches: list[dict[str, str]] = []
316
317 # An OrExpr matches when any AndExpr matches.
318 for and_expr in ast.clauses:
319 clause_symbols: list[dict[str, str]] = []
320 # An AndExpr matches when ALL comparisons match.
321 all_match = all(
322 _commit_matches_comparison(cmp, commit, manifest, root, clause_symbols)
323 for cmp in and_expr.clauses
324 )
325 if all_match:
326 symbol_matches.extend(clause_symbols)
327 break # or-short-circuit
328
329 if not symbol_matches:
330 # Check if commit-level only match.
331 only_commit_fields = all(
332 cmp.field in {"author", "agent_id", "model_id", "toolchain_id", "sem_ver_bump", "branch"}
333 for and_expr in ast.clauses
334 for cmp in and_expr.clauses
335 )
336 commit_match = any(
337 all(
338 _commit_matches_comparison(cmp, commit, manifest, root, [])
339 for cmp in and_expr.clauses
340 )
341 for and_expr in ast.clauses
342 )
343 if only_commit_fields and commit_match:
344 m = QueryMatch(
345 commit_id=commit.commit_id,
346 author=commit.author,
347 committed_at=commit.committed_at.isoformat(),
348 branch=commit.branch,
349 detail=commit.message[:80],
350 extra={},
351 )
352 if commit.agent_id:
353 m["agent_id"] = commit.agent_id
354 if commit.model_id:
355 m["model_id"] = commit.model_id
356 matches.append(m)
357 else:
358 for sym in symbol_matches[:20]: # cap per-commit matches
359 detail = sym.get("symbol") or sym.get("file", "?")
360 change = sym.get("change", "")
361 if change:
362 detail = f"{detail} ({change})"
363 m = QueryMatch(
364 commit_id=commit.commit_id,
365 author=commit.author,
366 committed_at=commit.committed_at.isoformat(),
367 branch=commit.branch,
368 detail=detail,
369 extra={k: v for k, v in sym.items()},
370 )
371 if commit.agent_id:
372 m["agent_id"] = commit.agent_id
373 matches.append(m)
374
375 return matches
376
377 return evaluator