cgcardona / muse public
_query.py python
206 lines 6.8 KB
675d76ba refactor: full Python 3.13 idiom pass Gabriel Cardona <gabriel@tellurstori.com> 1d ago
1 """Shared query helpers for the code-domain CLI commands.
2
3 This module provides the low-level primitives that multiple code-domain
4 commands need — symbol extraction from snapshots, commit-graph walking,
5 and language classification — so each command can stay thin.
6
7 None of these functions are part of the public ``CodePlugin`` API. They
8 are internal helpers for the CLI layer and must not be imported by any
9 core module.
10 """
11 from __future__ import annotations
12
13 import itertools
14 import logging
15 import pathlib
16 from collections.abc import Iterator
17
18 from muse.core.object_store import read_object
19 from muse.core.store import CommitRecord, read_commit
20 from muse.domain import DomainOp
21 from muse.plugins.code.ast_parser import (
22 SEMANTIC_EXTENSIONS,
23 SymbolRecord,
24 SymbolTree,
25 parse_symbols,
26 )
27
28 logger = logging.getLogger(__name__)
29
30 # ---------------------------------------------------------------------------
31 # Language classification
32 # ---------------------------------------------------------------------------
33
34 _SUFFIX_LANG: dict[str, str] = {
35 ".py": "Python", ".pyi": "Python",
36 ".ts": "TypeScript", ".tsx": "TypeScript",
37 ".js": "JavaScript", ".jsx": "JavaScript",
38 ".mjs": "JavaScript", ".cjs": "JavaScript",
39 ".go": "Go",
40 ".rs": "Rust",
41 ".java": "Java",
42 ".cs": "C#",
43 ".c": "C", ".h": "C",
44 ".cpp": "C++", ".cc": "C++", ".cxx": "C++", ".hpp": "C++", ".hxx": "C++",
45 ".rb": "Ruby",
46 ".kt": "Kotlin", ".kts": "Kotlin",
47 }
48
49
50 def language_of(file_path: str) -> str:
51 """Return a display language name for *file_path* based on its suffix."""
52 suffix = pathlib.PurePosixPath(file_path).suffix.lower()
53 return _SUFFIX_LANG.get(suffix, suffix or "(no ext)")
54
55
56 def is_semantic(file_path: str) -> bool:
57 """Return ``True`` if *file_path* has a suffix with AST-level support."""
58 suffix = pathlib.PurePosixPath(file_path).suffix.lower()
59 return suffix in SEMANTIC_EXTENSIONS
60
61
62 # ---------------------------------------------------------------------------
63 # Symbol extraction from a snapshot manifest
64 # ---------------------------------------------------------------------------
65
66
67 def symbols_for_snapshot(
68 root: pathlib.Path,
69 manifest: dict[str, str],
70 *,
71 kind_filter: str | None = None,
72 file_filter: str | None = None,
73 language_filter: str | None = None,
74 ) -> dict[str, SymbolTree]:
75 """Extract symbol trees for all semantic files in *manifest*.
76
77 Args:
78 root: Repository root (used to locate the object store).
79 manifest: Snapshot manifest mapping file path → SHA-256.
80 kind_filter: If set, only include symbols with this ``kind``.
81 file_filter: If set, only include symbols from this exact file path.
82 language_filter: If set, only include symbols from files of this language.
83
84 Returns:
85 Dict mapping ``file_path → SymbolTree``; empty trees are omitted.
86 """
87 result: dict[str, SymbolTree] = {}
88 for file_path, object_id in sorted(manifest.items()):
89 if not is_semantic(file_path):
90 continue
91 if file_filter and file_path != file_filter:
92 continue
93 if language_filter and language_of(file_path) != language_filter:
94 continue
95 raw = read_object(root, object_id)
96 if raw is None:
97 logger.debug("Object %s missing for %s — skipping", object_id[:8], file_path)
98 continue
99 tree = parse_symbols(raw, file_path)
100 if kind_filter:
101 tree = {addr: rec for addr, rec in tree.items() if rec["kind"] == kind_filter}
102 if tree:
103 result[file_path] = tree
104 return result
105
106
107 # ---------------------------------------------------------------------------
108 # Commit-graph walking
109 # ---------------------------------------------------------------------------
110
111
112 def walk_commits(
113 root: pathlib.Path,
114 start_commit_id: str,
115 max_commits: int = 10_000,
116 ) -> list[CommitRecord]:
117 """Walk the parent chain from *start_commit_id*, newest-first.
118
119 Args:
120 root: Repository root.
121 start_commit_id: SHA-256 of the commit to start from.
122 max_commits: Safety cap — stop after this many commits.
123
124 Returns:
125 List of ``CommitRecord`` objects, newest first.
126 """
127 commits: list[CommitRecord] = []
128 seen: set[str] = set()
129 current_id: str | None = start_commit_id
130 while current_id and current_id not in seen and len(commits) < max_commits:
131 seen.add(current_id)
132 commit = read_commit(root, current_id)
133 if commit is None:
134 break
135 commits.append(commit)
136 current_id = commit.parent_commit_id
137 return commits
138
139
140 def walk_commits_range(
141 root: pathlib.Path,
142 to_commit_id: str,
143 from_commit_id: str | None,
144 max_commits: int = 10_000,
145 ) -> list[CommitRecord]:
146 """Collect commits from *to_commit_id* back to (not including) *from_commit_id*.
147
148 Args:
149 root: Repository root.
150 to_commit_id: Inclusive end of the range.
151 from_commit_id: Exclusive start; ``None`` means walk to the initial commit.
152 max_commits: Safety cap.
153
154 Returns:
155 List of ``CommitRecord`` objects, newest first.
156 """
157 commits: list[CommitRecord] = []
158 seen: set[str] = set()
159 current_id: str | None = to_commit_id
160 while current_id and current_id not in seen and len(commits) < max_commits:
161 seen.add(current_id)
162 if current_id == from_commit_id:
163 break
164 commit = read_commit(root, current_id)
165 if commit is None:
166 break
167 commits.append(commit)
168 current_id = commit.parent_commit_id
169 return commits
170
171
172 # ---------------------------------------------------------------------------
173 # Op traversal helpers
174 # ---------------------------------------------------------------------------
175
176
177 def flat_symbol_ops(ops: list[DomainOp]) -> Iterator[DomainOp]:
178 """Yield all leaf ops, recursing into PatchOp.child_ops.
179
180 Only yields ops that have a symbol-level address (i.e. contain ``::``).
181 """
182 for op in ops:
183 if op["op"] == "patch":
184 for child in op["child_ops"]:
185 if "::" in child["address"]:
186 yield child
187 elif "::" in op["address"]:
188 yield op
189
190
191 def touched_files(ops: list[DomainOp]) -> frozenset[str]:
192 """Return the set of file paths that appear as PatchOp addresses in *ops*.
193
194 Only counts files that had symbol-level child ops (semantic changes),
195 not coarse file-level replace/insert/delete ops.
196 """
197 files: set[str] = set()
198 for op in ops:
199 if op["op"] == "patch" and op["child_ops"]:
200 files.add(op["address"])
201 return frozenset(files)
202
203
204 def file_pairs(files: frozenset[str]) -> Iterator[tuple[str, str]]:
205 """Yield all ordered pairs ``(a, b)`` with ``a < b`` from *files*."""
206 yield from itertools.combinations(sorted(files), 2)