diff options
author | David Lord <davidism@gmail.com> | 2021-04-23 08:38:47 -0700 |
---|---|---|
committer | David Lord <davidism@gmail.com> | 2021-04-23 08:38:47 -0700 |
commit | 0103c9570650daa59560baf42ad0a27e57b3157f (patch) | |
tree | 20c3c4ea3d69c24b91401881a903a4d822448dce /src/click/testing.py | |
parent | 77ce48f8d7d3b64a09741cf53dd2995d668317cf (diff) | |
download | click-0103c9570650daa59560baf42ad0a27e57b3157f.tar.gz |
add typing annotations
Diffstat (limited to 'src/click/testing.py')
-rw-r--r-- | src/click/testing.py | 171 |
1 files changed, 102 insertions, 69 deletions
diff --git a/src/click/testing.py b/src/click/testing.py index 637c46c..d19b850 100644 --- a/src/click/testing.py +++ b/src/click/testing.py @@ -5,49 +5,54 @@ import shlex import shutil import sys import tempfile +import typing as t +from types import TracebackType from . import formatting from . import termui from . import utils from ._compat import _find_binary_reader +if t.TYPE_CHECKING: + from .core import BaseCommand + class EchoingStdin: - def __init__(self, input, output): + def __init__(self, input: t.BinaryIO, output: t.BinaryIO) -> None: self._input = input self._output = output self._paused = False - def __getattr__(self, x): + def __getattr__(self, x: str) -> t.Any: return getattr(self._input, x) - def _echo(self, rv): + def _echo(self, rv: bytes) -> bytes: if not self._paused: self._output.write(rv) return rv - def read(self, n=-1): + def read(self, n: int = -1) -> bytes: return self._echo(self._input.read(n)) - def read1(self, n=-1): - return self._echo(self._input.read1(n)) + def read1(self, n: int = -1) -> bytes: + return self._echo(self._input.read1(n)) # type: ignore - def readline(self, n=-1): + def readline(self, n: int = -1) -> bytes: return self._echo(self._input.readline(n)) - def readlines(self): + def readlines(self) -> t.List[bytes]: return [self._echo(x) for x in self._input.readlines()] - def __iter__(self): + def __iter__(self) -> t.Iterator[bytes]: return iter(self._echo(x) for x in self._input) - def __repr__(self): + def __repr__(self) -> str: return repr(self._input) @contextlib.contextmanager -def _pause_echo(stream): +def _pause_echo(stream: t.Optional[EchoingStdin]) -> t.Iterator[None]: if stream is None: yield else: @@ -57,24 +62,28 @@ def _pause_echo(stream): class _NamedTextIOWrapper(io.TextIOWrapper): - def __init__(self, buffer, name=None, mode=None, **kwargs): + def __init__( + self, buffer: t.BinaryIO, name: str, mode: str, **kwargs: t.Any + ) -> None: super().__init__(buffer, **kwargs) self._name = name self._mode = mode @property - def name(self): + def name(self) -> str: return self._name @property - def mode(self): + def mode(self) -> str: return self._mode -def make_input_stream(input, charset): +def make_input_stream( + input: t.Optional[t.Union[str, bytes, t.IO]], charset: str +) -> t.BinaryIO: # Is already an input stream. if hasattr(input, "read"): - rv = _find_binary_reader(input) + rv = _find_binary_reader(t.cast(t.IO, input)) if rv is not None: return rv @@ -83,10 +92,10 @@ def make_input_stream(input, charset): if input is None: input = b"" - elif not isinstance(input, bytes): + elif isinstance(input, str): input = input.encode(charset) - return io.BytesIO(input) + return io.BytesIO(t.cast(bytes, input)) class Result: @@ -94,13 +103,15 @@ class Result: def __init__( self, - runner, - stdout_bytes, - stderr_bytes, - return_value, - exit_code, - exception, - exc_info=None, + runner: "CliRunner", + stdout_bytes: bytes, + stderr_bytes: t.Optional[bytes], + return_value: t.Any, + exit_code: int, + exception: t.Optional[BaseException], + exc_info: t.Optional[ + t.Tuple[t.Type[BaseException], BaseException, TracebackType] + ] = None, ): #: The runner that created the result self.runner = runner @@ -120,19 +131,19 @@ class Result: self.exc_info = exc_info @property - def output(self): + def output(self) -> str: """The (standard) output as unicode string.""" return self.stdout @property - def stdout(self): + def stdout(self) -> str: """The standard output as unicode string.""" return self.stdout_bytes.decode(self.runner.charset, "replace").replace( "\r\n", "\n" ) @property - def stderr(self): + def stderr(self) -> str: """The standard error as unicode string.""" if self.stderr_bytes is None: raise ValueError("stderr not separately captured") @@ -140,7 +151,7 @@ class Result: "\r\n", "\n" ) - def __repr__(self): + def __repr__(self) -> str: exc_str = repr(self.exception) if self.exception else "okay" return f"<{type(self).__name__} {exc_str}>" @@ -164,20 +175,28 @@ class CliRunner: independently """ - def __init__(self, charset="utf-8", env=None, echo_stdin=False, mix_stderr=True): + def __init__( + self, + charset: str = "utf-8", + env: t.Optional[t.Mapping[str, t.Optional[str]]] = None, + echo_stdin: bool = False, + mix_stderr: bool = True, + ) -> None: self.charset = charset self.env = env or {} self.echo_stdin = echo_stdin self.mix_stderr = mix_stderr - def get_default_prog_name(self, cli): + def get_default_prog_name(self, cli: "BaseCommand") -> str: """Given a command object it will return the default program name for it. The default is the `name` attribute or ``"root"`` if not set. """ return cli.name or "root" - def make_env(self, overrides=None): + def make_env( + self, overrides: t.Optional[t.Mapping[str, t.Optional[str]]] = None + ) -> t.Mapping[str, t.Optional[str]]: """Returns the environment overrides for invoking a script.""" rv = dict(self.env) if overrides: @@ -185,7 +204,12 @@ class CliRunner: return rv @contextlib.contextmanager - def isolation(self, input=None, env=None, color=False): + def isolation( + self, + input: t.Optional[t.Union[str, bytes, t.IO]] = None, + env: t.Optional[t.Mapping[str, t.Optional[str]]] = None, + color: bool = False, + ) -> t.Iterator[t.Tuple[io.BytesIO, t.Optional[io.BytesIO]]]: """A context manager that sets up the isolation for invoking of a command line tool. This sets up stdin with the given input data and `os.environ` with the overrides from the given dictionary. @@ -206,7 +230,7 @@ class CliRunner: .. versionchanged:: 4.0 Added the ``color`` parameter. """ - input = make_input_stream(input, self.charset) + bytes_input = make_input_stream(input, self.charset) echo_input = None old_stdin = sys.stdin @@ -220,16 +244,18 @@ class CliRunner: bytes_output = io.BytesIO() if self.echo_stdin: - input = echo_input = EchoingStdin(input, bytes_output) + bytes_input = echo_input = t.cast( + t.BinaryIO, EchoingStdin(bytes_input, bytes_output) + ) - sys.stdin = input = _NamedTextIOWrapper( - input, encoding=self.charset, name="<stdin>", mode="r" + sys.stdin = text_input = _NamedTextIOWrapper( + bytes_input, encoding=self.charset, name="<stdin>", mode="r" ) if self.echo_stdin: # Force unbuffered reads, otherwise TextIOWrapper reads a # large chunk which is echoed early. - input._CHUNK_SIZE = 1 + text_input._CHUNK_SIZE = 1 # type: ignore sys.stdout = _NamedTextIOWrapper( bytes_output, encoding=self.charset, name="<stdout>", mode="w" @@ -248,22 +274,22 @@ class CliRunner: errors="backslashreplace", ) - @_pause_echo(echo_input) - def visible_input(prompt=None): + @_pause_echo(echo_input) # type: ignore + def visible_input(prompt: t.Optional[str] = None) -> str: sys.stdout.write(prompt or "") - val = input.readline().rstrip("\r\n") + val = text_input.readline().rstrip("\r\n") sys.stdout.write(f"{val}\n") sys.stdout.flush() return val - @_pause_echo(echo_input) - def hidden_input(prompt=None): + @_pause_echo(echo_input) # type: ignore + def hidden_input(prompt: t.Optional[str] = None) -> str: sys.stdout.write(f"{prompt or ''}\n") sys.stdout.flush() - return input.readline().rstrip("\r\n") + return text_input.readline().rstrip("\r\n") - @_pause_echo(echo_input) - def _getchar(echo): + @_pause_echo(echo_input) # type: ignore + def _getchar(echo: bool) -> str: char = sys.stdin.read(1) if echo: @@ -274,7 +300,9 @@ class CliRunner: default_color = color - def should_strip_ansi(stream=None, color=None): + def should_strip_ansi( + stream: t.Optional[t.IO] = None, color: t.Optional[bool] = None + ) -> bool: if color is None: return not default_color return not color @@ -282,11 +310,11 @@ class CliRunner: old_visible_prompt_func = termui.visible_prompt_func old_hidden_prompt_func = termui.hidden_prompt_func old__getchar_func = termui._getchar - old_should_strip_ansi = utils.should_strip_ansi + old_should_strip_ansi = utils.should_strip_ansi # type: ignore termui.visible_prompt_func = visible_input termui.hidden_prompt_func = hidden_input termui._getchar = _getchar - utils.should_strip_ansi = should_strip_ansi + utils.should_strip_ansi = should_strip_ansi # type: ignore old_env = {} try: @@ -315,19 +343,19 @@ class CliRunner: termui.visible_prompt_func = old_visible_prompt_func termui.hidden_prompt_func = old_hidden_prompt_func termui._getchar = old__getchar_func - utils.should_strip_ansi = old_should_strip_ansi + utils.should_strip_ansi = old_should_strip_ansi # type: ignore formatting.FORCED_WIDTH = old_forced_width def invoke( self, - cli, - args=None, - input=None, - env=None, - catch_exceptions=True, - color=False, - **extra, - ): + cli: "BaseCommand", + args: t.Optional[t.Union[str, t.Sequence[str]]] = None, + input: t.Optional[t.Union[str, bytes, t.IO]] = None, + env: t.Optional[t.Mapping[str, t.Optional[str]]] = None, + catch_exceptions: bool = True, + color: bool = False, + **extra: t.Any, + ) -> Result: """Invokes a command in an isolated environment. The arguments are forwarded directly to the command line script, the `extra` keyword arguments are passed to the :meth:`~clickpkg.Command.main` function of @@ -365,7 +393,7 @@ class CliRunner: exc_info = None with self.isolation(input=input, env=env, color=color) as outstreams: return_value = None - exception = None + exception: t.Optional[BaseException] = None exit_code = 0 if isinstance(args, str): @@ -380,17 +408,20 @@ class CliRunner: return_value = cli.main(args=args or (), prog_name=prog_name, **extra) except SystemExit as e: exc_info = sys.exc_info() - exit_code = e.code - if exit_code is None: - exit_code = 0 + e_code = t.cast(t.Optional[t.Union[int, t.Any]], e.code) - if exit_code != 0: + if e_code is None: + e_code = 0 + + if e_code != 0: exception = e - if not isinstance(exit_code, int): - sys.stdout.write(str(exit_code)) + if not isinstance(e_code, int): + sys.stdout.write(str(e_code)) sys.stdout.write("\n") - exit_code = 1 + e_code = 1 + + exit_code = e_code except Exception as e: if not catch_exceptions: @@ -404,7 +435,7 @@ class CliRunner: if self.mix_stderr: stderr = None else: - stderr = outstreams[1].getvalue() + stderr = outstreams[1].getvalue() # type: ignore return Result( runner=self, @@ -413,11 +444,13 @@ class CliRunner: return_value=return_value, exit_code=exit_code, exception=exception, - exc_info=exc_info, + exc_info=exc_info, # type: ignore ) @contextlib.contextmanager - def isolated_filesystem(self, temp_dir=None): + def isolated_filesystem( + self, temp_dir: t.Optional[t.Union[str, os.PathLike]] = None + ) -> t.Iterator[str]: """A context manager that creates a temporary directory and changes the current working directory to it. This isolates tests that affect the contents of the CWD to prevent them from |