diff options
-rw-r--r-- | CHANGES | 5 | ||||
-rw-r--r-- | lib/sqlalchemy/sql.py | 26 | ||||
-rw-r--r-- | test/sql/select.py | 7 |
3 files changed, 32 insertions, 6 deletions
@@ -32,6 +32,11 @@ of unicode situations that occur in db's such as MS-SQL to be better handled and allows subclassing of the Unicode datatype. [ticket:522] + - the "where" criterion of an update() and delete() now correlates + embedded select() statements against the table being updated or + deleted. this works the same as nested select() statement + correlation, and can be disabled via the correlate=False flag on + the embedded select(). - column labels are now generated in the compilation phase, which means their lengths are dialect-dependent. So on oracle a label that gets truncated to 30 chars will go out to 63 characters diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 94b618491..a8663ed4c 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -2301,6 +2301,7 @@ class Select(_SelectBaseMixin, FromClause): use_labels=False, distinct=False, for_update=False, engine=None, limit=None, offset=None, scalar=False, correlate=True): + # TODO: docstring ! _SelectBaseMixin.__init__(self) self.__froms = util.OrderedSet() self.__hide_froms = util.Set([self]) @@ -2319,7 +2320,7 @@ class Select(_SelectBaseMixin, FromClause): self.is_scalar = scalar # indicates if this select statement, as a subquery, should automatically correlate - # its FROM clause to that of an enclosing select statement. + # its FROM clause to that of an enclosing select, update, or delete statement. # note that the "correlate" method can be used to explicitly add a value to be correlated. self.should_correlate = correlate @@ -2560,6 +2561,20 @@ class _UpdateBase(ClauseElement): def supports_execution(self): return True + class _SelectCorrelator(NoColumnVisitor): + def __init__(self, table): + NoColumnVisitor.__init__(self) + self.table = table + + def visit_select(self, select): + if select.should_correlate: + select.correlate(self.table) + + def _process_whereclause(self, whereclause): + if whereclause is not None: + _UpdateBase._SelectCorrelator(self.table).traverse(whereclause) + return whereclause + def _process_colparams(self, parameters): """Receive the *values* of an ``INSERT`` or ``UPDATE`` statement and construct appropriate bind parameters. @@ -2576,10 +2591,11 @@ class _UpdateBase(ClauseElement): i +=1 parameters = pp + correlator = _UpdateBase._SelectCorrelator(self.table) for key in parameters.keys(): value = parameters[key] - if isinstance(value, Select): - value.correlate(self.table) + if isinstance(value, ClauseElement): + correlator.traverse(value) elif _is_literal(value): if _is_literal(key): col = self.table.c[key] @@ -2611,7 +2627,7 @@ class _Insert(_UpdateBase): class _Update(_UpdateBase): def __init__(self, table, whereclause, values=None): self.table = table - self.whereclause = whereclause + self.whereclause = self._process_whereclause(whereclause) self.parameters = self._process_colparams(values) def get_children(self, **kwargs): @@ -2625,7 +2641,7 @@ class _Update(_UpdateBase): class _Delete(_UpdateBase): def __init__(self, table, whereclause): self.table = table - self.whereclause = whereclause + self.whereclause = self._process_whereclause(whereclause) def get_children(self, **kwargs): if self.whereclause is not None: diff --git a/test/sql/select.py b/test/sql/select.py index 91b293cbe..1d0a63e2f 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -828,10 +828,15 @@ class CRUDTest(SQLTest): u = update(table1, table1.c.name == 'jack', values = {table1.c.name : s}) self.runtest(u, "UPDATE mytable SET name=(SELECT myothertable.otherid, myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid) WHERE mytable.name = :mytable_name") - # test a correlated WHERE clause + # test a non-correlated WHERE clause s = select([table2.c.othername], table2.c.otherid == 7) u = update(table1, table1.c.name==s) self.runtest(u, "UPDATE mytable SET myid=:myid, name=:name, description=:description WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = :myothertable_otherid)") + + # test one that is actually correlated... + s = select([table2.c.othername], table2.c.otherid == table1.c.myid) + u = table1.update(table1.c.name==s) + self.runtest(u, "UPDATE mytable SET myid=:myid, name=:name, description=:description WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid)") def testdelete(self): self.runtest(delete(table1, table1.c.myid == 7), "DELETE FROM mytable WHERE mytable.myid = :mytable_myid") |