diff options
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r-- | lib/sqlalchemy/sql/base.py | 32 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/elements.py | 23 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 32 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/traversals.py | 18 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 7 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 11 |
6 files changed, 92 insertions, 31 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 6415d4b37..f14319089 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -522,7 +522,12 @@ class _MetaOptions(type): def __init__(cls, classname, bases, dict_): cls._cache_attrs = tuple( - sorted(d for d in dict_ if not d.startswith("__")) + sorted( + d + for d in dict_ + if not d.startswith("__") + and d not in ("_cache_key_traversal",) + ) ) type.__init__(cls, classname, bases, dict_) @@ -561,6 +566,31 @@ class Options(util.with_metaclass(_MetaOptions)): def _state_dict(cls): return cls._state_dict_const + @classmethod + def safe_merge(cls, other): + d = other._state_dict() + + # only support a merge with another object of our class + # and which does not have attrs that we dont. otherwise + # we risk having state that might not be part of our cache + # key strategy + + if ( + cls is not other.__class__ + and other._cache_attrs + and set(other._cache_attrs).difference(cls._cache_attrs) + ): + raise TypeError( + "other element %r is not empty, is not of type %s, " + "and contains attributes not covered here %r" + % ( + other, + cls, + set(other._cache_attrs).difference(cls._cache_attrs), + ) + ) + return cls + d + class CacheableOptions(Options, HasCacheKey): @hybridmethod diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 287e53724..fa2888a23 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -878,6 +878,7 @@ class ColumnElement( key = self._proxy_key else: key = name + co = ColumnClause( coercions.expect(roles.TruncatedLabelRole, name) if name_is_truncatable @@ -885,6 +886,7 @@ class ColumnElement( type_=getattr(self, "type", None), _selectable=selectable, ) + co._propagate_attrs = selectable._propagate_attrs co._proxies = [self] if selectable._is_clone_of is not None: @@ -1284,6 +1286,7 @@ class BindParameter(roles.InElementRole, ColumnElement): """ + if required is NO_ARG: required = value is NO_ARG and callable_ is None if value is NO_ARG: @@ -1302,6 +1305,7 @@ class BindParameter(roles.InElementRole, ColumnElement): id(self), re.sub(r"[%\(\) \$]+", "_", key).strip("_") if key is not None + and not isinstance(key, _anonymous_label) else "param", ) ) @@ -4182,16 +4186,27 @@ class Label(roles.LabeledColumnExprRole, ColumnElement): return self.element._from_objects def _make_proxy(self, selectable, name=None, **kw): + name = self.name if not name else name + key, e = self.element._make_proxy( selectable, - name=name if name else self.name, + name=name, disallow_is_literal=True, + name_is_truncatable=isinstance(name, _truncated_label), ) + # TODO: want to remove this assertion at some point. all + # _make_proxy() implementations will give us back the key that + # is our "name" in the first place. based on this we can + # safely return our "self.key" as the key here, to support a new + # case where the key and name are separate. + assert key == self.name + e._propagate_attrs = selectable._propagate_attrs e._proxies.append(self) if self._type is not None: e.type = self._type - return key, e + + return self.key, e class ColumnClause( @@ -4240,7 +4255,7 @@ class ColumnClause( __visit_name__ = "column" _traverse_internals = [ - ("name", InternalTraversal.dp_string), + ("name", InternalTraversal.dp_anon_name), ("type", InternalTraversal.dp_type), ("table", InternalTraversal.dp_clauseelement), ("is_literal", InternalTraversal.dp_boolean), @@ -4410,10 +4425,8 @@ class ColumnClause( def _gen_label(self, name, dedupe_on_key=True): t = self.table - if self.is_literal: return None - elif t is not None and t.named_with_column: if getattr(t, "schema", None): label = t.schema.replace(".", "_") + "_" + t.name + "_" + name diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 170e016a5..d6845e05f 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -3451,8 +3451,8 @@ class SelectState(util.MemoizedSlots, CompileState): self.columns_plus_names = statement._generate_columns_plus_names(True) def _get_froms(self, statement): - froms = [] seen = set() + froms = [] for item in itertools.chain( itertools.chain.from_iterable( @@ -3474,6 +3474,16 @@ class SelectState(util.MemoizedSlots, CompileState): froms.append(item) seen.update(item._cloned_set) + toremove = set( + itertools.chain.from_iterable( + [_expand_cloned(f._hide_froms) for f in froms] + ) + ) + if toremove: + # filter out to FROM clauses not in the list, + # using a list to maintain ordering + froms = [f for f in froms if f not in toremove] + return froms def _get_display_froms( @@ -3490,16 +3500,6 @@ class SelectState(util.MemoizedSlots, CompileState): froms = self.froms - toremove = set( - itertools.chain.from_iterable( - [_expand_cloned(f._hide_froms) for f in froms] - ) - ) - if toremove: - # filter out to FROM clauses not in the list, - # using a list to maintain ordering - froms = [f for f in froms if f not in toremove] - if self.statement._correlate: to_correlate = self.statement._correlate if to_correlate: @@ -3557,7 +3557,7 @@ class SelectState(util.MemoizedSlots, CompileState): def _memoized_attr__label_resolve_dict(self): with_cols = dict( (c._resolve_label or c._label or c.key, c) - for c in _select_iterables(self.statement._raw_columns) + for c in self.statement._exported_columns_iterator() if c._allow_label_resolve ) only_froms = dict( @@ -3578,6 +3578,10 @@ class SelectState(util.MemoizedSlots, CompileState): else: return None + @classmethod + def exported_columns_iterator(cls, statement): + return _select_iterables(statement._raw_columns) + def _setup_joins(self, args): for (right, onclause, left, flags) in args: isouter = flags["isouter"] @@ -4599,7 +4603,7 @@ class Select( pa = None collection = [] - for c in _select_iterables(self._raw_columns): + for c in self._exported_columns_iterator(): # we use key_label since this name is intended for targeting # within the ColumnCollection only, it's not related to SQL # rendering which always uses column name for SQL label names @@ -4630,7 +4634,7 @@ class Select( return self def _generate_columns_plus_names(self, anon_for_dupe_key): - cols = _select_iterables(self._raw_columns) + cols = self._exported_columns_iterator() # when use_labels is on: # in all cases == if we see the same label name, use _label_anon_label diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index a38088a27..388097e45 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -18,6 +18,7 @@ NO_CACHE = util.symbol("no_cache") CACHE_IN_PLACE = util.symbol("cache_in_place") CALL_GEN_CACHE_KEY = util.symbol("call_gen_cache_key") STATIC_CACHE_KEY = util.symbol("static_cache_key") +ANON_NAME = util.symbol("anon_name") def compare(obj1, obj2, **kw): @@ -33,6 +34,7 @@ class HasCacheKey(object): _cache_key_traversal = NO_CACHE __slots__ = () + @util.preload_module("sqlalchemy.sql.elements") def _gen_cache_key(self, anon_map, bindparams): """return an optional cache key. @@ -54,6 +56,8 @@ class HasCacheKey(object): """ + elements = util.preloaded.sql_elements + idself = id(self) if anon_map is not None: @@ -102,6 +106,10 @@ class HasCacheKey(object): result += (attrname, obj) elif meth is STATIC_CACHE_KEY: result += (attrname, obj._static_cache_key) + elif meth is ANON_NAME: + if elements._anonymous_label in obj.__class__.__mro__: + obj = obj.apply_map(anon_map) + result += (attrname, obj) elif meth is CALL_GEN_CACHE_KEY: result += ( attrname, @@ -321,6 +329,7 @@ class _CacheKey(ExtendedInternalTraversal): ) = visit_operator = visit_plain_obj = CACHE_IN_PLACE visit_statement_hint_list = CACHE_IN_PLACE visit_type = STATIC_CACHE_KEY + visit_anon_name = ANON_NAME def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams): return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams)) @@ -387,15 +396,6 @@ class _CacheKey(ExtendedInternalTraversal): attrname, obj, parent, anon_map, bindparams ) - def visit_anon_name(self, attrname, obj, parent, anon_map, bindparams): - from . import elements - - name = obj - if isinstance(name, elements._anonymous_label): - name = name.apply_map(anon_map) - - return (attrname, name) - def visit_fromclause_ordered_set( self, attrname, obj, parent, anon_map, bindparams ): diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 377aa4fe0..e8726000b 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -822,9 +822,14 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): # is another join or selectable that contains a table which our # selectable derives from, that we want to process return None + elif not isinstance(col, ColumnElement): return None - elif self.include_fn and not self.include_fn(col): + + if "adapt_column" in col._annotations: + col = col._annotations["adapt_column"] + + if self.include_fn and not self.include_fn(col): return None elif self.exclude_fn and self.exclude_fn(col): return None diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 683f545dd..5de68f504 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -50,6 +50,13 @@ def _generate_compiler_dispatch(cls): """ visit_name = cls.__visit_name__ + if "_compiler_dispatch" in cls.__dict__: + # class has a fixed _compiler_dispatch() method. + # copy it to "original" so that we can get it back if + # sqlalchemy.ext.compiles overrides it. + cls._original_compiler_dispatch = cls._compiler_dispatch + return + if not isinstance(visit_name, util.compat.string_types): raise exc.InvalidRequestError( "__visit_name__ on class %s must be a string at the class level" @@ -76,7 +83,9 @@ def _generate_compiler_dispatch(cls): + self.__visit_name__ on the visitor, and call it with the same kw params. """ - cls._compiler_dispatch = _compiler_dispatch + cls._compiler_dispatch = ( + cls._original_compiler_dispatch + ) = _compiler_dispatch class TraversibleType(type): |