summaryrefslogtreecommitdiff
path: root/src/click/testing.py
diff options
context:
space:
mode:
authorDavid Lord <davidism@gmail.com>2021-04-23 08:38:47 -0700
committerDavid Lord <davidism@gmail.com>2021-04-23 08:38:47 -0700
commit0103c9570650daa59560baf42ad0a27e57b3157f (patch)
tree20c3c4ea3d69c24b91401881a903a4d822448dce /src/click/testing.py
parent77ce48f8d7d3b64a09741cf53dd2995d668317cf (diff)
downloadclick-0103c9570650daa59560baf42ad0a27e57b3157f.tar.gz
add typing annotations
Diffstat (limited to 'src/click/testing.py')
-rw-r--r--src/click/testing.py171
1 files changed, 102 insertions, 69 deletions
diff --git a/src/click/testing.py b/src/click/testing.py
index 637c46c..d19b850 100644
--- a/src/click/testing.py
+++ b/src/click/testing.py
@@ -5,49 +5,54 @@ import shlex
import shutil
import sys
import tempfile
+import typing as t
+from types import TracebackType
from . import formatting
from . import termui
from . import utils
from ._compat import _find_binary_reader
+if t.TYPE_CHECKING:
+ from .core import BaseCommand
+
class EchoingStdin:
- def __init__(self, input, output):
+ def __init__(self, input: t.BinaryIO, output: t.BinaryIO) -> None:
self._input = input
self._output = output
self._paused = False
- def __getattr__(self, x):
+ def __getattr__(self, x: str) -> t.Any:
return getattr(self._input, x)
- def _echo(self, rv):
+ def _echo(self, rv: bytes) -> bytes:
if not self._paused:
self._output.write(rv)
return rv
- def read(self, n=-1):
+ def read(self, n: int = -1) -> bytes:
return self._echo(self._input.read(n))
- def read1(self, n=-1):
- return self._echo(self._input.read1(n))
+ def read1(self, n: int = -1) -> bytes:
+ return self._echo(self._input.read1(n)) # type: ignore
- def readline(self, n=-1):
+ def readline(self, n: int = -1) -> bytes:
return self._echo(self._input.readline(n))
- def readlines(self):
+ def readlines(self) -> t.List[bytes]:
return [self._echo(x) for x in self._input.readlines()]
- def __iter__(self):
+ def __iter__(self) -> t.Iterator[bytes]:
return iter(self._echo(x) for x in self._input)
- def __repr__(self):
+ def __repr__(self) -> str:
return repr(self._input)
@contextlib.contextmanager
-def _pause_echo(stream):
+def _pause_echo(stream: t.Optional[EchoingStdin]) -> t.Iterator[None]:
if stream is None:
yield
else:
@@ -57,24 +62,28 @@ def _pause_echo(stream):
class _NamedTextIOWrapper(io.TextIOWrapper):
- def __init__(self, buffer, name=None, mode=None, **kwargs):
+ def __init__(
+ self, buffer: t.BinaryIO, name: str, mode: str, **kwargs: t.Any
+ ) -> None:
super().__init__(buffer, **kwargs)
self._name = name
self._mode = mode
@property
- def name(self):
+ def name(self) -> str:
return self._name
@property
- def mode(self):
+ def mode(self) -> str:
return self._mode
-def make_input_stream(input, charset):
+def make_input_stream(
+ input: t.Optional[t.Union[str, bytes, t.IO]], charset: str
+) -> t.BinaryIO:
# Is already an input stream.
if hasattr(input, "read"):
- rv = _find_binary_reader(input)
+ rv = _find_binary_reader(t.cast(t.IO, input))
if rv is not None:
return rv
@@ -83,10 +92,10 @@ def make_input_stream(input, charset):
if input is None:
input = b""
- elif not isinstance(input, bytes):
+ elif isinstance(input, str):
input = input.encode(charset)
- return io.BytesIO(input)
+ return io.BytesIO(t.cast(bytes, input))
class Result:
@@ -94,13 +103,15 @@ class Result:
def __init__(
self,
- runner,
- stdout_bytes,
- stderr_bytes,
- return_value,
- exit_code,
- exception,
- exc_info=None,
+ runner: "CliRunner",
+ stdout_bytes: bytes,
+ stderr_bytes: t.Optional[bytes],
+ return_value: t.Any,
+ exit_code: int,
+ exception: t.Optional[BaseException],
+ exc_info: t.Optional[
+ t.Tuple[t.Type[BaseException], BaseException, TracebackType]
+ ] = None,
):
#: The runner that created the result
self.runner = runner
@@ -120,19 +131,19 @@ class Result:
self.exc_info = exc_info
@property
- def output(self):
+ def output(self) -> str:
"""The (standard) output as unicode string."""
return self.stdout
@property
- def stdout(self):
+ def stdout(self) -> str:
"""The standard output as unicode string."""
return self.stdout_bytes.decode(self.runner.charset, "replace").replace(
"\r\n", "\n"
)
@property
- def stderr(self):
+ def stderr(self) -> str:
"""The standard error as unicode string."""
if self.stderr_bytes is None:
raise ValueError("stderr not separately captured")
@@ -140,7 +151,7 @@ class Result:
"\r\n", "\n"
)
- def __repr__(self):
+ def __repr__(self) -> str:
exc_str = repr(self.exception) if self.exception else "okay"
return f"<{type(self).__name__} {exc_str}>"
@@ -164,20 +175,28 @@ class CliRunner:
independently
"""
- def __init__(self, charset="utf-8", env=None, echo_stdin=False, mix_stderr=True):
+ def __init__(
+ self,
+ charset: str = "utf-8",
+ env: t.Optional[t.Mapping[str, t.Optional[str]]] = None,
+ echo_stdin: bool = False,
+ mix_stderr: bool = True,
+ ) -> None:
self.charset = charset
self.env = env or {}
self.echo_stdin = echo_stdin
self.mix_stderr = mix_stderr
- def get_default_prog_name(self, cli):
+ def get_default_prog_name(self, cli: "BaseCommand") -> str:
"""Given a command object it will return the default program name
for it. The default is the `name` attribute or ``"root"`` if not
set.
"""
return cli.name or "root"
- def make_env(self, overrides=None):
+ def make_env(
+ self, overrides: t.Optional[t.Mapping[str, t.Optional[str]]] = None
+ ) -> t.Mapping[str, t.Optional[str]]:
"""Returns the environment overrides for invoking a script."""
rv = dict(self.env)
if overrides:
@@ -185,7 +204,12 @@ class CliRunner:
return rv
@contextlib.contextmanager
- def isolation(self, input=None, env=None, color=False):
+ def isolation(
+ self,
+ input: t.Optional[t.Union[str, bytes, t.IO]] = None,
+ env: t.Optional[t.Mapping[str, t.Optional[str]]] = None,
+ color: bool = False,
+ ) -> t.Iterator[t.Tuple[io.BytesIO, t.Optional[io.BytesIO]]]:
"""A context manager that sets up the isolation for invoking of a
command line tool. This sets up stdin with the given input data
and `os.environ` with the overrides from the given dictionary.
@@ -206,7 +230,7 @@ class CliRunner:
.. versionchanged:: 4.0
Added the ``color`` parameter.
"""
- input = make_input_stream(input, self.charset)
+ bytes_input = make_input_stream(input, self.charset)
echo_input = None
old_stdin = sys.stdin
@@ -220,16 +244,18 @@ class CliRunner:
bytes_output = io.BytesIO()
if self.echo_stdin:
- input = echo_input = EchoingStdin(input, bytes_output)
+ bytes_input = echo_input = t.cast(
+ t.BinaryIO, EchoingStdin(bytes_input, bytes_output)
+ )
- sys.stdin = input = _NamedTextIOWrapper(
- input, encoding=self.charset, name="<stdin>", mode="r"
+ sys.stdin = text_input = _NamedTextIOWrapper(
+ bytes_input, encoding=self.charset, name="<stdin>", mode="r"
)
if self.echo_stdin:
# Force unbuffered reads, otherwise TextIOWrapper reads a
# large chunk which is echoed early.
- input._CHUNK_SIZE = 1
+ text_input._CHUNK_SIZE = 1 # type: ignore
sys.stdout = _NamedTextIOWrapper(
bytes_output, encoding=self.charset, name="<stdout>", mode="w"
@@ -248,22 +274,22 @@ class CliRunner:
errors="backslashreplace",
)
- @_pause_echo(echo_input)
- def visible_input(prompt=None):
+ @_pause_echo(echo_input) # type: ignore
+ def visible_input(prompt: t.Optional[str] = None) -> str:
sys.stdout.write(prompt or "")
- val = input.readline().rstrip("\r\n")
+ val = text_input.readline().rstrip("\r\n")
sys.stdout.write(f"{val}\n")
sys.stdout.flush()
return val
- @_pause_echo(echo_input)
- def hidden_input(prompt=None):
+ @_pause_echo(echo_input) # type: ignore
+ def hidden_input(prompt: t.Optional[str] = None) -> str:
sys.stdout.write(f"{prompt or ''}\n")
sys.stdout.flush()
- return input.readline().rstrip("\r\n")
+ return text_input.readline().rstrip("\r\n")
- @_pause_echo(echo_input)
- def _getchar(echo):
+ @_pause_echo(echo_input) # type: ignore
+ def _getchar(echo: bool) -> str:
char = sys.stdin.read(1)
if echo:
@@ -274,7 +300,9 @@ class CliRunner:
default_color = color
- def should_strip_ansi(stream=None, color=None):
+ def should_strip_ansi(
+ stream: t.Optional[t.IO] = None, color: t.Optional[bool] = None
+ ) -> bool:
if color is None:
return not default_color
return not color
@@ -282,11 +310,11 @@ class CliRunner:
old_visible_prompt_func = termui.visible_prompt_func
old_hidden_prompt_func = termui.hidden_prompt_func
old__getchar_func = termui._getchar
- old_should_strip_ansi = utils.should_strip_ansi
+ old_should_strip_ansi = utils.should_strip_ansi # type: ignore
termui.visible_prompt_func = visible_input
termui.hidden_prompt_func = hidden_input
termui._getchar = _getchar
- utils.should_strip_ansi = should_strip_ansi
+ utils.should_strip_ansi = should_strip_ansi # type: ignore
old_env = {}
try:
@@ -315,19 +343,19 @@ class CliRunner:
termui.visible_prompt_func = old_visible_prompt_func
termui.hidden_prompt_func = old_hidden_prompt_func
termui._getchar = old__getchar_func
- utils.should_strip_ansi = old_should_strip_ansi
+ utils.should_strip_ansi = old_should_strip_ansi # type: ignore
formatting.FORCED_WIDTH = old_forced_width
def invoke(
self,
- cli,
- args=None,
- input=None,
- env=None,
- catch_exceptions=True,
- color=False,
- **extra,
- ):
+ cli: "BaseCommand",
+ args: t.Optional[t.Union[str, t.Sequence[str]]] = None,
+ input: t.Optional[t.Union[str, bytes, t.IO]] = None,
+ env: t.Optional[t.Mapping[str, t.Optional[str]]] = None,
+ catch_exceptions: bool = True,
+ color: bool = False,
+ **extra: t.Any,
+ ) -> Result:
"""Invokes a command in an isolated environment. The arguments are
forwarded directly to the command line script, the `extra` keyword
arguments are passed to the :meth:`~clickpkg.Command.main` function of
@@ -365,7 +393,7 @@ class CliRunner:
exc_info = None
with self.isolation(input=input, env=env, color=color) as outstreams:
return_value = None
- exception = None
+ exception: t.Optional[BaseException] = None
exit_code = 0
if isinstance(args, str):
@@ -380,17 +408,20 @@ class CliRunner:
return_value = cli.main(args=args or (), prog_name=prog_name, **extra)
except SystemExit as e:
exc_info = sys.exc_info()
- exit_code = e.code
- if exit_code is None:
- exit_code = 0
+ e_code = t.cast(t.Optional[t.Union[int, t.Any]], e.code)
- if exit_code != 0:
+ if e_code is None:
+ e_code = 0
+
+ if e_code != 0:
exception = e
- if not isinstance(exit_code, int):
- sys.stdout.write(str(exit_code))
+ if not isinstance(e_code, int):
+ sys.stdout.write(str(e_code))
sys.stdout.write("\n")
- exit_code = 1
+ e_code = 1
+
+ exit_code = e_code
except Exception as e:
if not catch_exceptions:
@@ -404,7 +435,7 @@ class CliRunner:
if self.mix_stderr:
stderr = None
else:
- stderr = outstreams[1].getvalue()
+ stderr = outstreams[1].getvalue() # type: ignore
return Result(
runner=self,
@@ -413,11 +444,13 @@ class CliRunner:
return_value=return_value,
exit_code=exit_code,
exception=exception,
- exc_info=exc_info,
+ exc_info=exc_info, # type: ignore
)
@contextlib.contextmanager
- def isolated_filesystem(self, temp_dir=None):
+ def isolated_filesystem(
+ self, temp_dir: t.Optional[t.Union[str, os.PathLike]] = None
+ ) -> t.Iterator[str]:
"""A context manager that creates a temporary directory and
changes the current working directory to it. This isolates tests
that affect the contents of the CWD to prevent them from