summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2020-08-20 15:22:08 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2020-08-20 15:22:08 +0000
commit544ef23cd36b0ea30a13c5158121ba5ea7573f03 (patch)
treebda3830e46dc395e5ebd9592cc6cc77ef67bbbea /lib/sqlalchemy
parent62347923754754f93adbf0c3888208be77f26e70 (diff)
parenta1939719a652774a437f69f8d4788b3f08650089 (diff)
downloadsqlalchemy-544ef23cd36b0ea30a13c5158121ba5ea7573f03.tar.gz
Merge "normalize execute style for events, 2.0"
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/cextension/utils.c27
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py20
-rw-r--r--lib/sqlalchemy/engine/base.py187
-rw-r--r--lib/sqlalchemy/engine/util.py73
-rw-r--r--lib/sqlalchemy/testing/assertsql.py6
5 files changed, 237 insertions, 76 deletions
diff --git a/lib/sqlalchemy/cextension/utils.c b/lib/sqlalchemy/cextension/utils.c
index ab8b39335..c612094dc 100644
--- a/lib/sqlalchemy/cextension/utils.c
+++ b/lib/sqlalchemy/cextension/utils.c
@@ -26,12 +26,13 @@ distill_params(PyObject *self, PyObject *args)
// TODO: pass the Connection in so that there can be a standard
// method for warning on parameter format
- PyObject *multiparams, *params;
+ PyObject *connection, *multiparams, *params;
PyObject *enclosing_list, *double_enclosing_list;
PyObject *zero_element, *zero_element_item;
+ PyObject *tmp;
Py_ssize_t multiparam_size, zero_element_length;
- if (!PyArg_UnpackTuple(args, "_distill_params", 2, 2, &multiparams, &params)) {
+ if (!PyArg_UnpackTuple(args, "_distill_params", 3, 3, &connection, &multiparams, &params)) {
return NULL;
}
@@ -47,8 +48,12 @@ distill_params(PyObject *self, PyObject *args)
if (multiparam_size == 0) {
if (params != Py_None && PyMapping_Size(params) != 0) {
- // TODO: this is keyword parameters, emit parameter format
- // deprecation warning
+
+ tmp = PyObject_CallMethod(connection, "_warn_for_legacy_exec_format", "");
+ if (tmp == NULL) {
+ return NULL;
+ }
+
enclosing_list = PyList_New(1);
if (enclosing_list == NULL) {
return NULL;
@@ -102,6 +107,7 @@ distill_params(PyObject *self, PyObject *args)
* execute(stmt, ("value", "value"))
*/
Py_XDECREF(zero_element_item);
+
enclosing_list = PyList_New(1);
if (enclosing_list == NULL) {
return NULL;
@@ -131,6 +137,11 @@ distill_params(PyObject *self, PyObject *args)
}
return enclosing_list;
} else {
+ tmp = PyObject_CallMethod(connection, "_warn_for_legacy_exec_format", "");
+ if (tmp == NULL) {
+ return NULL;
+ }
+
enclosing_list = PyList_New(1);
if (enclosing_list == NULL) {
return NULL;
@@ -157,8 +168,12 @@ distill_params(PyObject *self, PyObject *args)
}
}
else {
- // TODO: this is multiple positional params, emit parameter format
- // deprecation warning
+
+ tmp = PyObject_CallMethod(connection, "_warn_for_legacy_exec_format", "");
+ if (tmp == NULL) {
+ return NULL;
+ }
+
zero_element = PyTuple_GetItem(multiparams, 0);
if (PyObject_HasAttrString(zero_element, "__iter__") &&
!PyObject_HasAttrString(zero_element, "strip")
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 2db079799..c56cccd8d 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -2903,7 +2903,11 @@ class PGDialect(default.DefaultDialect):
"JOIN pg_namespace n ON n.oid = c.relnamespace "
"WHERE n.nspname = :schema AND c.relkind in ('r', 'p')"
).columns(relname=sqltypes.Unicode),
- schema=schema if schema is not None else self.default_schema_name,
+ dict(
+ schema=schema
+ if schema is not None
+ else self.default_schema_name
+ ),
)
return [name for name, in result]
@@ -3018,7 +3022,7 @@ class PGDialect(default.DefaultDialect):
.bindparams(sql.bindparam("table_oid", type_=sqltypes.Integer))
.columns(attname=sqltypes.Unicode, default=sqltypes.Unicode)
)
- c = connection.execute(s, table_oid=table_oid)
+ c = connection.execute(s, dict(table_oid=table_oid))
rows = c.fetchall()
# dictionary with (name, ) if default search path or (schema, name)
@@ -3260,7 +3264,7 @@ class PGDialect(default.DefaultDialect):
ORDER BY k.ord
"""
t = sql.text(PK_SQL).columns(attname=sqltypes.Unicode)
- c = connection.execute(t, table_oid=table_oid)
+ c = connection.execute(t, dict(table_oid=table_oid))
cols = [r[0] for r in c.fetchall()]
PK_CONS_SQL = """
@@ -3270,7 +3274,7 @@ class PGDialect(default.DefaultDialect):
ORDER BY 1
"""
t = sql.text(PK_CONS_SQL).columns(conname=sqltypes.Unicode)
- c = connection.execute(t, table_oid=table_oid)
+ c = connection.execute(t, dict(table_oid=table_oid))
name = c.scalar()
return {"constrained_columns": cols, "name": name}
@@ -3318,7 +3322,7 @@ class PGDialect(default.DefaultDialect):
t = sql.text(FK_SQL).columns(
conname=sqltypes.Unicode, condef=sqltypes.Unicode
)
- c = connection.execute(t, table=table_oid)
+ c = connection.execute(t, dict(table=table_oid))
fkeys = []
for conname, condef, conschema in c.fetchall():
m = re.search(FK_REGEX, condef).groups()
@@ -3490,7 +3494,7 @@ class PGDialect(default.DefaultDialect):
t = sql.text(IDX_SQL).columns(
relname=sqltypes.Unicode, attname=sqltypes.Unicode
)
- c = connection.execute(t, table_oid=table_oid)
+ c = connection.execute(t, dict(table_oid=table_oid))
indexes = defaultdict(lambda: defaultdict(dict))
@@ -3632,7 +3636,7 @@ class PGDialect(default.DefaultDialect):
"""
t = sql.text(UNIQUE_SQL).columns(col_name=sqltypes.Unicode)
- c = connection.execute(t, table_oid=table_oid)
+ c = connection.execute(t, dict(table_oid=table_oid))
uniques = defaultdict(lambda: defaultdict(dict))
for row in c.fetchall():
@@ -3683,7 +3687,7 @@ class PGDialect(default.DefaultDialect):
cons.contype = 'c'
"""
- c = connection.execute(sql.text(CHECK_SQL), table_oid=table_oid)
+ c = connection.execute(sql.text(CHECK_SQL), dict(table_oid=table_oid))
ret = []
for name, src in c:
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 34bf720b7..0eaa1fae1 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -935,6 +935,17 @@ class Connection(Connectable):
if not self.in_transaction():
self._rollback_impl()
+ def _warn_for_legacy_exec_format(self):
+ util.warn_deprecated_20(
+ "The connection.execute() method in "
+ "SQLAlchemy 2.0 will accept parameters as a single "
+ "dictionary or a "
+ "single sequence of dictionaries only. "
+ "Parameters passed as keyword arguments, tuples or positionally "
+ "oriened dictionaries and/or tuples "
+ "will no longer be accepted."
+ )
+
def close(self):
"""Close this :class:`_engine.Connection`.
@@ -1073,14 +1084,13 @@ class Connection(Connectable):
"or the Connection.exec_driver_sql() method to invoke a "
"driver-level SQL string."
)
- distilled_parameters = _distill_params(multiparams, params)
return self._exec_driver_sql(
object_,
multiparams,
params,
- distilled_parameters,
_EMPTY_EXECUTION_OPTS,
+ future=False,
)
try:
@@ -1113,11 +1123,16 @@ class Connection(Connectable):
execution_options
)
+ distilled_parameters = _distill_params(self, multiparams, params)
+
if self._has_events or self.engine._has_events:
- for fn in self.dispatch.before_execute:
- default, multiparams, params = fn(
- self, default, multiparams, params, execution_options
- )
+ (
+ distilled_params,
+ event_multiparams,
+ event_params,
+ ) = self._invoke_before_exec_event(
+ default, distilled_parameters, execution_options
+ )
try:
conn = self._dbapi_connection
@@ -1139,7 +1154,12 @@ class Connection(Connectable):
if self._has_events or self.engine._has_events:
self.dispatch.after_execute(
- self, default, multiparams, params, execution_options, ret
+ self,
+ default,
+ event_multiparams,
+ event_params,
+ execution_options,
+ ret,
)
return ret
@@ -1151,11 +1171,16 @@ class Connection(Connectable):
self._execution_options, execution_options
)
+ distilled_parameters = _distill_params(self, multiparams, params)
+
if self._has_events or self.engine._has_events:
- for fn in self.dispatch.before_execute:
- ddl, multiparams, params = fn(
- self, ddl, multiparams, params, execution_options
- )
+ (
+ distilled_params,
+ event_multiparams,
+ event_params,
+ ) = self._invoke_before_exec_event(
+ ddl, distilled_parameters, execution_options
+ )
exec_opts = self._execution_options.merge_with(execution_options)
schema_translate_map = exec_opts.get("schema_translate_map", None)
@@ -1175,10 +1200,43 @@ class Connection(Connectable):
)
if self._has_events or self.engine._has_events:
self.dispatch.after_execute(
- self, ddl, multiparams, params, execution_options, ret
+ self,
+ ddl,
+ event_multiparams,
+ event_params,
+ execution_options,
+ ret,
)
return ret
+ def _invoke_before_exec_event(
+ self, elem, distilled_params, execution_options
+ ):
+
+ if len(distilled_params) == 1:
+ event_multiparams, event_params = [], distilled_params[0]
+ else:
+ event_multiparams, event_params = distilled_params, {}
+
+ for fn in self.dispatch.before_execute:
+ elem, event_multiparams, event_params = fn(
+ self, elem, event_multiparams, event_params, execution_options,
+ )
+
+ if event_multiparams:
+ distilled_params = list(event_multiparams)
+ if event_params:
+ raise exc.InvalidRequestError(
+ "Event handler can't return non-empty multiparams "
+ "and params at the same time"
+ )
+ elif event_params:
+ distilled_params = [event_params]
+ else:
+ distilled_params = []
+
+ return distilled_params, event_multiparams, event_params
+
def _execute_clauseelement(
self, elem, multiparams, params, execution_options
):
@@ -1188,14 +1246,18 @@ class Connection(Connectable):
self._execution_options, execution_options
)
+ distilled_params = _distill_params(self, multiparams, params)
+
has_events = self._has_events or self.engine._has_events
if has_events:
- for fn in self.dispatch.before_execute:
- elem, multiparams, params = fn(
- self, elem, multiparams, params, execution_options
- )
+ (
+ distilled_params,
+ event_multiparams,
+ event_params,
+ ) = self._invoke_before_exec_event(
+ elem, distilled_params, execution_options
+ )
- distilled_params = _distill_params(multiparams, params)
if distilled_params:
# ensure we don't retain a link to the view object for keys()
# which links to the values, which we don't want to cache
@@ -1237,7 +1299,12 @@ class Connection(Connectable):
)
if has_events:
self.dispatch.after_execute(
- self, elem, multiparams, params, execution_options, ret
+ self,
+ elem,
+ event_multiparams,
+ event_params,
+ execution_options,
+ ret,
)
return ret
@@ -1257,49 +1324,58 @@ class Connection(Connectable):
execution_options = compiled.execution_options.merge_with(
self._execution_options, execution_options
)
+ distilled_parameters = _distill_params(self, multiparams, params)
if self._has_events or self.engine._has_events:
- for fn in self.dispatch.before_execute:
- compiled, multiparams, params = fn(
- self, compiled, multiparams, params, execution_options
- )
+ (
+ distilled_params,
+ event_multiparams,
+ event_params,
+ ) = self._invoke_before_exec_event(
+ compiled, distilled_parameters, execution_options
+ )
dialect = self.dialect
- parameters = _distill_params(multiparams, params)
ret = self._execute_context(
dialect,
dialect.execution_ctx_cls._init_compiled,
compiled,
- parameters,
+ distilled_parameters,
execution_options,
compiled,
- parameters,
+ distilled_parameters,
None,
None,
)
if self._has_events or self.engine._has_events:
self.dispatch.after_execute(
- self, compiled, multiparams, params, execution_options, ret
+ self,
+ compiled,
+ event_multiparams,
+ event_params,
+ execution_options,
+ ret,
)
return ret
def _exec_driver_sql(
- self,
- statement,
- multiparams,
- params,
- distilled_parameters,
- execution_options,
+ self, statement, multiparams, params, execution_options, future
):
execution_options = self._execution_options.merge_with(
execution_options
)
- if self._has_events or self.engine._has_events:
- for fn in self.dispatch.before_execute:
- statement, multiparams, params = fn(
- self, statement, multiparams, params, execution_options
+ distilled_parameters = _distill_params(self, multiparams, params)
+
+ if not future:
+ if self._has_events or self.engine._has_events:
+ (
+ distilled_params,
+ event_multiparams,
+ event_params,
+ ) = self._invoke_before_exec_event(
+ statement, distilled_parameters, execution_options
)
dialect = self.dialect
@@ -1312,10 +1388,17 @@ class Connection(Connectable):
statement,
distilled_parameters,
)
- if self._has_events or self.engine._has_events:
- self.dispatch.after_execute(
- self, statement, multiparams, params, execution_options, ret
- )
+
+ if not future:
+ if self._has_events or self.engine._has_events:
+ self.dispatch.after_execute(
+ self,
+ statement,
+ event_multiparams,
+ event_params,
+ execution_options,
+ ret,
+ )
return ret
def _execute_20(
@@ -1324,9 +1407,7 @@ class Connection(Connectable):
parameters=None,
execution_options=_EMPTY_EXECUTION_OPTS,
):
- multiparams, params, distilled_parameters = _distill_params_20(
- parameters
- )
+ args_10style, kwargs_10style = _distill_params_20(parameters)
try:
meth = statement._execute_on_connection
except AttributeError as err:
@@ -1334,7 +1415,7 @@ class Connection(Connectable):
exc.ObjectNotExecutableError(statement), replace_context=err
)
else:
- return meth(self, multiparams, params, execution_options)
+ return meth(self, args_10style, kwargs_10style, execution_options)
def exec_driver_sql(
self, statement, parameters=None, execution_options=None
@@ -1373,22 +1454,28 @@ class Connection(Connectable):
(1, 'v1')
)
+ .. note:: The :meth:`_engine.Connection.exec_driver_sql` method does
+ not participate in the
+ :meth:`_events.ConnectionEvents.before_execute` and
+ :meth:`_events.ConnectionEvents.after_execute` events. To
+ intercept calls to :meth:`_engine.Connection.exec_driver_sql`, use
+ :meth:`_events.ConnectionEvents.before_cursor_execute` and
+ :meth:`_events.ConnectionEvents.after_cursor_execute`.
+
.. seealso::
:pep:`249`
"""
- multiparams, params, distilled_parameters = _distill_params_20(
- parameters
- )
+ args_10style, kwargs_10style = _distill_params_20(parameters)
return self._exec_driver_sql(
statement,
- multiparams,
- params,
- distilled_parameters,
+ args_10style,
+ kwargs_10style,
execution_options,
+ future=True,
)
def _execute_context(
diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py
index fc0260ae2..c1f6bad77 100644
--- a/lib/sqlalchemy/engine/util.py
+++ b/lib/sqlalchemy/engine/util.py
@@ -30,10 +30,14 @@ def connection_memoize(key):
return decorated
+_no_tuple = ()
+_no_kw = util.immutabledict()
+
+
def py_fallback():
# TODO: pass the Connection in so that there can be a standard
# method for warning on parameter format
- def _distill_params(multiparams, params): # noqa
+ def _distill_params(connection, multiparams, params): # noqa
r"""Given arguments from the calling form \*multiparams, \**params,
return a list of bind parameter structures, usually a list of
dictionaries.
@@ -43,9 +47,12 @@ def py_fallback():
"""
+ # C version will fail if this assertion is not true.
+ # assert isinstance(multiparams, tuple)
+
if not multiparams:
if params:
- # TODO: parameter format deprecation warning
+ connection._warn_for_legacy_exec_format()
return [params]
else:
return []
@@ -61,16 +68,22 @@ def py_fallback():
# execute(stmt, [(), (), (), ...])
return zero
else:
+ # this is used by exec_driver_sql only, so a deprecation
+ # warning would already be coming from passing a plain
+ # textual statement with positional parameters to
+ # execute().
# execute(stmt, ("value", "value"))
+
return [zero]
elif hasattr(zero, "keys"):
# execute(stmt, {"key":"value"})
return [zero]
else:
+ connection._warn_for_legacy_exec_format()
# execute(stmt, "value")
return [[zero]]
else:
- # TODO: parameter format deprecation warning
+ connection._warn_for_legacy_exec_format()
if hasattr(multiparams[0], "__iter__") and not hasattr(
multiparams[0], "strip"
):
@@ -81,14 +94,55 @@ def py_fallback():
return locals()
-_no_tuple = ()
-_no_kw = util.immutabledict()
+def _distill_cursor_params(connection, multiparams, params):
+ """_distill_params without any warnings. more appropriate for
+ "cursor" params that can include tuple arguments, lists of tuples,
+ etc.
+
+ """
+
+ if not multiparams:
+ if params:
+ return [params]
+ else:
+ return []
+ elif len(multiparams) == 1:
+ zero = multiparams[0]
+ if isinstance(zero, (list, tuple)):
+ if (
+ not zero
+ or hasattr(zero[0], "__iter__")
+ and not hasattr(zero[0], "strip")
+ ):
+ # execute(stmt, [{}, {}, {}, ...])
+ # execute(stmt, [(), (), (), ...])
+ return zero
+ else:
+ # this is used by exec_driver_sql only, so a deprecation
+ # warning would already be coming from passing a plain
+ # textual statement with positional parameters to
+ # execute().
+ # execute(stmt, ("value", "value"))
+
+ return [zero]
+ elif hasattr(zero, "keys"):
+ # execute(stmt, {"key":"value"})
+ return [zero]
+ else:
+ # execute(stmt, "value")
+ return [[zero]]
+ else:
+ if hasattr(multiparams[0], "__iter__") and not hasattr(
+ multiparams[0], "strip"
+ ):
+ return multiparams
+ else:
+ return [multiparams]
def _distill_params_20(params):
- # TODO: this has to be in C
if params is None:
- return _no_tuple, _no_kw, []
+ return _no_tuple, _no_kw
elif isinstance(params, list):
# collections_abc.MutableSequence): # avoid abc.__instancecheck__
if params and not isinstance(
@@ -98,15 +152,14 @@ def _distill_params_20(params):
"List argument must consist only of tuples or dictionaries"
)
- # the tuple is needed atm by the C version of _distill_params...
- return tuple(params), _no_kw, params
+ return (params,), _no_kw
elif isinstance(
params,
(tuple, dict, immutabledict),
# avoid abc.__instancecheck__
# (collections_abc.Sequence, collections_abc.Mapping),
):
- return _no_tuple, params, [params]
+ return (params,), _no_kw
else:
raise exc.ArgumentError("mapping or sequence expected for parameters")
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index 73b062b96..c86e26ccf 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -13,7 +13,7 @@ from .. import event
from .. import util
from ..engine import url
from ..engine.default import DefaultDialect
-from ..engine.util import _distill_params
+from ..engine.util import _distill_cursor_params
from ..schema import _DDLCompiles
@@ -348,7 +348,9 @@ class SQLExecuteObserved(object):
def __init__(self, context, clauseelement, multiparams, params):
self.context = context
self.clauseelement = clauseelement
- self.parameters = _distill_params(multiparams, params)
+ self.parameters = _distill_cursor_params(
+ context.connection, tuple(multiparams), params
+ )
self.statements = []
def __repr__(self):