diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2020-08-20 15:22:08 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2020-08-20 15:22:08 +0000 |
commit | 544ef23cd36b0ea30a13c5158121ba5ea7573f03 (patch) | |
tree | bda3830e46dc395e5ebd9592cc6cc77ef67bbbea /lib/sqlalchemy | |
parent | 62347923754754f93adbf0c3888208be77f26e70 (diff) | |
parent | a1939719a652774a437f69f8d4788b3f08650089 (diff) | |
download | sqlalchemy-544ef23cd36b0ea30a13c5158121ba5ea7573f03.tar.gz |
Merge "normalize execute style for events, 2.0"
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/cextension/utils.c | 27 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 20 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/base.py | 187 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/util.py | 73 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/assertsql.py | 6 |
5 files changed, 237 insertions, 76 deletions
diff --git a/lib/sqlalchemy/cextension/utils.c b/lib/sqlalchemy/cextension/utils.c index ab8b39335..c612094dc 100644 --- a/lib/sqlalchemy/cextension/utils.c +++ b/lib/sqlalchemy/cextension/utils.c @@ -26,12 +26,13 @@ distill_params(PyObject *self, PyObject *args) // TODO: pass the Connection in so that there can be a standard // method for warning on parameter format - PyObject *multiparams, *params; + PyObject *connection, *multiparams, *params; PyObject *enclosing_list, *double_enclosing_list; PyObject *zero_element, *zero_element_item; + PyObject *tmp; Py_ssize_t multiparam_size, zero_element_length; - if (!PyArg_UnpackTuple(args, "_distill_params", 2, 2, &multiparams, ¶ms)) { + if (!PyArg_UnpackTuple(args, "_distill_params", 3, 3, &connection, &multiparams, ¶ms)) { return NULL; } @@ -47,8 +48,12 @@ distill_params(PyObject *self, PyObject *args) if (multiparam_size == 0) { if (params != Py_None && PyMapping_Size(params) != 0) { - // TODO: this is keyword parameters, emit parameter format - // deprecation warning + + tmp = PyObject_CallMethod(connection, "_warn_for_legacy_exec_format", ""); + if (tmp == NULL) { + return NULL; + } + enclosing_list = PyList_New(1); if (enclosing_list == NULL) { return NULL; @@ -102,6 +107,7 @@ distill_params(PyObject *self, PyObject *args) * execute(stmt, ("value", "value")) */ Py_XDECREF(zero_element_item); + enclosing_list = PyList_New(1); if (enclosing_list == NULL) { return NULL; @@ -131,6 +137,11 @@ distill_params(PyObject *self, PyObject *args) } return enclosing_list; } else { + tmp = PyObject_CallMethod(connection, "_warn_for_legacy_exec_format", ""); + if (tmp == NULL) { + return NULL; + } + enclosing_list = PyList_New(1); if (enclosing_list == NULL) { return NULL; @@ -157,8 +168,12 @@ distill_params(PyObject *self, PyObject *args) } } else { - // TODO: this is multiple positional params, emit parameter format - // deprecation warning + + tmp = PyObject_CallMethod(connection, "_warn_for_legacy_exec_format", ""); + if (tmp == NULL) { + return NULL; + } + zero_element = PyTuple_GetItem(multiparams, 0); if (PyObject_HasAttrString(zero_element, "__iter__") && !PyObject_HasAttrString(zero_element, "strip") diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 2db079799..c56cccd8d 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -2903,7 +2903,11 @@ class PGDialect(default.DefaultDialect): "JOIN pg_namespace n ON n.oid = c.relnamespace " "WHERE n.nspname = :schema AND c.relkind in ('r', 'p')" ).columns(relname=sqltypes.Unicode), - schema=schema if schema is not None else self.default_schema_name, + dict( + schema=schema + if schema is not None + else self.default_schema_name + ), ) return [name for name, in result] @@ -3018,7 +3022,7 @@ class PGDialect(default.DefaultDialect): .bindparams(sql.bindparam("table_oid", type_=sqltypes.Integer)) .columns(attname=sqltypes.Unicode, default=sqltypes.Unicode) ) - c = connection.execute(s, table_oid=table_oid) + c = connection.execute(s, dict(table_oid=table_oid)) rows = c.fetchall() # dictionary with (name, ) if default search path or (schema, name) @@ -3260,7 +3264,7 @@ class PGDialect(default.DefaultDialect): ORDER BY k.ord """ t = sql.text(PK_SQL).columns(attname=sqltypes.Unicode) - c = connection.execute(t, table_oid=table_oid) + c = connection.execute(t, dict(table_oid=table_oid)) cols = [r[0] for r in c.fetchall()] PK_CONS_SQL = """ @@ -3270,7 +3274,7 @@ class PGDialect(default.DefaultDialect): ORDER BY 1 """ t = sql.text(PK_CONS_SQL).columns(conname=sqltypes.Unicode) - c = connection.execute(t, table_oid=table_oid) + c = connection.execute(t, dict(table_oid=table_oid)) name = c.scalar() return {"constrained_columns": cols, "name": name} @@ -3318,7 +3322,7 @@ class PGDialect(default.DefaultDialect): t = sql.text(FK_SQL).columns( conname=sqltypes.Unicode, condef=sqltypes.Unicode ) - c = connection.execute(t, table=table_oid) + c = connection.execute(t, dict(table=table_oid)) fkeys = [] for conname, condef, conschema in c.fetchall(): m = re.search(FK_REGEX, condef).groups() @@ -3490,7 +3494,7 @@ class PGDialect(default.DefaultDialect): t = sql.text(IDX_SQL).columns( relname=sqltypes.Unicode, attname=sqltypes.Unicode ) - c = connection.execute(t, table_oid=table_oid) + c = connection.execute(t, dict(table_oid=table_oid)) indexes = defaultdict(lambda: defaultdict(dict)) @@ -3632,7 +3636,7 @@ class PGDialect(default.DefaultDialect): """ t = sql.text(UNIQUE_SQL).columns(col_name=sqltypes.Unicode) - c = connection.execute(t, table_oid=table_oid) + c = connection.execute(t, dict(table_oid=table_oid)) uniques = defaultdict(lambda: defaultdict(dict)) for row in c.fetchall(): @@ -3683,7 +3687,7 @@ class PGDialect(default.DefaultDialect): cons.contype = 'c' """ - c = connection.execute(sql.text(CHECK_SQL), table_oid=table_oid) + c = connection.execute(sql.text(CHECK_SQL), dict(table_oid=table_oid)) ret = [] for name, src in c: diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 34bf720b7..0eaa1fae1 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -935,6 +935,17 @@ class Connection(Connectable): if not self.in_transaction(): self._rollback_impl() + def _warn_for_legacy_exec_format(self): + util.warn_deprecated_20( + "The connection.execute() method in " + "SQLAlchemy 2.0 will accept parameters as a single " + "dictionary or a " + "single sequence of dictionaries only. " + "Parameters passed as keyword arguments, tuples or positionally " + "oriened dictionaries and/or tuples " + "will no longer be accepted." + ) + def close(self): """Close this :class:`_engine.Connection`. @@ -1073,14 +1084,13 @@ class Connection(Connectable): "or the Connection.exec_driver_sql() method to invoke a " "driver-level SQL string." ) - distilled_parameters = _distill_params(multiparams, params) return self._exec_driver_sql( object_, multiparams, params, - distilled_parameters, _EMPTY_EXECUTION_OPTS, + future=False, ) try: @@ -1113,11 +1123,16 @@ class Connection(Connectable): execution_options ) + distilled_parameters = _distill_params(self, multiparams, params) + if self._has_events or self.engine._has_events: - for fn in self.dispatch.before_execute: - default, multiparams, params = fn( - self, default, multiparams, params, execution_options - ) + ( + distilled_params, + event_multiparams, + event_params, + ) = self._invoke_before_exec_event( + default, distilled_parameters, execution_options + ) try: conn = self._dbapi_connection @@ -1139,7 +1154,12 @@ class Connection(Connectable): if self._has_events or self.engine._has_events: self.dispatch.after_execute( - self, default, multiparams, params, execution_options, ret + self, + default, + event_multiparams, + event_params, + execution_options, + ret, ) return ret @@ -1151,11 +1171,16 @@ class Connection(Connectable): self._execution_options, execution_options ) + distilled_parameters = _distill_params(self, multiparams, params) + if self._has_events or self.engine._has_events: - for fn in self.dispatch.before_execute: - ddl, multiparams, params = fn( - self, ddl, multiparams, params, execution_options - ) + ( + distilled_params, + event_multiparams, + event_params, + ) = self._invoke_before_exec_event( + ddl, distilled_parameters, execution_options + ) exec_opts = self._execution_options.merge_with(execution_options) schema_translate_map = exec_opts.get("schema_translate_map", None) @@ -1175,10 +1200,43 @@ class Connection(Connectable): ) if self._has_events or self.engine._has_events: self.dispatch.after_execute( - self, ddl, multiparams, params, execution_options, ret + self, + ddl, + event_multiparams, + event_params, + execution_options, + ret, ) return ret + def _invoke_before_exec_event( + self, elem, distilled_params, execution_options + ): + + if len(distilled_params) == 1: + event_multiparams, event_params = [], distilled_params[0] + else: + event_multiparams, event_params = distilled_params, {} + + for fn in self.dispatch.before_execute: + elem, event_multiparams, event_params = fn( + self, elem, event_multiparams, event_params, execution_options, + ) + + if event_multiparams: + distilled_params = list(event_multiparams) + if event_params: + raise exc.InvalidRequestError( + "Event handler can't return non-empty multiparams " + "and params at the same time" + ) + elif event_params: + distilled_params = [event_params] + else: + distilled_params = [] + + return distilled_params, event_multiparams, event_params + def _execute_clauseelement( self, elem, multiparams, params, execution_options ): @@ -1188,14 +1246,18 @@ class Connection(Connectable): self._execution_options, execution_options ) + distilled_params = _distill_params(self, multiparams, params) + has_events = self._has_events or self.engine._has_events if has_events: - for fn in self.dispatch.before_execute: - elem, multiparams, params = fn( - self, elem, multiparams, params, execution_options - ) + ( + distilled_params, + event_multiparams, + event_params, + ) = self._invoke_before_exec_event( + elem, distilled_params, execution_options + ) - distilled_params = _distill_params(multiparams, params) if distilled_params: # ensure we don't retain a link to the view object for keys() # which links to the values, which we don't want to cache @@ -1237,7 +1299,12 @@ class Connection(Connectable): ) if has_events: self.dispatch.after_execute( - self, elem, multiparams, params, execution_options, ret + self, + elem, + event_multiparams, + event_params, + execution_options, + ret, ) return ret @@ -1257,49 +1324,58 @@ class Connection(Connectable): execution_options = compiled.execution_options.merge_with( self._execution_options, execution_options ) + distilled_parameters = _distill_params(self, multiparams, params) if self._has_events or self.engine._has_events: - for fn in self.dispatch.before_execute: - compiled, multiparams, params = fn( - self, compiled, multiparams, params, execution_options - ) + ( + distilled_params, + event_multiparams, + event_params, + ) = self._invoke_before_exec_event( + compiled, distilled_parameters, execution_options + ) dialect = self.dialect - parameters = _distill_params(multiparams, params) ret = self._execute_context( dialect, dialect.execution_ctx_cls._init_compiled, compiled, - parameters, + distilled_parameters, execution_options, compiled, - parameters, + distilled_parameters, None, None, ) if self._has_events or self.engine._has_events: self.dispatch.after_execute( - self, compiled, multiparams, params, execution_options, ret + self, + compiled, + event_multiparams, + event_params, + execution_options, + ret, ) return ret def _exec_driver_sql( - self, - statement, - multiparams, - params, - distilled_parameters, - execution_options, + self, statement, multiparams, params, execution_options, future ): execution_options = self._execution_options.merge_with( execution_options ) - if self._has_events or self.engine._has_events: - for fn in self.dispatch.before_execute: - statement, multiparams, params = fn( - self, statement, multiparams, params, execution_options + distilled_parameters = _distill_params(self, multiparams, params) + + if not future: + if self._has_events or self.engine._has_events: + ( + distilled_params, + event_multiparams, + event_params, + ) = self._invoke_before_exec_event( + statement, distilled_parameters, execution_options ) dialect = self.dialect @@ -1312,10 +1388,17 @@ class Connection(Connectable): statement, distilled_parameters, ) - if self._has_events or self.engine._has_events: - self.dispatch.after_execute( - self, statement, multiparams, params, execution_options, ret - ) + + if not future: + if self._has_events or self.engine._has_events: + self.dispatch.after_execute( + self, + statement, + event_multiparams, + event_params, + execution_options, + ret, + ) return ret def _execute_20( @@ -1324,9 +1407,7 @@ class Connection(Connectable): parameters=None, execution_options=_EMPTY_EXECUTION_OPTS, ): - multiparams, params, distilled_parameters = _distill_params_20( - parameters - ) + args_10style, kwargs_10style = _distill_params_20(parameters) try: meth = statement._execute_on_connection except AttributeError as err: @@ -1334,7 +1415,7 @@ class Connection(Connectable): exc.ObjectNotExecutableError(statement), replace_context=err ) else: - return meth(self, multiparams, params, execution_options) + return meth(self, args_10style, kwargs_10style, execution_options) def exec_driver_sql( self, statement, parameters=None, execution_options=None @@ -1373,22 +1454,28 @@ class Connection(Connectable): (1, 'v1') ) + .. note:: The :meth:`_engine.Connection.exec_driver_sql` method does + not participate in the + :meth:`_events.ConnectionEvents.before_execute` and + :meth:`_events.ConnectionEvents.after_execute` events. To + intercept calls to :meth:`_engine.Connection.exec_driver_sql`, use + :meth:`_events.ConnectionEvents.before_cursor_execute` and + :meth:`_events.ConnectionEvents.after_cursor_execute`. + .. seealso:: :pep:`249` """ - multiparams, params, distilled_parameters = _distill_params_20( - parameters - ) + args_10style, kwargs_10style = _distill_params_20(parameters) return self._exec_driver_sql( statement, - multiparams, - params, - distilled_parameters, + args_10style, + kwargs_10style, execution_options, + future=True, ) def _execute_context( diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py index fc0260ae2..c1f6bad77 100644 --- a/lib/sqlalchemy/engine/util.py +++ b/lib/sqlalchemy/engine/util.py @@ -30,10 +30,14 @@ def connection_memoize(key): return decorated +_no_tuple = () +_no_kw = util.immutabledict() + + def py_fallback(): # TODO: pass the Connection in so that there can be a standard # method for warning on parameter format - def _distill_params(multiparams, params): # noqa + def _distill_params(connection, multiparams, params): # noqa r"""Given arguments from the calling form \*multiparams, \**params, return a list of bind parameter structures, usually a list of dictionaries. @@ -43,9 +47,12 @@ def py_fallback(): """ + # C version will fail if this assertion is not true. + # assert isinstance(multiparams, tuple) + if not multiparams: if params: - # TODO: parameter format deprecation warning + connection._warn_for_legacy_exec_format() return [params] else: return [] @@ -61,16 +68,22 @@ def py_fallback(): # execute(stmt, [(), (), (), ...]) return zero else: + # this is used by exec_driver_sql only, so a deprecation + # warning would already be coming from passing a plain + # textual statement with positional parameters to + # execute(). # execute(stmt, ("value", "value")) + return [zero] elif hasattr(zero, "keys"): # execute(stmt, {"key":"value"}) return [zero] else: + connection._warn_for_legacy_exec_format() # execute(stmt, "value") return [[zero]] else: - # TODO: parameter format deprecation warning + connection._warn_for_legacy_exec_format() if hasattr(multiparams[0], "__iter__") and not hasattr( multiparams[0], "strip" ): @@ -81,14 +94,55 @@ def py_fallback(): return locals() -_no_tuple = () -_no_kw = util.immutabledict() +def _distill_cursor_params(connection, multiparams, params): + """_distill_params without any warnings. more appropriate for + "cursor" params that can include tuple arguments, lists of tuples, + etc. + + """ + + if not multiparams: + if params: + return [params] + else: + return [] + elif len(multiparams) == 1: + zero = multiparams[0] + if isinstance(zero, (list, tuple)): + if ( + not zero + or hasattr(zero[0], "__iter__") + and not hasattr(zero[0], "strip") + ): + # execute(stmt, [{}, {}, {}, ...]) + # execute(stmt, [(), (), (), ...]) + return zero + else: + # this is used by exec_driver_sql only, so a deprecation + # warning would already be coming from passing a plain + # textual statement with positional parameters to + # execute(). + # execute(stmt, ("value", "value")) + + return [zero] + elif hasattr(zero, "keys"): + # execute(stmt, {"key":"value"}) + return [zero] + else: + # execute(stmt, "value") + return [[zero]] + else: + if hasattr(multiparams[0], "__iter__") and not hasattr( + multiparams[0], "strip" + ): + return multiparams + else: + return [multiparams] def _distill_params_20(params): - # TODO: this has to be in C if params is None: - return _no_tuple, _no_kw, [] + return _no_tuple, _no_kw elif isinstance(params, list): # collections_abc.MutableSequence): # avoid abc.__instancecheck__ if params and not isinstance( @@ -98,15 +152,14 @@ def _distill_params_20(params): "List argument must consist only of tuples or dictionaries" ) - # the tuple is needed atm by the C version of _distill_params... - return tuple(params), _no_kw, params + return (params,), _no_kw elif isinstance( params, (tuple, dict, immutabledict), # avoid abc.__instancecheck__ # (collections_abc.Sequence, collections_abc.Mapping), ): - return _no_tuple, params, [params] + return (params,), _no_kw else: raise exc.ArgumentError("mapping or sequence expected for parameters") diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index 73b062b96..c86e26ccf 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -13,7 +13,7 @@ from .. import event from .. import util from ..engine import url from ..engine.default import DefaultDialect -from ..engine.util import _distill_params +from ..engine.util import _distill_cursor_params from ..schema import _DDLCompiles @@ -348,7 +348,9 @@ class SQLExecuteObserved(object): def __init__(self, context, clauseelement, multiparams, params): self.context = context self.clauseelement = clauseelement - self.parameters = _distill_params(multiparams, params) + self.parameters = _distill_cursor_params( + context.connection, tuple(multiparams), params + ) self.statements = [] def __repr__(self): |