From 9ce7724713eab3fdec44d782d4d7607316c14449 Mon Sep 17 00:00:00 2001 From: Iwan Aucamp Date: Thu, 26 May 2022 14:00:01 +0200 Subject: Add more typing for SPARQL (#1965) This adds typing to `rdflib/plugins/sparql/algebra.py` and `rdflib/plugins/sparql/parserutils.py`. This is mainly being done to help detect issues in the new PRs that were recently opened that relate to SPARQL. This patch contain no runtime changes. --- rdflib/plugins/sparql/algebra.py | 230 ++++++++++++++++++++++++----------- rdflib/plugins/sparql/parserutils.py | 20 ++- 2 files changed, 179 insertions(+), 71 deletions(-) diff --git a/rdflib/plugins/sparql/algebra.py b/rdflib/plugins/sparql/algebra.py index 15e88ee3..eed49db5 100644 --- a/rdflib/plugins/sparql/algebra.py +++ b/rdflib/plugins/sparql/algebra.py @@ -8,12 +8,32 @@ http://www.w3.org/TR/sparql11-query/#sparqlQuery import collections import functools import operator +import typing from functools import reduce +from typing import ( + Any, + Callable, + DefaultDict, + Dict, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + overload, +) from pyparsing import ParseResults -from rdflib import BNode, Literal, URIRef, Variable -from rdflib.paths import AlternativePath, InvPath, MulPath, NegatedPath, SequencePath +from rdflib.paths import ( + AlternativePath, + InvPath, + MulPath, + NegatedPath, + Path, + SequencePath, +) from rdflib.plugins.sparql.operators import TrueFilter, and_ from rdflib.plugins.sparql.operators import simplify as simplifyFilters from rdflib.plugins.sparql.parserutils import CompValue, Expr @@ -21,62 +41,68 @@ from rdflib.plugins.sparql.sparql import Prologue, Query, Update # --------------------------- # Some convenience methods -from rdflib.term import Identifier +from rdflib.term import BNode, Identifier, Literal, URIRef, Variable -def OrderBy(p, expr): +def OrderBy(p: CompValue, expr: List[CompValue]) -> CompValue: return CompValue("OrderBy", p=p, expr=expr) -def ToMultiSet(p): +def ToMultiSet( + p: typing.Union[List[Dict[Variable, Identifier]], CompValue] +) -> CompValue: return CompValue("ToMultiSet", p=p) -def Union(p1, p2): +def Union(p1: CompValue, p2: CompValue) -> CompValue: return CompValue("Union", p1=p1, p2=p2) -def Join(p1, p2): +def Join(p1: CompValue, p2: Optional[CompValue]) -> CompValue: return CompValue("Join", p1=p1, p2=p2) -def Minus(p1, p2): +def Minus(p1: CompValue, p2: CompValue) -> CompValue: return CompValue("Minus", p1=p1, p2=p2) -def Graph(term, graph): +def Graph(term, graph) -> CompValue: return CompValue("Graph", term=term, p=graph) -def BGP(triples=None): +def BGP(triples=None) -> CompValue: return CompValue("BGP", triples=triples or []) -def LeftJoin(p1, p2, expr): +def LeftJoin(p1: CompValue, p2: CompValue, expr) -> CompValue: return CompValue("LeftJoin", p1=p1, p2=p2, expr=expr) -def Filter(expr, p): +def Filter(expr, p: CompValue) -> CompValue: return CompValue("Filter", expr=expr, p=p) -def Extend(p, expr, var): +def Extend(p: CompValue, expr, var) -> CompValue: return CompValue("Extend", p=p, expr=expr, var=var) -def Values(res): +def Values(res) -> CompValue: return CompValue("values", res=res) -def Project(p, PV): +def Project(p: CompValue, PV) -> CompValue: return CompValue("Project", p=p, PV=PV) -def Group(p, expr=None): +def Group(p: CompValue, expr: Optional[List[Variable]] = None) -> CompValue: return CompValue("Group", p=p, expr=expr) -def _knownTerms(triple, varsknown, varscount): +def _knownTerms( + triple: Tuple[Identifier, Identifier, Identifier], + varsknown: Set[typing.Union[BNode, Variable]], + varscount: Dict[Identifier, int], +): return ( len( [ @@ -90,7 +116,9 @@ def _knownTerms(triple, varsknown, varscount): ) -def reorderTriples(l_): +def reorderTriples( + l_: Iterable[Tuple[Identifier, Identifier, Identifier]] +) -> List[Tuple[Identifier, Identifier, Identifier]]: """ Reorder triple patterns so that we execute the ones with most bindings first @@ -100,9 +128,13 @@ def reorderTriples(l_): if isinstance(term, (Variable, BNode)): varsknown.add(term) - l_ = [(None, x) for x in l_] - varsknown = set() - varscount = collections.defaultdict(int) + # NOTE on type errors: most of these are because the same variable is used + # for different types. + + # type error: List comprehension has incompatible type List[Tuple[None, Tuple[Identifier, Identifier, Identifier]]]; expected List[Tuple[Identifier, Identifier, Identifier]] + l_ = [(None, x) for x in l_] # type: ignore[misc] + varsknown: Set[typing.Union[BNode, Variable]] = set() + varscount: Dict[Identifier, int] = collections.defaultdict(int) for t in l_: for c in t[1]: if isinstance(c, (Variable, BNode)): @@ -117,8 +149,11 @@ def reorderTriples(l_): # we sort by decorate/undecorate, since we need the value of the sort keys while i < len(l_): - l_[i:] = sorted((_knownTerms(x[1], varsknown, varscount), x[1]) for x in l_[i:]) - t = l_[i][0][0] # top block has this many terms bound + # type error: Generator has incompatible item type "Tuple[Any, Identifier]"; expected "Tuple[Identifier, Identifier, Identifier]" + # type error: Argument 1 to "_knownTerms" has incompatible type "Identifier"; expected "Tuple[Identifier, Identifier, Identifier]" + l_[i:] = sorted((_knownTerms(x[1], varsknown, varscount), x[1]) for x in l_[i:]) # type: ignore[misc,arg-type] + # type error: Incompatible types in assignment (expression has type "str", variable has type "Tuple[Identifier, Identifier, Identifier]") + t = l_[i][0][0] # type: ignore[assignment] # top block has this many terms bound j = 0 while i + j < len(l_) and l_[i + j][0][0] == t: for c in l_[i + j][1]: @@ -126,18 +161,26 @@ def reorderTriples(l_): j += 1 i += 1 - return [x[1] for x in l_] - + # type error: List comprehension has incompatible type List[Identifier]; expected List[Tuple[Identifier, Identifier, Identifier]] + return [x[1] for x in l_] # type: ignore[misc] -def triples(l): # noqa: E741 - l = reduce(lambda x, y: x + y, l) # noqa: E741 +def triples( + l: typing.Union[ # noqa: E741 + List[List[Identifier]], List[Tuple[Identifier, Identifier, Identifier]] + ] +) -> List[Tuple[Identifier, Identifier, Identifier]]: + # NOTE on type errors: errors are a result of the variable being reused for + # a different type. + # type error: Incompatible types in assignment (expression has type "Sequence[Identifier]", variable has type "Union[List[List[Identifier]], List[Tuple[Identifier, Identifier, Identifier]]]") + l = reduce(lambda x, y: x + y, l) # type: ignore[assignment] # noqa: E741 if (len(l) % 3) != 0: raise Exception("these aint triples") - return reorderTriples((l[x], l[x + 1], l[x + 2]) for x in range(0, len(l), 3)) + # type error: Generator has incompatible item type "Tuple[Union[List[Identifier], Tuple[Identifier, Identifier, Identifier]], Union[List[Identifier], Tuple[Identifier, Identifier, Identifier]], Union[List[Identifier], Tuple[Identifier, Identifier, Identifier]]]"; expected "Tuple[Identifier, Identifier, Identifier]" + return reorderTriples((l[x], l[x + 1], l[x + 2]) for x in range(0, len(l), 3)) # type: ignore[misc] -def translatePName(p, prologue): +def translatePName(p: typing.Union[CompValue, str], prologue: Prologue): """ Expand prefixed/relative URIs """ @@ -145,14 +188,26 @@ def translatePName(p, prologue): if p.name == "pname": return prologue.absolutize(p) if p.name == "literal": + # type error: Argument "datatype" to "Literal" has incompatible type "Union[CompValue, str, None]"; expected "Optional[str]" return Literal( - p.string, lang=p.lang, datatype=prologue.absolutize(p.datatype) + p.string, lang=p.lang, datatype=prologue.absolutize(p.datatype) # type: ignore[arg-type] ) elif isinstance(p, URIRef): return prologue.absolutize(p) -def translatePath(p): +@overload +def translatePath(p: URIRef) -> None: + ... + + +@overload +def translatePath(p: CompValue) -> "Path": + ... + + +# type error: Missing return statement +def translatePath(p: typing.Union[CompValue, URIRef]) -> Optional["Path"]: # type: ignore[return] """ Translate PropertyPath expressions """ @@ -197,7 +252,9 @@ def translatePath(p): return NegatedPath(p.part) -def translateExists(e): +def translateExists( + e: typing.Union[Expr, Literal, Variable] +) -> typing.Union[Expr, Literal, Variable]: """ Translate the graph pattern used by EXISTS and NOT EXISTS http://www.w3.org/TR/sparql11-query/#sparqlCollectFilters @@ -242,8 +299,8 @@ def collectAndRemoveFilters(parts): return None -def translateGroupOrUnionGraphPattern(graphPattern): - A = None +def translateGroupOrUnionGraphPattern(graphPattern: CompValue) -> Optional[CompValue]: + A: Optional[CompValue] = None for g in graphPattern.graph: g = translateGroupGraphPattern(g) @@ -254,15 +311,15 @@ def translateGroupOrUnionGraphPattern(graphPattern): return A -def translateGraphGraphPattern(graphPattern): +def translateGraphGraphPattern(graphPattern: CompValue) -> CompValue: return Graph(graphPattern.term, translateGroupGraphPattern(graphPattern.graph)) -def translateInlineData(graphPattern): +def translateInlineData(graphPattern: CompValue) -> CompValue: return ToMultiSet(translateValues(graphPattern)) -def translateGroupGraphPattern(graphPattern): +def translateGroupGraphPattern(graphPattern: CompValue) -> CompValue: """ http://www.w3.org/TR/sparql11-query/#convertGraphPattern """ @@ -275,7 +332,7 @@ def translateGroupGraphPattern(graphPattern): filters = collectAndRemoveFilters(graphPattern.part) - g = [] + g: List[CompValue] = [] for p in graphPattern.part: if p.name == "TriplesBlock": # merge adjacent TripleBlocks @@ -327,7 +384,11 @@ class StopTraversal(Exception): # noqa: N818 self.rv = rv -def _traverse(e, visitPre=lambda n: None, visitPost=lambda n: None): +def _traverse( + e: Any, + visitPre: Callable[[Any], Any] = lambda n: None, + visitPost: Callable[[Any], Any] = lambda n: None, +): """ Traverse a parse-tree, visit each node @@ -342,21 +403,23 @@ def _traverse(e, visitPre=lambda n: None, visitPost=lambda n: None): if isinstance(e, (list, ParseResults)): return [_traverse(x, visitPre, visitPost) for x in e] - elif isinstance(e, tuple): + # type error: Statement is unreachable + elif isinstance(e, tuple): # type: ignore[unreachable] return tuple([_traverse(x, visitPre, visitPost) for x in e]) elif isinstance(e, CompValue): for k, val in e.items(): e[k] = _traverse(val, visitPre, visitPost) - _e = visitPost(e) + # type error: Statement is unreachable + _e = visitPost(e) # type: ignore[unreachable] if _e is not None: return _e return e -def _traverseAgg(e, visitor=lambda n, v: None): +def _traverseAgg(e, visitor: Callable[[Any, Any], Any] = lambda n, v: None): """ Traverse a parse-tree, visit each node @@ -367,8 +430,8 @@ def _traverseAgg(e, visitor=lambda n, v: None): if isinstance(e, (list, ParseResults, tuple)): res = [_traverseAgg(x, visitor) for x in e] - - elif isinstance(e, CompValue): + # type error: Statement is unreachable + elif isinstance(e, CompValue): # type: ignore[unreachable] for k, val in e.items(): if val is not None: res.append(_traverseAgg(val, visitor)) @@ -376,7 +439,12 @@ def _traverseAgg(e, visitor=lambda n, v: None): return visitor(e, res) -def traverse(tree, visitPre=lambda n: None, visitPost=lambda n: None, complete=None): +def traverse( + tree, + visitPre: Callable[[Any], Any] = lambda n: None, + visitPost: Callable[[Any], Any] = lambda n: None, + complete: Optional[bool] = None, +): """ Traverse tree, visit each node with visit function visit function may raise StopTraversal to stop traversal @@ -392,7 +460,7 @@ def traverse(tree, visitPre=lambda n: None, visitPost=lambda n: None, complete=N return st.rv -def _hasAggregate(x): +def _hasAggregate(x) -> None: """ Traverse parse(sub)Tree return true if any aggregates are used @@ -403,7 +471,8 @@ def _hasAggregate(x): raise StopTraversal(True) -def _aggs(e, A): +# type error: Missing return statement +def _aggs(e, A) -> Optional[Variable]: # type: ignore[return] """ Collect Aggregates in A replaces aggregates with variable references @@ -418,7 +487,8 @@ def _aggs(e, A): return aggvar -def _findVars(x, res): +# type error: Missing return statement +def _findVars(x, res: Set[Variable]) -> Optional[CompValue]: # type: ignore[return] """ Find all variables in a tree """ @@ -434,7 +504,7 @@ def _findVars(x, res): return x -def _addVars(x, children): +def _addVars(x, children) -> Set[Variable]: """ find which variables may be bound by this part of the query """ @@ -467,7 +537,8 @@ def _addVars(x, children): return reduce(operator.or_, children, set()) -def _sample(e, v=None): +# type error: Missing return statement +def _sample(e: typing.Union[CompValue, List[Expr], Expr, List[str], Variable], v: Optional[Variable] = None) -> Optional[CompValue]: # type: ignore[return] """ For each unaggregated variable V in expr Replace V with Sample(V) @@ -483,9 +554,11 @@ def _simplifyFilters(e): return simplifyFilters(e) -def translateAggregates(q, M): - E = [] - A = [] +def translateAggregates( + q: CompValue, M: CompValue +) -> Tuple[CompValue, List[Tuple[Variable, Variable]]]: + E: List[Tuple[Variable, Variable]] = [] + A: List[CompValue] = [] # collect/replace aggs in : # select expr as ?var @@ -517,11 +590,13 @@ def translateAggregates(q, M): return CompValue("AggregateJoin", A=A, p=M), E -def translateValues(v): +def translateValues( + v: CompValue, +) -> typing.Union[List[Dict[Variable, Identifier]], CompValue]: # if len(v.var)!=len(v.value): # raise Exception("Unmatched vars and values in ValueClause: "+str(v)) - res = [] + res: List[Dict[Variable, Identifier]] = [] if not v.var: return res if not v.value: @@ -537,7 +612,7 @@ def translateValues(v): return Values(res) -def translate(q): +def translate(q: CompValue) -> Tuple[CompValue, List[Variable]]: """ http://www.w3.org/TR/sparql11-query/#convertSolMod @@ -548,7 +623,7 @@ def translate(q): q.where = traverse(q.where, visitPost=translatePath) # TODO: Var scope test - VS = set() + VS: Set[Variable] = set() traverse(q.where, functools.partial(_findVars, res=VS)) # all query types have a where part @@ -646,7 +721,8 @@ def translate(q): return M, PV -def simplify(n): +# type error: Missing return statement +def simplify(n) -> Optional[CompValue]: # type: ignore[return] """Remove joins to empty BGPs""" if isinstance(n, CompValue): if n.name == "Join": @@ -678,7 +754,12 @@ def analyse(n, children): return True -def translatePrologue(p, base, initNs=None, prologue=None): +def translatePrologue( + p: ParseResults, + base: Optional[str], + initNs: Optional[Mapping[str, str]] = None, + prologue: Optional[Prologue] = None, +) -> Prologue: if prologue is None: prologue = Prologue() @@ -689,6 +770,7 @@ def translatePrologue(p, base, initNs=None, prologue=None): for k, v in initNs.items(): prologue.bind(k, v) + x: CompValue for x in p: if x.name == "Base": prologue.base = x.iri @@ -698,13 +780,15 @@ def translatePrologue(p, base, initNs=None, prologue=None): return prologue -def translateQuads(quads): +def translateQuads(quads: CompValue): if quads.triples: alltriples = triples(quads.triples) else: alltriples = [] - allquads = collections.defaultdict(list) + allquads: DefaultDict[ + str, List[Tuple[Identifier, Identifier, Identifier]] + ] = collections.defaultdict(list) if quads.quadsNotTriples: for q in quads.quadsNotTriples: @@ -714,7 +798,7 @@ def translateQuads(quads): return alltriples, allquads -def translateUpdate1(u, prologue): +def translateUpdate1(u: CompValue, prologue: Prologue) -> CompValue: if u.name in ("Load", "Clear", "Drop", "Create"): pass # no translation needed elif u.name in ("Add", "Move", "Copy"): @@ -738,15 +822,20 @@ def translateUpdate1(u, prologue): return u -def translateUpdate(q, base=None, initNs=None): +def translateUpdate( + q: CompValue, + base: Optional[str] = None, + initNs: Optional[Mapping[str, str]] = None, +) -> Update: """ Returns a list of SPARQL Update Algebra expressions """ - res = [] + res: List[CompValue] = [] prologue = None if not q.request: - return res + # type error: Incompatible return value type (got "List[CompValue]", expected "Update") + return res # type: ignore[return-value] for p, u in zip(q.prologue, q.request): prologue = translatePrologue(p, base, initNs, prologue) @@ -758,10 +847,15 @@ def translateUpdate(q, base=None, initNs=None): res.append(translateUpdate1(u, prologue)) - return Update(prologue, res) + # type error: Argument 1 to "Update" has incompatible type "Optional[Any]"; expected "Prologue" + return Update(prologue, res) # type: ignore[arg-type] -def translateQuery(q, base=None, initNs=None): +def translateQuery( + q: ParseResults, + base: Optional[str] = None, + initNs: Optional[Mapping[str, str]] = None, +) -> Query: """ Translate a query-parsetree to a SPARQL Algebra Expression @@ -798,7 +892,7 @@ class ExpressionNotCoveredException(Exception): # noqa: N818 pass -def translateAlgebra(query_algebra: Query): +def translateAlgebra(query_algebra: Query) -> str: """ :param query_algebra: An algebra returned by the function call algebra.translateQuery(parse_tree). diff --git a/rdflib/plugins/sparql/parserutils.py b/rdflib/plugins/sparql/parserutils.py index 6748a7f0..a936b046 100644 --- a/rdflib/plugins/sparql/parserutils.py +++ b/rdflib/plugins/sparql/parserutils.py @@ -1,10 +1,14 @@ from collections import OrderedDict from types import MethodType +from typing import TYPE_CHECKING, Any from pyparsing import ParseResults, TokenConverter, originalTextFor from rdflib import BNode, Variable +if TYPE_CHECKING: + from rdflib.plugins.sparql.sparql import FrozenBindings + """ NOTE: PyParsing setResultName/__call__ provides a very similar solution to this @@ -38,7 +42,12 @@ the resulting CompValue # Comp('Sum')( Param('x')(Number) + '+' + Param('y')(Number) ) -def value(ctx, val, variables=False, errors=False): +def value( + ctx: "FrozenBindings", + val: Any, + variables: bool = False, + errors: bool = False, +): """ utility function for evaluating something... @@ -138,7 +147,7 @@ class CompValue(OrderedDict): """ - def __init__(self, name, **values): + def __init__(self, name: str, **values): OrderedDict.__init__(self) self.name = name self.update(values) @@ -164,7 +173,7 @@ class CompValue(OrderedDict): def get(self, a, variables=False, errors=False): return self._value(OrderedDict.get(self, a, a), variables, errors) - def __getattr__(self, a): + def __getattr__(self, a: str) -> Any: # Hack hack: OrderedDict relies on this if a in ("_OrderedDict__root", "_OrderedDict__end"): raise AttributeError() @@ -174,6 +183,11 @@ class CompValue(OrderedDict): # raise AttributeError('no such attribute '+a) return None + if TYPE_CHECKING: + # this is here because properties are dynamically set on CompValue + def __setattr__(self, __name: str, __value: Any) -> None: + ... + class Expr(CompValue): """ -- cgit v1.2.1