diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 108 |
1 files changed, 56 insertions, 52 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 6aab22a79..e8cc3378e 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -90,7 +90,7 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): operators = OPERATORS - def __init__(self, dialect, statement, parameters=None, **kwargs): + def __init__(self, dialect, statement, parameters=None, inline=False, **kwargs): """Construct a new ``DefaultCompiler`` object. dialect @@ -113,6 +113,9 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): # if we are insert/update. set to true when we visit an INSERT or UPDATE self.isinsert = self.isupdate = False + # compile INSERT/UPDATE defaults/sequences inlined (no pre-execute) + self.inline = inline or getattr(statement, 'inline', False) + # a dictionary of bind parameter keys to _BindParamClause instances. self.binds = {} @@ -151,12 +154,6 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): # an IdentifierPreparer that formats the quoting of identifiers self.preparer = self.dialect.identifier_preparer - # for UPDATE and INSERT statements, a set of columns whos values are being set - # from a SQL expression (i.e., not one of the bind parameter values). if present, - # default-value logic in the Dialect knows not to fire off column defaults - # and also knows postfetching will be needed to get the values represented by these - # parameters. - self.inline_params = None def after_compile(self): # this re will search for params like :param @@ -615,26 +612,14 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): def uses_sequences_for_inserts(self): return False - - def visit_insert(self, insert_stmt): - # search for columns who will be required to have an explicit bound value. - # for inserts, this includes Python-side defaults, columns with sequences for dialects - # that support sequences, and primary key columns for dialects that explicitly insert - # pre-generated primary key values - required_cols = [ - c for c in insert_stmt.table.c - if \ - isinstance(c, schema.SchemaItem) and \ - (self.parameters is None or self.parameters.get(c.key, None) is None) and \ - ( - ((c.primary_key or isinstance(c.default, schema.Sequence)) and self.uses_sequences_for_inserts()) or - isinstance(c.default, schema.ColumnDefault) - ) - ] + def visit_sequence(self, seq): + raise NotImplementedError() + + def visit_insert(self, insert_stmt): self.isinsert = True - colparams = self._get_colparams(insert_stmt, required_cols) + colparams = self._get_colparams(insert_stmt) return ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" + " VALUES (" + string.join([c[1] for c in colparams], ', ') + ")") @@ -642,17 +627,8 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): def visit_update(self, update_stmt): self.stack.append({'from':util.Set([update_stmt.table])}) - # search for columns who will be required to have an explicit bound value. - # for updates, this includes Python-side "onupdate" defaults. - required_cols = [c for c in update_stmt.table.c - if - isinstance(c, schema.SchemaItem) and \ - (self.parameters is None or self.parameters.get(c.key, None) is None) and - isinstance(c.onupdate, schema.ColumnDefault) - ] - self.isupdate = True - colparams = self._get_colparams(update_stmt, required_cols) + colparams = self._get_colparams(update_stmt) text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams], ', ') @@ -663,13 +639,10 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): return text - def _get_colparams(self, stmt, required_cols): + def _get_colparams(self, stmt): """create a set of tuples representing column/string pairs for use in an INSERT or UPDATE statement. - This method may generate new bind params within this compiled - based on the given set of "required columns", which are required - to have a value set in the statement. """ def create_bind_param(col, value): @@ -677,8 +650,9 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): self.binds[col.key] = bindparam return self.bindparam_string(self._truncate_bindparam(bindparam)) - self.inline_params = util.Set() - + self.postfetch = util.Set() + self.prefetch = util.Set() + def to_col(key): if not isinstance(key, sql._ColumnClause): return stmt.table.columns.get(unicode(key), key) @@ -701,23 +675,53 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): for k, v in stmt.parameters.iteritems(): parameters.setdefault(getattr(k, 'key', k), v) - for col in required_cols: - parameters.setdefault(col.key, None) - # create a list of column assignment clauses as tuples values = [] for c in stmt.table.columns: if c.key in parameters: value = parameters[c.key] - else: - continue - if sql._is_literal(value): - value = create_bind_param(c, value) - else: - self.inline_params.add(c) - value = self.process(value) - values.append((c, value)) - + if sql._is_literal(value): + value = create_bind_param(c, value) + else: + self.postfetch.add(c) + value = self.process(value.self_group()) + values.append((c, value)) + elif isinstance(c, schema.Column): + if self.isinsert: + if isinstance(c.default, schema.ColumnDefault): + if self.inline and isinstance(c.default.arg, sql.ClauseElement): + values.append((c, self.process(c.default.arg))) + self.postfetch.add(c) + else: + values.append((c, create_bind_param(c, None))) + self.prefetch.add(c) + elif isinstance(c.default, schema.PassiveDefault): + if c.primary_key and self.uses_sequences_for_inserts() and not self.inline: + values.append((c, create_bind_param(c, None))) + self.prefetch.add(c) + else: + self.postfetch.add(c) + elif (c.primary_key or isinstance(c.default, schema.Sequence)) and self.uses_sequences_for_inserts(): + if self.inline: + if c.default is not None: + proc = self.process(c.default) + if proc is not None: + values.append((c, proc)) + self.postfetch.add(c) + else: + print "ISINSERT, HAS A SEQUENCE, IS PRIMARY KEY, ADDING PREFETCH:", c.key + values.append((c, create_bind_param(c, None))) + self.prefetch.add(c) + elif self.isupdate: + if isinstance(c.onupdate, schema.ColumnDefault): + if self.inline and isinstance(c.onupdate.arg, sql.ClauseElement): + values.append((c, self.process(c.onupdate.arg))) + self.postfetch.add(c) + else: + values.append((c, create_bind_param(c, None))) + self.prefetch.add(c) + elif isinstance(c.onupdate, schema.PassiveDefault): + self.postfetch.add(c) return values def visit_delete(self, delete_stmt): |