summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/selectable.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/selectable.py')
-rw-r--r--lib/sqlalchemy/sql/selectable.py190
1 files changed, 152 insertions, 38 deletions
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 5ffdf23d8..4eab60801 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -36,6 +36,7 @@ from .base import Generative
from .base import HasCompileState
from .base import HasMemoized
from .base import Immutable
+from .base import prefix_anon_map
from .coercions import _document_text_coercion
from .elements import _anonymous_label
from .elements import and_
@@ -1770,7 +1771,9 @@ class Subquery(AliasedReturnsRows):
"use the :meth:`.Query.scalar_subquery` method.",
)
def as_scalar(self):
- return self.element.scalar_subquery()
+ return self.element._set_label_style(
+ LABEL_STYLE_NONE
+ ).scalar_subquery()
class FromGrouping(GroupedElement, FromClause):
@@ -2281,6 +2284,9 @@ class SelectBase(
:meth:`.SelectBase.scalar_subquery`.
"""
+ if self._label_style is not LABEL_STYLE_NONE:
+ self = self._set_label_style(LABEL_STYLE_NONE)
+
return ScalarSelect(self)
def label(self, name):
@@ -2356,7 +2362,16 @@ class SelectBase(
.. versionadded:: 1.4
"""
- return Subquery._construct(self, name)
+
+ return Subquery._construct(self._ensure_disambiguated_names(), name)
+
+ def _ensure_disambiguated_names(self):
+ """Ensure that the names generated by this selectbase will be
+ disambiguated in some way, if possible.
+
+ """
+
+ raise NotImplementedError()
def alias(self, name=None, flat=False):
"""Return a named subquery against this :class:`.SelectBase`.
@@ -2391,6 +2406,17 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
# type: (SelectBase) -> None
self.element = coercions.expect(roles.SelectStatementRole, element)
+ def _ensure_disambiguated_names(self):
+ new_element = self.element._ensure_disambiguated_names()
+ if new_element is not self.element:
+ return SelectStatementGrouping(new_element)
+ else:
+ return self
+
+ @property
+ def _label_style(self):
+ return self.element._label_style
+
@property
def select_statement(self):
return self.element
@@ -2470,6 +2496,11 @@ class DeprecatedSelectBaseGenerations(object):
self.group_by.non_generative(self, *clauses)
+LABEL_STYLE_NONE = util.symbol("LABEL_STYLE_NONE")
+LABEL_STYLE_TABLENAME_PLUS_COL = util.symbol("LABEL_STYLE_TABLENAME_PLUS_COL")
+LABEL_STYLE_DISAMBIGUATE_ONLY = util.symbol("LABEL_STYLE_DISAMBIGUATE_ONLY")
+
+
class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
"""Base class for SELECT statements where additional elements can be
added.
@@ -2491,6 +2522,7 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
def __init__(
self,
+ _label_style=LABEL_STYLE_NONE,
use_labels=False,
limit=None,
offset=None,
@@ -2498,7 +2530,10 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
group_by=None,
bind=None,
):
- self.use_labels = use_labels
+ if use_labels:
+ _label_style = LABEL_STYLE_TABLENAME_PLUS_COL
+
+ self._label_style = _label_style
if limit is not None:
self.limit.non_generative(self, limit)
@@ -2572,7 +2607,10 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
key_share=key_share,
)
- @_generative
+ @property
+ def use_labels(self):
+ return self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL
+
def apply_labels(self):
"""return a new selectable with the 'use_labels' flag set to True.
@@ -2584,7 +2622,13 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
"""
- self.use_labels = True
+ return self._set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+
+ def _set_label_style(self, style):
+ if self._label_style is not style:
+ self = self._generate()
+ self._label_style = style
+ return self
@property
def _group_by_clause(self):
@@ -2951,6 +2995,22 @@ class CompoundSelect(HasCompileState, GenerativeSelect):
return True
return False
+ def _set_label_style(self, style):
+ if self._label_style is not style:
+ self = self._generate()
+ select_0 = self.selects[0]._set_label_style(style)
+ self.selects = [select_0] + self.selects[1:]
+
+ return self
+
+ def _ensure_disambiguated_names(self):
+ new_select = self.selects[0]._ensure_disambiguated_names()
+ if new_select is not self.selects[0]:
+ self = self._generate()
+ self.selects = [new_select] + self.selects[1:]
+
+ return self
+
def _generate_fromclause_column_proxies(self, subquery):
# this is a slightly hacky thing - the union exports a
@@ -2963,8 +3023,8 @@ class CompoundSelect(HasCompileState, GenerativeSelect):
# ForeignKeys in. this would allow the union() to have all
# those fks too.
select_0 = self.selects[0]
- if self.use_labels:
- select_0 = select_0.apply_labels()
+ if self._label_style is not LABEL_STYLE_NONE:
+ select_0 = select_0._set_label_style(self._label_style)
select_0._generate_fromclause_column_proxies(subquery)
# hand-construct the "_proxies" collection to include all
@@ -3299,6 +3359,7 @@ class Select(
("_hints", InternalTraversal.dp_table_hint_list),
("_distinct", InternalTraversal.dp_boolean),
("_distinct_on", InternalTraversal.dp_clauseelement_list),
+ ("_label_style", InternalTraversal.dp_plain_obj),
]
+ HasPrefixes._has_prefixes_traverse_internals
+ HasSuffixes._has_suffixes_traverse_internals
@@ -4081,29 +4142,35 @@ class Select(
"""
names = set()
+ pa = None
+ collection = []
- cols = _select_iterables(self._raw_columns)
-
- def name_for_col(c):
+ for c in _select_iterables(self._raw_columns):
# we use key_label since this name is intended for targeting
# within the ColumnCollection only, it's not related to SQL
# rendering which always uses column name for SQL label names
- if self.use_labels:
+ if self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL:
name = c._key_label
else:
name = c._proxy_key
if name in names:
- if self.use_labels:
- name = c._label_anon_label
+ if pa is None:
+ pa = prefix_anon_map()
+
+ if self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL:
+ name = c._label_anon_label % pa
else:
- name = c.anon_label
+ name = c.anon_label % pa
else:
names.add(name)
- return name
+ collection.append((name, c))
- return ColumnCollection(
- (name_for_col(c), c) for c in cols
- ).as_immutable()
+ return ColumnCollection(collection).as_immutable()
+
+ def _ensure_disambiguated_names(self):
+ if self._label_style is LABEL_STYLE_NONE:
+ self = self._set_label_style(LABEL_STYLE_DISAMBIGUATE_ONLY)
+ return self
def _generate_columns_plus_names(self, anon_for_dupe_key):
cols = _select_iterables(self._raw_columns)
@@ -4117,30 +4184,54 @@ class Select(
# anon label, apply _dedupe_label_anon_label to the subsequent
# occurrences of it.
- if self.use_labels:
+ if self._label_style is LABEL_STYLE_NONE:
+ # don't generate any labels
+ same_cols = set()
+
+ return [
+ (None, c, c in same_cols or same_cols.add(c)) for c in cols
+ ]
+ else:
names = {}
+ use_tablename_labels = (
+ self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL
+ )
+
def name_for_col(c):
if not c._render_label_in_columns_clause:
return (None, c, False)
- elif c._label is None:
+ elif use_tablename_labels:
+ if c._label is None:
+ repeated = c.anon_label in names
+ names[c.anon_label] = c
+ return (None, c, repeated)
+ elif getattr(c, "name", None) is None:
+ # this is a scalar_select(). need to improve this case
repeated = c.anon_label in names
names[c.anon_label] = c
return (None, c, repeated)
- name = c._label
+ if use_tablename_labels:
+ name = effective_name = c._label
+ else:
+ name = None
+ effective_name = c.name
repeated = False
- if name in names:
+ if effective_name in names:
# when looking to see if names[name] is the same column as
# c, use hash(), so that an annotated version of the column
# is seen as the same as the non-annotated
- if hash(names[name]) != hash(c):
+ if hash(names[effective_name]) != hash(c):
# different column under the same name. apply
# disambiguating label
- name = c._label_anon_label
+ if use_tablename_labels:
+ name = c._label_anon_label
+ else:
+ name = c.anon_label
if anon_for_dupe_key and name in names:
# here, c._label_anon_label is definitely unique to
@@ -4154,27 +4245,26 @@ class Select(
# already present. apply the "dedupe" label to
# subsequent occurrences of the column so that the
# original stays non-ambiguous
- name = c._dedupe_label_anon_label
+ if use_tablename_labels:
+ name = c._dedupe_label_anon_label
+ else:
+ name = c._dedupe_anon_label
repeated = True
else:
names[name] = c
elif anon_for_dupe_key:
# same column under the same name. apply the "dedupe"
# label so that the original stays non-ambiguous
- name = c._dedupe_label_anon_label
+ if use_tablename_labels:
+ name = c._dedupe_label_anon_label
+ else:
+ name = c._dedupe_anon_label
repeated = True
else:
- names[name] = c
+ names[effective_name] = c
return name, c, repeated
return [name_for_col(c) for c in cols]
- else:
- # repeated name logic only for use labels at the moment
- same_cols = set()
-
- return [
- (None, c, c in same_cols or same_cols.add(c)) for c in cols
- ]
def _generate_fromclause_column_proxies(self, subquery):
"""generate column proxies to place in the exported .c collection
@@ -4183,17 +4273,33 @@ class Select(
keys_seen = set()
prox = []
+ pa = None
+
+ tablename_plus_col = (
+ self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL
+ )
+ disambiguate_only = self._label_style is LABEL_STYLE_DISAMBIGUATE_ONLY
+
for name, c, repeated in self._generate_columns_plus_names(False):
if not hasattr(c, "_make_proxy"):
continue
- if name is None:
- key = None
- elif self.use_labels:
+ elif tablename_plus_col:
key = c._key_label
if key is not None and key in keys_seen:
- key = c._label_anon_label
+ if pa is None:
+ pa = prefix_anon_map()
+ key = c._label_anon_label % pa
+ keys_seen.add(key)
+ elif disambiguate_only:
+ key = c.key
+ if key is not None and key in keys_seen:
+ if pa is None:
+ pa = prefix_anon_map()
+ key = c.anon_label % pa
keys_seen.add(key)
else:
+ # one of the above label styles is set for subqueries
+ # as of #5221 so this codepath is likely not called now.
key = None
prox.append(
c._make_proxy(
@@ -4435,6 +4541,8 @@ class TextualSelect(SelectBase):
__visit_name__ = "textual_select"
+ _label_style = LABEL_STYLE_NONE
+
_traverse_internals = [
("element", InternalTraversal.dp_clauseelement),
("column_args", InternalTraversal.dp_clauseelement_list),
@@ -4472,6 +4580,12 @@ class TextualSelect(SelectBase):
(c.key, c) for c in self.column_args
).as_immutable()
+ def _set_label_style(self, style):
+ return self
+
+ def _ensure_disambiguated_names(self):
+ return self
+
@property
def _bind(self):
return self.element._bind