diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/sqlalchemy/ext/baked.py | 3 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/query.py | 11 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/base.py | 21 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 22 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/elements.py | 5 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 190 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 3 | ||||
-rw-r--r-- | lib/sqlalchemy/util/deprecations.py | 1 |
8 files changed, 194 insertions, 62 deletions
diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py index cf67387e4..ff741db32 100644 --- a/lib/sqlalchemy/ext/baked.py +++ b/lib/sqlalchemy/ext/baked.py @@ -25,6 +25,7 @@ from ..orm.session import Session from ..sql import func from ..sql import literal_column from ..sql import util as sql_util +from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..util import collections_abc @@ -436,7 +437,7 @@ class Result(object): self.session, context, self._params, self._post_criteria ) - context.statement.use_labels = True + context.statement._label_style = LABEL_STYLE_TABLENAME_PLUS_COL if context.autoflush and not context.populate_existing: self.session._autoflush() q = context.query.params(self._params).with_session(self.session) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index de11ba7dc..4ba82d1e8 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -56,6 +56,7 @@ from ..sql.base import _generative from ..sql.base import ColumnCollection from ..sql.base import Generative from ..sql.selectable import ForUpdateArg +from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..util import collections_abc __all__ = ["Query", "QueryContext", "aliased"] @@ -1209,6 +1210,14 @@ class Query(Generative): self.session = session + @util.deprecated_20( + ":meth:`.Query.from_self`", + alternative="The new approach is to use the :func:`.orm.aliased` " + "construct in conjunction with a subquery. See the section " + ":ref:`Selecting from the query itself as a subquery " + "<migration_20_query_from_self>` in the 2.0 migration notes for an " + "example.", + ) def from_self(self, *entities): r"""return a Query that selects from this Query's SELECT statement. @@ -3313,7 +3322,7 @@ class Query(Generative): def __iter__(self): context = self._compile_context() - context.statement.use_labels = True + context.statement.label_style = LABEL_STYLE_TABLENAME_PLUS_COL if self._autoflush and not self._populate_existing: self.session._autoflush() return self._execute_and_instances(context) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 682ef891c..974ca6ddb 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -638,6 +638,27 @@ class Executable(Generative): return None +class prefix_anon_map(dict): + """A map that creates new keys for missing key access. + + Considers keys of the form "<ident> <name>" to produce + new symbols "<name>_<index>", where "index" is an incrementing integer + corresponding to <name>. + + Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which + is otherwise usually used for this type of operation. + + """ + + def __missing__(self, key): + (ident, derived) = key.split(" ", 1) + anonymous_counter = self.get(derived, 1) + self[derived] = anonymous_counter + 1 + value = derived + "_" + str(anonymous_counter) + self[key] = value + return value + + class SchemaEventTarget(object): """Base class for elements that are the targets of :class:`.DDLEvents` events. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 000fc05fa..87ae5232e 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -40,6 +40,7 @@ from . import schema from . import selectable from . import sqltypes from .base import NO_ARG +from .base import prefix_anon_map from .elements import quoted_name from .. import exc from .. import util @@ -541,27 +542,6 @@ class _CompileLabel(elements.ColumnElement): return self -class prefix_anon_map(dict): - """A map that creates new keys for missing key access. - - Considers keys of the form "<ident> <name>" to produce - new symbols "<name>_<index>", where "index" is an incrementing integer - corresponding to <name>. - - Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which - is otherwise usually used for this type of operation. - - """ - - def __missing__(self, key): - (ident, derived) = key.split(" ", 1) - anonymous_counter = self.get(derived, 1) - self[derived] = anonymous_counter + 1 - value = derived + "_" + str(anonymous_counter) - self[key] = value - return value - - class SQLCompiler(Compiled): """Default implementation of :class:`.Compiled`. diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index f644b16d9..2b994c513 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -940,6 +940,11 @@ class ColumnElement( return self._anon_label(getattr(self, "name", None)) @util.memoized_property + def _dedupe_anon_label(self): + label = getattr(self, "name", None) or "anon" + return self._anon_label(label + "_") + + @util.memoized_property def _label_anon_label(self): return self._anon_label(getattr(self, "_label", None)) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 5ffdf23d8..4eab60801 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -36,6 +36,7 @@ from .base import Generative from .base import HasCompileState from .base import HasMemoized from .base import Immutable +from .base import prefix_anon_map from .coercions import _document_text_coercion from .elements import _anonymous_label from .elements import and_ @@ -1770,7 +1771,9 @@ class Subquery(AliasedReturnsRows): "use the :meth:`.Query.scalar_subquery` method.", ) def as_scalar(self): - return self.element.scalar_subquery() + return self.element._set_label_style( + LABEL_STYLE_NONE + ).scalar_subquery() class FromGrouping(GroupedElement, FromClause): @@ -2281,6 +2284,9 @@ class SelectBase( :meth:`.SelectBase.scalar_subquery`. """ + if self._label_style is not LABEL_STYLE_NONE: + self = self._set_label_style(LABEL_STYLE_NONE) + return ScalarSelect(self) def label(self, name): @@ -2356,7 +2362,16 @@ class SelectBase( .. versionadded:: 1.4 """ - return Subquery._construct(self, name) + + return Subquery._construct(self._ensure_disambiguated_names(), name) + + def _ensure_disambiguated_names(self): + """Ensure that the names generated by this selectbase will be + disambiguated in some way, if possible. + + """ + + raise NotImplementedError() def alias(self, name=None, flat=False): """Return a named subquery against this :class:`.SelectBase`. @@ -2391,6 +2406,17 @@ class SelectStatementGrouping(GroupedElement, SelectBase): # type: (SelectBase) -> None self.element = coercions.expect(roles.SelectStatementRole, element) + def _ensure_disambiguated_names(self): + new_element = self.element._ensure_disambiguated_names() + if new_element is not self.element: + return SelectStatementGrouping(new_element) + else: + return self + + @property + def _label_style(self): + return self.element._label_style + @property def select_statement(self): return self.element @@ -2470,6 +2496,11 @@ class DeprecatedSelectBaseGenerations(object): self.group_by.non_generative(self, *clauses) +LABEL_STYLE_NONE = util.symbol("LABEL_STYLE_NONE") +LABEL_STYLE_TABLENAME_PLUS_COL = util.symbol("LABEL_STYLE_TABLENAME_PLUS_COL") +LABEL_STYLE_DISAMBIGUATE_ONLY = util.symbol("LABEL_STYLE_DISAMBIGUATE_ONLY") + + class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): """Base class for SELECT statements where additional elements can be added. @@ -2491,6 +2522,7 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): def __init__( self, + _label_style=LABEL_STYLE_NONE, use_labels=False, limit=None, offset=None, @@ -2498,7 +2530,10 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): group_by=None, bind=None, ): - self.use_labels = use_labels + if use_labels: + _label_style = LABEL_STYLE_TABLENAME_PLUS_COL + + self._label_style = _label_style if limit is not None: self.limit.non_generative(self, limit) @@ -2572,7 +2607,10 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): key_share=key_share, ) - @_generative + @property + def use_labels(self): + return self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL + def apply_labels(self): """return a new selectable with the 'use_labels' flag set to True. @@ -2584,7 +2622,13 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): """ - self.use_labels = True + return self._set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) + + def _set_label_style(self, style): + if self._label_style is not style: + self = self._generate() + self._label_style = style + return self @property def _group_by_clause(self): @@ -2951,6 +2995,22 @@ class CompoundSelect(HasCompileState, GenerativeSelect): return True return False + def _set_label_style(self, style): + if self._label_style is not style: + self = self._generate() + select_0 = self.selects[0]._set_label_style(style) + self.selects = [select_0] + self.selects[1:] + + return self + + def _ensure_disambiguated_names(self): + new_select = self.selects[0]._ensure_disambiguated_names() + if new_select is not self.selects[0]: + self = self._generate() + self.selects = [new_select] + self.selects[1:] + + return self + def _generate_fromclause_column_proxies(self, subquery): # this is a slightly hacky thing - the union exports a @@ -2963,8 +3023,8 @@ class CompoundSelect(HasCompileState, GenerativeSelect): # ForeignKeys in. this would allow the union() to have all # those fks too. select_0 = self.selects[0] - if self.use_labels: - select_0 = select_0.apply_labels() + if self._label_style is not LABEL_STYLE_NONE: + select_0 = select_0._set_label_style(self._label_style) select_0._generate_fromclause_column_proxies(subquery) # hand-construct the "_proxies" collection to include all @@ -3299,6 +3359,7 @@ class Select( ("_hints", InternalTraversal.dp_table_hint_list), ("_distinct", InternalTraversal.dp_boolean), ("_distinct_on", InternalTraversal.dp_clauseelement_list), + ("_label_style", InternalTraversal.dp_plain_obj), ] + HasPrefixes._has_prefixes_traverse_internals + HasSuffixes._has_suffixes_traverse_internals @@ -4081,29 +4142,35 @@ class Select( """ names = set() + pa = None + collection = [] - cols = _select_iterables(self._raw_columns) - - def name_for_col(c): + for c in _select_iterables(self._raw_columns): # we use key_label since this name is intended for targeting # within the ColumnCollection only, it's not related to SQL # rendering which always uses column name for SQL label names - if self.use_labels: + if self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL: name = c._key_label else: name = c._proxy_key if name in names: - if self.use_labels: - name = c._label_anon_label + if pa is None: + pa = prefix_anon_map() + + if self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL: + name = c._label_anon_label % pa else: - name = c.anon_label + name = c.anon_label % pa else: names.add(name) - return name + collection.append((name, c)) - return ColumnCollection( - (name_for_col(c), c) for c in cols - ).as_immutable() + return ColumnCollection(collection).as_immutable() + + def _ensure_disambiguated_names(self): + if self._label_style is LABEL_STYLE_NONE: + self = self._set_label_style(LABEL_STYLE_DISAMBIGUATE_ONLY) + return self def _generate_columns_plus_names(self, anon_for_dupe_key): cols = _select_iterables(self._raw_columns) @@ -4117,30 +4184,54 @@ class Select( # anon label, apply _dedupe_label_anon_label to the subsequent # occurrences of it. - if self.use_labels: + if self._label_style is LABEL_STYLE_NONE: + # don't generate any labels + same_cols = set() + + return [ + (None, c, c in same_cols or same_cols.add(c)) for c in cols + ] + else: names = {} + use_tablename_labels = ( + self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL + ) + def name_for_col(c): if not c._render_label_in_columns_clause: return (None, c, False) - elif c._label is None: + elif use_tablename_labels: + if c._label is None: + repeated = c.anon_label in names + names[c.anon_label] = c + return (None, c, repeated) + elif getattr(c, "name", None) is None: + # this is a scalar_select(). need to improve this case repeated = c.anon_label in names names[c.anon_label] = c return (None, c, repeated) - name = c._label + if use_tablename_labels: + name = effective_name = c._label + else: + name = None + effective_name = c.name repeated = False - if name in names: + if effective_name in names: # when looking to see if names[name] is the same column as # c, use hash(), so that an annotated version of the column # is seen as the same as the non-annotated - if hash(names[name]) != hash(c): + if hash(names[effective_name]) != hash(c): # different column under the same name. apply # disambiguating label - name = c._label_anon_label + if use_tablename_labels: + name = c._label_anon_label + else: + name = c.anon_label if anon_for_dupe_key and name in names: # here, c._label_anon_label is definitely unique to @@ -4154,27 +4245,26 @@ class Select( # already present. apply the "dedupe" label to # subsequent occurrences of the column so that the # original stays non-ambiguous - name = c._dedupe_label_anon_label + if use_tablename_labels: + name = c._dedupe_label_anon_label + else: + name = c._dedupe_anon_label repeated = True else: names[name] = c elif anon_for_dupe_key: # same column under the same name. apply the "dedupe" # label so that the original stays non-ambiguous - name = c._dedupe_label_anon_label + if use_tablename_labels: + name = c._dedupe_label_anon_label + else: + name = c._dedupe_anon_label repeated = True else: - names[name] = c + names[effective_name] = c return name, c, repeated return [name_for_col(c) for c in cols] - else: - # repeated name logic only for use labels at the moment - same_cols = set() - - return [ - (None, c, c in same_cols or same_cols.add(c)) for c in cols - ] def _generate_fromclause_column_proxies(self, subquery): """generate column proxies to place in the exported .c collection @@ -4183,17 +4273,33 @@ class Select( keys_seen = set() prox = [] + pa = None + + tablename_plus_col = ( + self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL + ) + disambiguate_only = self._label_style is LABEL_STYLE_DISAMBIGUATE_ONLY + for name, c, repeated in self._generate_columns_plus_names(False): if not hasattr(c, "_make_proxy"): continue - if name is None: - key = None - elif self.use_labels: + elif tablename_plus_col: key = c._key_label if key is not None and key in keys_seen: - key = c._label_anon_label + if pa is None: + pa = prefix_anon_map() + key = c._label_anon_label % pa + keys_seen.add(key) + elif disambiguate_only: + key = c.key + if key is not None and key in keys_seen: + if pa is None: + pa = prefix_anon_map() + key = c.anon_label % pa keys_seen.add(key) else: + # one of the above label styles is set for subqueries + # as of #5221 so this codepath is likely not called now. key = None prox.append( c._make_proxy( @@ -4435,6 +4541,8 @@ class TextualSelect(SelectBase): __visit_name__ = "textual_select" + _label_style = LABEL_STYLE_NONE + _traverse_internals = [ ("element", InternalTraversal.dp_clauseelement), ("column_args", InternalTraversal.dp_clauseelement_list), @@ -4472,6 +4580,12 @@ class TextualSelect(SelectBase): (c.key, c) for c in self.column_args ).as_immutable() + def _set_label_style(self, style): + return self + + def _ensure_disambiguated_names(self): + return self + @property def _bind(self): return self.element._bind diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 7dada1394..5af29a723 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -24,6 +24,7 @@ from .. import types as sqltypes from .. import util from ..engine import default from ..engine import url +from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..util import compat from ..util import decorator @@ -398,7 +399,7 @@ class AssertsCompiledSQL(object): if isinstance(clause, orm.Query): context = clause._compile_context() - context.statement.use_labels = True + context.statement._label_style = LABEL_STYLE_TABLENAME_PLUS_COL clause = context.statement elif isinstance(clause, orm.persistence.BulkUD): with mock.patch.object(clause, "_execute_stmt") as stmt_mock: diff --git a/lib/sqlalchemy/util/deprecations.py b/lib/sqlalchemy/util/deprecations.py index b78a71b1b..e3d4b88c4 100644 --- a/lib/sqlalchemy/util/deprecations.py +++ b/lib/sqlalchemy/util/deprecations.py @@ -201,6 +201,7 @@ def _sanitize_restructured_text(text): name += "()" return name + text = re.sub(r":ref:`(.+) <.*>`", lambda m: '"%s"' % m.group(1), text) return re.sub(r"\:(\w+)\:`~?\.?(.+?)`", repl, text) |