gabriel / muse public
cli_test_helper.py python
201 lines 6.2 KB
86000da9 fix: replace typer CliRunner with argparse-compatible test helper Gabriel Cardona <gabriel@tellurstori.com> 1d ago
1 """Argparse-compatible CliRunner replacement for Muse test suite.
2
3 Replaces ``typer.testing.CliRunner`` so tests can call ``runner.invoke(cli,
4 args)`` without modification after the typer → argparse migration. The first
5 argument (``cli``) is ignored; ``muse.cli.app.main`` is always the target.
6 """
7
8 from __future__ import annotations
9
10 import contextlib
11 import io
12 import os
13 import re
14 import sys
15 import traceback
16 from typing import Any
17
18 from muse.cli.app import main
19
20 _ANSI_ESCAPE = re.compile(r"\x1b\[[0-9;]*m")
21
22
23 def _strip_ansi(text: str) -> str:
24 """Remove ANSI escape sequences — typer's CliRunner did this automatically."""
25 return _ANSI_ESCAPE.sub("", text)
26
27
28 class _StdinWithBuffer:
29 """Text-mode stdin wrapper with a ``.buffer`` attribute for binary reads.
30
31 Some plumbing commands (e.g. ``unpack-objects``) read raw bytes from
32 ``sys.stdin.buffer``. StringIO has no ``.buffer``, so we wrap it.
33 """
34
35 def __init__(self, text: str) -> None:
36 self._text = io.StringIO(text)
37 self.buffer = io.BytesIO(text.encode())
38
39 def read(self, n: int = -1) -> str:
40 return self._text.read(n)
41
42 def readline(self) -> str:
43 return self._text.readline()
44
45 def isatty(self) -> bool:
46 return False
47
48
49 class _StdoutCapture:
50 """Text-mode stdout wrapper with a ``.buffer`` attribute for binary writes.
51
52 Some plumbing commands (e.g. ``cat-object``) write raw bytes to
53 ``sys.stdout.buffer``. StringIO has no ``.buffer``, so we wrap it with
54 a companion BytesIO and decode its bytes into the combined output.
55 """
56
57 def __init__(self) -> None:
58 self._text = io.StringIO()
59 self.buffer = io.BytesIO()
60
61 # --- text-mode interface ------------------------------------------------
62 def write(self, s: str) -> int:
63 return self._text.write(s)
64
65 def writelines(self, lines: list[str]) -> None:
66 self._text.writelines(lines)
67
68 def flush(self) -> None:
69 self._text.flush()
70
71 def isatty(self) -> bool:
72 return False
73
74 # --- value retrieval ----------------------------------------------------
75 def getvalue(self) -> str:
76 text_out = self._text.getvalue()
77 bytes_out = self.buffer.getvalue()
78 if bytes_out:
79 try:
80 text_out += bytes_out.decode("utf-8", errors="replace")
81 except Exception:
82 pass
83 return text_out
84
85
86 def _restore_env(saved: dict[str, str | None]) -> None:
87 """Restore environment variables to their pre-invoke state."""
88 for k, orig in saved.items():
89 if orig is None:
90 os.environ.pop(k, None)
91 else:
92 os.environ[k] = orig
93
94
95 class InvokeResult:
96 """Mirrors the fields that typer.testing.Result exposed."""
97
98 def __init__(
99 self,
100 exit_code: int,
101 output: str,
102 stderr_output: str = "",
103 stdout_bytes: bytes = b"",
104 ) -> None:
105 self.exit_code = exit_code
106 self.output = output
107 self.stdout = output
108 self.stderr = stderr_output
109 self.stdout_bytes = stdout_bytes
110 self.exception: BaseException | None = None
111
112 def __repr__(self) -> str:
113 return f"InvokeResult(exit_code={self.exit_code}, output={self.output!r})"
114
115
116 class CliRunner:
117 """Drop-in replacement for ``typer.testing.CliRunner``.
118
119 Captures stdout and stderr, calls ``main(args)``, and returns an
120 ``InvokeResult`` whose interface matches the typer equivalent closely
121 enough for the existing test suite to run without changes.
122
123 Honoured parameters:
124 - ``env``: key/value pairs set in ``os.environ`` for the duration of the
125 call and restored afterward.
126 - ``input``: string fed to ``sys.stdin`` (needed by ``unpack-objects``).
127 - ``catch_exceptions``: when False, exceptions propagate to the caller.
128 """
129
130 def invoke(
131 self,
132 _cli: Any,
133 args: list[str],
134 catch_exceptions: bool = True,
135 input: str | None = None,
136 env: dict[str, str] | None = None,
137 **_kwargs: Any,
138 ) -> InvokeResult:
139 """Invoke ``main(args)`` and return captured output + exit code."""
140 # Apply caller-supplied env overrides; restore originals when done.
141 saved: dict[str, str | None] = {}
142 if env:
143 for k, v in env.items():
144 saved[k] = os.environ.get(k)
145 os.environ[k] = v
146
147 stdout_cap = _StdoutCapture()
148 stderr_buf = io.StringIO()
149 exit_code = 0
150
151 # Patch sys.stdin for commands that read from it (e.g. unpack-objects).
152 # Use _StdinWithBuffer so sys.stdin.buffer is also available.
153 orig_stdin = sys.stdin
154 if input is not None:
155 sys.stdin = _StdinWithBuffer(input) # type: ignore[assignment]
156
157 try:
158 with (
159 contextlib.redirect_stdout(stdout_cap), # type: ignore[arg-type]
160 contextlib.redirect_stderr(stderr_buf),
161 ):
162 main(list(args))
163 except SystemExit as exc:
164 raw = exc.code
165 if isinstance(raw, int):
166 exit_code = raw
167 elif hasattr(raw, "value"):
168 exit_code = int(raw.value)
169 elif raw is None:
170 exit_code = 0
171 else:
172 exit_code = int(raw)
173 except Exception as exc:
174 if not catch_exceptions:
175 sys.stdin = orig_stdin
176 _restore_env(saved)
177 raise
178 stderr_buf.write(traceback.format_exc())
179 exit_code = 1
180 result = InvokeResult(
181 exit_code,
182 _strip_ansi(stdout_cap.getvalue() + stderr_buf.getvalue()),
183 _strip_ansi(stderr_buf.getvalue()),
184 stdout_bytes=stdout_cap.buffer.getvalue(),
185 )
186 result.exception = exc
187 sys.stdin = orig_stdin
188 _restore_env(saved)
189 return result
190 finally:
191 sys.stdin = orig_stdin
192 _restore_env(saved)
193
194 raw_bytes = stdout_cap.buffer.getvalue()
195 combined = _strip_ansi(stdout_cap.getvalue() + stderr_buf.getvalue())
196 return InvokeResult(
197 exit_code,
198 combined,
199 _strip_ansi(stderr_buf.getvalue()),
200 stdout_bytes=raw_bytes,
201 )