cgcardona / muse public
deps.py python
370 lines 13.2 KB
44b98511 fix(code): close 4 architectural gaps — validation, deps, find-symbol, … Gabriel Cardona <cgcardona@gmail.com> 1d ago
1 """muse deps — import graph and call-graph analysis.
2
3 Answers two questions that Git cannot:
4
5 **File mode** (``muse deps src/billing.py``):
6 What does this file import, and what files in the repo import it?
7
8 **Symbol mode** (``muse deps "src/billing.py::compute_invoice_total"``):
9 What does this function call? (Python only; uses stdlib ``ast``.)
10 With ``--reverse``: what symbols in the repo call this function?
11
12 These relationships are *structural impossibilities* in Git: Git stores files
13 as blobs of text with no concept of imports or call-sites. Muse reads the
14 typed symbol graph produced at commit time and the AST of the working tree
15 to answer these questions in milliseconds.
16
17 Usage::
18
19 muse deps src/billing.py # import graph (file)
20 muse deps src/billing.py --reverse # who imports this file?
21 muse deps "src/billing.py::compute_invoice_total" # call graph (Python)
22 muse deps "src/billing.py::compute_invoice_total" --reverse # callers
23
24 Flags:
25
26 ``--commit, -c REF``
27 Inspect a historical snapshot instead of HEAD (import graph mode only).
28
29 ``--reverse``
30 Invert the query: show callers instead of callees, or importers instead
31 of imports.
32
33 ``--json``
34 Emit results as JSON.
35 """
36 from __future__ import annotations
37
38 import ast
39 import json
40 import logging
41 import pathlib
42
43 import typer
44
45 from muse.core.errors import ExitCode
46 from muse.core.object_store import read_object
47 from muse.core.repo import require_repo
48 from muse.core.store import get_commit_snapshot_manifest, resolve_commit_ref
49 from muse.plugins.code._query import language_of, symbols_for_snapshot
50 from muse.plugins.code.ast_parser import SEMANTIC_EXTENSIONS, SymbolTree, parse_symbols
51
52 logger = logging.getLogger(__name__)
53
54 app = typer.Typer()
55
56
57 def _read_repo_id(root: pathlib.Path) -> str:
58 return str(json.loads((root / ".muse" / "repo.json").read_text())["repo_id"])
59
60
61 def _read_branch(root: pathlib.Path) -> str:
62 head_ref = (root / ".muse" / "HEAD").read_text().strip()
63 return head_ref.removeprefix("refs/heads/").strip()
64
65
66 # ---------------------------------------------------------------------------
67 # Import graph helpers
68 # ---------------------------------------------------------------------------
69
70
71 def _imports_in_tree(tree: SymbolTree) -> list[str]:
72 """Return the list of module/symbol names imported by symbols in *tree*."""
73 return sorted(
74 rec["qualified_name"]
75 for rec in tree.values()
76 if rec["kind"] == "import"
77 )
78
79
80 def _file_imports(
81 root: pathlib.Path,
82 manifest: dict[str, str],
83 target_file: str,
84 ) -> list[str]:
85 """Return import names declared in *target_file* within *manifest*."""
86 obj_id = manifest.get(target_file)
87 if obj_id is None:
88 return []
89 raw = read_object(root, obj_id)
90 if raw is None:
91 return []
92 tree = parse_symbols(raw, target_file)
93 return _imports_in_tree(tree)
94
95
96 def _reverse_imports(
97 root: pathlib.Path,
98 manifest: dict[str, str],
99 target_file: str,
100 ) -> list[str]:
101 """Return files in *manifest* that import a name matching *target_file*.
102
103 The heuristic: the target file's stem (e.g. ``billing`` for
104 ``src/billing.py``) is matched against each other file's import names.
105 This catches ``import billing``, ``from billing import X``, and fully-
106 qualified paths like ``src.billing``.
107 """
108 target_stem = pathlib.PurePosixPath(target_file).stem
109 target_module = pathlib.PurePosixPath(target_file).with_suffix("").as_posix().replace("/", ".")
110 importers: list[str] = []
111 for file_path, obj_id in manifest.items():
112 if file_path == target_file:
113 continue
114 suffix = pathlib.PurePosixPath(file_path).suffix.lower()
115 if suffix not in SEMANTIC_EXTENSIONS:
116 continue
117 raw = read_object(root, obj_id)
118 if raw is None:
119 continue
120 tree = parse_symbols(raw, file_path)
121 for imp_name in _imports_in_tree(tree):
122 # Match stem or any suffix of the dotted module path.
123 if (
124 imp_name == target_stem
125 or imp_name == target_module
126 or imp_name.endswith(f".{target_stem}")
127 or imp_name.endswith(f".{target_module}")
128 or target_stem in imp_name.split(".")
129 ):
130 importers.append(file_path)
131 break
132 return sorted(importers)
133
134
135 # ---------------------------------------------------------------------------
136 # Call-graph helpers (Python only)
137 # ---------------------------------------------------------------------------
138
139
140 def _call_name(func_node: ast.expr) -> str | None:
141 """Extract a readable callee name from an ``ast.Call`` func node."""
142 if isinstance(func_node, ast.Name):
143 return func_node.id
144 if isinstance(func_node, ast.Attribute):
145 # e.g. obj.method() → "method"
146 # e.g. module.func() → "func"
147 return func_node.attr
148 return None
149
150
151 def _find_func_node(
152 stmts: list[ast.stmt],
153 name_parts: list[str],
154 ) -> ast.FunctionDef | ast.AsyncFunctionDef | None:
155 """Recursively locate a function node by its dotted qualified name."""
156 target = name_parts[0]
157 for stmt in stmts:
158 if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)) and stmt.name == target:
159 if len(name_parts) == 1:
160 return stmt
161 elif isinstance(stmt, ast.ClassDef) and stmt.name == target and len(name_parts) > 1:
162 return _find_func_node(stmt.body, name_parts[1:])
163 return None
164
165
166 def _python_callees(source: bytes, address: str) -> list[str]:
167 """Return sorted unique names of callees inside the Python symbol at *address*."""
168 sym_qualified = address.split("::", 1)[1] if "::" in address else address
169 try:
170 tree = ast.parse(source)
171 except SyntaxError:
172 return []
173 func_node = _find_func_node(tree.body, sym_qualified.split("."))
174 if func_node is None:
175 return []
176 names: set[str] = set()
177 for node in ast.walk(func_node):
178 if isinstance(node, ast.Call):
179 name = _call_name(node.func)
180 if name:
181 names.add(name)
182 return sorted(names)
183
184
185 def _python_callers(
186 root: pathlib.Path,
187 manifest: dict[str, str],
188 target_name: str,
189 ) -> list[str]:
190 """Return addresses of Python symbols that call *target_name*.
191
192 *target_name* is the bare function/method name extracted from the address.
193 Scans all Python files in *manifest*.
194 """
195 callers: list[str] = []
196 for file_path, obj_id in sorted(manifest.items()):
197 if pathlib.PurePosixPath(file_path).suffix.lower() not in {".py", ".pyi"}:
198 continue
199 raw = read_object(root, obj_id)
200 if raw is None:
201 continue
202 try:
203 tree = ast.parse(raw)
204 except SyntaxError:
205 continue
206 sym_tree = parse_symbols(raw, file_path)
207 # Check each symbol body for calls to target_name.
208 for addr, rec in sym_tree.items():
209 if rec["kind"] not in {"function", "async_function", "method", "async_method"}:
210 continue
211 qualified = rec["qualified_name"]
212 func_node = _find_func_node(tree.body, qualified.split("."))
213 if func_node is None:
214 continue
215 for node in ast.walk(func_node):
216 if isinstance(node, ast.Call):
217 name = _call_name(node.func)
218 if name == target_name:
219 callers.append(addr)
220 break
221 return callers
222
223
224 # ---------------------------------------------------------------------------
225 # Command
226 # ---------------------------------------------------------------------------
227
228
229 @app.callback(invoke_without_command=True)
230 def deps(
231 ctx: typer.Context,
232 target: str = typer.Argument(
233 ..., metavar="TARGET",
234 help=(
235 'File path (e.g. "src/billing.py") for import graph, or '
236 'symbol address (e.g. "src/billing.py::compute_invoice_total") for call graph.'
237 ),
238 ),
239 reverse: bool = typer.Option(
240 False, "--reverse", "-r",
241 help="Show importers (file mode) or callers (symbol mode) instead.",
242 ),
243 ref: str | None = typer.Option(
244 None, "--commit", "-c", metavar="REF",
245 help="Inspect a historical commit instead of HEAD (import graph mode only).",
246 ),
247 as_json: bool = typer.Option(False, "--json", help="Emit results as JSON."),
248 ) -> None:
249 """Show the import graph or call graph for a file or symbol.
250
251 **File mode** — pass a file path::
252
253 muse deps src/billing.py # what does billing.py import?
254 muse deps src/billing.py --reverse # what files import billing.py?
255
256 **Symbol mode** — pass a symbol address (Python only for call graph)::
257
258 muse deps "src/billing.py::compute_invoice_total"
259 muse deps "src/billing.py::compute_invoice_total" --reverse
260
261 Call-graph analysis uses the live working tree for symbol mode.
262 Import-graph analysis uses the committed snapshot (``--commit`` to pin).
263 """
264 root = require_repo()
265 repo_id = _read_repo_id(root)
266 branch = _read_branch(root)
267
268 is_symbol_mode = "::" in target
269
270 # ----------------------------------------------------------------
271 # Symbol mode: call-graph (Python only)
272 # ----------------------------------------------------------------
273 if is_symbol_mode:
274 file_rel, sym_qualified = target.split("::", 1)
275 lang = language_of(file_rel)
276 if lang != "Python":
277 typer.echo(
278 f"⚠️ Call-graph analysis is currently Python-only. "
279 f"'{file_rel}' is {lang}.",
280 err=True,
281 )
282 raise typer.Exit(code=ExitCode.USER_ERROR)
283
284 # Read from working tree.
285 candidates = [root / "muse-work" / file_rel, root / file_rel]
286 src_path: pathlib.Path | None = None
287 for c in candidates:
288 if c.exists():
289 src_path = c
290 break
291 if src_path is None:
292 typer.echo(f"❌ File '{file_rel}' not found in working tree.", err=True)
293 raise typer.Exit(code=ExitCode.USER_ERROR)
294
295 source = src_path.read_bytes()
296
297 if not reverse:
298 callees = _python_callees(source, target)
299 if as_json:
300 typer.echo(json.dumps({"address": target, "calls": callees}, indent=2))
301 return
302 typer.echo(f"\nCallees of {target}")
303 typer.echo("─" * 62)
304 if not callees:
305 typer.echo(" (no function calls detected)")
306 else:
307 for name in callees:
308 typer.echo(f" {name}")
309 typer.echo(f"\n{len(callees)} callee(s)")
310 else:
311 target_name = sym_qualified.split(".")[-1]
312 commit = resolve_commit_ref(root, repo_id, branch, None)
313 if commit is None:
314 typer.echo("❌ No commits found.", err=True)
315 raise typer.Exit(code=ExitCode.USER_ERROR)
316 manifest = get_commit_snapshot_manifest(root, commit.commit_id) or {}
317 callers = _python_callers(root, manifest, target_name)
318 if as_json:
319 typer.echo(json.dumps(
320 {"address": target, "target_name": target_name, "called_by": callers},
321 indent=2,
322 ))
323 return
324 typer.echo(f"\nCallers of {target}")
325 typer.echo(f" (matching bare name: {target_name!r})")
326 typer.echo("─" * 62)
327 if not callers:
328 typer.echo(" (no callers found in committed snapshot)")
329 else:
330 for addr in callers:
331 typer.echo(f" {addr}")
332 typer.echo(f"\n{len(callers)} caller(s) found")
333 return
334
335 # ----------------------------------------------------------------
336 # File mode: import graph
337 # ----------------------------------------------------------------
338 commit = resolve_commit_ref(root, repo_id, branch, ref)
339 if commit is None:
340 typer.echo(f"❌ Commit '{ref or 'HEAD'}' not found.", err=True)
341 raise typer.Exit(code=ExitCode.USER_ERROR)
342
343 manifest = get_commit_snapshot_manifest(root, commit.commit_id) or {}
344
345 if not reverse:
346 imports = _file_imports(root, manifest, target)
347 if as_json:
348 typer.echo(json.dumps({"file": target, "imports": imports}, indent=2))
349 return
350 typer.echo(f"\nImports declared in {target}")
351 typer.echo("─" * 62)
352 if not imports:
353 typer.echo(" (no imports found)")
354 else:
355 for name in imports:
356 typer.echo(f" {name}")
357 typer.echo(f"\n{len(imports)} import(s)")
358 else:
359 importers = _reverse_imports(root, manifest, target)
360 if as_json:
361 typer.echo(json.dumps({"file": target, "imported_by": importers}, indent=2))
362 return
363 typer.echo(f"\nFiles that import {target}")
364 typer.echo("─" * 62)
365 if not importers:
366 typer.echo(" (no files import this module in the committed snapshot)")
367 else:
368 for fp in importers:
369 typer.echo(f" {fp}")
370 typer.echo(f"\n{len(importers)} importer(s) found")