diff options
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r-- | lib/sqlalchemy/sql/base.py | 14 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/coercions.py | 91 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/roles.py | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 12 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/traversals.py | 10 |
5 files changed, 65 insertions, 70 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index bb606a4d6..6415d4b37 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -470,12 +470,7 @@ class CompileState(object): return None @classmethod - def _get_plugin_compile_state_cls(cls, statement, plugin_name): - statement_plugin_name = statement._propagate_attrs.get( - "compile_state_plugin", "default" - ) - if statement_plugin_name != plugin_name: - return None + def _get_plugin_class_for_plugin(cls, statement, plugin_name): try: return cls.plugins[(plugin_name, statement.__visit_name__)] except KeyError: @@ -607,9 +602,6 @@ class Executable(Generative): def _disable_caching(self): self._cache_enable = HasCacheKey() - def _get_plugin_compile_state_cls(self, plugin_name): - return CompileState._get_plugin_compile_state_cls(self, plugin_name) - @_generative def options(self, *options): """Apply options to this statement. @@ -735,7 +727,9 @@ class Executable(Generative): "to execute this construct." % label ) raise exc.UnboundExecutionError(msg) - return e._execute_clauseelement(self, multiparams, params) + return e._execute_clauseelement( + self, multiparams, params, util.immutabledict() + ) @util.deprecated_20( ":meth:`.Executable.scalar`", diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index d8ef0222a..7503faf5b 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -50,16 +50,50 @@ def _document_text_coercion(paramname, meth_rst, param_rst): ) -def expect(role, element, apply_propagate_attrs=None, **kw): +def expect(role, element, apply_propagate_attrs=None, argname=None, **kw): # major case is that we are given a ClauseElement already, skip more # elaborate logic up front if possible impl = _impl_lookup[role] + original_element = element + if not isinstance( element, (elements.ClauseElement, schema.SchemaItem, schema.FetchedValue), ): - resolved = impl._resolve_for_clause_element(element, **kw) + resolved = None + + if impl._resolve_string_only: + resolved = impl._literal_coercion(element, **kw) + else: + + original_element = element + + is_clause_element = False + + while hasattr(element, "__clause_element__"): + is_clause_element = True + if not getattr(element, "is_clause_element", False): + element = element.__clause_element__() + else: + break + + if not is_clause_element: + if impl._use_inspection: + insp = inspection.inspect(element, raiseerr=False) + if insp is not None: + insp._post_inspect + try: + resolved = insp.__clause_element__() + except AttributeError: + impl._raise_for_expected(original_element, argname) + + if resolved is None: + resolved = impl._literal_coercion( + element, argname=argname, **kw + ) + else: + resolved = element else: resolved = element @@ -72,10 +106,12 @@ def expect(role, element, apply_propagate_attrs=None, **kw): if impl._role_class in resolved.__class__.__mro__: if impl._post_coercion: - resolved = impl._post_coercion(resolved, **kw) + resolved = impl._post_coercion(resolved, argname=argname, **kw) return resolved else: - return impl._implicit_coercions(element, resolved, **kw) + return impl._implicit_coercions( + original_element, resolved, argname=argname, **kw + ) def expect_as_key(role, element, **kw): @@ -107,51 +143,13 @@ class RoleImpl(object): raise NotImplementedError() _post_coercion = None + _resolve_string_only = False def __init__(self, role_class): self._role_class = role_class self.name = role_class._role_name self._use_inspection = issubclass(role_class, roles.UsesInspection) - def _resolve_for_clause_element(self, element, argname=None, **kw): - original_element = element - - is_clause_element = False - - while hasattr(element, "__clause_element__"): - is_clause_element = True - if not getattr(element, "is_clause_element", False): - element = element.__clause_element__() - else: - return element - - if not is_clause_element: - if self._use_inspection: - insp = inspection.inspect(element, raiseerr=False) - if insp is not None: - insp._post_inspect - try: - element = insp.__clause_element__() - except AttributeError: - self._raise_for_expected(original_element, argname) - else: - return element - - return self._literal_coercion(element, argname=argname, **kw) - else: - return element - - if self._use_inspection: - insp = inspection.inspect(element, raiseerr=False) - if insp is not None: - insp._post_inspect - try: - element = insp.__clause_element__() - except AttributeError: - self._raise_for_expected(original_element, argname) - - return self._literal_coercion(element, argname=argname, **kw) - def _implicit_coercions(self, element, resolved, argname=None, **kw): self._raise_for_expected(element, argname, resolved) @@ -191,8 +189,7 @@ class _Deannotate(object): class _StringOnly(object): __slots__ = () - def _resolve_for_clause_element(self, element, argname=None, **kw): - return self._literal_coercion(element, **kw) + _resolve_string_only = True class _ReturnsStringKey(object): @@ -465,7 +462,7 @@ class ByOfImpl(_CoerceLiterals, _ColumnCoercions, RoleImpl, roles.ByOfRole): class OrderByImpl(ByOfImpl, RoleImpl): __slots__ = () - def _post_coercion(self, resolved): + def _post_coercion(self, resolved, **kw): if ( isinstance(resolved, self._role_class) and resolved._order_by_label_element is not None @@ -490,7 +487,7 @@ class GroupByImpl(ByOfImpl, RoleImpl): class DMLColumnImpl(_ReturnsStringKey, RoleImpl): __slots__ = () - def _post_coercion(self, element, as_key=False): + def _post_coercion(self, element, as_key=False, **kw): if as_key: return element.key else: diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index d0f4fef60..5a55fe5f2 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +from .. import util + class SQLRole(object): """Define a "role" within a SQL statement structure. @@ -145,16 +147,12 @@ class CoerceTextStatementRole(SQLRole): _role_name = "Executable SQL or text() construct" -# _executable_statement = None - - class StatementRole(CoerceTextStatementRole): _role_name = "Executable SQL or text() construct" _is_future = False - def _get_plugin_compile_state_cls(self, name): - return None + _propagate_attrs = util.immutabledict() class ReturnsRowsRole(StatementRole): diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 008959aec..170e016a5 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -3428,7 +3428,14 @@ class DeprecatedSelectGenerations(object): @CompileState.plugin_for("default", "select") -class SelectState(CompileState): +class SelectState(util.MemoizedSlots, CompileState): + __slots__ = ( + "from_clauses", + "froms", + "columns_plus_names", + "_label_resolve_dict", + ) + class default_select_compile_options(CacheableOptions): _cache_key_traversal = [] @@ -3547,8 +3554,7 @@ class SelectState(CompileState): return froms - @util.memoized_property - def _label_resolve_dict(self): + def _memoized_attr__label_resolve_dict(self): with_cols = dict( (c._resolve_label or c._label or c.key, c) for c in _select_iterables(self.statement._raw_columns) diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 482248ada..a38088a27 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -179,7 +179,7 @@ class HasCacheKey(object): if NO_CACHE in _anon_map: return None else: - return CacheKey(key, bindparams, self) + return CacheKey(key, bindparams) @classmethod def _generate_cache_key_for_object(cls, obj): @@ -190,7 +190,7 @@ class HasCacheKey(object): if NO_CACHE in _anon_map: return None else: - return CacheKey(key, bindparams, obj) + return CacheKey(key, bindparams) class MemoizedHasCacheKey(HasCacheKey, HasMemoized): @@ -199,13 +199,13 @@ class MemoizedHasCacheKey(HasCacheKey, HasMemoized): return HasCacheKey._generate_cache_key(self) -class CacheKey(namedtuple("CacheKey", ["key", "bindparams", "statement"])): +class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])): def __hash__(self): """CacheKey itself is not hashable - hash the .key portion""" return None - def to_offline_string(self, statement_cache, parameters): + def to_offline_string(self, statement_cache, statement, parameters): """generate an "offline string" form of this :class:`.CacheKey` The "offline string" is basically the string SQL for the @@ -222,7 +222,7 @@ class CacheKey(namedtuple("CacheKey", ["key", "bindparams", "statement"])): """ if self.key not in statement_cache: - statement_cache[self.key] = sql_str = str(self.statement) + statement_cache[self.key] = sql_str = str(statement) else: sql_str = statement_cache[self.key] |