summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/base.py')
-rw-r--r--lib/sqlalchemy/sql/base.py80
1 files changed, 50 insertions, 30 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 04cc34480..bb606a4d6 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -439,46 +439,53 @@ class CompileState(object):
plugins = {}
@classmethod
- def _create(cls, statement, compiler, **kw):
+ def create_for_statement(cls, statement, compiler, **kw):
# factory construction.
- if statement._compile_state_plugin is not None:
- constructor = cls.plugins.get(
- (
- statement._compile_state_plugin,
- statement.__visit_name__,
- None,
- ),
- cls,
+ if statement._propagate_attrs:
+ plugin_name = statement._propagate_attrs.get(
+ "compile_state_plugin", "default"
)
else:
- constructor = cls
+ plugin_name = "default"
+
+ klass = cls.plugins[(plugin_name, statement.__visit_name__)]
- return constructor(statement, compiler, **kw)
+ if klass is cls:
+ return cls(statement, compiler, **kw)
+ else:
+ return klass.create_for_statement(statement, compiler, **kw)
def __init__(self, statement, compiler, **kw):
self.statement = statement
@classmethod
- def get_plugin_classmethod(cls, statement, name):
- if statement._compile_state_plugin is not None:
- fn = cls.plugins.get(
- (
- statement._compile_state_plugin,
- statement.__visit_name__,
- name,
- ),
- None,
- )
- if fn is not None:
- return fn
- return getattr(cls, name)
+ def get_plugin_class(cls, statement):
+ plugin_name = statement._propagate_attrs.get(
+ "compile_state_plugin", "default"
+ )
+ try:
+ return cls.plugins[(plugin_name, statement.__visit_name__)]
+ except KeyError:
+ return None
@classmethod
- def plugin_for(cls, plugin_name, visit_name, method_name=None):
- def decorate(fn):
- cls.plugins[(plugin_name, visit_name, method_name)] = fn
- return fn
+ def _get_plugin_compile_state_cls(cls, statement, plugin_name):
+ statement_plugin_name = statement._propagate_attrs.get(
+ "compile_state_plugin", "default"
+ )
+ if statement_plugin_name != plugin_name:
+ return None
+ try:
+ return cls.plugins[(plugin_name, statement.__visit_name__)]
+ except KeyError:
+ return None
+
+ @classmethod
+ def plugin_for(cls, plugin_name, visit_name):
+ def decorate(cls_to_decorate):
+ cls.plugins[(plugin_name, visit_name)] = cls_to_decorate
+ return cls_to_decorate
return decorate
@@ -508,12 +515,12 @@ class InPlaceGenerative(HasMemoized):
class HasCompileState(Generative):
"""A class that has a :class:`.CompileState` associated with it."""
- _compile_state_factory = CompileState._create
-
_compile_state_plugin = None
_attributes = util.immutabledict()
+ _compile_state_factory = CompileState.create_for_statement
+
class _MetaOptions(type):
"""metaclass for the Options class."""
@@ -549,6 +556,16 @@ class Options(util.with_metaclass(_MetaOptions)):
def add_to_element(self, name, value):
return self + {name: getattr(self, name) + value}
+ @hybridmethod
+ def _state_dict(self):
+ return self.__dict__
+
+ _state_dict_const = util.immutabledict()
+
+ @_state_dict.classlevel
+ def _state_dict(cls):
+ return cls._state_dict_const
+
class CacheableOptions(Options, HasCacheKey):
@hybridmethod
@@ -590,6 +607,9 @@ class Executable(Generative):
def _disable_caching(self):
self._cache_enable = HasCacheKey()
+ def _get_plugin_compile_state_cls(self, plugin_name):
+ return CompileState._get_plugin_compile_state_cls(self, plugin_name)
+
@_generative
def options(self, *options):
"""Apply options to this statement.