summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDavid Lord <davidism@gmail.com>2021-01-29 18:27:37 -0800
committerDavid Lord <davidism@gmail.com>2021-01-29 18:40:05 -0800
commit6a64222bfb18bd49e3a12f509f38ee7f2585799f (patch)
tree8c36679e1883ee23bcbecd6ccd0ae5d92e8f995e /src
parent9724cdedc887632d64d8fc7ed40056d0a8431f06 (diff)
downloadmarkupsafe-6a64222bfb18bd49e3a12f509f38ee7f2585799f.tar.gz
add type annotations
Diffstat (limited to 'src')
-rw-r--r--src/markupsafe/__init__.py128
-rw-r--r--src/markupsafe/_native.py13
-rw-r--r--src/markupsafe/_speedups.pyi20
-rw-r--r--src/markupsafe/py.typed0
4 files changed, 110 insertions, 51 deletions
diff --git a/src/markupsafe/__init__.py b/src/markupsafe/__init__.py
index 789979f..1d786fd 100644
--- a/src/markupsafe/__init__.py
+++ b/src/markupsafe/__init__.py
@@ -1,11 +1,32 @@
+import functools
import re
import string
+import typing as t
+
+if t.TYPE_CHECKING:
+
+ class HasHTML(t.Protocol):
+ def __html__(self) -> str:
+ pass
+
__version__ = "2.0.0a1"
_striptags_re = re.compile(r"(<!--.*?-->|<[^>]*>)")
+def _simple_escaping_wrapper(name: str) -> t.Callable[..., "Markup"]:
+ orig = getattr(str, name)
+
+ @functools.wraps(orig)
+ def wrapped(self: "Markup", *args: t.Any, **kwargs: t.Any) -> "Markup":
+ args = _escape_argspec(list(args), enumerate(args), self.escape) # type: ignore
+ _escape_argspec(kwargs, kwargs.items(), self.escape)
+ return self.__class__(orig(self, *args, **kwargs))
+
+ return wrapped
+
+
class Markup(str):
"""A string that is ready to be safely inserted into an HTML or XML
document, either because it was escaped or because it was marked
@@ -44,64 +65,76 @@ class Markup(str):
__slots__ = ()
- def __new__(cls, base="", encoding=None, errors="strict"):
+ def __new__(
+ cls, base: t.Any = "", encoding: t.Optional[str] = None, errors: str = "strict"
+ ) -> "Markup":
if hasattr(base, "__html__"):
base = base.__html__()
+
if encoding is None:
return super().__new__(cls, base)
+
return super().__new__(cls, base, encoding, errors)
- def __html__(self):
+ def __html__(self) -> "Markup":
return self
- def __add__(self, other):
+ def __add__(self, other: t.Union[str, "HasHTML"]) -> "Markup":
if isinstance(other, str) or hasattr(other, "__html__"):
return self.__class__(super().__add__(self.escape(other)))
+
return NotImplemented
- def __radd__(self, other):
+ def __radd__(self, other: t.Union[str, "HasHTML"]) -> "Markup":
if isinstance(other, str) or hasattr(other, "__html__"):
return self.escape(other).__add__(self)
+
return NotImplemented
- def __mul__(self, num):
+ def __mul__(self, num: int) -> "Markup":
if isinstance(num, int):
return self.__class__(super().__mul__(num))
- return NotImplemented
+
+ return NotImplemented # type: ignore
__rmul__ = __mul__
- def __mod__(self, arg):
+ def __mod__(self, arg: t.Any) -> "Markup":
if isinstance(arg, tuple):
arg = tuple(_MarkupEscapeHelper(x, self.escape) for x in arg)
else:
arg = _MarkupEscapeHelper(arg, self.escape)
+
return self.__class__(super().__mod__(arg))
- def __repr__(self):
+ def __repr__(self) -> str:
return f"{self.__class__.__name__}({super().__repr__()})"
- def join(self, seq):
+ def join(self, seq: t.Iterable[t.Union[str, "HasHTML"]]) -> "Markup":
return self.__class__(super().join(map(self.escape, seq)))
join.__doc__ = str.join.__doc__
- def split(self, *args, **kwargs):
- return list(map(self.__class__, super().split(*args, **kwargs)))
+ def split( # type: ignore
+ self, sep: t.Optional[str] = None, maxsplit: int = -1
+ ) -> t.List["Markup"]:
+ return [self.__class__(v) for v in super().split(sep, maxsplit)]
split.__doc__ = str.split.__doc__
- def rsplit(self, *args, **kwargs):
- return list(map(self.__class__, super().rsplit(*args, **kwargs)))
+ def rsplit( # type: ignore
+ self, sep: t.Optional[str] = None, maxsplit: int = -1
+ ) -> t.List["Markup"]:
+ return [self.__class__(v) for v in super().rsplit(sep, maxsplit)]
rsplit.__doc__ = str.rsplit.__doc__
- def splitlines(self, *args, **kwargs):
- return list(map(self.__class__, super().splitlines(*args, **kwargs)))
+ def splitlines(self, keepends: bool = False) -> t.List["Markup"]: # type: ignore
+ return [self.__class__(v) for v in super().splitlines(keepends)]
splitlines.__doc__ = str.splitlines.__doc__
- def unescape(self):
+ def unescape(self) -> str:
"""Convert escaped markup back into a text string. This replaces
HTML entities with the characters they represent.
@@ -112,7 +145,7 @@ class Markup(str):
return unescape(str(self))
- def striptags(self):
+ def striptags(self) -> str:
""":meth:`unescape` the markup, remove tags, and normalize
whitespace to single spaces.
@@ -123,26 +156,16 @@ class Markup(str):
return Markup(stripped).unescape()
@classmethod
- def escape(cls, s):
+ def escape(cls, s: t.Any) -> "Markup":
"""Escape a string. Calls :func:`escape` and ensures that for
subclasses the correct type is returned.
"""
rv = escape(s)
+
if rv.__class__ is not cls:
return cls(rv)
- return rv
-
- def make_simple_escaping_wrapper(name): # noqa: B902
- orig = getattr(str, name)
-
- def func(self, *args, **kwargs):
- args = _escape_argspec(list(args), enumerate(args), self.escape)
- _escape_argspec(kwargs, kwargs.items(), self.escape)
- return self.__class__(orig(self, *args, **kwargs))
- func.__name__ = orig.__name__
- func.__doc__ = orig.__doc__
- return func
+ return rv
for method in (
"__getitem__",
@@ -162,31 +185,36 @@ class Markup(str):
"swapcase",
"zfill",
):
- locals()[method] = make_simple_escaping_wrapper(method)
+ locals()[method] = _simple_escaping_wrapper(method)
- del method, make_simple_escaping_wrapper
+ del method
- def partition(self, sep):
- return tuple(map(self.__class__, super().partition(self.escape(sep))))
+ def partition(self, sep: str) -> t.Tuple["Markup", "Markup", "Markup"]:
+ l, s, r = super().partition(self.escape(sep))
+ cls = self.__class__
+ return cls(l), cls(s), cls(r)
- def rpartition(self, sep):
- return tuple(map(self.__class__, super().rpartition(self.escape(sep))))
+ def rpartition(self, sep: str) -> t.Tuple["Markup", "Markup", "Markup"]:
+ l, s, r = super().rpartition(self.escape(sep))
+ cls = self.__class__
+ return cls(l), cls(s), cls(r)
- def format(self, *args, **kwargs):
+ def format(self, *args: t.Any, **kwargs: t.Any) -> "Markup":
formatter = EscapeFormatter(self.escape)
return self.__class__(formatter.vformat(self, args, kwargs))
- def __html_format__(self, format_spec):
+ def __html_format__(self, format_spec: str) -> "Markup":
if format_spec:
raise ValueError("Unsupported format specification for Markup.")
+
return self
class EscapeFormatter(string.Formatter):
- def __init__(self, escape):
+ def __init__(self, escape: t.Callable[[t.Any], Markup]) -> None:
self.escape = escape
- def format_field(self, value, format_spec):
+ def format_field(self, value: t.Any, format_spec: str) -> str:
if hasattr(value, "__html_format__"):
rv = value.__html_format__(format_spec)
elif hasattr(value, "__html__"):
@@ -204,34 +232,40 @@ class EscapeFormatter(string.Formatter):
return str(self.escape(rv))
-def _escape_argspec(obj, iterable, escape):
+_ListOrDict = t.TypeVar("_ListOrDict", list, dict)
+
+
+def _escape_argspec(
+ obj: _ListOrDict, iterable: t.Iterable[t.Any], escape: t.Callable[[t.Any], Markup]
+) -> _ListOrDict:
"""Helper for various string-wrapped functions."""
for key, value in iterable:
if isinstance(value, str) or hasattr(value, "__html__"):
obj[key] = escape(value)
+
return obj
class _MarkupEscapeHelper:
"""Helper for :meth:`Markup.__mod__`."""
- def __init__(self, obj, escape):
+ def __init__(self, obj: t.Any, escape: t.Callable[[t.Any], Markup]) -> None:
self.obj = obj
self.escape = escape
- def __getitem__(self, item):
+ def __getitem__(self, item: t.Any) -> "_MarkupEscapeHelper":
return _MarkupEscapeHelper(self.obj[item], self.escape)
- def __str__(self):
+ def __str__(self) -> str:
return str(self.escape(self.obj))
- def __repr__(self):
+ def __repr__(self) -> str:
return str(self.escape(repr(self.obj)))
- def __int__(self):
+ def __int__(self) -> int:
return int(self.obj)
- def __float__(self):
+ def __float__(self) -> float:
return float(self.obj)
diff --git a/src/markupsafe/_native.py b/src/markupsafe/_native.py
index 7722296..1a3f214 100644
--- a/src/markupsafe/_native.py
+++ b/src/markupsafe/_native.py
@@ -1,7 +1,9 @@
+import typing as t
+
from . import Markup
-def escape(s):
+def escape(s: t.Any) -> Markup:
"""Replace the characters ``&``, ``<``, ``>``, ``'``, and ``"`` in
the string with HTML-safe sequences. Use this if you need to display
text that might contain such characters in HTML.
@@ -14,6 +16,7 @@ def escape(s):
"""
if hasattr(s, "__html__"):
return Markup(s.__html__())
+
return Markup(
str(s)
.replace("&", "&amp;")
@@ -24,7 +27,7 @@ def escape(s):
)
-def escape_silent(s):
+def escape_silent(s: t.Optional[t.Any]) -> Markup:
"""Like :func:`escape` but treats ``None`` as the empty string.
Useful with optional values, as otherwise you get the string
``'None'`` when the value is ``None``.
@@ -36,10 +39,11 @@ def escape_silent(s):
"""
if s is None:
return Markup()
+
return escape(s)
-def soft_str(s):
+def soft_str(s: t.Any) -> str:
"""Convert an object to a string if it isn't already. This preserves
a :class:`Markup` string rather than converting it back to a basic
string, so it will still be marked as safe and won't be escaped
@@ -55,10 +59,11 @@ def soft_str(s):
"""
if not isinstance(s, str):
return str(s)
+
return s
-def soft_unicode(s):
+def soft_unicode(s: t.Any) -> str:
import warnings
warnings.warn(
diff --git a/src/markupsafe/_speedups.pyi b/src/markupsafe/_speedups.pyi
new file mode 100644
index 0000000..a3cad64
--- /dev/null
+++ b/src/markupsafe/_speedups.pyi
@@ -0,0 +1,20 @@
+from typing import Any
+from typing import Optional
+
+from . import Markup
+
+
+def escape(s: Any) -> Markup:
+ ...
+
+
+def escape_silent(s: Optional[Any]) -> Markup:
+ ...
+
+
+def soft_str(s: Any) -> str:
+ ...
+
+
+def soft_unicode(s: Any) -> str:
+ ...
diff --git a/src/markupsafe/py.typed b/src/markupsafe/py.typed
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/markupsafe/py.typed