diff options
Diffstat (limited to 'lib/sqlalchemy/ext/mypy/plugin.py')
-rw-r--r-- | lib/sqlalchemy/ext/mypy/plugin.py | 215 |
1 files changed, 215 insertions, 0 deletions
diff --git a/lib/sqlalchemy/ext/mypy/plugin.py b/lib/sqlalchemy/ext/mypy/plugin.py new file mode 100644 index 000000000..9fcd09b1e --- /dev/null +++ b/lib/sqlalchemy/ext/mypy/plugin.py @@ -0,0 +1,215 @@ +# ext/mypy/plugin.py +# Copyright (C) 2021 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +Mypy plugin for SQLAlchemy ORM. + +""" +from typing import List +from typing import Tuple +from typing import Type + +from mypy import nodes +from mypy.mro import calculate_mro +from mypy.mro import MroError +from mypy.nodes import Block +from mypy.nodes import ClassDef +from mypy.nodes import GDEF +from mypy.nodes import MypyFile +from mypy.nodes import NameExpr +from mypy.nodes import SymbolTable +from mypy.nodes import SymbolTableNode +from mypy.nodes import TypeInfo +from mypy.plugin import AttributeContext +from mypy.plugin import Callable +from mypy.plugin import ClassDefContext +from mypy.plugin import DynamicClassDefContext +from mypy.plugin import Optional +from mypy.plugin import Plugin +from mypy.types import Instance + +from . import decl_class +from . import names +from . import util + + +class CustomPlugin(Plugin): + def get_dynamic_class_hook( + self, fullname: str + ) -> Optional[Callable[[DynamicClassDefContext], None]]: + if names._type_id_for_fullname(fullname) is names.DECLARATIVE_BASE: + return _dynamic_class_hook + return None + + def get_base_class_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + + # kind of a strange relationship between get_metaclass_hook() + # and get_base_class_hook(). the former doesn't fire off for + # subclasses. but then you can just check it here from the "base" + # and get the same effect. + sym = self.lookup_fully_qualified(fullname) + if ( + sym + and isinstance(sym.node, TypeInfo) + and sym.node.metaclass_type + and names._type_id_for_named_node(sym.node.metaclass_type.type) + is names.DECLARATIVE_META + ): + return _base_cls_hook + return None + + def get_class_decorator_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + + sym = self.lookup_fully_qualified(fullname) + + if ( + sym is not None + and names._type_id_for_named_node(sym.node) + is names.MAPPED_DECORATOR + ): + return _cls_decorator_hook + return None + + def get_customize_class_mro_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + return _fill_in_decorators + + def get_attribute_hook( + self, fullname: str + ) -> Optional[Callable[[AttributeContext], Type]]: + if fullname.startswith( + "sqlalchemy.orm.attributes.QueryableAttribute." + ): + return _queryable_getattr_hook + return None + + def get_additional_deps( + self, file: MypyFile + ) -> List[Tuple[int, str, int]]: + return [ + (10, "sqlalchemy.orm.attributes", -1), + (10, "sqlalchemy.orm.decl_api", -1), + ] + + +def plugin(version: str): + return CustomPlugin + + +def _queryable_getattr_hook(ctx: AttributeContext) -> Type: + # how do I....tell it it has no attribute of a certain name? + # can't find any Type that seems to match that + return ctx.default_attr_type + + +def _fill_in_decorators(ctx: ClassDefContext) -> None: + for decorator in ctx.cls.decorators: + # set the ".fullname" attribute of a class decorator + # that is a MemberExpr. This causes the logic in + # semanal.py->apply_class_plugin_hooks to invoke the + # get_class_decorator_hook for our "registry.map_class()" method. + # this seems like a bug in mypy that these decorators are otherwise + # skipped. + if ( + isinstance(decorator, nodes.MemberExpr) + and decorator.name == "mapped" + ): + + sym = ctx.api.lookup( + decorator.expr.name, decorator, suppress_errors=True + ) + if sym: + if sym.node.type and hasattr(sym.node.type, "type"): + decorator.fullname = ( + f"{sym.node.type.type.fullname}.{decorator.name}" + ) + else: + # if the registry is in the same file as where the + # decorator is used, it might not have semantic + # symbols applied and we can't get a fully qualified + # name or an inferred type, so we are actually going to + # flag an error in this case that they need to annotate + # it. The "registry" is declared just + # once (or few times), so they have to just not use + # type inference for its assignment in this one case. + util.fail( + ctx.api, + "Class decorator called mapped(), but we can't " + "tell if it's from an ORM registry. Please " + "annotate the registry assignment, e.g. " + "my_registry: registry = registry()", + sym.node, + ) + + +def _cls_metadata_hook(ctx: ClassDefContext) -> None: + decl_class._scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api) + + +def _base_cls_hook(ctx: ClassDefContext) -> None: + decl_class._scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api) + + +def _cls_decorator_hook(ctx: ClassDefContext) -> None: + assert isinstance(ctx.reason, nodes.MemberExpr) + expr = ctx.reason.expr + assert names._type_id_for_named_node(expr.node.type.type) is names.REGISTRY + + decl_class._scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api) + + +def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None: + """Generate a declarative Base class when the declarative_base() function + is encountered.""" + + cls = ClassDef(ctx.name, Block([])) + cls.fullname = ctx.api.qualified_name(ctx.name) + + declarative_meta_sym: SymbolTableNode = ctx.api.modules[ + "sqlalchemy.orm.decl_api" + ].names["DeclarativeMeta"] + declarative_meta_typeinfo: TypeInfo = declarative_meta_sym.node + + declarative_meta_name: NameExpr = NameExpr("DeclarativeMeta") + declarative_meta_name.kind = GDEF + declarative_meta_name.fullname = "sqlalchemy.orm.decl_api.DeclarativeMeta" + declarative_meta_name.node = declarative_meta_typeinfo + + cls.metaclass = declarative_meta_name + + declarative_meta_instance = Instance(declarative_meta_typeinfo, []) + + info = TypeInfo(SymbolTable(), cls, ctx.api.cur_mod_id) + info.declared_metaclass = info.metaclass_type = declarative_meta_instance + cls.info = info + + cls_arg = util._get_callexpr_kwarg(ctx.call, "cls") + if cls_arg is not None: + decl_class._scan_declarative_assignments_and_apply_types( + cls_arg.node.defn, ctx.api, is_mixin_scan=True + ) + info.bases = [Instance(cls_arg.node, [])] + else: + obj = ctx.api.builtin_type("builtins.object") + + info.bases = [obj] + + try: + calculate_mro(info) + except MroError: + util.fail( + ctx.api, "Not able to calculate MRO for declarative base", ctx.call + ) + info.bases = [obj] + info.fallback_to_any = True + + ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info)) |