diff options
Diffstat (limited to 'lib/sqlalchemy/sql/base.py')
-rw-r--r-- | lib/sqlalchemy/sql/base.py | 23 |
1 files changed, 17 insertions, 6 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index b235f5132..aba80222a 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -515,12 +515,20 @@ class CompileState(object): @classmethod def get_plugin_class(cls, statement): plugin_name = statement._propagate_attrs.get( - "compile_state_plugin", "default" + "compile_state_plugin", None ) + + if plugin_name: + key = (plugin_name, statement._effective_plugin_target) + if key in cls.plugins: + return cls.plugins[key] + + # there's no case where we call upon get_plugin_class() and want + # to get None back, there should always be a default. return that + # if there was no plugin-specific class (e.g. Insert with "orm" + # plugin) try: - return cls.plugins[ - (plugin_name, statement._effective_plugin_target) - ] + return cls.plugins[("default", statement._effective_plugin_target)] except KeyError: return None @@ -1665,7 +1673,7 @@ def _entity_namespace(entity): raise -def _entity_namespace_key(entity, key): +def _entity_namespace_key(entity, key, default=NO_ARG): """Return an entry from an entity_namespace. @@ -1676,7 +1684,10 @@ def _entity_namespace_key(entity, key): try: ns = _entity_namespace(entity) - return getattr(ns, key) + if default is not NO_ARG: + return getattr(ns, key, default) + else: + return getattr(ns, key) except AttributeError as err: util.raise_( exc.InvalidRequestError( |