diff options
Diffstat (limited to 'lib/sqlalchemy/sql/base.py')
-rw-r--r-- | lib/sqlalchemy/sql/base.py | 128 |
1 files changed, 120 insertions, 8 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 2d023c6a6..04cc34480 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -15,11 +15,14 @@ import operator import re from .traversals import HasCacheKey # noqa +from .traversals import MemoizedHasCacheKey # noqa from .visitors import ClauseVisitor +from .visitors import ExtendedInternalTraversal from .visitors import InternalTraversal from .. import exc from .. import util from ..util import HasMemoized +from ..util import hybridmethod if util.TYPE_CHECKING: from types import ModuleType @@ -433,22 +436,52 @@ class CompileState(object): __slots__ = ("statement",) + plugins = {} + @classmethod def _create(cls, statement, compiler, **kw): # factory construction. - # specific CompileState classes here will look for - # "plugins" in the given statement. From there they will invoke - # the appropriate plugin constructor if one is found and return - # the alternate CompileState object. + if statement._compile_state_plugin is not None: + constructor = cls.plugins.get( + ( + statement._compile_state_plugin, + statement.__visit_name__, + None, + ), + cls, + ) + else: + constructor = cls - c = cls.__new__(cls) - c.__init__(statement, compiler, **kw) - return c + return constructor(statement, compiler, **kw) def __init__(self, statement, compiler, **kw): self.statement = statement + @classmethod + def get_plugin_classmethod(cls, statement, name): + if statement._compile_state_plugin is not None: + fn = cls.plugins.get( + ( + statement._compile_state_plugin, + statement.__visit_name__, + name, + ), + None, + ) + if fn is not None: + return fn + return getattr(cls, name) + + @classmethod + def plugin_for(cls, plugin_name, visit_name, method_name=None): + def decorate(fn): + cls.plugins[(plugin_name, visit_name, method_name)] = fn + return fn + + return decorate + class Generative(HasMemoized): """Provide a method-chaining pattern in conjunction with the @@ -479,6 +512,57 @@ class HasCompileState(Generative): _compile_state_plugin = None + _attributes = util.immutabledict() + + +class _MetaOptions(type): + """metaclass for the Options class.""" + + def __init__(cls, classname, bases, dict_): + cls._cache_attrs = tuple( + sorted(d for d in dict_ if not d.startswith("__")) + ) + type.__init__(cls, classname, bases, dict_) + + def __add__(self, other): + o1 = self() + o1.__dict__.update(other) + return o1 + + +class Options(util.with_metaclass(_MetaOptions)): + """A cacheable option dictionary with defaults. + + + """ + + def __init__(self, **kw): + self.__dict__.update(kw) + + def __add__(self, other): + o1 = self.__class__.__new__(self.__class__) + o1.__dict__.update(self.__dict__) + o1.__dict__.update(other) + return o1 + + @hybridmethod + def add_to_element(self, name, value): + return self + {name: getattr(self, name) + value} + + +class CacheableOptions(Options, HasCacheKey): + @hybridmethod + def _gen_cache_key(self, anon_map, bindparams): + return HasCacheKey._gen_cache_key(self, anon_map, bindparams) + + @_gen_cache_key.classlevel + def _gen_cache_key(cls, anon_map, bindparams): + return (cls, ()) + + @hybridmethod + def _generate_cache_key(self): + return HasCacheKey._generate_cache_key_for_object(self) + class Executable(Generative): """Mark a ClauseElement as supporting execution. @@ -492,7 +576,21 @@ class Executable(Generative): supports_execution = True _execution_options = util.immutabledict() _bind = None + _with_options = () + _with_context_options = () + _cache_enable = True + + _executable_traverse_internals = [ + ("_with_options", ExtendedInternalTraversal.dp_has_cache_key_list), + ("_with_context_options", ExtendedInternalTraversal.dp_plain_obj), + ("_cache_enable", ExtendedInternalTraversal.dp_plain_obj), + ] + + @_generative + def _disable_caching(self): + self._cache_enable = HasCacheKey() + @_generative def options(self, *options): """Apply options to this statement. @@ -522,7 +620,21 @@ class Executable(Generative): to the usage of ORM queries """ - self._options += options + self._with_options += options + + @_generative + def _add_context_option(self, callable_, cache_args): + """Add a context option to this statement. + + These are callable functions that will + be given the CompileState object upon compilation. + + A second argument cache_args is required, which will be combined + with the identity of the function itself in order to produce a + cache key. + + """ + self._with_context_options += ((callable_, cache_args),) @_generative def execution_options(self, **kw): |