gabriel / muse public
cli_test_helper.py python
187 lines 5.9 KB
5175b6a3 fix: resolve typing_audit violations in cli_test_helper Gabriel Cardona <gabriel@tellurstori.com> 1d ago
1 """Argparse-compatible CliRunner replacement for Muse test suite.
2
3 Replaces ``typer.testing.CliRunner`` so tests can call ``runner.invoke(cli,
4 args)`` without modification after the typer → argparse migration. The first
5 argument (``cli``) is always ``None`` (a stub) after migration; it is accepted
6 but ignored, and ``muse.cli.app.main`` is always the target.
7 """
8
9 from __future__ import annotations
10
11 import contextlib
12 import io
13 import os
14 import re
15 import sys
16 import traceback
17
18 from muse.cli.app import main
19
20 _ANSI_ESCAPE = re.compile(r"\x1b\[[0-9;]*m")
21
22
23 def _strip_ansi(text: str) -> str:
24 """Remove ANSI escape sequences — typer's CliRunner did this automatically."""
25 return _ANSI_ESCAPE.sub("", text)
26
27
28 class _StdinWithBuffer(io.StringIO):
29 """Text-mode stdin backed by StringIO, with a ``.buffer`` BytesIO sibling.
30
31 Some plumbing commands (e.g. ``unpack-objects``) read raw bytes from
32 ``sys.stdin.buffer``. Plain ``StringIO`` has no ``.buffer``; subclassing
33 it lets us assign to ``sys.stdin`` without a type annotation workaround
34 while still exposing the binary-read surface.
35 """
36
37 def __init__(self, text: str) -> None:
38 super().__init__(text)
39 self.buffer = io.BytesIO(text.encode())
40
41 def isatty(self) -> bool:
42 return False
43
44
45 class _StdoutCapture(io.StringIO):
46 """Text-mode stdout backed by StringIO, with a ``.buffer`` BytesIO sibling.
47
48 Some plumbing commands (e.g. ``cat-object``) write raw bytes to
49 ``sys.stdout.buffer``. Subclassing ``StringIO`` makes this assignable to
50 ``sys.stdout`` (and passable to ``contextlib.redirect_stdout``) without
51 any type annotation workaround. Binary output is decoded and appended to
52 the text output in ``getvalue()``.
53 """
54
55 def __init__(self) -> None:
56 super().__init__()
57 self.buffer = io.BytesIO()
58
59 def isatty(self) -> bool:
60 return False
61
62 def getvalue(self) -> str:
63 text_out = super().getvalue()
64 bytes_out = self.buffer.getvalue()
65 if bytes_out:
66 try:
67 text_out += bytes_out.decode("utf-8", errors="replace")
68 except Exception:
69 pass
70 return text_out
71
72
73 def _restore_env(saved: dict[str, str | None]) -> None:
74 """Restore environment variables to their pre-invoke state."""
75 for k, orig in saved.items():
76 if orig is None:
77 os.environ.pop(k, None)
78 else:
79 os.environ[k] = orig
80
81
82 class InvokeResult:
83 """Mirrors the fields that typer.testing.Result exposed."""
84
85 def __init__(
86 self,
87 exit_code: int,
88 output: str,
89 stderr_output: str = "",
90 stdout_bytes: bytes = b"",
91 ) -> None:
92 self.exit_code = exit_code
93 self.output = output
94 self.stdout = output
95 self.stderr = stderr_output
96 self.stdout_bytes = stdout_bytes
97 self.exception: BaseException | None = None
98
99 def __repr__(self) -> str:
100 return f"InvokeResult(exit_code={self.exit_code}, output={self.output!r})"
101
102
103 class CliRunner:
104 """Drop-in replacement for ``typer.testing.CliRunner``.
105
106 Captures stdout and stderr, calls ``main(args)``, and returns an
107 ``InvokeResult`` whose interface matches the typer equivalent closely
108 enough for the existing test suite to run without changes.
109
110 Honoured parameters:
111 - ``env``: key/value pairs set in ``os.environ`` for the duration of the
112 call and restored afterward.
113 - ``input``: string fed to ``sys.stdin`` (needed by ``unpack-objects``).
114 - ``catch_exceptions``: when False, exceptions propagate to the caller.
115 """
116
117 def invoke(
118 self,
119 _cli: None,
120 args: list[str],
121 catch_exceptions: bool = True,
122 input: str | None = None,
123 env: dict[str, str] | None = None,
124 ) -> InvokeResult:
125 """Invoke ``main(args)`` and return captured output + exit code."""
126 # Apply caller-supplied env overrides; restore originals when done.
127 saved: dict[str, str | None] = {}
128 if env:
129 for k, v in env.items():
130 saved[k] = os.environ.get(k)
131 os.environ[k] = v
132
133 stdout_cap = _StdoutCapture()
134 stderr_buf = io.StringIO()
135 exit_code = 0
136
137 # Patch sys.stdin for commands that read from it (e.g. unpack-objects).
138 # _StdinWithBuffer subclasses StringIO so the assignment is well-typed.
139 orig_stdin = sys.stdin
140 if input is not None:
141 sys.stdin = _StdinWithBuffer(input)
142
143 try:
144 with (
145 contextlib.redirect_stdout(stdout_cap),
146 contextlib.redirect_stderr(stderr_buf),
147 ):
148 main(list(args))
149 except SystemExit as exc:
150 raw = exc.code
151 if isinstance(raw, int):
152 exit_code = raw
153 elif hasattr(raw, "value"):
154 exit_code = int(raw.value)
155 elif raw is None:
156 exit_code = 0
157 else:
158 exit_code = int(raw)
159 except Exception as exc:
160 if not catch_exceptions:
161 sys.stdin = orig_stdin
162 _restore_env(saved)
163 raise
164 stderr_buf.write(traceback.format_exc())
165 exit_code = 1
166 result = InvokeResult(
167 exit_code,
168 _strip_ansi(stdout_cap.getvalue() + stderr_buf.getvalue()),
169 _strip_ansi(stderr_buf.getvalue()),
170 stdout_bytes=stdout_cap.buffer.getvalue(),
171 )
172 result.exception = exc
173 sys.stdin = orig_stdin
174 _restore_env(saved)
175 return result
176 finally:
177 sys.stdin = orig_stdin
178 _restore_env(saved)
179
180 raw_bytes = stdout_cap.buffer.getvalue()
181 combined = _strip_ansi(stdout_cap.getvalue() + stderr_buf.getvalue())
182 return InvokeResult(
183 exit_code,
184 combined,
185 _strip_ansi(stderr_buf.getvalue()),
186 stdout_bytes=raw_bytes,
187 )