diff options
-rw-r--r-- | pylint/checkers/classes/__init__.py | 370 | ||||
-rw-r--r-- | pylint/checkers/classes/special_methods_checker.py | 386 |
2 files changed, 387 insertions, 369 deletions
diff --git a/pylint/checkers/classes/__init__.py b/pylint/checkers/classes/__init__.py index 52601bc42..15c9cdb6c 100644 --- a/pylint/checkers/classes/__init__.py +++ b/pylint/checkers/classes/__init__.py @@ -58,9 +58,9 @@ import astroid from astroid import nodes from pylint.checkers import BaseChecker, utils +from pylint.checkers.classes.special_methods_checker import SpecialMethodsChecker from pylint.checkers.utils import ( PYMETHODS, - SPECIAL_METHODS_PARAMS, check_messages, class_is_abstract, decorated_with, @@ -71,7 +71,6 @@ from pylint.checkers.utils import ( is_attr_protected, is_builtin_object, is_comprehension, - is_function_body_ellipsis, is_iterable, is_overload_stub, is_property_setter, @@ -86,7 +85,6 @@ from pylint.checkers.utils import ( from pylint.interfaces import IAstroidChecker from pylint.utils import get_global_option -NEXT_METHOD = "__next__" INVALID_BASE_CLASSES = {"bool", "range", "slice", "memoryview"} BUILTIN_DECORATORS = {"builtins.property", "builtins.classmethod"} ASTROID_TYPE_COMPARATORS = { @@ -428,29 +426,6 @@ def _has_bare_super_call(fundef_node): return False -def _safe_infer_call_result(node, caller, context=None): - """ - Safely infer the return value of a function. - - Returns None if inference failed or if there is some ambiguity (more than - one node has been inferred). Otherwise returns inferred value. - """ - try: - inferit = node.infer_call_result(caller, context=context) - value = next(inferit) - except astroid.InferenceError: - return None # inference failed - except StopIteration: - return None # no values inferred - try: - next(inferit) - return None # there is ambiguity on the inferred node - except astroid.InferenceError: - return None # there is some kind of ambiguity - except StopIteration: - return value - - def _has_same_layout_slots(slots, assigned_value): inferred = next(assigned_value.infer()) if isinstance(inferred, nodes.ClassDef): @@ -2065,349 +2040,6 @@ a metaclass class method.", ) -class SpecialMethodsChecker(BaseChecker): - """Checker which verifies that special methods - are implemented correctly. - """ - - __implements__ = (IAstroidChecker,) - name = "classes" - msgs = { - "E0301": ( - "__iter__ returns non-iterator", - "non-iterator-returned", - "Used when an __iter__ method returns something which is not an " - f"iterable (i.e. has no `{NEXT_METHOD}` method)", - { - "old_names": [ - ("W0234", "old-non-iterator-returned-1"), - ("E0234", "old-non-iterator-returned-2"), - ] - }, - ), - "E0302": ( - "The special method %r expects %s param(s), %d %s given", - "unexpected-special-method-signature", - "Emitted when a special method was defined with an " - "invalid number of parameters. If it has too few or " - "too many, it might not work at all.", - {"old_names": [("E0235", "bad-context-manager")]}, - ), - "E0303": ( - "__len__ does not return non-negative integer", - "invalid-length-returned", - "Used when a __len__ method returns something which is not a " - "non-negative integer", - ), - "E0304": ( - "__bool__ does not return bool", - "invalid-bool-returned", - "Used when a __bool__ method returns something which is not a bool", - ), - "E0305": ( - "__index__ does not return int", - "invalid-index-returned", - "Used when an __index__ method returns something which is not " - "an integer", - ), - "E0306": ( - "__repr__ does not return str", - "invalid-repr-returned", - "Used when a __repr__ method returns something which is not a string", - ), - "E0307": ( - "__str__ does not return str", - "invalid-str-returned", - "Used when a __str__ method returns something which is not a string", - ), - "E0308": ( - "__bytes__ does not return bytes", - "invalid-bytes-returned", - "Used when a __bytes__ method returns something which is not bytes", - ), - "E0309": ( - "__hash__ does not return int", - "invalid-hash-returned", - "Used when a __hash__ method returns something which is not an integer", - ), - "E0310": ( - "__length_hint__ does not return non-negative integer", - "invalid-length-hint-returned", - "Used when a __length_hint__ method returns something which is not a " - "non-negative integer", - ), - "E0311": ( - "__format__ does not return str", - "invalid-format-returned", - "Used when a __format__ method returns something which is not a string", - ), - "E0312": ( - "__getnewargs__ does not return a tuple", - "invalid-getnewargs-returned", - "Used when a __getnewargs__ method returns something which is not " - "a tuple", - ), - "E0313": ( - "__getnewargs_ex__ does not return a tuple containing (tuple, dict)", - "invalid-getnewargs-ex-returned", - "Used when a __getnewargs_ex__ method returns something which is not " - "of the form tuple(tuple, dict)", - ), - } - priority = -2 - - def __init__(self, linter=None): - super().__init__(linter) - self._protocol_map = { - "__iter__": self._check_iter, - "__len__": self._check_len, - "__bool__": self._check_bool, - "__index__": self._check_index, - "__repr__": self._check_repr, - "__str__": self._check_str, - "__bytes__": self._check_bytes, - "__hash__": self._check_hash, - "__length_hint__": self._check_length_hint, - "__format__": self._check_format, - "__getnewargs__": self._check_getnewargs, - "__getnewargs_ex__": self._check_getnewargs_ex, - } - - @check_messages( - "unexpected-special-method-signature", - "non-iterator-returned", - "invalid-length-returned", - "invalid-bool-returned", - "invalid-index-returned", - "invalid-repr-returned", - "invalid-str-returned", - "invalid-bytes-returned", - "invalid-hash-returned", - "invalid-length-hint-returned", - "invalid-format-returned", - "invalid-getnewargs-returned", - "invalid-getnewargs-ex-returned", - ) - def visit_functiondef(self, node: nodes.FunctionDef) -> None: - if not node.is_method(): - return - - inferred = _safe_infer_call_result(node, node) - # Only want to check types that we are able to infer - if ( - inferred - and node.name in self._protocol_map - and not is_function_body_ellipsis(node) - ): - self._protocol_map[node.name](node, inferred) - - if node.name in PYMETHODS: - self._check_unexpected_method_signature(node) - - visit_asyncfunctiondef = visit_functiondef - - def _check_unexpected_method_signature(self, node): - expected_params = SPECIAL_METHODS_PARAMS[node.name] - - if expected_params is None: - # This can support a variable number of parameters. - return - if not node.args.args and not node.args.vararg: - # Method has no parameter, will be caught - # by no-method-argument. - return - - if decorated_with(node, ["builtins.staticmethod"]): - # We expect to not take in consideration self. - all_args = node.args.args - else: - all_args = node.args.args[1:] - mandatory = len(all_args) - len(node.args.defaults) - optional = len(node.args.defaults) - current_params = mandatory + optional - - if isinstance(expected_params, tuple): - # The expected number of parameters can be any value from this - # tuple, although the user should implement the method - # to take all of them in consideration. - emit = mandatory not in expected_params - # pylint: disable-next=consider-using-f-string - expected_params = "between %d or %d" % expected_params - else: - # If the number of mandatory parameters doesn't - # suffice, the expected parameters for this - # function will be deduced from the optional - # parameters. - rest = expected_params - mandatory - if rest == 0: - emit = False - elif rest < 0: - emit = True - elif rest > 0: - emit = not ((optional - rest) >= 0 or node.args.vararg) - - if emit: - verb = "was" if current_params <= 1 else "were" - self.add_message( - "unexpected-special-method-signature", - args=(node.name, expected_params, current_params, verb), - node=node, - ) - - @staticmethod - def _is_wrapped_type(node, type_): - return ( - isinstance(node, astroid.Instance) - and node.name == type_ - and not isinstance(node, nodes.Const) - ) - - @staticmethod - def _is_int(node): - if SpecialMethodsChecker._is_wrapped_type(node, "int"): - return True - - return isinstance(node, nodes.Const) and isinstance(node.value, int) - - @staticmethod - def _is_str(node): - if SpecialMethodsChecker._is_wrapped_type(node, "str"): - return True - - return isinstance(node, nodes.Const) and isinstance(node.value, str) - - @staticmethod - def _is_bool(node): - if SpecialMethodsChecker._is_wrapped_type(node, "bool"): - return True - - return isinstance(node, nodes.Const) and isinstance(node.value, bool) - - @staticmethod - def _is_bytes(node): - if SpecialMethodsChecker._is_wrapped_type(node, "bytes"): - return True - - return isinstance(node, nodes.Const) and isinstance(node.value, bytes) - - @staticmethod - def _is_tuple(node): - if SpecialMethodsChecker._is_wrapped_type(node, "tuple"): - return True - - return isinstance(node, nodes.Const) and isinstance(node.value, tuple) - - @staticmethod - def _is_dict(node): - if SpecialMethodsChecker._is_wrapped_type(node, "dict"): - return True - - return isinstance(node, nodes.Const) and isinstance(node.value, dict) - - @staticmethod - def _is_iterator(node): - if node is astroid.Uninferable: - # Just ignore Uninferable objects. - return True - if isinstance(node, astroid.bases.Generator): - # Generators can be iterated. - return True - - if isinstance(node, astroid.Instance): - try: - node.local_attr(NEXT_METHOD) - return True - except astroid.NotFoundError: - pass - elif isinstance(node, nodes.ClassDef): - metaclass = node.metaclass() - if metaclass and isinstance(metaclass, nodes.ClassDef): - try: - metaclass.local_attr(NEXT_METHOD) - return True - except astroid.NotFoundError: - pass - return False - - def _check_iter(self, node, inferred): - if not self._is_iterator(inferred): - self.add_message("non-iterator-returned", node=node) - - def _check_len(self, node, inferred): - if not self._is_int(inferred): - self.add_message("invalid-length-returned", node=node) - elif isinstance(inferred, nodes.Const) and inferred.value < 0: - self.add_message("invalid-length-returned", node=node) - - def _check_bool(self, node, inferred): - if not self._is_bool(inferred): - self.add_message("invalid-bool-returned", node=node) - - def _check_index(self, node, inferred): - if not self._is_int(inferred): - self.add_message("invalid-index-returned", node=node) - - def _check_repr(self, node, inferred): - if not self._is_str(inferred): - self.add_message("invalid-repr-returned", node=node) - - def _check_str(self, node, inferred): - if not self._is_str(inferred): - self.add_message("invalid-str-returned", node=node) - - def _check_bytes(self, node, inferred): - if not self._is_bytes(inferred): - self.add_message("invalid-bytes-returned", node=node) - - def _check_hash(self, node, inferred): - if not self._is_int(inferred): - self.add_message("invalid-hash-returned", node=node) - - def _check_length_hint(self, node, inferred): - if not self._is_int(inferred): - self.add_message("invalid-length-hint-returned", node=node) - elif isinstance(inferred, nodes.Const) and inferred.value < 0: - self.add_message("invalid-length-hint-returned", node=node) - - def _check_format(self, node, inferred): - if not self._is_str(inferred): - self.add_message("invalid-format-returned", node=node) - - def _check_getnewargs(self, node, inferred): - if not self._is_tuple(inferred): - self.add_message("invalid-getnewargs-returned", node=node) - - def _check_getnewargs_ex(self, node, inferred): - if not self._is_tuple(inferred): - self.add_message("invalid-getnewargs-ex-returned", node=node) - return - - if not isinstance(inferred, nodes.Tuple): - # If it's not an astroid.Tuple we can't analyze it further - return - - found_error = False - - if len(inferred.elts) != 2: - found_error = True - else: - for arg, check in ( - (inferred.elts[0], self._is_tuple), - (inferred.elts[1], self._is_dict), - ): - - if isinstance(arg, nodes.Call): - arg = safe_infer(arg) - - if arg and arg is not astroid.Uninferable: - if not check(arg): - found_error = True - break - - if found_error: - self.add_message("invalid-getnewargs-ex-returned", node=node) - - def _ancestors_to_call(klass_node, method="__init__"): """return a dictionary where keys are the list of base classes providing the queried method, and so that should/may be called from the method node diff --git a/pylint/checkers/classes/special_methods_checker.py b/pylint/checkers/classes/special_methods_checker.py new file mode 100644 index 000000000..adfb7b3f1 --- /dev/null +++ b/pylint/checkers/classes/special_methods_checker.py @@ -0,0 +1,386 @@ +# Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html +# For details: https://github.com/PyCQA/pylint/blob/main/LICENSE + +"""Special methods checker and helper function's module""" + +import astroid +from astroid import nodes + +from pylint.checkers import BaseChecker +from pylint.checkers.utils import ( + PYMETHODS, + SPECIAL_METHODS_PARAMS, + check_messages, + decorated_with, + is_function_body_ellipsis, + safe_infer, +) +from pylint.interfaces import IAstroidChecker + +NEXT_METHOD = "__next__" + + +def _safe_infer_call_result(node, caller, context=None): + """ + Safely infer the return value of a function. + + Returns None if inference failed or if there is some ambiguity (more than + one node has been inferred). Otherwise returns inferred value. + """ + try: + inferit = node.infer_call_result(caller, context=context) + value = next(inferit) + except astroid.InferenceError: + return None # inference failed + except StopIteration: + return None # no values inferred + try: + next(inferit) + return None # there is ambiguity on the inferred node + except astroid.InferenceError: + return None # there is some kind of ambiguity + except StopIteration: + return value + + +class SpecialMethodsChecker(BaseChecker): + """Checker which verifies that special methods + are implemented correctly. + """ + + __implements__ = (IAstroidChecker,) + name = "classes" + msgs = { + "E0301": ( + "__iter__ returns non-iterator", + "non-iterator-returned", + "Used when an __iter__ method returns something which is not an " + f"iterable (i.e. has no `{NEXT_METHOD}` method)", + { + "old_names": [ + ("W0234", "old-non-iterator-returned-1"), + ("E0234", "old-non-iterator-returned-2"), + ] + }, + ), + "E0302": ( + "The special method %r expects %s param(s), %d %s given", + "unexpected-special-method-signature", + "Emitted when a special method was defined with an " + "invalid number of parameters. If it has too few or " + "too many, it might not work at all.", + {"old_names": [("E0235", "bad-context-manager")]}, + ), + "E0303": ( + "__len__ does not return non-negative integer", + "invalid-length-returned", + "Used when a __len__ method returns something which is not a " + "non-negative integer", + ), + "E0304": ( + "__bool__ does not return bool", + "invalid-bool-returned", + "Used when a __bool__ method returns something which is not a bool", + ), + "E0305": ( + "__index__ does not return int", + "invalid-index-returned", + "Used when an __index__ method returns something which is not " + "an integer", + ), + "E0306": ( + "__repr__ does not return str", + "invalid-repr-returned", + "Used when a __repr__ method returns something which is not a string", + ), + "E0307": ( + "__str__ does not return str", + "invalid-str-returned", + "Used when a __str__ method returns something which is not a string", + ), + "E0308": ( + "__bytes__ does not return bytes", + "invalid-bytes-returned", + "Used when a __bytes__ method returns something which is not bytes", + ), + "E0309": ( + "__hash__ does not return int", + "invalid-hash-returned", + "Used when a __hash__ method returns something which is not an integer", + ), + "E0310": ( + "__length_hint__ does not return non-negative integer", + "invalid-length-hint-returned", + "Used when a __length_hint__ method returns something which is not a " + "non-negative integer", + ), + "E0311": ( + "__format__ does not return str", + "invalid-format-returned", + "Used when a __format__ method returns something which is not a string", + ), + "E0312": ( + "__getnewargs__ does not return a tuple", + "invalid-getnewargs-returned", + "Used when a __getnewargs__ method returns something which is not " + "a tuple", + ), + "E0313": ( + "__getnewargs_ex__ does not return a tuple containing (tuple, dict)", + "invalid-getnewargs-ex-returned", + "Used when a __getnewargs_ex__ method returns something which is not " + "of the form tuple(tuple, dict)", + ), + } + priority = -2 + + def __init__(self, linter=None): + super().__init__(linter) + self._protocol_map = { + "__iter__": self._check_iter, + "__len__": self._check_len, + "__bool__": self._check_bool, + "__index__": self._check_index, + "__repr__": self._check_repr, + "__str__": self._check_str, + "__bytes__": self._check_bytes, + "__hash__": self._check_hash, + "__length_hint__": self._check_length_hint, + "__format__": self._check_format, + "__getnewargs__": self._check_getnewargs, + "__getnewargs_ex__": self._check_getnewargs_ex, + } + + @check_messages( + "unexpected-special-method-signature", + "non-iterator-returned", + "invalid-length-returned", + "invalid-bool-returned", + "invalid-index-returned", + "invalid-repr-returned", + "invalid-str-returned", + "invalid-bytes-returned", + "invalid-hash-returned", + "invalid-length-hint-returned", + "invalid-format-returned", + "invalid-getnewargs-returned", + "invalid-getnewargs-ex-returned", + ) + def visit_functiondef(self, node: nodes.FunctionDef) -> None: + if not node.is_method(): + return + + inferred = _safe_infer_call_result(node, node) + # Only want to check types that we are able to infer + if ( + inferred + and node.name in self._protocol_map + and not is_function_body_ellipsis(node) + ): + self._protocol_map[node.name](node, inferred) + + if node.name in PYMETHODS: + self._check_unexpected_method_signature(node) + + visit_asyncfunctiondef = visit_functiondef + + def _check_unexpected_method_signature(self, node): + expected_params = SPECIAL_METHODS_PARAMS[node.name] + + if expected_params is None: + # This can support a variable number of parameters. + return + if not node.args.args and not node.args.vararg: + # Method has no parameter, will be caught + # by no-method-argument. + return + + if decorated_with(node, ["builtins.staticmethod"]): + # We expect to not take in consideration self. + all_args = node.args.args + else: + all_args = node.args.args[1:] + mandatory = len(all_args) - len(node.args.defaults) + optional = len(node.args.defaults) + current_params = mandatory + optional + + if isinstance(expected_params, tuple): + # The expected number of parameters can be any value from this + # tuple, although the user should implement the method + # to take all of them in consideration. + emit = mandatory not in expected_params + # pylint: disable-next=consider-using-f-string + expected_params = "between %d or %d" % expected_params + else: + # If the number of mandatory parameters doesn't + # suffice, the expected parameters for this + # function will be deduced from the optional + # parameters. + rest = expected_params - mandatory + if rest == 0: + emit = False + elif rest < 0: + emit = True + elif rest > 0: + emit = not ((optional - rest) >= 0 or node.args.vararg) + + if emit: + verb = "was" if current_params <= 1 else "were" + self.add_message( + "unexpected-special-method-signature", + args=(node.name, expected_params, current_params, verb), + node=node, + ) + + @staticmethod + def _is_wrapped_type(node, type_): + return ( + isinstance(node, astroid.Instance) + and node.name == type_ + and not isinstance(node, nodes.Const) + ) + + @staticmethod + def _is_int(node): + if SpecialMethodsChecker._is_wrapped_type(node, "int"): + return True + + return isinstance(node, nodes.Const) and isinstance(node.value, int) + + @staticmethod + def _is_str(node): + if SpecialMethodsChecker._is_wrapped_type(node, "str"): + return True + + return isinstance(node, nodes.Const) and isinstance(node.value, str) + + @staticmethod + def _is_bool(node): + if SpecialMethodsChecker._is_wrapped_type(node, "bool"): + return True + + return isinstance(node, nodes.Const) and isinstance(node.value, bool) + + @staticmethod + def _is_bytes(node): + if SpecialMethodsChecker._is_wrapped_type(node, "bytes"): + return True + + return isinstance(node, nodes.Const) and isinstance(node.value, bytes) + + @staticmethod + def _is_tuple(node): + if SpecialMethodsChecker._is_wrapped_type(node, "tuple"): + return True + + return isinstance(node, nodes.Const) and isinstance(node.value, tuple) + + @staticmethod + def _is_dict(node): + if SpecialMethodsChecker._is_wrapped_type(node, "dict"): + return True + + return isinstance(node, nodes.Const) and isinstance(node.value, dict) + + @staticmethod + def _is_iterator(node): + if node is astroid.Uninferable: + # Just ignore Uninferable objects. + return True + if isinstance(node, astroid.bases.Generator): + # Generators can be iterated. + return True + + if isinstance(node, astroid.Instance): + try: + node.local_attr(NEXT_METHOD) + return True + except astroid.NotFoundError: + pass + elif isinstance(node, nodes.ClassDef): + metaclass = node.metaclass() + if metaclass and isinstance(metaclass, nodes.ClassDef): + try: + metaclass.local_attr(NEXT_METHOD) + return True + except astroid.NotFoundError: + pass + return False + + def _check_iter(self, node, inferred): + if not self._is_iterator(inferred): + self.add_message("non-iterator-returned", node=node) + + def _check_len(self, node, inferred): + if not self._is_int(inferred): + self.add_message("invalid-length-returned", node=node) + elif isinstance(inferred, nodes.Const) and inferred.value < 0: + self.add_message("invalid-length-returned", node=node) + + def _check_bool(self, node, inferred): + if not self._is_bool(inferred): + self.add_message("invalid-bool-returned", node=node) + + def _check_index(self, node, inferred): + if not self._is_int(inferred): + self.add_message("invalid-index-returned", node=node) + + def _check_repr(self, node, inferred): + if not self._is_str(inferred): + self.add_message("invalid-repr-returned", node=node) + + def _check_str(self, node, inferred): + if not self._is_str(inferred): + self.add_message("invalid-str-returned", node=node) + + def _check_bytes(self, node, inferred): + if not self._is_bytes(inferred): + self.add_message("invalid-bytes-returned", node=node) + + def _check_hash(self, node, inferred): + if not self._is_int(inferred): + self.add_message("invalid-hash-returned", node=node) + + def _check_length_hint(self, node, inferred): + if not self._is_int(inferred): + self.add_message("invalid-length-hint-returned", node=node) + elif isinstance(inferred, nodes.Const) and inferred.value < 0: + self.add_message("invalid-length-hint-returned", node=node) + + def _check_format(self, node, inferred): + if not self._is_str(inferred): + self.add_message("invalid-format-returned", node=node) + + def _check_getnewargs(self, node, inferred): + if not self._is_tuple(inferred): + self.add_message("invalid-getnewargs-returned", node=node) + + def _check_getnewargs_ex(self, node, inferred): + if not self._is_tuple(inferred): + self.add_message("invalid-getnewargs-ex-returned", node=node) + return + + if not isinstance(inferred, nodes.Tuple): + # If it's not an astroid.Tuple we can't analyze it further + return + + found_error = False + + if len(inferred.elts) != 2: + found_error = True + else: + for arg, check in ( + (inferred.elts[0], self._is_tuple), + (inferred.elts[1], self._is_dict), + ): + + if isinstance(arg, nodes.Call): + arg = safe_infer(arg) + + if arg and arg is not astroid.Uninferable: + if not check(arg): + found_error = True + break + + if found_error: + self.add_message("invalid-getnewargs-ex-returned", node=node) |