diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2020-05-28 19:28:35 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2020-05-28 19:28:35 +0000 |
commit | 056bad48e2bc948a08621ab841fd882cb6934262 (patch) | |
tree | 2635059b149309c2ad7a648bfce13fd5844d8dc8 /lib | |
parent | c07979e8d44a30fdf0ea73bc587aa05a52e9955a (diff) | |
parent | 77f1b7d236dba6b1c859bb428ef32d118ec372e6 (diff) | |
download | sqlalchemy-056bad48e2bc948a08621ab841fd882cb6934262.tar.gz |
Merge "callcount reductions and refinement for cached queries"
Diffstat (limited to 'lib')
24 files changed, 881 insertions, 806 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 3345d555f..5e0704597 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -1529,8 +1529,7 @@ class MSExecutionContext(default.DefaultExecutionContext): elif ( self.isinsert or self.isupdate or self.isdelete ) and self.compiled.returning: - fbcr = _cursor.FullyBufferedCursorFetchStrategy - self._result_strategy = fbcr.create_from_buffer( + self.cursor_fetch_strategy = _cursor.FullyBufferedCursorFetchStrategy( # noqa self.cursor, self.cursor.description, self.cursor.fetchall() ) @@ -1571,14 +1570,6 @@ class MSExecutionContext(default.DefaultExecutionContext): except Exception: pass - def get_result_cursor_strategy(self, result): - if self._result_strategy: - return self._result_strategy - else: - return super(MSExecutionContext, self).get_result_cursor_strategy( - result - ) - class MSSQLCompiler(compiler.SQLCompiler): returning_precedes_values = True diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index c61a1cc0a..4aae059dd 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -696,6 +696,27 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): self._generate_cursor_outputtype_handler() + def post_exec(self): + if self.compiled and self.out_parameters and self.compiled.returning: + # create a fake cursor result from the out parameters. unlike + # get_out_parameter_values(), the result-row handlers here will be + # applied at the Result level + returning_params = [ + self.dialect._returningval(self.out_parameters["ret_%d" % i]) + for i in range(len(self.out_parameters)) + ] + + fetch_strategy = _cursor.FullyBufferedCursorFetchStrategy( + self.cursor, + [ + (getattr(col, "name", col.anon_label), None) + for col in self.compiled.returning + ], + initial_buffer=[tuple(returning_params)], + ) + + self.cursor_fetch_strategy = fetch_strategy + def create_cursor(self): c = self._dbapi_connection.cursor() if self.dialect.arraysize: @@ -714,29 +735,6 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): for name in out_param_names ] - def get_result_cursor_strategy(self, result): - if self.compiled and self.out_parameters and self.compiled.returning: - # create a fake cursor result from the out parameters. unlike - # get_out_parameter_values(), the result-row handlers here will be - # applied at the Result level - returning_params = [ - self.dialect._returningval(self.out_parameters["ret_%d" % i]) - for i in range(len(self.out_parameters)) - ] - - return _cursor.FullyBufferedCursorFetchStrategy( - result.cursor, - [ - (getattr(col, "name", col.anon_label), None) - for col in result.context.compiled.returning - ], - initial_buffer=[tuple(returning_params)], - ) - else: - return super( - OracleExecutionContext_cx_oracle, self - ).get_result_cursor_strategy(result) - class OracleDialect_cx_oracle(OracleDialect): execution_ctx_cls = OracleExecutionContext_cx_oracle diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index 9585dd467..a9408bcb0 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -592,13 +592,9 @@ class PGExecutionContext_psycopg2(PGExecutionContext): ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:]) return self._dbapi_connection.cursor(ident) - def get_result_cursor_strategy(self, result): + def post_exec(self): self._log_notices(self.cursor) - return super(PGExecutionContext, self).get_result_cursor_strategy( - result - ) - def _log_notices(self, cursor): # check also that notices is an iterable, after it's already # established that we will be iterating through it. This is to get diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 0193ea47c..a36f4eee2 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -25,6 +25,8 @@ from ..sql import util as sql_util """ +_EMPTY_EXECUTION_OPTS = util.immutabledict() + class Connection(Connectable): """Provides high-level functionality for a wrapped DB-API connection. @@ -1038,7 +1040,11 @@ class Connection(Connectable): distilled_parameters = _distill_params(multiparams, params) return self._exec_driver_sql( - object_, multiparams, params, distilled_parameters + object_, + multiparams, + params, + distilled_parameters, + _EMPTY_EXECUTION_OPTS, ) try: meth = object_._execute_on_connection @@ -1047,24 +1053,29 @@ class Connection(Connectable): exc.ObjectNotExecutableError(object_), replace_context=err ) else: - return meth(self, multiparams, params, util.immutabledict()) + return meth(self, multiparams, params, _EMPTY_EXECUTION_OPTS) - def _execute_function( - self, func, multiparams, params, execution_options=util.immutabledict() - ): + def _execute_function(self, func, multiparams, params, execution_options): """Execute a sql.FunctionElement object.""" - return self._execute_clauseelement(func.select(), multiparams, params) + return self._execute_clauseelement( + func.select(), multiparams, params, execution_options + ) def _execute_default( self, default, multiparams, params, - execution_options=util.immutabledict(), + # migrate is calling this directly :( + execution_options=_EMPTY_EXECUTION_OPTS, ): """Execute a schema.ColumnDefault object.""" + execution_options = self._execution_options.merge_with( + execution_options + ) + if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: default, multiparams, params = fn( @@ -1096,11 +1107,13 @@ class Connection(Connectable): return ret - def _execute_ddl( - self, ddl, multiparams, params, execution_options=util.immutabledict() - ): + def _execute_ddl(self, ddl, multiparams, params, execution_options): """Execute a schema.DDL object.""" + execution_options = ddl._execution_options.merge_with( + self._execution_options, execution_options + ) + if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: ddl, multiparams, params = fn( @@ -1130,11 +1143,16 @@ class Connection(Connectable): return ret def _execute_clauseelement( - self, elem, multiparams, params, execution_options=util.immutabledict() + self, elem, multiparams, params, execution_options ): """Execute a sql.ClauseElement object.""" - if self._has_events or self.engine._has_events: + execution_options = elem._execution_options.merge_with( + self._execution_options, execution_options + ) + + has_events = self._has_events or self.engine._has_events + if has_events: for fn in self.dispatch.before_execute: elem, multiparams, params = fn( self, elem, multiparams, params, execution_options @@ -1144,18 +1162,19 @@ class Connection(Connectable): if distilled_params: # ensure we don't retain a link to the view object for keys() # which links to the values, which we don't want to cache - keys = list(distilled_params[0].keys()) - + keys = sorted(distilled_params[0]) + inline = len(distilled_params) > 1 else: keys = [] + inline = False dialect = self.dialect - exec_opts = self._execution_options.merge_with(execution_options) - - schema_translate_map = exec_opts.get("schema_translate_map", None) + schema_translate_map = execution_options.get( + "schema_translate_map", None + ) - compiled_cache = exec_opts.get( + compiled_cache = execution_options.get( "compiled_cache", self.dialect._compiled_cache ) @@ -1165,13 +1184,13 @@ class Connection(Connectable): elem_cache_key = None if elem_cache_key: - cache_key, extracted_params, _ = elem_cache_key + cache_key, extracted_params = elem_cache_key key = ( dialect, cache_key, - tuple(sorted(keys)), + tuple(keys), bool(schema_translate_map), - len(distilled_params) > 1, + inline, ) compiled_sql = compiled_cache.get(key) @@ -1180,7 +1199,7 @@ class Connection(Connectable): dialect=dialect, cache_key=elem_cache_key, column_keys=keys, - inline=len(distilled_params) > 1, + inline=inline, schema_translate_map=schema_translate_map, linting=self.dialect.compiler_linting | compiler.WARN_LINTING, @@ -1191,7 +1210,7 @@ class Connection(Connectable): compiled_sql = elem.compile( dialect=dialect, column_keys=keys, - inline=len(distilled_params) > 1, + inline=inline, schema_translate_map=schema_translate_map, linting=self.dialect.compiler_linting | compiler.WARN_LINTING, ) @@ -1207,7 +1226,7 @@ class Connection(Connectable): elem, extracted_params, ) - if self._has_events or self.engine._has_events: + if has_events: self.dispatch.after_execute( self, elem, multiparams, params, execution_options, ret ) @@ -1218,9 +1237,17 @@ class Connection(Connectable): compiled, multiparams, params, - execution_options=util.immutabledict(), + execution_options=_EMPTY_EXECUTION_OPTS, ): - """Execute a sql.Compiled object.""" + """Execute a sql.Compiled object. + + TODO: why do we have this? likely deprecate or remove + + """ + + execution_options = compiled.execution_options.merge_with( + self._execution_options, execution_options + ) if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: @@ -1253,9 +1280,13 @@ class Connection(Connectable): multiparams, params, distilled_parameters, - execution_options=util.immutabledict(), + execution_options, ): + execution_options = self._execution_options.merge_with( + execution_options + ) + if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: statement, multiparams, params = fn( @@ -1282,7 +1313,7 @@ class Connection(Connectable): self, statement, parameters=None, - execution_options=util.immutabledict(), + execution_options=_EMPTY_EXECUTION_OPTS, ): multiparams, params, distilled_parameters = _distill_params_20( parameters @@ -1398,8 +1429,7 @@ class Connection(Connectable): if self._is_future and self._transaction is None: self.begin() - if context.compiled: - context.pre_exec() + context.pre_exec() cursor, statement, parameters = ( context.cursor, @@ -1495,30 +1525,35 @@ class Connection(Connectable): context.executemany, ) - if context.compiled: - context.post_exec() + context.post_exec() result = context._setup_result_proxy() - if ( - not self._is_future - # usually we're in a transaction so avoid relatively - # expensive / legacy should_autocommit call - and self._transaction is None - and context.should_autocommit - ): - self._commit_impl(autocommit=True) + if not self._is_future: + should_close_with_result = branched.should_close_with_result - # for "connectionless" execution, we have to close this - # Connection after the statement is complete. - # legacy stuff. - if branched.should_close_with_result and context._soft_closed: - assert not self._is_future - assert not context._is_future_result + if not result._soft_closed and should_close_with_result: + result._autoclose_connection = True + + if ( + # usually we're in a transaction so avoid relatively + # expensive / legacy should_autocommit call + self._transaction is None + and context.should_autocommit + ): + self._commit_impl(autocommit=True) + + # for "connectionless" execution, we have to close this + # Connection after the statement is complete. + # legacy stuff. + if should_close_with_result and context._soft_closed: + assert not self._is_future + assert not context._is_future_result + + # CursorResult already exhausted rows / has no rows. + # close us now + branched.close() - # CursorResult already exhausted rows / has no rows. - # close us now - branched.close() except BaseException as e: self._handle_dbapi_exception( e, statement, parameters, cursor, context @@ -2319,7 +2354,7 @@ class Engine(Connectable, log.Identified): """ - _execution_options = util.immutabledict() + _execution_options = _EMPTY_EXECUTION_OPTS _has_events = False _connection_cls = Connection _sqla_logger_namespace = "sqlalchemy.engine.Engine" @@ -2709,13 +2744,29 @@ class Engine(Connectable, log.Identified): """ return self.execute(statement, *multiparams, **params).scalar() - def _execute_clauseelement(self, elem, multiparams=None, params=None): + def _execute_clauseelement( + self, + elem, + multiparams=None, + params=None, + execution_options=_EMPTY_EXECUTION_OPTS, + ): connection = self.connect(close_with_result=True) - return connection._execute_clauseelement(elem, multiparams, params) + return connection._execute_clauseelement( + elem, multiparams, params, execution_options + ) - def _execute_compiled(self, compiled, multiparams, params): + def _execute_compiled( + self, + compiled, + multiparams, + params, + execution_options=_EMPTY_EXECUTION_OPTS, + ): connection = self.connect(close_with_result=True) - return connection._execute_compiled(compiled, multiparams, params) + return connection._execute_compiled( + compiled, multiparams, params, execution_options + ) def connect(self, close_with_result=False): """Return a new :class:`_engine.Connection` object. diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index fdbf826ed..c32427644 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -10,12 +10,12 @@ import collections +import functools from .result import Result from .result import ResultMetaData from .result import SimpleResultMetaData from .result import tuplegetter -from .row import _baserow_usecext from .row import LegacyRow from .. import exc from .. import util @@ -89,14 +89,6 @@ class CursorResultMetaData(ResultMetaData): for index, rec in enumerate(self._metadata_for_keys(keys)) ] new_metadata._keymap = {rec[MD_LOOKUP_KEY]: rec for rec in new_recs} - if not _baserow_usecext: - # TODO: can consider assembling ints + negative ints here - new_metadata._keymap.update( - { - index: (index, new_keys[index], ()) - for index in range(len(new_keys)) - } - ) # TODO: need unit test for: # result = connection.execute("raw sql, no columns").scalars() @@ -186,25 +178,6 @@ class CursorResultMetaData(ResultMetaData): ) self._keymap = {} - if not _baserow_usecext: - # keymap indexes by integer index: this is only used - # in the pure Python BaseRow.__getitem__ - # implementation to avoid an expensive - # isinstance(key, util.int_types) in the most common - # case path - - len_raw = len(raw) - - self._keymap.update( - [ - (metadata_entry[MD_INDEX], metadata_entry) - for metadata_entry in raw - ] - + [ - (metadata_entry[MD_INDEX] - len_raw, metadata_entry) - for metadata_entry in raw - ] - ) # processors in key order for certain per-row # views like __iter__ and slices @@ -623,20 +596,23 @@ class CursorResultMetaData(ResultMetaData): return index def _indexes_for_keys(self, keys): - for rec in self._metadata_for_keys(keys): - yield rec[0] + + try: + return [self._keymap[key][0] for key in keys] + except KeyError as ke: + # ensure it raises + CursorResultMetaData._key_fallback(self, ke.args[0], ke) def _metadata_for_keys(self, keys): for key in keys: - # TODO: can consider pre-loading ints and negative ints - # into _keymap - if isinstance(key, int): + if int in key.__class__.__mro__: key = self._keys[key] try: rec = self._keymap[key] except KeyError as ke: - rec = self._key_fallback(key, ke) + # ensure it raises + CursorResultMetaData._key_fallback(self, ke.args[0], ke) index = rec[0] @@ -786,25 +762,27 @@ class ResultFetchStrategy(object): __slots__ = () - def soft_close(self, result): + alternate_cursor_description = None + + def soft_close(self, result, dbapi_cursor): raise NotImplementedError() - def hard_close(self, result): + def hard_close(self, result, dbapi_cursor): raise NotImplementedError() - def yield_per(self, result, num): + def yield_per(self, result, dbapi_cursor, num): return - def fetchone(self, result, hard_close=False): + def fetchone(self, result, dbapi_cursor, hard_close=False): raise NotImplementedError() - def fetchmany(self, result, size=None): + def fetchmany(self, result, dbapi_cursor, size=None): raise NotImplementedError() def fetchall(self, result): raise NotImplementedError() - def handle_exception(self, result, err): + def handle_exception(self, result, dbapi_cursor, err): raise err @@ -819,21 +797,19 @@ class NoCursorFetchStrategy(ResultFetchStrategy): __slots__ = () - cursor_description = None - - def soft_close(self, result): + def soft_close(self, result, dbapi_cursor): pass - def hard_close(self, result): + def hard_close(self, result, dbapi_cursor): pass - def fetchone(self, result, hard_close=False): + def fetchone(self, result, dbapi_cursor, hard_close=False): return self._non_result(result, None) - def fetchmany(self, result, size=None): + def fetchmany(self, result, dbapi_cursor, size=None): return self._non_result(result, []) - def fetchall(self, result): + def fetchall(self, result, dbapi_cursor): return self._non_result(result, []) def _non_result(self, result, default, err=None): @@ -893,71 +869,59 @@ class CursorFetchStrategy(ResultFetchStrategy): """ - __slots__ = ("dbapi_cursor", "cursor_description") - - def __init__(self, dbapi_cursor, cursor_description): - self.dbapi_cursor = dbapi_cursor - self.cursor_description = cursor_description - - @classmethod - def create(cls, result): - dbapi_cursor = result.cursor - description = dbapi_cursor.description - - if description is None: - return _NO_CURSOR_DML - else: - return cls(dbapi_cursor, description) + __slots__ = () - def soft_close(self, result): + def soft_close(self, result, dbapi_cursor): result.cursor_strategy = _NO_CURSOR_DQL - def hard_close(self, result): + def hard_close(self, result, dbapi_cursor): result.cursor_strategy = _NO_CURSOR_DQL - def handle_exception(self, result, err): + def handle_exception(self, result, dbapi_cursor, err): result.connection._handle_dbapi_exception( - err, None, None, self.dbapi_cursor, result.context + err, None, None, dbapi_cursor, result.context ) - def yield_per(self, result, num): + def yield_per(self, result, dbapi_cursor, num): result.cursor_strategy = BufferedRowCursorFetchStrategy( - self.dbapi_cursor, - self.cursor_description, - num, - collections.deque(), + dbapi_cursor, + {"max_row_buffer": num}, + initial_buffer=collections.deque(), growth_factor=0, ) - def fetchone(self, result, hard_close=False): + def fetchone(self, result, dbapi_cursor, hard_close=False): try: - row = self.dbapi_cursor.fetchone() + row = dbapi_cursor.fetchone() if row is None: result._soft_close(hard=hard_close) return row except BaseException as e: - self.handle_exception(result, e) + self.handle_exception(result, dbapi_cursor, e) - def fetchmany(self, result, size=None): + def fetchmany(self, result, dbapi_cursor, size=None): try: if size is None: - l = self.dbapi_cursor.fetchmany() + l = dbapi_cursor.fetchmany() else: - l = self.dbapi_cursor.fetchmany(size) + l = dbapi_cursor.fetchmany(size) if not l: result._soft_close() return l except BaseException as e: - self.handle_exception(result, e) + self.handle_exception(result, dbapi_cursor, e) - def fetchall(self, result): + def fetchall(self, result, dbapi_cursor): try: - rows = self.dbapi_cursor.fetchall() + rows = dbapi_cursor.fetchall() result._soft_close() return rows except BaseException as e: - self.handle_exception(result, e) + self.handle_exception(result, dbapi_cursor, e) + + +_DEFAULT_FETCH = CursorFetchStrategy() class BufferedRowCursorFetchStrategy(CursorFetchStrategy): @@ -993,18 +957,18 @@ class BufferedRowCursorFetchStrategy(CursorFetchStrategy): def __init__( self, dbapi_cursor, - description, - max_row_buffer, - initial_buffer, + execution_options, growth_factor=5, + initial_buffer=None, ): - super(BufferedRowCursorFetchStrategy, self).__init__( - dbapi_cursor, description - ) - self._max_row_buffer = max_row_buffer + self._max_row_buffer = execution_options.get("max_row_buffer", 1000) + + if initial_buffer is not None: + self._rowbuffer = initial_buffer + else: + self._rowbuffer = collections.deque(dbapi_cursor.fetchmany(1)) self._growth_factor = growth_factor - self._rowbuffer = initial_buffer if growth_factor: self._bufsize = min(self._max_row_buffer, self._growth_factor) @@ -1013,39 +977,19 @@ class BufferedRowCursorFetchStrategy(CursorFetchStrategy): @classmethod def create(cls, result): - """Buffered row strategy has to buffer the first rows *before* - cursor.description is fetched so that it works with named cursors - correctly - - """ - - dbapi_cursor = result.cursor - - # TODO: is create() called within a handle_error block externally? - # can this be guaranteed / tested / etc - initial_buffer = collections.deque(dbapi_cursor.fetchmany(1)) - - description = dbapi_cursor.description - - if description is None: - return _NO_CURSOR_DML - else: - max_row_buffer = result.context.execution_options.get( - "max_row_buffer", 1000 - ) - return cls( - dbapi_cursor, description, max_row_buffer, initial_buffer - ) + return BufferedRowCursorFetchStrategy( + result.cursor, result.context.execution_options, + ) - def _buffer_rows(self, result): + def _buffer_rows(self, result, dbapi_cursor): size = self._bufsize try: if size < 1: - new_rows = self.dbapi_cursor.fetchall() + new_rows = dbapi_cursor.fetchall() else: - new_rows = self.dbapi_cursor.fetchmany(size) + new_rows = dbapi_cursor.fetchmany(size) except BaseException as e: - self.handle_exception(result, e) + self.handle_exception(result, dbapi_cursor, e) if not new_rows: return @@ -1055,21 +999,25 @@ class BufferedRowCursorFetchStrategy(CursorFetchStrategy): self._max_row_buffer, size * self._growth_factor ) - def yield_per(self, result, num): + def yield_per(self, result, dbapi_cursor, num): self._growth_factor = 0 self._max_row_buffer = self._bufsize = num - def soft_close(self, result): + def soft_close(self, result, dbapi_cursor): self._rowbuffer.clear() - super(BufferedRowCursorFetchStrategy, self).soft_close(result) + super(BufferedRowCursorFetchStrategy, self).soft_close( + result, dbapi_cursor + ) - def hard_close(self, result): + def hard_close(self, result, dbapi_cursor): self._rowbuffer.clear() - super(BufferedRowCursorFetchStrategy, self).hard_close(result) + super(BufferedRowCursorFetchStrategy, self).hard_close( + result, dbapi_cursor + ) - def fetchone(self, result, hard_close=False): + def fetchone(self, result, dbapi_cursor, hard_close=False): if not self._rowbuffer: - self._buffer_rows(result) + self._buffer_rows(result, dbapi_cursor) if not self._rowbuffer: try: result._soft_close(hard=hard_close) @@ -1078,15 +1026,15 @@ class BufferedRowCursorFetchStrategy(CursorFetchStrategy): return None return self._rowbuffer.popleft() - def fetchmany(self, result, size=None): + def fetchmany(self, result, dbapi_cursor, size=None): if size is None: - return self.fetchall(result) + return self.fetchall(result, dbapi_cursor) buf = list(self._rowbuffer) lb = len(buf) if size > lb: try: - buf.extend(self.dbapi_cursor.fetchmany(size - lb)) + buf.extend(dbapi_cursor.fetchmany(size - lb)) except BaseException as e: self.handle_exception(result, e) @@ -1094,14 +1042,14 @@ class BufferedRowCursorFetchStrategy(CursorFetchStrategy): self._rowbuffer = collections.deque(buf[size:]) return result - def fetchall(self, result): + def fetchall(self, result, dbapi_cursor): try: - ret = list(self._rowbuffer) + list(self.dbapi_cursor.fetchall()) + ret = list(self._rowbuffer) + list(dbapi_cursor.fetchall()) self._rowbuffer.clear() result._soft_close() return ret except BaseException as e: - self.handle_exception(result, e) + self.handle_exception(result, dbapi_cursor, e) class FullyBufferedCursorFetchStrategy(CursorFetchStrategy): @@ -1113,42 +1061,42 @@ class FullyBufferedCursorFetchStrategy(CursorFetchStrategy): """ - __slots__ = ("_rowbuffer",) + __slots__ = ("_rowbuffer", "alternate_cursor_description") - def __init__(self, dbapi_cursor, description, initial_buffer=None): - super(FullyBufferedCursorFetchStrategy, self).__init__( - dbapi_cursor, description - ) + def __init__( + self, dbapi_cursor, alternate_description, initial_buffer=None + ): + self.alternate_cursor_description = alternate_description if initial_buffer is not None: self._rowbuffer = collections.deque(initial_buffer) else: - self._rowbuffer = collections.deque(self.dbapi_cursor.fetchall()) - - @classmethod - def create_from_buffer(cls, dbapi_cursor, description, buffer): - return cls(dbapi_cursor, description, buffer) + self._rowbuffer = collections.deque(dbapi_cursor.fetchall()) - def yield_per(self, result, num): + def yield_per(self, result, dbapi_cursor, num): pass - def soft_close(self, result): + def soft_close(self, result, dbapi_cursor): self._rowbuffer.clear() - super(FullyBufferedCursorFetchStrategy, self).soft_close(result) + super(FullyBufferedCursorFetchStrategy, self).soft_close( + result, dbapi_cursor + ) - def hard_close(self, result): + def hard_close(self, result, dbapi_cursor): self._rowbuffer.clear() - super(FullyBufferedCursorFetchStrategy, self).hard_close(result) + super(FullyBufferedCursorFetchStrategy, self).hard_close( + result, dbapi_cursor + ) - def fetchone(self, result, hard_close=False): + def fetchone(self, result, dbapi_cursor, hard_close=False): if self._rowbuffer: return self._rowbuffer.popleft() else: result._soft_close(hard=hard_close) return None - def fetchmany(self, result, size=None): + def fetchmany(self, result, dbapi_cursor, size=None): if size is None: - return self.fetchall(result) + return self.fetchall(result, dbapi_cursor) buf = list(self._rowbuffer) rows = buf[0:size] @@ -1157,7 +1105,7 @@ class FullyBufferedCursorFetchStrategy(CursorFetchStrategy): result._soft_close() return rows - def fetchall(self, result): + def fetchall(self, result, dbapi_cursor): ret = self._rowbuffer self._rowbuffer = collections.deque() result._soft_close() @@ -1210,40 +1158,53 @@ class BaseCursorResult(object): _soft_closed = False closed = False - @classmethod - def _create_for_context(cls, context): - - if context._is_future_result: - obj = CursorResult(context) - else: - obj = LegacyCursorResult(context) - return obj - - def __init__(self, context): + def __init__(self, context, cursor_strategy, cursor_description): self.context = context self.dialect = context.dialect self.cursor = context.cursor + self.cursor_strategy = cursor_strategy self.connection = context.root_connection self._echo = echo = ( self.connection._echo and context.engine._should_log_debug() ) - if echo: - log = self.context.engine.logger.debug + if cursor_description is not None: + # inline of Result._row_getter(), set up an initial row + # getter assuming no transformations will be called as this + # is the most common case + + if echo: + log = self.context.engine.logger.debug + + def log_row(row): + log("Row %r", sql_util._repr_row(row)) + return row - def log_row(row): - log("Row %r", sql_util._repr_row(row)) - return row + self._row_logging_fn = log_row + else: + log_row = None + + metadata = self._init_metadata(context, cursor_description) + + keymap = metadata._keymap + processors = metadata._processors + process_row = self._process_row + key_style = process_row._default_key_style + _make_row = functools.partial( + process_row, metadata, processors, keymap, key_style + ) + if log_row: - self._row_logging_fn = log_row + def make_row(row): + made_row = _make_row(row) + log_row(made_row) + return made_row - # this is a hook used by dialects to change the strategy, - # so for the moment we have to keep calling this every time - # :( - self.cursor_strategy = strat = context.get_result_cursor_strategy(self) + self._row_getter = make_row + else: + make_row = _make_row + self._set_memoized_attribute("_row_getter", make_row) - if strat.cursor_description is not None: - self._init_metadata(context, strat.cursor_description) else: self._metadata = _NO_RESULT_METADATA @@ -1251,19 +1212,41 @@ class BaseCursorResult(object): if context.compiled: if context.compiled._cached_metadata: cached_md = self.context.compiled._cached_metadata - self._metadata = cached_md self._metadata_from_cache = True + # result rewrite/ adapt step. two translations can occur here. + # one is if we are invoked against a cached statement, we want + # to rewrite the ResultMetaData to reflect the column objects + # that are in our current selectable, not the cached one. the + # other is, the CompileState can return an alternative Result + # object. Finally, CompileState might want to tell us to not + # actually do the ResultMetaData adapt step if it in fact has + # changed the selected columns in any case. + compiled = context.compiled + if ( + compiled + and not compiled._rewrites_selected_columns + and compiled.statement is not context.invoked_statement + ): + cached_md = cached_md._adapt_to_context(context) + + self._metadata = metadata = cached_md + else: self._metadata = ( - context.compiled._cached_metadata - ) = self._cursor_metadata(self, cursor_description) + metadata + ) = context.compiled._cached_metadata = self._cursor_metadata( + self, cursor_description + ) else: - self._metadata = self._cursor_metadata(self, cursor_description) + self._metadata = metadata = self._cursor_metadata( + self, cursor_description + ) if self._echo: context.engine.logger.debug( "Col %r", tuple(x[0] for x in cursor_description) ) + return metadata def _soft_close(self, hard=False): """Soft close this :class:`_engine.CursorResult`. @@ -1294,9 +1277,9 @@ class BaseCursorResult(object): if hard: self.closed = True - self.cursor_strategy.hard_close(self) + self.cursor_strategy.hard_close(self, self.cursor) else: - self.cursor_strategy.soft_close(self) + self.cursor_strategy.soft_close(self, self.cursor) if not self._soft_closed: cursor = self.cursor @@ -1632,19 +1615,19 @@ class CursorResult(BaseCursorResult, Result): fetchone = self.cursor_strategy.fetchone while True: - row = fetchone(self) + row = fetchone(self, self.cursor) if row is None: break yield row def _fetchone_impl(self, hard_close=False): - return self.cursor_strategy.fetchone(self, hard_close) + return self.cursor_strategy.fetchone(self, self.cursor, hard_close) def _fetchall_impl(self): - return self.cursor_strategy.fetchall(self) + return self.cursor_strategy.fetchall(self, self.cursor) def _fetchmany_impl(self, size=None): - return self.cursor_strategy.fetchmany(self, size) + return self.cursor_strategy.fetchmany(self, self.cursor, size) def _raw_row_iterator(self): return self._fetchiter_impl() @@ -1674,7 +1657,7 @@ class CursorResult(BaseCursorResult, Result): @_generative def yield_per(self, num): self._yield_per = num - self.cursor_strategy.yield_per(self, num) + self.cursor_strategy.yield_per(self, self.cursor, num) class LegacyCursorResult(CursorResult): diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index b5cb2a1b2..d0f5cfe96 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -709,6 +709,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext): returned_defaults = None execution_options = util.immutabledict() + cursor_fetch_strategy = _cursor._DEFAULT_FETCH + cache_stats = None invoked_statement = None @@ -745,9 +747,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self.compiled = compiled = compiled_ddl self.isddl = True - self.execution_options = compiled.execution_options.merge_with( - connection._execution_options, execution_options - ) + self.execution_options = execution_options self._is_future_result = ( connection._is_future @@ -802,13 +802,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self.invoked_statement = invoked_statement self.compiled = compiled - # this should be caught in the engine before - # we get here - assert compiled.can_execute - - self.execution_options = compiled.execution_options.merge_with( - connection._execution_options, execution_options - ) + self.execution_options = execution_options self._is_future_result = ( connection._is_future @@ -829,7 +823,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if self.isinsert or self.isupdate or self.isdelete: self.is_crud = True self._is_explicit_returning = bool(compiled.statement._returning) - self._is_implicit_returning = bool( + self._is_implicit_returning = ( compiled.returning and not compiled.statement._returning ) @@ -853,7 +847,10 @@ class DefaultExecutionContext(interfaces.ExecutionContext): # this must occur before create_cursor() since the statement # has to be regexed in some cases for server side cursor - self.unicode_statement = util.text_type(compiled) + if util.py2k: + self.unicode_statement = util.text_type(compiled.string) + else: + self.unicode_statement = compiled.string self.cursor = self.create_cursor() @@ -909,32 +906,38 @@ class DefaultExecutionContext(interfaces.ExecutionContext): # Convert the dictionary of bind parameter values # into a dict or list to be sent to the DBAPI's # execute() or executemany() method. + parameters = [] if compiled.positional: - parameters = [ - dialect.execute_sequence_format( - [ - processors[key](compiled_params[key]) - if key in processors - else compiled_params[key] - for key in positiontup - ] - ) - for compiled_params in self.compiled_parameters - ] + for compiled_params in self.compiled_parameters: + param = [ + processors[key](compiled_params[key]) + if key in processors + else compiled_params[key] + for key in positiontup + ] + parameters.append(dialect.execute_sequence_format(param)) else: encode = not dialect.supports_unicode_statements + if encode: + encoder = dialect._encoder + for compiled_params in self.compiled_parameters: - parameters = [ - { - dialect._encoder(key)[0] - if encode - else key: processors[key](value) - if key in processors - else value - for key, value in compiled_params.items() - } - for compiled_params in self.compiled_parameters - ] + if encode: + param = { + encoder(key)[0]: processors[key](compiled_params[key]) + if key in processors + else compiled_params[key] + for key in compiled_params + } + else: + param = { + key: processors[key](compiled_params[key]) + if key in processors + else compiled_params[key] + for key in compiled_params + } + + parameters.append(param) self.parameters = dialect.execute_sequence_format(parameters) @@ -958,9 +961,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self.dialect = connection.dialect self.is_text = True - self.execution_options = self.execution_options.merge_with( - connection._execution_options, execution_options - ) + self.execution_options = execution_options self._is_future_result = ( connection._is_future @@ -1011,9 +1012,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self._dbapi_connection = dbapi_connection self.dialect = connection.dialect - self.execution_options = self.execution_options.merge_with( - connection._execution_options, execution_options - ) + self.execution_options = execution_options self._is_future_result = ( connection._is_future @@ -1214,25 +1213,6 @@ class DefaultExecutionContext(interfaces.ExecutionContext): def handle_dbapi_exception(self, e): pass - def get_result_cursor_strategy(self, result): - """Dialect-overriable hook to return the internal strategy that - fetches results. - - - Some dialects will in some cases return special objects here that - have pre-buffered rows from some source or another, such as turning - Oracle OUT parameters into rows to accommodate for "returning", - SQL Server fetching "returning" before it resets "identity insert", - etc. - - """ - if self._is_server_side: - strat_cls = _cursor.BufferedRowCursorFetchStrategy - else: - strat_cls = _cursor.CursorFetchStrategy - - return strat_cls.create(result) - @property def rowcount(self): return self.cursor.rowcount @@ -1245,9 +1225,28 @@ class DefaultExecutionContext(interfaces.ExecutionContext): def _setup_result_proxy(self): if self.is_crud or self.is_text: - result = self._setup_crud_result_proxy() + result = self._setup_dml_or_text_result() else: - result = _cursor.CursorResult._create_for_context(self) + strategy = self.cursor_fetch_strategy + if self._is_server_side and strategy is _cursor._DEFAULT_FETCH: + strategy = _cursor.BufferedRowCursorFetchStrategy( + self.cursor, self.execution_options + ) + cursor_description = ( + strategy.alternate_cursor_description + or self.cursor.description + ) + if cursor_description is None: + strategy = _cursor._NO_CURSOR_DQL + + if self._is_future_result: + result = _cursor.CursorResult( + self, strategy, cursor_description + ) + else: + result = _cursor.LegacyCursorResult( + self, strategy, cursor_description + ) if ( self.compiled @@ -1256,33 +1255,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext): ): self._setup_out_parameters(result) - if not self._is_future_result: - conn = self.root_connection - assert not conn._is_future - - if not result._soft_closed and conn.should_close_with_result: - result._autoclose_connection = True - self._soft_closed = result._soft_closed - # result rewrite/ adapt step. two translations can occur here. - # one is if we are invoked against a cached statement, we want - # to rewrite the ResultMetaData to reflect the column objects - # that are in our current selectable, not the cached one. the - # other is, the CompileState can return an alternative Result - # object. Finally, CompileState might want to tell us to not - # actually do the ResultMetaData adapt step if it in fact has - # changed the selected columns in any case. - compiled = self.compiled - if compiled: - adapt_metadata = ( - result._metadata_from_cache - and not compiled._rewrites_selected_columns - ) - - if adapt_metadata: - result._metadata = result._metadata._adapt_to_context(self) - return result def _setup_out_parameters(self, result): @@ -1313,7 +1287,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): result.out_parameters = out_parameters - def _setup_crud_result_proxy(self): + def _setup_dml_or_text_result(self): if self.isinsert and not self.executemany: if ( not self._is_implicit_returning @@ -1326,7 +1300,23 @@ class DefaultExecutionContext(interfaces.ExecutionContext): elif not self._is_implicit_returning: self._setup_ins_pk_from_empty() - result = _cursor.CursorResult._create_for_context(self) + strategy = self.cursor_fetch_strategy + if self._is_server_side and strategy is _cursor._DEFAULT_FETCH: + strategy = _cursor.BufferedRowCursorFetchStrategy( + self.cursor, self.execution_options + ) + cursor_description = ( + strategy.alternate_cursor_description or self.cursor.description + ) + if cursor_description is None: + strategy = _cursor._NO_CURSOR_DML + + if self._is_future_result: + result = _cursor.CursorResult(self, strategy, cursor_description) + else: + result = _cursor.LegacyCursorResult( + self, strategy, cursor_description + ) if self.isinsert: if self._is_implicit_returning: diff --git a/lib/sqlalchemy/engine/events.py b/lib/sqlalchemy/engine/events.py index 293c7afdd..ef760bb54 100644 --- a/lib/sqlalchemy/engine/events.py +++ b/lib/sqlalchemy/engine/events.py @@ -184,15 +184,11 @@ class ConnectionEvents(event.Events): :meth:`_engine.Connection.execute`. :param multiparams: Multiple parameter sets, a list of dictionaries. :param params: Single parameter set, a single dictionary. - :param execution_options: dictionary of per-execution execution - options passed along with the statement, if any. This only applies to - the the SQLAlchemy 2.0 version of :meth:`_engine.Connection.execute` - . To - view all execution options associated with the connection, access the - :meth:`_engine.Connection.get_execution_options` - method to view the fixed - execution options dictionary, then consider elements within this local - dictionary to be unioned into that dictionary. + :param execution_options: dictionary of execution + options passed along with the statement, if any. This is a merge + of all options that will be used, including those of the statement, + the connection, and those passed in to the method itself for + the 2.0 style of execution. .. versionadded: 1.4 @@ -231,15 +227,11 @@ class ConnectionEvents(event.Events): :meth:`_engine.Connection.execute`. :param multiparams: Multiple parameter sets, a list of dictionaries. :param params: Single parameter set, a single dictionary. - :param execution_options: dictionary of per-execution execution - options passed along with the statement, if any. This only applies to - the the SQLAlchemy 2.0 version of :meth:`_engine.Connection.execute` - . To - view all execution options associated with the connection, access the - :meth:`_engine.Connection.get_execution_options` - method to view the fixed - execution options dictionary, then consider elements within this local - dictionary to be unioned into that dictionary. + :param execution_options: dictionary of execution + options passed along with the statement, if any. This is a merge + of all options that will be used, including those of the statement, + the connection, and those passed in to the method itself for + the 2.0 style of execution. .. versionadded: 1.4 diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 0ee80ede4..600229037 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -36,8 +36,11 @@ else: return lambda row: (it(row),) def _row_as_tuple(*indexes): + # circumvent LegacyRow.__getitem__ pointing to + # _get_by_key_impl_mapping for now. otherwise we could + # use itemgetter getters = [ - operator.methodcaller("_get_by_key_impl_mapping", index) + operator.methodcaller("_get_by_int_impl", index) for index in indexes ] return lambda rec: tuple([getter(rec) for getter in getters]) @@ -64,10 +67,7 @@ class ResultMetaData(object): def _key_fallback(self, key, err, raiseerr=True): assert raiseerr - if isinstance(key, int): - util.raise_(IndexError(key), replace_context=err) - else: - util.raise_(KeyError(key), replace_context=err) + util.raise_(KeyError(key), replace_context=err) def _warn_for_nonint(self, key): raise TypeError( @@ -94,7 +94,7 @@ class ResultMetaData(object): return None def _row_as_tuple_getter(self, keys): - indexes = list(self._indexes_for_keys(keys)) + indexes = self._indexes_for_keys(keys) return _row_as_tuple(*indexes) @@ -154,19 +154,15 @@ class SimpleResultMetaData(ResultMetaData): self._tuplefilter = _tuplefilter self._translated_indexes = _translated_indexes self._unique_filters = _unique_filters - len_keys = len(self._keys) if extra: recs_names = [ - ( - (index, name, index - len_keys) + extras, - (index, name, extras), - ) + ((name,) + extras, (index, name, extras),) for index, (name, extras) in enumerate(zip(self._keys, extra)) ] else: recs_names = [ - ((index, name, index - len_keys), (index, name, ())) + ((name,), (index, name, ())) for index, name in enumerate(self._keys) ] @@ -212,6 +208,8 @@ class SimpleResultMetaData(ResultMetaData): return value in row._data def _index_for_key(self, key, raiseerr=True): + if int in key.__class__.__mro__: + key = self._keys[key] try: rec = self._keymap[key] except KeyError as ke: @@ -220,11 +218,13 @@ class SimpleResultMetaData(ResultMetaData): return rec[0] def _indexes_for_keys(self, keys): - for rec in self._metadata_for_keys(keys): - yield rec[0] + return [self._keymap[key][0] for key in keys] def _metadata_for_keys(self, keys): for key in keys: + if int in key.__class__.__mro__: + key = self._keys[key] + try: rec = self._keymap[key] except KeyError as ke: @@ -234,7 +234,12 @@ class SimpleResultMetaData(ResultMetaData): def _reduce(self, keys): try: - metadata_for_keys = [self._keymap[key] for key in keys] + metadata_for_keys = [ + self._keymap[ + self._keys[key] if int in key.__class__.__mro__ else key + ] + for key in keys + ] except KeyError as ke: self._key_fallback(ke.args[0], ke, True) @@ -508,12 +513,11 @@ class Result(InPlaceGenerative): @_generative def _column_slices(self, indexes): - self._metadata = self._metadata._reduce(indexes) - if self._source_supports_scalars and len(indexes) == 1: self._generate_rows = False else: self._generate_rows = True + self._metadata = self._metadata._reduce(indexes) def _getter(self, key, raiseerr=True): """return a callable that will retrieve the given key from a @@ -551,10 +555,15 @@ class Result(InPlaceGenerative): :return: this :class:`._engine.Result` object with modifications. """ + + if self._source_supports_scalars: + self._metadata = self._metadata._reduce([0]) + self._post_creational_filter = operator.attrgetter("_mapping") self._no_scalar_onerow = False self._generate_rows = True + @HasMemoized.memoized_attribute def _row_getter(self): if self._source_supports_scalars: if not self._generate_rows: @@ -571,6 +580,7 @@ class Result(InPlaceGenerative): else: process_row = self._process_row + key_style = self._process_row._default_key_style metadata = self._metadata @@ -578,7 +588,7 @@ class Result(InPlaceGenerative): processors = metadata._processors tf = metadata._tuplefilter - if tf: + if tf and not self._source_supports_scalars: if processors: processors = tf(processors) @@ -660,7 +670,7 @@ class Result(InPlaceGenerative): @HasMemoized.memoized_attribute def _iterator_getter(self): - make_row = self._row_getter() + make_row = self._row_getter post_creational_filter = self._post_creational_filter @@ -689,60 +699,44 @@ class Result(InPlaceGenerative): return iterrows - @HasMemoized.memoized_attribute - def _allrow_getter(self): + def _raw_all_rows(self): + make_row = self._row_getter + rows = self._fetchall_impl() + return [make_row(row) for row in rows] - make_row = self._row_getter() + def _allrows(self): + + make_row = self._row_getter + + rows = self._fetchall_impl() + if make_row: + made_rows = [make_row(row) for row in rows] + else: + made_rows = rows post_creational_filter = self._post_creational_filter if self._unique_filter_state: uniques, strategy = self._unique_strategy - def allrows(self): - rows = self._fetchall_impl() - if make_row: - made_rows = [make_row(row) for row in rows] - else: - made_rows = rows - rows = [ - made_row - for made_row, sig_row in [ - ( - made_row, - strategy(made_row) if strategy else made_row, - ) - for made_row in made_rows - ] - if sig_row not in uniques and not uniques.add(sig_row) + rows = [ + made_row + for made_row, sig_row in [ + (made_row, strategy(made_row) if strategy else made_row,) + for made_row in made_rows ] - - if post_creational_filter: - rows = [post_creational_filter(row) for row in rows] - return rows - + if sig_row not in uniques and not uniques.add(sig_row) + ] else: + rows = made_rows - def allrows(self): - rows = self._fetchall_impl() - - if post_creational_filter: - if make_row: - rows = [ - post_creational_filter(make_row(row)) - for row in rows - ] - else: - rows = [post_creational_filter(row) for row in rows] - elif make_row: - rows = [make_row(row) for row in rows] - return rows - - return allrows + if post_creational_filter: + rows = [post_creational_filter(row) for row in rows] + return rows @HasMemoized.memoized_attribute def _onerow_getter(self): - make_row = self._row_getter() + make_row = self._row_getter post_creational_filter = self._post_creational_filter @@ -782,7 +776,7 @@ class Result(InPlaceGenerative): @HasMemoized.memoized_attribute def _manyrow_getter(self): - make_row = self._row_getter() + make_row = self._row_getter post_creational_filter = self._post_creational_filter @@ -884,7 +878,7 @@ class Result(InPlaceGenerative): def fetchall(self): """A synonym for the :meth:`_engine.Result.all` method.""" - return self._allrow_getter(self) + return self._allrows() def fetchone(self): """Fetch one row. @@ -955,7 +949,7 @@ class Result(InPlaceGenerative): may be returned. """ - return self._allrow_getter(self) + return self._allrows() def _only_one_row(self, raise_for_second_row, raise_for_none): onerow = self._fetchone_impl @@ -969,7 +963,7 @@ class Result(InPlaceGenerative): else: return None - make_row = self._row_getter() + make_row = self._row_getter row = make_row(row) if make_row else row @@ -1236,13 +1230,11 @@ class ChunkedIteratorResult(IteratorResult): self.raw = raw self.iterator = itertools.chain.from_iterable(self.chunks(None)) - def _column_slices(self, indexes): - result = super(ChunkedIteratorResult, self)._column_slices(indexes) - return result - @_generative def yield_per(self, num): self._yield_per = num + # TODO: this should raise if the iterator has already been started. + # we can't change the yield mid-stream like this self.iterator = itertools.chain.from_iterable(self.chunks(num)) diff --git a/lib/sqlalchemy/engine/row.py b/lib/sqlalchemy/engine/row.py index 70f45c82c..fe6831e30 100644 --- a/lib/sqlalchemy/engine/row.py +++ b/lib/sqlalchemy/engine/row.py @@ -103,8 +103,10 @@ except ImportError: def __getitem__(self, key): return self._data[key] + _get_by_int_impl = __getitem__ + def _get_by_key_impl(self, key): - if self._key_style == KEY_INTEGER_ONLY: + if int in key.__class__.__mro__: return self._data[key] # the following is all LegacyRow support. none of this @@ -125,11 +127,7 @@ except ImportError: if mdindex is None: self._parent._raise_for_ambiguous_column_name(rec) - elif ( - self._key_style == KEY_OBJECTS_BUT_WARN - and mdindex != key - and not isinstance(key, int) - ): + elif self._key_style == KEY_OBJECTS_BUT_WARN and mdindex != key: self._parent._warn_for_nonint(key) return self._data[mdindex] diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py index 112e245f7..7ac556dcc 100644 --- a/lib/sqlalchemy/ext/baked.py +++ b/lib/sqlalchemy/ext/baked.py @@ -230,10 +230,6 @@ class BakedQuery(object): # invoked statement = query._statement_20(orm_results=True) - # the before_compile() event can create a new Query object - # before it makes the statement. - query = statement.compile_options._orm_query - # if the query is not safe to cache, we still do everything as though # we did cache it, since the receiver of _bake() assumes subqueryload # context was set up, etc. @@ -243,7 +239,7 @@ class BakedQuery(object): # used by the Connection, which in itself is more expensive to # generate than what BakedQuery was able to provide in 1.3 and prior - if query.compile_options._bake_ok: + if statement.compile_options._bake_ok: self._bakery[self._effective_key(session)] = ( query, statement, @@ -383,7 +379,7 @@ class Result(object): return str(self._as_query()) def __iter__(self): - return iter(self._iter()) + return self._iter().__iter__() def _iter(self): bq = self.bq @@ -397,12 +393,14 @@ class Result(object): if query is None: query, statement = bq._bake(self.session) - q = query.params(self._params) + if self._params: + q = query.params(self._params) + else: + q = query for fn in self._post_criteria: q = fn(q) params = q.load_options._params - q.load_options += {"_orm_query": q} execution_options = dict(q._execution_options) execution_options.update( { @@ -463,16 +461,15 @@ class Result(object): Equivalent to :meth:`_query.Query.first`. """ + bq = self.bq.with_criteria(lambda q: q.slice(0, 1)) - ret = list( + return ( bq.for_session(self.session) .params(self._params) ._using_post_criteria(self._post_criteria) + ._iter() + .first() ) - if len(ret) > 0: - return ret[0] - else: - return None def one(self): """Return exactly one result or raise an exception. @@ -480,19 +477,7 @@ class Result(object): Equivalent to :meth:`_query.Query.one`. """ - try: - ret = self.one_or_none() - except orm_exc.MultipleResultsFound as err: - util.raise_( - orm_exc.MultipleResultsFound( - "Multiple rows were found for one()" - ), - replace_context=err, - ) - else: - if ret is None: - raise orm_exc.NoResultFound("No row was found for one()") - return ret + return self._iter().one() def one_or_none(self): """Return one or zero results, or raise an exception for multiple @@ -503,17 +488,7 @@ class Result(object): .. versionadded:: 1.0.9 """ - ret = list(self) - - l = len(ret) - if l == 1: - return ret[0] - elif l == 0: - return None - else: - raise orm_exc.MultipleResultsFound( - "Multiple rows were found for one_or_none()" - ) + return self._iter().one_or_none() def all(self): """Return all rows. diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 1375a24cd..c3ac71c10 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -15,8 +15,10 @@ the source distribution. """ -from sqlalchemy import event +from .. import event +from .. import exc from .. import inspect +from .. import util from ..orm.query import Query from ..orm.session import Session @@ -28,6 +30,7 @@ class ShardedQuery(Query): super(ShardedQuery, self).__init__(*args, **kwargs) self.id_chooser = self.session.id_chooser self.query_chooser = self.session.query_chooser + self.execute_chooser = self.session.execute_chooser self._shard_id = None def set_shard(self, shard_id): @@ -45,10 +48,7 @@ class ShardedQuery(Query): ) """ - - q = self._clone() - q._shard_id = shard_id - return q + return self.execution_options(_sa_shard_id=shard_id) def _execute_crud(self, stmt, mapper): def exec_for_shard(shard_id): @@ -68,7 +68,8 @@ class ShardedQuery(Query): else: rowcount = 0 results = [] - for shard_id in self.query_chooser(self): + # TODO: this will have to be the new object + for shard_id in self.execute_chooser(self): result = exec_for_shard(shard_id) rowcount += result.rowcount results.append(result) @@ -107,7 +108,7 @@ class ShardedSession(Session): self, shard_chooser, id_chooser, - query_chooser, + execute_chooser=None, shards=None, query_cls=ShardedQuery, **kwargs @@ -125,14 +126,19 @@ class ShardedSession(Session): values, which should return a list of shard ids where the ID might reside. The databases will be queried in the order of this listing. - :param query_chooser: For a given Query, returns the list of shard_ids + :param execute_chooser: For a given :class:`.ORMExecuteState`, + returns the list of shard_ids where the query should be issued. Results from all shards returned will be combined together into a single listing. + .. versionchanged:: 1.4 The ``execute_chooser`` paramter + supersedes the ``query_chooser`` parameter. + :param shards: A dictionary of string shard names to :class:`~sqlalchemy.engine.Engine` objects. """ + query_chooser = kwargs.pop("query_chooser", None) super(ShardedSession, self).__init__(query_cls=query_cls, **kwargs) event.listen( @@ -140,6 +146,25 @@ class ShardedSession(Session): ) self.shard_chooser = shard_chooser self.id_chooser = id_chooser + + if query_chooser: + util.warn_deprecated( + "The ``query_choser`` parameter is deprecated; " + "please use ``execute_chooser``.", + "1.4", + ) + if execute_chooser: + raise exc.ArgumentError( + "Can't pass query_chooser and execute_chooser " + "at the same time." + ) + + def execute_chooser(orm_context): + return query_chooser(orm_context.statement) + + self.execute_chooser = execute_chooser + else: + self.execute_chooser = execute_chooser self.query_chooser = query_chooser self.__binds = {} if shards is not None: @@ -241,13 +266,13 @@ def execute_and_instances(orm_context): load_options = orm_context.load_options session = orm_context.session - orm_query = orm_context.orm_query + # orm_query = orm_context.orm_query if params is None: params = load_options._params def iter_for_shard(shard_id, load_options): - execution_options = dict(orm_context.execution_options) + execution_options = dict(orm_context.local_execution_options) bind_arguments = dict(orm_context.bind_arguments) bind_arguments["_horizontal_shard"] = True @@ -265,8 +290,8 @@ def execute_and_instances(orm_context): if load_options._refresh_identity_token is not None: shard_id = load_options._refresh_identity_token - elif orm_query is not None and orm_query._shard_id is not None: - shard_id = orm_query._shard_id + elif "_sa_shard_id" in orm_context.merged_execution_options: + shard_id = orm_context.merged_execution_options["_sa_shard_id"] elif "shard_id" in orm_context.bind_arguments: shard_id = orm_context.bind_arguments["shard_id"] else: @@ -276,9 +301,7 @@ def execute_and_instances(orm_context): return iter_for_shard(shard_id, load_options) else: partial = [] - for shard_id in session.query_chooser( - orm_query if orm_query is not None else orm_context.statement - ): + for shard_id in session.execute_chooser(orm_context): result_ = iter_for_shard(shard_id, load_options) partial.append(result_) diff --git a/lib/sqlalchemy/future/selectable.py b/lib/sqlalchemy/future/selectable.py index 58fced887..74cc13501 100644 --- a/lib/sqlalchemy/future/selectable.py +++ b/lib/sqlalchemy/future/selectable.py @@ -98,6 +98,15 @@ class Select(_LegacySelect): ] return self.filter(*clauses) + @property + def column_descriptions(self): + """Return a 'column descriptions' structure which may be + plugin-specific. + + """ + meth = SelectState.get_plugin_class(self).get_column_descriptions + return meth(self) + @_generative def join(self, target, onclause=None, isouter=False, full=False): r"""Create a SQL JOIN against this :class:`_expresson.Select` diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 3acab7df7..09f3e7a12 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -39,11 +39,12 @@ from ..sql.visitors import InternalTraversal _path_registry = PathRegistry.root +_EMPTY_DICT = util.immutabledict() + class QueryContext(object): __slots__ = ( "compile_state", - "orm_query", "query", "load_options", "bind_arguments", @@ -74,8 +75,7 @@ class QueryContext(object): _yield_per = None _refresh_state = None _lazy_loaded_from = None - _orm_query = None - _params = util.immutabledict() + _params = _EMPTY_DICT def __init__( self, @@ -87,10 +87,9 @@ class QueryContext(object): ): self.load_options = load_options - self.execution_options = execution_options or {} - self.bind_arguments = bind_arguments or {} + self.execution_options = execution_options or _EMPTY_DICT + self.bind_arguments = bind_arguments or _EMPTY_DICT self.compile_state = compile_state - self.orm_query = compile_state.orm_query self.query = query = compile_state.query self.session = session @@ -118,20 +117,16 @@ class QueryContext(object): % ", ".join(compile_state._no_yield_pers) ) - @property - def is_single_entity(self): - # used for the check if we return a list of entities or tuples. - # this is gone in 2.0 when we no longer make this decision. - return ( - not self.load_options._only_return_tuples - and len(self.compile_state._entities) == 1 - and self.compile_state._entities[0].supports_single_entity - ) + def dispose(self): + self.attributes.clear() + self.load_options._refresh_state = None + self.load_options._lazy_loaded_from = None class ORMCompileState(CompileState): class default_compile_options(CacheableOptions): _cache_key_traversal = [ + ("_use_legacy_query_style", InternalTraversal.dp_boolean), ("_orm_results", InternalTraversal.dp_boolean), ("_bake_ok", InternalTraversal.dp_boolean), ( @@ -140,7 +135,6 @@ class ORMCompileState(CompileState): ), ("_current_path", InternalTraversal.dp_has_cache_key), ("_enable_single_crit", InternalTraversal.dp_boolean), - ("_statement", InternalTraversal.dp_clauseelement), ("_enable_eagerloads", InternalTraversal.dp_boolean), ("_orm_only_from_obj_alias", InternalTraversal.dp_boolean), ("_only_load_props", InternalTraversal.dp_plain_obj), @@ -148,6 +142,7 @@ class ORMCompileState(CompileState): ("_for_refresh_state", InternalTraversal.dp_boolean), ] + _use_legacy_query_style = False _orm_results = True _bake_ok = True _with_polymorphic_adapt_map = () @@ -159,37 +154,36 @@ class ORMCompileState(CompileState): _set_base_alias = False _for_refresh_state = False - # non-cache-key elements mostly for legacy use - _statement = None - _orm_query = None - @classmethod def merge(cls, other): return cls + other._state_dict() - orm_query = None current_path = _path_registry def __init__(self, *arg, **kw): raise NotImplementedError() + def dispose(self): + self.attributes.clear() + @classmethod def create_for_statement(cls, statement_container, compiler, **kw): raise NotImplementedError() @classmethod - def _create_for_legacy_query(cls, query, for_statement=False): + def _create_for_legacy_query(cls, query, toplevel, for_statement=False): stmt = query._statement_20(orm_results=not for_statement) - if query.compile_options._statement is not None: - compile_state_cls = ORMFromStatementCompileState - else: - compile_state_cls = ORMSelectCompileState + # this chooses between ORMFromStatementCompileState and + # ORMSelectCompileState. We could also base this on + # query._statement is not None as we have the ORM Query here + # however this is the more general path. + compile_state_cls = CompileState._get_plugin_class_for_plugin( + stmt, "orm" + ) - # true in all cases except for two tests in test/orm/test_events.py - # assert stmt.compile_options._orm_query is query return compile_state_cls._create_for_statement_or_query( - stmt, for_statement=for_statement + stmt, toplevel, for_statement=for_statement ) @classmethod @@ -199,6 +193,10 @@ class ORMCompileState(CompileState): raise NotImplementedError() @classmethod + def get_column_descriptions(self, statement): + return _column_descriptions(statement) + + @classmethod def orm_pre_session_exec( cls, session, statement, execution_options, bind_arguments ): @@ -219,10 +217,16 @@ class ORMCompileState(CompileState): # as the statement is built. "subject" mapper is the generally # standard object used as an identifier for multi-database schemes. - if "plugin_subject" in statement._propagate_attrs: - bind_arguments["mapper"] = statement._propagate_attrs[ - "plugin_subject" - ].mapper + # we are here based on the fact that _propagate_attrs contains + # "compile_state_plugin": "orm". The "plugin_subject" + # needs to be present as well. + + try: + plugin_subject = statement._propagate_attrs["plugin_subject"] + except KeyError: + assert False, "statement had 'orm' plugin but no plugin_subject" + else: + bind_arguments["mapper"] = plugin_subject.mapper if load_options._autoflush: session._autoflush() @@ -296,11 +300,14 @@ class ORMFromStatementCompileState(ORMCompileState): @classmethod def create_for_statement(cls, statement_container, compiler, **kw): compiler._rewrites_selected_columns = True - return cls._create_for_statement_or_query(statement_container) + toplevel = not compiler.stack + return cls._create_for_statement_or_query( + statement_container, toplevel + ) @classmethod def _create_for_statement_or_query( - cls, statement_container, for_statement=False, + cls, statement_container, toplevel, for_statement=False, ): # from .query import FromStatement @@ -309,8 +316,9 @@ class ORMFromStatementCompileState(ORMCompileState): self = cls.__new__(cls) self._primary_entity = None - self.orm_query = statement_container.compile_options._orm_query - + self.use_orm_style = ( + statement_container.compile_options._use_legacy_query_style + ) self.statement_container = self.query = statement_container self.requested_statement = statement_container.element @@ -325,12 +333,13 @@ class ORMFromStatementCompileState(ORMCompileState): self.current_path = statement_container.compile_options._current_path - if statement_container._with_options: + if toplevel and statement_container._with_options: self.attributes = {"_unbound_load_dedupes": set()} for opt in statement_container._with_options: if opt._is_compile_state: opt.process_compile_state(self) + else: self.attributes = {} @@ -411,24 +420,24 @@ class ORMSelectCompileState(ORMCompileState, SelectState): _where_criteria = () _having_criteria = () - orm_query = None - @classmethod def create_for_statement(cls, statement, compiler, **kw): if not statement._is_future: return SelectState(statement, compiler, **kw) + toplevel = not compiler.stack + compiler._rewrites_selected_columns = True orm_state = cls._create_for_statement_or_query( - statement, for_statement=True + statement, for_statement=True, toplevel=toplevel ) SelectState.__init__(orm_state, orm_state.statement, compiler, **kw) return orm_state @classmethod def _create_for_statement_or_query( - cls, query, for_statement=False, _entities_only=False, + cls, query, toplevel, for_statement=False, _entities_only=False ): assert isinstance(query, future.Select) @@ -440,9 +449,8 @@ class ORMSelectCompileState(ORMCompileState, SelectState): self._primary_entity = None - self.orm_query = query.compile_options._orm_query - self.query = query + self.use_orm_style = query.compile_options._use_legacy_query_style self.select_statement = select_statement = query @@ -484,7 +492,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): # rather than LABEL_STYLE_NONE, and if we can use disambiguate style # for new style ORM selects too. if self.select_statement._label_style is LABEL_STYLE_NONE: - if self.orm_query and not for_statement: + if self.use_orm_style and not for_statement: self.label_style = LABEL_STYLE_TABLENAME_PLUS_COL else: self.label_style = LABEL_STYLE_DISAMBIGUATE_ONLY @@ -495,7 +503,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): self.eager_order_by = () - if select_statement._with_options: + if toplevel and select_statement._with_options: self.attributes = {"_unbound_load_dedupes": set()} for opt in self.select_statement._with_options: diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 616e757a3..44ab7dd63 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -14,8 +14,6 @@ as well as some of the attribute loading strategies. """ from __future__ import absolute_import -import collections - from . import attributes from . import exc as orm_exc from . import path_registry @@ -57,7 +55,11 @@ def instances(cursor, context): compile_state = context.compile_state filtered = compile_state._has_mapper_entities - single_entity = context.is_single_entity + single_entity = ( + not context.load_options._only_return_tuples + and len(compile_state._entities) == 1 + and compile_state._entities[0].supports_single_entity + ) try: (process, labels, extra) = list( @@ -105,7 +107,7 @@ def instances(cursor, context): if not fetch: break else: - fetch = cursor.fetchall() + fetch = cursor._raw_all_rows() if single_entity: proc = process[0] @@ -123,6 +125,10 @@ def instances(cursor, context): if not yield_per: break + context.dispose() + if not cursor.context.compiled.cache_key: + compile_state.attributes.clear() + result = ChunkedIteratorResult( row_metadata, chunks, source_supports_scalars=single_entity, raw=cursor ) @@ -347,12 +353,6 @@ def load_on_pk_identity( q.compile_options ) - # checking that query doesnt have criteria on it - # just delete it here w/ optional assertion? since we are setting a - # where clause also - if refresh_state is None: - _no_criterion_assertion(q, "get", order_by=False, distinct=False) - if primary_key_identity is not None: # mapper = query._only_full_mapper_zero("load_on_pk_identity") @@ -446,24 +446,6 @@ def load_on_pk_identity( return None -def _no_criterion_assertion(stmt, meth, order_by=True, distinct=True): - if ( - stmt._where_criteria - or stmt.compile_options._statement is not None - or stmt._from_obj - or stmt._legacy_setup_joins - or stmt._limit_clause is not None - or stmt._offset_clause is not None - or stmt._group_by_clauses - or (order_by and stmt._order_by_clauses) - or (distinct and stmt._distinct) - ): - raise sa_exc.InvalidRequestError( - "Query.%s() being called on a " - "Query with existing criterion. " % meth - ) - - def _set_get_options( compile_opt, load_opt, @@ -587,49 +569,110 @@ def _instance_processor( # performance-critical section in the whole ORM. identity_class = mapper._identity_class + compile_state = context.compile_state - populators = collections.defaultdict(list) + # look for "row getter" functions that have been assigned along + # with the compile state that were cached from a previous load. + # these are operator.itemgetter() objects that each will extract a + # particular column from each row. + + getter_key = ("getters", mapper) + getters = path.get(compile_state.attributes, getter_key, None) + + if getters is None: + # no getters, so go through a list of attributes we are loading for, + # and the ones that are column based will have already put information + # for us in another collection "memoized_setups", which represents the + # output of the LoaderStrategy.setup_query() method. We can just as + # easily call LoaderStrategy.create_row_processor for each, but by + # getting it all at once from setup_query we save another method call + # per attribute. + props = mapper._prop_set + if only_load_props is not None: + props = props.intersection( + mapper._props[k] for k in only_load_props + ) - props = mapper._prop_set - if only_load_props is not None: - props = props.intersection(mapper._props[k] for k in only_load_props) + quick_populators = path.get( + context.attributes, "memoized_setups", _none_set + ) - quick_populators = path.get( - context.attributes, "memoized_setups", _none_set - ) + todo = [] + cached_populators = { + "new": [], + "quick": [], + "deferred": [], + "expire": [], + "delayed": [], + "existing": [], + "eager": [], + } + + if refresh_state is None: + # we can also get the "primary key" tuple getter function + pk_cols = mapper.primary_key - for prop in props: - if prop in quick_populators: - # this is an inlined path just for column-based attributes. - col = quick_populators[prop] - if col is _DEFER_FOR_STATE: - populators["new"].append( - (prop.key, prop._deferred_column_loader) - ) - elif col is _SET_DEFERRED_EXPIRED: - # note that in this path, we are no longer - # searching in the result to see if the column might - # be present in some unexpected way. - populators["expire"].append((prop.key, False)) - elif col is _RAISE_FOR_STATE: - populators["new"].append((prop.key, prop._raise_column_loader)) - else: - getter = None - if not getter: - getter = result._getter(col, False) - if getter: - populators["quick"].append((prop.key, getter)) - else: - # fall back to the ColumnProperty itself, which - # will iterate through all of its columns - # to see if one fits - prop.create_row_processor( - context, path, mapper, result, adapter, populators - ) + if adapter: + pk_cols = [adapter.columns[c] for c in pk_cols] + primary_key_getter = result._tuple_getter(pk_cols) else: - prop.create_row_processor( - context, path, mapper, result, adapter, populators - ) + primary_key_getter = None + + getters = { + "cached_populators": cached_populators, + "todo": todo, + "primary_key_getter": primary_key_getter, + } + for prop in props: + if prop in quick_populators: + # this is an inlined path just for column-based attributes. + col = quick_populators[prop] + if col is _DEFER_FOR_STATE: + cached_populators["new"].append( + (prop.key, prop._deferred_column_loader) + ) + elif col is _SET_DEFERRED_EXPIRED: + # note that in this path, we are no longer + # searching in the result to see if the column might + # be present in some unexpected way. + cached_populators["expire"].append((prop.key, False)) + elif col is _RAISE_FOR_STATE: + cached_populators["new"].append( + (prop.key, prop._raise_column_loader) + ) + else: + getter = None + if not getter: + getter = result._getter(col, False) + if getter: + cached_populators["quick"].append((prop.key, getter)) + else: + # fall back to the ColumnProperty itself, which + # will iterate through all of its columns + # to see if one fits + prop.create_row_processor( + context, + path, + mapper, + result, + adapter, + cached_populators, + ) + else: + # loader strategries like subqueryload, selectinload, + # joinedload, basically relationships, these need to interact + # with the context each time to work correctly. + todo.append(prop) + + path.set(compile_state.attributes, getter_key, getters) + + cached_populators = getters["cached_populators"] + + populators = {key: list(value) for key, value in cached_populators.items()} + for prop in getters["todo"]: + prop.create_row_processor( + context, path, mapper, result, adapter, populators + ) propagated_loader_options = context.propagated_loader_options load_path = ( @@ -707,11 +750,7 @@ def _instance_processor( else: refresh_identity_key = None - pk_cols = mapper.primary_key - - if adapter: - pk_cols = [adapter.columns[c] for c in pk_cols] - tuple_getter = result._tuple_getter(pk_cols) + primary_key_getter = getters["primary_key_getter"] if mapper.allow_partial_pks: is_not_primary_key = _none_set.issuperset @@ -732,7 +771,11 @@ def _instance_processor( else: # look at the row, see if that identity is in the # session, or we have to create a new one - identitykey = (identity_class, tuple_getter(row), identity_token) + identitykey = ( + identity_class, + primary_key_getter(row), + identity_token, + ) instance = session_identity_map.get(identitykey) @@ -875,7 +918,7 @@ def _instance_processor( def ensure_no_pk(row): identitykey = ( identity_class, - tuple_getter(row), + primary_key_getter(row), identity_token, ) if not is_not_primary_key(identitykey[1]): diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 25d6f4736..97a81e30f 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -128,6 +128,7 @@ class Query( _aliased_generation = None _enable_assertions = True _last_joined_entity = None + _statement = None # mirrors that of ClauseElement, used to propagate the "orm" # plugin as well as the "subject" of the plugin, e.g. the mapper @@ -232,7 +233,7 @@ class Query( return if ( self._where_criteria - or self.compile_options._statement is not None + or self._statement is not None or self._from_obj or self._legacy_setup_joins or self._limit_clause is not None @@ -250,7 +251,7 @@ class Query( self._no_criterion_assertion(meth, order_by, distinct) self._from_obj = self._legacy_setup_joins = () - if self.compile_options._statement is not None: + if self._statement is not None: self.compile_options += {"_statement": None} self._where_criteria = () self._distinct = False @@ -270,7 +271,7 @@ class Query( def _no_statement_condition(self, meth): if not self._enable_assertions: return - if self.compile_options._statement is not None: + if self._statement is not None: raise sa_exc.InvalidRequestError( ( "Query.%s() being called on a Query with an existing full " @@ -356,7 +357,6 @@ class Query( if ( not self.compile_options._set_base_alias and not self.compile_options._with_polymorphic_adapt_map - # and self.compile_options._statement is None ): # if we don't have legacy top level aliasing features in use # then convert to a future select() directly @@ -383,48 +383,25 @@ class Query( if not fn._bake_ok: self.compile_options += {"_bake_ok": False} - if self.compile_options._statement is not None: - stmt = FromStatement( - self._raw_columns, self.compile_options._statement - ) - # TODO: once SubqueryLoader uses select(), we can remove - # "_orm_query" from this structure + compile_options = self.compile_options + compile_options += {"_use_legacy_query_style": True} + + if self._statement is not None: + stmt = FromStatement(self._raw_columns, self._statement) stmt.__dict__.update( _with_options=self._with_options, _with_context_options=self._with_context_options, - compile_options=self.compile_options - + {"_orm_query": self.with_session(None)}, + compile_options=compile_options, _execution_options=self._execution_options, ) stmt._propagate_attrs = self._propagate_attrs else: + # Query / select() internal attributes are 99% cross-compatible stmt = FutureSelect.__new__(FutureSelect) - + stmt.__dict__.update(self.__dict__) stmt.__dict__.update( - _raw_columns=self._raw_columns, - _where_criteria=self._where_criteria, - _from_obj=self._from_obj, - _legacy_setup_joins=self._legacy_setup_joins, - _order_by_clauses=self._order_by_clauses, - _group_by_clauses=self._group_by_clauses, - _having_criteria=self._having_criteria, - _distinct=self._distinct, - _distinct_on=self._distinct_on, - _with_options=self._with_options, - _with_context_options=self._with_context_options, - _hints=self._hints, - _statement_hints=self._statement_hints, - _correlate=self._correlate, - _auto_correlate=self._auto_correlate, - _limit_clause=self._limit_clause, - _offset_clause=self._offset_clause, - _for_update_arg=self._for_update_arg, - _prefixes=self._prefixes, - _suffixes=self._suffixes, _label_style=self._label_style, - compile_options=self.compile_options - + {"_orm_query": self.with_session(None)}, - _execution_options=self._execution_options, + compile_options=compile_options, ) if not orm_results: @@ -897,9 +874,11 @@ class Query( :return: The object instance, or ``None``. """ + self._no_criterion_assertion("get", order_by=False, distinct=False) return self._get_impl(ident, loading.load_on_pk_identity) def _get_impl(self, primary_key_identity, db_load_fn, identity_token=None): + # convert composite types to individual args if hasattr(primary_key_identity, "__composite_values__"): primary_key_identity = primary_key_identity.__composite_values__() @@ -977,33 +956,14 @@ class Query( """An :class:`.InstanceState` that is using this :class:`_query.Query` for a lazy load operation. - The primary rationale for this attribute is to support the horizontal - sharding extension, where it is available within specific query - execution time hooks created by this extension. To that end, the - attribute is only intended to be meaningful at **query execution - time**, and importantly not any time prior to that, including query - compilation time. - - .. note:: - - Within the realm of regular :class:`_query.Query` usage, this - attribute is set by the lazy loader strategy before the query is - invoked. However there is no established hook that is available to - reliably intercept this value programmatically. It is set by the - lazy loading strategy after any mapper option objects would have - been applied, and now that the lazy loading strategy in the ORM - makes use of "baked" queries to cache SQL compilation, the - :meth:`.QueryEvents.before_compile` hook is also not reliable. + .. deprecated:: 1.4 This attribute should be viewed via the + :attr:`.ORMExecuteState.lazy_loaded_from` attribute, within + the context of the :meth:`.SessionEvents.do_orm_execute` + event. - Currently, setting the :paramref:`_orm.relationship.bake_queries` - to ``False`` on the target :func:`_orm.relationship`, and then - making use of the :meth:`.QueryEvents.before_compile` event hook, - is the only available programmatic path to intercepting this - attribute. In future releases, there will be new hooks available - that allow interception of the :class:`_query.Query` before it is - executed, rather than before it is compiled. + .. seealso:: - .. versionadded:: 1.2.9 + :attr:`.ORMExecuteState.lazy_loaded_from` """ return self.load_options._lazy_loaded_from @@ -2713,6 +2673,7 @@ class Query( statement = coercions.expect( roles.SelectStatementRole, statement, apply_propagate_attrs=self ) + self._statement = statement self.compile_options += {"_statement": statement} def first(self): @@ -2736,7 +2697,7 @@ class Query( """ # replicates limit(1) behavior - if self.compile_options._statement is not None: + if self._statement is not None: return self._iter().first() else: return self.limit(1)._iter().first() @@ -2918,7 +2879,9 @@ class Query( "for linking ORM results to arbitrary select constructs.", version="1.4", ) - compile_state = ORMCompileState._create_for_legacy_query(self) + compile_state = ORMCompileState._create_for_legacy_query( + self, toplevel=True + ) context = QueryContext( compile_state, self.session, self.load_options ) @@ -3332,7 +3295,7 @@ class Query( def _compile_state(self, for_statement=False, **kw): return ORMCompileState._create_for_legacy_query( - self, for_statement=for_statement, **kw + self, toplevel=True, for_statement=for_statement, **kw ) def _compile_context(self, for_statement=False): @@ -3366,7 +3329,7 @@ class FromStatement(SelectStatementGrouping, Executable): super(FromStatement, self).__init__(element) def _compiler_dispatch(self, compiler, **kw): - compile_state = self._compile_state_factory(self, self, **kw) + compile_state = self._compile_state_factory(self, compiler, **kw) toplevel = not compiler.stack diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 8d2f13df3..25e224348 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -35,6 +35,7 @@ from ..inspection import inspect from ..sql import coercions from ..sql import roles from ..sql import visitors +from ..sql.base import CompileState __all__ = ["Session", "SessionTransaction", "sessionmaker"] @@ -98,7 +99,7 @@ DEACTIVE = util.symbol("DEACTIVE") CLOSED = util.symbol("CLOSED") -class ORMExecuteState(object): +class ORMExecuteState(util.MemoizedSlots): """Stateful object used for the :meth:`.SessionEvents.do_orm_execute` .. versionadded:: 1.4 @@ -109,7 +110,8 @@ class ORMExecuteState(object): "session", "statement", "parameters", - "execution_options", + "_execution_options", + "_merged_execution_options", "bind_arguments", ) @@ -119,7 +121,7 @@ class ORMExecuteState(object): self.session = session self.statement = statement self.parameters = parameters - self.execution_options = execution_options + self._execution_options = execution_options self.bind_arguments = bind_arguments def invoke_statement( @@ -182,33 +184,51 @@ class ORMExecuteState(object): _params = self.parameters if execution_options: - _execution_options = dict(self.execution_options) + _execution_options = dict(self._execution_options) _execution_options.update(execution_options) else: - _execution_options = self.execution_options + _execution_options = self._execution_options return self.session.execute( statement, _params, _execution_options, _bind_arguments ) @property - def orm_query(self): - """Return the :class:`_orm.Query` object associated with this - execution. + def execution_options(self): + """Placeholder for execution options. + + Raises an informative message, as there are local options + vs. merged options that can be viewed, via the + :attr:`.ORMExecuteState.local_execution_options` and + :attr:`.ORMExecuteState.merged_execution_options` methods. - For SQLAlchemy-2.0 style usage, the :class:`_orm.Query` object - is not used at all, and this attribute will return None. """ - load_opts = self.load_options - if load_opts._orm_query: - return load_opts._orm_query + raise AttributeError( + "Please use .local_execution_options or " + ".merged_execution_options" + ) - opts = self._orm_compile_options() - if opts is not None: - return opts._orm_query - else: - return None + @property + def local_execution_options(self): + """Dictionary view of the execution options passed to the + :meth:`.Session.execute` method. This does not include options + that may be associated with the statement being invoked. + + """ + return util.immutabledict(self._execution_options) + + @property + def merged_execution_options(self): + """Dictionary view of all execution options merged together; + this includes those of the statement as well as those passed to + :meth:`.Session.execute`, with the local options taking precedence. + + """ + return self._merged_execution_options + + def _memoized_attr__merged_execution_options(self): + return self.statement._execution_options.union(self._execution_options) def _orm_compile_options(self): opts = self.statement.compile_options @@ -218,6 +238,21 @@ class ORMExecuteState(object): return None @property + def lazy_loaded_from(self): + """An :class:`.InstanceState` that is using this statement execution + for a lazy load operation. + + The primary rationale for this attribute is to support the horizontal + sharding extension, where it is available within specific query + execution time hooks created by this extension. To that end, the + attribute is only intended to be meaningful at **query execution + time**, and importantly not any time prior to that, including query + compilation time. + + """ + return self.load_options._lazy_loaded_from + + @property def loader_strategy_path(self): """Return the :class:`.PathRegistry` for the current load path. @@ -235,7 +270,7 @@ class ORMExecuteState(object): def load_options(self): """Return the load_options that will be used for this execution.""" - return self.execution_options.get( + return self._execution_options.get( "_sa_orm_load_options", context.QueryContext.default_load_options ) @@ -1407,7 +1442,6 @@ class Session(_SessionClassMethods): in order to execute the statement. """ - statement = coercions.expect(roles.CoerceTextStatementRole, statement) if not bind_arguments: @@ -1415,12 +1449,19 @@ class Session(_SessionClassMethods): elif kw: bind_arguments.update(kw) - compile_state_cls = statement._get_plugin_compile_state_cls("orm") - if compile_state_cls: + if ( + statement._propagate_attrs.get("compile_state_plugin", None) + == "orm" + ): + compile_state_cls = CompileState._get_plugin_class_for_plugin( + statement, "orm" + ) + compile_state_cls.orm_pre_session_exec( self, statement, execution_options, bind_arguments ) else: + compile_state_cls = None bind_arguments.setdefault("clause", statement) if statement._is_future: execution_options = util.immutabledict().merge_with( @@ -1694,9 +1735,19 @@ class Session(_SessionClassMethods): :meth:`.Session.bind_table` """ + + # this function is documented as a subclassing hook, so we have + # to call this method even if the return is simple if bind: return bind + elif not self.__binds and self.bind: + # simplest and most common case, we have a bind and no + # per-mapper/table binds, we're done + return self.bind + # we don't have self.bind and either have self.__binds + # or we don't have self.__binds (which is legacy). Look at the + # mapper and the clause if mapper is clause is None: if self.bind: return self.bind @@ -1707,6 +1758,7 @@ class Session(_SessionClassMethods): "a binding." ) + # look more closely at the mapper. if mapper is not None: try: mapper = inspect(mapper) @@ -1718,6 +1770,7 @@ class Session(_SessionClassMethods): else: raise + # match up the mapper or clause in the __binds if self.__binds: # matching mappers and selectables to entries in the # binds dictionary; supported use case. @@ -1733,7 +1786,8 @@ class Session(_SessionClassMethods): if obj in self.__binds: return self.__binds[obj] - # session has a single bind; supported use case. + # none of the __binds matched, but we have a fallback bind. + # return that if self.bind: return self.bind @@ -1745,16 +1799,10 @@ class Session(_SessionClassMethods): if clause is not None: if clause.bind: return clause.bind - # for obj in visitors.iterate(clause): - # if obj.bind: - # return obj.bind if mapper: if mapper.persist_selectable.bind: return mapper.persist_selectable.bind - # for obj in visitors.iterate(mapper.persist_selectable): - # if obj.bind: - # return obj.bind context = [] if mapper is not None: diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index a7d501b53..626018997 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -12,12 +12,12 @@ from __future__ import absolute_import import collections import itertools -from sqlalchemy.orm import query from . import attributes from . import exc as orm_exc from . import interfaces from . import loading from . import properties +from . import query from . import relationships from . import unitofwork from . import util as orm_util @@ -1143,7 +1143,7 @@ class SubqueryLoader(PostLoader): ) = self._get_leftmost(subq_path) orig_query = compile_state.attributes.get( - ("orig_query", SubqueryLoader), compile_state.orm_query + ("orig_query", SubqueryLoader), compile_state.query ) # generate a new Query from the original, then @@ -1168,9 +1168,7 @@ class SubqueryLoader(PostLoader): def set_state_options(compile_state): compile_state.attributes.update( { - ("orig_query", SubqueryLoader): orig_query.with_session( - None - ), + ("orig_query", SubqueryLoader): orig_query, ("subquery_path", None): subq_path, } ) @@ -1236,6 +1234,19 @@ class SubqueryLoader(PostLoader): # to look only for significant columns q = orig_query._clone().correlate(None) + # LEGACY: make a Query back from the select() !! + # This suits at least two legacy cases: + # 1. applications which expect before_compile() to be called + # below when we run .subquery() on this query (Keystone) + # 2. applications which are doing subqueryload with complex + # from_self() queries, as query.subquery() / .statement + # has to do the full compile context for multiply-nested + # from_self() (Neutron) - see test_subqload_from_self + # for demo. + q2 = query.Query.__new__(query.Query) + q2.__dict__.update(q.__dict__) + q = q2 + # set the query's "FROM" list explicitly to what the # FROM list would be in any case, as we will be limiting # the columns in the SELECT list which may no longer include @@ -1251,15 +1262,6 @@ class SubqueryLoader(PostLoader): } ) - # NOTE: keystone has a test which is counting before_compile - # events. That test is in one case dependent on an extra - # call that was occurring here within the subqueryloader setup - # process, probably when the subquery() method was called. - # Ultimately that call will not be occurring here. - # the event has already been called on the original query when - # we are here in any case, so keystone will need to adjust that - # test. - # for column information, look to the compile state that is # already being passed through compile_state = orig_compile_state @@ -1304,7 +1306,8 @@ class SubqueryLoader(PostLoader): # the original query now becomes a subquery # which we'll join onto. - + # LEGACY: as "q" is a Query, the before_compile() event is invoked + # here. embed_q = q.apply_labels().subquery() left_alias = orm_util.AliasedClass( leftmost_mapper, embed_q, use_mapper_path=True @@ -1416,8 +1419,6 @@ class SubqueryLoader(PostLoader): # these will fire relative to subq_path. q = q._with_current_path(subq_path) q = q.options(*orig_query._with_options) - if orig_query.load_options._populate_existing: - q.load_options += {"_populate_existing": True} return q @@ -1475,8 +1476,11 @@ class SubqueryLoader(PostLoader): ) q = q.with_session(self.session) + if self.load_options._populate_existing: + q = q.populate_existing() # to work with baked query, the parameters may have been # updated since this query was created, so take these into account + rows = list(q.params(self.load_options._params)) for k, v in itertools.groupby(rows, lambda x: x[1:]): self._data[k].extend(vv[0] for vv in v) diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index 273841588..5fca41ba6 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -861,7 +861,14 @@ class _UnboundLoad(Load): # we just located, then go through the rest of our path # tokens and populate into the Load(). loader = Load(path_element) + if context is not None: + # TODO: this creates a cycle with context.attributes. + # the current approach to mitigating this is the context / + # compile_state attributes are cleared out when a result + # is fetched. However, it would be nice if these attributes + # could be passed to all methods so that all the state + # gets set up without ever creating any assignments. loader.context = context else: context = loader.context diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index bb606a4d6..6415d4b37 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -470,12 +470,7 @@ class CompileState(object): return None @classmethod - def _get_plugin_compile_state_cls(cls, statement, plugin_name): - statement_plugin_name = statement._propagate_attrs.get( - "compile_state_plugin", "default" - ) - if statement_plugin_name != plugin_name: - return None + def _get_plugin_class_for_plugin(cls, statement, plugin_name): try: return cls.plugins[(plugin_name, statement.__visit_name__)] except KeyError: @@ -607,9 +602,6 @@ class Executable(Generative): def _disable_caching(self): self._cache_enable = HasCacheKey() - def _get_plugin_compile_state_cls(self, plugin_name): - return CompileState._get_plugin_compile_state_cls(self, plugin_name) - @_generative def options(self, *options): """Apply options to this statement. @@ -735,7 +727,9 @@ class Executable(Generative): "to execute this construct." % label ) raise exc.UnboundExecutionError(msg) - return e._execute_clauseelement(self, multiparams, params) + return e._execute_clauseelement( + self, multiparams, params, util.immutabledict() + ) @util.deprecated_20( ":meth:`.Executable.scalar`", diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index d8ef0222a..7503faf5b 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -50,16 +50,50 @@ def _document_text_coercion(paramname, meth_rst, param_rst): ) -def expect(role, element, apply_propagate_attrs=None, **kw): +def expect(role, element, apply_propagate_attrs=None, argname=None, **kw): # major case is that we are given a ClauseElement already, skip more # elaborate logic up front if possible impl = _impl_lookup[role] + original_element = element + if not isinstance( element, (elements.ClauseElement, schema.SchemaItem, schema.FetchedValue), ): - resolved = impl._resolve_for_clause_element(element, **kw) + resolved = None + + if impl._resolve_string_only: + resolved = impl._literal_coercion(element, **kw) + else: + + original_element = element + + is_clause_element = False + + while hasattr(element, "__clause_element__"): + is_clause_element = True + if not getattr(element, "is_clause_element", False): + element = element.__clause_element__() + else: + break + + if not is_clause_element: + if impl._use_inspection: + insp = inspection.inspect(element, raiseerr=False) + if insp is not None: + insp._post_inspect + try: + resolved = insp.__clause_element__() + except AttributeError: + impl._raise_for_expected(original_element, argname) + + if resolved is None: + resolved = impl._literal_coercion( + element, argname=argname, **kw + ) + else: + resolved = element else: resolved = element @@ -72,10 +106,12 @@ def expect(role, element, apply_propagate_attrs=None, **kw): if impl._role_class in resolved.__class__.__mro__: if impl._post_coercion: - resolved = impl._post_coercion(resolved, **kw) + resolved = impl._post_coercion(resolved, argname=argname, **kw) return resolved else: - return impl._implicit_coercions(element, resolved, **kw) + return impl._implicit_coercions( + original_element, resolved, argname=argname, **kw + ) def expect_as_key(role, element, **kw): @@ -107,51 +143,13 @@ class RoleImpl(object): raise NotImplementedError() _post_coercion = None + _resolve_string_only = False def __init__(self, role_class): self._role_class = role_class self.name = role_class._role_name self._use_inspection = issubclass(role_class, roles.UsesInspection) - def _resolve_for_clause_element(self, element, argname=None, **kw): - original_element = element - - is_clause_element = False - - while hasattr(element, "__clause_element__"): - is_clause_element = True - if not getattr(element, "is_clause_element", False): - element = element.__clause_element__() - else: - return element - - if not is_clause_element: - if self._use_inspection: - insp = inspection.inspect(element, raiseerr=False) - if insp is not None: - insp._post_inspect - try: - element = insp.__clause_element__() - except AttributeError: - self._raise_for_expected(original_element, argname) - else: - return element - - return self._literal_coercion(element, argname=argname, **kw) - else: - return element - - if self._use_inspection: - insp = inspection.inspect(element, raiseerr=False) - if insp is not None: - insp._post_inspect - try: - element = insp.__clause_element__() - except AttributeError: - self._raise_for_expected(original_element, argname) - - return self._literal_coercion(element, argname=argname, **kw) - def _implicit_coercions(self, element, resolved, argname=None, **kw): self._raise_for_expected(element, argname, resolved) @@ -191,8 +189,7 @@ class _Deannotate(object): class _StringOnly(object): __slots__ = () - def _resolve_for_clause_element(self, element, argname=None, **kw): - return self._literal_coercion(element, **kw) + _resolve_string_only = True class _ReturnsStringKey(object): @@ -465,7 +462,7 @@ class ByOfImpl(_CoerceLiterals, _ColumnCoercions, RoleImpl, roles.ByOfRole): class OrderByImpl(ByOfImpl, RoleImpl): __slots__ = () - def _post_coercion(self, resolved): + def _post_coercion(self, resolved, **kw): if ( isinstance(resolved, self._role_class) and resolved._order_by_label_element is not None @@ -490,7 +487,7 @@ class GroupByImpl(ByOfImpl, RoleImpl): class DMLColumnImpl(_ReturnsStringKey, RoleImpl): __slots__ = () - def _post_coercion(self, element, as_key=False): + def _post_coercion(self, element, as_key=False, **kw): if as_key: return element.key else: diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index d0f4fef60..5a55fe5f2 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -5,6 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +from .. import util + class SQLRole(object): """Define a "role" within a SQL statement structure. @@ -145,16 +147,12 @@ class CoerceTextStatementRole(SQLRole): _role_name = "Executable SQL or text() construct" -# _executable_statement = None - - class StatementRole(CoerceTextStatementRole): _role_name = "Executable SQL or text() construct" _is_future = False - def _get_plugin_compile_state_cls(self, name): - return None + _propagate_attrs = util.immutabledict() class ReturnsRowsRole(StatementRole): diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 008959aec..170e016a5 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -3428,7 +3428,14 @@ class DeprecatedSelectGenerations(object): @CompileState.plugin_for("default", "select") -class SelectState(CompileState): +class SelectState(util.MemoizedSlots, CompileState): + __slots__ = ( + "from_clauses", + "froms", + "columns_plus_names", + "_label_resolve_dict", + ) + class default_select_compile_options(CacheableOptions): _cache_key_traversal = [] @@ -3547,8 +3554,7 @@ class SelectState(CompileState): return froms - @util.memoized_property - def _label_resolve_dict(self): + def _memoized_attr__label_resolve_dict(self): with_cols = dict( (c._resolve_label or c._label or c.key, c) for c in _select_iterables(self.statement._raw_columns) diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 482248ada..a38088a27 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -179,7 +179,7 @@ class HasCacheKey(object): if NO_CACHE in _anon_map: return None else: - return CacheKey(key, bindparams, self) + return CacheKey(key, bindparams) @classmethod def _generate_cache_key_for_object(cls, obj): @@ -190,7 +190,7 @@ class HasCacheKey(object): if NO_CACHE in _anon_map: return None else: - return CacheKey(key, bindparams, obj) + return CacheKey(key, bindparams) class MemoizedHasCacheKey(HasCacheKey, HasMemoized): @@ -199,13 +199,13 @@ class MemoizedHasCacheKey(HasCacheKey, HasMemoized): return HasCacheKey._generate_cache_key(self) -class CacheKey(namedtuple("CacheKey", ["key", "bindparams", "statement"])): +class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])): def __hash__(self): """CacheKey itself is not hashable - hash the .key portion""" return None - def to_offline_string(self, statement_cache, parameters): + def to_offline_string(self, statement_cache, statement, parameters): """generate an "offline string" form of this :class:`.CacheKey` The "offline string" is basically the string SQL for the @@ -222,7 +222,7 @@ class CacheKey(namedtuple("CacheKey", ["key", "bindparams", "statement"])): """ if self.key not in statement_cache: - statement_cache[self.key] = sql_str = str(self.statement) + statement_cache[self.key] = sql_str = str(statement) else: sql_str = statement_cache[self.key] diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index f6fefc244..52debc517 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -942,6 +942,10 @@ class HasMemoized(object): for elem in self._memoized_keys: assert elem not in self.__dict__ + def _set_memoized_attribute(self, key, value): + self.__dict__[key] = value + self._memoized_keys |= {key} + class memoized_attribute(object): """A read-only @property that is only evaluated once.""" @@ -1260,15 +1264,20 @@ class classproperty(property): class hybridproperty(object): def __init__(self, func): self.func = func + self.clslevel = func def __get__(self, instance, owner): if instance is None: - clsval = self.func(owner) + clsval = self.clslevel(owner) clsval.__doc__ = self.func.__doc__ return clsval else: return self.func(instance) + def classlevel(self, func): + self.clslevel = func + return self + class hybridmethod(object): """Decorate a function as cls- or instance- level.""" |