diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-04-18 22:54:40 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-04-18 22:54:40 +0000 |
commit | 4fffc21c87cbdfc538fe2924f82bf1591823856d (patch) | |
tree | f200a79e608f9e901baf515ce3d0e1b3b21b8bf6 /lib/sqlalchemy/sql.py | |
parent | 7efd23b23cbbd1d714cc31e44e776b7e1e9af319 (diff) | |
download | sqlalchemy-4fffc21c87cbdfc538fe2924f82bf1591823856d.tar.gz |
- the "where" criterion of an update() and delete() now correlates
embedded select() statements against the table being updated or
deleted. this works the same as nested select() statement
correlation, and can be disabled via the correlate=False flag on
the embedded select().
Diffstat (limited to 'lib/sqlalchemy/sql.py')
-rw-r--r-- | lib/sqlalchemy/sql.py | 26 |
1 files changed, 21 insertions, 5 deletions
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 94b618491..a8663ed4c 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -2301,6 +2301,7 @@ class Select(_SelectBaseMixin, FromClause): use_labels=False, distinct=False, for_update=False, engine=None, limit=None, offset=None, scalar=False, correlate=True): + # TODO: docstring ! _SelectBaseMixin.__init__(self) self.__froms = util.OrderedSet() self.__hide_froms = util.Set([self]) @@ -2319,7 +2320,7 @@ class Select(_SelectBaseMixin, FromClause): self.is_scalar = scalar # indicates if this select statement, as a subquery, should automatically correlate - # its FROM clause to that of an enclosing select statement. + # its FROM clause to that of an enclosing select, update, or delete statement. # note that the "correlate" method can be used to explicitly add a value to be correlated. self.should_correlate = correlate @@ -2560,6 +2561,20 @@ class _UpdateBase(ClauseElement): def supports_execution(self): return True + class _SelectCorrelator(NoColumnVisitor): + def __init__(self, table): + NoColumnVisitor.__init__(self) + self.table = table + + def visit_select(self, select): + if select.should_correlate: + select.correlate(self.table) + + def _process_whereclause(self, whereclause): + if whereclause is not None: + _UpdateBase._SelectCorrelator(self.table).traverse(whereclause) + return whereclause + def _process_colparams(self, parameters): """Receive the *values* of an ``INSERT`` or ``UPDATE`` statement and construct appropriate bind parameters. @@ -2576,10 +2591,11 @@ class _UpdateBase(ClauseElement): i +=1 parameters = pp + correlator = _UpdateBase._SelectCorrelator(self.table) for key in parameters.keys(): value = parameters[key] - if isinstance(value, Select): - value.correlate(self.table) + if isinstance(value, ClauseElement): + correlator.traverse(value) elif _is_literal(value): if _is_literal(key): col = self.table.c[key] @@ -2611,7 +2627,7 @@ class _Insert(_UpdateBase): class _Update(_UpdateBase): def __init__(self, table, whereclause, values=None): self.table = table - self.whereclause = whereclause + self.whereclause = self._process_whereclause(whereclause) self.parameters = self._process_colparams(values) def get_children(self, **kwargs): @@ -2625,7 +2641,7 @@ class _Update(_UpdateBase): class _Delete(_UpdateBase): def __init__(self, table, whereclause): self.table = table - self.whereclause = whereclause + self.whereclause = self._process_whereclause(whereclause) def get_children(self, **kwargs): if self.whereclause is not None: |