gabriel / muse public
ast_parser.py python
1005 lines 34.9 KB
44b98511 fix(code): close 4 architectural gaps — validation, deps, find-symbol, … Gabriel Cardona <cgcardona@gmail.com> 5d ago
1 """AST parsing and symbol extraction for the code domain plugin.
2
3 This module provides the :class:`LanguageAdapter` protocol and concrete
4 adapters for parsing source files into :type:`SymbolTree` structures.
5
6 Language support matrix
7 -----------------------
8 - **Python** (``*.py``, ``*.pyi``): Full AST-based extraction using the
9 stdlib :mod:`ast` module. Content IDs are hashes of normalized (unparsed)
10 AST text — insensitive to whitespace, comments, and formatting.
11 - **JavaScript / TypeScript** (``*.js``, ``*.jsx``, ``*.mjs``, ``*.cjs``,
12 ``*.ts``, ``*.tsx``): tree-sitter based.
13 - **Go** (``*.go``): tree-sitter based. Method qualified names carry the
14 receiver type (e.g. ``Dog.Bark``).
15 - **Rust** (``*.rs``): tree-sitter based. Functions inside ``impl`` blocks
16 are qualified with the implementing type (e.g. ``Dog.bark``).
17 - **Java** (``*.java``), **C#** (``*.cs``): tree-sitter based.
18 - **C** (``*.c``, ``*.h``), **C++** (``*.cpp``, ``*.cc``, ``*.cxx``,
19 ``*.hpp``, ``*.hxx``): tree-sitter based.
20 - **Ruby** (``*.rb``), **Kotlin** (``*.kt``, ``*.kts``): tree-sitter based.
21
22 Symbol addresses
23 ----------------
24 Every extracted symbol is stored in the :type:`SymbolTree` dict under a
25 stable *address* key of the form::
26
27 "<workspace-relative-posix-path>::<qualified-symbol-name>"
28
29 Nested symbols (class methods) use dotted qualified names::
30
31 "src/models.py::User.save"
32 "src/models.py::User.__init__"
33
34 Top-level symbols::
35
36 "src/utils.py::calculate_total"
37 "src/utils.py::import::pathlib"
38
39 Content IDs and rename / move detection
40 ----------------------------------------
41 Each :class:`SymbolRecord` carries three hashes:
42
43 ``content_id``
44 SHA-256 of the full normalized AST of the symbol (includes name,
45 signature, and body). Two symbols are "the same thing" when their
46 ``content_id`` matches — regardless of where in the repo they live.
47
48 ``body_hash``
49 SHA-256 of the normalized body statements only (excludes the ``def``
50 line). Used to detect *renames*: same body, different name.
51
52 ``signature_id``
53 SHA-256 of ``"name(args) -> return"``. Used to detect *implementation-
54 only changes*: signature unchanged, body changed.
55
56 Extending
57 ---------
58 Implement :class:`LanguageAdapter` and append an instance to
59 :data:`ADAPTERS`. The adapter is selected by the file's suffix, with the
60 first matching adapter taking priority.
61 """
62 from __future__ import annotations
63
64 import ast
65 import hashlib
66 import importlib
67 import logging
68 import pathlib
69 import re
70 from typing import Literal, Protocol, TypedDict, runtime_checkable
71
72 from tree_sitter import Language, Node, Parser, Query, QueryCursor
73
74 logger = logging.getLogger(__name__)
75
76 # ---------------------------------------------------------------------------
77 # Symbol record types
78 # ---------------------------------------------------------------------------
79
80 SymbolKind = Literal[
81 "function",
82 "async_function",
83 "class",
84 "method",
85 "async_method",
86 "variable",
87 "import",
88 ]
89
90
91 class SymbolRecord(TypedDict):
92 """Content-addressed record for a single named symbol in source code."""
93
94 kind: SymbolKind
95 name: str
96 qualified_name: str # "ClassName.method" for nested; flat name for top-level
97 content_id: str # SHA-256 of full normalized AST (name + signature + body)
98 body_hash: str # SHA-256 of body stmts only — for rename detection
99 signature_id: str # SHA-256 of "name(args)->return" — for impl-only changes
100 lineno: int
101 end_lineno: int
102
103
104 #: Flat map from symbol address to :class:`SymbolRecord`.
105 #: Nested symbols (methods) appear at their qualified address alongside the
106 #: parent class.
107 SymbolTree = dict[str, SymbolRecord]
108
109
110 # ---------------------------------------------------------------------------
111 # Language adapter protocol
112 # ---------------------------------------------------------------------------
113
114
115 @runtime_checkable
116 class LanguageAdapter(Protocol):
117 """Protocol every language adapter must implement.
118
119 Adapters are stateless. The same instance may be called concurrently
120 for different files without synchronization.
121 """
122
123 def supported_extensions(self) -> frozenset[str]:
124 """Return the set of lowercase file suffixes this adapter handles."""
125 ...
126
127 def parse_symbols(self, source: bytes, file_path: str) -> SymbolTree:
128 """Extract the symbol tree from raw source bytes.
129
130 Args:
131 source: Raw bytes of the source file.
132 file_path: Workspace-relative POSIX path — used to build the
133 symbol address prefix.
134
135 Returns:
136 A :type:`SymbolTree` mapping symbol addresses to
137 :class:`SymbolRecord` dicts. Returns an empty dict on parse
138 errors so that the caller can fall through to file-level ops.
139 """
140 ...
141
142 def file_content_id(self, source: bytes) -> str:
143 """Return a stable content identifier for the whole file.
144
145 For AST-capable adapters: hash of the normalized (unparsed) module
146 AST — insensitive to formatting and comments.
147 For non-AST adapters: SHA-256 of raw bytes.
148
149 Args:
150 source: Raw bytes of the file.
151
152 Returns:
153 Hex-encoded SHA-256 digest.
154 """
155 ...
156
157
158 # ---------------------------------------------------------------------------
159 # Helpers
160 # ---------------------------------------------------------------------------
161
162
163 def _sha256(text: str) -> str:
164 return hashlib.sha256(text.encode("utf-8", errors="replace")).hexdigest()
165
166
167 def _sha256_bytes(data: bytes) -> str:
168 return hashlib.sha256(data).hexdigest()
169
170
171 # ---------------------------------------------------------------------------
172 # Python adapter
173 # ---------------------------------------------------------------------------
174
175
176 class PythonAdapter:
177 """Python language adapter — AST-based, zero external dependencies.
178
179 Uses :func:`ast.parse` for parsing and :func:`ast.unparse` for
180 normalization. The result is a deterministic, whitespace-insensitive
181 representation that strips comments and normalizes indentation.
182
183 ``ast.unparse`` is available since Python 3.9; Muse requires 3.12.
184 """
185
186 def supported_extensions(self) -> frozenset[str]:
187 return frozenset({".py", ".pyi"})
188
189 def parse_symbols(self, source: bytes, file_path: str) -> SymbolTree:
190 try:
191 tree = ast.parse(source, filename=file_path)
192 except SyntaxError:
193 return {}
194 symbols: SymbolTree = {}
195 _extract_stmts(tree.body, file_path, "", symbols)
196 return symbols
197
198 def file_content_id(self, source: bytes) -> str:
199 try:
200 tree = ast.parse(source)
201 return _sha256(ast.unparse(tree))
202 except SyntaxError:
203 return _sha256_bytes(source)
204
205
206 # ---------------------------------------------------------------------------
207 # AST extraction helpers (module-level so they can be tested independently)
208 # ---------------------------------------------------------------------------
209
210
211 def _extract_stmts(
212 stmts: list[ast.stmt],
213 file_path: str,
214 class_prefix: str,
215 out: SymbolTree,
216 ) -> None:
217 """Recursively walk *stmts* and populate *out* with symbol records.
218
219 Args:
220 stmts: Statement list from an :class:`ast.Module` or
221 :class:`ast.ClassDef` body.
222 file_path: Workspace-relative POSIX path — used as address prefix.
223 class_prefix: Dotted class path for methods (e.g. ``"MyClass."``).
224 Empty string at top-level.
225 out: Accumulator — modified in place.
226 """
227 for node in stmts:
228 if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
229 is_async = isinstance(node, ast.AsyncFunctionDef)
230 if class_prefix:
231 kind: SymbolKind = "async_method" if is_async else "method"
232 else:
233 kind = "async_function" if is_async else "function"
234 qualified = f"{class_prefix}{node.name}"
235 addr = f"{file_path}::{qualified}"
236 out[addr] = _make_func_record(node, node.name, qualified, kind)
237
238 elif isinstance(node, ast.ClassDef):
239 qualified = f"{class_prefix}{node.name}"
240 addr = f"{file_path}::{qualified}"
241 out[addr] = _make_class_record(node, qualified)
242 _extract_stmts(node.body, file_path, f"{qualified}.", out)
243
244 elif isinstance(node, (ast.Assign, ast.AnnAssign)) and not class_prefix:
245 # Only top-level assignments — class-level attributes are captured
246 # as part of the parent class's content_id.
247 for name in _assignment_names(node):
248 addr = f"{file_path}::{name}"
249 out[addr] = _make_var_record(node, name)
250
251 elif isinstance(node, (ast.Import, ast.ImportFrom)) and not class_prefix:
252 for name in _import_names(node):
253 addr = f"{file_path}::import::{name}"
254 out[addr] = _make_import_record(node, name)
255
256
257 def _make_func_record(
258 node: ast.FunctionDef | ast.AsyncFunctionDef,
259 name: str,
260 qualified_name: str,
261 kind: SymbolKind,
262 ) -> SymbolRecord:
263 full_src = ast.unparse(node)
264 body_src = "\n".join(ast.unparse(s) for s in node.body)
265 args_src = ast.unparse(node.args)
266 ret_src = ast.unparse(node.returns) if node.returns else ""
267 return SymbolRecord(
268 kind=kind,
269 name=name,
270 qualified_name=qualified_name,
271 content_id=_sha256(full_src),
272 body_hash=_sha256(body_src),
273 signature_id=_sha256(f"{name}({args_src})->{ret_src}"),
274 lineno=node.lineno,
275 end_lineno=node.end_lineno or node.lineno,
276 )
277
278
279 def _make_class_record(node: ast.ClassDef, qualified_name: str) -> SymbolRecord:
280 full_src = ast.unparse(node)
281 base_src = ", ".join(ast.unparse(b) for b in node.bases) if node.bases else ""
282 # Body hash captures class structure (bases + method names) but NOT method
283 # bodies — those change independently and have their own records.
284 method_names = sorted(
285 n.name
286 for n in node.body
287 if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))
288 )
289 structure = f"class {node.name}({base_src}):{method_names}"
290 header = f"class {node.name}({base_src})" if node.bases else f"class {node.name}"
291 return SymbolRecord(
292 kind="class",
293 name=node.name,
294 qualified_name=qualified_name,
295 content_id=_sha256(full_src),
296 body_hash=_sha256(structure),
297 signature_id=_sha256(header),
298 lineno=node.lineno,
299 end_lineno=node.end_lineno or node.lineno,
300 )
301
302
303 def _make_var_record(node: ast.Assign | ast.AnnAssign, name: str) -> SymbolRecord:
304 normalized = ast.unparse(node)
305 return SymbolRecord(
306 kind="variable",
307 name=name,
308 qualified_name=name,
309 content_id=_sha256(normalized),
310 body_hash=_sha256(normalized),
311 signature_id=_sha256(name),
312 lineno=node.lineno,
313 end_lineno=node.end_lineno or node.lineno,
314 )
315
316
317 def _make_import_record(
318 node: ast.Import | ast.ImportFrom, name: str
319 ) -> SymbolRecord:
320 normalized = ast.unparse(node)
321 return SymbolRecord(
322 kind="import",
323 name=name,
324 qualified_name=f"import::{name}",
325 content_id=_sha256(normalized),
326 body_hash=_sha256(normalized),
327 signature_id=_sha256(name),
328 lineno=node.lineno,
329 end_lineno=node.lineno,
330 )
331
332
333 def _assignment_names(node: ast.Assign | ast.AnnAssign) -> list[str]:
334 if isinstance(node, ast.Assign):
335 return [t.id for t in node.targets if isinstance(t, ast.Name)]
336 if isinstance(node.target, ast.Name):
337 return [node.target.id]
338 return []
339
340
341 def _import_names(node: ast.Import | ast.ImportFrom) -> list[str]:
342 if isinstance(node, ast.Import):
343 return [a.asname or a.name for a in node.names]
344 # ImportFrom
345 if node.names and node.names[0].name == "*":
346 return [f"*:{node.module or '?'}"]
347 return [a.asname or a.name for a in node.names]
348
349
350 # ---------------------------------------------------------------------------
351 # Fallback adapter — file-level identity only, no symbol extraction
352 # ---------------------------------------------------------------------------
353
354
355 class FallbackAdapter:
356 """Fallback adapter for languages without a dedicated AST parser.
357
358 Returns an empty :type:`SymbolTree` (file-level tracking only) and uses
359 raw-bytes SHA-256 as the file content ID.
360 """
361
362 def __init__(self, extensions: frozenset[str]) -> None:
363 self._extensions = extensions
364
365 def supported_extensions(self) -> frozenset[str]:
366 return self._extensions
367
368 def parse_symbols(self, source: bytes, file_path: str) -> SymbolTree: # noqa: ARG002
369 return {}
370
371 def file_content_id(self, source: bytes) -> str:
372 return _sha256_bytes(source)
373
374
375 # ---------------------------------------------------------------------------
376 # tree-sitter adapter — shared infrastructure for all non-Python languages
377 # ---------------------------------------------------------------------------
378
379 _WS_RE: re.Pattern[bytes] = re.compile(rb"\s+")
380
381
382 def _norm_ws(src: bytes) -> bytes:
383 """Collapse all whitespace runs to a single space and strip the result."""
384 return _WS_RE.sub(b" ", src).strip()
385
386
387 def _node_text(src: bytes, node: Node) -> bytes:
388 """Extract the raw source bytes covered by a tree-sitter node."""
389 return src[node.start_byte : node.end_byte]
390
391
392 def _class_name_from(src: bytes, node: Node, field: str) -> str | None:
393 """Extract a class/struct name from a parent CST node.
394
395 Tries ``child_by_field_name(field)`` first (covers Java, C#, C++, Rust).
396 Falls back to the first ``identifier``-typed named child to handle
397 languages like Kotlin where the class name is not a named field.
398 """
399 child = node.child_by_field_name(field)
400 if child is None:
401 for c in node.named_children:
402 if c.type == "identifier":
403 child = c
404 break
405 if child is None:
406 return None
407 return _node_text(src, child).decode("utf-8", errors="replace")
408
409
410 def _qualified_name_ts(
411 src: bytes,
412 sym_node: Node,
413 name: str,
414 class_node_types: frozenset[str],
415 class_name_field: str,
416 ) -> str:
417 """Walk the CST parent chain to build a dotted qualified name.
418
419 For a method ``bark`` inside ``class Dog``, returns ``"Dog.bark"``.
420 For a top-level function, returns just ``"standalone"``.
421 """
422 parts = [name]
423 parent = sym_node.parent
424 while parent is not None:
425 if parent.type in class_node_types:
426 class_name = _class_name_from(src, parent, class_name_field)
427 if class_name:
428 parts.insert(0, class_name)
429 parent = parent.parent
430 return ".".join(parts)
431
432
433 class LangSpec(TypedDict):
434 """Per-language tree-sitter configuration consumed by :class:`TreeSitterAdapter`."""
435
436 extensions: frozenset[str]
437 module_name: str # Python import name, e.g. ``"tree_sitter_javascript"``
438 lang_func: str # Attribute on the module returning the raw capsule
439 query_str: str # tree-sitter S-expr query — must capture ``@sym`` and ``@name``
440 kind_map: dict[str, SymbolKind] # CST node type → SymbolKind
441 class_node_types: frozenset[str] # Ancestor types that scope methods
442 class_name_field: str # Field name for the class name (e.g. ``"name"`` or ``"type"``)
443 receiver_capture: str # Capture name for Go-style method receivers; ``""`` to skip
444
445
446 class TreeSitterAdapter:
447 """Implements :class:`LanguageAdapter` using tree-sitter for real CST parsing.
448
449 tree-sitter is the same parsing technology used by GitHub Copilot, VS Code,
450 Neovim, and Zed. It produces a concrete syntax tree from every source file,
451 even if the file has syntax errors — making it suitable for real-world repos
452 that may contain partially-written code.
453
454 Parsing is error-tolerant: individual file failures are logged at DEBUG
455 level and return an empty :type:`SymbolTree` so the caller falls back to
456 file-level diffing rather than crashing.
457 """
458
459 def __init__(
460 self,
461 spec: LangSpec,
462 parser: Parser,
463 language: Language,
464 ) -> None:
465 self._spec = spec
466 self._parser = parser
467 self._language = language
468 self._query = Query(language, spec["query_str"])
469
470 def supported_extensions(self) -> frozenset[str]:
471 return self._spec["extensions"]
472
473 def parse_symbols(self, source: bytes, file_path: str) -> SymbolTree:
474 try:
475 tree = self._parser.parse(source)
476 cursor = QueryCursor(self._query)
477 symbols: SymbolTree = {}
478 recv_cap = self._spec["receiver_capture"]
479
480 for _pat, caps in cursor.matches(tree.root_node):
481 sym_list = caps.get("sym", [])
482 name_list = caps.get("name", [])
483 if not sym_list or not name_list:
484 continue
485 sym_node = sym_list[0]
486 name_node = name_list[0]
487
488 name_txt = _node_text(source, name_node).decode(
489 "utf-8", errors="replace"
490 )
491 kind = self._spec["kind_map"].get(sym_node.type, "function")
492
493 # Build qualified name — walking ancestor chain for methods.
494 qualified = _qualified_name_ts(
495 source,
496 sym_node,
497 name_txt,
498 self._spec["class_node_types"],
499 self._spec["class_name_field"],
500 )
501
502 # Go-style receiver prefix: (d *Dog) → "Dog.Bark"
503 if recv_cap:
504 recv_list = caps.get(recv_cap, [])
505 if recv_list:
506 recv_txt = (
507 _node_text(source, recv_list[0])
508 .decode("utf-8", errors="replace")
509 .lstrip("*")
510 .strip()
511 )
512 if recv_txt:
513 qualified = f"{recv_txt}.{qualified}"
514
515 addr = f"{file_path}::{qualified}"
516 node_bytes = _node_text(source, sym_node)
517 name_bytes = _node_text(source, name_node)
518 # Substitute the name with a placeholder to isolate the body
519 # from the identifier — two symbols with the same body but
520 # different names share the same body_hash, signalling a rename.
521 body_bytes = node_bytes.replace(name_bytes, b"\xfe", 1)
522
523 params_node = (
524 sym_node.child_by_field_name("parameters")
525 or sym_node.child_by_field_name("formal_parameters")
526 or sym_node.child_by_field_name("function_value_parameters")
527 )
528 params_bytes = (
529 _node_text(source, params_node)
530 if params_node is not None
531 else b""
532 )
533
534 symbols[addr] = SymbolRecord(
535 kind=kind,
536 name=name_txt,
537 qualified_name=qualified,
538 content_id=_sha256_bytes(_norm_ws(node_bytes)),
539 body_hash=_sha256_bytes(_norm_ws(body_bytes)),
540 signature_id=_sha256_bytes(_norm_ws(name_bytes + params_bytes)),
541 lineno=sym_node.start_point[0] + 1,
542 end_lineno=sym_node.end_point[0] + 1,
543 )
544 return symbols
545 except Exception as exc: # noqa: BLE001
546 logger.debug("tree-sitter parse error in %s: %s", file_path, exc)
547 return {}
548
549 def file_content_id(self, source: bytes) -> str:
550 """Whitespace-normalised SHA-256 of the source — insensitive to reformatting."""
551 return _sha256_bytes(_norm_ws(source))
552
553 def validate_source(self, source: bytes) -> str | None:
554 """Return an error description if *source* has syntax errors, else None.
555
556 tree-sitter always produces a parse tree even for broken code.
557 Errors appear as nodes with ``type == "ERROR"`` or ``is_missing == True``.
558 ``root_node.has_error`` is the fast top-level check.
559 """
560 try:
561 tree = self._parser.parse(source)
562 except Exception as exc: # noqa: BLE001
563 return f"parser error: {exc}"
564
565 if not tree.root_node.has_error:
566 return None
567
568 # Walk the tree to find the first concrete error site.
569 error_node = _first_error_node(tree.root_node)
570 if error_node is not None:
571 line = error_node.start_point[0] + 1
572 fragment = source[
573 error_node.start_byte : min(error_node.end_byte, error_node.start_byte + 60)
574 ].decode("utf-8", errors="replace").strip()
575 msg = f"syntax error on line {line}"
576 if fragment:
577 msg += f": {fragment!r}"
578 return msg
579 return "syntax error (unknown location)"
580
581
582 def _make_ts_adapter(spec: LangSpec) -> LanguageAdapter:
583 """Build a :class:`TreeSitterAdapter`; fall back to :class:`FallbackAdapter` on error.
584
585 Importing the grammar capsule is deferred to this factory so that a
586 missing or incompatible grammar package degrades gracefully rather than
587 preventing the entire plugin from loading.
588 """
589 try:
590 mod = importlib.import_module(spec["module_name"])
591 raw_lang = getattr(mod, spec["lang_func"])()
592 lang = Language(raw_lang)
593 parser = Parser(lang)
594 return TreeSitterAdapter(spec, parser, lang)
595 except Exception as exc: # noqa: BLE001
596 logger.debug(
597 "tree-sitter grammar %s.%s unavailable — using file-level fallback: %s",
598 spec["module_name"],
599 spec["lang_func"],
600 exc,
601 )
602 return FallbackAdapter(spec["extensions"])
603
604
605 # ---------------------------------------------------------------------------
606 # Per-language tree-sitter specs
607 # ---------------------------------------------------------------------------
608
609 _JS_SPEC: LangSpec = {
610 "extensions": frozenset({".js", ".jsx", ".mjs", ".cjs"}),
611 "module_name": "tree_sitter_javascript",
612 "lang_func": "language",
613 # Note: tree-sitter-javascript uses "class" for both class declarations and
614 # named class expressions. "class_expression" is not a valid node type.
615 "query_str": (
616 "(function_declaration name: (identifier) @name) @sym\n"
617 "(function_expression name: (identifier) @name) @sym\n"
618 "(generator_function_declaration name: (identifier) @name) @sym\n"
619 "(class_declaration name: (identifier) @name) @sym\n"
620 "(class name: (identifier) @name) @sym\n"
621 "(method_definition name: (property_identifier) @name) @sym"
622 ),
623 "kind_map": {
624 "function_declaration": "function",
625 "function_expression": "function",
626 "generator_function_declaration": "function",
627 "class_declaration": "class",
628 "class": "class",
629 "method_definition": "method",
630 },
631 "class_node_types": frozenset({"class_declaration", "class"}),
632 "class_name_field": "name",
633 "receiver_capture": "",
634 }
635
636 _TS_QUERY = (
637 # TypeScript uses type_identifier (not identifier) for class names.
638 "(function_declaration name: (identifier) @name) @sym\n"
639 "(function_expression name: (identifier) @name) @sym\n"
640 "(generator_function_declaration name: (identifier) @name) @sym\n"
641 "(class_declaration name: (type_identifier) @name) @sym\n"
642 "(class name: (type_identifier) @name) @sym\n"
643 "(abstract_class_declaration name: (type_identifier) @name) @sym\n"
644 "(method_definition name: (property_identifier) @name) @sym\n"
645 "(interface_declaration name: (type_identifier) @name) @sym\n"
646 "(type_alias_declaration name: (type_identifier) @name) @sym\n"
647 "(enum_declaration name: (identifier) @name) @sym"
648 )
649
650 _TS_KIND_MAP: dict[str, SymbolKind] = {
651 "function_declaration": "function",
652 "function_expression": "function",
653 "generator_function_declaration": "function",
654 "class_declaration": "class",
655 "class": "class",
656 "abstract_class_declaration": "class",
657 "method_definition": "method",
658 "interface_declaration": "class",
659 "type_alias_declaration": "variable",
660 "enum_declaration": "class",
661 }
662
663 _TS_CLASS_NODES: frozenset[str] = frozenset(
664 {"class_declaration", "class", "abstract_class_declaration"}
665 )
666
667 _TS_SPEC: LangSpec = {
668 "extensions": frozenset({".ts"}),
669 "module_name": "tree_sitter_typescript",
670 "lang_func": "language_typescript",
671 "query_str": _TS_QUERY,
672 "kind_map": _TS_KIND_MAP,
673 "class_node_types": _TS_CLASS_NODES,
674 "class_name_field": "name",
675 "receiver_capture": "",
676 }
677
678 _TSX_SPEC: LangSpec = {
679 "extensions": frozenset({".tsx"}),
680 "module_name": "tree_sitter_typescript",
681 "lang_func": "language_tsx",
682 "query_str": _TS_QUERY,
683 "kind_map": _TS_KIND_MAP,
684 "class_node_types": _TS_CLASS_NODES,
685 "class_name_field": "name",
686 "receiver_capture": "",
687 }
688
689 _GO_SPEC: LangSpec = {
690 "extensions": frozenset({".go"}),
691 "module_name": "tree_sitter_go",
692 "lang_func": "language",
693 "query_str": (
694 "(function_declaration name: (identifier) @name) @sym\n"
695 "(method_declaration\n"
696 " receiver: (parameter_list\n"
697 " (parameter_declaration type: _ @recv))\n"
698 " name: (field_identifier) @name) @sym\n"
699 "(type_spec name: (type_identifier) @name) @sym"
700 ),
701 "kind_map": {
702 "function_declaration": "function",
703 "method_declaration": "method",
704 "type_spec": "class",
705 },
706 "class_node_types": frozenset(),
707 "class_name_field": "name",
708 "receiver_capture": "recv",
709 }
710
711 _RUST_SPEC: LangSpec = {
712 "extensions": frozenset({".rs"}),
713 "module_name": "tree_sitter_rust",
714 "lang_func": "language",
715 "query_str": (
716 "(function_item name: (identifier) @name) @sym\n"
717 "(struct_item name: (type_identifier) @name) @sym\n"
718 "(enum_item name: (type_identifier) @name) @sym\n"
719 "(trait_item name: (type_identifier) @name) @sym"
720 ),
721 "kind_map": {
722 "function_item": "function",
723 "struct_item": "class",
724 "enum_item": "class",
725 "trait_item": "class",
726 },
727 # impl_item scopes methods; its implementing type is in the "type" field.
728 "class_node_types": frozenset({"impl_item"}),
729 "class_name_field": "type",
730 "receiver_capture": "",
731 }
732
733 _JAVA_SPEC: LangSpec = {
734 "extensions": frozenset({".java"}),
735 "module_name": "tree_sitter_java",
736 "lang_func": "language",
737 "query_str": (
738 "(method_declaration name: (identifier) @name) @sym\n"
739 "(constructor_declaration name: (identifier) @name) @sym\n"
740 "(class_declaration name: (identifier) @name) @sym\n"
741 "(interface_declaration name: (identifier) @name) @sym\n"
742 "(enum_declaration name: (identifier) @name) @sym"
743 ),
744 "kind_map": {
745 "method_declaration": "method",
746 "constructor_declaration": "function",
747 "class_declaration": "class",
748 "interface_declaration": "class",
749 "enum_declaration": "class",
750 },
751 "class_node_types": frozenset({"class_declaration", "interface_declaration"}),
752 "class_name_field": "name",
753 "receiver_capture": "",
754 }
755
756 _C_SPEC: LangSpec = {
757 "extensions": frozenset({".c", ".h"}),
758 "module_name": "tree_sitter_c",
759 "lang_func": "language",
760 "query_str": (
761 "(function_definition\n"
762 " declarator: (function_declarator\n"
763 " declarator: (identifier) @name)) @sym"
764 ),
765 "kind_map": {"function_definition": "function"},
766 "class_node_types": frozenset(),
767 "class_name_field": "name",
768 "receiver_capture": "",
769 }
770
771 _CPP_SPEC: LangSpec = {
772 "extensions": frozenset({".cpp", ".cc", ".cxx", ".hpp", ".hxx"}),
773 "module_name": "tree_sitter_cpp",
774 "lang_func": "language",
775 "query_str": (
776 "(function_definition\n"
777 " declarator: (function_declarator\n"
778 " declarator: (identifier) @name)) @sym\n"
779 "(class_specifier name: (type_identifier) @name) @sym\n"
780 "(struct_specifier name: (type_identifier) @name) @sym"
781 ),
782 "kind_map": {
783 "function_definition": "function",
784 "class_specifier": "class",
785 "struct_specifier": "class",
786 },
787 "class_node_types": frozenset({"class_specifier", "struct_specifier"}),
788 "class_name_field": "name",
789 "receiver_capture": "",
790 }
791
792 _CS_SPEC: LangSpec = {
793 "extensions": frozenset({".cs"}),
794 "module_name": "tree_sitter_c_sharp",
795 "lang_func": "language",
796 "query_str": (
797 "(method_declaration name: (identifier) @name) @sym\n"
798 "(constructor_declaration name: (identifier) @name) @sym\n"
799 "(class_declaration name: (identifier) @name) @sym\n"
800 "(interface_declaration name: (identifier) @name) @sym\n"
801 "(enum_declaration name: (identifier) @name) @sym\n"
802 "(struct_declaration name: (identifier) @name) @sym"
803 ),
804 "kind_map": {
805 "method_declaration": "method",
806 "constructor_declaration": "function",
807 "class_declaration": "class",
808 "interface_declaration": "class",
809 "enum_declaration": "class",
810 "struct_declaration": "class",
811 },
812 "class_node_types": frozenset(
813 {"class_declaration", "interface_declaration", "struct_declaration"}
814 ),
815 "class_name_field": "name",
816 "receiver_capture": "",
817 }
818
819 _RUBY_SPEC: LangSpec = {
820 "extensions": frozenset({".rb"}),
821 "module_name": "tree_sitter_ruby",
822 "lang_func": "language",
823 "query_str": (
824 "(method name: (identifier) @name) @sym\n"
825 "(singleton_method name: (identifier) @name) @sym\n"
826 "(class name: (constant) @name) @sym\n"
827 "(module name: (constant) @name) @sym"
828 ),
829 "kind_map": {
830 "method": "method",
831 "singleton_method": "method",
832 "class": "class",
833 "module": "class",
834 },
835 "class_node_types": frozenset({"class", "module"}),
836 "class_name_field": "name",
837 "receiver_capture": "",
838 }
839
840 _KT_SPEC: LangSpec = {
841 "extensions": frozenset({".kt", ".kts"}),
842 "module_name": "tree_sitter_kotlin",
843 "lang_func": "language",
844 "query_str": (
845 "(function_declaration (identifier) @name) @sym\n"
846 "(class_declaration (identifier) @name) @sym"
847 ),
848 "kind_map": {
849 "function_declaration": "function",
850 "class_declaration": "class",
851 },
852 # Kotlin methods are function_declaration nodes inside class_body.
853 # child_by_field_name("name") is None for Kotlin classes; _class_name_from
854 # falls back to the first identifier-typed named child automatically.
855 "class_node_types": frozenset({"class_declaration"}),
856 "class_name_field": "name",
857 "receiver_capture": "",
858 }
859
860 #: All tree-sitter language specs, loaded in registration order.
861 _TS_LANG_SPECS: list[LangSpec] = [
862 _JS_SPEC,
863 _TS_SPEC,
864 _TSX_SPEC,
865 _GO_SPEC,
866 _RUST_SPEC,
867 _JAVA_SPEC,
868 _C_SPEC,
869 _CPP_SPEC,
870 _CS_SPEC,
871 _RUBY_SPEC,
872 _KT_SPEC,
873 ]
874
875
876 # ---------------------------------------------------------------------------
877 # Adapter registry and public helpers
878 # ---------------------------------------------------------------------------
879
880 _PYTHON = PythonAdapter()
881 _FALLBACK = FallbackAdapter(frozenset())
882
883 #: Adapters checked in order; first match wins.
884 ADAPTERS: list[LanguageAdapter] = [_PYTHON]
885
886 # Build and register tree-sitter adapters. _make_ts_adapter degrades to
887 # FallbackAdapter if a grammar package isn't installed.
888 for _spec in _TS_LANG_SPECS:
889 ADAPTERS.append(_make_ts_adapter(_spec))
890
891 #: File extensions that receive semantic (AST-based) symbol extraction.
892 SEMANTIC_EXTENSIONS: frozenset[str] = frozenset().union(
893 *(a.supported_extensions() for a in ADAPTERS if not isinstance(a, FallbackAdapter))
894 )
895
896 #: Source extensions tracked as first-class files (raw-bytes identity for
897 #: languages without an AST adapter, AST identity for Python).
898 SOURCE_EXTENSIONS: frozenset[str] = frozenset({
899 ".py", ".pyi",
900 ".ts", ".tsx", ".js", ".jsx", ".mjs", ".cjs",
901 ".swift",
902 ".go",
903 ".rs",
904 ".java",
905 ".c", ".cpp", ".cc", ".cxx", ".h", ".hpp",
906 ".rb",
907 ".kt",
908 ".cs",
909 ".sh", ".bash", ".zsh",
910 ".toml", ".yaml", ".yml", ".json", ".jsonc",
911 ".md", ".rst", ".txt",
912 ".css", ".scss", ".html",
913 ".sql",
914 ".proto",
915 ".tf",
916 })
917
918
919 def adapter_for_path(file_path: str) -> LanguageAdapter:
920 """Return the best :class:`LanguageAdapter` for *file_path*.
921
922 Checks registered adapters in order; falls back to
923 :class:`FallbackAdapter` when no adapter claims the suffix.
924
925 Args:
926 file_path: Workspace-relative POSIX path (e.g. ``"src/utils.py"``).
927
928 Returns:
929 The first adapter whose :meth:`~LanguageAdapter.supported_extensions`
930 set contains the file's lowercase suffix.
931 """
932 suffix = pathlib.PurePosixPath(file_path).suffix.lower()
933 for adapter in ADAPTERS:
934 if suffix in adapter.supported_extensions():
935 return adapter
936 return _FALLBACK
937
938
939 def parse_symbols(source: bytes, file_path: str) -> SymbolTree:
940 """Parse *source* with the best available adapter for *file_path*.
941
942 Args:
943 source: Raw bytes of the source file.
944 file_path: Workspace-relative POSIX path.
945
946 Returns:
947 A :type:`SymbolTree` (may be empty for unsupported file types).
948 """
949 return adapter_for_path(file_path).parse_symbols(source, file_path)
950
951
952 def file_content_id(source: bytes, file_path: str) -> str:
953 """Return the semantic content ID for *file_path* given its raw *source*.
954
955 Args:
956 source: Raw bytes of the file.
957 file_path: Workspace-relative POSIX path.
958
959 Returns:
960 Hex-encoded SHA-256 digest — AST-based for Python, raw-bytes for others.
961 """
962 return adapter_for_path(file_path).file_content_id(source)
963
964
965 def _first_error_node(node: Node) -> Node | None:
966 """Return the first ERROR or MISSING node in *node*'s subtree, depth-first."""
967 if node.type == "ERROR" or node.is_missing:
968 return node
969 for child in node.children:
970 found = _first_error_node(child)
971 if found is not None:
972 return found
973 return None
974
975
976 def validate_syntax(source: bytes, file_path: str) -> str | None:
977 """Return a human-readable error description if *source* has syntax errors.
978
979 Covers Python (via :mod:`ast`) and all tree-sitter languages. Returns
980 ``None`` for valid files and for file types without a parser.
981
982 This is used by ``muse patch`` to verify that a surgical replacement
983 does not introduce a syntax error before writing the result to disk.
984
985 Args:
986 source: UTF-8 encoded source bytes to validate.
987 file_path: Workspace-relative path — used to select the parser.
988
989 Returns:
990 A human-readable error string, or ``None`` if the file is valid.
991 """
992 suffix = pathlib.PurePosixPath(file_path).suffix.lower()
993
994 if suffix in {".py", ".pyi"}:
995 try:
996 ast.parse(source)
997 return None
998 except SyntaxError as exc:
999 return f"syntax error on line {exc.lineno}: {exc.msg}"
1000
1001 adapter = adapter_for_path(file_path)
1002 if isinstance(adapter, TreeSitterAdapter):
1003 return adapter.validate_source(source)
1004
1005 return None