diff options
35 files changed, 1495 insertions, 927 deletions
diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index c32427644..1d832e4af 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -117,8 +117,6 @@ class CursorResultMetaData(ResultMetaData): compiled_statement = context.compiled.statement invoked_statement = context.invoked_statement - # same statement was invoked as the one we cached against, - # return self if compiled_statement is invoked_statement: return self diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py index 7ac556dcc..f95a30fda 100644 --- a/lib/sqlalchemy/ext/baked.py +++ b/lib/sqlalchemy/ext/baked.py @@ -228,7 +228,7 @@ class BakedQuery(object): # in 1.4, this is where before_compile() event is # invoked - statement = query._statement_20(orm_results=True) + statement = query._statement_20() # if the query is not safe to cache, we still do everything as though # we did cache it, since the receiver of _bake() assumes subqueryload diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py index 32975a949..7736a1290 100644 --- a/lib/sqlalchemy/ext/compiler.py +++ b/lib/sqlalchemy/ext/compiler.py @@ -401,7 +401,6 @@ Example usage:: from .. import exc from .. import util from ..sql import sqltypes -from ..sql import visitors def compiles(class_, *specs): @@ -456,12 +455,12 @@ def compiles(class_, *specs): def deregister(class_): """Remove all custom compilers associated with a given - :class:`_expression.ClauseElement` type.""" + :class:`_expression.ClauseElement` type. + + """ if hasattr(class_, "_compiler_dispatcher"): - # regenerate default _compiler_dispatch - visitors._generate_compiler_dispatch(class_) - # remove custom directive + class_._compiler_dispatch = class_._original_compiler_dispatch del class_._compiler_dispatcher diff --git a/lib/sqlalchemy/future/selectable.py b/lib/sqlalchemy/future/selectable.py index 74cc13501..407ec9633 100644 --- a/lib/sqlalchemy/future/selectable.py +++ b/lib/sqlalchemy/future/selectable.py @@ -71,6 +71,10 @@ class Select(_LegacySelect): return self.where(*criteria) + def _exported_columns_iterator(self): + meth = SelectState.get_plugin_class(self).exported_columns_iterator + return meth(self) + def _filter_by_zero(self): if self._setup_joins: meth = SelectState.get_plugin_class( diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 5589f0e0c..ba30d203b 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -5,7 +5,6 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php - from . import attributes from . import interfaces from . import loading @@ -27,6 +26,7 @@ from ..sql import expression from ..sql import roles from ..sql import util as sql_util from ..sql import visitors +from ..sql.base import _select_iterables from ..sql.base import CacheableOptions from ..sql.base import CompileState from ..sql.base import Options @@ -90,7 +90,7 @@ class QueryContext(object): self.execution_options = execution_options or _EMPTY_DICT self.bind_arguments = bind_arguments or _EMPTY_DICT self.compile_state = compile_state - self.query = query = compile_state.query + self.query = query = compile_state.select_statement self.session = session self.propagated_loader_options = { @@ -119,10 +119,14 @@ class QueryContext(object): class ORMCompileState(CompileState): + # note this is a dictionary, but the + # default_compile_options._with_polymorphic_adapt_map is a tuple + _with_polymorphic_adapt_map = _EMPTY_DICT + class default_compile_options(CacheableOptions): _cache_key_traversal = [ ("_use_legacy_query_style", InternalTraversal.dp_boolean), - ("_orm_results", InternalTraversal.dp_boolean), + ("_for_statement", InternalTraversal.dp_boolean), ("_bake_ok", InternalTraversal.dp_boolean), ( "_with_polymorphic_adapt_map", @@ -137,8 +141,18 @@ class ORMCompileState(CompileState): ("_for_refresh_state", InternalTraversal.dp_boolean), ] + # set to True by default from Query._statement_20(), to indicate + # the rendered query should look like a legacy ORM query. right + # now this basically indicates we should use tablename_columnname + # style labels. Generally indicates the statement originated + # from a Query object. _use_legacy_query_style = False - _orm_results = True + + # set *only* when we are coming from the Query.statement + # accessor, or a Query-level equivalent such as + # query.subquery(). this supersedes "toplevel". + _for_statement = False + _bake_ok = True _with_polymorphic_adapt_map = () _current_path = _path_registry @@ -149,42 +163,24 @@ class ORMCompileState(CompileState): _set_base_alias = False _for_refresh_state = False - @classmethod - def merge(cls, other): - return cls + other._state_dict() - current_path = _path_registry def __init__(self, *arg, **kw): raise NotImplementedError() - def dispose(self): - self.attributes.clear() - @classmethod def create_for_statement(cls, statement_container, compiler, **kw): - raise NotImplementedError() + """Create a context for a statement given a :class:`.Compiler`. - @classmethod - def _create_for_legacy_query(cls, query, toplevel, for_statement=False): - stmt = query._statement_20(orm_results=not for_statement) - - # this chooses between ORMFromStatementCompileState and - # ORMSelectCompileState. We could also base this on - # query._statement is not None as we have the ORM Query here - # however this is the more general path. - compile_state_cls = CompileState._get_plugin_class_for_plugin( - stmt, "orm" - ) + This method is always invoked in the context of SQLCompiler.process(). - return compile_state_cls._create_for_statement_or_query( - stmt, toplevel, for_statement=for_statement - ) + For a Select object, this would be invoked from + SQLCompiler.visit_select(). For the special FromStatement object used + by Query to indicate "Query.from_statement()", this is called by + FromStatement._compiler_dispatch() that would be called by + SQLCompiler.process(). - @classmethod - def _create_for_statement_or_query( - cls, statement_container, for_statement=False, - ): + """ raise NotImplementedError() @classmethod @@ -266,21 +262,20 @@ class ORMCompileState(CompileState): and ext_info.mapper.persist_selectable not in self._polymorphic_adapters ): - self._mapper_loads_polymorphically_with( - ext_info.mapper, - sql_util.ColumnAdapter( - selectable, ext_info.mapper._equivalent_columns - ), - ) + for mp in ext_info.mapper.iterate_to_root(): + self._mapper_loads_polymorphically_with( + mp, + sql_util.ColumnAdapter(selectable, mp._equivalent_columns), + ) def _mapper_loads_polymorphically_with(self, mapper, adapter): for m2 in mapper._with_polymorphic_mappers or [mapper]: self._polymorphic_adapters[m2] = adapter - for m in m2.iterate_to_root(): + for m in m2.iterate_to_root(): # TODO: redundant ? self._polymorphic_adapters[m.local_table] = adapter -@sql.base.CompileState.plugin_for("orm", "grouping") +@sql.base.CompileState.plugin_for("orm", "orm_from_statement") class ORMFromStatementCompileState(ORMCompileState): _aliased_generations = util.immutabledict() _from_obj_alias = None @@ -294,31 +289,23 @@ class ORMFromStatementCompileState(ORMCompileState): @classmethod def create_for_statement(cls, statement_container, compiler, **kw): - compiler._rewrites_selected_columns = True - toplevel = not compiler.stack - return cls._create_for_statement_or_query( - statement_container, toplevel - ) - @classmethod - def _create_for_statement_or_query( - cls, statement_container, toplevel, for_statement=False, - ): - # from .query import FromStatement - - # assert isinstance(statement_container, FromStatement) + if compiler is not None: + compiler._rewrites_selected_columns = True + toplevel = not compiler.stack + else: + toplevel = True self = cls.__new__(cls) self._primary_entity = None - self.use_orm_style = ( + self.use_legacy_query_style = ( statement_container.compile_options._use_legacy_query_style ) - self.statement_container = self.query = statement_container - self.requested_statement = statement_container.element + self.statement_container = self.select_statement = statement_container + self.requested_statement = statement = statement_container.element self._entities = [] - self._with_polymorphic_adapt_map = {} self._polymorphic_adapters = {} self._no_yield_pers = set() @@ -349,12 +336,6 @@ class ORMFromStatementCompileState(ORMCompileState): self.create_eager_joins = [] self._fallback_from_clauses = [] - self._setup_for_statement() - - return self - - def _setup_for_statement(self): - statement = self.requested_statement if ( isinstance(statement, expression.SelectBase) and not statement._is_textual @@ -392,6 +373,8 @@ class ORMFromStatementCompileState(ORMCompileState): # for entity in self._entities: # entity.setup_compile_state(self) + return self + def _adapt_col_list(self, cols, current_adapter): return cols @@ -401,7 +384,8 @@ class ORMFromStatementCompileState(ORMCompileState): @sql.base.CompileState.plugin_for("orm", "select") class ORMSelectCompileState(ORMCompileState, SelectState): - _joinpath = _joinpoint = util.immutabledict() + _joinpath = _joinpoint = _EMPTY_DICT + _from_obj_alias = None _has_mapper_entities = False @@ -417,77 +401,71 @@ class ORMSelectCompileState(ORMCompileState, SelectState): @classmethod def create_for_statement(cls, statement, compiler, **kw): + """compiler hook, we arrive here from compiler.visit_select() only.""" + if not statement._is_future: return SelectState(statement, compiler, **kw) - toplevel = not compiler.stack + if compiler is not None: + toplevel = not compiler.stack + compiler._rewrites_selected_columns = True + else: + toplevel = True - compiler._rewrites_selected_columns = True + select_statement = statement - orm_state = cls._create_for_statement_or_query( - statement, for_statement=True, toplevel=toplevel - ) - SelectState.__init__(orm_state, orm_state.statement, compiler, **kw) - return orm_state - - @classmethod - def _create_for_statement_or_query( - cls, query, toplevel, for_statement=False, _entities_only=False - ): - assert isinstance(query, future.Select) - - query.compile_options = cls.default_compile_options.merge( - query.compile_options + # if we are a select() that was never a legacy Query, we won't + # have ORM level compile options. + statement.compile_options = cls.default_compile_options.safe_merge( + statement.compile_options ) self = cls.__new__(cls) - self._primary_entity = None - - self.query = query - self.use_orm_style = query.compile_options._use_legacy_query_style + self.select_statement = select_statement - self.select_statement = select_statement = query + # indicates this select() came from Query.statement + self.for_statement = ( + for_statement + ) = select_statement.compile_options._for_statement - if not hasattr(select_statement.compile_options, "_orm_results"): - select_statement.compile_options = cls.default_compile_options - select_statement.compile_options += {"_orm_results": for_statement} - else: - for_statement = not select_statement.compile_options._orm_results + if not for_statement and not toplevel: + # for subqueries, turn off eagerloads. + # if "for_statement" mode is set, Query.subquery() + # would have set this flag to False already if that's what's + # desired + select_statement.compile_options += { + "_enable_eagerloads": False, + } - self.query = query + # generally if we are from Query or directly from a select() + self.use_legacy_query_style = ( + select_statement.compile_options._use_legacy_query_style + ) self._entities = [] - + self._primary_entity = None self._aliased_generations = {} self._polymorphic_adapters = {} self._no_yield_pers = set() # legacy: only for query.with_polymorphic() - self._with_polymorphic_adapt_map = wpam = dict( - select_statement.compile_options._with_polymorphic_adapt_map - ) - if wpam: + if select_statement.compile_options._with_polymorphic_adapt_map: + self._with_polymorphic_adapt_map = dict( + select_statement.compile_options._with_polymorphic_adapt_map + ) self._setup_with_polymorphics() _QueryEntity.to_compile_state(self, select_statement._raw_columns) - if _entities_only: - return self - - self.compile_options = query.compile_options - - # TODO: the name of this flag "for_statement" has to change, - # as it is difficult to distinguish from the "query._statement" use - # case which is something totally different - self.for_statement = for_statement + self.compile_options = select_statement.compile_options # determine label style. we can make different decisions here. # at the moment, trying to see if we can always use DISAMBIGUATE_ONLY # rather than LABEL_STYLE_NONE, and if we can use disambiguate style # for new style ORM selects too. if self.select_statement._label_style is LABEL_STYLE_NONE: - if self.use_orm_style and not for_statement: + if self.use_legacy_query_style and not self.for_statement: self.label_style = LABEL_STYLE_TABLENAME_PLUS_COL else: self.label_style = LABEL_STYLE_DISAMBIGUATE_ONLY @@ -522,129 +500,16 @@ class ORMSelectCompileState(ORMCompileState, SelectState): info.selectable for info in select_statement._from_obj ] + # this is a fairly arbitrary break into a second method, + # so it might be nicer to break up create_for_statement() + # and _setup_for_generate into three or four logical sections self._setup_for_generate() - return self - - @classmethod - def _create_entities_collection(cls, query): - """Creates a partial ORMSelectCompileState that includes - the full collection of _MapperEntity and other _QueryEntity objects. - - Supports a few remaining use cases that are pre-compilation - but still need to gather some of the column / adaption information. - - """ - self = cls.__new__(cls) - - self._entities = [] - self._primary_entity = None - self._aliased_generations = {} - self._polymorphic_adapters = {} - - # legacy: only for query.with_polymorphic() - self._with_polymorphic_adapt_map = wpam = dict( - query.compile_options._with_polymorphic_adapt_map - ) - if wpam: - self._setup_with_polymorphics() + if compiler is not None: + SelectState.__init__(self, self.statement, compiler, **kw) - _QueryEntity.to_compile_state(self, query._raw_columns) return self - @classmethod - def determine_last_joined_entity(cls, statement): - setup_joins = statement._setup_joins - - if not setup_joins: - return None - - (target, onclause, from_, flags) = setup_joins[-1] - - if isinstance(target, interfaces.PropComparator): - return target.entity - else: - return target - - def _setup_with_polymorphics(self): - # legacy: only for query.with_polymorphic() - for ext_info, wp in self._with_polymorphic_adapt_map.items(): - self._mapper_loads_polymorphically_with(ext_info, wp._adapter) - - def _set_select_from_alias(self): - - query = self.select_statement # query - - assert self.compile_options._set_base_alias - assert len(query._from_obj) == 1 - - adapter = self._get_select_from_alias_from_obj(query._from_obj[0]) - if adapter: - self.compile_options += {"_enable_single_crit": False} - self._from_obj_alias = adapter - - def _get_select_from_alias_from_obj(self, from_obj): - info = from_obj - - if "parententity" in info._annotations: - info = info._annotations["parententity"] - - if hasattr(info, "mapper"): - if not info.is_aliased_class: - raise sa_exc.ArgumentError( - "A selectable (FromClause) instance is " - "expected when the base alias is being set." - ) - else: - return info._adapter - - elif isinstance(info.selectable, sql.selectable.AliasedReturnsRows): - equivs = self._all_equivs() - return sql_util.ColumnAdapter(info, equivs) - else: - return None - - def _mapper_zero(self): - """return the Mapper associated with the first QueryEntity.""" - return self._entities[0].mapper - - def _entity_zero(self): - """Return the 'entity' (mapper or AliasedClass) associated - with the first QueryEntity, or alternatively the 'select from' - entity if specified.""" - - for ent in self.from_clauses: - if "parententity" in ent._annotations: - return ent._annotations["parententity"] - for qent in self._entities: - if qent.entity_zero: - return qent.entity_zero - - return None - - def _only_full_mapper_zero(self, methname): - if self._entities != [self._primary_entity]: - raise sa_exc.InvalidRequestError( - "%s() can only be used against " - "a single mapped class." % methname - ) - return self._primary_entity.entity_zero - - def _only_entity_zero(self, rationale=None): - if len(self._entities) > 1: - raise sa_exc.InvalidRequestError( - rationale - or "This operation requires a Query " - "against a single mapper." - ) - return self._entity_zero() - - def _all_equivs(self): - equivs = {} - for ent in self._mapper_entities: - equivs.update(ent.mapper._equivalent_columns) - return equivs - def _setup_for_generate(self): query = self.select_statement @@ -772,6 +637,140 @@ class ORMSelectCompileState(ORMCompileState, SelectState): {"deepentity": ezero} ) + @classmethod + def _create_entities_collection(cls, query): + """Creates a partial ORMSelectCompileState that includes + the full collection of _MapperEntity and other _QueryEntity objects. + + Supports a few remaining use cases that are pre-compilation + but still need to gather some of the column / adaption information. + + """ + self = cls.__new__(cls) + + self._entities = [] + self._primary_entity = None + self._aliased_generations = {} + self._polymorphic_adapters = {} + + # legacy: only for query.with_polymorphic() + if query.compile_options._with_polymorphic_adapt_map: + self._with_polymorphic_adapt_map = dict( + query.compile_options._with_polymorphic_adapt_map + ) + self._setup_with_polymorphics() + + _QueryEntity.to_compile_state(self, query._raw_columns) + return self + + @classmethod + def determine_last_joined_entity(cls, statement): + setup_joins = statement._setup_joins + + if not setup_joins: + return None + + (target, onclause, from_, flags) = setup_joins[-1] + + if isinstance(target, interfaces.PropComparator): + return target.entity + else: + return target + + @classmethod + def exported_columns_iterator(cls, statement): + for element in statement._raw_columns: + if ( + element.is_selectable + and "entity_namespace" in element._annotations + ): + for elem in _select_iterables( + element._annotations["entity_namespace"].columns + ): + yield elem + else: + for elem in _select_iterables([element]): + yield elem + + def _setup_with_polymorphics(self): + # legacy: only for query.with_polymorphic() + for ext_info, wp in self._with_polymorphic_adapt_map.items(): + self._mapper_loads_polymorphically_with(ext_info, wp._adapter) + + def _set_select_from_alias(self): + + query = self.select_statement # query + + assert self.compile_options._set_base_alias + assert len(query._from_obj) == 1 + + adapter = self._get_select_from_alias_from_obj(query._from_obj[0]) + if adapter: + self.compile_options += {"_enable_single_crit": False} + self._from_obj_alias = adapter + + def _get_select_from_alias_from_obj(self, from_obj): + info = from_obj + + if "parententity" in info._annotations: + info = info._annotations["parententity"] + + if hasattr(info, "mapper"): + if not info.is_aliased_class: + raise sa_exc.ArgumentError( + "A selectable (FromClause) instance is " + "expected when the base alias is being set." + ) + else: + return info._adapter + + elif isinstance(info.selectable, sql.selectable.AliasedReturnsRows): + equivs = self._all_equivs() + return sql_util.ColumnAdapter(info, equivs) + else: + return None + + def _mapper_zero(self): + """return the Mapper associated with the first QueryEntity.""" + return self._entities[0].mapper + + def _entity_zero(self): + """Return the 'entity' (mapper or AliasedClass) associated + with the first QueryEntity, or alternatively the 'select from' + entity if specified.""" + + for ent in self.from_clauses: + if "parententity" in ent._annotations: + return ent._annotations["parententity"] + for qent in self._entities: + if qent.entity_zero: + return qent.entity_zero + + return None + + def _only_full_mapper_zero(self, methname): + if self._entities != [self._primary_entity]: + raise sa_exc.InvalidRequestError( + "%s() can only be used against " + "a single mapped class." % methname + ) + return self._primary_entity.entity_zero + + def _only_entity_zero(self, rationale=None): + if len(self._entities) > 1: + raise sa_exc.InvalidRequestError( + rationale + or "This operation requires a Query " + "against a single mapper." + ) + return self._entity_zero() + + def _all_equivs(self): + equivs = {} + for ent in self._mapper_entities: + equivs.update(ent.mapper._equivalent_columns) + return equivs + def _compound_eager_statement(self): # for eager joins present and LIMIT/OFFSET/DISTINCT, # wrap the query inside a select, @@ -920,6 +919,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): statement = Select.__new__(Select) statement._raw_columns = raw_columns statement._from_obj = from_obj + statement._label_style = label_style if where_criteria: @@ -1653,31 +1653,10 @@ class ORMSelectCompileState(ORMCompileState, SelectState): "target." ) - aliased_entity = ( - right_mapper - and not right_is_aliased - and ( - # TODO: there is a reliance here on aliasing occurring - # when we join to a polymorphic mapper that doesn't actually - # need aliasing. When this condition is present, we should - # be able to say mapper_loads_polymorphically_with() - # and render the straight polymorphic selectable. this - # does not appear to be possible at the moment as the - # adapter no longer takes place on the rest of the query - # and it's not clear where that's failing to happen. - ( - right_mapper.with_polymorphic - and isinstance( - right_mapper._with_polymorphic_selectable, - expression.AliasedReturnsRows, - ) - ) - or overlap - # test for overlap: - # orm/inheritance/relationships.py - # SelfReferentialM2MTest - ) - ) + # test for overlap: + # orm/inheritance/relationships.py + # SelfReferentialM2MTest + aliased_entity = right_mapper and not right_is_aliased and overlap if not need_adapter and (create_aliases or aliased_entity): # there are a few places in the ORM that automatic aliasing @@ -1707,7 +1686,30 @@ class ORMSelectCompileState(ORMCompileState, SelectState): self._aliased_generations[aliased_generation] = ( adapter, ) + self._aliased_generations.get(aliased_generation, ()) - + elif ( + not r_info.is_clause_element + and not right_is_aliased + and right_mapper.with_polymorphic + and isinstance( + right_mapper._with_polymorphic_selectable, + expression.AliasedReturnsRows, + ) + ): + # for the case where the target mapper has a with_polymorphic + # set up, ensure an adapter is set up for criteria that works + # against this mapper. Previously, this logic used to + # use the "create_aliases or aliased_entity" case to generate + # an aliased() object, but this creates an alias that isn't + # strictly necessary. + # see test/orm/test_core_compilation.py + # ::RelNaturalAliasedJoinsTest::test_straight + # and similar + self._mapper_loads_polymorphically_with( + right_mapper, + sql_util.ColumnAdapter( + right_mapper.selectable, right_mapper._equivalent_columns, + ), + ) # if the onclause is a ClauseElement, adapt it with any # adapters that are in place right now if isinstance(onclause, expression.ClauseElement): @@ -1755,8 +1757,8 @@ class ORMSelectCompileState(ORMCompileState, SelectState): "offset_clause": self.select_statement._offset_clause, "distinct": self.distinct, "distinct_on": self.distinct_on, - "prefixes": self.query._prefixes, - "suffixes": self.query._suffixes, + "prefixes": self.select_statement._prefixes, + "suffixes": self.select_statement._suffixes, "group_by": self.group_by or None, } @@ -2036,7 +2038,14 @@ class _MapperEntity(_QueryEntity): self._with_polymorphic_mappers = ext_info.with_polymorphic_mappers self._polymorphic_discriminator = ext_info.polymorphic_on - if mapper.with_polymorphic or mapper._requires_row_aliasing: + if ( + mapper.with_polymorphic + # controversy - only if inheriting mapper is also + # polymorphic? + # or (mapper.inherits and mapper.inherits.with_polymorphic) + or mapper.inherits + or mapper._requires_row_aliasing + ): compile_state._create_with_polymorphic_adapter( ext_info, self.selectable ) @@ -2361,7 +2370,7 @@ class _ORMColumnEntity(_ColumnEntity): _entity._post_inspect self.entity_zero = self.entity_zero_or_selectable = ezero = _entity - self.mapper = _entity.mapper + self.mapper = mapper = _entity.mapper if parent_bundle: parent_bundle._entities.append(self) @@ -2373,7 +2382,11 @@ class _ORMColumnEntity(_ColumnEntity): self._extra_entities = (self.expr, self.column) - if self.mapper.with_polymorphic: + if ( + mapper.with_polymorphic + or mapper.inherits + or mapper._requires_row_aliasing + ): compile_state._create_with_polymorphic_adapter( ezero, ezero.selectable ) @@ -2414,6 +2427,7 @@ class _ORMColumnEntity(_ColumnEntity): column = current_adapter(self.column, False) else: column = self.column + ezero = self.entity_zero single_table_crit = self.mapper._single_table_criterion diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 88d01eb0f..424ed5dfe 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -345,7 +345,7 @@ def load_on_pk_identity( if load_options is None: load_options = QueryContext.default_load_options - compile_options = ORMCompileState.default_compile_options.merge( + compile_options = ORMCompileState.default_compile_options.safe_merge( q.compile_options ) diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 7bfe70c36..4166e6d2a 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1743,6 +1743,11 @@ class Mapper( or prop.columns[0] is self.polymorphic_on ) + if isinstance(col, expression.Label): + # new in 1.4, get column property against expressions + # to be addressable in subqueries + col.key = col._key_label = key + self.columns.add(col, key) for col in prop.columns + prop._orig_columns: for col in col.proxy_set: @@ -2282,6 +2287,29 @@ class Mapper( ) ) + def _columns_plus_keys(self, polymorphic_mappers=()): + if polymorphic_mappers: + poly_properties = self._iterate_polymorphic_properties( + polymorphic_mappers + ) + else: + poly_properties = self._polymorphic_properties + + return [ + (prop.key, prop.columns[0]) + for prop in poly_properties + if isinstance(prop, properties.ColumnProperty) + ] + + @HasMemoized.memoized_attribute + def _polymorphic_adapter(self): + if self.with_polymorphic: + return sql_util.ColumnAdapter( + self.selectable, equivalents=self._equivalent_columns + ) + else: + return None + def _iterate_polymorphic_properties(self, mappers=None): """Return an iterator of MapperProperty objects which will render into a SELECT.""" diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 4cf501e3f..02f0752a5 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -264,6 +264,7 @@ class ColumnProperty(StrategizedProperty): def do_init(self): super(ColumnProperty, self).do_init() + if len(self.columns) > 1 and set(self.parent.primary_key).issuperset( self.columns ): @@ -339,28 +340,51 @@ class ColumnProperty(StrategizedProperty): __slots__ = "__clause_element__", "info", "expressions" + def _orm_annotate_column(self, column): + """annotate and possibly adapt a column to be returned + as the mapped-attribute exposed version of the column. + + The column in this context needs to act as much like the + column in an ORM mapped context as possible, so includes + annotations to give hints to various ORM functions as to + the source entity of this column. It also adapts it + to the mapper's with_polymorphic selectable if one is + present. + + """ + + pe = self._parententity + annotations = { + "entity_namespace": pe, + "parententity": pe, + "parentmapper": pe, + "orm_key": self.prop.key, + } + + col = column + + # for a mapper with polymorphic_on and an adapter, return + # the column against the polymorphic selectable. + # see also orm.util._orm_downgrade_polymorphic_columns + # for the reverse operation. + if self._parentmapper._polymorphic_adapter: + mapper_local_col = col + col = self._parentmapper._polymorphic_adapter.traverse(col) + + # this is a clue to the ORM Query etc. that this column + # was adapted to the mapper's polymorphic_adapter. the + # ORM uses this hint to know which column its adapting. + annotations["adapt_column"] = mapper_local_col + + return col._annotate(annotations)._set_propagate_attrs( + {"compile_state_plugin": "orm", "plugin_subject": pe} + ) + def _memoized_method___clause_element__(self): if self.adapter: return self.adapter(self.prop.columns[0], self.prop.key) else: - pe = self._parententity - # no adapter, so we aren't aliased - # assert self._parententity is self._parentmapper - return ( - self.prop.columns[0] - ._annotate( - { - "entity_namespace": pe, - "parententity": pe, - "parentmapper": pe, - "orm_key": self.prop.key, - "compile_state_plugin": "orm", - } - ) - ._set_propagate_attrs( - {"compile_state_plugin": "orm", "plugin_subject": pe} - ) - ) + return self._orm_annotate_column(self.prop.columns[0]) def _memoized_attr_info(self): """The .info dictionary for this attribute.""" @@ -384,23 +408,8 @@ class ColumnProperty(StrategizedProperty): for col in self.prop.columns ] else: - # no adapter, so we aren't aliased - # assert self._parententity is self._parentmapper return [ - col._annotate( - { - "parententity": self._parententity, - "parentmapper": self._parententity, - "orm_key": self.prop.key, - "compile_state_plugin": "orm", - } - )._set_propagate_attrs( - { - "compile_state_plugin": "orm", - "plugin_subject": self._parententity, - } - ) - for col in self.prop.columns + self._orm_annotate_column(col) for col in self.prop.columns ] def _fallback_getattr(self, key): diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 97a81e30f..5137f9b1d 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -360,7 +360,7 @@ class Query( ): # if we don't have legacy top level aliasing features in use # then convert to a future select() directly - stmt = self._statement_20() + stmt = self._statement_20(for_statement=True) else: stmt = self._compile_state(for_statement=True).statement @@ -371,7 +371,24 @@ class Query( return stmt - def _statement_20(self, orm_results=False): + def _final_statement(self, legacy_query_style=True): + """Return the 'final' SELECT statement for this :class:`.Query`. + + This is the Core-only select() that will be rendered by a complete + compilation of this query, and is what .statement used to return + in 1.3. + + This method creates a complete compile state so is fairly expensive. + + """ + + q = self._clone() + + return q._compile_state( + use_legacy_query_style=legacy_query_style + ).statement + + def _statement_20(self, for_statement=False, use_legacy_query_style=True): # TODO: this event needs to be deprecated, as it currently applies # only to ORM query and occurs at this spot that is now more # or less an artificial spot @@ -384,7 +401,10 @@ class Query( self.compile_options += {"_bake_ok": False} compile_options = self.compile_options - compile_options += {"_use_legacy_query_style": True} + compile_options += { + "_for_statement": for_statement, + "_use_legacy_query_style": use_legacy_query_style, + } if self._statement is not None: stmt = FromStatement(self._raw_columns, self._statement) @@ -404,13 +424,16 @@ class Query( compile_options=compile_options, ) - if not orm_results: - stmt.compile_options += {"_orm_results": False} - stmt._propagate_attrs = self._propagate_attrs return stmt - def subquery(self, name=None, with_labels=False, reduce_columns=False): + def subquery( + self, + name=None, + with_labels=False, + reduce_columns=False, + _legacy_core_statement=False, + ): """return the full SELECT statement represented by this :class:`_query.Query`, embedded within an :class:`_expression.Alias`. @@ -436,7 +459,11 @@ class Query( q = self.enable_eagerloads(False) if with_labels: q = q.with_labels() - q = q.statement + + if _legacy_core_statement: + q = q._compile_state(for_statement=True).statement + else: + q = q.statement if reduce_columns: q = q.reduce_columns() @@ -943,7 +970,7 @@ class Query( # tablename_colname style is used which at the moment is asserted # in a lot of unit tests :) - statement = self._statement_20(orm_results=True).apply_labels() + statement = self._statement_20().apply_labels() return db_load_fn( self.session, statement, @@ -1328,13 +1355,13 @@ class Query( self.with_labels() .enable_eagerloads(False) .correlate(None) - .subquery() + .subquery(_legacy_core_statement=True) ._anonymous_fromclause() ) parententity = self._raw_columns[0]._annotations.get("parententity") if parententity: - ac = aliased(parententity, alias=fromclause) + ac = aliased(parententity.mapper, alias=fromclause) q = self._from_selectable(ac) else: q = self._from_selectable(fromclause) @@ -2782,7 +2809,7 @@ class Query( def _iter(self): # new style execution. params = self.load_options._params - statement = self._statement_20(orm_results=True) + statement = self._statement_20() result = self.session.execute( statement, params, @@ -2808,7 +2835,7 @@ class Query( ) def __str__(self): - statement = self._statement_20(orm_results=True) + statement = self._statement_20() try: bind = ( @@ -2879,9 +2906,8 @@ class Query( "for linking ORM results to arbitrary select constructs.", version="1.4", ) - compile_state = ORMCompileState._create_for_legacy_query( - self, toplevel=True - ) + compile_state = self._compile_state(for_statement=False) + context = QueryContext( compile_state, self.session, self.load_options ) @@ -3294,10 +3320,35 @@ class Query( return update_op.rowcount def _compile_state(self, for_statement=False, **kw): - return ORMCompileState._create_for_legacy_query( - self, toplevel=True, for_statement=for_statement, **kw + """Create an out-of-compiler ORMCompileState object. + + The ORMCompileState object is normally created directly as a result + of the SQLCompiler.process() method being handed a Select() + or FromStatement() object that uses the "orm" plugin. This method + provides a means of creating this ORMCompileState object directly + without using the compiler. + + This method is used only for deprecated cases, which include + the .from_self() method for a Query that has multiple levels + of .from_self() in use, as well as the instances() method. It is + also used within the test suite to generate ORMCompileState objects + for test purposes. + + """ + + stmt = self._statement_20(for_statement=for_statement, **kw) + assert for_statement == stmt.compile_options._for_statement + + # this chooses between ORMFromStatementCompileState and + # ORMSelectCompileState. We could also base this on + # query._statement is not None as we have the ORM Query here + # however this is the more general path. + compile_state_cls = ORMCompileState._get_plugin_class_for_plugin( + stmt, "orm" ) + return compile_state_cls.create_for_statement(stmt, None) + def _compile_context(self, for_statement=False): compile_state = self._compile_state(for_statement=for_statement) context = QueryContext(compile_state, self.session, self.load_options) @@ -3311,6 +3362,8 @@ class FromStatement(SelectStatementGrouping, Executable): """ + __visit_name__ = "orm_from_statement" + compile_options = ORMFromStatementCompileState.default_compile_options _compile_state_factory = ORMFromStatementCompileState.create_for_statement @@ -3329,6 +3382,14 @@ class FromStatement(SelectStatementGrouping, Executable): super(FromStatement, self).__init__(element) def _compiler_dispatch(self, compiler, **kw): + + """provide a fixed _compiler_dispatch method. + + This is roughly similar to using the sqlalchemy.ext.compiler + ``@compiles`` extension. + + """ + compile_state = self._compile_state_factory(self, compiler, **kw) toplevel = not compiler.stack diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index e82cd174f..683f2b978 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -1170,9 +1170,9 @@ class RelationshipProperty(StrategizedProperty): def __clause_element__(self): adapt_from = self._source_selectable() if self._of_type: - of_type_mapper = inspect(self._of_type).mapper + of_type_entity = inspect(self._of_type) else: - of_type_mapper = None + of_type_entity = None ( pj, @@ -1184,7 +1184,7 @@ class RelationshipProperty(StrategizedProperty): ) = self.property._create_joins( source_selectable=adapt_from, source_polymorphic=True, - of_type_mapper=of_type_mapper, + of_type_entity=of_type_entity, alias_secondary=True, ) if sj is not None: @@ -1311,7 +1311,6 @@ class RelationshipProperty(StrategizedProperty): secondary, target_adapter, ) = self.property._create_joins( - dest_polymorphic=True, dest_selectable=to_selectable, source_selectable=source_selectable, ) @@ -2424,9 +2423,8 @@ class RelationshipProperty(StrategizedProperty): self, source_polymorphic=False, source_selectable=None, - dest_polymorphic=False, dest_selectable=None, - of_type_mapper=None, + of_type_entity=None, alias_secondary=False, ): @@ -2439,9 +2437,17 @@ class RelationshipProperty(StrategizedProperty): if source_polymorphic and self.parent.with_polymorphic: source_selectable = self.parent._with_polymorphic_selectable + if of_type_entity: + dest_mapper = of_type_entity.mapper + if dest_selectable is None: + dest_selectable = of_type_entity.selectable + aliased = True + else: + dest_mapper = self.mapper + if dest_selectable is None: dest_selectable = self.entity.selectable - if dest_polymorphic and self.mapper.with_polymorphic: + if self.mapper.with_polymorphic: aliased = True if self._is_self_referential and source_selectable is None: @@ -2453,8 +2459,6 @@ class RelationshipProperty(StrategizedProperty): ): aliased = True - dest_mapper = of_type_mapper or self.mapper - single_crit = dest_mapper._single_table_criterion aliased = aliased or ( source_selectable is not None diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 626018997..2b8c384c9 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -1143,7 +1143,7 @@ class SubqueryLoader(PostLoader): ) = self._get_leftmost(subq_path) orig_query = compile_state.attributes.get( - ("orig_query", SubqueryLoader), compile_state.query + ("orig_query", SubqueryLoader), compile_state.select_statement ) # generate a new Query from the original, then diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index ce37d962e..85f4f85d1 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -466,7 +466,7 @@ class AliasedClass(object): def __init__( self, - cls, + mapped_class_or_ac, alias=None, name=None, flat=False, @@ -478,7 +478,9 @@ class AliasedClass(object): use_mapper_path=False, represents_outer_join=False, ): - mapper = _class_to_mapper(cls) + insp = inspection.inspect(mapped_class_or_ac) + mapper = insp.mapper + if alias is None: alias = mapper._with_polymorphic_selectable._anonymous_fromclause( name=name, flat=flat @@ -486,7 +488,7 @@ class AliasedClass(object): self._aliased_insp = AliasedInsp( self, - mapper, + insp, alias, name, with_polymorphic_mappers @@ -617,7 +619,7 @@ class AliasedInsp( def __init__( self, entity, - mapper, + inspected, selectable, name, with_polymorphic_mappers, @@ -627,6 +629,10 @@ class AliasedInsp( adapt_on_names, represents_outer_join, ): + + mapped_class_or_ac = inspected.entity + mapper = inspected.mapper + self._weak_entity = weakref.ref(entity) self.mapper = mapper self.selectable = ( @@ -665,9 +671,12 @@ class AliasedInsp( adapt_on_names=adapt_on_names, anonymize_labels=True, ) + if inspected.is_aliased_class: + self._adapter = inspected._adapter.wrap(self._adapter) self._adapt_on_names = adapt_on_names - self._target = mapper.class_ + self._target = mapped_class_or_ac + # self._target = mapper.class_ # mapped_class_or_ac @property def entity(self): @@ -795,6 +804,21 @@ class AliasedInsp( def _memoized_values(self): return {} + @util.memoized_property + def columns(self): + if self._is_with_polymorphic: + cols_plus_keys = self.mapper._columns_plus_keys( + [ent.mapper for ent in self._with_polymorphic_entities] + ) + else: + cols_plus_keys = self.mapper._columns_plus_keys() + + cols_plus_keys = [ + (key, self._adapt_element(col)) for key, col in cols_plus_keys + ] + + return ColumnCollection(cols_plus_keys) + def _memo(self, key, callable_, *args, **kw): if key in self._memoized_values: return self._memoized_values[key] @@ -1290,8 +1314,7 @@ class _ORMJoin(expression.Join): source_selectable=adapt_from, dest_selectable=adapt_to, source_polymorphic=True, - dest_polymorphic=True, - of_type_mapper=right_info.mapper, + of_type_entity=right_info, alias_secondary=True, ) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 6415d4b37..f14319089 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -522,7 +522,12 @@ class _MetaOptions(type): def __init__(cls, classname, bases, dict_): cls._cache_attrs = tuple( - sorted(d for d in dict_ if not d.startswith("__")) + sorted( + d + for d in dict_ + if not d.startswith("__") + and d not in ("_cache_key_traversal",) + ) ) type.__init__(cls, classname, bases, dict_) @@ -561,6 +566,31 @@ class Options(util.with_metaclass(_MetaOptions)): def _state_dict(cls): return cls._state_dict_const + @classmethod + def safe_merge(cls, other): + d = other._state_dict() + + # only support a merge with another object of our class + # and which does not have attrs that we dont. otherwise + # we risk having state that might not be part of our cache + # key strategy + + if ( + cls is not other.__class__ + and other._cache_attrs + and set(other._cache_attrs).difference(cls._cache_attrs) + ): + raise TypeError( + "other element %r is not empty, is not of type %s, " + "and contains attributes not covered here %r" + % ( + other, + cls, + set(other._cache_attrs).difference(cls._cache_attrs), + ) + ) + return cls + d + class CacheableOptions(Options, HasCacheKey): @hybridmethod diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 287e53724..fa2888a23 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -878,6 +878,7 @@ class ColumnElement( key = self._proxy_key else: key = name + co = ColumnClause( coercions.expect(roles.TruncatedLabelRole, name) if name_is_truncatable @@ -885,6 +886,7 @@ class ColumnElement( type_=getattr(self, "type", None), _selectable=selectable, ) + co._propagate_attrs = selectable._propagate_attrs co._proxies = [self] if selectable._is_clone_of is not None: @@ -1284,6 +1286,7 @@ class BindParameter(roles.InElementRole, ColumnElement): """ + if required is NO_ARG: required = value is NO_ARG and callable_ is None if value is NO_ARG: @@ -1302,6 +1305,7 @@ class BindParameter(roles.InElementRole, ColumnElement): id(self), re.sub(r"[%\(\) \$]+", "_", key).strip("_") if key is not None + and not isinstance(key, _anonymous_label) else "param", ) ) @@ -4182,16 +4186,27 @@ class Label(roles.LabeledColumnExprRole, ColumnElement): return self.element._from_objects def _make_proxy(self, selectable, name=None, **kw): + name = self.name if not name else name + key, e = self.element._make_proxy( selectable, - name=name if name else self.name, + name=name, disallow_is_literal=True, + name_is_truncatable=isinstance(name, _truncated_label), ) + # TODO: want to remove this assertion at some point. all + # _make_proxy() implementations will give us back the key that + # is our "name" in the first place. based on this we can + # safely return our "self.key" as the key here, to support a new + # case where the key and name are separate. + assert key == self.name + e._propagate_attrs = selectable._propagate_attrs e._proxies.append(self) if self._type is not None: e.type = self._type - return key, e + + return self.key, e class ColumnClause( @@ -4240,7 +4255,7 @@ class ColumnClause( __visit_name__ = "column" _traverse_internals = [ - ("name", InternalTraversal.dp_string), + ("name", InternalTraversal.dp_anon_name), ("type", InternalTraversal.dp_type), ("table", InternalTraversal.dp_clauseelement), ("is_literal", InternalTraversal.dp_boolean), @@ -4410,10 +4425,8 @@ class ColumnClause( def _gen_label(self, name, dedupe_on_key=True): t = self.table - if self.is_literal: return None - elif t is not None and t.named_with_column: if getattr(t, "schema", None): label = t.schema.replace(".", "_") + "_" + t.name + "_" + name diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 170e016a5..d6845e05f 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -3451,8 +3451,8 @@ class SelectState(util.MemoizedSlots, CompileState): self.columns_plus_names = statement._generate_columns_plus_names(True) def _get_froms(self, statement): - froms = [] seen = set() + froms = [] for item in itertools.chain( itertools.chain.from_iterable( @@ -3474,6 +3474,16 @@ class SelectState(util.MemoizedSlots, CompileState): froms.append(item) seen.update(item._cloned_set) + toremove = set( + itertools.chain.from_iterable( + [_expand_cloned(f._hide_froms) for f in froms] + ) + ) + if toremove: + # filter out to FROM clauses not in the list, + # using a list to maintain ordering + froms = [f for f in froms if f not in toremove] + return froms def _get_display_froms( @@ -3490,16 +3500,6 @@ class SelectState(util.MemoizedSlots, CompileState): froms = self.froms - toremove = set( - itertools.chain.from_iterable( - [_expand_cloned(f._hide_froms) for f in froms] - ) - ) - if toremove: - # filter out to FROM clauses not in the list, - # using a list to maintain ordering - froms = [f for f in froms if f not in toremove] - if self.statement._correlate: to_correlate = self.statement._correlate if to_correlate: @@ -3557,7 +3557,7 @@ class SelectState(util.MemoizedSlots, CompileState): def _memoized_attr__label_resolve_dict(self): with_cols = dict( (c._resolve_label or c._label or c.key, c) - for c in _select_iterables(self.statement._raw_columns) + for c in self.statement._exported_columns_iterator() if c._allow_label_resolve ) only_froms = dict( @@ -3578,6 +3578,10 @@ class SelectState(util.MemoizedSlots, CompileState): else: return None + @classmethod + def exported_columns_iterator(cls, statement): + return _select_iterables(statement._raw_columns) + def _setup_joins(self, args): for (right, onclause, left, flags) in args: isouter = flags["isouter"] @@ -4599,7 +4603,7 @@ class Select( pa = None collection = [] - for c in _select_iterables(self._raw_columns): + for c in self._exported_columns_iterator(): # we use key_label since this name is intended for targeting # within the ColumnCollection only, it's not related to SQL # rendering which always uses column name for SQL label names @@ -4630,7 +4634,7 @@ class Select( return self def _generate_columns_plus_names(self, anon_for_dupe_key): - cols = _select_iterables(self._raw_columns) + cols = self._exported_columns_iterator() # when use_labels is on: # in all cases == if we see the same label name, use _label_anon_label diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index a38088a27..388097e45 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -18,6 +18,7 @@ NO_CACHE = util.symbol("no_cache") CACHE_IN_PLACE = util.symbol("cache_in_place") CALL_GEN_CACHE_KEY = util.symbol("call_gen_cache_key") STATIC_CACHE_KEY = util.symbol("static_cache_key") +ANON_NAME = util.symbol("anon_name") def compare(obj1, obj2, **kw): @@ -33,6 +34,7 @@ class HasCacheKey(object): _cache_key_traversal = NO_CACHE __slots__ = () + @util.preload_module("sqlalchemy.sql.elements") def _gen_cache_key(self, anon_map, bindparams): """return an optional cache key. @@ -54,6 +56,8 @@ class HasCacheKey(object): """ + elements = util.preloaded.sql_elements + idself = id(self) if anon_map is not None: @@ -102,6 +106,10 @@ class HasCacheKey(object): result += (attrname, obj) elif meth is STATIC_CACHE_KEY: result += (attrname, obj._static_cache_key) + elif meth is ANON_NAME: + if elements._anonymous_label in obj.__class__.__mro__: + obj = obj.apply_map(anon_map) + result += (attrname, obj) elif meth is CALL_GEN_CACHE_KEY: result += ( attrname, @@ -321,6 +329,7 @@ class _CacheKey(ExtendedInternalTraversal): ) = visit_operator = visit_plain_obj = CACHE_IN_PLACE visit_statement_hint_list = CACHE_IN_PLACE visit_type = STATIC_CACHE_KEY + visit_anon_name = ANON_NAME def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams): return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams)) @@ -387,15 +396,6 @@ class _CacheKey(ExtendedInternalTraversal): attrname, obj, parent, anon_map, bindparams ) - def visit_anon_name(self, attrname, obj, parent, anon_map, bindparams): - from . import elements - - name = obj - if isinstance(name, elements._anonymous_label): - name = name.apply_map(anon_map) - - return (attrname, name) - def visit_fromclause_ordered_set( self, attrname, obj, parent, anon_map, bindparams ): diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 377aa4fe0..e8726000b 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -822,9 +822,14 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): # is another join or selectable that contains a table which our # selectable derives from, that we want to process return None + elif not isinstance(col, ColumnElement): return None - elif self.include_fn and not self.include_fn(col): + + if "adapt_column" in col._annotations: + col = col._annotations["adapt_column"] + + if self.include_fn and not self.include_fn(col): return None elif self.exclude_fn and self.exclude_fn(col): return None diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 683f545dd..5de68f504 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -50,6 +50,13 @@ def _generate_compiler_dispatch(cls): """ visit_name = cls.__visit_name__ + if "_compiler_dispatch" in cls.__dict__: + # class has a fixed _compiler_dispatch() method. + # copy it to "original" so that we can get it back if + # sqlalchemy.ext.compiles overrides it. + cls._original_compiler_dispatch = cls._compiler_dispatch + return + if not isinstance(visit_name, util.compat.string_types): raise exc.InvalidRequestError( "__visit_name__ on class %s must be a string at the class level" @@ -76,7 +83,9 @@ def _generate_compiler_dispatch(cls): + self.__visit_name__ on the visitor, and call it with the same kw params. """ - cls._compiler_dispatch = _compiler_dispatch + cls._compiler_dispatch = ( + cls._original_compiler_dispatch + ) = _compiler_dispatch class TraversibleType(type): diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 79de3c978..247dbc13c 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -399,7 +399,7 @@ def reraise(tp, value, tb=None, cause=None): raise_(value, with_traceback=tb, from_=cause) -def with_metaclass(meta, *bases): +def with_metaclass(meta, *bases, **kw): """Create a base class with a metaclass. Drops the middle class upon creation. @@ -414,8 +414,15 @@ def with_metaclass(meta, *bases): def __new__(cls, name, this_bases, d): if this_bases is None: - return type.__new__(cls, name, (), d) - return meta(name, bases, d) + cls = type.__new__(cls, name, (), d) + else: + cls = meta(name, bases, d) + + if hasattr(cls, "__init_subclass__") and hasattr( + cls.__init_subclass__, "__func__" + ): + cls.__init_subclass__.__func__(cls, **kw) + return cls return metaclass("temporary_class", None, {}) diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py index f261bc811..5dbbc2f5c 100644 --- a/test/aaa_profiling/test_orm.py +++ b/test/aaa_profiling/test_orm.py @@ -857,7 +857,10 @@ class JoinedEagerLoadTest(fixtures.MappedTest): exec_opts = {} bind_arguments = {} ORMCompileState.orm_pre_session_exec( - sess, compile_state.query, exec_opts, bind_arguments + sess, + compile_state.select_statement, + exec_opts, + bind_arguments, ) r = sess.connection().execute( diff --git a/test/orm/inheritance/_poly_fixtures.py b/test/orm/inheritance/_poly_fixtures.py index 5d23e7801..da2ad4cdf 100644 --- a/test/orm/inheritance/_poly_fixtures.py +++ b/test/orm/inheritance/_poly_fixtures.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import create_session from sqlalchemy.orm import mapper from sqlalchemy.orm import polymorphic_union from sqlalchemy.orm import relationship +from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import config from sqlalchemy.testing import fixtures @@ -54,6 +55,8 @@ class _PolymorphicFixtureBase(fixtures.MappedTest, AssertsCompiledSQL): run_setup_mappers = "once" run_deletes = None + label_style = LABEL_STYLE_TABLENAME_PLUS_COL + @classmethod def define_tables(cls, metadata): global people, engineers, managers, boss @@ -427,14 +430,16 @@ class _PolymorphicAliasedJoins(_PolymorphicFixtureBase): person_join = ( people.outerjoin(engineers) .outerjoin(managers) - .select(use_labels=True) - .alias("pjoin") + .select() + ._set_label_style(cls.label_style) + .subquery("pjoin") ) manager_join = ( people.join(managers) .outerjoin(boss) - .select(use_labels=True) - .alias("mjoin") + .select() + ._set_label_style(cls.label_style) + .subquery("mjoin") ) person_with_polymorphic = ([Person, Manager, Engineer], person_join) manager_with_polymorphic = ("*", manager_join) diff --git a/test/orm/inheritance/test_assorted_poly.py b/test/orm/inheritance/test_assorted_poly.py index 62f2097d3..514f4ba76 100644 --- a/test/orm/inheritance/test_assorted_poly.py +++ b/test/orm/inheritance/test_assorted_poly.py @@ -1342,6 +1342,7 @@ class GenerativeTest(fixtures.MappedTest, AssertsExecutionResults): ) .order_by(Person.name) ) + eq_( list(r), [ diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index 5e832e934..e38758ee2 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -1357,6 +1357,7 @@ class EagerTargetingTest(fixtures.MappedTest): bid = b1.id sess.expunge_all() + node = sess.query(B).filter(B.id == bid).all()[0] eq_(node, B(id=1, name="b1", b_data="i")) eq_(node.children[0], B(id=2, name="b2", b_data="l")) diff --git a/test/orm/inheritance/test_polymorphic_rel.py b/test/orm/inheritance/test_polymorphic_rel.py index 549414507..e7e2530b2 100644 --- a/test/orm/inheritance/test_polymorphic_rel.py +++ b/test/orm/inheritance/test_polymorphic_rel.py @@ -1522,6 +1522,32 @@ class _PolymorphicTestBase(object): expected, ) + def test_self_referential_two_newstyle(self): + # TODO: this is the first test *EVER* of an aliased class of + # an aliased class. we should add many more tests for this. + # new case added in Id810f485c5f7ed971529489b84694e02a3356d6d + sess = create_session() + expected = [(m1, e1), (m1, e2), (m1, b1)] + + p1 = aliased(Person) + p2 = aliased(Person) + stmt = ( + future_select(p1, p2) + .filter(p1.company_id == p2.company_id) + .filter(p1.name == "dogbert") + .filter(p1.person_id > p2.person_id) + ) + subq = stmt.subquery() + + pa1 = aliased(p1, subq) + pa2 = aliased(p2, subq) + + stmt = future_select(pa1, pa2).order_by(pa1.person_id, pa2.person_id) + + eq_( + sess.execute(stmt).unique().all(), expected, + ) + def test_nesting_queries(self): # query.statement places a flag "no_adapt" on the returned # statement. This prevents the polymorphic adaptation in the diff --git a/test/orm/inheritance/test_relationship.py b/test/orm/inheritance/test_relationship.py index ea5b9f96b..5ba482649 100644 --- a/test/orm/inheritance/test_relationship.py +++ b/test/orm/inheritance/test_relationship.py @@ -6,6 +6,7 @@ from sqlalchemy import String from sqlalchemy import testing from sqlalchemy.orm import aliased from sqlalchemy.orm import backref +from sqlalchemy.orm import configure_mappers from sqlalchemy.orm import contains_eager from sqlalchemy.orm import create_session from sqlalchemy.orm import joinedload @@ -765,6 +766,48 @@ class SelfReferentialM2MTest(fixtures.MappedTest, AssertsCompiledSQL): "secondary_1.left_id", ) + def test_query_crit_core_workaround(self): + # do a test in the style of orm/test_core_compilation.py + + Child1, Child2 = self.classes.Child1, self.classes.Child2 + secondary = self.tables.secondary + + configure_mappers() + + from sqlalchemy.sql import join + + C1 = aliased(Child1, flat=True) + + # figure out all the things we need to do in Core to make + # the identical query that the ORM renders. + + salias = secondary.alias() + stmt = ( + select([Child2]) + .select_from( + join( + Child2, + salias, + Child2.id.expressions[1] == salias.c.left_id, + ).join(C1, salias.c.right_id == C1.id.expressions[1]) + ) + .where(C1.left_child2 == Child2(id=1)) + ) + + self.assert_compile( + stmt.apply_labels(), + "SELECT parent.id AS parent_id, " + "parent.cls AS parent_cls, child2.id AS child2_id " + "FROM secondary AS secondary_1, " + "parent JOIN child2 ON parent.id = child2.id JOIN secondary AS " + "secondary_2 ON parent.id = secondary_2.left_id JOIN " + "(parent AS parent_1 JOIN child1 AS child1_1 " + "ON parent_1.id = child1_1.id) " + "ON parent_1.id = secondary_2.right_id WHERE " + "parent_1.id = secondary_1.right_id AND :param_1 = " + "secondary_1.left_id", + ) + def test_eager_join(self): Child1, Child2 = self.classes.Child1, self.classes.Child2 sess = create_session() diff --git a/test/orm/test_core_compilation.py b/test/orm/test_core_compilation.py index 61df1d277..a26d0ae26 100644 --- a/test/orm/test_core_compilation.py +++ b/test/orm/test_core_compilation.py @@ -2,21 +2,30 @@ from sqlalchemy import exc from sqlalchemy import func from sqlalchemy import insert from sqlalchemy import literal_column +from sqlalchemy import or_ from sqlalchemy import testing +from sqlalchemy import util from sqlalchemy.future import select from sqlalchemy.orm import aliased from sqlalchemy.orm import column_property +from sqlalchemy.orm import contains_eager from sqlalchemy.orm import join as orm_join +from sqlalchemy.orm import joinedload from sqlalchemy.orm import mapper +from sqlalchemy.orm import query_expression +from sqlalchemy.orm import relationship from sqlalchemy.orm import Session +from sqlalchemy.orm import with_expression from sqlalchemy.orm import with_polymorphic from sqlalchemy.sql.selectable import Join as core_join +from sqlalchemy.sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY +from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL +from sqlalchemy.testing import eq_ from .inheritance import _poly_fixtures from .test_query import QueryTest - # TODO: # composites / unions, etc. @@ -178,6 +187,344 @@ class JoinTest(QueryTest, AssertsCompiledSQL): ) +class LoadersInSubqueriesTest(QueryTest, AssertsCompiledSQL): + """The Query object calls eanble_eagerloads(False) when you call + .subquery(). With Core select, we don't have that information, we instead + have to look at the "toplevel" flag to know where we are. make sure + the many different combinations that these two objects and still + too many flags at the moment work as expected on the outside. + + """ + + __dialect__ = "default" + + run_setup_mappers = None + + @testing.fixture + def joinedload_fixture(self): + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={"addresses": relationship(Address, lazy="joined")}, + ) + + mapper(Address, addresses) + + return User, Address + + def test_no_joinedload_in_subquery_select_rows(self, joinedload_fixture): + User, Address = joinedload_fixture + + sess = Session() + stmt1 = sess.query(User).subquery() + stmt1 = sess.query(stmt1) + + stmt2 = select(User).subquery() + + stmt2 = select(stmt2) + + expected = ( + "SELECT anon_1.id, anon_1.name FROM " + "(SELECT users.id AS id, users.name AS name " + "FROM users) AS anon_1" + ) + self.assert_compile( + stmt1._final_statement(legacy_query_style=False), expected, + ) + + self.assert_compile(stmt2, expected) + + def test_no_joinedload_in_subquery_select_entity(self, joinedload_fixture): + User, Address = joinedload_fixture + + sess = Session() + stmt1 = sess.query(User).subquery() + ua = aliased(User, stmt1) + stmt1 = sess.query(ua) + + stmt2 = select(User).subquery() + + ua = aliased(User, stmt2) + stmt2 = select(ua) + + expected = ( + "SELECT anon_1.id, anon_1.name, addresses_1.id AS id_1, " + "addresses_1.user_id, addresses_1.email_address FROM " + "(SELECT users.id AS id, users.name AS name FROM users) AS anon_1 " + "LEFT OUTER JOIN addresses AS addresses_1 " + "ON anon_1.id = addresses_1.user_id" + ) + + self.assert_compile( + stmt1._final_statement(legacy_query_style=False), expected, + ) + + self.assert_compile(stmt2, expected) + + # TODO: need to test joinedload options, deferred mappings, deferred + # options. these are all loader options that should *only* have an + # effect on the outermost statement, never a subquery. + + +class ExtraColsTest(QueryTest, AssertsCompiledSQL): + __dialect__ = "default" + + run_setup_mappers = None + + @testing.fixture + def query_expression_fixture(self): + users, User = ( + self.tables.users, + self.classes.User, + ) + + mapper( + User, + users, + properties=util.OrderedDict([("value", query_expression())]), + ) + return User + + @testing.fixture + def column_property_fixture(self): + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties=util.OrderedDict( + [ + ("concat", column_property((users.c.id * 2))), + ( + "count", + column_property( + select(func.count(addresses.c.id)) + .where(users.c.id == addresses.c.user_id,) + .correlate(users) + .scalar_subquery() + ), + ), + ] + ), + ) + + mapper(Address, addresses, properties={"user": relationship(User,)}) + + return User, Address + + @testing.fixture + def plain_fixture(self): + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, users, + ) + + mapper(Address, addresses, properties={"user": relationship(User,)}) + + return User, Address + + def test_no_joinedload_embedded(self, plain_fixture): + User, Address = plain_fixture + + stmt = select(Address).options(joinedload(Address.user)) + + subq = stmt.subquery() + + s2 = select(subq) + + self.assert_compile( + s2, + "SELECT anon_1.id, anon_1.user_id, anon_1.email_address " + "FROM (SELECT addresses.id AS id, addresses.user_id AS " + "user_id, addresses.email_address AS email_address " + "FROM addresses) AS anon_1", + ) + + def test_with_expr_one(self, query_expression_fixture): + User = query_expression_fixture + + stmt = select(User).options( + with_expression(User.value, User.name + "foo") + ) + + self.assert_compile( + stmt, + "SELECT users.name || :name_1 AS anon_1, users.id, " + "users.name FROM users", + ) + + def test_with_expr_two(self, query_expression_fixture): + User = query_expression_fixture + + stmt = select(User.id, User.name, (User.name + "foo").label("foo")) + + subq = stmt.subquery() + u1 = aliased(User, subq) + + stmt = select(u1).options(with_expression(u1.value, subq.c.foo)) + + self.assert_compile( + stmt, + "SELECT anon_1.foo, anon_1.id, anon_1.name FROM " + "(SELECT users.id AS id, users.name AS name, " + "users.name || :name_1 AS foo FROM users) AS anon_1", + ) + + def test_joinedload_outermost(self, plain_fixture): + User, Address = plain_fixture + + stmt = select(Address).options(joinedload(Address.user)) + + # render joined eager loads with stringify + self.assert_compile( + stmt, + "SELECT addresses.id, addresses.user_id, addresses.email_address, " + "users_1.id AS id_1, users_1.name FROM addresses " + "LEFT OUTER JOIN users AS users_1 " + "ON users_1.id = addresses.user_id", + ) + + def test_contains_eager_outermost(self, plain_fixture): + User, Address = plain_fixture + + stmt = ( + select(Address) + .join(Address.user) + .options(contains_eager(Address.user)) + ) + + # render joined eager loads with stringify + self.assert_compile( + stmt, + "SELECT users.id, users.name, addresses.id AS id_1, " + "addresses.user_id, " + "addresses.email_address " + "FROM addresses JOIN users ON users.id = addresses.user_id", + ) + + def test_column_properties(self, column_property_fixture): + """test querying mappings that reference external columns or + selectables.""" + + User, Address = column_property_fixture + + stmt = select(User) + + self.assert_compile( + stmt, + "SELECT users.id * :id_1 AS anon_1, " + "(SELECT count(addresses.id) AS count_1 FROM addresses " + "WHERE users.id = addresses.user_id) AS anon_2, users.id, " + "users.name FROM users", + checkparams={"id_1": 2}, + ) + + def test_column_properties_can_we_use(self, column_property_fixture): + """test querying mappings that reference external columns or + selectables. """ + + # User, Address = column_property_fixture + + # stmt = select(User) + + # TODO: shouldn't we be able to get at count ? + + # stmt = stmt.where(stmt.selected_columns.count > 5) + + # self.assert_compile(stmt, "") + + def test_column_properties_subquery(self, column_property_fixture): + """test querying mappings that reference external columns or + selectables.""" + + User, Address = column_property_fixture + + stmt = select(User) + + # here, the subquery needs to export the columns that include + # the column properties + stmt = select(stmt.subquery()) + + # TODO: shouldnt we be able to get to stmt.subquery().c.count ? + self.assert_compile( + stmt, + "SELECT anon_2.anon_1, anon_2.anon_3, anon_2.id, anon_2.name " + "FROM (SELECT users.id * :id_1 AS anon_1, " + "(SELECT count(addresses.id) AS count_1 FROM addresses " + "WHERE users.id = addresses.user_id) AS anon_3, users.id AS id, " + "users.name AS name FROM users) AS anon_2", + checkparams={"id_1": 2}, + ) + + def test_column_properties_subquery_two(self, column_property_fixture): + """test querying mappings that reference external columns or + selectables.""" + + User, Address = column_property_fixture + + # col properties will retain anonymous labels, however will + # adopt the .key within the subquery collection so they can + # be addressed. + stmt = select(User.id, User.name, User.concat, User.count,) + + subq = stmt.subquery() + # here, the subquery needs to export the columns that include + # the column properties + stmt = select(subq).where(subq.c.concat == "foo") + + self.assert_compile( + stmt, + "SELECT anon_1.id, anon_1.name, anon_1.anon_2, anon_1.anon_3 " + "FROM (SELECT users.id AS id, users.name AS name, " + "users.id * :id_1 AS anon_2, " + "(SELECT count(addresses.id) AS count_1 " + "FROM addresses WHERE users.id = addresses.user_id) AS anon_3 " + "FROM users) AS anon_1 WHERE anon_1.anon_2 = :param_1", + checkparams={"id_1": 2, "param_1": "foo"}, + ) + + def test_column_properties_aliased_subquery(self, column_property_fixture): + """test querying mappings that reference external columns or + selectables.""" + + User, Address = column_property_fixture + + u1 = aliased(User) + stmt = select(u1) + + # here, the subquery needs to export the columns that include + # the column properties + stmt = select(stmt.subquery()) + self.assert_compile( + stmt, + "SELECT anon_2.anon_1, anon_2.anon_3, anon_2.id, anon_2.name " + "FROM (SELECT users_1.id * :id_1 AS anon_1, " + "(SELECT count(addresses.id) AS count_1 FROM addresses " + "WHERE users_1.id = addresses.user_id) AS anon_3, " + "users_1.id AS id, users_1.name AS name " + "FROM users AS users_1) AS anon_2", + checkparams={"id_1": 2}, + ) + + class RelationshipNaturalCompileTest(QueryTest, AssertsCompiledSQL): """test using core join() with relationship attributes. @@ -193,7 +540,6 @@ class RelationshipNaturalCompileTest(QueryTest, AssertsCompiledSQL): __dialect__ = "default" - @testing.fails("need to have of_type() expressions render directly") def test_of_type_implicit_join(self): User, Address = self.classes("User", "Address") @@ -201,7 +547,12 @@ class RelationshipNaturalCompileTest(QueryTest, AssertsCompiledSQL): a1 = aliased(Address) stmt1 = select(u1).where(u1.addresses.of_type(a1)) - stmt2 = Session().query(u1).filter(u1.addresses.of_type(a1)) + stmt2 = ( + Session() + .query(u1) + .filter(u1.addresses.of_type(a1)) + ._final_statement(legacy_query_style=False) + ) expected = ( "SELECT users_1.id, users_1.name FROM users AS users_1, " @@ -260,6 +611,118 @@ class InheritedTest(_poly_fixtures._Polymorphic): run_setup_mappers = "once" +class ExplicitWithPolymorhpicTest( + _poly_fixtures._PolymorphicUnions, AssertsCompiledSQL +): + + __dialect__ = "default" + + default_punion = ( + "(SELECT pjoin.person_id AS person_id, " + "pjoin.company_id AS company_id, " + "pjoin.name AS name, pjoin.type AS type, " + "pjoin.status AS status, pjoin.engineer_name AS engineer_name, " + "pjoin.primary_language AS primary_language, " + "pjoin.manager_name AS manager_name " + "FROM (SELECT engineers.person_id AS person_id, " + "people.company_id AS company_id, people.name AS name, " + "people.type AS type, engineers.status AS status, " + "engineers.engineer_name AS engineer_name, " + "engineers.primary_language AS primary_language, " + "CAST(NULL AS VARCHAR(50)) AS manager_name " + "FROM people JOIN engineers ON people.person_id = engineers.person_id " + "UNION ALL SELECT managers.person_id AS person_id, " + "people.company_id AS company_id, people.name AS name, " + "people.type AS type, managers.status AS status, " + "CAST(NULL AS VARCHAR(50)) AS engineer_name, " + "CAST(NULL AS VARCHAR(50)) AS primary_language, " + "managers.manager_name AS manager_name FROM people " + "JOIN managers ON people.person_id = managers.person_id) AS pjoin) " + "AS anon_1" + ) + + def test_subquery_col_expressions_wpoly_one(self): + Person, Manager, Engineer = self.classes( + "Person", "Manager", "Engineer" + ) + + wp1 = with_polymorphic(Person, [Manager, Engineer]) + + subq1 = select(wp1).subquery() + + wp2 = with_polymorphic(Person, [Engineer, Manager]) + subq2 = select(wp2).subquery() + + # first thing we see, is that when we go through with_polymorphic, + # the entities that get placed into the aliased class go through + # Mapper._mappers_from_spec(), which matches them up to the + # existing Mapper.self_and_descendants collection, meaning, + # the order is the same every time. Assert here that's still + # happening. If a future internal change modifies this assumption, + # that's not necessarily bad, but it would change things. + + eq_( + subq1.c.keys(), + [ + "person_id", + "company_id", + "name", + "type", + "person_id_1", + "status", + "engineer_name", + "primary_language", + "person_id_1", + "status_1", + "manager_name", + ], + ) + eq_( + subq2.c.keys(), + [ + "person_id", + "company_id", + "name", + "type", + "person_id_1", + "status", + "engineer_name", + "primary_language", + "person_id_1", + "status_1", + "manager_name", + ], + ) + + def test_subquery_col_expressions_wpoly_two(self): + Person, Manager, Engineer = self.classes( + "Person", "Manager", "Engineer" + ) + + wp1 = with_polymorphic(Person, [Manager, Engineer]) + + subq1 = select(wp1).subquery() + + stmt = select(subq1).where( + or_( + subq1.c.engineer_name == "dilbert", + subq1.c.manager_name == "dogbert", + ) + ) + + self.assert_compile( + stmt, + "SELECT anon_1.person_id, anon_1.company_id, anon_1.name, " + "anon_1.type, anon_1.person_id AS person_id_1, anon_1.status, " + "anon_1.engineer_name, anon_1.primary_language, " + "anon_1.person_id AS person_id_2, anon_1.status AS status_1, " + "anon_1.manager_name FROM " + "%s WHERE " + "anon_1.engineer_name = :engineer_name_1 " + "OR anon_1.manager_name = :manager_name_1" % (self.default_punion), + ) + + class ImplicitWithPolymorphicTest( _poly_fixtures._PolymorphicUnions, AssertsCompiledSQL ): @@ -310,7 +773,9 @@ class ImplicitWithPolymorphicTest( ) self.assert_compile(stmt, expected) - self.assert_compile(q.statement, expected) + self.assert_compile( + q._final_statement(legacy_query_style=False), expected, + ) def test_select_where_baseclass(self): Person = self.classes.Person @@ -349,7 +814,9 @@ class ImplicitWithPolymorphicTest( ) self.assert_compile(stmt, expected) - self.assert_compile(q.statement, expected) + self.assert_compile( + q._final_statement(legacy_query_style=False), expected, + ) def test_select_where_subclass(self): @@ -397,7 +864,10 @@ class ImplicitWithPolymorphicTest( # in context.py self.assert_compile(stmt, disambiguate_expected) - self.assert_compile(q.statement, disambiguate_expected) + self.assert_compile( + q._final_statement(legacy_query_style=False), + disambiguate_expected, + ) def test_select_where_columns_subclass(self): @@ -436,7 +906,9 @@ class ImplicitWithPolymorphicTest( ) self.assert_compile(stmt, expected) - self.assert_compile(q.statement, expected) + self.assert_compile( + q._final_statement(legacy_query_style=False), expected, + ) class RelationshipNaturalInheritedTest(InheritedTest, AssertsCompiledSQL): @@ -506,15 +978,14 @@ class RelationshipNaturalInheritedTest(InheritedTest, AssertsCompiledSQL): orm_join(Company, Person, Company.employees) ) stmt2 = select(Company).join(Company.employees) - stmt3 = Session().query(Company).join(Company.employees).statement - - # TODO: can't get aliasing to not happen for .join() verion - self.assert_compile( - stmt1, - self.straight_company_to_person_expected.replace( - "pjoin_1", "pjoin" - ), + stmt3 = ( + Session() + .query(Company) + .join(Company.employees) + ._final_statement(legacy_query_style=False) ) + + self.assert_compile(stmt1, self.straight_company_to_person_expected) self.assert_compile(stmt2, self.straight_company_to_person_expected) self.assert_compile(stmt3, self.straight_company_to_person_expected) @@ -532,12 +1003,11 @@ class RelationshipNaturalInheritedTest(InheritedTest, AssertsCompiledSQL): "Company", "Person", "Manager", "Engineer" ) - # TODO: fails - # stmt1 = ( - # select(Company) - # .select_from(orm_join(Company, Person, Company.employees)) - # .where(Person.name == "ed") - # ) + stmt1 = ( + select(Company) + .select_from(orm_join(Company, Person, Company.employees)) + .where(Person.name == "ed") + ) stmt2 = ( select(Company).join(Company.employees).where(Person.name == "ed") @@ -547,20 +1017,10 @@ class RelationshipNaturalInheritedTest(InheritedTest, AssertsCompiledSQL): .query(Company) .join(Company.employees) .filter(Person.name == "ed") - .statement + ._final_statement(legacy_query_style=False) ) - # TODO: more inheriance woes, the first statement doesn't know that - # it loads polymorphically with Person. should we have mappers and - # ORM attributes return their polymorphic entity for - # __clause_element__() ? or should we know to look inside the - # orm_join and find all the entities that are important? it is - # looking like having ORM expressions use their polymoprhic selectable - # will solve a lot but not all of these problems. - - # self.assert_compile(stmt1, self.c_to_p_whereclause) - - # self.assert_compile(stmt1, self.c_to_p_whereclause) + self.assert_compile(stmt1, self.c_to_p_whereclause) self.assert_compile(stmt2, self.c_to_p_whereclause) self.assert_compile(stmt3, self.c_to_p_whereclause) @@ -581,16 +1041,12 @@ class RelationshipNaturalInheritedTest(InheritedTest, AssertsCompiledSQL): .query(Company) .join(Company.employees) .join(Person.paperwork) - .statement + ._final_statement(legacy_query_style=False) ) self.assert_compile(stmt1, self.person_paperwork_expected) - self.assert_compile( - stmt2, self.person_paperwork_expected.replace("pjoin", "pjoin_1") - ) - self.assert_compile( - stmt3, self.person_paperwork_expected.replace("pjoin", "pjoin_1") - ) + self.assert_compile(stmt2, self.person_paperwork_expected) + self.assert_compile(stmt3, self.person_paperwork_expected) def test_wpoly_of_type(self): Company, Person, Manager, Engineer = self.classes( @@ -608,7 +1064,7 @@ class RelationshipNaturalInheritedTest(InheritedTest, AssertsCompiledSQL): Session() .query(Company) .join(Company.employees.of_type(p1)) - .statement + ._final_statement(legacy_query_style=False) ) expected = ( "SELECT companies.company_id, companies.name " @@ -633,7 +1089,11 @@ class RelationshipNaturalInheritedTest(InheritedTest, AssertsCompiledSQL): stmt2 = select(Company).join(p1, Company.employees.of_type(p1)) - stmt3 = s.query(Company).join(Company.employees.of_type(p1)).statement + stmt3 = ( + s.query(Company) + .join(Company.employees.of_type(p1)) + ._final_statement(legacy_query_style=False) + ) expected = ( "SELECT companies.company_id, companies.name FROM companies " @@ -661,7 +1121,7 @@ class RelationshipNaturalInheritedTest(InheritedTest, AssertsCompiledSQL): Session() .query(Company) .join(Company.employees.of_type(p1)) - .statement + ._final_statement(legacy_query_style=False) ) expected = ( @@ -677,9 +1137,12 @@ class RelationshipNaturalInheritedTest(InheritedTest, AssertsCompiledSQL): class RelNaturalAliasedJoinsTest( _poly_fixtures._PolymorphicAliasedJoins, RelationshipNaturalInheritedTest ): + + # this is the label style for the polymorphic selectable, not the + # outside query + label_style = LABEL_STYLE_TABLENAME_PLUS_COL + straight_company_to_person_expected = ( - # TODO: would rather not have the aliasing here but can't fix - # that right now "SELECT companies.company_id, companies.name FROM companies " "JOIN (SELECT people.person_id AS people_person_id, people.company_id " "AS people_company_id, people.name AS people_name, people.type " @@ -691,8 +1154,8 @@ class RelNaturalAliasedJoinsTest( "managers.manager_name AS managers_manager_name FROM people " "LEFT OUTER JOIN engineers ON people.person_id = " "engineers.person_id LEFT OUTER JOIN managers ON people.person_id = " - "managers.person_id) AS pjoin_1 ON companies.company_id = " - "pjoin_1.people_company_id" + "managers.person_id) AS pjoin ON companies.company_id = " + "pjoin.people_company_id" ) person_paperwork_expected = ( @@ -768,8 +1231,8 @@ class RelNaturalAliasedJoinsTest( "FROM people LEFT OUTER JOIN engineers " "ON people.person_id = engineers.person_id " "LEFT OUTER JOIN managers ON people.person_id = managers.person_id) " - "AS pjoin_1 ON companies.company_id = pjoin_1.people_company_id " - "WHERE pjoin_1.people_name = :name_1" + "AS pjoin ON companies.company_id = pjoin.people_company_id " + "WHERE pjoin.people_name = :people_name_1" ) poly_columns = ( @@ -788,6 +1251,113 @@ class RelNaturalAliasedJoinsTest( ) +class RelNaturalAliasedJoinsDisamTest( + _poly_fixtures._PolymorphicAliasedJoins, RelationshipNaturalInheritedTest +): + # this is the label style for the polymorphic selectable, not the + # outside query + label_style = LABEL_STYLE_DISAMBIGUATE_ONLY + + straight_company_to_person_expected = ( + "SELECT companies.company_id, companies.name FROM companies JOIN " + "(SELECT people.person_id AS person_id, " + "people.company_id AS company_id, people.name AS name, " + "people.type AS type, engineers.person_id AS person_id_1, " + "engineers.status AS status, " + "engineers.engineer_name AS engineer_name, " + "engineers.primary_language AS primary_language, " + "managers.person_id AS person_id_2, managers.status AS status_1, " + "managers.manager_name AS manager_name FROM people " + "LEFT OUTER JOIN engineers ON people.person_id = engineers.person_id " + "LEFT OUTER JOIN managers ON people.person_id = managers.person_id) " + "AS pjoin ON companies.company_id = pjoin.company_id" + ) + + person_paperwork_expected = ( + "SELECT companies.company_id, companies.name FROM companies " + "JOIN (SELECT people.person_id AS person_id, people.company_id " + "AS company_id, people.name AS name, people.type AS type, " + "engineers.person_id AS person_id_1, engineers.status AS status, " + "engineers.engineer_name AS engineer_name, " + "engineers.primary_language AS primary_language, managers.person_id " + "AS person_id_2, managers.status AS status_1, managers.manager_name " + "AS manager_name FROM people LEFT OUTER JOIN engineers " + "ON people.person_id = engineers.person_id " + "LEFT OUTER JOIN managers ON people.person_id = managers.person_id) " + "AS pjoin ON companies.company_id = pjoin.company_id " + "JOIN paperwork ON pjoin.person_id = paperwork.person_id" + ) + + default_pjoin = ( + "(SELECT people.person_id AS person_id, people.company_id AS " + "company_id, people.name AS name, people.type AS type, " + "engineers.person_id AS person_id_1, engineers.status AS status, " + "engineers.engineer_name AS engineer_name, engineers.primary_language " + "AS primary_language, managers.person_id AS person_id_2, " + "managers.status AS status_1, managers.manager_name AS manager_name " + "FROM people LEFT OUTER JOIN engineers ON people.person_id = " + "engineers.person_id LEFT OUTER JOIN managers ON people.person_id = " + "managers.person_id) AS pjoin " + "ON companies.company_id = pjoin.company_id" + ) + flat_aliased_pjoin = ( + "(SELECT people.person_id AS person_id, people.company_id AS " + "company_id, people.name AS name, people.type AS type, " + "engineers.person_id AS person_id_1, engineers.status AS status, " + "engineers.engineer_name AS engineer_name, " + "engineers.primary_language AS primary_language, " + "managers.person_id AS person_id_2, managers.status AS status_1, " + "managers.manager_name AS manager_name FROM people " + "LEFT OUTER JOIN engineers ON people.person_id = engineers.person_id " + "LEFT OUTER JOIN managers ON people.person_id = managers.person_id) " + "AS pjoin_1 ON companies.company_id = pjoin_1.company_id" + ) + + aliased_pjoin = ( + "(SELECT people.person_id AS person_id, people.company_id AS " + "company_id, people.name AS name, people.type AS type, " + "engineers.person_id AS person_id_1, engineers.status AS status, " + "engineers.engineer_name AS engineer_name, engineers.primary_language " + "AS primary_language, managers.person_id AS person_id_2, " + "managers.status AS status_1, managers.manager_name AS manager_name " + "FROM people LEFT OUTER JOIN engineers ON people.person_id = " + "engineers.person_id LEFT OUTER JOIN managers ON people.person_id = " + "managers.person_id) AS pjoin_1 " + "ON companies.company_id = pjoin_1.company_id" + ) + + c_to_p_whereclause = ( + "SELECT companies.company_id, companies.name FROM companies JOIN " + "(SELECT people.person_id AS person_id, " + "people.company_id AS company_id, people.name AS name, " + "people.type AS type, engineers.person_id AS person_id_1, " + "engineers.status AS status, " + "engineers.engineer_name AS engineer_name, " + "engineers.primary_language AS primary_language, " + "managers.person_id AS person_id_2, managers.status AS status_1, " + "managers.manager_name AS manager_name FROM people " + "LEFT OUTER JOIN engineers ON people.person_id = engineers.person_id " + "LEFT OUTER JOIN managers ON people.person_id = managers.person_id) " + "AS pjoin ON companies.company_id = pjoin.company_id " + "WHERE pjoin.name = :name_1" + ) + + poly_columns = ( + "SELECT pjoin.person_id FROM (SELECT people.person_id AS " + "person_id, people.company_id AS company_id, people.name AS name, " + "people.type AS type, engineers.person_id AS person_id_1, " + "engineers.status AS status, " + "engineers.engineer_name AS engineer_name, " + "engineers.primary_language AS primary_language, " + "managers.person_id AS person_id_2, " + "managers.status AS status_1, managers.manager_name AS manager_name " + "FROM people LEFT OUTER JOIN engineers " + "ON people.person_id = engineers.person_id " + "LEFT OUTER JOIN managers " + "ON people.person_id = managers.person_id) AS pjoin" + ) + + class RawSelectTest(QueryTest, AssertsCompiledSQL): """older tests from test_query. Here, they are converted to use future selects with ORM compilation. @@ -808,7 +1378,12 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): User = self.classes.User stmt1 = select(User).where(User.addresses) - stmt2 = Session().query(User).filter(User.addresses).statement + stmt2 = ( + Session() + .query(User) + .filter(User.addresses) + ._final_statement(legacy_query_style=False) + ) expected = ( "SELECT users.id, users.name FROM users, addresses " @@ -829,7 +1404,12 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): ) stmt1 = select(Item).where(Item.keywords) - stmt2 = Session().query(Item).filter(Item.keywords).statement + stmt2 = ( + Session() + .query(Item) + .filter(Item.keywords) + ._final_statement(legacy_query_style=False) + ) self.assert_compile(stmt1, expected) self.assert_compile(stmt2, expected) @@ -839,7 +1419,10 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): expected = "SELECT * FROM users" stmt1 = select(literal_column("*")).select_from(User) stmt2 = ( - Session().query(literal_column("*")).select_from(User).statement + Session() + .query(literal_column("*")) + .select_from(User) + ._final_statement(legacy_query_style=False) ) self.assert_compile(stmt1, expected) @@ -850,7 +1433,12 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): ua = aliased(User, name="ua") stmt1 = select(literal_column("*")).select_from(ua) - stmt2 = Session().query(literal_column("*")).select_from(ua) + stmt2 = ( + Session() + .query(literal_column("*")) + .select_from(ua) + ._final_statement(legacy_query_style=False) + ) expected = "SELECT * FROM users AS ua" @@ -886,7 +1474,7 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): .correlate(User) .scalar_subquery(), ) - .statement + ._final_statement(legacy_query_style=False) ) self.assert_compile(stmt1, expected) @@ -916,7 +1504,7 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): .correlate(uu) .scalar_subquery(), ) - .statement + ._final_statement(legacy_query_style=False) ) expected = ( @@ -935,7 +1523,9 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): expected = "SELECT users.id, users.name FROM users" stmt1 = select(User) - stmt2 = Session().query(User).statement + stmt2 = ( + Session().query(User)._final_statement(legacy_query_style=False) + ) self.assert_compile(stmt1, expected) self.assert_compile(stmt2, expected) @@ -946,7 +1536,11 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): expected = "SELECT users.id, users.name FROM users" stmt1 = select(User.id, User.name) - stmt2 = Session().query(User.id, User.name).statement + stmt2 = ( + Session() + .query(User.id, User.name) + ._final_statement(legacy_query_style=False) + ) self.assert_compile(stmt1, expected) self.assert_compile(stmt2, expected) @@ -956,7 +1550,11 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): ua = aliased(User, name="ua") stmt1 = select(ua.id, ua.name) - stmt2 = Session().query(ua.id, ua.name).statement + stmt2 = ( + Session() + .query(ua.id, ua.name) + ._final_statement(legacy_query_style=False) + ) expected = "SELECT ua.id, ua.name FROM users AS ua" self.assert_compile(stmt1, expected) @@ -967,7 +1565,7 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): ua = aliased(User, name="ua") stmt1 = select(ua) - stmt2 = Session().query(ua).statement + stmt2 = Session().query(ua)._final_statement(legacy_query_style=False) expected = "SELECT ua.id, ua.name FROM users AS ua" self.assert_compile(stmt1, expected) @@ -1081,7 +1679,7 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): .query(Foo) .filter(Foo.foob == "somename") .order_by(Foo.foob) - .statement + ._final_statement(legacy_query_style=False) ) expected = ( diff --git a/test/orm/test_froms.py b/test/orm/test_froms.py index ce687fdee..1fde343d8 100644 --- a/test/orm/test_froms.py +++ b/test/orm/test_froms.py @@ -2478,6 +2478,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): "email_address" ), ).group_by(Address.user_id) + ag1 = aliased(Address, agg_address.subquery()) ag2 = aliased(Address, agg_address.subquery(), adapt_on_names=True) diff --git a/test/orm/test_joins.py b/test/orm/test_joins.py index 300670a70..fe3f2a721 100644 --- a/test/orm/test_joins.py +++ b/test/orm/test_joins.py @@ -17,7 +17,6 @@ from sqlalchemy import true from sqlalchemy.engine import default from sqlalchemy.orm import aliased from sqlalchemy.orm import backref -from sqlalchemy.orm import configure_mappers from sqlalchemy.orm import create_session from sqlalchemy.orm import join from sqlalchemy.orm import joinedload @@ -33,283 +32,15 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing.schema import Column from test.orm import _fixtures +from .inheritance import _poly_fixtures +from .test_query import QueryTest -class QueryTest(_fixtures.FixtureTest): +class InheritedTest(_poly_fixtures._Polymorphic): run_setup_mappers = "once" - run_inserts = "once" - run_deletes = None - - @classmethod - def setup_mappers(cls): - ( - Node, - composite_pk_table, - users, - Keyword, - items, - Dingaling, - order_items, - item_keywords, - Item, - User, - dingalings, - Address, - keywords, - CompositePk, - nodes, - Order, - orders, - addresses, - ) = ( - cls.classes.Node, - cls.tables.composite_pk_table, - cls.tables.users, - cls.classes.Keyword, - cls.tables.items, - cls.classes.Dingaling, - cls.tables.order_items, - cls.tables.item_keywords, - cls.classes.Item, - cls.classes.User, - cls.tables.dingalings, - cls.classes.Address, - cls.tables.keywords, - cls.classes.CompositePk, - cls.tables.nodes, - cls.classes.Order, - cls.tables.orders, - cls.tables.addresses, - ) - - mapper( - User, - users, - properties={ - "addresses": relationship( - Address, backref="user", order_by=addresses.c.id - ), - # o2m, m2o - "orders": relationship( - Order, backref="user", order_by=orders.c.id - ), - }, - ) - mapper( - Address, - addresses, - properties={ - # o2o - "dingaling": relationship( - Dingaling, uselist=False, backref="address" - ) - }, - ) - mapper(Dingaling, dingalings) - mapper( - Order, - orders, - properties={ - # m2m - "items": relationship( - Item, secondary=order_items, order_by=items.c.id - ), - "address": relationship(Address), # m2o - }, - ) - mapper( - Item, - items, - properties={ - "keywords": relationship( - Keyword, secondary=item_keywords - ) # m2m - }, - ) - mapper(Keyword, keywords) - - mapper( - Node, - nodes, - properties={ - "children": relationship( - Node, backref=backref("parent", remote_side=[nodes.c.id]) - ) - }, - ) - - mapper(CompositePk, composite_pk_table) - - configure_mappers() - - -class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL): - run_setup_mappers = "once" - - @classmethod - def define_tables(cls, metadata): - Table( - "companies", - metadata, - Column( - "company_id", - Integer, - primary_key=True, - test_needs_autoincrement=True, - ), - Column("name", String(50)), - ) - - Table( - "people", - metadata, - Column( - "person_id", - Integer, - primary_key=True, - test_needs_autoincrement=True, - ), - Column("company_id", Integer, ForeignKey("companies.company_id")), - Column("name", String(50)), - Column("type", String(30)), - ) - - Table( - "engineers", - metadata, - Column( - "person_id", - Integer, - ForeignKey("people.person_id"), - primary_key=True, - ), - Column("status", String(30)), - Column("engineer_name", String(50)), - Column("primary_language", String(50)), - ) - Table( - "machines", - metadata, - Column( - "machine_id", - Integer, - primary_key=True, - test_needs_autoincrement=True, - ), - Column("name", String(50)), - Column("engineer_id", Integer, ForeignKey("engineers.person_id")), - ) - - Table( - "managers", - metadata, - Column( - "person_id", - Integer, - ForeignKey("people.person_id"), - primary_key=True, - ), - Column("status", String(30)), - Column("manager_name", String(50)), - ) - - Table( - "boss", - metadata, - Column( - "boss_id", - Integer, - ForeignKey("managers.person_id"), - primary_key=True, - ), - Column("golf_swing", String(30)), - ) - - Table( - "paperwork", - metadata, - Column( - "paperwork_id", - Integer, - primary_key=True, - test_needs_autoincrement=True, - ), - Column("description", String(50)), - Column("person_id", Integer, ForeignKey("people.person_id")), - ) - - @classmethod - def setup_classes(cls): - paperwork, people, companies, boss, managers, machines, engineers = ( - cls.tables.paperwork, - cls.tables.people, - cls.tables.companies, - cls.tables.boss, - cls.tables.managers, - cls.tables.machines, - cls.tables.engineers, - ) - - class Company(cls.Comparable): - pass - - class Person(cls.Comparable): - pass - - class Engineer(Person): - pass - - class Manager(Person): - pass - - class Boss(Manager): - pass - - class Machine(cls.Comparable): - pass - - class Paperwork(cls.Comparable): - pass - - mapper( - Company, - companies, - properties={ - "employees": relationship(Person, order_by=people.c.person_id) - }, - ) - - mapper(Machine, machines) - - mapper( - Person, - people, - polymorphic_on=people.c.type, - polymorphic_identity="person", - properties={ - "paperwork": relationship( - Paperwork, order_by=paperwork.c.paperwork_id - ) - }, - ) - mapper( - Engineer, - engineers, - inherits=Person, - polymorphic_identity="engineer", - properties={ - "machines": relationship( - Machine, order_by=machines.c.machine_id - ) - }, - ) - mapper( - Manager, managers, inherits=Person, polymorphic_identity="manager" - ) - mapper(Boss, boss, inherits=Manager, polymorphic_identity="boss") - mapper(Paperwork, paperwork) +class InheritedJoinTest(InheritedTest, AssertsCompiledSQL): def test_single_prop(self): Company = self.classes.Company diff --git a/test/orm/test_of_type.py b/test/orm/test_of_type.py index 82930f754..daac38dc2 100644 --- a/test/orm/test_of_type.py +++ b/test/orm/test_of_type.py @@ -54,7 +54,7 @@ class _PolymorphicTestBase(object): def test_any_four(self): sess = Session() - any_ = Company.employees.of_type(Boss).any( + any_ = Company.employees.of_type(Manager).any( Manager.manager_name == "pointy" ) eq_(sess.query(Company).filter(any_).one(), self.c1) diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 76706b37b..478fc7147 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -16,7 +16,6 @@ from sqlalchemy import exc as sa_exc from sqlalchemy import exists from sqlalchemy import ForeignKey from sqlalchemy import func -from sqlalchemy import insert from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import literal @@ -553,227 +552,6 @@ class BindSensitiveStringifyTest(fixtures.TestBase): self._test(True, True, True, True) -class RawSelectTest(QueryTest, AssertsCompiledSQL): - __dialect__ = "default" - - def test_select_from_entity(self): - User = self.classes.User - - self.assert_compile( - select(["*"]).select_from(User), "SELECT * FROM users" - ) - - def test_where_relationship(self): - User = self.classes.User - - self.assert_compile( - select([User]).where(User.addresses), - "SELECT users.id, users.name FROM users, addresses " - "WHERE users.id = addresses.user_id", - ) - - def test_where_m2m_relationship(self): - Item = self.classes.Item - - self.assert_compile( - select([Item]).where(Item.keywords), - "SELECT items.id, items.description FROM items, " - "item_keywords AS item_keywords_1, keywords " - "WHERE items.id = item_keywords_1.item_id " - "AND keywords.id = item_keywords_1.keyword_id", - ) - - def test_inline_select_from_entity(self): - User = self.classes.User - - self.assert_compile( - select(["*"], from_obj=User), "SELECT * FROM users" - ) - - def test_select_from_aliased_entity(self): - User = self.classes.User - ua = aliased(User, name="ua") - self.assert_compile( - select(["*"]).select_from(ua), "SELECT * FROM users AS ua" - ) - - def test_correlate_entity(self): - User = self.classes.User - Address = self.classes.Address - - self.assert_compile( - select( - [ - User.name, - Address.id, - select([func.count(Address.id)]) - .where(User.id == Address.user_id) - .correlate(User) - .scalar_subquery(), - ] - ), - "SELECT users.name, addresses.id, " - "(SELECT count(addresses.id) AS count_1 " - "FROM addresses WHERE users.id = addresses.user_id) AS anon_1 " - "FROM users, addresses", - ) - - def test_correlate_aliased_entity(self): - User = self.classes.User - Address = self.classes.Address - uu = aliased(User, name="uu") - - self.assert_compile( - select( - [ - uu.name, - Address.id, - select([func.count(Address.id)]) - .where(uu.id == Address.user_id) - .correlate(uu) - .scalar_subquery(), - ] - ), - # for a long time, "uu.id = address.user_id" was reversed; - # this was resolved as of #2872 and had to do with - # InstrumentedAttribute.__eq__() taking precedence over - # QueryableAttribute.__eq__() - "SELECT uu.name, addresses.id, " - "(SELECT count(addresses.id) AS count_1 " - "FROM addresses WHERE uu.id = addresses.user_id) AS anon_1 " - "FROM users AS uu, addresses", - ) - - def test_columns_clause_entity(self): - User = self.classes.User - - self.assert_compile( - select([User]), "SELECT users.id, users.name FROM users" - ) - - def test_columns_clause_columns(self): - User = self.classes.User - - self.assert_compile( - select([User.id, User.name]), - "SELECT users.id, users.name FROM users", - ) - - def test_columns_clause_aliased_columns(self): - User = self.classes.User - ua = aliased(User, name="ua") - self.assert_compile( - select([ua.id, ua.name]), "SELECT ua.id, ua.name FROM users AS ua" - ) - - def test_columns_clause_aliased_entity(self): - User = self.classes.User - ua = aliased(User, name="ua") - self.assert_compile( - select([ua]), "SELECT ua.id, ua.name FROM users AS ua" - ) - - def test_core_join(self): - User = self.classes.User - Address = self.classes.Address - from sqlalchemy.sql import join - - self.assert_compile( - select([User]).select_from(join(User, Address)), - "SELECT users.id, users.name FROM users " - "JOIN addresses ON users.id = addresses.user_id", - ) - - def test_insert_from_query(self): - User = self.classes.User - Address = self.classes.Address - - s = Session() - q = s.query(User.id, User.name).filter_by(name="ed") - self.assert_compile( - insert(Address).from_select(("id", "email_address"), q), - "INSERT INTO addresses (id, email_address) " - "SELECT users.id AS users_id, users.name AS users_name " - "FROM users WHERE users.name = :name_1", - ) - - def test_insert_from_query_col_attr(self): - User = self.classes.User - Address = self.classes.Address - - s = Session() - q = s.query(User.id, User.name).filter_by(name="ed") - self.assert_compile( - insert(Address).from_select( - (Address.id, Address.email_address), q - ), - "INSERT INTO addresses (id, email_address) " - "SELECT users.id AS users_id, users.name AS users_name " - "FROM users WHERE users.name = :name_1", - ) - - def test_update_from_entity(self): - from sqlalchemy.sql import update - - User = self.classes.User - self.assert_compile( - update(User), "UPDATE users SET id=:id, name=:name" - ) - - self.assert_compile( - update(User).values(name="ed").where(User.id == 5), - "UPDATE users SET name=:name WHERE users.id = :id_1", - checkparams={"id_1": 5, "name": "ed"}, - ) - - def test_delete_from_entity(self): - from sqlalchemy.sql import delete - - User = self.classes.User - self.assert_compile(delete(User), "DELETE FROM users") - - self.assert_compile( - delete(User).where(User.id == 5), - "DELETE FROM users WHERE users.id = :id_1", - checkparams={"id_1": 5}, - ) - - def test_insert_from_entity(self): - from sqlalchemy.sql import insert - - User = self.classes.User - self.assert_compile( - insert(User), "INSERT INTO users (id, name) VALUES (:id, :name)" - ) - - self.assert_compile( - insert(User).values(name="ed"), - "INSERT INTO users (name) VALUES (:name)", - checkparams={"name": "ed"}, - ) - - def test_col_prop_builtin_function(self): - class Foo(object): - pass - - mapper( - Foo, - self.tables.users, - properties={ - "foob": column_property( - func.coalesce(self.tables.users.c.name) - ) - }, - ) - - self.assert_compile( - select([Foo]).where(Foo.foob == "somename").order_by(Foo.foob), - "SELECT users.id, users.name FROM users " - "WHERE coalesce(users.name) = :param_1 " - "ORDER BY coalesce(users.name)", - ) - - class GetTest(QueryTest): def test_get_composite_pk_keyword_based_no_result(self): CompositePk = self.classes.CompositePk diff --git a/test/orm/test_utils.py b/test/orm/test_utils.py index 5e3f51606..1d9882678 100644 --- a/test/orm/test_utils.py +++ b/test/orm/test_utils.py @@ -220,7 +220,6 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL): "parententity": point_mapper, "parentmapper": point_mapper, "orm_key": "x_alone", - "compile_state_plugin": "orm", }, ) eq_( @@ -230,7 +229,6 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL): "parententity": point_mapper, "parentmapper": point_mapper, "orm_key": "x", - "compile_state_plugin": "orm", }, ) diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index d3d21cb0e..2d84ab676 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -661,6 +661,56 @@ class CoreFixtures(object): fixtures.append(_statements_w_context_options_fixtures) + def _statements_w_anonymous_col_names(): + def one(): + c = column("q") + + l = c.label(None) + + # new case as of Id810f485c5f7ed971529489b84694e02a3356d6d + subq = select([l]).subquery() + + # this creates a ColumnClause as a proxy to the Label() that has + # an anoymous name, so the column has one too. + anon_col = subq.c[0] + + # then when BindParameter is created, it checks the label + # and doesn't double up on the anonymous name which is uncachable + return anon_col > 5 + + def two(): + c = column("p") + + l = c.label(None) + + # new case as of Id810f485c5f7ed971529489b84694e02a3356d6d + subq = select([l]).subquery() + + # this creates a ColumnClause as a proxy to the Label() that has + # an anoymous name, so the column has one too. + anon_col = subq.c[0] + + # then when BindParameter is created, it checks the label + # and doesn't double up on the anonymous name which is uncachable + return anon_col > 5 + + def three(): + + l1, l2 = table_a.c.a.label(None), table_a.c.b.label(None) + + stmt = select([table_a.c.a, table_a.c.b, l1, l2]) + + subq = stmt.subquery() + return select([subq]).where(subq.c[2] == 10) + + return ( + one(), + two(), + three(), + ) + + fixtures.append(_statements_w_anonymous_col_names) + class CacheKeyFixture(object): def _run_cache_key_fixture(self, fixture, compare_values): diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 20f31ba1e..4c87c0a46 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -71,6 +71,7 @@ from sqlalchemy.engine import default from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql import column from sqlalchemy.sql import compiler +from sqlalchemy.sql import elements from sqlalchemy.sql import label from sqlalchemy.sql import operators from sqlalchemy.sql import table @@ -3294,6 +3295,29 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase): checkparams={"3foo_1": "foo", "4_foo_1": "bar"}, ) + def test_bind_given_anon_name_dont_double(self): + c = column("id") + l = c.label(None) + + # new case as of Id810f485c5f7ed971529489b84694e02a3356d6d + subq = select([l]).subquery() + + # this creates a ColumnClause as a proxy to the Label() that has + # an anoymous name, so the column has one too. + anon_col = subq.c[0] + assert isinstance(anon_col.name, elements._anonymous_label) + + # then when BindParameter is created, it checks the label + # and doesn't double up on the anonymous name which is uncachable + expr = anon_col > 5 + + self.assert_compile( + expr, "anon_1.id_1 > :param_1", checkparams={"param_1": 5} + ) + + # see also test_compare.py -> _statements_w_anonymous_col_names + # fixture for cache key + def test_bind_as_col(self): t = table("foo", column("id")) diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index e509c9f95..d53ee3385 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -147,6 +147,66 @@ class SelectableTest( assert s1.corresponding_column(scalar_select) is s1.c.foo assert s2.corresponding_column(scalar_select) is s2.c.foo + def test_labels_name_w_separate_key(self): + label = select([table1.c.col1]).label("foo") + label.key = "bar" + + s1 = select([label]) + assert s1.corresponding_column(label) is s1.selected_columns.bar + + # renders as foo + self.assert_compile( + s1, "SELECT (SELECT table1.col1 FROM table1) AS foo" + ) + + def test_labels_anon_w_separate_key(self): + label = select([table1.c.col1]).label(None) + label.key = "bar" + + s1 = select([label]) + + # .bar is there + assert s1.corresponding_column(label) is s1.selected_columns.bar + + # renders as anon_1 + self.assert_compile( + s1, "SELECT (SELECT table1.col1 FROM table1) AS anon_1" + ) + + def test_labels_anon_w_separate_key_subquery(self): + label = select([table1.c.col1]).label(None) + label.key = label._key_label = "bar" + + s1 = select([label]) + + subq = s1.subquery() + + s2 = select([subq]).where(subq.c.bar > 5) + self.assert_compile( + s2, + "SELECT anon_2.anon_1 FROM (SELECT (SELECT table1.col1 " + "FROM table1) AS anon_1) AS anon_2 " + "WHERE anon_2.anon_1 > :param_1", + checkparams={"param_1": 5}, + ) + + def test_labels_anon_generate_binds_subquery(self): + label = select([table1.c.col1]).label(None) + label.key = label._key_label = "bar" + + s1 = select([label]) + + subq = s1.subquery() + + s2 = select([subq]).where(subq.c[0] > 5) + self.assert_compile( + s2, + "SELECT anon_2.anon_1 FROM (SELECT (SELECT table1.col1 " + "FROM table1) AS anon_1) AS anon_2 " + "WHERE anon_2.anon_1 > :param_1", + checkparams={"param_1": 5}, + ) + def test_select_label_grouped_still_corresponds(self): label = select([table1.c.col1]).label("foo") label2 = label.self_group() diff --git a/test/sql/test_utils.py b/test/sql/test_utils.py index 4e713dd28..d68a74475 100644 --- a/test/sql/test_utils.py +++ b/test/sql/test_utils.py @@ -4,8 +4,11 @@ from sqlalchemy import MetaData from sqlalchemy import select from sqlalchemy import String from sqlalchemy import Table +from sqlalchemy.sql import base as sql_base from sqlalchemy.sql import util as sql_util from sqlalchemy.sql.elements import ColumnElement +from sqlalchemy.testing import assert_raises +from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures @@ -48,3 +51,41 @@ class MiscTest(fixtures.TestBase): set(sql_util.find_tables(subset_select, include_aliases=True)), {common, calias, subset_select}, ) + + def test_options_merge(self): + class opt1(sql_base.CacheableOptions): + _cache_key_traversal = [] + + class opt2(sql_base.CacheableOptions): + _cache_key_traversal = [] + + foo = "bar" + + class opt3(sql_base.CacheableOptions): + _cache_key_traversal = [] + + foo = "bar" + bat = "hi" + + o2 = opt2.safe_merge(opt1) + eq_(o2.__dict__, {}) + eq_(o2.foo, "bar") + + assert_raises_message( + TypeError, + r"other element .*opt2.* is not empty, is not of type .*opt1.*, " + r"and contains attributes not covered here .*'foo'.*", + opt1.safe_merge, + opt2, + ) + + o2 = opt2 + {"foo": "bat"} + o3 = opt2.safe_merge(o2) + + eq_(o3.foo, "bat") + + o4 = opt3.safe_merge(o2) + eq_(o4.foo, "bat") + eq_(o4.bat, "hi") + + assert_raises(TypeError, opt2.safe_merge, o4) |