diff options
Diffstat (limited to 'lib/sqlalchemy/sql/base.py')
-rw-r--r-- | lib/sqlalchemy/sql/base.py | 80 |
1 files changed, 50 insertions, 30 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 04cc34480..bb606a4d6 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -439,46 +439,53 @@ class CompileState(object): plugins = {} @classmethod - def _create(cls, statement, compiler, **kw): + def create_for_statement(cls, statement, compiler, **kw): # factory construction. - if statement._compile_state_plugin is not None: - constructor = cls.plugins.get( - ( - statement._compile_state_plugin, - statement.__visit_name__, - None, - ), - cls, + if statement._propagate_attrs: + plugin_name = statement._propagate_attrs.get( + "compile_state_plugin", "default" ) else: - constructor = cls + plugin_name = "default" + + klass = cls.plugins[(plugin_name, statement.__visit_name__)] - return constructor(statement, compiler, **kw) + if klass is cls: + return cls(statement, compiler, **kw) + else: + return klass.create_for_statement(statement, compiler, **kw) def __init__(self, statement, compiler, **kw): self.statement = statement @classmethod - def get_plugin_classmethod(cls, statement, name): - if statement._compile_state_plugin is not None: - fn = cls.plugins.get( - ( - statement._compile_state_plugin, - statement.__visit_name__, - name, - ), - None, - ) - if fn is not None: - return fn - return getattr(cls, name) + def get_plugin_class(cls, statement): + plugin_name = statement._propagate_attrs.get( + "compile_state_plugin", "default" + ) + try: + return cls.plugins[(plugin_name, statement.__visit_name__)] + except KeyError: + return None @classmethod - def plugin_for(cls, plugin_name, visit_name, method_name=None): - def decorate(fn): - cls.plugins[(plugin_name, visit_name, method_name)] = fn - return fn + def _get_plugin_compile_state_cls(cls, statement, plugin_name): + statement_plugin_name = statement._propagate_attrs.get( + "compile_state_plugin", "default" + ) + if statement_plugin_name != plugin_name: + return None + try: + return cls.plugins[(plugin_name, statement.__visit_name__)] + except KeyError: + return None + + @classmethod + def plugin_for(cls, plugin_name, visit_name): + def decorate(cls_to_decorate): + cls.plugins[(plugin_name, visit_name)] = cls_to_decorate + return cls_to_decorate return decorate @@ -508,12 +515,12 @@ class InPlaceGenerative(HasMemoized): class HasCompileState(Generative): """A class that has a :class:`.CompileState` associated with it.""" - _compile_state_factory = CompileState._create - _compile_state_plugin = None _attributes = util.immutabledict() + _compile_state_factory = CompileState.create_for_statement + class _MetaOptions(type): """metaclass for the Options class.""" @@ -549,6 +556,16 @@ class Options(util.with_metaclass(_MetaOptions)): def add_to_element(self, name, value): return self + {name: getattr(self, name) + value} + @hybridmethod + def _state_dict(self): + return self.__dict__ + + _state_dict_const = util.immutabledict() + + @_state_dict.classlevel + def _state_dict(cls): + return cls._state_dict_const + class CacheableOptions(Options, HasCacheKey): @hybridmethod @@ -590,6 +607,9 @@ class Executable(Generative): def _disable_caching(self): self._cache_enable = HasCacheKey() + def _get_plugin_compile_state_cls(self, plugin_name): + return CompileState._get_plugin_compile_state_cls(self, plugin_name) + @_generative def options(self, *options): """Apply options to this statement. |