summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/base.py')
-rw-r--r--lib/sqlalchemy/sql/base.py128
1 files changed, 120 insertions, 8 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 2d023c6a6..04cc34480 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -15,11 +15,14 @@ import operator
import re
from .traversals import HasCacheKey # noqa
+from .traversals import MemoizedHasCacheKey # noqa
from .visitors import ClauseVisitor
+from .visitors import ExtendedInternalTraversal
from .visitors import InternalTraversal
from .. import exc
from .. import util
from ..util import HasMemoized
+from ..util import hybridmethod
if util.TYPE_CHECKING:
from types import ModuleType
@@ -433,22 +436,52 @@ class CompileState(object):
__slots__ = ("statement",)
+ plugins = {}
+
@classmethod
def _create(cls, statement, compiler, **kw):
# factory construction.
- # specific CompileState classes here will look for
- # "plugins" in the given statement. From there they will invoke
- # the appropriate plugin constructor if one is found and return
- # the alternate CompileState object.
+ if statement._compile_state_plugin is not None:
+ constructor = cls.plugins.get(
+ (
+ statement._compile_state_plugin,
+ statement.__visit_name__,
+ None,
+ ),
+ cls,
+ )
+ else:
+ constructor = cls
- c = cls.__new__(cls)
- c.__init__(statement, compiler, **kw)
- return c
+ return constructor(statement, compiler, **kw)
def __init__(self, statement, compiler, **kw):
self.statement = statement
+ @classmethod
+ def get_plugin_classmethod(cls, statement, name):
+ if statement._compile_state_plugin is not None:
+ fn = cls.plugins.get(
+ (
+ statement._compile_state_plugin,
+ statement.__visit_name__,
+ name,
+ ),
+ None,
+ )
+ if fn is not None:
+ return fn
+ return getattr(cls, name)
+
+ @classmethod
+ def plugin_for(cls, plugin_name, visit_name, method_name=None):
+ def decorate(fn):
+ cls.plugins[(plugin_name, visit_name, method_name)] = fn
+ return fn
+
+ return decorate
+
class Generative(HasMemoized):
"""Provide a method-chaining pattern in conjunction with the
@@ -479,6 +512,57 @@ class HasCompileState(Generative):
_compile_state_plugin = None
+ _attributes = util.immutabledict()
+
+
+class _MetaOptions(type):
+ """metaclass for the Options class."""
+
+ def __init__(cls, classname, bases, dict_):
+ cls._cache_attrs = tuple(
+ sorted(d for d in dict_ if not d.startswith("__"))
+ )
+ type.__init__(cls, classname, bases, dict_)
+
+ def __add__(self, other):
+ o1 = self()
+ o1.__dict__.update(other)
+ return o1
+
+
+class Options(util.with_metaclass(_MetaOptions)):
+ """A cacheable option dictionary with defaults.
+
+
+ """
+
+ def __init__(self, **kw):
+ self.__dict__.update(kw)
+
+ def __add__(self, other):
+ o1 = self.__class__.__new__(self.__class__)
+ o1.__dict__.update(self.__dict__)
+ o1.__dict__.update(other)
+ return o1
+
+ @hybridmethod
+ def add_to_element(self, name, value):
+ return self + {name: getattr(self, name) + value}
+
+
+class CacheableOptions(Options, HasCacheKey):
+ @hybridmethod
+ def _gen_cache_key(self, anon_map, bindparams):
+ return HasCacheKey._gen_cache_key(self, anon_map, bindparams)
+
+ @_gen_cache_key.classlevel
+ def _gen_cache_key(cls, anon_map, bindparams):
+ return (cls, ())
+
+ @hybridmethod
+ def _generate_cache_key(self):
+ return HasCacheKey._generate_cache_key_for_object(self)
+
class Executable(Generative):
"""Mark a ClauseElement as supporting execution.
@@ -492,7 +576,21 @@ class Executable(Generative):
supports_execution = True
_execution_options = util.immutabledict()
_bind = None
+ _with_options = ()
+ _with_context_options = ()
+ _cache_enable = True
+
+ _executable_traverse_internals = [
+ ("_with_options", ExtendedInternalTraversal.dp_has_cache_key_list),
+ ("_with_context_options", ExtendedInternalTraversal.dp_plain_obj),
+ ("_cache_enable", ExtendedInternalTraversal.dp_plain_obj),
+ ]
+
+ @_generative
+ def _disable_caching(self):
+ self._cache_enable = HasCacheKey()
+ @_generative
def options(self, *options):
"""Apply options to this statement.
@@ -522,7 +620,21 @@ class Executable(Generative):
to the usage of ORM queries
"""
- self._options += options
+ self._with_options += options
+
+ @_generative
+ def _add_context_option(self, callable_, cache_args):
+ """Add a context option to this statement.
+
+ These are callable functions that will
+ be given the CompileState object upon compilation.
+
+ A second argument cache_args is required, which will be combined
+ with the identity of the function itself in order to produce a
+ cache key.
+
+ """
+ self._with_context_options += ((callable_, cache_args),)
@_generative
def execution_options(self, **kw):