diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-04-27 12:58:12 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-05-25 13:56:37 -0400 |
commit | 6930dfc032c3f9f474e71ab4e021c0ef8384930e (patch) | |
tree | 34b919a3c34edaffda1750f161a629fc5b9a8020 /lib/sqlalchemy | |
parent | dce8c7a125cb99fad62c76cd145752d5afefae36 (diff) | |
download | sqlalchemy-6930dfc032c3f9f474e71ab4e021c0ef8384930e.tar.gz |
Convert execution to move through Session
This patch replaces the ORM execution flow with a
single pathway through Session.execute() for all queries,
including Core and ORM.
Currently included is full support for ORM Query,
Query.from_statement(), select(), as well as the
baked query and horizontal shard systems. Initial
changes have also been made to the dogpile caching
example, which like baked query makes use of a
new ORM-specific execution hook that replaces the
use of both QueryEvents.before_compile() as well
as Query._execute_and_instances() as the central
ORM interception hooks.
select() and Query() constructs alike can be passed to
Session.execute() where they will return ORM
results in a Results object. This API is currently
used internally by Query. Full support for
Session.execute()->results to behave in a fully
2.0 fashion will be in later changesets.
bulk update/delete with ORM support will also
be delivered via the update() and delete()
constructs, however these have not yet been adapted
to the new system and may follow in a subsequent
update.
Performance is also beginning to lag as of this
commit and some previous ones. It is hoped that
a few central functions such as the coercions
functions can be rewritten in C to re-gain
performance. Additionally, query caching
is now available and some subsequent patches
will attempt to cache more of the per-execution
work from the ORM layer, e.g. column getters
and adapters.
This patch also contains initial "turn on" of the
caching system enginewide via the query_cache_size
parameter to create_engine(). Still defaulting at
zero for "no caching". The caching system still
needs adjustments in order to gain adequate performance.
Change-Id: I047a7ebb26aa85dc01f6789fac2bff561dcd555d
Diffstat (limited to 'lib/sqlalchemy')
40 files changed, 2007 insertions, 1031 deletions
diff --git a/lib/sqlalchemy/cextension/resultproxy.c b/lib/sqlalchemy/cextension/resultproxy.c index ff6cadac0..ed6f57470 100644 --- a/lib/sqlalchemy/cextension/resultproxy.c +++ b/lib/sqlalchemy/cextension/resultproxy.c @@ -651,6 +651,7 @@ static int BaseRow_setkeystyle(BaseRow *self, PyObject *value, void *closure) { if (value == NULL) { + PyErr_SetString( PyExc_TypeError, "Cannot delete the 'key_style' attribute"); diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 05c34c171..3345d555f 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -1704,12 +1704,14 @@ class MSSQLCompiler(compiler.SQLCompiler): self.process(element.typeclause, **kw), ) - def visit_select(self, select, **kwargs): + def translate_select_structure(self, select_stmt, **kwargs): """Look for ``LIMIT`` and OFFSET in a select statement, and if so tries to wrap it in a subquery with ``row_number()`` criterion. MSSQL 2012 and above are excluded """ + select = select_stmt + if ( not self.dialect._supports_offset_fetch and ( @@ -1741,7 +1743,7 @@ class MSSQLCompiler(compiler.SQLCompiler): limit_clause = select._limit_clause offset_clause = select._offset_clause - kwargs["select_wraps_for"] = select + select = select._generate() select._mssql_visit = True select = ( @@ -1766,9 +1768,9 @@ class MSSQLCompiler(compiler.SQLCompiler): ) else: limitselect = limitselect.where(mssql_rn <= (limit_clause)) - return self.process(limitselect, **kwargs) + return limitselect else: - return compiler.SQLCompiler.visit_select(self, select, **kwargs) + return select @_with_legacy_schema_aliasing def visit_table(self, table, mssql_aliased=False, iscrud=False, **kwargs): diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index dd7d6a4d1..481ea7263 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -975,16 +975,8 @@ class OracleCompiler(compiler.SQLCompiler): return "RETURNING " + ", ".join(columns) + " INTO " + ", ".join(binds) - def _TODO_visit_compound_select(self, select): - """Need to determine how to get ``LIMIT``/``OFFSET`` into a - ``UNION`` for Oracle. - """ - pass - - def visit_select(self, select, **kwargs): - """Look for ``LIMIT`` and OFFSET in a select statement, and if - so tries to wrap it in a subquery with ``rownum`` criterion. - """ + def translate_select_structure(self, select_stmt, **kwargs): + select = select_stmt if not getattr(select, "_oracle_visit", None): if not self.dialect.use_ansi: @@ -1003,7 +995,7 @@ class OracleCompiler(compiler.SQLCompiler): # https://blogs.oracle.com/oraclemagazine/\ # on-rownum-and-limiting-results - kwargs["select_wraps_for"] = orig_select = select + orig_select = select select = select._generate() select._oracle_visit = True @@ -1136,7 +1128,7 @@ class OracleCompiler(compiler.SQLCompiler): offsetselect._for_update_arg = for_update select = offsetselect - return compiler.SQLCompiler.visit_select(self, select, **kwargs) + return select def limit_clause(self, select, **kw): return "" diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index ee02899f6..0193ea47c 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -225,14 +225,10 @@ class Connection(Connectable): A dictionary where :class:`.Compiled` objects will be cached when the :class:`_engine.Connection` compiles a clause - expression into a :class:`.Compiled` object. - It is the user's responsibility to - manage the size of this dictionary, which will have keys - corresponding to the dialect, clause element, the column - names within the VALUES or SET clause of an INSERT or UPDATE, - as well as the "batch" mode for an INSERT or UPDATE statement. - The format of this dictionary is not guaranteed to stay the - same in future releases. + expression into a :class:`.Compiled` object. This dictionary will + supersede the statement cache that may be configured on the + :class:`_engine.Engine` itself. If set to None, caching + is disabled, even if the engine has a configured cache size. Note that the ORM makes use of its own "compiled" caches for some operations, including flush operations. The caching @@ -1159,13 +1155,17 @@ class Connection(Connectable): schema_translate_map = exec_opts.get("schema_translate_map", None) - if "compiled_cache" in exec_opts: + compiled_cache = exec_opts.get( + "compiled_cache", self.dialect._compiled_cache + ) + + if compiled_cache is not None: elem_cache_key = elem._generate_cache_key() else: elem_cache_key = None if elem_cache_key: - cache_key, extracted_params = elem_cache_key + cache_key, extracted_params, _ = elem_cache_key key = ( dialect, cache_key, @@ -1173,8 +1173,7 @@ class Connection(Connectable): bool(schema_translate_map), len(distilled_params) > 1, ) - cache = exec_opts["compiled_cache"] - compiled_sql = cache.get(key) + compiled_sql = compiled_cache.get(key) if compiled_sql is None: compiled_sql = elem.compile( @@ -1185,12 +1184,8 @@ class Connection(Connectable): schema_translate_map=schema_translate_map, linting=self.dialect.compiler_linting | compiler.WARN_LINTING, - compile_state_factories=exec_opts.get( - "compile_state_factories", None - ), ) - cache[key] = compiled_sql - + compiled_cache[key] = compiled_sql else: extracted_params = None compiled_sql = elem.compile( @@ -1199,9 +1194,6 @@ class Connection(Connectable): inline=len(distilled_params) > 1, schema_translate_map=schema_translate_map, linting=self.dialect.compiler_linting | compiler.WARN_LINTING, - compile_state_factories=exec_opts.get( - "compile_state_factories", None - ), ) ret = self._execute_context( @@ -1430,18 +1422,35 @@ class Connection(Connectable): ) if self._echo: + self.engine.logger.info(statement) + + # stats = context._get_cache_stats() + if not self.engine.hide_parameters: + # TODO: I love the stats but a ton of tests that are hardcoded. + # to certain log output are failing. self.engine.logger.info( "%r", sql_util._repr_params( parameters, batches=10, ismulti=context.executemany ), ) + # self.engine.logger.info( + # "[%s] %r", + # stats, + # sql_util._repr_params( + # parameters, batches=10, ismulti=context.executemany + # ), + # ) else: self.engine.logger.info( "[SQL parameters hidden due to hide_parameters=True]" ) + # self.engine.logger.info( + # "[%s] [SQL parameters hidden due to hide_parameters=True]" + # % (stats,) + # ) evt_handled = False try: @@ -1502,19 +1511,14 @@ class Connection(Connectable): # for "connectionless" execution, we have to close this # Connection after the statement is complete. - if branched.should_close_with_result: + # legacy stuff. + if branched.should_close_with_result and context._soft_closed: assert not self._is_future assert not context._is_future_result # CursorResult already exhausted rows / has no rows. - # close us now. note this is where we call .close() - # on the "branched" connection if we're doing that. - if result._soft_closed: - branched.close() - else: - # CursorResult will close this Connection when no more - # rows to fetch. - result._autoclose_connection = True + # close us now + branched.close() except BaseException as e: self._handle_dbapi_exception( e, statement, parameters, cursor, context diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index e683b6297..4c912349e 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -435,6 +435,23 @@ def create_engine(url, **kwargs): .. versionadded:: 1.2.3 + :param query_cache_size: size of the cache used to cache the SQL string + form of queries. Defaults to zero, which disables caching. + + Caching is accomplished on a per-statement basis by generating a + cache key that represents the statement's structure, then generating + string SQL for the current dialect only if that key is not present + in the cache. All statements support caching, however some features + such as an INSERT with a large set of parameters will intentionally + bypass the cache. SQL logging will indicate statistics for each + statement whether or not it were pull from the cache. + + .. seealso:: + + ``engine_caching`` - TODO: this will be an upcoming section describing + the SQL caching system. + + .. versionadded:: 1.4 """ # noqa diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index 8d1a1bb57..fdbf826ed 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -57,6 +57,9 @@ class CursorResultMetaData(ResultMetaData): returns_rows = True + def _has_key(self, key): + return key in self._keymap + def _for_freeze(self): return SimpleResultMetaData( self._keys, @@ -1203,6 +1206,7 @@ class BaseCursorResult(object): out_parameters = None _metadata = None + _metadata_from_cache = False _soft_closed = False closed = False @@ -1213,7 +1217,6 @@ class BaseCursorResult(object): obj = CursorResult(context) else: obj = LegacyCursorResult(context) - return obj def __init__(self, context): @@ -1247,8 +1250,9 @@ class BaseCursorResult(object): def _init_metadata(self, context, cursor_description): if context.compiled: if context.compiled._cached_metadata: - cached_md = context.compiled._cached_metadata - self._metadata = cached_md._adapt_to_context(context) + cached_md = self.context.compiled._cached_metadata + self._metadata = cached_md + self._metadata_from_cache = True else: self._metadata = ( diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index e30daaeb8..b5cb2a1b2 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -16,6 +16,7 @@ as the base class for their own corresponding classes. import codecs import random import re +import time import weakref from . import cursor as _cursor @@ -226,6 +227,7 @@ class DefaultDialect(interfaces.Dialect): supports_native_boolean=None, max_identifier_length=None, label_length=None, + query_cache_size=0, # int() is because the @deprecated_params decorator cannot accommodate # the direct reference to the "NO_LINTING" object compiler_linting=int(compiler.NO_LINTING), @@ -257,6 +259,10 @@ class DefaultDialect(interfaces.Dialect): if supports_native_boolean is not None: self.supports_native_boolean = supports_native_boolean self.case_sensitive = case_sensitive + if query_cache_size != 0: + self._compiled_cache = util.LRUCache(query_cache_size) + else: + self._compiled_cache = None self._user_defined_max_identifier_length = max_identifier_length if self._user_defined_max_identifier_length: @@ -702,11 +708,17 @@ class DefaultExecutionContext(interfaces.ExecutionContext): result_column_struct = None returned_defaults = None execution_options = util.immutabledict() + + cache_stats = None + invoked_statement = None + _is_implicit_returning = False _is_explicit_returning = False _is_future_result = False _is_server_side = False + _soft_closed = False + # a hook for SQLite's translation of # result column names # NOTE: pyhive is using this hook, can't remove it :( @@ -1011,6 +1023,16 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self.cursor = self.create_cursor() return self + def _get_cache_stats(self): + if self.compiled is None: + return "raw SQL" + + now = time.time() + if self.compiled.cache_key is None: + return "gen %.5fs" % (now - self.compiled._gen_time,) + else: + return "cached %.5fs" % (now - self.compiled._gen_time,) + @util.memoized_property def engine(self): return self.root_connection.engine @@ -1234,6 +1256,33 @@ class DefaultExecutionContext(interfaces.ExecutionContext): ): self._setup_out_parameters(result) + if not self._is_future_result: + conn = self.root_connection + assert not conn._is_future + + if not result._soft_closed and conn.should_close_with_result: + result._autoclose_connection = True + + self._soft_closed = result._soft_closed + + # result rewrite/ adapt step. two translations can occur here. + # one is if we are invoked against a cached statement, we want + # to rewrite the ResultMetaData to reflect the column objects + # that are in our current selectable, not the cached one. the + # other is, the CompileState can return an alternative Result + # object. Finally, CompileState might want to tell us to not + # actually do the ResultMetaData adapt step if it in fact has + # changed the selected columns in any case. + compiled = self.compiled + if compiled: + adapt_metadata = ( + result._metadata_from_cache + and not compiled._rewrites_selected_columns + ) + + if adapt_metadata: + result._metadata = result._metadata._adapt_to_context(self) + return result def _setup_out_parameters(self, result): diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 4e6b22820..0ee80ede4 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -56,6 +56,9 @@ class ResultMetaData(object): def keys(self): return RMKeyView(self) + def _has_key(self, key): + raise NotImplementedError() + def _for_freeze(self): raise NotImplementedError() @@ -171,6 +174,9 @@ class SimpleResultMetaData(ResultMetaData): self._processors = _processors + def _has_key(self, key): + return key in self._keymap + def _for_freeze(self): unique_filters = self._unique_filters if unique_filters and self._tuplefilter: @@ -287,6 +293,8 @@ class Result(InPlaceGenerative): _no_scalar_onerow = False _yield_per = None + _attributes = util.immutabledict() + def __init__(self, cursor_metadata): self._metadata = cursor_metadata @@ -548,10 +556,21 @@ class Result(InPlaceGenerative): self._generate_rows = True def _row_getter(self): - if self._source_supports_scalars and not self._generate_rows: - return None + if self._source_supports_scalars: + if not self._generate_rows: + return None + else: + _proc = self._process_row + + def process_row( + metadata, processors, keymap, key_style, scalar_obj + ): + return _proc( + metadata, processors, keymap, key_style, (scalar_obj,) + ) - process_row = self._process_row + else: + process_row = self._process_row key_style = self._process_row._default_key_style metadata = self._metadata @@ -771,16 +790,15 @@ class Result(InPlaceGenerative): uniques, strategy = self._unique_strategy def filterrows(make_row, rows, strategy, uniques): + if make_row: + rows = [make_row(row) for row in rows] + if strategy: made_rows = ( - (made_row, strategy(made_row)) - for made_row in [make_row(row) for row in rows] + (made_row, strategy(made_row)) for made_row in rows ) else: - made_rows = ( - (made_row, made_row) - for made_row in [make_row(row) for row in rows] - ) + made_rows = ((made_row, made_row) for made_row in rows) return [ made_row for made_row, sig_row in made_rows @@ -831,7 +849,8 @@ class Result(InPlaceGenerative): num = self._yield_per rows = self._fetchmany_impl(num) - rows = [make_row(row) for row in rows] + if make_row: + rows = [make_row(row) for row in rows] if post_creational_filter: rows = [post_creational_filter(row) for row in rows] return rows @@ -1114,24 +1133,42 @@ class FrozenResult(object): def __init__(self, result): self.metadata = result._metadata._for_freeze() self._post_creational_filter = result._post_creational_filter - self._source_supports_scalars = result._source_supports_scalars self._generate_rows = result._generate_rows + self._source_supports_scalars = result._source_supports_scalars + self._attributes = result._attributes result._post_creational_filter = None - self.data = result.fetchall() + if self._source_supports_scalars: + self.data = list(result._raw_row_iterator()) + else: + self.data = result.fetchall() + + def rewrite_rows(self): + if self._source_supports_scalars: + return [[elem] for elem in self.data] + else: + return [list(row) for row in self.data] - def with_data(self, data): + def with_new_rows(self, tuple_data): fr = FrozenResult.__new__(FrozenResult) fr.metadata = self.metadata fr._post_creational_filter = self._post_creational_filter - fr.data = data + fr._generate_rows = self._generate_rows + fr._attributes = self._attributes + fr._source_supports_scalars = self._source_supports_scalars + + if self._source_supports_scalars: + fr.data = [d[0] for d in tuple_data] + else: + fr.data = tuple_data return fr def __call__(self): result = IteratorResult(self.metadata, iter(self.data)) result._post_creational_filter = self._post_creational_filter - result._source_supports_scalars = self._source_supports_scalars result._generate_rows = self._generate_rows + result._attributes = self._attributes + result._source_supports_scalars = self._source_supports_scalars return result @@ -1143,9 +1180,10 @@ class IteratorResult(Result): """ - def __init__(self, cursor_metadata, iterator): + def __init__(self, cursor_metadata, iterator, raw=None): self._metadata = cursor_metadata self.iterator = iterator + self.raw = raw def _soft_close(self, **kw): self.iterator = iter([]) @@ -1189,28 +1227,23 @@ class ChunkedIteratorResult(IteratorResult): """ - def __init__(self, cursor_metadata, chunks, source_supports_scalars=False): + def __init__( + self, cursor_metadata, chunks, source_supports_scalars=False, raw=None + ): self._metadata = cursor_metadata self.chunks = chunks self._source_supports_scalars = source_supports_scalars - - self.iterator = itertools.chain.from_iterable( - self.chunks(None, self._generate_rows) - ) + self.raw = raw + self.iterator = itertools.chain.from_iterable(self.chunks(None)) def _column_slices(self, indexes): result = super(ChunkedIteratorResult, self)._column_slices(indexes) - self.iterator = itertools.chain.from_iterable( - self.chunks(self._yield_per, self._generate_rows) - ) return result @_generative def yield_per(self, num): self._yield_per = num - self.iterator = itertools.chain.from_iterable( - self.chunks(num, self._generate_rows) - ) + self.iterator = itertools.chain.from_iterable(self.chunks(num)) class MergedResult(IteratorResult): @@ -1238,8 +1271,14 @@ class MergedResult(IteratorResult): self._post_creational_filter = results[0]._post_creational_filter self._no_scalar_onerow = results[0]._no_scalar_onerow self._yield_per = results[0]._yield_per + + # going to try someting w/ this in next rev self._source_supports_scalars = results[0]._source_supports_scalars + self._generate_rows = results[0]._generate_rows + self._attributes = self._attributes.merge_with( + *[r._attributes for r in results] + ) def close(self): self._soft_close(hard=True) diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py index 24af454b6..112e245f7 100644 --- a/lib/sqlalchemy/ext/baked.py +++ b/lib/sqlalchemy/ext/baked.py @@ -19,7 +19,6 @@ from .. import exc as sa_exc from .. import util from ..orm import exc as orm_exc from ..orm import strategy_options -from ..orm.context import QueryContext from ..orm.query import Query from ..orm.session import Session from ..sql import func @@ -201,11 +200,12 @@ class BakedQuery(object): self.spoil(full=True) else: for opt in options: - cache_key = opt._generate_path_cache_key(cache_path) - if cache_key is False: - self.spoil(full=True) - elif cache_key is not None: - key += cache_key + if opt._is_legacy_option or opt._is_compile_state: + cache_key = opt._generate_path_cache_key(cache_path) + if cache_key is False: + self.spoil(full=True) + elif cache_key is not None: + key += cache_key self.add_criteria( lambda q: q._with_current_path(effective_path).options(*options), @@ -224,41 +224,32 @@ class BakedQuery(object): def _bake(self, session): query = self._as_query(session) + query.session = None - compile_state = query._compile_state() + # in 1.4, this is where before_compile() event is + # invoked + statement = query._statement_20(orm_results=True) - self._bake_subquery_loaders(session, compile_state) - - # TODO: compile_state clearly needs to be simplified here. - # if the session remains, fails memusage test - compile_state.orm_query = ( - query - ) = ( - compile_state.select_statement - ) = compile_state.query = compile_state.orm_query.with_session(None) - query._execution_options = query._execution_options.union( - {"compiled_cache": self._bakery} - ) - - # we'll be holding onto the query for some of its state, - # so delete some compilation-use-only attributes that can take up - # space - for attr in ( - "_correlate", - "_from_obj", - "_mapper_adapter_map", - "_joinpath", - "_joinpoint", - ): - query.__dict__.pop(attr, None) + # the before_compile() event can create a new Query object + # before it makes the statement. + query = statement.compile_options._orm_query # if the query is not safe to cache, we still do everything as though # we did cache it, since the receiver of _bake() assumes subqueryload # context was set up, etc. - if compile_state.compile_options._bake_ok: - self._bakery[self._effective_key(session)] = compile_state + # + # note also we want to cache the statement itself because this + # allows the statement itself to hold onto its cache key that is + # used by the Connection, which in itself is more expensive to + # generate than what BakedQuery was able to provide in 1.3 and prior + + if query.compile_options._bake_ok: + self._bakery[self._effective_key(session)] = ( + query, + statement, + ) - return compile_state + return query, statement def to_query(self, query_or_session): """Return the :class:`_query.Query` object for use as a subquery. @@ -321,50 +312,6 @@ class BakedQuery(object): return query - def _bake_subquery_loaders(self, session, compile_state): - """convert subquery eager loaders in the cache into baked queries. - - For subquery eager loading to work, all we need here is that the - Query point to the correct session when it is run. However, since - we are "baking" anyway, we may as well also turn the query into - a "baked" query so that we save on performance too. - - """ - compile_state.attributes["baked_queries"] = baked_queries = [] - for k, v in list(compile_state.attributes.items()): - if isinstance(v, dict) and "query" in v: - if "subqueryload_data" in k: - query = v["query"] - bk = BakedQuery(self._bakery, lambda *args: query) - bk._cache_key = self._cache_key + k - bk._bake(session) - baked_queries.append((k, bk._cache_key, v)) - del compile_state.attributes[k] - - def _unbake_subquery_loaders( - self, session, compile_state, context, params, post_criteria - ): - """Retrieve subquery eager loaders stored by _bake_subquery_loaders - and turn them back into Result objects that will iterate just - like a Query object. - - """ - if "baked_queries" not in compile_state.attributes: - return - - for k, cache_key, v in compile_state.attributes["baked_queries"]: - query = v["query"] - bk = BakedQuery( - self._bakery, lambda sess, q=query: q.with_session(sess) - ) - bk._cache_key = cache_key - q = bk.for_session(session) - for fn in post_criteria: - q = q.with_post_criteria(fn) - v = dict(v) - v["query"] = q.params(**params) - context.attributes[k] = v - class Result(object): """Invokes a :class:`.BakedQuery` against a :class:`.Session`. @@ -406,17 +353,19 @@ class Result(object): This adds a function that will be run against the :class:`_query.Query` object after it is retrieved from the - cache. Functions here can be used to alter the query in ways - that **do not affect the SQL output**, such as execution options - and shard identifiers (when using a shard-enabled query object) + cache. This currently includes **only** the + :meth:`_query.Query.params` and :meth:`_query.Query.execution_options` + methods. .. warning:: :meth:`_baked.Result.with_post_criteria` functions are applied to the :class:`_query.Query` object **after** the query's SQL statement - object has been retrieved from the cache. Any operations here - which intend to modify the SQL should ensure that - :meth:`.BakedQuery.spoil` was called first. + object has been retrieved from the cache. Only + :meth:`_query.Query.params` and + :meth:`_query.Query.execution_options` + methods should be used. + .. versionadded:: 1.2 @@ -438,40 +387,41 @@ class Result(object): def _iter(self): bq = self.bq + if not self.session.enable_baked_queries or bq._spoiled: return self._as_query()._iter() - baked_compile_state = bq._bakery.get( - bq._effective_key(self.session), None + query, statement = bq._bakery.get( + bq._effective_key(self.session), (None, None) ) - if baked_compile_state is None: - baked_compile_state = bq._bake(self.session) - - context = QueryContext(baked_compile_state, self.session) - context.session = self.session - - bq._unbake_subquery_loaders( - self.session, - baked_compile_state, - context, - self._params, - self._post_criteria, - ) - - # asserts true - # if isinstance(baked_compile_state.statement, expression.Select): - # assert baked_compile_state.statement._label_style == \ - # LABEL_STYLE_TABLENAME_PLUS_COL + if query is None: + query, statement = bq._bake(self.session) - if context.autoflush and not context.populate_existing: - self.session._autoflush() - q = context.orm_query.params(self._params).with_session(self.session) + q = query.params(self._params) for fn in self._post_criteria: q = fn(q) params = q.load_options._params + q.load_options += {"_orm_query": q} + execution_options = dict(q._execution_options) + execution_options.update( + { + "_sa_orm_load_options": q.load_options, + "compiled_cache": bq._bakery, + } + ) + + result = self.session.execute( + statement, params, execution_options=execution_options + ) + + if result._attributes.get("is_single_entity", False): + result = result.scalars() + + if result._attributes.get("filtered", False): + result = result.unique() - return q._execute_and_instances(context, params=params) + return result def count(self): """return the 'count'. @@ -583,10 +533,10 @@ class Result(object): query = self.bq.steps[0](self.session) return query._get_impl(ident, self._load_on_pk_identity) - def _load_on_pk_identity(self, query, primary_key_identity): + def _load_on_pk_identity(self, session, query, primary_key_identity, **kw): """Load the given primary key identity from the database.""" - mapper = query._only_full_mapper_zero("load_on_pk_identity") + mapper = query._raw_columns[0]._annotations["parententity"] _get_clause, _get_params = mapper._get_clause diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 919f4409a..1375a24cd 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -15,10 +15,8 @@ the source distribution. """ -import copy - +from sqlalchemy import event from .. import inspect -from .. import util from ..orm.query import Query from ..orm.session import Session @@ -37,54 +35,32 @@ class ShardedQuery(Query): all subsequent operations with the returned query will be against the single shard regardless of other state. - """ - q = self._clone() - q._shard_id = shard_id - return q + The shard_id can be passed for a 2.0 style execution to the + bind_arguments dictionary of :meth:`.Session.execute`:: - def _execute_and_instances(self, context, params=None): - if params is None: - params = self.load_options._params - - def iter_for_shard(shard_id): - # shallow copy, so that each context may be used by - # ORM load events and similar. - copied_context = copy.copy(context) - copied_context.attributes = context.attributes.copy() - - copied_context.attributes[ - "shard_id" - ] = copied_context.identity_token = shard_id - result_ = self._connection_from_session( - mapper=context.compile_state._bind_mapper(), shard_id=shard_id - ).execute( - copied_context.compile_state.statement, - self.load_options._params, + results = session.execute( + stmt, + bind_arguments={"shard_id": "my_shard"} ) - return self.instances(result_, copied_context) - if context.identity_token is not None: - return iter_for_shard(context.identity_token) - elif self._shard_id is not None: - return iter_for_shard(self._shard_id) - else: - partial = [] - for shard_id in self.query_chooser(self): - result_ = iter_for_shard(shard_id) - partial.append(result_) + """ - return partial[0].merge(*partial[1:]) + q = self._clone() + q._shard_id = shard_id + return q def _execute_crud(self, stmt, mapper): def exec_for_shard(shard_id): - conn = self._connection_from_session( + conn = self.session.connection( mapper=mapper, shard_id=shard_id, clause=stmt, close_with_result=True, ) - result = conn.execute(stmt, self.load_options._params) + result = conn._execute_20( + stmt, self.load_options._params, self._execution_options + ) return result if self._shard_id is not None: @@ -99,38 +75,6 @@ class ShardedQuery(Query): return ShardedResult(results, rowcount) - def _get_impl(self, primary_key_identity, db_load_fn, identity_token=None): - """Override the default Query._get_impl() method so that we emit - a query to the DB for each possible identity token, if we don't - have one already. - - """ - - def _db_load_fn(query, primary_key_identity): - # load from the database. The original db_load_fn will - # use the given Query object to load from the DB, so our - # shard_id is what will indicate the DB that we query from. - if self._shard_id is not None: - return db_load_fn(self, primary_key_identity) - else: - ident = util.to_list(primary_key_identity) - # build a ShardedQuery for each shard identifier and - # try to load from the DB - for shard_id in self.id_chooser(self, ident): - q = self.set_shard(shard_id) - o = db_load_fn(q, ident) - if o is not None: - return o - else: - return None - - if identity_token is None and self._shard_id is not None: - identity_token = self._shard_id - - return super(ShardedQuery, self)._get_impl( - primary_key_identity, _db_load_fn, identity_token=identity_token - ) - class ShardedResult(object): """A value object that represents multiple :class:`_engine.CursorResult` @@ -190,11 +134,14 @@ class ShardedSession(Session): """ super(ShardedSession, self).__init__(query_cls=query_cls, **kwargs) + + event.listen( + self, "do_orm_execute", execute_and_instances, retval=True + ) self.shard_chooser = shard_chooser self.id_chooser = id_chooser self.query_chooser = query_chooser self.__binds = {} - self.connection_callable = self.connection if shards is not None: for k in shards: self.bind_shard(k, shards[k]) @@ -207,8 +154,8 @@ class ShardedSession(Session): lazy_loaded_from=None, **kw ): - """override the default :meth:`.Session._identity_lookup` method so that we - search for a given non-token primary key identity across all + """override the default :meth:`.Session._identity_lookup` method so + that we search for a given non-token primary key identity across all possible identity tokens (e.g. shard ids). .. versionchanged:: 1.4 Moved :meth:`.Session._identity_lookup` from @@ -255,7 +202,14 @@ class ShardedSession(Session): state.identity_token = shard_id return shard_id - def connection(self, mapper=None, instance=None, shard_id=None, **kwargs): + def connection_callable( + self, mapper=None, instance=None, shard_id=None, **kwargs + ): + """Provide a :class:`_engine.Connection` to use in the unit of work + flush process. + + """ + if shard_id is None: shard_id = self._choose_shard_and_assign(mapper, instance) @@ -267,7 +221,7 @@ class ShardedSession(Session): ).connect(**kwargs) def get_bind( - self, mapper, shard_id=None, instance=None, clause=None, **kw + self, mapper=None, shard_id=None, instance=None, clause=None, **kw ): if shard_id is None: shard_id = self._choose_shard_and_assign( @@ -277,3 +231,55 @@ class ShardedSession(Session): def bind_shard(self, shard_id, bind): self.__binds[shard_id] = bind + + +def execute_and_instances(orm_context): + if orm_context.bind_arguments.get("_horizontal_shard", False): + return None + + params = orm_context.parameters + + load_options = orm_context.load_options + session = orm_context.session + orm_query = orm_context.orm_query + + if params is None: + params = load_options._params + + def iter_for_shard(shard_id, load_options): + execution_options = dict(orm_context.execution_options) + + bind_arguments = dict(orm_context.bind_arguments) + bind_arguments["_horizontal_shard"] = True + bind_arguments["shard_id"] = shard_id + + load_options += {"_refresh_identity_token": shard_id} + execution_options["_sa_orm_load_options"] = load_options + + return session.execute( + orm_context.statement, + orm_context.parameters, + execution_options, + bind_arguments, + ) + + if load_options._refresh_identity_token is not None: + shard_id = load_options._refresh_identity_token + elif orm_query is not None and orm_query._shard_id is not None: + shard_id = orm_query._shard_id + elif "shard_id" in orm_context.bind_arguments: + shard_id = orm_context.bind_arguments["shard_id"] + else: + shard_id = None + + if shard_id is not None: + return iter_for_shard(shard_id, load_options) + else: + partial = [] + for shard_id in session.query_chooser( + orm_query if orm_query is not None else orm_context.statement + ): + result_ = iter_for_shard(shard_id, load_options) + partial.append(result_) + + return partial[0].merge(*partial[1:]) diff --git a/lib/sqlalchemy/future/selectable.py b/lib/sqlalchemy/future/selectable.py index 2b76245e0..58fced887 100644 --- a/lib/sqlalchemy/future/selectable.py +++ b/lib/sqlalchemy/future/selectable.py @@ -56,7 +56,9 @@ class Select(_LegacySelect): self = cls.__new__(cls) self._raw_columns = [ - coercions.expect(roles.ColumnsClauseRole, ent, apply_plugins=self) + coercions.expect( + roles.ColumnsClauseRole, ent, apply_propagate_attrs=self + ) for ent in entities ] @@ -71,9 +73,9 @@ class Select(_LegacySelect): def _filter_by_zero(self): if self._setup_joins: - meth = SelectState.get_plugin_classmethod( - self, "determine_last_joined_entity" - ) + meth = SelectState.get_plugin_class( + self + ).determine_last_joined_entity _last_joined_entity = meth(self) if _last_joined_entity is not None: return _last_joined_entity @@ -106,7 +108,7 @@ class Select(_LegacySelect): """ target = coercions.expect( - roles.JoinTargetRole, target, apply_plugins=self + roles.JoinTargetRole, target, apply_propagate_attrs=self ) self._setup_joins += ( (target, onclause, None, {"isouter": isouter, "full": full}), @@ -123,12 +125,15 @@ class Select(_LegacySelect): """ + # note the order of parsing from vs. target is important here, as we + # are also deriving the source of the plugin (i.e. the subject mapper + # in an ORM query) which should favor the "from_" over the "target" - target = coercions.expect( - roles.JoinTargetRole, target, apply_plugins=self - ) from_ = coercions.expect( - roles.FromClauseRole, from_, apply_plugins=self + roles.FromClauseRole, from_, apply_propagate_attrs=self + ) + target = coercions.expect( + roles.JoinTargetRole, target, apply_propagate_attrs=self ) self._setup_joins += ( diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 0a353f81c..110c27811 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -30,6 +30,7 @@ from .mapper import reconstructor # noqa from .mapper import validates # noqa from .properties import ColumnProperty # noqa from .query import AliasOption # noqa +from .query import FromStatement # noqa from .query import Query # noqa from .relationships import foreign # noqa from .relationships import RelationshipProperty # noqa @@ -39,8 +40,10 @@ from .session import close_all_sessions # noqa from .session import make_transient # noqa from .session import make_transient_to_detached # noqa from .session import object_session # noqa +from .session import ORMExecuteState # noqa from .session import Session # noqa from .session import sessionmaker # noqa +from .session import SessionTransaction # noqa from .strategy_options import Load # noqa from .util import aliased # noqa from .util import Bundle # noqa diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 7b4415bfe..262a1efc9 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -207,6 +207,10 @@ class QueryableAttribute( def __clause_element__(self): return self.expression + @property + def _from_objects(self): + return self.expression._from_objects + def _bulk_update_tuples(self, value): """Return setter tuples for a bulk UPDATE.""" diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 0a3701134..3acab7df7 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -18,19 +18,21 @@ from .util import Bundle from .util import join as orm_join from .util import ORMAdapter from .. import exc as sa_exc +from .. import future from .. import inspect from .. import sql from .. import util -from ..future.selectable import Select as FutureSelect from ..sql import coercions from ..sql import expression from ..sql import roles from ..sql import util as sql_util from ..sql import visitors from ..sql.base import CacheableOptions +from ..sql.base import CompileState from ..sql.base import Options +from ..sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY +from ..sql.selectable import LABEL_STYLE_NONE from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL -from ..sql.selectable import Select from ..sql.selectable import SelectState from ..sql.visitors import ExtendedInternalTraversal from ..sql.visitors import InternalTraversal @@ -44,6 +46,8 @@ class QueryContext(object): "orm_query", "query", "load_options", + "bind_arguments", + "execution_options", "session", "autoflush", "populate_existing", @@ -51,7 +55,7 @@ class QueryContext(object): "version_check", "refresh_state", "create_eager_joins", - "propagate_options", + "propagated_loader_options", "attributes", "runid", "partials", @@ -70,20 +74,30 @@ class QueryContext(object): _yield_per = None _refresh_state = None _lazy_loaded_from = None + _orm_query = None _params = util.immutabledict() - def __init__(self, compile_state, session): - query = compile_state.query + def __init__( + self, + compile_state, + session, + load_options, + execution_options=None, + bind_arguments=None, + ): + self.load_options = load_options + self.execution_options = execution_options or {} + self.bind_arguments = bind_arguments or {} self.compile_state = compile_state self.orm_query = compile_state.orm_query - self.query = compile_state.query + self.query = query = compile_state.query self.session = session - self.load_options = load_options = query.load_options - self.propagate_options = set( + self.propagated_loader_options = { o for o in query._with_options if o.propagate_to_loaders - ) + } + self.attributes = dict(compile_state.attributes) self.autoflush = load_options._autoflush @@ -92,11 +106,7 @@ class QueryContext(object): self.version_check = load_options._version_check self.refresh_state = load_options._refresh_state self.yield_per = load_options._yield_per - - if self.refresh_state is not None: - self.identity_token = load_options._refresh_identity_token - else: - self.identity_token = None + self.identity_token = load_options._refresh_identity_token if self.yield_per and compile_state._no_yield_pers: raise sa_exc.InvalidRequestError( @@ -119,25 +129,10 @@ class QueryContext(object): ) -class QueryCompileState(sql.base.CompileState): - _joinpath = _joinpoint = util.immutabledict() - _from_obj_alias = None - _has_mapper_entities = False - - _has_orm_entities = False - multi_row_eager_loaders = False - compound_eager_adapter = None - loaders_require_buffering = False - loaders_require_uniquing = False - - correlate = None - _where_criteria = () - _having_criteria = () - - orm_query = None - +class ORMCompileState(CompileState): class default_compile_options(CacheableOptions): _cache_key_traversal = [ + ("_orm_results", InternalTraversal.dp_boolean), ("_bake_ok", InternalTraversal.dp_boolean), ( "_with_polymorphic_adapt_map", @@ -153,136 +148,310 @@ class QueryCompileState(sql.base.CompileState): ("_for_refresh_state", InternalTraversal.dp_boolean), ] + _orm_results = True _bake_ok = True _with_polymorphic_adapt_map = () _current_path = _path_registry _enable_single_crit = True - _statement = None _enable_eagerloads = True _orm_only_from_obj_alias = True _only_load_props = None _set_base_alias = False _for_refresh_state = False + # non-cache-key elements mostly for legacy use + _statement = None + _orm_query = None + + @classmethod + def merge(cls, other): + return cls + other._state_dict() + + orm_query = None + current_path = _path_registry + def __init__(self, *arg, **kw): raise NotImplementedError() @classmethod - def _create_for_select(cls, statement, compiler, **kw): - if not statement._is_future: - return SelectState(statement, compiler, **kw) + def create_for_statement(cls, statement_container, compiler, **kw): + raise NotImplementedError() - self = cls.__new__(cls) + @classmethod + def _create_for_legacy_query(cls, query, for_statement=False): + stmt = query._statement_20(orm_results=not for_statement) - if not isinstance( - statement.compile_options, cls.default_compile_options - ): - statement.compile_options = cls.default_compile_options - orm_state = self._create_for_legacy_query_via_either(statement) - compile_state = SelectState(orm_state.statement, compiler, **kw) - compile_state._orm_state = orm_state - return compile_state + if query.compile_options._statement is not None: + compile_state_cls = ORMFromStatementCompileState + else: + compile_state_cls = ORMSelectCompileState + + # true in all cases except for two tests in test/orm/test_events.py + # assert stmt.compile_options._orm_query is query + return compile_state_cls._create_for_statement_or_query( + stmt, for_statement=for_statement + ) @classmethod - def _create_future_select_from_query(cls, query): - stmt = FutureSelect.__new__(FutureSelect) - - # the internal state of Query is now a mirror of that of - # Select which can be transferred directly. The Select - # supports compilation into its correct form taking all ORM - # features into account via the plugin and the compile options. - # however it does not export its columns or other attributes - # correctly if deprecated ORM features that adapt plain mapped - # elements are used; for this reason the Select() returned here - # can always support direct execution, but for composition in a larger - # select only works if it does not represent legacy ORM adaption - # features. - stmt.__dict__.update( - dict( - _raw_columns=query._raw_columns, - _compile_state_plugin="orm", # ;) - _where_criteria=query._where_criteria, - _from_obj=query._from_obj, - _legacy_setup_joins=query._legacy_setup_joins, - _order_by_clauses=query._order_by_clauses, - _group_by_clauses=query._group_by_clauses, - _having_criteria=query._having_criteria, - _distinct=query._distinct, - _distinct_on=query._distinct_on, - _with_options=query._with_options, - _with_context_options=query._with_context_options, - _hints=query._hints, - _statement_hints=query._statement_hints, - _correlate=query._correlate, - _auto_correlate=query._auto_correlate, - _limit_clause=query._limit_clause, - _offset_clause=query._offset_clause, - _for_update_arg=query._for_update_arg, - _prefixes=query._prefixes, - _suffixes=query._suffixes, - _label_style=query._label_style, - compile_options=query.compile_options, - # this will be moving but for now make it work like orm.Query - load_options=query.load_options, + def _create_for_statement_or_query( + cls, statement_container, for_statement=False, + ): + raise NotImplementedError() + + @classmethod + def orm_pre_session_exec( + cls, session, statement, execution_options, bind_arguments + ): + if execution_options: + # TODO: will have to provide public API to set some load + # options and also extract them from that API here, likely + # execution options + load_options = execution_options.get( + "_sa_orm_load_options", QueryContext.default_load_options ) + else: + load_options = QueryContext.default_load_options + + bind_arguments["clause"] = statement + + # new in 1.4 - the coercions system is leveraged to allow the + # "subject" mapper of a statement be propagated to the top + # as the statement is built. "subject" mapper is the generally + # standard object used as an identifier for multi-database schemes. + + if "plugin_subject" in statement._propagate_attrs: + bind_arguments["mapper"] = statement._propagate_attrs[ + "plugin_subject" + ].mapper + + if load_options._autoflush: + session._autoflush() + + @classmethod + def orm_setup_cursor_result(cls, session, bind_arguments, result): + execution_context = result.context + compile_state = execution_context.compiled.compile_state + + # cover edge case where ORM entities used in legacy select + # were passed to session.execute: + # session.execute(legacy_select([User.id, User.name])) + # see test_query->test_legacy_tuple_old_select + if not execution_context.compiled.statement._is_future: + return result + + execution_options = execution_context.execution_options + + # we are getting these right above in orm_pre_session_exec(), + # then getting them again right here. + load_options = execution_options.get( + "_sa_orm_load_options", QueryContext.default_load_options + ) + querycontext = QueryContext( + compile_state, + session, + load_options, + execution_options, + bind_arguments, ) + return loading.instances(result, querycontext) - return stmt + @property + def _mapper_entities(self): + return ( + ent for ent in self._entities if isinstance(ent, _MapperEntity) + ) + + def _create_with_polymorphic_adapter(self, ext_info, selectable): + if ( + not ext_info.is_aliased_class + and ext_info.mapper.persist_selectable + not in self._polymorphic_adapters + ): + self._mapper_loads_polymorphically_with( + ext_info.mapper, + sql_util.ColumnAdapter( + selectable, ext_info.mapper._equivalent_columns + ), + ) + + def _mapper_loads_polymorphically_with(self, mapper, adapter): + for m2 in mapper._with_polymorphic_mappers or [mapper]: + self._polymorphic_adapters[m2] = adapter + for m in m2.iterate_to_root(): + self._polymorphic_adapters[m.local_table] = adapter + + +@sql.base.CompileState.plugin_for("orm", "grouping") +class ORMFromStatementCompileState(ORMCompileState): + _aliased_generations = util.immutabledict() + _from_obj_alias = None + _has_mapper_entities = False + + _has_orm_entities = False + multi_row_eager_loaders = False + compound_eager_adapter = None + loaders_require_buffering = False + loaders_require_uniquing = False + + @classmethod + def create_for_statement(cls, statement_container, compiler, **kw): + compiler._rewrites_selected_columns = True + return cls._create_for_statement_or_query(statement_container) @classmethod - def _create_for_legacy_query( - cls, query, for_statement=False, entities_only=False + def _create_for_statement_or_query( + cls, statement_container, for_statement=False, ): - # as we are seeking to use Select() with ORM state as the - # primary executable element, have all Query objects that are not - # from_statement() convert to a Select() first, then run on that. + # from .query import FromStatement - if query.compile_options._statement is not None: - return cls._create_for_legacy_query_via_either( - query, - for_statement=for_statement, - entities_only=entities_only, - orm_query=query, - ) + # assert isinstance(statement_container, FromStatement) + + self = cls.__new__(cls) + self._primary_entity = None + + self.orm_query = statement_container.compile_options._orm_query + + self.statement_container = self.query = statement_container + self.requested_statement = statement_container.element + + self._entities = [] + self._with_polymorphic_adapt_map = {} + self._polymorphic_adapters = {} + self._no_yield_pers = set() + + _QueryEntity.to_compile_state(self, statement_container._raw_columns) + + self.compile_options = statement_container.compile_options + + self.current_path = statement_container.compile_options._current_path + + if statement_container._with_options: + self.attributes = {"_unbound_load_dedupes": set()} + + for opt in statement_container._with_options: + if opt._is_compile_state: + opt.process_compile_state(self) + else: + self.attributes = {} + + if statement_container._with_context_options: + for fn, key in statement_container._with_context_options: + fn(self) + + self.primary_columns = [] + self.secondary_columns = [] + self.eager_joins = {} + self.single_inh_entities = {} + self.create_eager_joins = [] + self._fallback_from_clauses = [] + self._setup_for_statement() + + return self + + def _setup_for_statement(self): + statement = self.requested_statement + if ( + isinstance(statement, expression.SelectBase) + and not statement._is_textual + and not statement.use_labels + ): + self.statement = statement.apply_labels() else: - assert query.compile_options._statement is None + self.statement = statement + self.order_by = None - stmt = cls._create_future_select_from_query(query) + if isinstance(self.statement, expression.TextClause): + # setup for all entities. Currently, this is not useful + # for eager loaders, as the eager loaders that work are able + # to do their work entirely in row_processor. + for entity in self._entities: + entity.setup_compile_state(self) - return cls._create_for_legacy_query_via_either( - stmt, - for_statement=for_statement, - entities_only=entities_only, - orm_query=query, + # we did the setup just to get primary columns. + self.statement = expression.TextualSelect( + self.statement, self.primary_columns, positional=False ) + else: + # allow TextualSelect with implicit columns as well + # as select() with ad-hoc columns, see test_query::TextTest + self._from_obj_alias = sql.util.ColumnAdapter( + self.statement, adapt_on_names=True + ) + # set up for eager loaders, however if we fix subqueryload + # it should not need to do this here. the model of eager loaders + # that can work entirely in row_processor might be interesting + # here though subqueryloader has a lot of upfront work to do + # see test/orm/test_query.py -> test_related_eagerload_against_text + # for where this part makes a difference. would rather have + # subqueryload figure out what it needs more intelligently. + # for entity in self._entities: + # entity.setup_compile_state(self) + + def _adapt_col_list(self, cols, current_adapter): + return cols + + def _get_current_adapter(self): + return None + + +@sql.base.CompileState.plugin_for("orm", "select") +class ORMSelectCompileState(ORMCompileState, SelectState): + _joinpath = _joinpoint = util.immutabledict() + _from_obj_alias = None + _has_mapper_entities = False + + _has_orm_entities = False + multi_row_eager_loaders = False + compound_eager_adapter = None + loaders_require_buffering = False + loaders_require_uniquing = False + + correlate = None + _where_criteria = () + _having_criteria = () + + orm_query = None + + @classmethod + def create_for_statement(cls, statement, compiler, **kw): + if not statement._is_future: + return SelectState(statement, compiler, **kw) + + compiler._rewrites_selected_columns = True + + orm_state = cls._create_for_statement_or_query( + statement, for_statement=True + ) + SelectState.__init__(orm_state, orm_state.statement, compiler, **kw) + return orm_state @classmethod - def _create_for_legacy_query_via_either( - cls, query, for_statement=False, entities_only=False, orm_query=None + def _create_for_statement_or_query( + cls, query, for_statement=False, _entities_only=False, ): + assert isinstance(query, future.Select) + + query.compile_options = cls.default_compile_options.merge( + query.compile_options + ) self = cls.__new__(cls) self._primary_entity = None - self.has_select = isinstance(query, Select) + self.orm_query = query.compile_options._orm_query - if orm_query: - self.orm_query = orm_query - self.query = query - self.has_orm_query = True - else: - self.query = query - if not self.has_select: - self.orm_query = query - self.has_orm_query = True - else: - self.orm_query = None - self.has_orm_query = False + self.query = query self.select_statement = select_statement = query + if not hasattr(select_statement.compile_options, "_orm_results"): + select_statement.compile_options = cls.default_compile_options + select_statement.compile_options += {"_orm_results": for_statement} + else: + for_statement = not select_statement.compile_options._orm_results + self.query = query self._entities = [] @@ -300,19 +469,28 @@ class QueryCompileState(sql.base.CompileState): _QueryEntity.to_compile_state(self, select_statement._raw_columns) - if entities_only: + if _entities_only: return self self.compile_options = query.compile_options + + # TODO: the name of this flag "for_statement" has to change, + # as it is difficult to distinguish from the "query._statement" use + # case which is something totally different self.for_statement = for_statement - if self.has_orm_query and not for_statement: - self.label_style = LABEL_STYLE_TABLENAME_PLUS_COL + # determine label style. we can make different decisions here. + # at the moment, trying to see if we can always use DISAMBIGUATE_ONLY + # rather than LABEL_STYLE_NONE, and if we can use disambiguate style + # for new style ORM selects too. + if self.select_statement._label_style is LABEL_STYLE_NONE: + if self.orm_query and not for_statement: + self.label_style = LABEL_STYLE_TABLENAME_PLUS_COL + else: + self.label_style = LABEL_STYLE_DISAMBIGUATE_ONLY else: self.label_style = self.select_statement._label_style - self.labels = self.label_style is LABEL_STYLE_TABLENAME_PLUS_COL - self.current_path = select_statement.compile_options._current_path self.eager_order_by = () @@ -321,7 +499,7 @@ class QueryCompileState(sql.base.CompileState): self.attributes = {"_unbound_load_dedupes": set()} for opt in self.select_statement._with_options: - if not opt._is_legacy_option: + if opt._is_compile_state: opt.process_compile_state(self) else: self.attributes = {} @@ -341,13 +519,50 @@ class QueryCompileState(sql.base.CompileState): info.selectable for info in select_statement._from_obj ] - if self.compile_options._statement is not None: - self._setup_for_statement() - else: - self._setup_for_generate() + self._setup_for_generate() return self + @classmethod + def _create_entities_collection(cls, query): + """Creates a partial ORMSelectCompileState that includes + the full collection of _MapperEntity and other _QueryEntity objects. + + Supports a few remaining use cases that are pre-compilation + but still need to gather some of the column / adaption information. + + """ + self = cls.__new__(cls) + + self._entities = [] + self._primary_entity = None + self._aliased_generations = {} + self._polymorphic_adapters = {} + + # legacy: only for query.with_polymorphic() + self._with_polymorphic_adapt_map = wpam = dict( + query.compile_options._with_polymorphic_adapt_map + ) + if wpam: + self._setup_with_polymorphics() + + _QueryEntity.to_compile_state(self, query._raw_columns) + return self + + @classmethod + def determine_last_joined_entity(cls, statement): + setup_joins = statement._setup_joins + + if not setup_joins: + return None + + (target, onclause, from_, flags) = setup_joins[-1] + + if isinstance(target, interfaces.PropComparator): + return target.entity + else: + return target + def _setup_with_polymorphics(self): # legacy: only for query.with_polymorphic() for ext_info, wp in self._with_polymorphic_adapt_map.items(): @@ -404,34 +619,6 @@ class QueryCompileState(sql.base.CompileState): return None - def _deep_entity_zero(self): - """Return a 'deep' entity; this is any entity we can find associated - with the first entity / column experssion. this is used only for - session.get_bind(). - - it is hoped this concept can be removed in an upcoming change - to the ORM execution model. - - """ - for ent in self.from_clauses: - if "parententity" in ent._annotations: - return ent._annotations["parententity"].mapper - for ent in self._entities: - ezero = ent._deep_entity_zero() - if ezero is not None: - return ezero.mapper - else: - return None - - @property - def _mapper_entities(self): - for ent in self._entities: - if isinstance(ent, _MapperEntity): - yield ent - - def _bind_mapper(self): - return self._deep_entity_zero() - def _only_full_mapper_zero(self, methname): if self._entities != [self._primary_entity]: raise sa_exc.InvalidRequestError( @@ -490,7 +677,7 @@ class QueryCompileState(sql.base.CompileState): else query._order_by_clauses ) - if query._having_criteria is not None: + if query._having_criteria: self._having_criteria = tuple( current_adapter(crit, True, True) if current_adapter else crit for crit in query._having_criteria @@ -527,7 +714,7 @@ class QueryCompileState(sql.base.CompileState): for s in query._correlate ) ) - elif self.has_select and not query._auto_correlate: + elif not query._auto_correlate: self.correlate = (None,) # PART II @@ -582,33 +769,6 @@ class QueryCompileState(sql.base.CompileState): {"deepentity": ezero} ) - def _setup_for_statement(self): - compile_options = self.compile_options - - if ( - isinstance(compile_options._statement, expression.SelectBase) - and not compile_options._statement._is_textual - and not compile_options._statement.use_labels - ): - self.statement = compile_options._statement.apply_labels() - else: - self.statement = compile_options._statement - self.order_by = None - - if isinstance(self.statement, expression.TextClause): - # setup for all entities, including contains_eager entities. - for entity in self._entities: - entity.setup_compile_state(self) - self.statement = expression.TextualSelect( - self.statement, self.primary_columns, positional=False - ) - else: - # allow TextualSelect with implicit columns as well - # as select() with ad-hoc columns, see test_query::TextTest - self._from_obj_alias = sql.util.ColumnAdapter( - self.statement, adapt_on_names=True - ) - def _compound_eager_statement(self): # for eager joins present and LIMIT/OFFSET/DISTINCT, # wrap the query inside a select, @@ -659,9 +819,10 @@ class QueryCompileState(sql.base.CompileState): self.compound_eager_adapter = sql_util.ColumnAdapter(inner, equivs) - statement = sql.select( - [inner] + self.secondary_columns, use_labels=self.labels + statement = future.select( + *([inner] + self.secondary_columns) # use_labels=self.labels ) + statement._label_style = self.label_style # Oracle however does not allow FOR UPDATE on the subquery, # and the Oracle dialect ignores it, plus for PostgreSQL, MySQL @@ -752,6 +913,7 @@ class QueryCompileState(sql.base.CompileState): group_by, ): + Select = future.Select statement = Select.__new__(Select) statement._raw_columns = raw_columns statement._from_obj = from_obj @@ -794,25 +956,6 @@ class QueryCompileState(sql.base.CompileState): return statement - def _create_with_polymorphic_adapter(self, ext_info, selectable): - if ( - not ext_info.is_aliased_class - and ext_info.mapper.persist_selectable - not in self._polymorphic_adapters - ): - self._mapper_loads_polymorphically_with( - ext_info.mapper, - sql_util.ColumnAdapter( - selectable, ext_info.mapper._equivalent_columns - ), - ) - - def _mapper_loads_polymorphically_with(self, mapper, adapter): - for m2 in mapper._with_polymorphic_mappers or [mapper]: - self._polymorphic_adapters[m2] = adapter - for m in m2.iterate_to_root(): - self._polymorphic_adapters[m.local_table] = adapter - def _adapt_polymorphic_element(self, element): if "parententity" in element._annotations: search = element._annotations["parententity"] @@ -924,6 +1067,8 @@ class QueryCompileState(sql.base.CompileState): # onclause = right right = None + elif "parententity" in right._annotations: + right = right._annotations["parententity"].entity if onclause is None: r_info = inspect(right) @@ -932,7 +1077,6 @@ class QueryCompileState(sql.base.CompileState): "Expected mapped entity or " "selectable/table as join target" ) - if isinstance(onclause, interfaces.PropComparator): of_type = getattr(onclause, "_of_type", None) else: @@ -1584,7 +1728,7 @@ class QueryCompileState(sql.base.CompileState): "aliased_generation": aliased_generation, } - return right, inspect(right), onclause + return inspect(right), right, onclause def _update_joinpoint(self, jp): self._joinpoint = jp @@ -1668,14 +1812,8 @@ class QueryCompileState(sql.base.CompileState): def _column_descriptions(query_or_select_stmt): - # TODO: this is a hack for now, as it is a little bit non-performant - # to build up QueryEntity for every entity right now. - ctx = QueryCompileState._create_for_legacy_query_via_either( - query_or_select_stmt, - entities_only=True, - orm_query=query_or_select_stmt - if not isinstance(query_or_select_stmt, Select) - else None, + ctx = ORMSelectCompileState._create_entities_collection( + query_or_select_stmt ) return [ { @@ -1731,23 +1869,6 @@ def _entity_from_pre_ent_zero(query_or_augmented_select): return ent -@sql.base.CompileState.plugin_for( - "orm", "select", "determine_last_joined_entity" -) -def _determine_last_joined_entity(statement): - setup_joins = statement._setup_joins - - if not setup_joins: - return None - - (target, onclause, from_, flags) = setup_joins[-1] - - if isinstance(target, interfaces.PropComparator): - return target.entity - else: - return target - - def _legacy_determine_last_joined_entity(setup_joins, entity_zero): """given the legacy_setup_joins collection at a point in time, figure out what the "filter by entity" would be in terms @@ -1929,9 +2050,6 @@ class _MapperEntity(_QueryEntity): def entity_zero_or_selectable(self): return self.entity_zero - def _deep_entity_zero(self): - return self.entity_zero - def corresponds_to(self, entity): return _entity_corresponds_to(self.entity_zero, entity) @@ -2093,14 +2211,6 @@ class _BundleEntity(_QueryEntity): else: return None - def _deep_entity_zero(self): - for ent in self._entities: - ezero = ent._deep_entity_zero() - if ezero is not None: - return ezero - else: - return None - def setup_compile_state(self, compile_state): for ent in self._entities: ent.setup_compile_state(compile_state) @@ -2175,17 +2285,6 @@ class _RawColumnEntity(_ColumnEntity): ) self._extra_entities = (self.expr, self.column) - def _deep_entity_zero(self): - for obj in visitors.iterate( - self.column, {"column_tables": True, "column_collections": False}, - ): - if "parententity" in obj._annotations: - return obj._annotations["parententity"] - elif "deepentity" in obj._annotations: - return obj._annotations["deepentity"] - else: - return None - def corresponds_to(self, entity): return False @@ -2276,9 +2375,6 @@ class _ORMColumnEntity(_ColumnEntity): ezero, ezero.selectable ) - def _deep_entity_zero(self): - return self.mapper - def corresponds_to(self, entity): if _is_aliased_class(entity): # TODO: polymorphic subclasses ? @@ -2342,8 +2438,3 @@ class _ORMColumnEntity(_ColumnEntity): compile_state.primary_columns.append(column) compile_state.attributes[("fetch_column", self)] = column - - -sql.base.CompileState.plugin_for("orm", "select")( - QueryCompileState._create_for_select -) diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index f5d191860..be7aa272e 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -1397,6 +1397,43 @@ class SessionEvents(event.Events): event_key.base_listen(**kw) + def do_orm_execute(self, orm_execute_state): + """Intercept statement executions that occur in terms of a :class:`.Session`. + + This event is invoked for all top-level SQL statements invoked + from the :meth:`_orm.Session.execute` method. As of SQLAlchemy 1.4, + all ORM queries emitted on behalf of a :class:`_orm.Session` will + flow through this method, so this event hook provides the single + point at which ORM queries of all types may be intercepted before + they are invoked, and additionally to replace their execution with + a different process. + + This event is a ``do_`` event, meaning it has the capability to replace + the operation that the :meth:`_orm.Session.execute` method normally + performs. The intended use for this includes sharding and + result-caching schemes which may seek to invoke the same statement + across multiple database connections, returning a result that is + merged from each of them, or which don't invoke the statement at all, + instead returning data from a cache. + + The hook intends to replace the use of the + ``Query._execute_and_instances`` method that could be subclassed prior + to SQLAlchemy 1.4. + + :param orm_execute_state: an instance of :class:`.ORMExecuteState` + which contains all information about the current execution, as well + as helper functions used to derive other commonly required + information. See that object for details. + + .. seealso:: + + :class:`.ORMExecuteState` + + + .. versionadded:: 1.4 + + """ + def after_transaction_create(self, session, transaction): """Execute when a new :class:`.SessionTransaction` is created. diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 313f2fda8..6c0f5d3ef 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -64,6 +64,12 @@ __all__ = ( ) +class ORMStatementRole(roles.CoerceTextStatementRole): + _role_name = ( + "Executable SQL or text() construct, including ORM " "aware objects" + ) + + class ORMColumnsClauseRole(roles.ColumnsClauseRole): _role_name = "ORM mapped entity, aliased entity, or Column expression" @@ -662,8 +668,15 @@ class StrategizedProperty(MapperProperty): ) -class LoaderOption(HasCacheKey): - """Describe a modification to an ORM statement at compilation time. +class ORMOption(object): + """Base class for option objects that are passed to ORM queries. + + These options may be consumed by :meth:`.Query.options`, + :meth:`.Select.options`, or in a more general sense by any + :meth:`.Executable.options` method. They are interpreted at + statement compile time or execution time in modern use. The + deprecated :class:`.MapperOption` is consumed at ORM query construction + time. .. versionadded:: 1.4 @@ -680,6 +693,18 @@ class LoaderOption(HasCacheKey): """ + _is_compile_state = False + + +class LoaderOption(HasCacheKey, ORMOption): + """Describe a loader modification to an ORM statement at compilation time. + + .. versionadded:: 1.4 + + """ + + _is_compile_state = True + def process_compile_state(self, compile_state): """Apply a modification to a given :class:`.CompileState`.""" @@ -693,18 +718,39 @@ class LoaderOption(HasCacheKey): return False +class UserDefinedOption(ORMOption): + """Base class for a user-defined option that can be consumed from the + :meth:`.SessionEvents.do_orm_execute` event hook. + + """ + + _is_legacy_option = False + + propagate_to_loaders = False + """if True, indicate this option should be carried along + to "secondary" Query objects produced during lazy loads + or refresh operations. + + """ + + def __init__(self, payload=None): + self.payload = payload + + def _gen_cache_key(self, *arg, **kw): + return () + + @util.deprecated_cls( "1.4", "The :class:`.MapperOption class is deprecated and will be removed " - "in a future release. ORM options now run within the compilation " - "phase and are based on the :class:`.LoaderOption` class which is " - "intended for internal consumption only. For " - "modifications to queries on a per-execution basis, the " - ":meth:`.before_execute` hook will now intercept ORM :class:`.Query` " - "objects before they are invoked", + "in a future release. For " + "modifications to queries on a per-execution basis, use the " + ":class:`.UserDefinedOption` class to establish state within a " + ":class:`.Query` or other Core statement, then use the " + ":meth:`.SessionEvents.before_orm_execute` hook to consume them.", constructor=None, ) -class MapperOption(object): +class MapperOption(ORMOption): """Describe a modification to a Query""" _is_legacy_option = True @@ -735,23 +781,6 @@ class MapperOption(object): def _generate_path_cache_key(self, path): """Used by the "baked lazy loader" to see if this option can be cached. - The "baked lazy loader" refers to the :class:`_query.Query` that is - produced during a lazy load operation for a mapped relationship. - It does not yet apply to the "lazy" load operation for deferred - or expired column attributes, however this may change in the future. - - This loader generates SQL for a query only once and attempts to cache - it; from that point on, if the SQL has been cached it will no longer - run the :meth:`_query.Query.options` method of the - :class:`_query.Query`. The - :class:`.MapperOption` object that wishes to participate within a lazy - load operation therefore needs to tell the baked loader that it either - needs to forego this caching, or that it needs to include the state of - the :class:`.MapperOption` itself as part of its cache key, otherwise - SQL or other query state that has been affected by the - :class:`.MapperOption` may be cached in place of a query that does not - include these modifications, or the option may not be invoked at all. - By default, this method returns the value ``False``, which means the :class:`.BakedQuery` generated by the lazy loader will not cache the SQL when this :class:`.MapperOption` is present. @@ -760,26 +789,10 @@ class MapperOption(object): an unlimited number of :class:`_query.Query` objects for an unlimited number of :class:`.MapperOption` objects. - .. versionchanged:: 1.2.8 the default return value of - :meth:`.MapperOption._generate_cache_key` is False; previously it - was ``None`` indicating "safe to cache, don't include as part of - the cache key" - - To enable caching of :class:`_query.Query` objects within lazy loaders - , a - given :class:`.MapperOption` that returns a cache key must return a key - that uniquely identifies the complete state of this option, which will - match any other :class:`.MapperOption` that itself retains the - identical state. This includes path options, flags, etc. It should - be a state that is repeatable and part of a limited set of possible - options. - - If the :class:`.MapperOption` does not apply to the given path and - would not affect query results on such a path, it should return None, - indicating the :class:`_query.Query` is safe to cache for this given - loader path and that this :class:`.MapperOption` need not be - part of the cache key. - + For caching support it is recommended to use the + :class:`.UserDefinedOption` class in conjunction with + the :meth:`.Session.do_orm_execute` method so that statements may + be modified before they are cached. """ return False diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 48641685e..616e757a3 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -26,6 +26,7 @@ from .base import _SET_DEFERRED_EXPIRED from .util import _none_set from .util import state_str from .. import exc as sa_exc +from .. import future from .. import util from ..engine import result_tuple from ..engine.result import ChunkedIteratorResult @@ -36,8 +37,20 @@ from ..sql import util as sql_util _new_runid = util.counter() -def instances(query, cursor, context): - """Return an ORM result as an iterator.""" +def instances(cursor, context): + """Return a :class:`.Result` given an ORM query context. + + :param cursor: a :class:`.CursorResult`, generated by a statement + which came from :class:`.ORMCompileState` + + :param context: a :class:`.QueryContext` object + + :return: a :class:`.Result` object representing ORM results + + .. versionchanged:: 1.4 The instances() function now uses + :class:`.Result` objects and has an all new interface. + + """ context.runid = _new_runid() context.post_load_paths = {} @@ -80,7 +93,7 @@ def instances(query, cursor, context): ], ) - def chunks(size, as_tuples): + def chunks(size): while True: yield_per = size @@ -94,7 +107,7 @@ def instances(query, cursor, context): else: fetch = cursor.fetchall() - if not as_tuples: + if single_entity: proc = process[0] rows = [proc(row) for row in fetch] else: @@ -111,20 +124,62 @@ def instances(query, cursor, context): break result = ChunkedIteratorResult( - row_metadata, chunks, source_supports_scalars=single_entity + row_metadata, chunks, source_supports_scalars=single_entity, raw=cursor + ) + + result._attributes = result._attributes.union( + dict(filtered=filtered, is_single_entity=single_entity) ) + if context.yield_per: result.yield_per(context.yield_per) - if single_entity: - result = result.scalars() + return result - filtered = context.compile_state._has_mapper_entities - if filtered: - result = result.unique() +@util.preload_module("sqlalchemy.orm.context") +def merge_frozen_result(session, statement, frozen_result, load=True): + querycontext = util.preloaded.orm_context - return result + if load: + # flush current contents if we expect to load data + session._autoflush() + + ctx = querycontext.ORMSelectCompileState._create_entities_collection( + statement + ) + + autoflush = session.autoflush + try: + session.autoflush = False + mapped_entities = [ + i + for i, e in enumerate(ctx._entities) + if isinstance(e, querycontext._MapperEntity) + ] + keys = [ent._label_name for ent in ctx._entities] + + keyed_tuple = result_tuple( + keys, [ent._extra_entities for ent in ctx._entities] + ) + + result = [] + for newrow in frozen_result.rewrite_rows(): + for i in mapped_entities: + if newrow[i] is not None: + newrow[i] = session._merge( + attributes.instance_state(newrow[i]), + attributes.instance_dict(newrow[i]), + load=load, + _recursive={}, + _resolve_conflict_map={}, + ) + + result.append(keyed_tuple(newrow)) + + return frozen_result.with_new_rows(result) + finally: + session.autoflush = autoflush @util.preload_module("sqlalchemy.orm.context") @@ -145,9 +200,7 @@ def merge_result(query, iterator, load=True): else: frozen_result = None - ctx = querycontext.QueryCompileState._create_for_legacy_query( - query, entities_only=True - ) + ctx = querycontext.ORMSelectCompileState._create_entities_collection(query) autoflush = session.autoflush try: @@ -235,12 +288,15 @@ def get_from_identity(session, mapper, key, passive): def load_on_ident( - query, + session, + statement, key, + load_options=None, refresh_state=None, with_for_update=None, only_load_props=None, no_autoflush=False, + bind_arguments=util.immutabledict(), ): """Load the given identity key from the database.""" if key is not None: @@ -249,38 +305,59 @@ def load_on_ident( else: ident = identity_token = None - if no_autoflush: - query = query.autoflush(False) - return load_on_pk_identity( - query, + session, + statement, ident, + load_options=load_options, refresh_state=refresh_state, with_for_update=with_for_update, only_load_props=only_load_props, identity_token=identity_token, + no_autoflush=no_autoflush, + bind_arguments=bind_arguments, ) def load_on_pk_identity( - query, + session, + statement, primary_key_identity, + load_options=None, refresh_state=None, with_for_update=None, only_load_props=None, identity_token=None, + no_autoflush=False, + bind_arguments=util.immutabledict(), ): """Load the given primary key identity from the database.""" + query = statement + q = query._clone() + + # TODO: fix these imports .... + from .context import QueryContext, ORMCompileState + + if load_options is None: + load_options = QueryContext.default_load_options + + compile_options = ORMCompileState.default_compile_options.merge( + q.compile_options + ) + + # checking that query doesnt have criteria on it + # just delete it here w/ optional assertion? since we are setting a + # where clause also if refresh_state is None: - q = query._clone() - q._get_condition() - else: - q = query._clone() + _no_criterion_assertion(q, "get", order_by=False, distinct=False) if primary_key_identity is not None: - mapper = query._only_full_mapper_zero("load_on_pk_identity") + # mapper = query._only_full_mapper_zero("load_on_pk_identity") + + # TODO: error checking? + mapper = query._raw_columns[0]._annotations["parententity"] (_get_clause, _get_params) = mapper._get_clause @@ -320,9 +397,8 @@ def load_on_pk_identity( ] ) - q.load_options += {"_params": params} + load_options += {"_params": params} - # with_for_update needs to be query.LockmodeArg() if with_for_update is not None: version_check = True q._for_update_arg = with_for_update @@ -333,11 +409,15 @@ def load_on_pk_identity( version_check = False if refresh_state and refresh_state.load_options: - # if refresh_state.load_path.parent: - q = q._with_current_path(refresh_state.load_path.parent) - q = q.options(refresh_state.load_options) + compile_options += {"_current_path": refresh_state.load_path.parent} + q = q.options(*refresh_state.load_options) - q._get_options( + # TODO: most of the compile_options that are not legacy only involve this + # function, so try to see if handling of them can mostly be local to here + + q.compile_options, load_options = _set_get_options( + compile_options, + load_options, populate_existing=bool(refresh_state), version_check=version_check, only_load_props=only_load_props, @@ -346,12 +426,76 @@ def load_on_pk_identity( ) q._order_by = None + if no_autoflush: + load_options += {"_autoflush": False} + + result = ( + session.execute( + q, + params=load_options._params, + execution_options={"_sa_orm_load_options": load_options}, + bind_arguments=bind_arguments, + ) + .unique() + .scalars() + ) + try: - return q.one() + return result.one() except orm_exc.NoResultFound: return None +def _no_criterion_assertion(stmt, meth, order_by=True, distinct=True): + if ( + stmt._where_criteria + or stmt.compile_options._statement is not None + or stmt._from_obj + or stmt._legacy_setup_joins + or stmt._limit_clause is not None + or stmt._offset_clause is not None + or stmt._group_by_clauses + or (order_by and stmt._order_by_clauses) + or (distinct and stmt._distinct) + ): + raise sa_exc.InvalidRequestError( + "Query.%s() being called on a " + "Query with existing criterion. " % meth + ) + + +def _set_get_options( + compile_opt, + load_opt, + populate_existing=None, + version_check=None, + only_load_props=None, + refresh_state=None, + identity_token=None, +): + + compile_options = {} + load_options = {} + if version_check: + load_options["_version_check"] = version_check + if populate_existing: + load_options["_populate_existing"] = populate_existing + if refresh_state: + load_options["_refresh_state"] = refresh_state + compile_options["_for_refresh_state"] = True + if only_load_props: + compile_options["_only_load_props"] = frozenset(only_load_props) + if identity_token: + load_options["_refresh_identity_token"] = identity_token + + if load_options: + load_opt += load_options + if compile_options: + compile_opt += compile_options + + return compile_opt, load_opt + + def _setup_entity_query( compile_state, mapper, @@ -487,7 +631,7 @@ def _instance_processor( context, path, mapper, result, adapter, populators ) - propagate_options = context.propagate_options + propagated_loader_options = context.propagated_loader_options load_path = ( context.compile_state.current_path + path if context.compile_state.current_path.path @@ -639,8 +783,8 @@ def _instance_processor( # be conservative about setting load_path when populate_existing # is in effect; want to maintain options from the original # load. see test_expire->test_refresh_maintains_deferred_options - if isnew and (propagate_options or not populate_existing): - state.load_options = propagate_options + if isnew and (propagated_loader_options or not populate_existing): + state.load_options = propagated_loader_options state.load_path = load_path _populate_full( @@ -1055,7 +1199,7 @@ def load_scalar_attributes(mapper, state, attribute_names, passive): result = False - no_autoflush = passive & attributes.NO_AUTOFLUSH + no_autoflush = bool(passive & attributes.NO_AUTOFLUSH) # in the case of inheritance, particularly concrete and abstract # concrete inheritance, the class manager might have some keys @@ -1080,10 +1224,16 @@ def load_scalar_attributes(mapper, state, attribute_names, passive): # note: using from_statement() here means there is an adaption # with adapt_on_names set up. the other option is to make the # aliased() against a subquery which affects the SQL. + + from .query import FromStatement + + stmt = FromStatement(mapper, statement).options( + strategy_options.Load(mapper).undefer("*") + ) + result = load_on_ident( - session.query(mapper) - .options(strategy_options.Load(mapper).undefer("*")) - .from_statement(statement), + session, + stmt, None, only_load_props=attribute_names, refresh_state=state, @@ -1121,7 +1271,8 @@ def load_scalar_attributes(mapper, state, attribute_names, passive): return result = load_on_ident( - session.query(mapper), + session, + future.select(mapper).apply_labels(), identity_key, refresh_state=state, only_load_props=attribute_names, diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index a6fb1039f..7bfe70c36 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -2237,6 +2237,8 @@ class Mapper( "parentmapper": self, "compile_state_plugin": "orm", } + )._set_propagate_attrs( + {"compile_state_plugin": "orm", "plugin_subject": self} ) @property diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py index 1698a5181..2e5941713 100644 --- a/lib/sqlalchemy/orm/path_registry.py +++ b/lib/sqlalchemy/orm/path_registry.py @@ -228,10 +228,29 @@ class RootRegistry(PathRegistry): PathRegistry.root = RootRegistry() +class PathToken(HasCacheKey, str): + """cacheable string token""" + + _intern = {} + + def _gen_cache_key(self, anon_map, bindparams): + return (str(self),) + + @classmethod + def intern(cls, strvalue): + if strvalue in cls._intern: + return cls._intern[strvalue] + else: + cls._intern[strvalue] = result = PathToken(strvalue) + return result + + class TokenRegistry(PathRegistry): __slots__ = ("token", "parent", "path", "natural_path") def __init__(self, parent, token): + token = PathToken.intern(token) + self.token = token self.parent = parent self.path = parent.path + (token,) diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index d14f6c27b..163ebf22a 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -25,6 +25,7 @@ from . import loading from . import sync from .base import state_str from .. import exc as sa_exc +from .. import future from .. import sql from .. import util from ..sql import coercions @@ -1424,8 +1425,10 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states): if toload_now: state.key = base_mapper._identity_key_from_state(state) + stmt = future.select(mapper).apply_labels() loading.load_on_ident( - uowtransaction.session.query(mapper), + uowtransaction.session, + stmt, state.key, refresh_state=state, only_load_props=toload_now, @@ -1723,7 +1726,7 @@ class BulkUD(object): self.context ) = compile_state = query._compile_state() - self.mapper = compile_state._bind_mapper() + self.mapper = compile_state._entity_zero() if isinstance( compile_state._entities[0], query_context._RawColumnEntity, diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 027786c19..4cf501e3f 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -346,14 +346,20 @@ class ColumnProperty(StrategizedProperty): pe = self._parententity # no adapter, so we aren't aliased # assert self._parententity is self._parentmapper - return self.prop.columns[0]._annotate( - { - "entity_namespace": pe, - "parententity": pe, - "parentmapper": pe, - "orm_key": self.prop.key, - "compile_state_plugin": "orm", - } + return ( + self.prop.columns[0] + ._annotate( + { + "entity_namespace": pe, + "parententity": pe, + "parentmapper": pe, + "orm_key": self.prop.key, + "compile_state_plugin": "orm", + } + ) + ._set_propagate_attrs( + {"compile_state_plugin": "orm", "plugin_subject": pe} + ) ) def _memoized_attr_info(self): @@ -388,6 +394,11 @@ class ColumnProperty(StrategizedProperty): "orm_key": self.prop.key, "compile_state_plugin": "orm", } + )._set_propagate_attrs( + { + "compile_state_plugin": "orm", + "plugin_subject": self._parententity, + } ) for col in self.prop.columns ] diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 8a861c3dc..25d6f4736 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -18,6 +18,7 @@ ORM session, whereas the ``Select`` construct interacts directly with the database to return iterable result sets. """ +import itertools from . import attributes from . import exc as orm_exc @@ -28,7 +29,8 @@ from .base import _assertions from .context import _column_descriptions from .context import _legacy_determine_last_joined_entity from .context import _legacy_filter_by_entity_zero -from .context import QueryCompileState +from .context import ORMCompileState +from .context import ORMFromStatementCompileState from .context import QueryContext from .interfaces import ORMColumnsClauseRole from .util import aliased @@ -42,18 +44,22 @@ from .. import inspection from .. import log from .. import sql from .. import util +from ..future.selectable import Select as FutureSelect from ..sql import coercions from ..sql import expression from ..sql import roles from ..sql import util as sql_util +from ..sql.annotation import SupportsCloneAnnotations from ..sql.base import _generative from ..sql.base import Executable +from ..sql.selectable import _SelectFromElements from ..sql.selectable import ForUpdateArg from ..sql.selectable import HasHints from ..sql.selectable import HasPrefixes from ..sql.selectable import HasSuffixes from ..sql.selectable import LABEL_STYLE_NONE from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..sql.selectable import SelectStatementGrouping from ..sql.util import _entity_namespace_key from ..util import collections_abc @@ -62,7 +68,15 @@ __all__ = ["Query", "QueryContext", "aliased"] @inspection._self_inspects @log.class_logger -class Query(HasPrefixes, HasSuffixes, HasHints, Executable): +class Query( + _SelectFromElements, + SupportsCloneAnnotations, + HasPrefixes, + HasSuffixes, + HasHints, + Executable, +): + """ORM-level SQL construction object. :class:`_query.Query` @@ -105,7 +119,7 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): _legacy_setup_joins = () _label_style = LABEL_STYLE_NONE - compile_options = QueryCompileState.default_compile_options + compile_options = ORMCompileState.default_compile_options load_options = QueryContext.default_load_options @@ -115,6 +129,11 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): _enable_assertions = True _last_joined_entity = None + # mirrors that of ClauseElement, used to propagate the "orm" + # plugin as well as the "subject" of the plugin, e.g. the mapper + # we are querying against. + _propagate_attrs = util.immutabledict() + def __init__(self, entities, session=None): """Construct a :class:`_query.Query` directly. @@ -148,7 +167,9 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): def _set_entities(self, entities): self._raw_columns = [ - coercions.expect(roles.ColumnsClauseRole, ent) + coercions.expect( + roles.ColumnsClauseRole, ent, apply_propagate_attrs=self + ) for ent in util.to_list(entities) ] @@ -183,7 +204,10 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): def _set_select_from(self, obj, set_base_alias): fa = [ coercions.expect( - roles.StrictFromClauseRole, elem, allow_select=True + roles.StrictFromClauseRole, + elem, + allow_select=True, + apply_propagate_attrs=self, ) for elem in obj ] @@ -332,15 +356,13 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): if ( not self.compile_options._set_base_alias and not self.compile_options._with_polymorphic_adapt_map - and self.compile_options._statement is None + # and self.compile_options._statement is None ): # if we don't have legacy top level aliasing features in use # then convert to a future select() directly stmt = self._statement_20() else: - stmt = QueryCompileState._create_for_legacy_query( - self, for_statement=True - ).statement + stmt = self._compile_state(for_statement=True).statement if self.load_options._params: # this is the search and replace thing. this is kind of nuts @@ -349,8 +371,67 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): return stmt - def _statement_20(self): - return QueryCompileState._create_future_select_from_query(self) + def _statement_20(self, orm_results=False): + # TODO: this event needs to be deprecated, as it currently applies + # only to ORM query and occurs at this spot that is now more + # or less an artificial spot + if self.dispatch.before_compile: + for fn in self.dispatch.before_compile: + new_query = fn(self) + if new_query is not None and new_query is not self: + self = new_query + if not fn._bake_ok: + self.compile_options += {"_bake_ok": False} + + if self.compile_options._statement is not None: + stmt = FromStatement( + self._raw_columns, self.compile_options._statement + ) + # TODO: once SubqueryLoader uses select(), we can remove + # "_orm_query" from this structure + stmt.__dict__.update( + _with_options=self._with_options, + _with_context_options=self._with_context_options, + compile_options=self.compile_options + + {"_orm_query": self.with_session(None)}, + _execution_options=self._execution_options, + ) + stmt._propagate_attrs = self._propagate_attrs + else: + stmt = FutureSelect.__new__(FutureSelect) + + stmt.__dict__.update( + _raw_columns=self._raw_columns, + _where_criteria=self._where_criteria, + _from_obj=self._from_obj, + _legacy_setup_joins=self._legacy_setup_joins, + _order_by_clauses=self._order_by_clauses, + _group_by_clauses=self._group_by_clauses, + _having_criteria=self._having_criteria, + _distinct=self._distinct, + _distinct_on=self._distinct_on, + _with_options=self._with_options, + _with_context_options=self._with_context_options, + _hints=self._hints, + _statement_hints=self._statement_hints, + _correlate=self._correlate, + _auto_correlate=self._auto_correlate, + _limit_clause=self._limit_clause, + _offset_clause=self._offset_clause, + _for_update_arg=self._for_update_arg, + _prefixes=self._prefixes, + _suffixes=self._suffixes, + _label_style=self._label_style, + compile_options=self.compile_options + + {"_orm_query": self.with_session(None)}, + _execution_options=self._execution_options, + ) + + if not orm_results: + stmt.compile_options += {"_orm_results": False} + + stmt._propagate_attrs = self._propagate_attrs + return stmt def subquery(self, name=None, with_labels=False, reduce_columns=False): """return the full SELECT statement represented by @@ -879,7 +960,17 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): elif instance is attributes.PASSIVE_CLASS_MISMATCH: return None - return db_load_fn(self, primary_key_identity) + # apply_labels() not strictly necessary, however this will ensure that + # tablename_colname style is used which at the moment is asserted + # in a lot of unit tests :) + + statement = self._statement_20(orm_results=True).apply_labels() + return db_load_fn( + self.session, + statement, + primary_key_identity, + load_options=self.load_options, + ) @property def lazy_loaded_from(self): @@ -1059,7 +1150,9 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): self._raw_columns = list(self._raw_columns) self._raw_columns.append( - coercions.expect(roles.ColumnsClauseRole, entity) + coercions.expect( + roles.ColumnsClauseRole, entity, apply_propagate_attrs=self + ) ) @_generative @@ -1397,7 +1490,10 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): self._raw_columns = list(self._raw_columns) self._raw_columns.extend( - coercions.expect(roles.ColumnsClauseRole, c) for c in column + coercions.expect( + roles.ColumnsClauseRole, c, apply_propagate_attrs=self + ) + for c in column ) @util.deprecated( @@ -1584,7 +1680,9 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): """ for criterion in list(criterion): - criterion = coercions.expect(roles.WhereHavingRole, criterion) + criterion = coercions.expect( + roles.WhereHavingRole, criterion, apply_propagate_attrs=self + ) # legacy vvvvvvvvvvvvvvvvvvvvvvvvvvv if self._aliased_generation: @@ -1742,7 +1840,9 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): """ self._having_criteria += ( - coercions.expect(roles.WhereHavingRole, criterion), + coercions.expect( + roles.WhereHavingRole, criterion, apply_propagate_attrs=self + ), ) def _set_op(self, expr_fn, *q): @@ -2177,7 +2277,12 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): self._legacy_setup_joins += tuple( ( - coercions.expect(roles.JoinTargetRole, prop[0], legacy=True), + coercions.expect( + roles.JoinTargetRole, + prop[0], + legacy=True, + apply_propagate_attrs=self, + ), prop[1] if len(prop) == 2 else None, None, { @@ -2605,7 +2710,9 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): ORM tutorial """ - statement = coercions.expect(roles.SelectStatementRole, statement) + statement = coercions.expect( + roles.SelectStatementRole, statement, apply_propagate_attrs=self + ) self.compile_options += {"_statement": statement} def first(self): @@ -2711,76 +2818,50 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): def __iter__(self): return self._iter().__iter__() - # TODO: having _iter(), _execute_and_instances, _connection_from_session, - # etc., is all too much. + def _iter(self): + # new style execution. + params = self.load_options._params + statement = self._statement_20(orm_results=True) + result = self.session.execute( + statement, + params, + execution_options={"_sa_orm_load_options": self.load_options}, + ) - # new recipes / extensions should be based on an event hook of some kind, - # can allow an execution that would return a Result to take in all the - # information and return a different Result. this has to be at - # the session / connection .execute() level, and can perhaps be - # before_execute() but needs to be focused around rewriting of results. + # legacy: automatically set scalars, unique + if result._attributes.get("is_single_entity", False): + result = result.scalars() - # the dialect do_execute() *may* be this but that seems a bit too low - # level. it may need to be ORM session based and be a session event, - # becasue it might not invoke the cursor, might invoke for multiple - # connections, etc. OK really has to be a session level event in this - # case to support horizontal sharding. + if result._attributes.get("filtered", False): + result = result.unique() - def _iter(self): - context = self._compile_context() + return result + + def _execute_crud(self, stmt, mapper): + conn = self.session.connection( + mapper=mapper, clause=stmt, close_with_result=True + ) - if self.load_options._autoflush: - self.session._autoflush() - return self._execute_and_instances(context) + return conn._execute_20( + stmt, self.load_options._params, self._execution_options + ) def __str__(self): - compile_state = self._compile_state() + statement = self._statement_20(orm_results=True) + try: bind = ( - self._get_bind_args(compile_state, self.session.get_bind) + self._get_bind_args(statement, self.session.get_bind) if self.session else None ) except sa_exc.UnboundExecutionError: bind = None - return str(compile_state.statement.compile(bind)) - - def _connection_from_session(self, **kw): - conn = self.session.connection(**kw) - if self._execution_options: - conn = conn.execution_options(**self._execution_options) - return conn - - def _execute_and_instances(self, querycontext, params=None): - conn = self._get_bind_args( - querycontext.compile_state, - self._connection_from_session, - close_with_result=True, - ) - if params is None: - params = querycontext.load_options._params + return str(statement.compile(bind)) - result = conn._execute_20( - querycontext.compile_state.statement, - params, - # execution_options=self.session._orm_execution_options(), - ) - return loading.instances(querycontext.query, result, querycontext) - - def _execute_crud(self, stmt, mapper): - conn = self._connection_from_session( - mapper=mapper, clause=stmt, close_with_result=True - ) - - return conn.execute(stmt, self.load_options._params) - - def _get_bind_args(self, compile_state, fn, **kw): - return fn( - mapper=compile_state._bind_mapper(), - clause=compile_state.statement, - **kw - ) + def _get_bind_args(self, statement, fn, **kw): + return fn(clause=statement, **kw) @property def column_descriptions(self): @@ -2837,10 +2918,21 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): "for linking ORM results to arbitrary select constructs.", version="1.4", ) - compile_state = QueryCompileState._create_for_legacy_query(self) - context = QueryContext(compile_state, self.session) + compile_state = ORMCompileState._create_for_legacy_query(self) + context = QueryContext( + compile_state, self.session, self.load_options + ) + + result = loading.instances(result_proxy, context) + + # legacy: automatically set scalars, unique + if result._attributes.get("is_single_entity", False): + result = result.scalars() + + if result._attributes.get("filtered", False): + result = result.unique() - return loading.instances(self, result_proxy, context) + return result def merge_result(self, iterator, load=True): """Merge a result into this :class:`_query.Query` object's Session. @@ -3239,36 +3331,62 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): return update_op.rowcount def _compile_state(self, for_statement=False, **kw): - # TODO: this needs to become a general event for all - # Executable objects as well (all ClauseElement?) - # but then how do we clarify that this event is only for - # *top level* compile, not as an embedded element is visted? - # how does that even work because right now a Query that does things - # like from_self() will in fact invoke before_compile for each - # inner element. - # OK perhaps with 2.0 style folks will continue using before_execute() - # as they can now, as a select() with ORM elements will be delivered - # there, OK. sort of fixes the "bake_ok" problem too. - if self.dispatch.before_compile: - for fn in self.dispatch.before_compile: - new_query = fn(self) - if new_query is not None and new_query is not self: - self = new_query - if not fn._bake_ok: - self.compile_options += {"_bake_ok": False} - - compile_state = QueryCompileState._create_for_legacy_query( + return ORMCompileState._create_for_legacy_query( self, for_statement=for_statement, **kw ) - return compile_state def _compile_context(self, for_statement=False): compile_state = self._compile_state(for_statement=for_statement) - context = QueryContext(compile_state, self.session) + context = QueryContext(compile_state, self.session, self.load_options) return context +class FromStatement(SelectStatementGrouping, Executable): + """Core construct that represents a load of ORM objects from a finished + select or text construct. + + """ + + compile_options = ORMFromStatementCompileState.default_compile_options + + _compile_state_factory = ORMFromStatementCompileState.create_for_statement + + _is_future = True + + _for_update_arg = None + + def __init__(self, entities, element): + self._raw_columns = [ + coercions.expect( + roles.ColumnsClauseRole, ent, apply_propagate_attrs=self + ) + for ent in util.to_list(entities) + ] + super(FromStatement, self).__init__(element) + + def _compiler_dispatch(self, compiler, **kw): + compile_state = self._compile_state_factory(self, self, **kw) + + toplevel = not compiler.stack + + if toplevel: + compiler.compile_state = compile_state + + return compiler.process(compile_state.statement, **kw) + + def _ensure_disambiguated_names(self): + return self + + def get_children(self, **kw): + for elem in itertools.chain.from_iterable( + element._from_objects for element in self._raw_columns + ): + yield elem + for elem in super(FromStatement, self).get_children(**kw): + yield elem + + class AliasOption(interfaces.LoaderOption): @util.deprecated( "1.4", diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index f539e968f..e82cd174f 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -2737,12 +2737,12 @@ class JoinCondition(object): def replace(element): if "remote" in element._annotations: - v = element._annotations.copy() + v = dict(element._annotations) del v["remote"] v["local"] = True return element._with_annotations(v) elif "local" in element._annotations: - v = element._annotations.copy() + v = dict(element._annotations) del v["local"] v["remote"] = True return element._with_annotations(v) diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 6cb8a0062..8d2f13df3 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -12,6 +12,7 @@ import sys import weakref from . import attributes +from . import context from . import exc from . import identity from . import loading @@ -28,13 +29,12 @@ from .base import state_str from .unitofwork import UOWTransaction from .. import engine from .. import exc as sa_exc -from .. import sql +from .. import future from .. import util from ..inspection import inspect from ..sql import coercions from ..sql import roles -from ..sql import util as sql_util - +from ..sql import visitors __all__ = ["Session", "SessionTransaction", "sessionmaker"] @@ -98,6 +98,160 @@ DEACTIVE = util.symbol("DEACTIVE") CLOSED = util.symbol("CLOSED") +class ORMExecuteState(object): + """Stateful object used for the :meth:`.SessionEvents.do_orm_execute` + + .. versionadded:: 1.4 + + """ + + __slots__ = ( + "session", + "statement", + "parameters", + "execution_options", + "bind_arguments", + ) + + def __init__( + self, session, statement, parameters, execution_options, bind_arguments + ): + self.session = session + self.statement = statement + self.parameters = parameters + self.execution_options = execution_options + self.bind_arguments = bind_arguments + + def invoke_statement( + self, + statement=None, + params=None, + execution_options=None, + bind_arguments=None, + ): + """Execute the statement represented by this + :class:`.ORMExecuteState`, without re-invoking events. + + This method essentially performs a re-entrant execution of the + current statement for which the :meth:`.SessionEvents.do_orm_execute` + event is being currently invoked. The use case for this is + for event handlers that want to override how the ultimate results + object is returned, such as for schemes that retrieve results from + an offline cache or which concatenate results from multiple executions. + + :param statement: optional statement to be invoked, in place of the + statement currently represented by :attr:`.ORMExecuteState.statement`. + + :param params: optional dictionary of parameters which will be merged + into the existing :attr:`.ORMExecuteState.parameters` of this + :class:`.ORMExecuteState`. + + :param execution_options: optional dictionary of execution options + will be merged into the existing + :attr:`.ORMExecuteState.execution_options` of this + :class:`.ORMExecuteState`. + + :param bind_arguments: optional dictionary of bind_arguments + which will be merged amongst the current + :attr:`.ORMExecuteState.bind_arguments` + of this :class:`.ORMExecuteState`. + + :return: a :class:`_engine.Result` object with ORM-level results. + + .. seealso:: + + :ref:`examples_caching` - includes example use of the + :meth:`.SessionEvents.do_orm_execute` hook as well as the + :meth:`.ORMExecuteState.invoke_query` method. + + + """ + + if statement is None: + statement = self.statement + + _bind_arguments = dict(self.bind_arguments) + if bind_arguments: + _bind_arguments.update(bind_arguments) + _bind_arguments["_sa_skip_events"] = True + + if params: + _params = dict(self.parameters) + _params.update(params) + else: + _params = self.parameters + + if execution_options: + _execution_options = dict(self.execution_options) + _execution_options.update(execution_options) + else: + _execution_options = self.execution_options + + return self.session.execute( + statement, _params, _execution_options, _bind_arguments + ) + + @property + def orm_query(self): + """Return the :class:`_orm.Query` object associated with this + execution. + + For SQLAlchemy-2.0 style usage, the :class:`_orm.Query` object + is not used at all, and this attribute will return None. + + """ + load_opts = self.load_options + if load_opts._orm_query: + return load_opts._orm_query + + opts = self._orm_compile_options() + if opts is not None: + return opts._orm_query + else: + return None + + def _orm_compile_options(self): + opts = self.statement.compile_options + if isinstance(opts, context.ORMCompileState.default_compile_options): + return opts + else: + return None + + @property + def loader_strategy_path(self): + """Return the :class:`.PathRegistry` for the current load path. + + This object represents the "path" in a query along relationships + when a particular object or collection is being loaded. + + """ + opts = self._orm_compile_options() + if opts is not None: + return opts._current_path + else: + return None + + @property + def load_options(self): + """Return the load_options that will be used for this execution.""" + + return self.execution_options.get( + "_sa_orm_load_options", context.QueryContext.default_load_options + ) + + @property + def user_defined_options(self): + """The sequence of :class:`.UserDefinedOptions` that have been + associated with the statement being invoked. + + """ + return [ + opt + for opt in self.statement._with_options + if not opt._is_compile_state and not opt._is_legacy_option + ] + + class SessionTransaction(object): """A :class:`.Session`-level transaction. @@ -1032,9 +1186,7 @@ class Session(_SessionClassMethods): def connection( self, - mapper=None, - clause=None, - bind=None, + bind_arguments=None, close_with_result=False, execution_options=None, **kw @@ -1059,23 +1211,18 @@ class Session(_SessionClassMethods): resolved through any of the optional keyword arguments. This ultimately makes usage of the :meth:`.get_bind` method for resolution. + :param bind_arguments: dictionary of bind arguments. may include + "mapper", "bind", "clause", other custom arguments that are passed + to :meth:`.Session.get_bind`. + :param bind: - Optional :class:`_engine.Engine` to be used as the bind. If - this engine is already involved in an ongoing transaction, - that connection will be used. This argument takes precedence - over ``mapper``, ``clause``. + deprecated; use bind_arguments :param mapper: - Optional :func:`.mapper` mapped class, used to identify - the appropriate bind. This argument takes precedence over - ``clause``. + deprecated; use bind_arguments :param clause: - A :class:`_expression.ClauseElement` (i.e. - :func:`_expression.select`, - :func:`_expression.text`, - etc.) which will be used to locate a bind, if a bind - cannot otherwise be identified. + deprecated; use bind_arguments :param close_with_result: Passed to :meth:`_engine.Engine.connect`, indicating the :class:`_engine.Connection` should be considered @@ -1097,13 +1244,16 @@ class Session(_SessionClassMethods): :ref:`session_transaction_isolation` :param \**kw: - Additional keyword arguments are sent to :meth:`get_bind()`, - allowing additional arguments to be passed to custom - implementations of :meth:`get_bind`. + deprecated; use bind_arguments """ + + if not bind_arguments: + bind_arguments = kw + + bind = bind_arguments.pop("bind", None) if bind is None: - bind = self.get_bind(mapper, clause=clause, **kw) + bind = self.get_bind(**bind_arguments) return self._connection_for_bind( bind, @@ -1124,7 +1274,14 @@ class Session(_SessionClassMethods): conn = conn.execution_options(**execution_options) return conn - def execute(self, clause, params=None, mapper=None, bind=None, **kw): + def execute( + self, + statement, + params=None, + execution_options=util.immutabledict(), + bind_arguments=None, + **kw + ): r"""Execute a SQL expression construct or string statement within the current transaction. @@ -1222,22 +1379,19 @@ class Session(_SessionClassMethods): "executemany" will be invoked. The keys in each dictionary must correspond to parameter names present in the statement. + :param bind_arguments: dictionary of additional arguments to determine + the bind. may include "mapper", "bind", or other custom arguments. + Contents of this dictionary are passed to the + :meth:`.Session.get_bind` method. + :param mapper: - Optional :func:`.mapper` or mapped class, used to identify - the appropriate bind. This argument takes precedence over - ``clause`` when locating a bind. See :meth:`.Session.get_bind` - for more details. + deprecated; use the bind_arguments dictionary :param bind: - Optional :class:`_engine.Engine` to be used as the bind. If - this engine is already involved in an ongoing transaction, - that connection will be used. This argument takes - precedence over ``mapper`` and ``clause`` when locating - a bind. + deprecated; use the bind_arguments dictionary :param \**kw: - Additional keyword arguments are sent to :meth:`.Session.get_bind()` - to allow extensibility of "bind" schemes. + deprecated; use the bind_arguments dictionary .. seealso:: @@ -1253,20 +1407,63 @@ class Session(_SessionClassMethods): in order to execute the statement. """ - clause = coercions.expect(roles.CoerceTextStatementRole, clause) - if bind is None: - bind = self.get_bind(mapper, clause=clause, **kw) + statement = coercions.expect(roles.CoerceTextStatementRole, statement) - return self._connection_for_bind( - bind, close_with_result=True - )._execute_20(clause, params,) + if not bind_arguments: + bind_arguments = kw + elif kw: + bind_arguments.update(kw) + + compile_state_cls = statement._get_plugin_compile_state_cls("orm") + if compile_state_cls: + compile_state_cls.orm_pre_session_exec( + self, statement, execution_options, bind_arguments + ) + else: + bind_arguments.setdefault("clause", statement) + if statement._is_future: + execution_options = util.immutabledict().merge_with( + execution_options, {"future_result": True} + ) + + if self.dispatch.do_orm_execute: + skip_events = bind_arguments.pop("_sa_skip_events", False) + + if not skip_events: + orm_exec_state = ORMExecuteState( + self, statement, params, execution_options, bind_arguments + ) + for fn in self.dispatch.do_orm_execute: + result = fn(orm_exec_state) + if result: + return result + + bind = self.get_bind(**bind_arguments) + + conn = self._connection_for_bind(bind, close_with_result=True) + result = conn._execute_20(statement, params or {}, execution_options) - def scalar(self, clause, params=None, mapper=None, bind=None, **kw): + if compile_state_cls: + result = compile_state_cls.orm_setup_cursor_result( + self, bind_arguments, result + ) + + return result + + def scalar( + self, + statement, + params=None, + execution_options=None, + mapper=None, + bind=None, + **kw + ): """Like :meth:`~.Session.execute` but return a scalar result.""" return self.execute( - clause, params=params, mapper=mapper, bind=bind, **kw + statement, params=params, mapper=mapper, bind=bind, **kw ).scalar() def close(self): @@ -1422,7 +1619,7 @@ class Session(_SessionClassMethods): """ self._add_bind(table, bind) - def get_bind(self, mapper=None, clause=None): + def get_bind(self, mapper=None, clause=None, bind=None): """Return a "bind" to which this :class:`.Session` is bound. The "bind" is usually an instance of :class:`_engine.Engine`, @@ -1497,6 +1694,8 @@ class Session(_SessionClassMethods): :meth:`.Session.bind_table` """ + if bind: + return bind if mapper is clause is None: if self.bind: @@ -1520,6 +1719,8 @@ class Session(_SessionClassMethods): raise if self.__binds: + # matching mappers and selectables to entries in the + # binds dictionary; supported use case. if mapper: for cls in mapper.class_.__mro__: if cls in self.__binds: @@ -1528,18 +1729,32 @@ class Session(_SessionClassMethods): clause = mapper.persist_selectable if clause is not None: - for t in sql_util.find_tables(clause, include_crud=True): - if t in self.__binds: - return self.__binds[t] + for obj in visitors.iterate(clause): + if obj in self.__binds: + return self.__binds[obj] + # session has a single bind; supported use case. if self.bind: return self.bind - if isinstance(clause, sql.expression.ClauseElement) and clause.bind: - return clause.bind + # now we are in legacy territory. looking for "bind" on tables + # that are via bound metadata. this goes away in 2.0. + if mapper and clause is None: + clause = mapper.persist_selectable - if mapper and mapper.persist_selectable.bind: - return mapper.persist_selectable.bind + if clause is not None: + if clause.bind: + return clause.bind + # for obj in visitors.iterate(clause): + # if obj.bind: + # return obj.bind + + if mapper: + if mapper.persist_selectable.bind: + return mapper.persist_selectable.bind + # for obj in visitors.iterate(mapper.persist_selectable): + # if obj.bind: + # return obj.bind context = [] if mapper is not None: @@ -1722,9 +1937,11 @@ class Session(_SessionClassMethods): else: with_for_update = None + stmt = future.select(object_mapper(instance)) if ( loading.load_on_ident( - self.query(object_mapper(instance)), + self, + stmt, state.key, refresh_state=state, with_for_update=with_for_update, diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index c0c090b3d..a7d501b53 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -33,6 +33,7 @@ from .util import _none_set from .util import aliased from .. import event from .. import exc as sa_exc +from .. import future from .. import inspect from .. import log from .. import sql @@ -440,10 +441,13 @@ class DeferredColumnLoader(LoaderStrategy): if self.raiseload: self._invoke_raise_load(state, passive, "raise") - query = session.query(localparent) if ( loading.load_on_ident( - query, state.key, only_load_props=group, refresh_state=state + session, + future.select(localparent).apply_labels(), + state.key, + only_load_props=group, + refresh_state=state, ) is None ): @@ -897,7 +901,7 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): q(session) .with_post_criteria(lambda q: q._set_lazyload_from(state)) ._load_on_pk_identity( - session.query(self.mapper), primary_key_identity + session, session.query(self.mapper), primary_key_identity ) ) @@ -1090,7 +1094,6 @@ class SubqueryLoader(PostLoader): parentmapper=None, **kwargs ): - if ( not compile_state.compile_options._enable_eagerloads or compile_state.compile_options._for_refresh_state @@ -1146,6 +1149,7 @@ class SubqueryLoader(PostLoader): # generate a new Query from the original, then # produce a subquery from it. left_alias = self._generate_from_original_query( + compile_state, orig_query, leftmost_mapper, leftmost_attr, @@ -1164,7 +1168,9 @@ class SubqueryLoader(PostLoader): def set_state_options(compile_state): compile_state.attributes.update( { - ("orig_query", SubqueryLoader): orig_query, + ("orig_query", SubqueryLoader): orig_query.with_session( + None + ), ("subquery_path", None): subq_path, } ) @@ -1188,6 +1194,7 @@ class SubqueryLoader(PostLoader): # by create_row_processor # NOTE: be sure to consult baked.py for some hardcoded logic # about this structure as well + assert q.session is None path.set( compile_state.attributes, "subqueryload_data", {"query": q}, ) @@ -1218,6 +1225,7 @@ class SubqueryLoader(PostLoader): def _generate_from_original_query( self, + orig_compile_state, orig_query, leftmost_mapper, leftmost_attr, @@ -1243,11 +1251,18 @@ class SubqueryLoader(PostLoader): } ) - cs = q._clone() + # NOTE: keystone has a test which is counting before_compile + # events. That test is in one case dependent on an extra + # call that was occurring here within the subqueryloader setup + # process, probably when the subquery() method was called. + # Ultimately that call will not be occurring here. + # the event has already been called on the original query when + # we are here in any case, so keystone will need to adjust that + # test. - # using the _compile_state method so that the before_compile() - # event is hit here. keystone is testing for this. - compile_state = cs._compile_state(entities_only=True) + # for column information, look to the compile state that is + # already being passed through + compile_state = orig_compile_state # select from the identity columns of the outer (specifically, these # are the 'local_cols' of the property). This will remove @@ -1260,7 +1275,6 @@ class SubqueryLoader(PostLoader): ], compile_state._get_current_adapter(), ) - # q.add_columns.non_generative(q, target_cols) q._set_entities(target_cols) distinct_target_key = leftmost_relationship.distinct_target_key @@ -1428,10 +1442,20 @@ class SubqueryLoader(PostLoader): """ - __slots__ = ("subq_info", "subq", "_data") + __slots__ = ( + "session", + "execution_options", + "load_options", + "subq", + "_data", + ) - def __init__(self, subq_info): - self.subq_info = subq_info + def __init__(self, context, subq_info): + # avoid creating a cycle by storing context + # even though that's preferable + self.session = context.session + self.execution_options = context.execution_options + self.load_options = context.load_options self.subq = subq_info["query"] self._data = None @@ -1443,7 +1467,17 @@ class SubqueryLoader(PostLoader): def _load(self): self._data = collections.defaultdict(list) - rows = list(self.subq) + q = self.subq + assert q.session is None + if "compiled_cache" in self.execution_options: + q = q.execution_options( + compiled_cache=self.execution_options["compiled_cache"] + ) + q = q.with_session(self.session) + + # to work with baked query, the parameters may have been + # updated since this query was created, so take these into account + rows = list(q.params(self.load_options._params)) for k, v in itertools.groupby(rows, lambda x: x[1:]): self._data[k].extend(vv[0] for vv in v) @@ -1474,14 +1508,7 @@ class SubqueryLoader(PostLoader): subq = subq_info["query"] - if subq.session is None: - subq.session = context.session - assert subq.session is context.session, ( - "Subquery session doesn't refer to that of " - "our context. Are there broken context caching " - "schemes being used?" - ) - + assert subq.session is None local_cols = self.parent_property.local_columns # cache the loaded collections in the context @@ -1489,7 +1516,7 @@ class SubqueryLoader(PostLoader): # call upon create_row_processor again collections = path.get(context.attributes, "collections") if collections is None: - collections = self._SubqCollections(subq_info) + collections = self._SubqCollections(context, subq_info) path.set(context.attributes, "collections", collections) if adapter: diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 1e415e49c..ce37d962e 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -41,6 +41,7 @@ from ..sql import expression from ..sql import roles from ..sql import util as sql_util from ..sql import visitors +from ..sql.annotation import SupportsCloneAnnotations from ..sql.base import ColumnCollection @@ -694,6 +695,8 @@ class AliasedInsp( "entity_namespace": self, "compile_state_plugin": "orm", } + )._set_propagate_attrs( + {"compile_state_plugin": "orm", "plugin_subject": self} ) @property @@ -748,10 +751,20 @@ class AliasedInsp( ) def _adapt_element(self, elem, key=None): - d = {"parententity": self, "parentmapper": self.mapper} + d = { + "parententity": self, + "parentmapper": self.mapper, + "compile_state_plugin": "orm", + } if key: d["orm_key"] = key - return self._adapter.traverse(elem)._annotate(d) + return ( + self._adapter.traverse(elem) + ._annotate(d) + ._set_propagate_attrs( + {"compile_state_plugin": "orm", "plugin_subject": self} + ) + ) def _entity_for_mapper(self, mapper): self_poly = self.with_polymorphic_mappers @@ -1037,7 +1050,7 @@ def with_polymorphic( @inspection._self_inspects -class Bundle(ORMColumnsClauseRole, InspectionAttr): +class Bundle(ORMColumnsClauseRole, SupportsCloneAnnotations, InspectionAttr): """A grouping of SQL expressions that are returned by a :class:`.Query` under one namespace. @@ -1070,6 +1083,8 @@ class Bundle(ORMColumnsClauseRole, InspectionAttr): is_bundle = True + _propagate_attrs = util.immutabledict() + def __init__(self, name, *exprs, **kw): r"""Construct a new :class:`.Bundle`. @@ -1090,7 +1105,10 @@ class Bundle(ORMColumnsClauseRole, InspectionAttr): """ self.name = self._label = name self.exprs = exprs = [ - coercions.expect(roles.ColumnsClauseRole, expr) for expr in exprs + coercions.expect( + roles.ColumnsClauseRole, expr, apply_propagate_attrs=self + ) + for expr in exprs ] self.c = self.columns = ColumnCollection( @@ -1145,11 +1163,14 @@ class Bundle(ORMColumnsClauseRole, InspectionAttr): return cloned def __clause_element__(self): + annotations = self._annotations.union( + {"bundle": self, "entity_namespace": self} + ) return expression.ClauseList( _literal_as_text_role=roles.ColumnsClauseRole, group=False, *[e._annotations.get("bundle", e) for e in self.exprs] - )._annotate({"bundle": self, "entity_namespace": self}) + )._annotate(annotations) @property def clauses(self): diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index 71d05f38f..08ed121d3 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -17,8 +17,12 @@ from .traversals import anon_map from .visitors import InternalTraversal from .. import util +EMPTY_ANNOTATIONS = util.immutabledict() + class SupportsAnnotations(object): + _annotations = EMPTY_ANNOTATIONS + @util.memoized_property def _annotations_cache_key(self): anon_map_ = anon_map() @@ -40,7 +44,6 @@ class SupportsAnnotations(object): class SupportsCloneAnnotations(SupportsAnnotations): - _annotations = util.immutabledict() _clone_annotations_traverse_internals = [ ("_annotations", InternalTraversal.dp_annotations_key) @@ -113,12 +116,9 @@ class SupportsWrappingAnnotations(SupportsAnnotations): """ if clone: - # clone is used when we are also copying - # the expression for a deep deannotation - return self._clone() + s = self._clone() + return s else: - # if no clone, since we have no annotations we return - # self return self @@ -163,12 +163,11 @@ class Annotated(object): self.__dict__.pop("_annotations_cache_key", None) self.__dict__.pop("_generate_cache_key", None) self.__element = element - self._annotations = values + self._annotations = util.immutabledict(values) self._hash = hash(element) def _annotate(self, values): - _values = self._annotations.copy() - _values.update(values) + _values = self._annotations.union(values) return self._with_annotations(_values) def _with_annotations(self, values): @@ -183,10 +182,15 @@ class Annotated(object): if values is None: return self.__element else: - _values = self._annotations.copy() - for v in values: - _values.pop(v, None) - return self._with_annotations(_values) + return self._with_annotations( + util.immutabledict( + { + key: value + for key, value in self._annotations.items() + if key not in values + } + ) + ) def _compiler_dispatch(self, visitor, **kw): return self.__element.__class__._compiler_dispatch(self, visitor, **kw) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 04cc34480..bb606a4d6 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -439,46 +439,53 @@ class CompileState(object): plugins = {} @classmethod - def _create(cls, statement, compiler, **kw): + def create_for_statement(cls, statement, compiler, **kw): # factory construction. - if statement._compile_state_plugin is not None: - constructor = cls.plugins.get( - ( - statement._compile_state_plugin, - statement.__visit_name__, - None, - ), - cls, + if statement._propagate_attrs: + plugin_name = statement._propagate_attrs.get( + "compile_state_plugin", "default" ) else: - constructor = cls + plugin_name = "default" + + klass = cls.plugins[(plugin_name, statement.__visit_name__)] - return constructor(statement, compiler, **kw) + if klass is cls: + return cls(statement, compiler, **kw) + else: + return klass.create_for_statement(statement, compiler, **kw) def __init__(self, statement, compiler, **kw): self.statement = statement @classmethod - def get_plugin_classmethod(cls, statement, name): - if statement._compile_state_plugin is not None: - fn = cls.plugins.get( - ( - statement._compile_state_plugin, - statement.__visit_name__, - name, - ), - None, - ) - if fn is not None: - return fn - return getattr(cls, name) + def get_plugin_class(cls, statement): + plugin_name = statement._propagate_attrs.get( + "compile_state_plugin", "default" + ) + try: + return cls.plugins[(plugin_name, statement.__visit_name__)] + except KeyError: + return None @classmethod - def plugin_for(cls, plugin_name, visit_name, method_name=None): - def decorate(fn): - cls.plugins[(plugin_name, visit_name, method_name)] = fn - return fn + def _get_plugin_compile_state_cls(cls, statement, plugin_name): + statement_plugin_name = statement._propagate_attrs.get( + "compile_state_plugin", "default" + ) + if statement_plugin_name != plugin_name: + return None + try: + return cls.plugins[(plugin_name, statement.__visit_name__)] + except KeyError: + return None + + @classmethod + def plugin_for(cls, plugin_name, visit_name): + def decorate(cls_to_decorate): + cls.plugins[(plugin_name, visit_name)] = cls_to_decorate + return cls_to_decorate return decorate @@ -508,12 +515,12 @@ class InPlaceGenerative(HasMemoized): class HasCompileState(Generative): """A class that has a :class:`.CompileState` associated with it.""" - _compile_state_factory = CompileState._create - _compile_state_plugin = None _attributes = util.immutabledict() + _compile_state_factory = CompileState.create_for_statement + class _MetaOptions(type): """metaclass for the Options class.""" @@ -549,6 +556,16 @@ class Options(util.with_metaclass(_MetaOptions)): def add_to_element(self, name, value): return self + {name: getattr(self, name) + value} + @hybridmethod + def _state_dict(self): + return self.__dict__ + + _state_dict_const = util.immutabledict() + + @_state_dict.classlevel + def _state_dict(cls): + return cls._state_dict_const + class CacheableOptions(Options, HasCacheKey): @hybridmethod @@ -590,6 +607,9 @@ class Executable(Generative): def _disable_caching(self): self._cache_enable = HasCacheKey() + def _get_plugin_compile_state_cls(self, plugin_name): + return CompileState._get_plugin_compile_state_cls(self, plugin_name) + @_generative def options(self, *options): """Apply options to this statement. diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 2fc63b82f..d8ef0222a 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -50,19 +50,26 @@ def _document_text_coercion(paramname, meth_rst, param_rst): ) -def expect(role, element, **kw): +def expect(role, element, apply_propagate_attrs=None, **kw): # major case is that we are given a ClauseElement already, skip more # elaborate logic up front if possible impl = _impl_lookup[role] if not isinstance( element, - (elements.ClauseElement, schema.SchemaItem, schema.FetchedValue,), + (elements.ClauseElement, schema.SchemaItem, schema.FetchedValue), ): resolved = impl._resolve_for_clause_element(element, **kw) else: resolved = element + if ( + apply_propagate_attrs is not None + and not apply_propagate_attrs._propagate_attrs + and resolved._propagate_attrs + ): + apply_propagate_attrs._propagate_attrs = resolved._propagate_attrs + if impl._role_class in resolved.__class__.__mro__: if impl._post_coercion: resolved = impl._post_coercion(resolved, **kw) @@ -106,32 +113,32 @@ class RoleImpl(object): self.name = role_class._role_name self._use_inspection = issubclass(role_class, roles.UsesInspection) - def _resolve_for_clause_element( - self, element, argname=None, apply_plugins=None, **kw - ): + def _resolve_for_clause_element(self, element, argname=None, **kw): original_element = element is_clause_element = False + while hasattr(element, "__clause_element__"): is_clause_element = True if not getattr(element, "is_clause_element", False): element = element.__clause_element__() else: - break - - should_apply_plugins = ( - apply_plugins is not None - and apply_plugins._compile_state_plugin is None - ) + return element + + if not is_clause_element: + if self._use_inspection: + insp = inspection.inspect(element, raiseerr=False) + if insp is not None: + insp._post_inspect + try: + element = insp.__clause_element__() + except AttributeError: + self._raise_for_expected(original_element, argname) + else: + return element - if is_clause_element: - if ( - should_apply_plugins - and "compile_state_plugin" in element._annotations - ): - apply_plugins._compile_state_plugin = element._annotations[ - "compile_state_plugin" - ] + return self._literal_coercion(element, argname=argname, **kw) + else: return element if self._use_inspection: @@ -142,14 +149,6 @@ class RoleImpl(object): element = insp.__clause_element__() except AttributeError: self._raise_for_expected(original_element, argname) - else: - if ( - should_apply_plugins - and "compile_state_plugin" in element._annotations - ): - plugin = element._annotations["compile_state_plugin"] - apply_plugins._compile_state_plugin = plugin - return element return self._literal_coercion(element, argname=argname, **kw) @@ -649,8 +648,8 @@ class SelectStatementImpl(_NoTextCoercion, RoleImpl): self._raise_for_expected(original_element, argname, resolved) -class HasCTEImpl(ReturnsRowsImpl, roles.HasCTERole): - pass +class HasCTEImpl(ReturnsRowsImpl): + __slots__ = () class JoinTargetImpl(RoleImpl): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 9a7646743..8eae0ab7d 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -28,6 +28,7 @@ import contextlib import itertools import operator import re +import time from . import base from . import coercions @@ -380,6 +381,54 @@ class Compiled(object): sub-elements of the statement can modify these. """ + compile_state = None + """Optional :class:`.CompileState` object that maintains additional + state used by the compiler. + + Major executable objects such as :class:`_expression.Insert`, + :class:`_expression.Update`, :class:`_expression.Delete`, + :class:`_expression.Select` will generate this + state when compiled in order to calculate additional information about the + object. For the top level object that is to be executed, the state can be + stored here where it can also have applicability towards result set + processing. + + .. versionadded:: 1.4 + + """ + + _rewrites_selected_columns = False + """if True, indicates the compile_state object rewrites an incoming + ReturnsRows (like a Select) so that the columns we compile against in the + result set are not what were expressed on the outside. this is a hint to + the execution context to not link the statement.selected_columns to the + columns mapped in the result object. + + That is, when this flag is False:: + + stmt = some_statement() + + result = conn.execute(stmt) + row = result.first() + + # selected_columns are in a 1-1 relationship with the + # columns in the result, and are targetable in mapping + for col in stmt.selected_columns: + assert col in row._mapping + + When True:: + + # selected columns are not what are in the rows. the context + # rewrote the statement for some other set of selected_columns. + for col in stmt.selected_columns: + assert col not in row._mapping + + + """ + + cache_key = None + _gen_time = None + def __init__( self, dialect, @@ -433,6 +482,7 @@ class Compiled(object): self.string = self.preparer._render_schema_translates( self.string, schema_translate_map ) + self._gen_time = time.time() def _execute_on_connection( self, connection, multiparams, params, execution_options @@ -637,28 +687,6 @@ class SQLCompiler(Compiled): insert_prefetch = update_prefetch = () - compile_state = None - """Optional :class:`.CompileState` object that maintains additional - state used by the compiler. - - Major executable objects such as :class:`_expression.Insert`, - :class:`_expression.Update`, :class:`_expression.Delete`, - :class:`_expression.Select` will generate this - state when compiled in order to calculate additional information about the - object. For the top level object that is to be executed, the state can be - stored here where it can also have applicability towards result set - processing. - - .. versionadded:: 1.4 - - """ - - compile_state_factories = util.immutabledict() - """Dictionary of alternate :class:`.CompileState` factories for given - classes, identified by their visit_name. - - """ - def __init__( self, dialect, @@ -667,7 +695,6 @@ class SQLCompiler(Compiled): column_keys=None, inline=False, linting=NO_LINTING, - compile_state_factories=None, **kwargs ): """Construct a new :class:`.SQLCompiler` object. @@ -734,9 +761,6 @@ class SQLCompiler(Compiled): # dialect.label_length or dialect.max_identifier_length self.truncated_names = {} - if compile_state_factories: - self.compile_state_factories = compile_state_factories - Compiled.__init__(self, dialect, statement, **kwargs) if ( @@ -1542,7 +1566,7 @@ class SQLCompiler(Compiled): compile_state = cs._compile_state_factory(cs, self, **kwargs) - if toplevel: + if toplevel and not self.compile_state: self.compile_state = compile_state entry = self._default_stack_entry if toplevel else self.stack[-1] @@ -2541,6 +2565,13 @@ class SQLCompiler(Compiled): ) return froms + translate_select_structure = None + """if none None, should be a callable which accepts (select_stmt, **kw) + and returns a select object. this is used for structural changes + mostly to accommodate for LIMIT/OFFSET schemes + + """ + def visit_select( self, select_stmt, @@ -2552,7 +2583,17 @@ class SQLCompiler(Compiled): from_linter=None, **kwargs ): + assert select_wraps_for is None, ( + "SQLAlchemy 1.4 requires use of " + "the translate_select_structure hook for structural " + "translations of SELECT objects" + ) + # initial setup of SELECT. the compile_state_factory may now + # be creating a totally different SELECT from the one that was + # passed in. for ORM use this will convert from an ORM-state + # SELECT to a regular "Core" SELECT. other composed operations + # such as computation of joins will be performed. compile_state = select_stmt._compile_state_factory( select_stmt, self, **kwargs ) @@ -2560,9 +2601,29 @@ class SQLCompiler(Compiled): toplevel = not self.stack - if toplevel: + if toplevel and not self.compile_state: self.compile_state = compile_state + # translate step for Oracle, SQL Server which often need to + # restructure the SELECT to allow for LIMIT/OFFSET and possibly + # other conditions + if self.translate_select_structure: + new_select_stmt = self.translate_select_structure( + select_stmt, asfrom=asfrom, **kwargs + ) + + # if SELECT was restructured, maintain a link to the originals + # and assemble a new compile state + if new_select_stmt is not select_stmt: + compile_state_wraps_for = compile_state + select_wraps_for = select_stmt + select_stmt = new_select_stmt + + compile_state = select_stmt._compile_state_factory( + select_stmt, self, **kwargs + ) + select_stmt = compile_state.statement + entry = self._default_stack_entry if toplevel else self.stack[-1] populate_result_map = need_column_expressions = ( @@ -2624,13 +2685,9 @@ class SQLCompiler(Compiled): ] if populate_result_map and select_wraps_for is not None: - # if this select is a compiler-generated wrapper, + # if this select was generated from translate_select, # rewrite the targeted columns in the result map - compile_state_wraps_for = select_wraps_for._compile_state_factory( - select_wraps_for, self, **kwargs - ) - translate = dict( zip( [ @@ -3013,7 +3070,8 @@ class SQLCompiler(Compiled): if toplevel: self.isinsert = True - self.compile_state = compile_state + if not self.compile_state: + self.compile_state = compile_state self.stack.append( { diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 3dc4e917c..467a764d6 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -39,54 +39,8 @@ class DMLState(CompileState): isdelete = False isinsert = False - @classmethod - def _create_insert(cls, statement, compiler, **kw): - return DMLState(statement, compiler, isinsert=True, **kw) - - @classmethod - def _create_update(cls, statement, compiler, **kw): - return DMLState(statement, compiler, isupdate=True, **kw) - - @classmethod - def _create_delete(cls, statement, compiler, **kw): - return DMLState(statement, compiler, isdelete=True, **kw) - - def __init__( - self, - statement, - compiler, - isinsert=False, - isupdate=False, - isdelete=False, - **kw - ): - self.statement = statement - - if isupdate: - self.isupdate = True - self._preserve_parameter_order = ( - statement._preserve_parameter_order - ) - if statement._ordered_values is not None: - self._process_ordered_values(statement) - elif statement._values is not None: - self._process_values(statement) - elif statement._multi_values: - self._process_multi_values(statement) - self._extra_froms = self._make_extra_froms(statement) - elif isinsert: - self.isinsert = True - if statement._select_names: - self._process_select_values(statement) - if statement._values is not None: - self._process_values(statement) - if statement._multi_values: - self._process_multi_values(statement) - elif isdelete: - self.isdelete = True - self._extra_froms = self._make_extra_froms(statement) - else: - assert False, "one of isinsert, isupdate, or isdelete must be set" + def __init__(self, statement, compiler, **kw): + raise NotImplementedError() def _make_extra_froms(self, statement): froms = [] @@ -174,6 +128,51 @@ class DMLState(CompileState): ) +@CompileState.plugin_for("default", "insert") +class InsertDMLState(DMLState): + isinsert = True + + def __init__(self, statement, compiler, **kw): + self.statement = statement + + self.isinsert = True + if statement._select_names: + self._process_select_values(statement) + if statement._values is not None: + self._process_values(statement) + if statement._multi_values: + self._process_multi_values(statement) + + +@CompileState.plugin_for("default", "update") +class UpdateDMLState(DMLState): + isupdate = True + + def __init__(self, statement, compiler, **kw): + self.statement = statement + + self.isupdate = True + self._preserve_parameter_order = statement._preserve_parameter_order + if statement._ordered_values is not None: + self._process_ordered_values(statement) + elif statement._values is not None: + self._process_values(statement) + elif statement._multi_values: + self._process_multi_values(statement) + self._extra_froms = self._make_extra_froms(statement) + + +@CompileState.plugin_for("default", "delete") +class DeleteDMLState(DMLState): + isdelete = True + + def __init__(self, statement, compiler, **kw): + self.statement = statement + + self.isdelete = True + self._extra_froms = self._make_extra_froms(statement) + + class UpdateBase( roles.DMLRole, HasCTE, @@ -754,8 +753,6 @@ class Insert(ValuesBase): _supports_multi_parameters = True - _compile_state_factory = DMLState._create_insert - select = None include_insert_from_select_defaults = False @@ -964,8 +961,6 @@ class Update(DMLWhereBase, ValuesBase): __visit_name__ = "update" - _compile_state_factory = DMLState._create_update - _traverse_internals = ( [ ("table", InternalTraversal.dp_clauseelement), @@ -1210,8 +1205,6 @@ class Delete(DMLWhereBase, UpdateBase): __visit_name__ = "delete" - _compile_state_factory = DMLState._create_delete - _traverse_internals = ( [ ("table", InternalTraversal.dp_clauseelement), diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index c1bc9edbc..287e53724 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -191,7 +191,12 @@ class ClauseElement( __visit_name__ = "clause" - _annotations = {} + _propagate_attrs = util.immutabledict() + """like annotations, however these propagate outwards liberally + as SQL constructs are built, and are set up at construction time. + + """ + supports_execution = False _from_objects = [] bind = None @@ -215,6 +220,16 @@ class ClauseElement( _cache_key_traversal = None + def _set_propagate_attrs(self, values): + # usually, self._propagate_attrs is empty here. one case where it's + # not is a subquery against ORM select, that is then pulled as a + # property of an aliased class. should all be good + + # assert not self._propagate_attrs + + self._propagate_attrs = util.immutabledict(values) + return self + def _clone(self): """Create a shallow copy of this ClauseElement. @@ -870,6 +885,7 @@ class ColumnElement( type_=getattr(self, "type", None), _selectable=selectable, ) + co._propagate_attrs = selectable._propagate_attrs co._proxies = [self] if selectable._is_clone_of is not None: co._is_clone_of = selectable._is_clone_of.columns.get(key) @@ -1495,6 +1511,8 @@ class TextClause( _render_label_in_columns_clause = False + _hide_froms = () + def __and__(self, other): # support use in select.where(), query.filter() return and_(self, other) @@ -1509,10 +1527,6 @@ class TextClause( _allow_label_resolve = False - @property - def _hide_froms(self): - return [] - def __init__(self, text, bind=None): self._bind = bind self._bindparams = {} @@ -2093,14 +2107,16 @@ class ClauseList( ) if self.group_contents: self.clauses = [ - coercions.expect(text_converter_role, clause).self_group( - against=self.operator - ) + coercions.expect( + text_converter_role, clause, apply_propagate_attrs=self + ).self_group(against=self.operator) for clause in clauses ] else: self.clauses = [ - coercions.expect(text_converter_role, clause) + coercions.expect( + text_converter_role, clause, apply_propagate_attrs=self + ) for clause in clauses ] self._is_implicitly_boolean = operators.is_boolean(self.operator) @@ -2641,7 +2657,9 @@ class Case(ColumnElement): whenlist = [ ( coercions.expect( - roles.ExpressionElementRole, c + roles.ExpressionElementRole, + c, + apply_propagate_attrs=self, ).self_group(), coercions.expect(roles.ExpressionElementRole, r), ) @@ -2650,7 +2668,9 @@ class Case(ColumnElement): else: whenlist = [ ( - coercions.expect(roles.ColumnArgumentRole, c).self_group(), + coercions.expect( + roles.ColumnArgumentRole, c, apply_propagate_attrs=self + ).self_group(), coercions.expect(roles.ExpressionElementRole, r), ) for (c, r) in whens @@ -2805,7 +2825,10 @@ class Cast(WrapsColumnExpression, ColumnElement): """ self.type = type_api.to_instance(type_) self.clause = coercions.expect( - roles.ExpressionElementRole, expression, type_=self.type + roles.ExpressionElementRole, + expression, + type_=self.type, + apply_propagate_attrs=self, ) self.typeclause = TypeClause(self.type) @@ -2906,7 +2929,10 @@ class TypeCoerce(WrapsColumnExpression, ColumnElement): """ self.type = type_api.to_instance(type_) self.clause = coercions.expect( - roles.ExpressionElementRole, expression, type_=self.type + roles.ExpressionElementRole, + expression, + type_=self.type, + apply_propagate_attrs=self, ) @property @@ -3031,6 +3057,7 @@ class UnaryExpression(ColumnElement): ): self.operator = operator self.modifier = modifier + self._propagate_attrs = element._propagate_attrs self.element = element.self_group( against=self.operator or self.modifier ) @@ -3474,6 +3501,7 @@ class BinaryExpression(ColumnElement): if isinstance(operator, util.string_types): operator = operators.custom_op(operator) self._orig = (left.__hash__(), right.__hash__()) + self._propagate_attrs = left._propagate_attrs or right._propagate_attrs self.left = left.self_group(against=operator) self.right = right.self_group(against=operator) self.operator = operator @@ -4159,6 +4187,7 @@ class Label(roles.LabeledColumnExprRole, ColumnElement): name=name if name else self.name, disallow_is_literal=True, ) + e._propagate_attrs = selectable._propagate_attrs e._proxies.append(self) if self._type is not None: e.type = self._type @@ -4340,16 +4369,10 @@ class ColumnClause( return other.proxy_set.intersection(self.proxy_set) def get_children(self, column_tables=False, **kw): - if column_tables and self.table is not None: - # TODO: this is only used by ORM query deep_entity_zero. - # this is being removed in a later release so remove - # column_tables also at that time. - return [self.table] - else: - # override base get_children() to not return the Table - # or selectable that is parent to this column. Traversals - # expect the columns of tables and subqueries to be leaf nodes. - return [] + # override base get_children() to not return the Table + # or selectable that is parent to this column. Traversals + # expect the columns of tables and subqueries to be leaf nodes. + return [] @HasMemoized.memoized_attribute def _from_objects(self): @@ -4474,6 +4497,7 @@ class ColumnClause( _selectable=selectable, is_literal=is_literal, ) + c._propagate_attrs = selectable._propagate_attrs if name is None: c.key = self.key c._proxies = [self] diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index cedb76f55..6b1172eba 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -107,6 +107,7 @@ class FunctionElement(Executable, ColumnElement, FromClause): roles.ExpressionElementRole, c, name=getattr(self, "name", None), + apply_propagate_attrs=self, ) for c in clauses ] @@ -749,7 +750,10 @@ class GenericFunction(util.with_metaclass(_GenericMeta, Function)): if parsed_args is None: parsed_args = [ coercions.expect( - roles.ExpressionElementRole, c, name=self.name + roles.ExpressionElementRole, + c, + name=self.name, + apply_propagate_attrs=self, ) for c in args ] @@ -813,7 +817,12 @@ class ReturnTypeFromArgs(GenericFunction): def __init__(self, *args, **kwargs): args = [ - coercions.expect(roles.ExpressionElementRole, c, name=self.name) + coercions.expect( + roles.ExpressionElementRole, + c, + name=self.name, + apply_propagate_attrs=self, + ) for c in args ] kwargs.setdefault("type_", _type_from_args(args)) @@ -944,7 +953,12 @@ class array_agg(GenericFunction): type = sqltypes.ARRAY def __init__(self, *args, **kwargs): - args = [coercions.expect(roles.ExpressionElementRole, c) for c in args] + args = [ + coercions.expect( + roles.ExpressionElementRole, c, apply_propagate_attrs=self + ) + for c in args + ] default_array_type = kwargs.pop("_default_array_type", sqltypes.ARRAY) if "type_" not in kwargs: diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index b861f721b..d0f4fef60 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -142,12 +142,20 @@ class AnonymizedFromClauseRole(StrictFromClauseRole): class CoerceTextStatementRole(SQLRole): - _role_name = "Executable SQL, text() construct, or string statement" + _role_name = "Executable SQL or text() construct" + + +# _executable_statement = None class StatementRole(CoerceTextStatementRole): _role_name = "Executable SQL or text() construct" + _is_future = False + + def _get_plugin_compile_state_cls(self, name): + return None + class ReturnsRowsRole(StatementRole): _role_name = ( diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 65f8bd81c..263f579de 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -1632,6 +1632,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): ) c.table = selectable + c._propagate_attrs = selectable._propagate_attrs if selectable._is_clone_of is not None: c._is_clone_of = selectable._is_clone_of.columns.get(c.key) if self.primary_key: diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 6a552c18c..008959aec 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -1342,7 +1342,9 @@ class AliasedReturnsRows(NoInit, FromClause): raise NotImplementedError() def _init(self, selectable, name=None): - self.element = selectable + self.element = coercions.expect( + roles.ReturnsRowsRole, selectable, apply_propagate_attrs=self + ) self.supports_execution = selectable.supports_execution if self.supports_execution: self._execution_options = selectable._execution_options @@ -3026,6 +3028,7 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): ) +@CompileState.plugin_for("default", "compound_select") class CompoundSelectState(CompileState): @util.memoized_property def _label_resolve_dict(self): @@ -3058,7 +3061,6 @@ class CompoundSelect(HasCompileState, GenerativeSelect): """ __visit_name__ = "compound_select" - _compile_state_factory = CompoundSelectState._create _traverse_internals = [ ("selects", InternalTraversal.dp_clauseelement_list), @@ -3425,6 +3427,7 @@ class DeprecatedSelectGenerations(object): self.select_from.non_generative(self, fromclause) +@CompileState.plugin_for("default", "select") class SelectState(CompileState): class default_select_compile_options(CacheableOptions): _cache_key_traversal = [] @@ -3462,7 +3465,7 @@ class SelectState(CompileState): ) if not seen.intersection(item._cloned_set): froms.append(item) - seen.update(item._cloned_set) + seen.update(item._cloned_set) return froms @@ -3714,12 +3717,29 @@ class SelectState(CompileState): return replace_from_obj_index +class _SelectFromElements(object): + def _iterate_from_elements(self): + # note this does not include elements + # in _setup_joins or _legacy_setup_joins + + return itertools.chain( + itertools.chain.from_iterable( + [element._from_objects for element in self._raw_columns] + ), + itertools.chain.from_iterable( + [element._from_objects for element in self._where_criteria] + ), + self._from_obj, + ) + + class Select( HasPrefixes, HasSuffixes, HasHints, HasCompileState, DeprecatedSelectGenerations, + _SelectFromElements, GenerativeSelect, ): """Represents a ``SELECT`` statement. @@ -3728,7 +3748,6 @@ class Select( __visit_name__ = "select" - _compile_state_factory = SelectState._create _is_future = False _setup_joins = () _legacy_setup_joins = () @@ -4047,7 +4066,7 @@ class Select( if cols_present: self._raw_columns = [ coercions.expect( - roles.ColumnsClauseRole, c, apply_plugins=self + roles.ColumnsClauseRole, c, apply_propagate_attrs=self ) for c in columns ] @@ -4073,17 +4092,6 @@ class Select( cols = list(elem._select_iterable) return cols[0].type - def _iterate_from_elements(self): - return itertools.chain( - itertools.chain.from_iterable( - [element._from_objects for element in self._raw_columns] - ), - itertools.chain.from_iterable( - [element._from_objects for element in self._where_criteria] - ), - self._from_obj, - ) - @property def froms(self): """Return the displayed list of FromClause elements.""" @@ -4192,14 +4200,16 @@ class Select( self._raw_columns = self._raw_columns + [ coercions.expect( - roles.ColumnsClauseRole, column, apply_plugins=self + roles.ColumnsClauseRole, column, apply_propagate_attrs=self ) for column in columns ] def _set_entities(self, entities): self._raw_columns = [ - coercions.expect(roles.ColumnsClauseRole, ent, apply_plugins=self) + coercions.expect( + roles.ColumnsClauseRole, ent, apply_propagate_attrs=self + ) for ent in util.to_list(entities) ] @@ -4342,14 +4352,24 @@ class Select( self._raw_columns = rc @property - def _whereclause(self): - """Legacy, return the WHERE clause as a """ - """:class:`_expression.BooleanClauseList`""" + def whereclause(self): + """Return the completed WHERE clause for this :class:`.Select` + statement. + + This assembles the current collection of WHERE criteria + into a single :class:`_expression.BooleanClauseList` construct. + + + .. versionadded:: 1.4 + + """ return BooleanClauseList._construct_for_whereclause( self._where_criteria ) + _whereclause = whereclause + @_generative def where(self, whereclause): """return a new select() construct with the given expression added to @@ -4430,7 +4450,7 @@ class Select( self._from_obj += tuple( coercions.expect( - roles.FromClauseRole, fromclause, apply_plugins=self + roles.FromClauseRole, fromclause, apply_propagate_attrs=self ) for fromclause in froms ) diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index a308feb7c..482248ada 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -179,7 +179,7 @@ class HasCacheKey(object): if NO_CACHE in _anon_map: return None else: - return CacheKey(key, bindparams) + return CacheKey(key, bindparams, self) @classmethod def _generate_cache_key_for_object(cls, obj): @@ -190,7 +190,7 @@ class HasCacheKey(object): if NO_CACHE in _anon_map: return None else: - return CacheKey(key, bindparams) + return CacheKey(key, bindparams, obj) class MemoizedHasCacheKey(HasCacheKey, HasMemoized): @@ -199,9 +199,42 @@ class MemoizedHasCacheKey(HasCacheKey, HasMemoized): return HasCacheKey._generate_cache_key(self) -class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])): +class CacheKey(namedtuple("CacheKey", ["key", "bindparams", "statement"])): def __hash__(self): - return hash(self.key) + """CacheKey itself is not hashable - hash the .key portion""" + + return None + + def to_offline_string(self, statement_cache, parameters): + """generate an "offline string" form of this :class:`.CacheKey` + + The "offline string" is basically the string SQL for the + statement plus a repr of the bound parameter values in series. + Whereas the :class:`.CacheKey` object is dependent on in-memory + identities in order to work as a cache key, the "offline" version + is suitable for a cache that will work for other processes as well. + + The given "statement_cache" is a dictionary-like object where the + string form of the statement itself will be cached. this dictionary + should be in a longer lived scope in order to reduce the time spent + stringifying statements. + + + """ + if self.key not in statement_cache: + statement_cache[self.key] = sql_str = str(self.statement) + else: + sql_str = statement_cache[self.key] + + return repr( + ( + sql_str, + tuple( + parameters.get(bindparam.key, bindparam.value) + for bindparam in self.bindparams + ), + ) + ) def __eq__(self, other): return self.key == other.key @@ -411,7 +444,6 @@ class _CacheKey(ExtendedInternalTraversal): def visit_setup_join_tuple( self, attrname, obj, parent, anon_map, bindparams ): - # TODO: look at attrname for "legacy_join" and use different structure return tuple( ( target._gen_cache_key(anon_map, bindparams), @@ -596,7 +628,6 @@ class _CopyInternals(InternalTraversal): def visit_setup_join_tuple( self, attrname, parent, element, clone=_clone, **kw ): - # TODO: look at attrname for "legacy_join" and use different structure return tuple( ( clone(target, **kw) if target is not None else None, @@ -668,6 +699,15 @@ class _CopyInternals(InternalTraversal): _copy_internals = _CopyInternals() +def _flatten_clauseelement(element): + while hasattr(element, "__clause_element__") and not getattr( + element, "is_clause_element", False + ): + element = element.__clause_element__() + + return element + + class _GetChildren(InternalTraversal): """Generate a _children_traversal internal traversal dispatch for classes with a _traverse_internals collection.""" @@ -696,6 +736,17 @@ class _GetChildren(InternalTraversal): def visit_clauseelement_unordered_set(self, element, **kw): return element + def visit_setup_join_tuple(self, element, **kw): + for (target, onclause, from_, flags) in element: + if from_ is not None: + yield from_ + + if not isinstance(target, str): + yield _flatten_clauseelement(target) + + # if onclause is not None and not isinstance(onclause, str): + # yield _flatten_clauseelement(onclause) + def visit_dml_ordered_values(self, element, **kw): for k, v in element: if hasattr(k, "__clause_element__"): diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 030fd2fde..683f545dd 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -591,6 +591,7 @@ def iterate(obj, opts=util.immutabledict()): """ yield obj children = obj.get_children(**opts) + if not children: return diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py index 24e96dfab..92bd452a5 100644 --- a/lib/sqlalchemy/testing/profiling.py +++ b/lib/sqlalchemy/testing/profiling.py @@ -294,6 +294,7 @@ def count_functions(variance=0.05): print(("Pstats calls: %d Expected %s" % (callcount, expected_count))) stats.sort_stats(_profile_stats.sort) stats.print_stats() + # stats.print_callers() if _profile_stats.force_write: _profile_stats.replace(callcount) elif expected_count: diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 9a832ba1b..605686494 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -97,17 +97,13 @@ class FacadeDict(ImmutableContainer, dict): def __new__(cls, *args): new = dict.__new__(cls) - dict.__init__(new, *args) return new - def __init__(self, *args): - pass - - # note that currently, "copy()" is used as a way to get a plain dict - # from an immutabledict, while also allowing the method to work if the - # dictionary is already a plain dict. - # def copy(self): - # return immutabledict.__new__(immutabledict, self) + def copy(self): + raise NotImplementedError( + "an immutabledict shouldn't need to be copied. use dict(d) " + "if you need a mutable dictionary." + ) def __reduce__(self): return FacadeDict, (dict(self),) |