diff options
Diffstat (limited to 'Cython/TestUtils.py')
-rw-r--r-- | Cython/TestUtils.py | 111 |
1 files changed, 109 insertions, 2 deletions
diff --git a/Cython/TestUtils.py b/Cython/TestUtils.py index bb2070d39..45a8e6f59 100644 --- a/Cython/TestUtils.py +++ b/Cython/TestUtils.py @@ -1,18 +1,21 @@ from __future__ import absolute_import import os +import re import unittest import shlex import sys import tempfile import textwrap from io import open +from functools import partial from .Compiler import Errors from .CodeWriter import CodeWriter -from .Compiler.TreeFragment import TreeFragment, strip_common_indent +from .Compiler.TreeFragment import TreeFragment, strip_common_indent, StringParseContext from .Compiler.Visitor import TreeVisitor, VisitorTransform from .Compiler import TreePath +from .Compiler.ParseTreeTransforms import PostParse class NodeTypeWriter(TreeVisitor): @@ -161,11 +164,81 @@ class TransformTest(CythonTest): return tree +# For the test C code validation, we have to take care that the test directives (and thus +# the match strings) do not just appear in (multiline) C code comments containing the original +# Cython source code. Thus, we discard the comments before matching. +# This seems a prime case for re.VERBOSE, but it seems to match some of the whitespace. +_strip_c_comments = partial(re.compile( + re.sub(r'\s+', '', r''' + /[*] ( + (?: [^*\n] | [*][^/] )* + [\n] + (?: [^*] | [*][^/] )* + ) [*]/ + ''') +).sub, '') + +_strip_cython_code_from_html = partial(re.compile( + re.sub(r'\s\s+', '', r''' + <pre class=["'][^"']*cython\s+line[^"']*["']\s*> + (?:[^<]|<(?!/pre))+ + </pre> + ''') +).sub, '') + + class TreeAssertVisitor(VisitorTransform): # actually, a TreeVisitor would be enough, but this needs to run # as part of the compiler pipeline - def visit_CompilerDirectivesNode(self, node): + def __init__(self): + super(TreeAssertVisitor, self).__init__() + self._module_pos = None + self._c_patterns = [] + self._c_antipatterns = [] + + def create_c_file_validator(self): + patterns, antipatterns = self._c_patterns, self._c_antipatterns + + def fail(pos, pattern, found, file_path): + Errors.error(pos, "Pattern '%s' %s found in %s" %( + pattern, + 'was' if found else 'was not', + file_path, + )) + + def validate_file_content(file_path, content): + for pattern in patterns: + #print("Searching pattern '%s'" % pattern) + if not re.search(pattern, content): + fail(self._module_pos, pattern, found=False, file_path=file_path) + + for antipattern in antipatterns: + #print("Searching antipattern '%s'" % antipattern) + if re.search(antipattern, content): + fail(self._module_pos, antipattern, found=True, file_path=file_path) + + def validate_c_file(result): + c_file = result.c_file + if not (patterns or antipatterns): + #print("No patterns defined for %s" % c_file) + return result + + with open(c_file, encoding='utf8') as f: + content = f.read() + content = _strip_c_comments(content) + validate_file_content(c_file, content) + + html_file = os.path.splitext(c_file)[0] + ".html" + if os.path.exists(html_file) and os.path.getmtime(c_file) <= os.path.getmtime(html_file): + with open(html_file, encoding='utf8') as f: + content = f.read() + content = _strip_cython_code_from_html(content) + validate_file_content(html_file, content) + + return validate_c_file + + def _check_directives(self, node): directives = node.directives if 'test_assert_path_exists' in directives: for path in directives['test_assert_path_exists']: @@ -179,6 +252,19 @@ class TreeAssertVisitor(VisitorTransform): Errors.error( node.pos, "Unexpected path '%s' found in result tree" % path) + if 'test_assert_c_code_has' in directives: + self._c_patterns.extend(directives['test_assert_c_code_has']) + if 'test_fail_if_c_code_has' in directives: + self._c_antipatterns.extend(directives['test_fail_if_c_code_has']) + + def visit_ModuleNode(self, node): + self._module_pos = node.pos + self._check_directives(node) + self.visitchildren(node) + return node + + def visit_CompilerDirectivesNode(self, node): + self._check_directives(node) self.visitchildren(node) return node @@ -272,3 +358,24 @@ def write_newer_file(file_path, newer_than, content, dedent=False, encoding=None while other_time is None or other_time >= os.path.getmtime(file_path): write_file(file_path, content, dedent=dedent, encoding=encoding) + + +def py_parse_code(code): + """ + Compiles code far enough to get errors from the parser and post-parse stage. + + Is useful for checking for syntax errors, however it doesn't generate runable + code. + """ + context = StringParseContext("test") + # all the errors we care about are in the parsing or postparse stage + try: + with Errors.local_errors() as errors: + result = TreeFragment(code, pipeline=[PostParse(context)]) + result = result.substitute() + if errors: + raise errors[0] # compile error, which should get caught + else: + return result + except Errors.CompileError as e: + raise SyntaxError(e.message_only) |