summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIwan Aucamp <aucampia@gmail.com>2023-03-12 10:08:50 +0100
committerGitHub <noreply@github.com>2023-03-12 10:08:50 +0100
commit314c7701f569d5cf49265f1ad5706b68d701a950 (patch)
treefb3989cb05c945181bb9136b863415cdd1cb962b
parente9a81ceb510ff5d16fd7e7e5e3eb0f52182d1f98 (diff)
downloadrdflib-314c7701f569d5cf49265f1ad5706b68d701a950.tar.gz
feat: add typing to `rdflib.util` (#2262)
Mainly so that users can use RDFLib in a safer way, and that we can make safer changes to RDFLib in future. There are also some accomodating type-hint related changes outside of `rdflib.util`. This change does not have a runtime impact.
-rw-r--r--rdflib/plugins/serializers/rdfxml.py3
-rw-r--r--rdflib/util.py49
-rw-r--r--test/test_parsers/test_parser_turtlelike.py3
-rw-r--r--test/utils/sparql_checker.py12
4 files changed, 46 insertions, 21 deletions
diff --git a/rdflib/plugins/serializers/rdfxml.py b/rdflib/plugins/serializers/rdfxml.py
index 2bba8b40..e3d9ec77 100644
--- a/rdflib/plugins/serializers/rdfxml.py
+++ b/rdflib/plugins/serializers/rdfxml.py
@@ -251,7 +251,8 @@ class PrettyXMLSerializer(Serializer):
type = first(store.objects(subject, RDF.type))
try:
- self.nm.qname(type)
+ # type error: Argument 1 to "qname" of "NamespaceManager" has incompatible type "Optional[Node]"; expected "str"
+ self.nm.qname(type) # type: ignore[arg-type]
except:
type = None
diff --git a/rdflib/util.py b/rdflib/util.py
index 487d7bd1..4485de2e 100644
--- a/rdflib/util.py
+++ b/rdflib/util.py
@@ -31,12 +31,16 @@ from typing import (
TYPE_CHECKING,
Any,
Callable,
+ Dict,
+ Hashable,
+ Iterable,
Iterator,
List,
Optional,
Set,
Tuple,
TypeVar,
+ Union,
overload,
)
from urllib.parse import quote, urlsplit, urlunsplit
@@ -65,17 +69,21 @@ __all__ = [
"_iri2uri",
]
+_HashableT = TypeVar("_HashableT", bound=Hashable)
+_AnyT = TypeVar("_AnyT")
+
-def list2set(seq):
+def list2set(seq: Iterable[_HashableT]) -> List[_HashableT]:
"""
Return a new list without duplicates.
Preserves the order, unlike set(seq)
"""
seen = set()
- return [x for x in seq if x not in seen and not seen.add(x)]
+ # type error: "add" of "set" does not return a value
+ return [x for x in seq if x not in seen and not seen.add(x)] # type: ignore[func-returns-value]
-def first(seq):
+def first(seq: Iterable[_AnyT]) -> Optional[_AnyT]:
"""
return the first element in a python sequence
for graphs, use graph.value instead
@@ -85,7 +93,7 @@ def first(seq):
return None
-def uniq(sequence, strip=0):
+def uniq(sequence: Iterable[str], strip: int = 0) -> Set[str]:
"""removes duplicate strings from the sequence."""
if strip:
return set(s.strip() for s in sequence)
@@ -93,7 +101,7 @@ def uniq(sequence, strip=0):
return set(sequence)
-def more_than(sequence, number):
+def more_than(sequence: Iterable[Any], number: int) -> int:
"Returns 1 if sequence has more items than number and 0 if not."
i = 0
for item in sequence:
@@ -103,7 +111,9 @@ def more_than(sequence, number):
return 0
-def to_term(s, default=None):
+def to_term(
+ s: Optional[str], default: Optional[rdflib.term.Identifier] = None
+) -> Optional[rdflib.term.Identifier]:
"""
Creates and returns an Identifier of type corresponding
to the pattern of the given positional argument string ``s``:
@@ -130,7 +140,12 @@ def to_term(s, default=None):
raise Exception(msg)
-def from_n3(s: str, default=None, backend=None, nsm=None):
+def from_n3(
+ s: str,
+ default: Optional[str] = None,
+ backend: Optional[str] = None,
+ nsm: Optional[rdflib.namespace.NamespaceManager] = None,
+) -> Optional[Union[rdflib.term.Node, str]]:
r'''
Creates the Identifier corresponding to the given n3 string.
@@ -196,7 +211,8 @@ def from_n3(s: str, default=None, backend=None, nsm=None):
# Hack: this should correctly handle strings with either native unicode
# characters, or \u1234 unicode escapes.
value = value.encode("raw-unicode-escape").decode("unicode-escape")
- return rdflib.term.Literal(value, language, datatype)
+ # type error: Argument 3 to "Literal" has incompatible type "Union[Node, str, None]"; expected "Optional[str]"
+ return rdflib.term.Literal(value, language, datatype) # type: ignore[arg-type]
elif s == "true" or s == "false":
return rdflib.term.Literal(s == "true")
elif (
@@ -214,10 +230,14 @@ def from_n3(s: str, default=None, backend=None, nsm=None):
elif s.startswith("{"):
identifier = from_n3(s[1:-1])
- return rdflib.graph.QuotedGraph(backend, identifier)
+ # type error: Argument 1 to "QuotedGraph" has incompatible type "Optional[str]"; expected "Union[Store, str]"
+ # type error: Argument 2 to "QuotedGraph" has incompatible type "Union[Node, str, None]"; expected "Union[IdentifiedNode, str, None]"
+ return rdflib.graph.QuotedGraph(backend, identifier) # type: ignore[arg-type]
elif s.startswith("["):
identifier = from_n3(s[1:-1])
- return rdflib.graph.Graph(backend, identifier)
+ # type error: Argument 1 to "Graph" has incompatible type "Optional[str]"; expected "Union[Store, str]"
+ # type error: Argument 2 to "Graph" has incompatible type "Union[Node, str, None]"; expected "Union[IdentifiedNode, str, None]"
+ return rdflib.graph.Graph(backend, identifier) # type: ignore[arg-type]
elif s.startswith("_:"):
return rdflib.term.BNode(s[2:])
elif ":" in s:
@@ -266,7 +286,7 @@ def date_time(t=None, local_time_zone=False):
return s
-def parse_date_time(val):
+def parse_date_time(val: str) -> int:
"""always returns seconds in UTC
# tests are written like this to make any errors easier to understand
@@ -330,7 +350,7 @@ SUFFIX_FORMAT_MAP = {
}
-def guess_format(fpath, fmap=None) -> Optional[str]:
+def guess_format(fpath: str, fmap: Optional[Dict[str, str]] = None) -> Optional[str]:
"""
Guess RDF serialization based on file suffix. Uses
``SUFFIX_FORMAT_MAP`` unless ``fmap`` is provided. Examples:
@@ -364,7 +384,7 @@ def guess_format(fpath, fmap=None) -> Optional[str]:
return fmap.get(_get_ext(fpath)) or fmap.get(fpath.lower())
-def _get_ext(fpath, lower=True):
+def _get_ext(fpath: str, lower: bool = True) -> str:
"""
Gets the file extension from a file(path); stripped of leading '.' and in
lower case. Examples:
@@ -465,9 +485,6 @@ def get_tree(
return (mapper(root), sorted(tree, key=sortkey))
-_AnyT = TypeVar("_AnyT")
-
-
@overload
def _coalesce(*args: Optional[_AnyT], default: _AnyT) -> _AnyT:
...
diff --git a/test/test_parsers/test_parser_turtlelike.py b/test/test_parsers/test_parser_turtlelike.py
index 1eac25e1..e74a55e7 100644
--- a/test/test_parsers/test_parser_turtlelike.py
+++ b/test/test_parsers/test_parser_turtlelike.py
@@ -78,7 +78,8 @@ def parse_identifier(identifier_string: str, format: str) -> Identifier:
def parse_n3_identifier(identifier_string: str, format: str) -> Identifier:
- return from_n3(identifier_string)
+ # type error: Incompatible return value type (got "Union[Node, str, None]", expected "Identifier")
+ return from_n3(identifier_string) # type: ignore[return-value]
ParseFunction = Callable[[str, str], Identifier]
diff --git a/test/utils/sparql_checker.py b/test/utils/sparql_checker.py
index b076a73a..836c040f 100644
--- a/test/utils/sparql_checker.py
+++ b/test/utils/sparql_checker.py
@@ -132,7 +132,10 @@ class GraphData:
"public_id = %s - graph = %s\n%s", public_id, graph_path, graph_text
)
dataset.parse(
- data=graph_text, publicID=public_id, format=guess_format(graph_path)
+ # type error: Argument 1 to "guess_format" has incompatible type "Path"; expected "str"
+ data=graph_text,
+ publicID=public_id,
+ format=guess_format(graph_path), # type: ignore[arg-type]
)
@@ -211,7 +214,9 @@ class SPARQLEntry(ManifestEntry):
data_text,
)
dataset.default_context.parse(
- data=data_text, format=guess_format(data_path)
+ # type error: Argument 1 to "guess_format" has incompatible type "Path"; expected "str"
+ data=data_text,
+ format=guess_format(data_path), # type: ignore[arg-type]
)
if graph_data_set is not None:
for graph_data in graph_data_set:
@@ -352,7 +357,8 @@ def patched_query_context_load(uri_mapper: URIMapper) -> Callable[..., Any]:
) -> None:
public_id = None
use_source: Union[URIRef, Path] = source
- format = guess_format(use_source)
+ # type error: Argument 1 to "guess_format" has incompatible type "Union[URIRef, Path]"; expected "str"
+ format = guess_format(use_source) # type: ignore[arg-type]
if f"{source}".startswith(("https://", "http://")):
use_source = uri_mapper.to_local_path(source)
public_id = source