diff options
Diffstat (limited to 'lib/sqlalchemy/sql.py')
-rw-r--r-- | lib/sqlalchemy/sql.py | 26 |
1 files changed, 21 insertions, 5 deletions
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 94b618491..a8663ed4c 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -2301,6 +2301,7 @@ class Select(_SelectBaseMixin, FromClause): use_labels=False, distinct=False, for_update=False, engine=None, limit=None, offset=None, scalar=False, correlate=True): + # TODO: docstring ! _SelectBaseMixin.__init__(self) self.__froms = util.OrderedSet() self.__hide_froms = util.Set([self]) @@ -2319,7 +2320,7 @@ class Select(_SelectBaseMixin, FromClause): self.is_scalar = scalar # indicates if this select statement, as a subquery, should automatically correlate - # its FROM clause to that of an enclosing select statement. + # its FROM clause to that of an enclosing select, update, or delete statement. # note that the "correlate" method can be used to explicitly add a value to be correlated. self.should_correlate = correlate @@ -2560,6 +2561,20 @@ class _UpdateBase(ClauseElement): def supports_execution(self): return True + class _SelectCorrelator(NoColumnVisitor): + def __init__(self, table): + NoColumnVisitor.__init__(self) + self.table = table + + def visit_select(self, select): + if select.should_correlate: + select.correlate(self.table) + + def _process_whereclause(self, whereclause): + if whereclause is not None: + _UpdateBase._SelectCorrelator(self.table).traverse(whereclause) + return whereclause + def _process_colparams(self, parameters): """Receive the *values* of an ``INSERT`` or ``UPDATE`` statement and construct appropriate bind parameters. @@ -2576,10 +2591,11 @@ class _UpdateBase(ClauseElement): i +=1 parameters = pp + correlator = _UpdateBase._SelectCorrelator(self.table) for key in parameters.keys(): value = parameters[key] - if isinstance(value, Select): - value.correlate(self.table) + if isinstance(value, ClauseElement): + correlator.traverse(value) elif _is_literal(value): if _is_literal(key): col = self.table.c[key] @@ -2611,7 +2627,7 @@ class _Insert(_UpdateBase): class _Update(_UpdateBase): def __init__(self, table, whereclause, values=None): self.table = table - self.whereclause = whereclause + self.whereclause = self._process_whereclause(whereclause) self.parameters = self._process_colparams(values) def get_children(self, **kwargs): @@ -2625,7 +2641,7 @@ class _Update(_UpdateBase): class _Delete(_UpdateBase): def __init__(self, table, whereclause): self.table = table - self.whereclause = whereclause + self.whereclause = self._process_whereclause(whereclause) def get_children(self, **kwargs): if self.whereclause is not None: |