diff options
author | David Lord <davidism@gmail.com> | 2021-01-29 18:27:37 -0800 |
---|---|---|
committer | David Lord <davidism@gmail.com> | 2021-01-29 18:40:05 -0800 |
commit | 6a64222bfb18bd49e3a12f509f38ee7f2585799f (patch) | |
tree | 8c36679e1883ee23bcbecd6ccd0ae5d92e8f995e /src | |
parent | 9724cdedc887632d64d8fc7ed40056d0a8431f06 (diff) | |
download | markupsafe-6a64222bfb18bd49e3a12f509f38ee7f2585799f.tar.gz |
add type annotations
Diffstat (limited to 'src')
-rw-r--r-- | src/markupsafe/__init__.py | 128 | ||||
-rw-r--r-- | src/markupsafe/_native.py | 13 | ||||
-rw-r--r-- | src/markupsafe/_speedups.pyi | 20 | ||||
-rw-r--r-- | src/markupsafe/py.typed | 0 |
4 files changed, 110 insertions, 51 deletions
diff --git a/src/markupsafe/__init__.py b/src/markupsafe/__init__.py index 789979f..1d786fd 100644 --- a/src/markupsafe/__init__.py +++ b/src/markupsafe/__init__.py @@ -1,11 +1,32 @@ +import functools import re import string +import typing as t + +if t.TYPE_CHECKING: + + class HasHTML(t.Protocol): + def __html__(self) -> str: + pass + __version__ = "2.0.0a1" _striptags_re = re.compile(r"(<!--.*?-->|<[^>]*>)") +def _simple_escaping_wrapper(name: str) -> t.Callable[..., "Markup"]: + orig = getattr(str, name) + + @functools.wraps(orig) + def wrapped(self: "Markup", *args: t.Any, **kwargs: t.Any) -> "Markup": + args = _escape_argspec(list(args), enumerate(args), self.escape) # type: ignore + _escape_argspec(kwargs, kwargs.items(), self.escape) + return self.__class__(orig(self, *args, **kwargs)) + + return wrapped + + class Markup(str): """A string that is ready to be safely inserted into an HTML or XML document, either because it was escaped or because it was marked @@ -44,64 +65,76 @@ class Markup(str): __slots__ = () - def __new__(cls, base="", encoding=None, errors="strict"): + def __new__( + cls, base: t.Any = "", encoding: t.Optional[str] = None, errors: str = "strict" + ) -> "Markup": if hasattr(base, "__html__"): base = base.__html__() + if encoding is None: return super().__new__(cls, base) + return super().__new__(cls, base, encoding, errors) - def __html__(self): + def __html__(self) -> "Markup": return self - def __add__(self, other): + def __add__(self, other: t.Union[str, "HasHTML"]) -> "Markup": if isinstance(other, str) or hasattr(other, "__html__"): return self.__class__(super().__add__(self.escape(other))) + return NotImplemented - def __radd__(self, other): + def __radd__(self, other: t.Union[str, "HasHTML"]) -> "Markup": if isinstance(other, str) or hasattr(other, "__html__"): return self.escape(other).__add__(self) + return NotImplemented - def __mul__(self, num): + def __mul__(self, num: int) -> "Markup": if isinstance(num, int): return self.__class__(super().__mul__(num)) - return NotImplemented + + return NotImplemented # type: ignore __rmul__ = __mul__ - def __mod__(self, arg): + def __mod__(self, arg: t.Any) -> "Markup": if isinstance(arg, tuple): arg = tuple(_MarkupEscapeHelper(x, self.escape) for x in arg) else: arg = _MarkupEscapeHelper(arg, self.escape) + return self.__class__(super().__mod__(arg)) - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}({super().__repr__()})" - def join(self, seq): + def join(self, seq: t.Iterable[t.Union[str, "HasHTML"]]) -> "Markup": return self.__class__(super().join(map(self.escape, seq))) join.__doc__ = str.join.__doc__ - def split(self, *args, **kwargs): - return list(map(self.__class__, super().split(*args, **kwargs))) + def split( # type: ignore + self, sep: t.Optional[str] = None, maxsplit: int = -1 + ) -> t.List["Markup"]: + return [self.__class__(v) for v in super().split(sep, maxsplit)] split.__doc__ = str.split.__doc__ - def rsplit(self, *args, **kwargs): - return list(map(self.__class__, super().rsplit(*args, **kwargs))) + def rsplit( # type: ignore + self, sep: t.Optional[str] = None, maxsplit: int = -1 + ) -> t.List["Markup"]: + return [self.__class__(v) for v in super().rsplit(sep, maxsplit)] rsplit.__doc__ = str.rsplit.__doc__ - def splitlines(self, *args, **kwargs): - return list(map(self.__class__, super().splitlines(*args, **kwargs))) + def splitlines(self, keepends: bool = False) -> t.List["Markup"]: # type: ignore + return [self.__class__(v) for v in super().splitlines(keepends)] splitlines.__doc__ = str.splitlines.__doc__ - def unescape(self): + def unescape(self) -> str: """Convert escaped markup back into a text string. This replaces HTML entities with the characters they represent. @@ -112,7 +145,7 @@ class Markup(str): return unescape(str(self)) - def striptags(self): + def striptags(self) -> str: """:meth:`unescape` the markup, remove tags, and normalize whitespace to single spaces. @@ -123,26 +156,16 @@ class Markup(str): return Markup(stripped).unescape() @classmethod - def escape(cls, s): + def escape(cls, s: t.Any) -> "Markup": """Escape a string. Calls :func:`escape` and ensures that for subclasses the correct type is returned. """ rv = escape(s) + if rv.__class__ is not cls: return cls(rv) - return rv - - def make_simple_escaping_wrapper(name): # noqa: B902 - orig = getattr(str, name) - - def func(self, *args, **kwargs): - args = _escape_argspec(list(args), enumerate(args), self.escape) - _escape_argspec(kwargs, kwargs.items(), self.escape) - return self.__class__(orig(self, *args, **kwargs)) - func.__name__ = orig.__name__ - func.__doc__ = orig.__doc__ - return func + return rv for method in ( "__getitem__", @@ -162,31 +185,36 @@ class Markup(str): "swapcase", "zfill", ): - locals()[method] = make_simple_escaping_wrapper(method) + locals()[method] = _simple_escaping_wrapper(method) - del method, make_simple_escaping_wrapper + del method - def partition(self, sep): - return tuple(map(self.__class__, super().partition(self.escape(sep)))) + def partition(self, sep: str) -> t.Tuple["Markup", "Markup", "Markup"]: + l, s, r = super().partition(self.escape(sep)) + cls = self.__class__ + return cls(l), cls(s), cls(r) - def rpartition(self, sep): - return tuple(map(self.__class__, super().rpartition(self.escape(sep)))) + def rpartition(self, sep: str) -> t.Tuple["Markup", "Markup", "Markup"]: + l, s, r = super().rpartition(self.escape(sep)) + cls = self.__class__ + return cls(l), cls(s), cls(r) - def format(self, *args, **kwargs): + def format(self, *args: t.Any, **kwargs: t.Any) -> "Markup": formatter = EscapeFormatter(self.escape) return self.__class__(formatter.vformat(self, args, kwargs)) - def __html_format__(self, format_spec): + def __html_format__(self, format_spec: str) -> "Markup": if format_spec: raise ValueError("Unsupported format specification for Markup.") + return self class EscapeFormatter(string.Formatter): - def __init__(self, escape): + def __init__(self, escape: t.Callable[[t.Any], Markup]) -> None: self.escape = escape - def format_field(self, value, format_spec): + def format_field(self, value: t.Any, format_spec: str) -> str: if hasattr(value, "__html_format__"): rv = value.__html_format__(format_spec) elif hasattr(value, "__html__"): @@ -204,34 +232,40 @@ class EscapeFormatter(string.Formatter): return str(self.escape(rv)) -def _escape_argspec(obj, iterable, escape): +_ListOrDict = t.TypeVar("_ListOrDict", list, dict) + + +def _escape_argspec( + obj: _ListOrDict, iterable: t.Iterable[t.Any], escape: t.Callable[[t.Any], Markup] +) -> _ListOrDict: """Helper for various string-wrapped functions.""" for key, value in iterable: if isinstance(value, str) or hasattr(value, "__html__"): obj[key] = escape(value) + return obj class _MarkupEscapeHelper: """Helper for :meth:`Markup.__mod__`.""" - def __init__(self, obj, escape): + def __init__(self, obj: t.Any, escape: t.Callable[[t.Any], Markup]) -> None: self.obj = obj self.escape = escape - def __getitem__(self, item): + def __getitem__(self, item: t.Any) -> "_MarkupEscapeHelper": return _MarkupEscapeHelper(self.obj[item], self.escape) - def __str__(self): + def __str__(self) -> str: return str(self.escape(self.obj)) - def __repr__(self): + def __repr__(self) -> str: return str(self.escape(repr(self.obj))) - def __int__(self): + def __int__(self) -> int: return int(self.obj) - def __float__(self): + def __float__(self) -> float: return float(self.obj) diff --git a/src/markupsafe/_native.py b/src/markupsafe/_native.py index 7722296..1a3f214 100644 --- a/src/markupsafe/_native.py +++ b/src/markupsafe/_native.py @@ -1,7 +1,9 @@ +import typing as t + from . import Markup -def escape(s): +def escape(s: t.Any) -> Markup: """Replace the characters ``&``, ``<``, ``>``, ``'``, and ``"`` in the string with HTML-safe sequences. Use this if you need to display text that might contain such characters in HTML. @@ -14,6 +16,7 @@ def escape(s): """ if hasattr(s, "__html__"): return Markup(s.__html__()) + return Markup( str(s) .replace("&", "&") @@ -24,7 +27,7 @@ def escape(s): ) -def escape_silent(s): +def escape_silent(s: t.Optional[t.Any]) -> Markup: """Like :func:`escape` but treats ``None`` as the empty string. Useful with optional values, as otherwise you get the string ``'None'`` when the value is ``None``. @@ -36,10 +39,11 @@ def escape_silent(s): """ if s is None: return Markup() + return escape(s) -def soft_str(s): +def soft_str(s: t.Any) -> str: """Convert an object to a string if it isn't already. This preserves a :class:`Markup` string rather than converting it back to a basic string, so it will still be marked as safe and won't be escaped @@ -55,10 +59,11 @@ def soft_str(s): """ if not isinstance(s, str): return str(s) + return s -def soft_unicode(s): +def soft_unicode(s: t.Any) -> str: import warnings warnings.warn( diff --git a/src/markupsafe/_speedups.pyi b/src/markupsafe/_speedups.pyi new file mode 100644 index 0000000..a3cad64 --- /dev/null +++ b/src/markupsafe/_speedups.pyi @@ -0,0 +1,20 @@ +from typing import Any +from typing import Optional + +from . import Markup + + +def escape(s: Any) -> Markup: + ... + + +def escape_silent(s: Optional[Any]) -> Markup: + ... + + +def soft_str(s: Any) -> str: + ... + + +def soft_unicode(s: Any) -> str: + ... diff --git a/src/markupsafe/py.typed b/src/markupsafe/py.typed new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/markupsafe/py.typed |