diff options
author | Peter Law <PeterJCLaw@gmail.com> | 2020-03-17 20:53:37 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-03-17 13:53:37 -0700 |
commit | 76416437ef22077b5e2949e78fa3000b3580e319 (patch) | |
tree | 166bb5a59dd96ae521f5b47dfa6625af1e91bef6 | |
parent | ef3c5cb31e06e9f094294353c3c83a9decc0fe57 (diff) | |
download | pyflakes-76416437ef22077b5e2949e78fa3000b3580e319.tar.gz |
Fix quoted type annotations in unusual contexts (#516)
* Extract a context manager for when we're in a annotation context
* Detect quoted type annotations within typing.cast calls
* Refactor typing name check
This will make it easier to detect an unspecified typing module
member as well as opening up other checks (such as testing for
one of a collection many members all in one go).
* Detect quoted annotations within subscripts of typing classes
* Add nested quoted test case (Callable)
* Use a lambda here for clarity
This is slightly more usual than accessing a .__eq__ method and is
more obviously similar to the other usage of this helper.
-rw-r--r-- | pyflakes/checker.py | 75 | ||||
-rw-r--r-- | pyflakes/test/test_type_annotations.py | 40 |
2 files changed, 102 insertions, 13 deletions
diff --git a/pyflakes/checker.py b/pyflakes/checker.py index f0b7c37..c239950 100644 --- a/pyflakes/checker.py +++ b/pyflakes/checker.py @@ -8,6 +8,7 @@ import __future__ import ast import bisect import collections +import contextlib import doctest import functools import os @@ -663,17 +664,26 @@ def getNodeName(node): return node.name -def _is_typing(node, typing_attr, scope_stack): +TYPING_MODULES = frozenset(('typing', 'typing_extensions')) + + +def _is_typing_helper(node, is_name_match_fn, scope_stack): + """ + Internal helper to determine whether or not something is a member of a + typing module. This is used as part of working out whether we are within a + type annotation context. + + Note: you probably don't want to use this function directly. Instead see the + utils below which wrap it (`_is_typing` and `_is_any_typing_member`). + """ + def _bare_name_is_attr(name): - expected_typing_names = { - 'typing.{}'.format(typing_attr), - 'typing_extensions.{}'.format(typing_attr), - } for scope in reversed(scope_stack): if name in scope: return ( isinstance(scope[name], ImportationFrom) and - scope[name].fullName in expected_typing_names + scope[name].module in TYPING_MODULES and + is_name_match_fn(scope[name].real_name) ) return False @@ -685,12 +695,33 @@ def _is_typing(node, typing_attr, scope_stack): ) or ( isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name) and - node.value.id in {'typing', 'typing_extensions'} and - node.attr == typing_attr + node.value.id in TYPING_MODULES and + is_name_match_fn(node.attr) ) ) +def _is_typing(node, typing_attr, scope_stack): + """ + Determine whether `node` represents the member of a typing module specified + by `typing_attr`. + + This is used as part of working out whether we are within a type annotation + context. + """ + return _is_typing_helper(node, lambda x: x == typing_attr, scope_stack) + + +def _is_any_typing_member(node, scope_stack): + """ + Determine whether `node` represents any member of a typing module. + + This is used as part of working out whether we are within a type annotation + context. + """ + return _is_typing_helper(node, lambda x: True, scope_stack) + + def is_typing_overload(value, scope_stack): return ( isinstance(value.source, FUNCTION_TYPES) and @@ -704,11 +735,8 @@ def is_typing_overload(value, scope_stack): def in_annotation(func): @functools.wraps(func) def in_annotation_func(self, *args, **kwargs): - orig, self._in_annotation = self._in_annotation, True - try: + with self._enter_annotation(): return func(self, *args, **kwargs) - finally: - self._in_annotation = orig return in_annotation_func @@ -1236,6 +1264,14 @@ class Checker(object): except KeyError: self.report(messages.UndefinedName, node, name) + @contextlib.contextmanager + def _enter_annotation(self): + orig, self._in_annotation = self._in_annotation, True + try: + yield + finally: + self._in_annotation = orig + def _handle_type_comments(self, node): for (lineno, col_offset), comment in self._type_comments.get(node, ()): comment = comment.split(':', 1)[1].strip() @@ -1428,7 +1464,11 @@ class Checker(object): finally: self._in_typing_literal = orig else: - self.handleChildren(node) + if _is_any_typing_member(node.value, self.scopeStack): + with self._enter_annotation(): + self.handleChildren(node) + else: + self.handleChildren(node) def _handle_string_dot_format(self, node): try: @@ -1557,6 +1597,15 @@ class Checker(object): node.func.attr == 'format' ): self._handle_string_dot_format(node) + + if ( + _is_typing(node.func, 'cast', self.scopeStack) and + len(node.args) >= 1 and + isinstance(node.args[0], ast.Str) + ): + with self._enter_annotation(): + self.handleNode(node.args[0], node) + self.handleChildren(node) def _handle_percent_format(self, node): diff --git a/pyflakes/test/test_type_annotations.py b/pyflakes/test/test_type_annotations.py index 15c658b..4804dda 100644 --- a/pyflakes/test/test_type_annotations.py +++ b/pyflakes/test/test_type_annotations.py @@ -449,6 +449,46 @@ class TestTypeAnnotations(TestCase): return None """) + def test_partially_quoted_type_assignment(self): + self.flakes(""" + from queue import Queue + from typing import Optional + + MaybeQueue = Optional['Queue[str]'] + """) + + def test_nested_partially_quoted_type_assignment(self): + self.flakes(""" + from queue import Queue + from typing import Callable + + Func = Callable[['Queue[str]'], None] + """) + + def test_quoted_type_cast(self): + self.flakes(""" + from typing import cast, Optional + + maybe_int = cast('Optional[int]', 42) + """) + + def test_type_cast_literal_str_to_str(self): + # Checks that our handling of quoted type annotations in the first + # argument to `cast` doesn't cause issues when (only) the _second_ + # argument is a literal str which looks a bit like a type annoation. + self.flakes(""" + from typing import cast + + a_string = cast(str, 'Optional[int]') + """) + + def test_quoted_type_cast_renamed_import(self): + self.flakes(""" + from typing import cast as tsac, Optional as Maybe + + maybe_int = tsac('Maybe[int]', 42) + """) + @skipIf(version_info < (3,), 'new in Python 3') def test_literal_type_typing(self): self.flakes(""" |