diff options
-rw-r--r-- | doc/build/changelog/changelog_11.rst | 9 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/default.py | 59 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 35 | ||||
-rw-r--r-- | test/requirements.py | 8 | ||||
-rw-r--r-- | test/sql/test_defaults.py | 87 |
6 files changed, 170 insertions, 34 deletions
diff --git a/doc/build/changelog/changelog_11.rst b/doc/build/changelog/changelog_11.rst index 8ed600639..75c434da8 100644 --- a/doc/build/changelog/changelog_11.rst +++ b/doc/build/changelog/changelog_11.rst @@ -22,6 +22,15 @@ :version: 1.1.0b3 .. change:: + :tags: bug, sql + :tickets: 3745 + + Fixed bug in new CTE feature for update/insert/delete stated + as a CTE inside of an enclosing statement (typically SELECT) whereby + oninsert and onupdate values weren't called upon for the embedded + statement. + + .. change:: :tags: bug, ext sqlalchemy.ext.indexable will intercept IndexError as well diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 3ed2d5ee8..1bb575984 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -593,12 +593,11 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self._is_implicit_returning = bool( compiled.returning and not compiled.statement._returning) - if not self.isdelete: - if self.compiled.prefetch: - if self.executemany: - self._process_executemany_defaults() - else: - self._process_executesingle_defaults() + if self.compiled.insert_prefetch or self.compiled.update_prefetch: + if self.executemany: + self._process_executemany_defaults() + else: + self._process_executesingle_defaults() processors = compiled._bind_processors @@ -712,7 +711,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext): @util.memoized_property def prefetch_cols(self): - return self.compiled.prefetch + if self.isinsert: + return self.compiled.insert_prefetch + elif self.isupdate: + return self.compiled.update_prefetch + else: + return () @util.memoized_property def returning_cols(self): @@ -1007,46 +1011,57 @@ class DefaultExecutionContext(interfaces.ExecutionContext): def _process_executemany_defaults(self): key_getter = self.compiled._key_getters_for_crud_column[2] - prefetch = self.compiled.prefetch scalar_defaults = {} + insert_prefetch = self.compiled.insert_prefetch + update_prefetch = self.compiled.update_prefetch + # pre-determine scalar Python-side defaults # to avoid many calls of get_insert_default()/ # get_update_default() - for c in prefetch: - if self.isinsert and c.default and c.default.is_scalar: + for c in insert_prefetch: + if c.default and c.default.is_scalar: scalar_defaults[c] = c.default.arg - elif self.isupdate and c.onupdate and c.onupdate.is_scalar: + for c in update_prefetch: + if c.onupdate and c.onupdate.is_scalar: scalar_defaults[c] = c.onupdate.arg for param in self.compiled_parameters: self.current_parameters = param - for c in prefetch: + for c in insert_prefetch: if c in scalar_defaults: val = scalar_defaults[c] - elif self.isinsert: + else: val = self.get_insert_default(c) + if val is not None: + param[key_getter(c)] = val + for c in update_prefetch: + if c in scalar_defaults: + val = scalar_defaults[c] else: val = self.get_update_default(c) if val is not None: param[key_getter(c)] = val + del self.current_parameters def _process_executesingle_defaults(self): key_getter = self.compiled._key_getters_for_crud_column[2] - prefetch = self.compiled.prefetch self.current_parameters = compiled_parameters = \ self.compiled_parameters[0] - for c in prefetch: - if self.isinsert: - if c.default and \ - not c.default.is_sequence and c.default.is_scalar: - val = c.default.arg - else: - val = self.get_insert_default(c) + for c in self.compiled.insert_prefetch: + if c.default and \ + not c.default.is_sequence and c.default.is_scalar: + val = c.default.arg else: - val = self.get_update_default(c) + val = self.get_insert_default(c) + + if val is not None: + compiled_parameters[key_getter(c)] = val + + for c in self.compiled.update_prefetch: + val = self.get_update_default(c) if val is not None: compiled_parameters[key_getter(c)] = val diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 16ca7f959..095c84f03 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -359,6 +359,8 @@ class SQLCompiler(Compiled): True unless using an unordered TextAsFrom. """ + insert_prefetch = update_prefetch = () + def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): """Construct a new :class:`.SQLCompiler` object. @@ -428,6 +430,10 @@ class SQLCompiler(Compiled): if self.positional and dialect.paramstyle == 'numeric': self._apply_numbered_params() + @property + def prefetch(self): + return list(self.insert_prefetch + self.update_prefetch) + @util.memoized_instancemethod def _init_cte_state(self): """Initialize collections related to CTEs only if diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 70e03d220..f770fc513 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -11,6 +11,7 @@ within INSERT and UPDATE statements. """ from .. import util from .. import exc +from . import dml from . import elements import operator @@ -73,7 +74,8 @@ def _get_crud_params(compiler, stmt, **kw): """ compiler.postfetch = [] - compiler.prefetch = [] + compiler.insert_prefetch = [] + compiler.update_prefetch = [] compiler.returning = [] # no parameters in the statement, no parameters in the @@ -370,7 +372,7 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): compiler.returning.append(c) else: values.append( - (c, _create_prefetch_bind_param(compiler, c)) + (c, _create_insert_prefetch_bind_param(compiler, c)) ) elif c is stmt.table._autoincrement_column or c.server_default is not None: compiler.returning.append(c) @@ -380,9 +382,15 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): _raise_pk_with_no_anticipated_value(c) -def _create_prefetch_bind_param(compiler, c, process=True, name=None): +def _create_insert_prefetch_bind_param(compiler, c, process=True, name=None): param = _create_bind_param(compiler, c, None, process=process, name=name) - compiler.prefetch.append(c) + compiler.insert_prefetch.append(c) + return param + + +def _create_update_prefetch_bind_param(compiler, c, process=True, name=None): + param = _create_bind_param(compiler, c, None, process=process, name=name) + compiler.update_prefetch.append(c) return param @@ -399,7 +407,7 @@ class _multiparam_column(elements.ColumnElement): other.original == self.original -def _process_multiparam_default_bind(compiler, c, index, kw): +def _process_multiparam_default_bind(compiler, stmt, c, index, kw): if not c.default: raise exc.CompileError( @@ -410,7 +418,10 @@ def _process_multiparam_default_bind(compiler, c, index, kw): return compiler.process(c.default.arg.self_group(), **kw) else: col = _multiparam_column(c, index) - return _create_prefetch_bind_param(compiler, col) + if isinstance(stmt, dml.Insert): + return _create_insert_prefetch_bind_param(compiler, col) + else: + return _create_update_prefetch_bind_param(compiler, col) def _append_param_insert_pk(compiler, stmt, c, values, kw): @@ -448,7 +459,7 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw): ) ): values.append( - (c, _create_prefetch_bind_param(compiler, c)) + (c, _create_insert_prefetch_bind_param(compiler, c)) ) elif c.default is None and c.server_default is None and not c.nullable: # no .default, no .server_default, not autoincrement, we have @@ -482,7 +493,7 @@ def _append_param_insert_hasdefault( compiler.postfetch.append(c) else: values.append( - (c, _create_prefetch_bind_param(compiler, c)) + (c, _create_insert_prefetch_bind_param(compiler, c)) ) @@ -500,7 +511,7 @@ def _append_param_insert_select_hasdefault( values.append((c, proc)) else: values.append( - (c, _create_prefetch_bind_param(compiler, c, process=False)) + (c, _create_insert_prefetch_bind_param(compiler, c, process=False)) ) @@ -520,7 +531,7 @@ def _append_param_update( compiler.postfetch.append(c) else: values.append( - (c, _create_prefetch_bind_param(compiler, c)) + (c, _create_update_prefetch_bind_param(compiler, c)) ) elif c.server_onupdate is not None: if implicit_return_defaults and \ @@ -575,7 +586,7 @@ def _get_multitable_params( compiler.postfetch.append(c) else: values.append( - (c, _create_prefetch_bind_param( + (c, _create_update_prefetch_bind_param( compiler, c, name=_col_bind_name(c))) ) elif c.server_onupdate is not None: @@ -597,7 +608,7 @@ def _extend_values_for_multiparams(compiler, stmt, values, kw): else compiler.process( row[c.key].self_group(), **kw)) if c.key in row else - _process_multiparam_default_bind(compiler, c, i, kw) + _process_multiparam_default_bind(compiler, stmt, c, i, kw) ) for (c, param) in values_0 ] diff --git a/test/requirements.py b/test/requirements.py index d31088e16..3c7a3fbb4 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -351,6 +351,14 @@ class DefaultRequirements(SuiteRequirements): return skip_if(exclude('mysql', '<', (4, 1, 1)), 'no subquery support') @property + def ctes(self): + """Target database supports CTEs""" + + return only_if( + ['postgresql', 'mssql'] + ) + + @property def mod_operator_as_percent_sign(self): """target database must use a plain percent '%' as the 'modulus' operator.""" diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index db19e145b..57af1e536 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -539,6 +539,93 @@ class DefaultTest(fixtures.TestBase): eq_(55, l['col3']) +class CTEDefaultTest(fixtures.TablesTest): + __requires__ = ('ctes',) + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + 'q', metadata, + Column('x', Integer, default=2), + Column('y', Integer, onupdate=5), + Column('z', Integer) + ) + + Table( + 'p', metadata, + Column('s', Integer), + Column('t', Integer), + Column('u', Integer, onupdate=1) + ) + + def _test_a_in_b(self, a, b): + q = self.tables.q + p = self.tables.p + + with testing.db.connect() as conn: + if a == 'delete': + conn.execute(q.insert().values(y=10, z=1)) + cte = q.delete().\ + where(q.c.z == 1).returning(q.c.z).cte('c') + expected = None + elif a == "insert": + cte = q.insert().values(z=1, y=10).returning(q.c.z).cte('c') + expected = (2, 10) + elif a == "update": + conn.execute(q.insert().values(x=5, y=10, z=1)) + cte = q.update().\ + where(q.c.z == 1).values(x=7).returning(q.c.z).cte('c') + expected = (7, 5) + elif a == "select": + conn.execute(q.insert().values(x=5, y=10, z=1)) + cte = sa.select([q.c.z]).cte('c') + expected = (5, 10) + + if b == "select": + conn.execute(p.insert().values(s=1)) + stmt = select([p.c.s, cte.c.z]) + elif b == "insert": + sel = select([1, cte.c.z, ]) + stmt = p.insert().from_select(['s', 't'], sel).returning( + p.c.s, p.c.t) + elif b == "delete": + stmt = p.insert().values(s=1, t=cte.c.z).returning( + p.c.s, cte.c.z) + elif b == "update": + conn.execute(p.insert().values(s=1)) + stmt = p.update().values(t=5).\ + where(p.c.s == cte.c.z).\ + returning(p.c.u, cte.c.z) + eq_( + conn.execute(stmt).fetchall(), + [(1, 1)] + ) + + eq_( + conn.execute(select([q.c.x, q.c.y])).fetchone(), + expected + ) + + def test_update_in_select(self): + self._test_a_in_b("update", "select") + + def test_delete_in_select(self): + self._test_a_in_b("update", "select") + + def test_insert_in_select(self): + self._test_a_in_b("update", "select") + + def test_select_in_update(self): + self._test_a_in_b("select", "update") + + def test_select_in_insert(self): + self._test_a_in_b("select", "insert") + + # TODO: updates / inserts can be run in one statement w/ CTE ? + # deletes? + + class PKDefaultTest(fixtures.TablesTest): __requires__ = ('subqueries',) __backend__ = True |