gabriel / muse public
cli_test_helper.py python
191 lines 6.1 KB
dec4604a feat(mwp): replace JSON+base64 wire protocol with MWP binary msgpack Gabriel Cardona <gabriel@tellurstori.com> 13h ago
1 """Argparse-compatible CliRunner replacement for Muse test suite.
2
3 Replaces ``typer.testing.CliRunner`` so tests can call ``runner.invoke(cli,
4 args)`` without modification after the typer → argparse migration. The first
5 argument (``cli``) is always ``None`` (a stub) after migration; it is accepted
6 but ignored, and ``muse.cli.app.main`` is always the target.
7 """
8
9 from __future__ import annotations
10
11 import contextlib
12 import io
13 import os
14 import re
15 import sys
16 import traceback
17
18 from muse.cli.app import main
19
20 _ANSI_ESCAPE = re.compile(r"\x1b\[[0-9;]*m")
21
22
23 def _strip_ansi(text: str) -> str:
24 """Remove ANSI escape sequences — typer's CliRunner did this automatically."""
25 return _ANSI_ESCAPE.sub("", text)
26
27
28 class _StdinWithBuffer(io.StringIO):
29 """Text-mode stdin backed by StringIO, with a ``.buffer`` BytesIO sibling.
30
31 Accepts either a ``str`` (text) or ``bytes`` (binary). When binary is
32 provided the text surface is left empty — commands that read from
33 ``sys.stdin.buffer`` (e.g. ``unpack-objects``) get the raw bytes while
34 text-mode reads get an empty string.
35 """
36
37 def __init__(self, text: str | bytes) -> None:
38 if isinstance(text, bytes):
39 super().__init__("")
40 self.buffer = io.BytesIO(text)
41 else:
42 super().__init__(text)
43 self.buffer = io.BytesIO(text.encode())
44
45 def isatty(self) -> bool:
46 return False
47
48
49 class _StdoutCapture(io.StringIO):
50 """Text-mode stdout backed by StringIO, with a ``.buffer`` BytesIO sibling.
51
52 Some plumbing commands (e.g. ``cat-object``) write raw bytes to
53 ``sys.stdout.buffer``. Subclassing ``StringIO`` makes this assignable to
54 ``sys.stdout`` (and passable to ``contextlib.redirect_stdout``) without
55 any type annotation workaround. Binary output is decoded and appended to
56 the text output in ``getvalue()``.
57 """
58
59 def __init__(self) -> None:
60 super().__init__()
61 self.buffer = io.BytesIO()
62
63 def isatty(self) -> bool:
64 return False
65
66 def getvalue(self) -> str:
67 text_out = super().getvalue()
68 bytes_out = self.buffer.getvalue()
69 if bytes_out:
70 try:
71 text_out += bytes_out.decode("utf-8", errors="replace")
72 except Exception:
73 pass
74 return text_out
75
76
77 def _restore_env(saved: dict[str, str | None]) -> None:
78 """Restore environment variables to their pre-invoke state."""
79 for k, orig in saved.items():
80 if orig is None:
81 os.environ.pop(k, None)
82 else:
83 os.environ[k] = orig
84
85
86 class InvokeResult:
87 """Mirrors the fields that typer.testing.Result exposed."""
88
89 def __init__(
90 self,
91 exit_code: int,
92 output: str,
93 stderr_output: str = "",
94 stdout_bytes: bytes = b"",
95 ) -> None:
96 self.exit_code = exit_code
97 self.output = output
98 self.stdout = output
99 self.stderr = stderr_output
100 self.stdout_bytes = stdout_bytes
101 self.exception: BaseException | None = None
102
103 def __repr__(self) -> str:
104 return f"InvokeResult(exit_code={self.exit_code}, output={self.output!r})"
105
106
107 class CliRunner:
108 """Drop-in replacement for ``typer.testing.CliRunner``.
109
110 Captures stdout and stderr, calls ``main(args)``, and returns an
111 ``InvokeResult`` whose interface matches the typer equivalent closely
112 enough for the existing test suite to run without changes.
113
114 Honoured parameters:
115 - ``env``: key/value pairs set in ``os.environ`` for the duration of the
116 call and restored afterward.
117 - ``input``: string fed to ``sys.stdin`` (needed by ``unpack-objects``).
118 - ``catch_exceptions``: when False, exceptions propagate to the caller.
119 """
120
121 def invoke(
122 self,
123 _cli: None,
124 args: list[str],
125 catch_exceptions: bool = True,
126 input: str | bytes | None = None,
127 env: dict[str, str] | None = None,
128 ) -> InvokeResult:
129 """Invoke ``main(args)`` and return captured output + exit code."""
130 # Apply caller-supplied env overrides; restore originals when done.
131 saved: dict[str, str | None] = {}
132 if env:
133 for k, v in env.items():
134 saved[k] = os.environ.get(k)
135 os.environ[k] = v
136
137 stdout_cap = _StdoutCapture()
138 stderr_buf = io.StringIO()
139 exit_code = 0
140
141 # Patch sys.stdin for commands that read from it (e.g. unpack-objects).
142 # _StdinWithBuffer subclasses StringIO so the assignment is well-typed.
143 orig_stdin = sys.stdin
144 if input is not None:
145 sys.stdin = _StdinWithBuffer(input)
146
147 try:
148 with (
149 contextlib.redirect_stdout(stdout_cap),
150 contextlib.redirect_stderr(stderr_buf),
151 ):
152 main(list(args))
153 except SystemExit as exc:
154 raw = exc.code
155 if isinstance(raw, int):
156 exit_code = raw
157 elif hasattr(raw, "value"):
158 exit_code = int(raw.value)
159 elif raw is None:
160 exit_code = 0
161 else:
162 exit_code = int(raw)
163 except Exception as exc:
164 if not catch_exceptions:
165 sys.stdin = orig_stdin
166 _restore_env(saved)
167 raise
168 stderr_buf.write(traceback.format_exc())
169 exit_code = 1
170 result = InvokeResult(
171 exit_code,
172 _strip_ansi(stdout_cap.getvalue() + stderr_buf.getvalue()),
173 _strip_ansi(stderr_buf.getvalue()),
174 stdout_bytes=stdout_cap.buffer.getvalue(),
175 )
176 result.exception = exc
177 sys.stdin = orig_stdin
178 _restore_env(saved)
179 return result
180 finally:
181 sys.stdin = orig_stdin
182 _restore_env(saved)
183
184 raw_bytes = stdout_cap.buffer.getvalue()
185 combined = _strip_ansi(stdout_cap.getvalue() + stderr_buf.getvalue())
186 return InvokeResult(
187 exit_code,
188 combined,
189 _strip_ansi(stderr_buf.getvalue()),
190 stdout_bytes=raw_bytes,
191 )