diff options
-rw-r--r-- | rdflib/plugins/sparql/parser.py | 25 | ||||
-rw-r--r-- | rdflib/plugins/sparql/parserutils.py | 35 | ||||
-rw-r--r-- | test/test_sparql/test_sparql.py | 24 |
3 files changed, 60 insertions, 24 deletions
diff --git a/rdflib/plugins/sparql/parser.py b/rdflib/plugins/sparql/parser.py index 2a897f82..302271f7 100644 --- a/rdflib/plugins/sparql/parser.py +++ b/rdflib/plugins/sparql/parser.py @@ -6,6 +6,9 @@ based on pyparsing import re import sys +from typing import Any, BinaryIO +from typing import Optional as OptionalType +from typing import TextIO, Tuple, Union from pyparsing import CaselessKeyword as Keyword # watch out :) from pyparsing import ( @@ -37,15 +40,15 @@ DEBUG = False # ---------------- ACTIONS -def neg(literal): +def neg(literal) -> rdflib.Literal: return rdflib.Literal(-literal, datatype=literal.datatype) -def setLanguage(terms): +def setLanguage(terms: Tuple[Any, OptionalType[str]]) -> rdflib.Literal: return rdflib.Literal(terms[0], lang=terms[1]) -def setDataType(terms): +def setDataType(terms: Tuple[Any, OptionalType[str]]) -> rdflib.Literal: return rdflib.Literal(terms[0], datatype=terms[1]) @@ -1508,25 +1511,27 @@ QueryUnit.ignore("#" + restOfLine) UpdateUnit.ignore("#" + restOfLine) -expandUnicodeEscapes_re = re.compile(r"\\u([0-9a-f]{4}(?:[0-9a-f]{4})?)", flags=re.I) +expandUnicodeEscapes_re: re.Pattern = re.compile( + r"\\u([0-9a-f]{4}(?:[0-9a-f]{4})?)", flags=re.I +) -def expandUnicodeEscapes(q): +def expandUnicodeEscapes(q: str) -> str: r""" The syntax of the SPARQL Query Language is expressed over code points in Unicode [UNICODE]. The encoding is always UTF-8 [RFC3629]. Unicode code points may also be expressed using an \ uXXXX (U+0 to U+FFFF) or \ UXXXXXXXX syntax (for U+10000 onwards) where X is a hexadecimal digit [0-9A-F] """ - def expand(m): + def expand(m: re.Match) -> str: try: return chr(int(m.group(1), 16)) - except: # noqa: E722 - raise Exception("Invalid unicode code point: " + m) + except (ValueError, OverflowError) as e: + raise ValueError("Invalid unicode code point: " + m.group(1)) from e return expandUnicodeEscapes_re.sub(expand, q) -def parseQuery(q): +def parseQuery(q: Union[str, bytes, TextIO, BinaryIO]) -> ParseResults: if hasattr(q, "read"): q = q.read() if isinstance(q, bytes): @@ -1536,7 +1541,7 @@ def parseQuery(q): return Query.parseString(q, parseAll=True) -def parseUpdate(q): +def parseUpdate(q: Union[str, bytes, TextIO, BinaryIO]): if hasattr(q, "read"): q = q.read() diff --git a/rdflib/plugins/sparql/parserutils.py b/rdflib/plugins/sparql/parserutils.py index 1f2e88ea..09f19ff8 100644 --- a/rdflib/plugins/sparql/parserutils.py +++ b/rdflib/plugins/sparql/parserutils.py @@ -1,10 +1,11 @@ from collections import OrderedDict from types import MethodType -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, List, Tuple, Union from pyparsing import ParseResults, TokenConverter, originalTextFor from rdflib import BNode, Variable +from rdflib.term import Identifier if TYPE_CHECKING: from rdflib.plugins.sparql.sparql import FrozenBindings @@ -252,26 +253,34 @@ class Comp(TokenConverter): return self -def prettify_parsetree(t, indent="", depth=0): - out = [] - if isinstance(t, ParseResults): - for e in t.asList(): - out.append(prettify_parsetree(e, indent, depth + 1)) - for k, v in sorted(t.items()): - out.append("%s%s- %s:\n" % (indent, " " * depth, k)) - out.append(prettify_parsetree(v, indent, depth + 1)) - elif isinstance(t, CompValue): +def prettify_parsetree(t: ParseResults, indent: str = "", depth: int = 0) -> str: + out: List[str] = [] + for e in t.asList(): + out.append(_prettify_sub_parsetree(e, indent, depth + 1)) + for k, v in sorted(t.items()): + out.append("%s%s- %s:\n" % (indent, " " * depth, k)) + out.append(_prettify_sub_parsetree(v, indent, depth + 1)) + return "".join(out) + + +def _prettify_sub_parsetree( + t: Union[Identifier, CompValue, set, list, dict, Tuple, bool, None], + indent: str = "", + depth: int = 0, +) -> str: + out: List[str] = [] + if isinstance(t, CompValue): out.append("%s%s> %s:\n" % (indent, " " * depth, t.name)) for k, v in t.items(): out.append("%s%s- %s:\n" % (indent, " " * (depth + 1), k)) - out.append(prettify_parsetree(v, indent, depth + 2)) + out.append(_prettify_sub_parsetree(v, indent, depth + 2)) elif isinstance(t, dict): for k, v in t.items(): out.append("%s%s- %s:\n" % (indent, " " * (depth + 1), k)) - out.append(prettify_parsetree(v, indent, depth + 2)) + out.append(_prettify_sub_parsetree(v, indent, depth + 2)) elif isinstance(t, list): for e in t: - out.append(prettify_parsetree(e, indent, depth + 1)) + out.append(_prettify_sub_parsetree(e, indent, depth + 1)) else: out.append("%s%s- %r\n" % (indent, " " * depth, t)) return "".join(out) diff --git a/test/test_sparql/test_sparql.py b/test/test_sparql/test_sparql.py index 02406bdf..80768604 100644 --- a/test/test_sparql/test_sparql.py +++ b/test/test_sparql/test_sparql.py @@ -16,7 +16,7 @@ from rdflib.plugins.sparql import prepareQuery, sparql from rdflib.plugins.sparql.algebra import translateQuery from rdflib.plugins.sparql.evaluate import evalPart from rdflib.plugins.sparql.evalutils import _eval -from rdflib.plugins.sparql.parser import parseQuery +from rdflib.plugins.sparql.parser import expandUnicodeEscapes, parseQuery from rdflib.plugins.sparql.parserutils import prettify_parsetree from rdflib.plugins.sparql.sparql import SPARQLError from rdflib.query import Result, ResultRow @@ -957,3 +957,25 @@ def test_sparql_describe( subjects = {s for s in r.graph.subjects() if not isinstance(s, BNode)} assert subjects == expected_subjects assert len(r.graph) == expected_size + + +@pytest.mark.parametrize( + "arg, expected_result, expected_valid", + [ + ("abc", "abc", True), + ("1234", "1234", True), + (r"1234\u0050", "1234P", True), + (r"1234\u00e3", "1234\u00e3", True), + (r"1234\u00e3\u00e5", "1234ãå", True), + (r"1234\u900000e5", "", False), + (r"1234\u010000e5", "", False), + (r"1234\u001000e5", "1234\U001000e5", True), + ], +) +def test_expand_unicode_escapes(arg: str, expected_result: str, expected_valid: bool): + if expected_valid: + actual_result = expandUnicodeEscapes(arg) + assert actual_result == expected_result + else: + with pytest.raises(ValueError, match="Invalid unicode code point"): + _ = expandUnicodeEscapes(arg) |