summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/base.py32
-rw-r--r--lib/sqlalchemy/sql/elements.py23
-rw-r--r--lib/sqlalchemy/sql/selectable.py32
-rw-r--r--lib/sqlalchemy/sql/traversals.py18
-rw-r--r--lib/sqlalchemy/sql/util.py7
-rw-r--r--lib/sqlalchemy/sql/visitors.py11
6 files changed, 92 insertions, 31 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 6415d4b37..f14319089 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -522,7 +522,12 @@ class _MetaOptions(type):
def __init__(cls, classname, bases, dict_):
cls._cache_attrs = tuple(
- sorted(d for d in dict_ if not d.startswith("__"))
+ sorted(
+ d
+ for d in dict_
+ if not d.startswith("__")
+ and d not in ("_cache_key_traversal",)
+ )
)
type.__init__(cls, classname, bases, dict_)
@@ -561,6 +566,31 @@ class Options(util.with_metaclass(_MetaOptions)):
def _state_dict(cls):
return cls._state_dict_const
+ @classmethod
+ def safe_merge(cls, other):
+ d = other._state_dict()
+
+ # only support a merge with another object of our class
+ # and which does not have attrs that we dont. otherwise
+ # we risk having state that might not be part of our cache
+ # key strategy
+
+ if (
+ cls is not other.__class__
+ and other._cache_attrs
+ and set(other._cache_attrs).difference(cls._cache_attrs)
+ ):
+ raise TypeError(
+ "other element %r is not empty, is not of type %s, "
+ "and contains attributes not covered here %r"
+ % (
+ other,
+ cls,
+ set(other._cache_attrs).difference(cls._cache_attrs),
+ )
+ )
+ return cls + d
+
class CacheableOptions(Options, HasCacheKey):
@hybridmethod
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 287e53724..fa2888a23 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -878,6 +878,7 @@ class ColumnElement(
key = self._proxy_key
else:
key = name
+
co = ColumnClause(
coercions.expect(roles.TruncatedLabelRole, name)
if name_is_truncatable
@@ -885,6 +886,7 @@ class ColumnElement(
type_=getattr(self, "type", None),
_selectable=selectable,
)
+
co._propagate_attrs = selectable._propagate_attrs
co._proxies = [self]
if selectable._is_clone_of is not None:
@@ -1284,6 +1286,7 @@ class BindParameter(roles.InElementRole, ColumnElement):
"""
+
if required is NO_ARG:
required = value is NO_ARG and callable_ is None
if value is NO_ARG:
@@ -1302,6 +1305,7 @@ class BindParameter(roles.InElementRole, ColumnElement):
id(self),
re.sub(r"[%\(\) \$]+", "_", key).strip("_")
if key is not None
+ and not isinstance(key, _anonymous_label)
else "param",
)
)
@@ -4182,16 +4186,27 @@ class Label(roles.LabeledColumnExprRole, ColumnElement):
return self.element._from_objects
def _make_proxy(self, selectable, name=None, **kw):
+ name = self.name if not name else name
+
key, e = self.element._make_proxy(
selectable,
- name=name if name else self.name,
+ name=name,
disallow_is_literal=True,
+ name_is_truncatable=isinstance(name, _truncated_label),
)
+ # TODO: want to remove this assertion at some point. all
+ # _make_proxy() implementations will give us back the key that
+ # is our "name" in the first place. based on this we can
+ # safely return our "self.key" as the key here, to support a new
+ # case where the key and name are separate.
+ assert key == self.name
+
e._propagate_attrs = selectable._propagate_attrs
e._proxies.append(self)
if self._type is not None:
e.type = self._type
- return key, e
+
+ return self.key, e
class ColumnClause(
@@ -4240,7 +4255,7 @@ class ColumnClause(
__visit_name__ = "column"
_traverse_internals = [
- ("name", InternalTraversal.dp_string),
+ ("name", InternalTraversal.dp_anon_name),
("type", InternalTraversal.dp_type),
("table", InternalTraversal.dp_clauseelement),
("is_literal", InternalTraversal.dp_boolean),
@@ -4410,10 +4425,8 @@ class ColumnClause(
def _gen_label(self, name, dedupe_on_key=True):
t = self.table
-
if self.is_literal:
return None
-
elif t is not None and t.named_with_column:
if getattr(t, "schema", None):
label = t.schema.replace(".", "_") + "_" + t.name + "_" + name
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 170e016a5..d6845e05f 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -3451,8 +3451,8 @@ class SelectState(util.MemoizedSlots, CompileState):
self.columns_plus_names = statement._generate_columns_plus_names(True)
def _get_froms(self, statement):
- froms = []
seen = set()
+ froms = []
for item in itertools.chain(
itertools.chain.from_iterable(
@@ -3474,6 +3474,16 @@ class SelectState(util.MemoizedSlots, CompileState):
froms.append(item)
seen.update(item._cloned_set)
+ toremove = set(
+ itertools.chain.from_iterable(
+ [_expand_cloned(f._hide_froms) for f in froms]
+ )
+ )
+ if toremove:
+ # filter out to FROM clauses not in the list,
+ # using a list to maintain ordering
+ froms = [f for f in froms if f not in toremove]
+
return froms
def _get_display_froms(
@@ -3490,16 +3500,6 @@ class SelectState(util.MemoizedSlots, CompileState):
froms = self.froms
- toremove = set(
- itertools.chain.from_iterable(
- [_expand_cloned(f._hide_froms) for f in froms]
- )
- )
- if toremove:
- # filter out to FROM clauses not in the list,
- # using a list to maintain ordering
- froms = [f for f in froms if f not in toremove]
-
if self.statement._correlate:
to_correlate = self.statement._correlate
if to_correlate:
@@ -3557,7 +3557,7 @@ class SelectState(util.MemoizedSlots, CompileState):
def _memoized_attr__label_resolve_dict(self):
with_cols = dict(
(c._resolve_label or c._label or c.key, c)
- for c in _select_iterables(self.statement._raw_columns)
+ for c in self.statement._exported_columns_iterator()
if c._allow_label_resolve
)
only_froms = dict(
@@ -3578,6 +3578,10 @@ class SelectState(util.MemoizedSlots, CompileState):
else:
return None
+ @classmethod
+ def exported_columns_iterator(cls, statement):
+ return _select_iterables(statement._raw_columns)
+
def _setup_joins(self, args):
for (right, onclause, left, flags) in args:
isouter = flags["isouter"]
@@ -4599,7 +4603,7 @@ class Select(
pa = None
collection = []
- for c in _select_iterables(self._raw_columns):
+ for c in self._exported_columns_iterator():
# we use key_label since this name is intended for targeting
# within the ColumnCollection only, it's not related to SQL
# rendering which always uses column name for SQL label names
@@ -4630,7 +4634,7 @@ class Select(
return self
def _generate_columns_plus_names(self, anon_for_dupe_key):
- cols = _select_iterables(self._raw_columns)
+ cols = self._exported_columns_iterator()
# when use_labels is on:
# in all cases == if we see the same label name, use _label_anon_label
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index a38088a27..388097e45 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -18,6 +18,7 @@ NO_CACHE = util.symbol("no_cache")
CACHE_IN_PLACE = util.symbol("cache_in_place")
CALL_GEN_CACHE_KEY = util.symbol("call_gen_cache_key")
STATIC_CACHE_KEY = util.symbol("static_cache_key")
+ANON_NAME = util.symbol("anon_name")
def compare(obj1, obj2, **kw):
@@ -33,6 +34,7 @@ class HasCacheKey(object):
_cache_key_traversal = NO_CACHE
__slots__ = ()
+ @util.preload_module("sqlalchemy.sql.elements")
def _gen_cache_key(self, anon_map, bindparams):
"""return an optional cache key.
@@ -54,6 +56,8 @@ class HasCacheKey(object):
"""
+ elements = util.preloaded.sql_elements
+
idself = id(self)
if anon_map is not None:
@@ -102,6 +106,10 @@ class HasCacheKey(object):
result += (attrname, obj)
elif meth is STATIC_CACHE_KEY:
result += (attrname, obj._static_cache_key)
+ elif meth is ANON_NAME:
+ if elements._anonymous_label in obj.__class__.__mro__:
+ obj = obj.apply_map(anon_map)
+ result += (attrname, obj)
elif meth is CALL_GEN_CACHE_KEY:
result += (
attrname,
@@ -321,6 +329,7 @@ class _CacheKey(ExtendedInternalTraversal):
) = visit_operator = visit_plain_obj = CACHE_IN_PLACE
visit_statement_hint_list = CACHE_IN_PLACE
visit_type = STATIC_CACHE_KEY
+ visit_anon_name = ANON_NAME
def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams):
return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams))
@@ -387,15 +396,6 @@ class _CacheKey(ExtendedInternalTraversal):
attrname, obj, parent, anon_map, bindparams
)
- def visit_anon_name(self, attrname, obj, parent, anon_map, bindparams):
- from . import elements
-
- name = obj
- if isinstance(name, elements._anonymous_label):
- name = name.apply_map(anon_map)
-
- return (attrname, name)
-
def visit_fromclause_ordered_set(
self, attrname, obj, parent, anon_map, bindparams
):
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 377aa4fe0..e8726000b 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -822,9 +822,14 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
# is another join or selectable that contains a table which our
# selectable derives from, that we want to process
return None
+
elif not isinstance(col, ColumnElement):
return None
- elif self.include_fn and not self.include_fn(col):
+
+ if "adapt_column" in col._annotations:
+ col = col._annotations["adapt_column"]
+
+ if self.include_fn and not self.include_fn(col):
return None
elif self.exclude_fn and self.exclude_fn(col):
return None
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 683f545dd..5de68f504 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -50,6 +50,13 @@ def _generate_compiler_dispatch(cls):
"""
visit_name = cls.__visit_name__
+ if "_compiler_dispatch" in cls.__dict__:
+ # class has a fixed _compiler_dispatch() method.
+ # copy it to "original" so that we can get it back if
+ # sqlalchemy.ext.compiles overrides it.
+ cls._original_compiler_dispatch = cls._compiler_dispatch
+ return
+
if not isinstance(visit_name, util.compat.string_types):
raise exc.InvalidRequestError(
"__visit_name__ on class %s must be a string at the class level"
@@ -76,7 +83,9 @@ def _generate_compiler_dispatch(cls):
+ self.__visit_name__ on the visitor, and call it with the same
kw params.
"""
- cls._compiler_dispatch = _compiler_dispatch
+ cls._compiler_dispatch = (
+ cls._original_compiler_dispatch
+ ) = _compiler_dispatch
class TraversibleType(type):