From 4fffc21c87cbdfc538fe2924f82bf1591823856d Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 18 Apr 2007 22:54:40 +0000 Subject: - the "where" criterion of an update() and delete() now correlates embedded select() statements against the table being updated or deleted. this works the same as nested select() statement correlation, and can be disabled via the correlate=False flag on the embedded select(). --- lib/sqlalchemy/sql.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) (limited to 'lib/sqlalchemy/sql.py') diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 94b618491..a8663ed4c 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -2301,6 +2301,7 @@ class Select(_SelectBaseMixin, FromClause): use_labels=False, distinct=False, for_update=False, engine=None, limit=None, offset=None, scalar=False, correlate=True): + # TODO: docstring ! _SelectBaseMixin.__init__(self) self.__froms = util.OrderedSet() self.__hide_froms = util.Set([self]) @@ -2319,7 +2320,7 @@ class Select(_SelectBaseMixin, FromClause): self.is_scalar = scalar # indicates if this select statement, as a subquery, should automatically correlate - # its FROM clause to that of an enclosing select statement. + # its FROM clause to that of an enclosing select, update, or delete statement. # note that the "correlate" method can be used to explicitly add a value to be correlated. self.should_correlate = correlate @@ -2560,6 +2561,20 @@ class _UpdateBase(ClauseElement): def supports_execution(self): return True + class _SelectCorrelator(NoColumnVisitor): + def __init__(self, table): + NoColumnVisitor.__init__(self) + self.table = table + + def visit_select(self, select): + if select.should_correlate: + select.correlate(self.table) + + def _process_whereclause(self, whereclause): + if whereclause is not None: + _UpdateBase._SelectCorrelator(self.table).traverse(whereclause) + return whereclause + def _process_colparams(self, parameters): """Receive the *values* of an ``INSERT`` or ``UPDATE`` statement and construct appropriate bind parameters. @@ -2576,10 +2591,11 @@ class _UpdateBase(ClauseElement): i +=1 parameters = pp + correlator = _UpdateBase._SelectCorrelator(self.table) for key in parameters.keys(): value = parameters[key] - if isinstance(value, Select): - value.correlate(self.table) + if isinstance(value, ClauseElement): + correlator.traverse(value) elif _is_literal(value): if _is_literal(key): col = self.table.c[key] @@ -2611,7 +2627,7 @@ class _Insert(_UpdateBase): class _Update(_UpdateBase): def __init__(self, table, whereclause, values=None): self.table = table - self.whereclause = whereclause + self.whereclause = self._process_whereclause(whereclause) self.parameters = self._process_colparams(values) def get_children(self, **kwargs): @@ -2625,7 +2641,7 @@ class _Update(_UpdateBase): class _Delete(_UpdateBase): def __init__(self, table, whereclause): self.table = table - self.whereclause = whereclause + self.whereclause = self._process_whereclause(whereclause) def get_children(self, **kwargs): if self.whereclause is not None: -- cgit v1.2.1