summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/selectable.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2020-03-29 14:24:39 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2020-03-29 22:27:41 -0400
commita65d5c250e9fd7090311ef12f28d7d959c6c738e (patch)
tree349f0b4c1127c3d87d9cffb62b5d6e02979a3a9a /lib/sqlalchemy/sql/selectable.py
parent8e857e3f6beecf7510f741428d8d0ba24f5cb71b (diff)
downloadsqlalchemy-a65d5c250e9fd7090311ef12f28d7d959c6c738e.tar.gz
Add a third labeling mode for SELECT statements
Enhanced the disambiguating labels feature of the :func:`~.sql.expression.select` construct such that when a select statement is used in a subquery, repeated column names from different tables are now automatically labeled with a unique label name, without the need to use the full "apply_labels()" feature that conbines tablename plus column name. The disambigated labels are available as plain string keys in the .c collection of the subquery, and most importantly the feature allows an ORM :func:`.orm.aliased` construct against the combination of an entity and an arbitrary subquery to work correctly, targeting the correct columns despite same-named columns in the source tables, without the need for an "apply labels" warning. The existing labeling style is now called LABEL_STYLE_TABLENAME_PLUS_COL. This labeling style will remain used throughout the ORM as has been the case for over a decade, however, the new disambiguation scheme could theoretically replace this scheme entirely. The new scheme would dramatically alter how SQL looks when rendered from the ORM to be more succinct but arguably harder to read. The tablename_columnname scheme used by Join.c is unaffected here, as that's still hardcoded to that scheme. Fixes: #5221 Change-Id: Ib47d9e0f35046b3afc77bef6e65709b93d0c3026
Diffstat (limited to 'lib/sqlalchemy/sql/selectable.py')
-rw-r--r--lib/sqlalchemy/sql/selectable.py190
1 files changed, 152 insertions, 38 deletions
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 5ffdf23d8..4eab60801 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -36,6 +36,7 @@ from .base import Generative
from .base import HasCompileState
from .base import HasMemoized
from .base import Immutable
+from .base import prefix_anon_map
from .coercions import _document_text_coercion
from .elements import _anonymous_label
from .elements import and_
@@ -1770,7 +1771,9 @@ class Subquery(AliasedReturnsRows):
"use the :meth:`.Query.scalar_subquery` method.",
)
def as_scalar(self):
- return self.element.scalar_subquery()
+ return self.element._set_label_style(
+ LABEL_STYLE_NONE
+ ).scalar_subquery()
class FromGrouping(GroupedElement, FromClause):
@@ -2281,6 +2284,9 @@ class SelectBase(
:meth:`.SelectBase.scalar_subquery`.
"""
+ if self._label_style is not LABEL_STYLE_NONE:
+ self = self._set_label_style(LABEL_STYLE_NONE)
+
return ScalarSelect(self)
def label(self, name):
@@ -2356,7 +2362,16 @@ class SelectBase(
.. versionadded:: 1.4
"""
- return Subquery._construct(self, name)
+
+ return Subquery._construct(self._ensure_disambiguated_names(), name)
+
+ def _ensure_disambiguated_names(self):
+ """Ensure that the names generated by this selectbase will be
+ disambiguated in some way, if possible.
+
+ """
+
+ raise NotImplementedError()
def alias(self, name=None, flat=False):
"""Return a named subquery against this :class:`.SelectBase`.
@@ -2391,6 +2406,17 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
# type: (SelectBase) -> None
self.element = coercions.expect(roles.SelectStatementRole, element)
+ def _ensure_disambiguated_names(self):
+ new_element = self.element._ensure_disambiguated_names()
+ if new_element is not self.element:
+ return SelectStatementGrouping(new_element)
+ else:
+ return self
+
+ @property
+ def _label_style(self):
+ return self.element._label_style
+
@property
def select_statement(self):
return self.element
@@ -2470,6 +2496,11 @@ class DeprecatedSelectBaseGenerations(object):
self.group_by.non_generative(self, *clauses)
+LABEL_STYLE_NONE = util.symbol("LABEL_STYLE_NONE")
+LABEL_STYLE_TABLENAME_PLUS_COL = util.symbol("LABEL_STYLE_TABLENAME_PLUS_COL")
+LABEL_STYLE_DISAMBIGUATE_ONLY = util.symbol("LABEL_STYLE_DISAMBIGUATE_ONLY")
+
+
class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
"""Base class for SELECT statements where additional elements can be
added.
@@ -2491,6 +2522,7 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
def __init__(
self,
+ _label_style=LABEL_STYLE_NONE,
use_labels=False,
limit=None,
offset=None,
@@ -2498,7 +2530,10 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
group_by=None,
bind=None,
):
- self.use_labels = use_labels
+ if use_labels:
+ _label_style = LABEL_STYLE_TABLENAME_PLUS_COL
+
+ self._label_style = _label_style
if limit is not None:
self.limit.non_generative(self, limit)
@@ -2572,7 +2607,10 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
key_share=key_share,
)
- @_generative
+ @property
+ def use_labels(self):
+ return self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL
+
def apply_labels(self):
"""return a new selectable with the 'use_labels' flag set to True.
@@ -2584,7 +2622,13 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
"""
- self.use_labels = True
+ return self._set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+
+ def _set_label_style(self, style):
+ if self._label_style is not style:
+ self = self._generate()
+ self._label_style = style
+ return self
@property
def _group_by_clause(self):
@@ -2951,6 +2995,22 @@ class CompoundSelect(HasCompileState, GenerativeSelect):
return True
return False
+ def _set_label_style(self, style):
+ if self._label_style is not style:
+ self = self._generate()
+ select_0 = self.selects[0]._set_label_style(style)
+ self.selects = [select_0] + self.selects[1:]
+
+ return self
+
+ def _ensure_disambiguated_names(self):
+ new_select = self.selects[0]._ensure_disambiguated_names()
+ if new_select is not self.selects[0]:
+ self = self._generate()
+ self.selects = [new_select] + self.selects[1:]
+
+ return self
+
def _generate_fromclause_column_proxies(self, subquery):
# this is a slightly hacky thing - the union exports a
@@ -2963,8 +3023,8 @@ class CompoundSelect(HasCompileState, GenerativeSelect):
# ForeignKeys in. this would allow the union() to have all
# those fks too.
select_0 = self.selects[0]
- if self.use_labels:
- select_0 = select_0.apply_labels()
+ if self._label_style is not LABEL_STYLE_NONE:
+ select_0 = select_0._set_label_style(self._label_style)
select_0._generate_fromclause_column_proxies(subquery)
# hand-construct the "_proxies" collection to include all
@@ -3299,6 +3359,7 @@ class Select(
("_hints", InternalTraversal.dp_table_hint_list),
("_distinct", InternalTraversal.dp_boolean),
("_distinct_on", InternalTraversal.dp_clauseelement_list),
+ ("_label_style", InternalTraversal.dp_plain_obj),
]
+ HasPrefixes._has_prefixes_traverse_internals
+ HasSuffixes._has_suffixes_traverse_internals
@@ -4081,29 +4142,35 @@ class Select(
"""
names = set()
+ pa = None
+ collection = []
- cols = _select_iterables(self._raw_columns)
-
- def name_for_col(c):
+ for c in _select_iterables(self._raw_columns):
# we use key_label since this name is intended for targeting
# within the ColumnCollection only, it's not related to SQL
# rendering which always uses column name for SQL label names
- if self.use_labels:
+ if self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL:
name = c._key_label
else:
name = c._proxy_key
if name in names:
- if self.use_labels:
- name = c._label_anon_label
+ if pa is None:
+ pa = prefix_anon_map()
+
+ if self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL:
+ name = c._label_anon_label % pa
else:
- name = c.anon_label
+ name = c.anon_label % pa
else:
names.add(name)
- return name
+ collection.append((name, c))
- return ColumnCollection(
- (name_for_col(c), c) for c in cols
- ).as_immutable()
+ return ColumnCollection(collection).as_immutable()
+
+ def _ensure_disambiguated_names(self):
+ if self._label_style is LABEL_STYLE_NONE:
+ self = self._set_label_style(LABEL_STYLE_DISAMBIGUATE_ONLY)
+ return self
def _generate_columns_plus_names(self, anon_for_dupe_key):
cols = _select_iterables(self._raw_columns)
@@ -4117,30 +4184,54 @@ class Select(
# anon label, apply _dedupe_label_anon_label to the subsequent
# occurrences of it.
- if self.use_labels:
+ if self._label_style is LABEL_STYLE_NONE:
+ # don't generate any labels
+ same_cols = set()
+
+ return [
+ (None, c, c in same_cols or same_cols.add(c)) for c in cols
+ ]
+ else:
names = {}
+ use_tablename_labels = (
+ self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL
+ )
+
def name_for_col(c):
if not c._render_label_in_columns_clause:
return (None, c, False)
- elif c._label is None:
+ elif use_tablename_labels:
+ if c._label is None:
+ repeated = c.anon_label in names
+ names[c.anon_label] = c
+ return (None, c, repeated)
+ elif getattr(c, "name", None) is None:
+ # this is a scalar_select(). need to improve this case
repeated = c.anon_label in names
names[c.anon_label] = c
return (None, c, repeated)
- name = c._label
+ if use_tablename_labels:
+ name = effective_name = c._label
+ else:
+ name = None
+ effective_name = c.name
repeated = False
- if name in names:
+ if effective_name in names:
# when looking to see if names[name] is the same column as
# c, use hash(), so that an annotated version of the column
# is seen as the same as the non-annotated
- if hash(names[name]) != hash(c):
+ if hash(names[effective_name]) != hash(c):
# different column under the same name. apply
# disambiguating label
- name = c._label_anon_label
+ if use_tablename_labels:
+ name = c._label_anon_label
+ else:
+ name = c.anon_label
if anon_for_dupe_key and name in names:
# here, c._label_anon_label is definitely unique to
@@ -4154,27 +4245,26 @@ class Select(
# already present. apply the "dedupe" label to
# subsequent occurrences of the column so that the
# original stays non-ambiguous
- name = c._dedupe_label_anon_label
+ if use_tablename_labels:
+ name = c._dedupe_label_anon_label
+ else:
+ name = c._dedupe_anon_label
repeated = True
else:
names[name] = c
elif anon_for_dupe_key:
# same column under the same name. apply the "dedupe"
# label so that the original stays non-ambiguous
- name = c._dedupe_label_anon_label
+ if use_tablename_labels:
+ name = c._dedupe_label_anon_label
+ else:
+ name = c._dedupe_anon_label
repeated = True
else:
- names[name] = c
+ names[effective_name] = c
return name, c, repeated
return [name_for_col(c) for c in cols]
- else:
- # repeated name logic only for use labels at the moment
- same_cols = set()
-
- return [
- (None, c, c in same_cols or same_cols.add(c)) for c in cols
- ]
def _generate_fromclause_column_proxies(self, subquery):
"""generate column proxies to place in the exported .c collection
@@ -4183,17 +4273,33 @@ class Select(
keys_seen = set()
prox = []
+ pa = None
+
+ tablename_plus_col = (
+ self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL
+ )
+ disambiguate_only = self._label_style is LABEL_STYLE_DISAMBIGUATE_ONLY
+
for name, c, repeated in self._generate_columns_plus_names(False):
if not hasattr(c, "_make_proxy"):
continue
- if name is None:
- key = None
- elif self.use_labels:
+ elif tablename_plus_col:
key = c._key_label
if key is not None and key in keys_seen:
- key = c._label_anon_label
+ if pa is None:
+ pa = prefix_anon_map()
+ key = c._label_anon_label % pa
+ keys_seen.add(key)
+ elif disambiguate_only:
+ key = c.key
+ if key is not None and key in keys_seen:
+ if pa is None:
+ pa = prefix_anon_map()
+ key = c.anon_label % pa
keys_seen.add(key)
else:
+ # one of the above label styles is set for subqueries
+ # as of #5221 so this codepath is likely not called now.
key = None
prox.append(
c._make_proxy(
@@ -4435,6 +4541,8 @@ class TextualSelect(SelectBase):
__visit_name__ = "textual_select"
+ _label_style = LABEL_STYLE_NONE
+
_traverse_internals = [
("element", InternalTraversal.dp_clauseelement),
("column_args", InternalTraversal.dp_clauseelement_list),
@@ -4472,6 +4580,12 @@ class TextualSelect(SelectBase):
(c.key, c) for c in self.column_args
).as_immutable()
+ def _set_label_style(self, style):
+ return self
+
+ def _ensure_disambiguated_names(self):
+ return self
+
@property
def _bind(self):
return self.element._bind