diff options
Diffstat (limited to 'lib/sqlalchemy/sql/functions.py')
-rw-r--r-- | lib/sqlalchemy/sql/functions.py | 63 |
1 files changed, 35 insertions, 28 deletions
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index b7b9257b4..3b6da7175 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -37,8 +37,8 @@ from .elements import WithinGroup from .selectable import FromClause from .selectable import Select from .selectable import TableValuedAlias +from .type_api import TypeEngine from .visitors import InternalTraversal -from .visitors import TraversibleType from .. import util @@ -48,7 +48,7 @@ _registry = util.defaultdict(dict) def register_function(identifier, fn, package="_default"): """Associate a callable with a particular func. name. - This is normally called by _GenericMeta, but is also + This is normally called by GenericFunction, but is also available by itself so that a non-Function construct can be associated with the :data:`.func` accessor (i.e. CAST, EXTRACT). @@ -828,7 +828,11 @@ class Function(FunctionElement): ("type", InternalTraversal.dp_type), ] - type = sqltypes.NULLTYPE + name: str + + identifier: str + + type: TypeEngine = sqltypes.NULLTYPE """A :class:`_types.TypeEngine` object which refers to the SQL return type represented by this SQL function. @@ -871,30 +875,7 @@ class Function(FunctionElement): ) -class _GenericMeta(TraversibleType): - def __init__(cls, clsname, bases, clsdict): - if annotation.Annotated not in cls.__mro__: - cls.name = name = clsdict.get("name", clsname) - cls.identifier = identifier = clsdict.get("identifier", name) - package = clsdict.pop("package", "_default") - # legacy - if "__return_type__" in clsdict: - cls.type = clsdict["__return_type__"] - - # Check _register attribute status - cls._register = getattr(cls, "_register", True) - - # Register the function if required - if cls._register: - register_function(identifier, cls, package) - else: - # Set _register to True to register child classes by default - cls._register = True - - super(_GenericMeta, cls).__init__(clsname, bases, clsdict) - - -class GenericFunction(Function, metaclass=_GenericMeta): +class GenericFunction(Function): """Define a 'generic' function. A generic function is a pre-established :class:`.Function` @@ -986,9 +967,34 @@ class GenericFunction(Function, metaclass=_GenericMeta): """ coerce_arguments = True - _register = False inherit_cache = True + name = "GenericFunction" + + def __init_subclass__(cls) -> None: + if annotation.Annotated not in cls.__mro__: + cls._register_generic_function(cls.__name__, cls.__dict__) + super().__init_subclass__() + + @classmethod + def _register_generic_function(cls, clsname, clsdict): + cls.name = name = clsdict.get("name", clsname) + cls.identifier = identifier = clsdict.get("identifier", name) + package = clsdict.get("package", "_default") + # legacy + if "__return_type__" in clsdict: + cls.type = clsdict["__return_type__"] + + # Check _register attribute status + cls._register = getattr(cls, "_register", True) + + # Register the function if required + if cls._register: + register_function(identifier, cls, package) + else: + # Set _register to True to register child classes by default + cls._register = True + def __init__(self, *args, **kwargs): parsed_args = kwargs.pop("_parsed_args", None) if parsed_args is None: @@ -1006,6 +1012,7 @@ class GenericFunction(Function, metaclass=_GenericMeta): self.clause_expr = ClauseList( operator=operators.comma_op, group_contents=True, *parsed_args ).self_group() + self.type = sqltypes.to_instance( kwargs.pop("type_", None) or getattr(self, "type", None) ) |