diff options
Diffstat (limited to 'lib/sqlalchemy/sql/selectable.py')
-rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 190 |
1 files changed, 152 insertions, 38 deletions
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 5ffdf23d8..4eab60801 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -36,6 +36,7 @@ from .base import Generative from .base import HasCompileState from .base import HasMemoized from .base import Immutable +from .base import prefix_anon_map from .coercions import _document_text_coercion from .elements import _anonymous_label from .elements import and_ @@ -1770,7 +1771,9 @@ class Subquery(AliasedReturnsRows): "use the :meth:`.Query.scalar_subquery` method.", ) def as_scalar(self): - return self.element.scalar_subquery() + return self.element._set_label_style( + LABEL_STYLE_NONE + ).scalar_subquery() class FromGrouping(GroupedElement, FromClause): @@ -2281,6 +2284,9 @@ class SelectBase( :meth:`.SelectBase.scalar_subquery`. """ + if self._label_style is not LABEL_STYLE_NONE: + self = self._set_label_style(LABEL_STYLE_NONE) + return ScalarSelect(self) def label(self, name): @@ -2356,7 +2362,16 @@ class SelectBase( .. versionadded:: 1.4 """ - return Subquery._construct(self, name) + + return Subquery._construct(self._ensure_disambiguated_names(), name) + + def _ensure_disambiguated_names(self): + """Ensure that the names generated by this selectbase will be + disambiguated in some way, if possible. + + """ + + raise NotImplementedError() def alias(self, name=None, flat=False): """Return a named subquery against this :class:`.SelectBase`. @@ -2391,6 +2406,17 @@ class SelectStatementGrouping(GroupedElement, SelectBase): # type: (SelectBase) -> None self.element = coercions.expect(roles.SelectStatementRole, element) + def _ensure_disambiguated_names(self): + new_element = self.element._ensure_disambiguated_names() + if new_element is not self.element: + return SelectStatementGrouping(new_element) + else: + return self + + @property + def _label_style(self): + return self.element._label_style + @property def select_statement(self): return self.element @@ -2470,6 +2496,11 @@ class DeprecatedSelectBaseGenerations(object): self.group_by.non_generative(self, *clauses) +LABEL_STYLE_NONE = util.symbol("LABEL_STYLE_NONE") +LABEL_STYLE_TABLENAME_PLUS_COL = util.symbol("LABEL_STYLE_TABLENAME_PLUS_COL") +LABEL_STYLE_DISAMBIGUATE_ONLY = util.symbol("LABEL_STYLE_DISAMBIGUATE_ONLY") + + class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): """Base class for SELECT statements where additional elements can be added. @@ -2491,6 +2522,7 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): def __init__( self, + _label_style=LABEL_STYLE_NONE, use_labels=False, limit=None, offset=None, @@ -2498,7 +2530,10 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): group_by=None, bind=None, ): - self.use_labels = use_labels + if use_labels: + _label_style = LABEL_STYLE_TABLENAME_PLUS_COL + + self._label_style = _label_style if limit is not None: self.limit.non_generative(self, limit) @@ -2572,7 +2607,10 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): key_share=key_share, ) - @_generative + @property + def use_labels(self): + return self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL + def apply_labels(self): """return a new selectable with the 'use_labels' flag set to True. @@ -2584,7 +2622,13 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): """ - self.use_labels = True + return self._set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) + + def _set_label_style(self, style): + if self._label_style is not style: + self = self._generate() + self._label_style = style + return self @property def _group_by_clause(self): @@ -2951,6 +2995,22 @@ class CompoundSelect(HasCompileState, GenerativeSelect): return True return False + def _set_label_style(self, style): + if self._label_style is not style: + self = self._generate() + select_0 = self.selects[0]._set_label_style(style) + self.selects = [select_0] + self.selects[1:] + + return self + + def _ensure_disambiguated_names(self): + new_select = self.selects[0]._ensure_disambiguated_names() + if new_select is not self.selects[0]: + self = self._generate() + self.selects = [new_select] + self.selects[1:] + + return self + def _generate_fromclause_column_proxies(self, subquery): # this is a slightly hacky thing - the union exports a @@ -2963,8 +3023,8 @@ class CompoundSelect(HasCompileState, GenerativeSelect): # ForeignKeys in. this would allow the union() to have all # those fks too. select_0 = self.selects[0] - if self.use_labels: - select_0 = select_0.apply_labels() + if self._label_style is not LABEL_STYLE_NONE: + select_0 = select_0._set_label_style(self._label_style) select_0._generate_fromclause_column_proxies(subquery) # hand-construct the "_proxies" collection to include all @@ -3299,6 +3359,7 @@ class Select( ("_hints", InternalTraversal.dp_table_hint_list), ("_distinct", InternalTraversal.dp_boolean), ("_distinct_on", InternalTraversal.dp_clauseelement_list), + ("_label_style", InternalTraversal.dp_plain_obj), ] + HasPrefixes._has_prefixes_traverse_internals + HasSuffixes._has_suffixes_traverse_internals @@ -4081,29 +4142,35 @@ class Select( """ names = set() + pa = None + collection = [] - cols = _select_iterables(self._raw_columns) - - def name_for_col(c): + for c in _select_iterables(self._raw_columns): # we use key_label since this name is intended for targeting # within the ColumnCollection only, it's not related to SQL # rendering which always uses column name for SQL label names - if self.use_labels: + if self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL: name = c._key_label else: name = c._proxy_key if name in names: - if self.use_labels: - name = c._label_anon_label + if pa is None: + pa = prefix_anon_map() + + if self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL: + name = c._label_anon_label % pa else: - name = c.anon_label + name = c.anon_label % pa else: names.add(name) - return name + collection.append((name, c)) - return ColumnCollection( - (name_for_col(c), c) for c in cols - ).as_immutable() + return ColumnCollection(collection).as_immutable() + + def _ensure_disambiguated_names(self): + if self._label_style is LABEL_STYLE_NONE: + self = self._set_label_style(LABEL_STYLE_DISAMBIGUATE_ONLY) + return self def _generate_columns_plus_names(self, anon_for_dupe_key): cols = _select_iterables(self._raw_columns) @@ -4117,30 +4184,54 @@ class Select( # anon label, apply _dedupe_label_anon_label to the subsequent # occurrences of it. - if self.use_labels: + if self._label_style is LABEL_STYLE_NONE: + # don't generate any labels + same_cols = set() + + return [ + (None, c, c in same_cols or same_cols.add(c)) for c in cols + ] + else: names = {} + use_tablename_labels = ( + self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL + ) + def name_for_col(c): if not c._render_label_in_columns_clause: return (None, c, False) - elif c._label is None: + elif use_tablename_labels: + if c._label is None: + repeated = c.anon_label in names + names[c.anon_label] = c + return (None, c, repeated) + elif getattr(c, "name", None) is None: + # this is a scalar_select(). need to improve this case repeated = c.anon_label in names names[c.anon_label] = c return (None, c, repeated) - name = c._label + if use_tablename_labels: + name = effective_name = c._label + else: + name = None + effective_name = c.name repeated = False - if name in names: + if effective_name in names: # when looking to see if names[name] is the same column as # c, use hash(), so that an annotated version of the column # is seen as the same as the non-annotated - if hash(names[name]) != hash(c): + if hash(names[effective_name]) != hash(c): # different column under the same name. apply # disambiguating label - name = c._label_anon_label + if use_tablename_labels: + name = c._label_anon_label + else: + name = c.anon_label if anon_for_dupe_key and name in names: # here, c._label_anon_label is definitely unique to @@ -4154,27 +4245,26 @@ class Select( # already present. apply the "dedupe" label to # subsequent occurrences of the column so that the # original stays non-ambiguous - name = c._dedupe_label_anon_label + if use_tablename_labels: + name = c._dedupe_label_anon_label + else: + name = c._dedupe_anon_label repeated = True else: names[name] = c elif anon_for_dupe_key: # same column under the same name. apply the "dedupe" # label so that the original stays non-ambiguous - name = c._dedupe_label_anon_label + if use_tablename_labels: + name = c._dedupe_label_anon_label + else: + name = c._dedupe_anon_label repeated = True else: - names[name] = c + names[effective_name] = c return name, c, repeated return [name_for_col(c) for c in cols] - else: - # repeated name logic only for use labels at the moment - same_cols = set() - - return [ - (None, c, c in same_cols or same_cols.add(c)) for c in cols - ] def _generate_fromclause_column_proxies(self, subquery): """generate column proxies to place in the exported .c collection @@ -4183,17 +4273,33 @@ class Select( keys_seen = set() prox = [] + pa = None + + tablename_plus_col = ( + self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL + ) + disambiguate_only = self._label_style is LABEL_STYLE_DISAMBIGUATE_ONLY + for name, c, repeated in self._generate_columns_plus_names(False): if not hasattr(c, "_make_proxy"): continue - if name is None: - key = None - elif self.use_labels: + elif tablename_plus_col: key = c._key_label if key is not None and key in keys_seen: - key = c._label_anon_label + if pa is None: + pa = prefix_anon_map() + key = c._label_anon_label % pa + keys_seen.add(key) + elif disambiguate_only: + key = c.key + if key is not None and key in keys_seen: + if pa is None: + pa = prefix_anon_map() + key = c.anon_label % pa keys_seen.add(key) else: + # one of the above label styles is set for subqueries + # as of #5221 so this codepath is likely not called now. key = None prox.append( c._make_proxy( @@ -4435,6 +4541,8 @@ class TextualSelect(SelectBase): __visit_name__ = "textual_select" + _label_style = LABEL_STYLE_NONE + _traverse_internals = [ ("element", InternalTraversal.dp_clauseelement), ("column_args", InternalTraversal.dp_clauseelement_list), @@ -4472,6 +4580,12 @@ class TextualSelect(SelectBase): (c.key, c) for c in self.column_args ).as_immutable() + def _set_label_style(self, style): + return self + + def _ensure_disambiguated_names(self): + return self + @property def _bind(self): return self.element._bind |