diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-03-29 14:24:39 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-03-29 22:27:41 -0400 |
commit | a65d5c250e9fd7090311ef12f28d7d959c6c738e (patch) | |
tree | 349f0b4c1127c3d87d9cffb62b5d6e02979a3a9a /lib/sqlalchemy/sql/selectable.py | |
parent | 8e857e3f6beecf7510f741428d8d0ba24f5cb71b (diff) | |
download | sqlalchemy-a65d5c250e9fd7090311ef12f28d7d959c6c738e.tar.gz |
Add a third labeling mode for SELECT statements
Enhanced the disambiguating labels feature of the
:func:`~.sql.expression.select` construct such that when a select statement
is used in a subquery, repeated column names from different tables are now
automatically labeled with a unique label name, without the need to use the
full "apply_labels()" feature that conbines tablename plus column name.
The disambigated labels are available as plain string keys in the .c
collection of the subquery, and most importantly the feature allows an ORM
:func:`.orm.aliased` construct against the combination of an entity and an
arbitrary subquery to work correctly, targeting the correct columns despite
same-named columns in the source tables, without the need for an "apply
labels" warning.
The existing labeling style is now called
LABEL_STYLE_TABLENAME_PLUS_COL. This labeling style will remain used
throughout the ORM as has been the case for over a decade, however,
the new disambiguation scheme could theoretically replace this scheme
entirely. The new scheme would dramatically alter how SQL looks
when rendered from the ORM to be more succinct but arguably harder
to read.
The tablename_columnname scheme used by Join.c is unaffected here,
as that's still hardcoded to that scheme.
Fixes: #5221
Change-Id: Ib47d9e0f35046b3afc77bef6e65709b93d0c3026
Diffstat (limited to 'lib/sqlalchemy/sql/selectable.py')
-rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 190 |
1 files changed, 152 insertions, 38 deletions
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 5ffdf23d8..4eab60801 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -36,6 +36,7 @@ from .base import Generative from .base import HasCompileState from .base import HasMemoized from .base import Immutable +from .base import prefix_anon_map from .coercions import _document_text_coercion from .elements import _anonymous_label from .elements import and_ @@ -1770,7 +1771,9 @@ class Subquery(AliasedReturnsRows): "use the :meth:`.Query.scalar_subquery` method.", ) def as_scalar(self): - return self.element.scalar_subquery() + return self.element._set_label_style( + LABEL_STYLE_NONE + ).scalar_subquery() class FromGrouping(GroupedElement, FromClause): @@ -2281,6 +2284,9 @@ class SelectBase( :meth:`.SelectBase.scalar_subquery`. """ + if self._label_style is not LABEL_STYLE_NONE: + self = self._set_label_style(LABEL_STYLE_NONE) + return ScalarSelect(self) def label(self, name): @@ -2356,7 +2362,16 @@ class SelectBase( .. versionadded:: 1.4 """ - return Subquery._construct(self, name) + + return Subquery._construct(self._ensure_disambiguated_names(), name) + + def _ensure_disambiguated_names(self): + """Ensure that the names generated by this selectbase will be + disambiguated in some way, if possible. + + """ + + raise NotImplementedError() def alias(self, name=None, flat=False): """Return a named subquery against this :class:`.SelectBase`. @@ -2391,6 +2406,17 @@ class SelectStatementGrouping(GroupedElement, SelectBase): # type: (SelectBase) -> None self.element = coercions.expect(roles.SelectStatementRole, element) + def _ensure_disambiguated_names(self): + new_element = self.element._ensure_disambiguated_names() + if new_element is not self.element: + return SelectStatementGrouping(new_element) + else: + return self + + @property + def _label_style(self): + return self.element._label_style + @property def select_statement(self): return self.element @@ -2470,6 +2496,11 @@ class DeprecatedSelectBaseGenerations(object): self.group_by.non_generative(self, *clauses) +LABEL_STYLE_NONE = util.symbol("LABEL_STYLE_NONE") +LABEL_STYLE_TABLENAME_PLUS_COL = util.symbol("LABEL_STYLE_TABLENAME_PLUS_COL") +LABEL_STYLE_DISAMBIGUATE_ONLY = util.symbol("LABEL_STYLE_DISAMBIGUATE_ONLY") + + class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): """Base class for SELECT statements where additional elements can be added. @@ -2491,6 +2522,7 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): def __init__( self, + _label_style=LABEL_STYLE_NONE, use_labels=False, limit=None, offset=None, @@ -2498,7 +2530,10 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): group_by=None, bind=None, ): - self.use_labels = use_labels + if use_labels: + _label_style = LABEL_STYLE_TABLENAME_PLUS_COL + + self._label_style = _label_style if limit is not None: self.limit.non_generative(self, limit) @@ -2572,7 +2607,10 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): key_share=key_share, ) - @_generative + @property + def use_labels(self): + return self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL + def apply_labels(self): """return a new selectable with the 'use_labels' flag set to True. @@ -2584,7 +2622,13 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): """ - self.use_labels = True + return self._set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) + + def _set_label_style(self, style): + if self._label_style is not style: + self = self._generate() + self._label_style = style + return self @property def _group_by_clause(self): @@ -2951,6 +2995,22 @@ class CompoundSelect(HasCompileState, GenerativeSelect): return True return False + def _set_label_style(self, style): + if self._label_style is not style: + self = self._generate() + select_0 = self.selects[0]._set_label_style(style) + self.selects = [select_0] + self.selects[1:] + + return self + + def _ensure_disambiguated_names(self): + new_select = self.selects[0]._ensure_disambiguated_names() + if new_select is not self.selects[0]: + self = self._generate() + self.selects = [new_select] + self.selects[1:] + + return self + def _generate_fromclause_column_proxies(self, subquery): # this is a slightly hacky thing - the union exports a @@ -2963,8 +3023,8 @@ class CompoundSelect(HasCompileState, GenerativeSelect): # ForeignKeys in. this would allow the union() to have all # those fks too. select_0 = self.selects[0] - if self.use_labels: - select_0 = select_0.apply_labels() + if self._label_style is not LABEL_STYLE_NONE: + select_0 = select_0._set_label_style(self._label_style) select_0._generate_fromclause_column_proxies(subquery) # hand-construct the "_proxies" collection to include all @@ -3299,6 +3359,7 @@ class Select( ("_hints", InternalTraversal.dp_table_hint_list), ("_distinct", InternalTraversal.dp_boolean), ("_distinct_on", InternalTraversal.dp_clauseelement_list), + ("_label_style", InternalTraversal.dp_plain_obj), ] + HasPrefixes._has_prefixes_traverse_internals + HasSuffixes._has_suffixes_traverse_internals @@ -4081,29 +4142,35 @@ class Select( """ names = set() + pa = None + collection = [] - cols = _select_iterables(self._raw_columns) - - def name_for_col(c): + for c in _select_iterables(self._raw_columns): # we use key_label since this name is intended for targeting # within the ColumnCollection only, it's not related to SQL # rendering which always uses column name for SQL label names - if self.use_labels: + if self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL: name = c._key_label else: name = c._proxy_key if name in names: - if self.use_labels: - name = c._label_anon_label + if pa is None: + pa = prefix_anon_map() + + if self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL: + name = c._label_anon_label % pa else: - name = c.anon_label + name = c.anon_label % pa else: names.add(name) - return name + collection.append((name, c)) - return ColumnCollection( - (name_for_col(c), c) for c in cols - ).as_immutable() + return ColumnCollection(collection).as_immutable() + + def _ensure_disambiguated_names(self): + if self._label_style is LABEL_STYLE_NONE: + self = self._set_label_style(LABEL_STYLE_DISAMBIGUATE_ONLY) + return self def _generate_columns_plus_names(self, anon_for_dupe_key): cols = _select_iterables(self._raw_columns) @@ -4117,30 +4184,54 @@ class Select( # anon label, apply _dedupe_label_anon_label to the subsequent # occurrences of it. - if self.use_labels: + if self._label_style is LABEL_STYLE_NONE: + # don't generate any labels + same_cols = set() + + return [ + (None, c, c in same_cols or same_cols.add(c)) for c in cols + ] + else: names = {} + use_tablename_labels = ( + self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL + ) + def name_for_col(c): if not c._render_label_in_columns_clause: return (None, c, False) - elif c._label is None: + elif use_tablename_labels: + if c._label is None: + repeated = c.anon_label in names + names[c.anon_label] = c + return (None, c, repeated) + elif getattr(c, "name", None) is None: + # this is a scalar_select(). need to improve this case repeated = c.anon_label in names names[c.anon_label] = c return (None, c, repeated) - name = c._label + if use_tablename_labels: + name = effective_name = c._label + else: + name = None + effective_name = c.name repeated = False - if name in names: + if effective_name in names: # when looking to see if names[name] is the same column as # c, use hash(), so that an annotated version of the column # is seen as the same as the non-annotated - if hash(names[name]) != hash(c): + if hash(names[effective_name]) != hash(c): # different column under the same name. apply # disambiguating label - name = c._label_anon_label + if use_tablename_labels: + name = c._label_anon_label + else: + name = c.anon_label if anon_for_dupe_key and name in names: # here, c._label_anon_label is definitely unique to @@ -4154,27 +4245,26 @@ class Select( # already present. apply the "dedupe" label to # subsequent occurrences of the column so that the # original stays non-ambiguous - name = c._dedupe_label_anon_label + if use_tablename_labels: + name = c._dedupe_label_anon_label + else: + name = c._dedupe_anon_label repeated = True else: names[name] = c elif anon_for_dupe_key: # same column under the same name. apply the "dedupe" # label so that the original stays non-ambiguous - name = c._dedupe_label_anon_label + if use_tablename_labels: + name = c._dedupe_label_anon_label + else: + name = c._dedupe_anon_label repeated = True else: - names[name] = c + names[effective_name] = c return name, c, repeated return [name_for_col(c) for c in cols] - else: - # repeated name logic only for use labels at the moment - same_cols = set() - - return [ - (None, c, c in same_cols or same_cols.add(c)) for c in cols - ] def _generate_fromclause_column_proxies(self, subquery): """generate column proxies to place in the exported .c collection @@ -4183,17 +4273,33 @@ class Select( keys_seen = set() prox = [] + pa = None + + tablename_plus_col = ( + self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL + ) + disambiguate_only = self._label_style is LABEL_STYLE_DISAMBIGUATE_ONLY + for name, c, repeated in self._generate_columns_plus_names(False): if not hasattr(c, "_make_proxy"): continue - if name is None: - key = None - elif self.use_labels: + elif tablename_plus_col: key = c._key_label if key is not None and key in keys_seen: - key = c._label_anon_label + if pa is None: + pa = prefix_anon_map() + key = c._label_anon_label % pa + keys_seen.add(key) + elif disambiguate_only: + key = c.key + if key is not None and key in keys_seen: + if pa is None: + pa = prefix_anon_map() + key = c.anon_label % pa keys_seen.add(key) else: + # one of the above label styles is set for subqueries + # as of #5221 so this codepath is likely not called now. key = None prox.append( c._make_proxy( @@ -4435,6 +4541,8 @@ class TextualSelect(SelectBase): __visit_name__ = "textual_select" + _label_style = LABEL_STYLE_NONE + _traverse_internals = [ ("element", InternalTraversal.dp_clauseelement), ("column_args", InternalTraversal.dp_clauseelement_list), @@ -4472,6 +4580,12 @@ class TextualSelect(SelectBase): (c.key, c) for c in self.column_args ).as_immutable() + def _set_label_style(self, style): + return self + + def _ensure_disambiguated_names(self): + return self + @property def _bind(self): return self.element._bind |