diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-07-10 11:00:49 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-07-10 11:00:49 -0400 |
commit | 0e41673ed4e8551b892c058ffc6a607cf7aba71c (patch) | |
tree | 6914091d7cbb58331d147242e6efa83bd1345424 /lib/sqlalchemy/sql/compiler.py | |
parent | b297b40fca923a03e3c34094e5298d6524944c39 (diff) | |
download | sqlalchemy-0e41673ed4e8551b892c058ffc6a607cf7aba71c.tar.gz |
- [bug] Fixed more un-intuitivenesses in CTEs
which prevented referring to a CTE in a union
of itself without it being aliased.
CTEs now render uniquely
on name, rendering the outermost CTE of a given
name only - all other references are rendered
just as the name. This even includes other
CTE/SELECTs that refer to different versions
of the same CTE object, such as a SELECT
or a UNION ALL of that SELECT. We are
somewhat loosening the usual link between object
identity and lexical identity in this case.
A true name conflict between two unrelated
CTEs now raises an error.
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 317 |
1 files changed, 174 insertions, 143 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 6fdb943d0..979c88e4b 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -17,7 +17,7 @@ strings :class:`~sqlalchemy.sql.compiler.GenericTypeCompiler` - renders type specification strings. -To generate user-defined SQL strings, see +To generate user-defined SQL strings, see :module:`~sqlalchemy.ext.compiler`. """ @@ -215,7 +215,7 @@ class SQLCompiler(engine.Compiled): driver/DB enforces this """ - def __init__(self, dialect, statement, column_keys=None, + def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): """Construct a new ``DefaultCompiler`` object. @@ -259,11 +259,7 @@ class SQLCompiler(engine.Compiled): self.positiontup = [] self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] - # collect CTEs to tack on top of a SELECT - self.ctes = util.OrderedDict() - self.ctes_recursive = False - if self.positional: - self.cte_positional = [] + self.ctes = None # an IdentifierPreparer that formats the quoting of identifiers self.preparer = dialect.identifier_preparer @@ -282,11 +278,25 @@ class SQLCompiler(engine.Compiled): if self.positional and dialect.paramstyle == 'numeric': self._apply_numbered_params() + @util.memoized_instancemethod + def _init_cte_state(self): + """Initialize collections related to CTEs only if + a CTE is located, to save on the overhead of + these collections otherwise. + + """ + # collect CTEs to tack on top of a SELECT + self.ctes = util.OrderedDict() + self.ctes_by_name = {} + self.ctes_recursive = False + if self.positional: + self.cte_positional = [] + def _apply_numbered_params(self): poscount = itertools.count(1) self.string = re.sub( - r'\[_POSITION\]', - lambda m:str(util.next(poscount)), + r'\[_POSITION\]', + lambda m:str(util.next(poscount)), self.string) @util.memoized_property @@ -320,11 +330,11 @@ class SQLCompiler(engine.Compiled): if _group_number: raise exc.InvalidRequestError( "A value is required for bind parameter %r, " - "in parameter group %d" % + "in parameter group %d" % (bindparam.key, _group_number)) else: raise exc.InvalidRequestError( - "A value is required for bind parameter %r" + "A value is required for bind parameter %r" % bindparam.key) else: pd[name] = bindparam.effective_value @@ -336,18 +346,18 @@ class SQLCompiler(engine.Compiled): if _group_number: raise exc.InvalidRequestError( "A value is required for bind parameter %r, " - "in parameter group %d" % + "in parameter group %d" % (bindparam.key, _group_number)) else: raise exc.InvalidRequestError( - "A value is required for bind parameter %r" + "A value is required for bind parameter %r" % bindparam.key) pd[self.bind_names[bindparam]] = bindparam.effective_value return pd @property def params(self): - """Return the bind param dictionary embedded into this + """Return the bind param dictionary embedded into this compiled object, for those values that are present.""" return self.construct_params(_check=False) @@ -363,8 +373,8 @@ class SQLCompiler(engine.Compiled): def visit_grouping(self, grouping, asfrom=False, **kwargs): return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")" - def visit_label(self, label, result_map=None, - within_label_clause=False, + def visit_label(self, label, result_map=None, + within_label_clause=False, within_columns_clause=False, **kw): # only render labels within the columns clause # or ORDER BY clause of a select. dialect-specific compilers @@ -376,23 +386,23 @@ class SQLCompiler(engine.Compiled): labelname = label.name if result_map is not None: - result_map[labelname - if self.dialect.case_sensitive + result_map[labelname + if self.dialect.case_sensitive else labelname.lower()] = ( - label.name, - (label, label.element, labelname, ) + + label.name, + (label, label.element, labelname, ) + label._alt_names, label.type) - return label.element._compiler_dispatch(self, + return label.element._compiler_dispatch(self, within_columns_clause=True, - within_label_clause=True, + within_label_clause=True, **kw) + \ OPERATORS[operators.as_] + \ self.preparer.format_label(label, labelname) else: - return label.element._compiler_dispatch(self, - within_columns_clause=False, + return label.element._compiler_dispatch(self, + within_columns_clause=False, **kw) def visit_column(self, column, result_map=None, **kwargs): @@ -406,10 +416,10 @@ class SQLCompiler(engine.Compiled): name = self._truncated_identifier("colident", name) if result_map is not None: - result_map[name - if self.dialect.case_sensitive - else name.lower()] = (orig_name, - (column, name, column.key), + result_map[name + if self.dialect.case_sensitive + else name.lower()] = (orig_name, + (column, name, column.key), column.type) if is_literal: @@ -423,7 +433,7 @@ class SQLCompiler(engine.Compiled): else: if table.schema: schema_prefix = self.preparer.quote_schema( - table.schema, + table.schema, table.quote_schema) + '.' else: schema_prefix = '' @@ -456,8 +466,8 @@ class SQLCompiler(engine.Compiled): def visit_textclause(self, textclause, **kwargs): if textclause.typemap is not None: for colname, type_ in textclause.typemap.iteritems(): - self.result_map[colname - if self.dialect.case_sensitive + self.result_map[colname + if self.dialect.case_sensitive else colname.lower()] = \ (colname, None, type_) @@ -490,8 +500,8 @@ class SQLCompiler(engine.Compiled): else: sep = OPERATORS[clauselist.operator] return sep.join( - s for s in - (c._compiler_dispatch(self, **kwargs) + s for s in + (c._compiler_dispatch(self, **kwargs) for c in clauselist.clauses) if s) @@ -531,13 +541,13 @@ class SQLCompiler(engine.Compiled): def visit_extract(self, extract, **kwargs): field = self.extract_map.get(extract.field, extract.field) - return "EXTRACT(%s FROM %s)" % (field, + return "EXTRACT(%s FROM %s)" % (field, extract.expr._compiler_dispatch(self, **kwargs)) def visit_function(self, func, result_map=None, **kwargs): if result_map is not None: - result_map[func.name - if self.dialect.case_sensitive + result_map[func.name + if self.dialect.case_sensitive else func.name.lower()] = \ (func.name, None, func.type) @@ -560,7 +570,7 @@ class SQLCompiler(engine.Compiled): def function_argspec(self, func, **kwargs): return func.clause_expr._compiler_dispatch(self, **kwargs) - def visit_compound_select(self, cs, asfrom=False, + def visit_compound_select(self, cs, asfrom=False, parens=True, compound_index=1, **kwargs): entry = self.stack and self.stack[-1] or {} self.stack.append({'from':entry.get('from', None), 'iswrapper':True}) @@ -568,8 +578,8 @@ class SQLCompiler(engine.Compiled): keyword = self.compound_keywords.get(cs.keyword) text = (" " + keyword + " ").join( - (c._compiler_dispatch(self, - asfrom=asfrom, parens=False, + (c._compiler_dispatch(self, + asfrom=asfrom, parens=False, compound_index=i, **kwargs) for i, c in enumerate(cs.selects)) ) @@ -610,8 +620,8 @@ class SQLCompiler(engine.Compiled): return self._operator_dispatch(binary.operator, binary, - lambda opstr: binary.left._compiler_dispatch(self, **kw) + - opstr + + lambda opstr: binary.left._compiler_dispatch(self, **kw) + + opstr + binary.right._compiler_dispatch( self, **kw), **kw @@ -620,36 +630,36 @@ class SQLCompiler(engine.Compiled): def visit_like_op(self, binary, **kw): escape = binary.modifiers.get("escape", None) return '%s LIKE %s' % ( - binary.left._compiler_dispatch(self, **kw), + binary.left._compiler_dispatch(self, **kw), binary.right._compiler_dispatch(self, **kw)) \ - + (escape and + + (escape and (' ESCAPE ' + self.render_literal_value(escape, None)) or '') def visit_notlike_op(self, binary, **kw): escape = binary.modifiers.get("escape", None) return '%s NOT LIKE %s' % ( - binary.left._compiler_dispatch(self, **kw), + binary.left._compiler_dispatch(self, **kw), binary.right._compiler_dispatch(self, **kw)) \ - + (escape and + + (escape and (' ESCAPE ' + self.render_literal_value(escape, None)) or '') def visit_ilike_op(self, binary, **kw): escape = binary.modifiers.get("escape", None) return 'lower(%s) LIKE lower(%s)' % ( - binary.left._compiler_dispatch(self, **kw), + binary.left._compiler_dispatch(self, **kw), binary.right._compiler_dispatch(self, **kw)) \ - + (escape and + + (escape and (' ESCAPE ' + self.render_literal_value(escape, None)) or '') def visit_notilike_op(self, binary, **kw): escape = binary.modifiers.get("escape", None) return 'lower(%s) NOT LIKE lower(%s)' % ( - binary.left._compiler_dispatch(self, **kw), + binary.left._compiler_dispatch(self, **kw), binary.right._compiler_dispatch(self, **kw)) \ - + (escape and + + (escape and (' ESCAPE ' + self.render_literal_value(escape, None)) or '') @@ -693,7 +703,7 @@ class SQLCompiler(engine.Compiled): "bindparam() name '%s' is reserved " "for automatic usage in the VALUES or SET " "clause of this " - "insert/update statement. Please use a " + "insert/update statement. Please use a " "name other than column name when using bindparam() " "with insert() or update() (for example, 'b_%s')." % (bindparam.key, bindparam.key) @@ -771,7 +781,7 @@ class SQLCompiler(engine.Compiled): self.anon_map[derived] = anonymous_counter + 1 return derived + "_" + str(anonymous_counter) - def bindparam_string(self, name, quote=None, + def bindparam_string(self, name, quote=None, positional_names=None, **kw): if self.positional: if positional_names is not None: @@ -780,8 +790,10 @@ class SQLCompiler(engine.Compiled): self.positiontup.append(name) return self.bindtemplate % {'name':name} - def visit_cte(self, cte, asfrom=False, ashint=False, - fromhints=None, **kwargs): + def visit_cte(self, cte, asfrom=False, ashint=False, + fromhints=None, + **kwargs): + self._init_cte_state() if self.positional: kwargs['positional_names'] = self.cte_positional @@ -790,6 +802,25 @@ class SQLCompiler(engine.Compiled): else: cte_name = cte.name + if cte_name in self.ctes_by_name: + existing_cte = self.ctes_by_name[cte_name] + # we've generated a same-named CTE that we are enclosed in, + # or this is the same CTE. just return the name. + if cte in existing_cte._restates or cte is existing_cte: + return cte_name + elif existing_cte in cte._restates: + # we've generated a same-named CTE that is + # enclosed in us - we take precedence, so + # discard the text for the "inner". + del self.ctes[existing_cte] + else: + raise exc.CompileError( + "Multiple, unrelated CTEs found with " + "the same name: %r" % + cte_name) + + self.ctes_by_name[cte_name] = cte + if cte.cte_alias: if isinstance(cte.cte_alias, sql._truncated_label): cte_alias = self._truncated_identifier("alias", cte.cte_alias) @@ -806,12 +837,12 @@ class SQLCompiler(engine.Compiled): col_source = cte.original.selects[0] else: assert False - recur_cols = [c for c in + recur_cols = [c for c in util.unique_list(col_source.inner_columns) if c is not None] text += "(%s)" % (", ".join( - self.preparer.format_column(ident) + self.preparer.format_column(ident) for ident in recur_cols)) text += " AS \n" + \ cte.original._compiler_dispatch( @@ -826,7 +857,7 @@ class SQLCompiler(engine.Compiled): return self.preparer.format_alias(cte, cte_name) return text - def visit_alias(self, alias, asfrom=False, ashint=False, + def visit_alias(self, alias, asfrom=False, ashint=False, iscrud=False, fromhints=None, **kwargs): if asfrom or ashint: @@ -838,13 +869,13 @@ class SQLCompiler(engine.Compiled): if ashint: return self.preparer.format_alias(alias, alias_name) elif asfrom: - ret = alias.original._compiler_dispatch(self, + ret = alias.original._compiler_dispatch(self, asfrom=True, **kwargs) + \ " AS " + \ self.preparer.format_alias(alias, alias_name) if fromhints and alias in fromhints: - ret = self.format_from_hint_text(ret, alias, + ret = self.format_from_hint_text(ret, alias, fromhints[alias], iscrud) return ret @@ -861,8 +892,8 @@ class SQLCompiler(engine.Compiled): select.use_labels and \ column._label: return _CompileLabel( - column, - column._label, + column, + column._label, alt_names=(column._key_label, ) ) @@ -872,9 +903,9 @@ class SQLCompiler(engine.Compiled): not column.is_literal and \ column.table is not None and \ not isinstance(column.table, sql.Select): - return _CompileLabel(column, sql._as_truncated(column.name), + return _CompileLabel(column, sql._as_truncated(column.name), alt_names=(column.key,)) - elif not isinstance(column, + elif not isinstance(column, (sql._UnaryExpression, sql._TextClause)) \ and (not hasattr(column, 'name') or \ isinstance(column, sql.Function)): @@ -897,9 +928,9 @@ class SQLCompiler(engine.Compiled): def get_crud_hint_text(self, table, text): return None - def visit_select(self, select, asfrom=False, parens=True, - iswrapper=False, fromhints=None, - compound_index=1, + def visit_select(self, select, asfrom=False, parens=True, + iswrapper=False, fromhints=None, + compound_index=1, positional_names=None, **kwargs): entry = self.stack and self.stack[-1] or {} @@ -919,7 +950,7 @@ class SQLCompiler(engine.Compiled): : iswrapper}) if compound_index==1 and not entry or entry.get('iswrapper', False): - column_clause_args = {'result_map':self.result_map, + column_clause_args = {'result_map':self.result_map, 'positional_names':positional_names} else: column_clause_args = {'positional_names':positional_names} @@ -930,7 +961,7 @@ class SQLCompiler(engine.Compiled): self.label_select_column(select, co, asfrom=asfrom).\ _compiler_dispatch(self, within_columns_clause=True, - **column_clause_args) + **column_clause_args) for co in util.unique_list(select.inner_columns) ] if c is not None @@ -943,9 +974,9 @@ class SQLCompiler(engine.Compiled): (from_, hinttext % { 'name':from_._compiler_dispatch( self, ashint=True) - }) - for (from_, dialect), hinttext in - select._hints.iteritems() + }) + for (from_, dialect), hinttext in + select._hints.iteritems() if dialect in ('*', self.dialect.name) ]) hint_text = self.get_select_hint_text(byfrom) @@ -954,7 +985,7 @@ class SQLCompiler(engine.Compiled): if select._prefixes: text += " ".join( - x._compiler_dispatch(self, **kwargs) + x._compiler_dispatch(self, **kwargs) for x in select._prefixes) + " " text += self.get_select_precolumns(select) text += ', '.join(inner_columns) @@ -963,13 +994,13 @@ class SQLCompiler(engine.Compiled): text += " \nFROM " if select._hints: - text += ', '.join([f._compiler_dispatch(self, - asfrom=True, fromhints=byfrom, - **kwargs) + text += ', '.join([f._compiler_dispatch(self, + asfrom=True, fromhints=byfrom, + **kwargs) for f in froms]) else: - text += ', '.join([f._compiler_dispatch(self, - asfrom=True, **kwargs) + text += ', '.join([f._compiler_dispatch(self, + asfrom=True, **kwargs) for f in froms]) else: text += self.default_from() @@ -1054,7 +1085,7 @@ class SQLCompiler(engine.Compiled): text += " OFFSET " + self.process(sql.literal(select._offset)) return text - def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, + def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, fromhints=None, **kwargs): if asfrom or ashint: if getattr(table, "schema", None): @@ -1065,7 +1096,7 @@ class SQLCompiler(engine.Compiled): else: ret = self.preparer.quote(table.name, table.quote) if fromhints and table in fromhints: - ret = self.format_from_hint_text(ret, table, + ret = self.format_from_hint_text(ret, table, fromhints[table], iscrud) return ret else: @@ -1073,10 +1104,10 @@ class SQLCompiler(engine.Compiled): def visit_join(self, join, asfrom=False, **kwargs): return ( - join.left._compiler_dispatch(self, asfrom=True, **kwargs) + - (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + - join.right._compiler_dispatch(self, asfrom=True, **kwargs) + - " ON " + + join.left._compiler_dispatch(self, asfrom=True, **kwargs) + + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + + join.right._compiler_dispatch(self, asfrom=True, **kwargs) + + " ON " + join.onclause._compiler_dispatch(self, **kwargs) ) @@ -1088,7 +1119,7 @@ class SQLCompiler(engine.Compiled): not self.dialect.supports_default_values and \ not self.dialect.supports_empty_insert: raise exc.CompileError("The version of %s you are using does " - "not support empty inserts." % + "not support empty inserts." % self.dialect.name) preparer = self.preparer @@ -1107,14 +1138,14 @@ class SQLCompiler(engine.Compiled): if insert_stmt._hints: dialect_hints = dict([ (table, hint_text) - for (table, dialect), hint_text in + for (table, dialect), hint_text in insert_stmt._hints.items() if dialect in ('*', self.dialect.name) ]) if insert_stmt.table in dialect_hints: table_text = self.format_from_hint_text( table_text, - insert_stmt.table, + insert_stmt.table, dialect_hints[insert_stmt.table], True ) @@ -1148,7 +1179,7 @@ class SQLCompiler(engine.Compiled): """Provide a hook for MySQL to add LIMIT to the UPDATE""" return None - def update_tables_clause(self, update_stmt, from_table, + def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): """Provide a hook to override the initial table clause in an UPDATE statement. @@ -1156,22 +1187,22 @@ class SQLCompiler(engine.Compiled): MySQL overrides this. """ - return from_table._compiler_dispatch(self, asfrom=True, + return from_table._compiler_dispatch(self, asfrom=True, iscrud=True, **kw) - def update_from_clause(self, update_stmt, - from_table, extra_froms, + def update_from_clause(self, update_stmt, + from_table, extra_froms, from_hints, **kw): - """Provide a hook to override the generation of an + """Provide a hook to override the generation of an UPDATE..FROM clause. MySQL and MSSQL override this. """ return "FROM " + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) + t._compiler_dispatch(self, asfrom=True, + fromhints=from_hints, **kw) for t in extra_froms) def visit_update(self, update_stmt, **kw): @@ -1190,14 +1221,14 @@ class SQLCompiler(engine.Compiled): if update_stmt._hints: dialect_hints = dict([ (table, hint_text) - for (table, dialect), hint_text in + for (table, dialect), hint_text in update_stmt._hints.items() if dialect in ('*', self.dialect.name) ]) if update_stmt.table in dialect_hints: table_text = self.format_from_hint_text( table_text, - update_stmt.table, + update_stmt.table, dialect_hints[update_stmt.table], True ) @@ -1209,12 +1240,12 @@ class SQLCompiler(engine.Compiled): text += ' SET ' if extra_froms and self.render_table_with_column_in_update_from: text += ', '.join( - self.visit_column(c[0]) + + self.visit_column(c[0]) + '=' + c[1] for c in colparams ) else: text += ', '.join( - self.preparer.quote(c[0].name, c[0].quote) + + self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1] for c in colparams ) @@ -1226,9 +1257,9 @@ class SQLCompiler(engine.Compiled): if extra_froms: extra_from_text = self.update_from_clause( - update_stmt, - update_stmt.table, - extra_froms, + update_stmt, + update_stmt.table, + extra_froms, dialect_hints, **kw) if extra_from_text: text += " " + extra_from_text @@ -1249,7 +1280,7 @@ class SQLCompiler(engine.Compiled): return text def _create_crud_bind_param(self, col, value, required=False): - bindparam = sql.bindparam(col.key, value, + bindparam = sql.bindparam(col.key, value, type_=col.type, required=required, quote=col.quote) bindparam._is_crud = True @@ -1275,8 +1306,8 @@ class SQLCompiler(engine.Compiled): # compiled params - return binds for all columns if self.column_keys is None and stmt.parameters is None: return [ - (c, self._create_crud_bind_param(c, - None, required=True)) + (c, self._create_crud_bind_param(c, + None, required=True)) for c in stmt.table.columns ] @@ -1288,8 +1319,8 @@ class SQLCompiler(engine.Compiled): parameters = {} else: parameters = dict((sql._column_as_key(key), required) - for key in self.column_keys - if not stmt.parameters or + for key in self.column_keys + if not stmt.parameters or key not in stmt.parameters) if stmt.parameters is not None: @@ -1310,7 +1341,7 @@ class SQLCompiler(engine.Compiled): postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid check_columns = {} - # special logic that only occurs for multi-table UPDATE + # special logic that only occurs for multi-table UPDATE # statements if extra_tables and stmt.parameters: assert self.isupdate @@ -1329,7 +1360,7 @@ class SQLCompiler(engine.Compiled): value = self.process(value.self_group()) values.append((c, value)) # determine tables which are actually - # to be updated - process onupdate and + # to be updated - process onupdate and # server_onupdate for these for t in affected_tables: for c in t.c: @@ -1350,7 +1381,7 @@ class SQLCompiler(engine.Compiled): self.postfetch.append(c) # iterating through columns at the top to maintain ordering. - # otherwise we might iterate through individual sets of + # otherwise we might iterate through individual sets of # "defaults", "primary key cols", etc. for c in stmt.table.columns: if c.key in parameters and c.key not in check_columns: @@ -1370,8 +1401,8 @@ class SQLCompiler(engine.Compiled): if c.primary_key and \ need_pks and \ ( - implicit_returning or - not postfetch_lastrowid or + implicit_returning or + not postfetch_lastrowid or c is not stmt.table._autoincrement_column ): @@ -1457,7 +1488,7 @@ class SQLCompiler(engine.Compiled): ).difference(check_columns) if check: raise exc.CompileError( - "Unconsumed column names: %s" % + "Unconsumed column names: %s" % (", ".join(check)) ) @@ -1468,13 +1499,13 @@ class SQLCompiler(engine.Compiled): self.isdelete = True text = "DELETE FROM " - table_text = delete_stmt.table._compiler_dispatch(self, + table_text = delete_stmt.table._compiler_dispatch(self, asfrom=True, iscrud=True) if delete_stmt._hints: dialect_hints = dict([ (table, hint_text) - for (table, dialect), hint_text in + for (table, dialect), hint_text in delete_stmt._hints.items() if dialect in ('*', self.dialect.name) ]) @@ -1498,8 +1529,8 @@ class SQLCompiler(engine.Compiled): delete_stmt, delete_stmt._returning) if delete_stmt._whereclause is not None: - text += " WHERE " - text += delete_stmt._whereclause._compiler_dispatch(self) + text += " WHERE " + text += delete_stmt._whereclause._compiler_dispatch(self) if self.returning and not self.returning_precedes_values: text += " " + self.returning_clause( @@ -1580,7 +1611,7 @@ class DDLCompiler(engine.Compiled): text += separator separator = ", \n" text += "\t" + self.get_column_specification( - column, + column, first_pk=column.primary_key and \ not first_pk ) @@ -1592,16 +1623,16 @@ class DDLCompiler(engine.Compiled): text += " " + const except exc.CompileError, ce: # Py3K - #raise exc.CompileError("(in table '%s', column '%s'): %s" + #raise exc.CompileError("(in table '%s', column '%s'): %s" # % ( - # table.description, - # column.name, + # table.description, + # column.name, # ce.args[0] # )) from ce # Py2K - raise exc.CompileError("(in table '%s', column '%s'): %s" + raise exc.CompileError("(in table '%s', column '%s'): %s" % ( - table.description, + table.description, column.name, ce.args[0] )), None, sys.exc_info()[2] @@ -1622,17 +1653,17 @@ class DDLCompiler(engine.Compiled): if table.primary_key: constraints.append(table.primary_key) - constraints.extend([c for c in table._sorted_constraints + constraints.extend([c for c in table._sorted_constraints if c is not table.primary_key]) return ", \n\t".join(p for p in - (self.process(constraint) - for constraint in constraints + (self.process(constraint) + for constraint in constraints if ( constraint._create_rule is None or constraint._create_rule(self)) and ( - not self.dialect.supports_alter or + not self.dialect.supports_alter or not getattr(constraint, 'use_alter', False) )) if p is not None ) @@ -1660,7 +1691,7 @@ class DDLCompiler(engine.Compiled): if index.unique: text += "UNIQUE " text += "INDEX %s ON %s (%s)" \ - % (preparer.quote(self._index_identifier(index.name), + % (preparer.quote(self._index_identifier(index.name), index.quote), preparer.format_table(index.table), ', '.join(preparer.quote(c.name, c.quote) @@ -1787,7 +1818,7 @@ class DDLCompiler(engine.Compiled): text += "CONSTRAINT %s " % \ self.preparer.format_constraint(constraint) text += "UNIQUE (%s)" % ( - ', '.join(self.preparer.quote(c.name, c.quote) + ', '.join(self.preparer.quote(c.name, c.quote) for c in constraint)) text += self.define_constraint_deferrability(constraint) return text @@ -1839,7 +1870,7 @@ class GenericTypeCompiler(engine.TypeCompiler): {'precision': type_.precision} else: return "NUMERIC(%(precision)s, %(scale)s)" % \ - {'precision': type_.precision, + {'precision': type_.precision, 'scale' : type_.scale} def visit_DECIMAL(self, type_): @@ -1896,25 +1927,25 @@ class GenericTypeCompiler(engine.TypeCompiler): def visit_large_binary(self, type_): return self.visit_BLOB(type_) - def visit_boolean(self, type_): + def visit_boolean(self, type_): return self.visit_BOOLEAN(type_) - def visit_time(self, type_): + def visit_time(self, type_): return self.visit_TIME(type_) - def visit_datetime(self, type_): + def visit_datetime(self, type_): return self.visit_DATETIME(type_) - def visit_date(self, type_): + def visit_date(self, type_): return self.visit_DATE(type_) - def visit_big_integer(self, type_): + def visit_big_integer(self, type_): return self.visit_BIGINT(type_) - def visit_small_integer(self, type_): + def visit_small_integer(self, type_): return self.visit_SMALLINT(type_) - def visit_integer(self, type_): + def visit_integer(self, type_): return self.visit_INTEGER(type_) def visit_real(self, type_): @@ -1923,19 +1954,19 @@ class GenericTypeCompiler(engine.TypeCompiler): def visit_float(self, type_): return self.visit_FLOAT(type_) - def visit_numeric(self, type_): + def visit_numeric(self, type_): return self.visit_NUMERIC(type_) - def visit_string(self, type_): + def visit_string(self, type_): return self.visit_VARCHAR(type_) - def visit_unicode(self, type_): + def visit_unicode(self, type_): return self.visit_VARCHAR(type_) - def visit_text(self, type_): + def visit_text(self, type_): return self.visit_TEXT(type_) - def visit_unicode_text(self, type_): + def visit_unicode_text(self, type_): return self.visit_TEXT(type_) def visit_enum(self, type_): @@ -1959,7 +1990,7 @@ class IdentifierPreparer(object): illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS - def __init__(self, dialect, initial_quote='"', + def __init__(self, dialect, initial_quote='"', final_quote=None, escape_quote='"', omit_schema=False): """Construct a new ``IdentifierPreparer`` object. @@ -2023,7 +2054,7 @@ class IdentifierPreparer(object): def quote_schema(self, schema, force): """Quote a schema. - Subclasses should override this to provide database-dependent + Subclasses should override this to provide database-dependent quoting behavior. """ return self.quote(schema, force) @@ -2080,7 +2111,7 @@ class IdentifierPreparer(object): return self.quote(name, quote) - def format_column(self, column, use_table=False, + def format_column(self, column, use_table=False, name=None, table_name=None): """Prepare a quoted column name.""" @@ -2089,7 +2120,7 @@ class IdentifierPreparer(object): if not getattr(column, 'is_literal', False): if use_table: return self.format_table( - column.table, use_schema=False, + column.table, use_schema=False, name=table_name) + "." + \ self.quote(name, column.quote) else: |