diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 101 |
1 files changed, 48 insertions, 53 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 1fe9ef062..78bb4e31c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -19,7 +19,7 @@ is otherwise internal to SQLAlchemy. """ import string, re, itertools -from sqlalchemy import schema, engine, util, exceptions +from sqlalchemy import schema, engine, util, exc from sqlalchemy.sql import operators, functions from sqlalchemy.sql import expression as sql @@ -115,8 +115,6 @@ class DefaultCompiler(engine.Compiled): paradigm as visitors.ClauseVisitor but implements its own traversal. """ - __traverse_options__ = {'column_collections':False, 'entry':True} - operators = OPERATORS functions = FUNCTIONS @@ -162,17 +160,12 @@ class DefaultCompiler(engine.Compiled): # for aliases self.generated_ids = {} - # paramstyle from the dialect (comes from DB-API) - self.paramstyle = self.dialect.paramstyle - # true if the paramstyle is positional self.positional = self.dialect.positional + if self.positional: + self.positiontup = [] - self.bindtemplate = BIND_TEMPLATES[self.paramstyle] - - # a list of the compiled's bind parameter names, used to help - # formulate a positional argument list - self.positiontup = [] + self.bindtemplate = BIND_TEMPLATES[self.dialect.paramstyle] # an IdentifierPreparer that formats the quoting of identifiers self.preparer = self.dialect.identifier_preparer @@ -230,15 +223,18 @@ class DefaultCompiler(engine.Compiled): return "" def visit_grouping(self, grouping, **kwargs): - return "(" + self.process(grouping.elem) + ")" + return "(" + self.process(grouping.element) + ")" - def visit_label(self, label, result_map=None): + def visit_label(self, label, result_map=None, render_labels=False): + if not render_labels: + return self.process(label.element) + labelname = self._truncated_identifier("colident", label.name) if result_map is not None: - result_map[labelname.lower()] = (label.name, (label, label.obj, labelname), label.obj.type) + result_map[labelname.lower()] = (label.name, (label, label.element, labelname), label.element.type) - return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)]) + return " ".join([self.process(label.element), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)]) def visit_column(self, column, result_map=None, **kwargs): @@ -261,16 +257,16 @@ class DefaultCompiler(engine.Compiled): if getattr(column, "is_literal", False): name = self.escape_literal_column(name) else: - name = self.preparer.quote(column, name) + name = self.preparer.quote(name, column.quote) if column.table is None or not column.table.named_with_column: return name else: if getattr(column.table, 'schema', None): - schema_prefix = self.preparer.quote(column.table, column.table.schema) + '.' + schema_prefix = self.preparer.quote(column.table.schema, column.table.quote_schema) + '.' else: schema_prefix = '' - return schema_prefix + self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + name + return schema_prefix + self.preparer.quote(ANONYMOUS_LABEL.sub(self._process_anon, column.table.name), column.table.quote) + "." + name def escape_literal_column(self, text): """provide escaping for the literal_column() construct.""" @@ -387,7 +383,7 @@ class DefaultCompiler(engine.Compiled): if name in self.binds: existing = self.binds[name] if existing is not bindparam and (existing.unique or bindparam.unique): - raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key) + raise exc.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key) self.binds[bindparam.key] = self.binds[name] = bindparam return self.bindparam_string(name) @@ -418,7 +414,7 @@ class DefaultCompiler(engine.Compiled): return truncname def _process_anon(self, match): - (ident, derived) = match.group(1,2) + (ident, derived) = match.group(1, 2) key = ('anonymous', ident) if key in self.generated_ids: @@ -436,8 +432,9 @@ class DefaultCompiler(engine.Compiled): def bindparam_string(self, name): if self.positional: self.positiontup.append(name) - - return self.bindtemplate % {'name':name, 'position':len(self.positiontup)} + return self.bindtemplate % {'name':name, 'position':len(self.positiontup)} + else: + return self.bindtemplate % {'name':name} def visit_alias(self, alias, asfrom=False, **kwargs): if asfrom: @@ -490,7 +487,7 @@ class DefaultCompiler(engine.Compiled): froms = select._get_display_froms(existingfroms) - correlate_froms = util.Set(itertools.chain(*([froms] + [f._get_from_objects() for f in froms]))) + correlate_froms = util.Set(sql._from_objects(*froms)) # TODO: might want to propigate existing froms for select(select(select)) # where innermost select should correlate to outermost @@ -504,6 +501,7 @@ class DefaultCompiler(engine.Compiled): [c for c in [ self.process( self.label_select_column(select, co, asfrom=asfrom), + render_labels=True, **column_clause_args) for co in select.inner_columns ] @@ -580,9 +578,9 @@ class DefaultCompiler(engine.Compiled): def visit_table(self, table, asfrom=False, **kwargs): if asfrom: if getattr(table, "schema", None): - return self.preparer.quote(table, table.schema) + "." + self.preparer.quote(table, table.name) + return self.preparer.quote(table.schema, table.quote_schema) + "." + self.preparer.quote(table.name, table.quote) else: - return self.preparer.quote(table, table.name) + return self.preparer.quote(table.name, table.quote) else: return "" @@ -603,7 +601,7 @@ class DefaultCompiler(engine.Compiled): return (insert + " INTO %s (%s) VALUES (%s)" % (preparer.format_table(insert_stmt.table), - ', '.join([preparer.quote(c[0], c[0].name) + ', '.join([preparer.quote(c[0].name, c[0].quote) for c in colparams]), ', '.join([c[1] for c in colparams]))) @@ -613,7 +611,7 @@ class DefaultCompiler(engine.Compiled): self.isupdate = True colparams = self._get_colparams(update_stmt) - text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.quote(c[0], c[0].name), c[1]) for c in colparams], ', ') + text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.quote(c[0].name, c[0].quote), c[1]) for c in colparams], ', ') if update_stmt._whereclause: text += " WHERE " + self.process(update_stmt._whereclause) @@ -837,7 +835,7 @@ class SchemaGenerator(DDLBase): if constraint.name is not None: self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) self.append("PRIMARY KEY ") - self.append("(%s)" % ', '.join([self.preparer.quote(c, c.name) for c in constraint])) + self.append("(%s)" % ', '.join([self.preparer.quote(c.name, c.quote) for c in constraint])) self.define_constraint_deferrability(constraint) def visit_foreign_key_constraint(self, constraint): @@ -858,9 +856,9 @@ class SchemaGenerator(DDLBase): preparer.format_constraint(constraint)) table = list(constraint.elements)[0].column.table self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % ( - ', '.join([preparer.quote(f.parent, f.parent.name) for f in constraint.elements]), + ', '.join([preparer.quote(f.parent.name, f.parent.quote) for f in constraint.elements]), preparer.format_table(table), - ', '.join([preparer.quote(f.column, f.column.name) for f in constraint.elements]) + ', '.join([preparer.quote(f.column.name, f.column.quote) for f in constraint.elements]) )) if constraint.ondelete is not None: self.append(" ON DELETE %s" % constraint.ondelete) @@ -873,7 +871,7 @@ class SchemaGenerator(DDLBase): if constraint.name is not None: self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) - self.append(" UNIQUE (%s)" % (', '.join([self.preparer.quote(c, c.name) for c in constraint]))) + self.append(" UNIQUE (%s)" % (', '.join([self.preparer.quote(c.name, c.quote) for c in constraint]))) self.define_constraint_deferrability(constraint) def define_constraint_deferrability(self, constraint): @@ -896,7 +894,7 @@ class SchemaGenerator(DDLBase): self.append("INDEX %s ON %s (%s)" \ % (preparer.format_index(index), preparer.format_table(index.table), - string.join([preparer.quote(c, c.name) for c in index.columns], ', '))) + string.join([preparer.quote(c.name, c.quote) for c in index.columns], ', '))) self.execute() @@ -1005,9 +1003,12 @@ class IdentifierPreparer(object): or not self.legal_characters.match(unicode(value)) or (lc_value != value)) - def quote(self, obj, ident): - if getattr(obj, 'quote', False): + def quote(self, ident, force): + if force: return self.quote_identifier(ident) + elif force is False: + return ident + if ident in self.__strings: return self.__strings[ident] else: @@ -1017,53 +1018,47 @@ class IdentifierPreparer(object): self.__strings[ident] = ident return self.__strings[ident] - def should_quote(self, object): - return object.quote or self._requires_quotes(object.name) - def format_sequence(self, sequence, use_schema=True): - name = self.quote(sequence, sequence.name) + name = self.quote(sequence.name, sequence.quote) if not self.omit_schema and use_schema and sequence.schema is not None: - name = self.quote(sequence, sequence.schema) + "." + name + name = self.quote(sequence.schema, sequence.quote) + "." + name return name def format_label(self, label, name=None): - return self.quote(label, name or label.name) + return self.quote(name or label.name, label.quote) def format_alias(self, alias, name=None): - return self.quote(alias, name or alias.name) + return self.quote(name or alias.name, alias.quote) def format_savepoint(self, savepoint, name=None): - return self.quote(savepoint, name or savepoint.ident) + return self.quote(name or savepoint.ident, savepoint.quote) def format_constraint(self, constraint): - return self.quote(constraint, constraint.name) + return self.quote(constraint.name, constraint.quote) def format_index(self, index): - return self.quote(index, index.name) + return self.quote(index.name, index.quote) def format_table(self, table, use_schema=True, name=None): """Prepare a quoted table and schema name.""" if name is None: name = table.name - result = self.quote(table, name) + result = self.quote(name, table.quote) if not self.omit_schema and use_schema and getattr(table, "schema", None): - result = self.quote(table, table.schema) + "." + result + result = self.quote(table.schema, table.quote_schema) + "." + result return result def format_column(self, column, use_table=False, name=None, table_name=None): - """Prepare a quoted column name. - - deprecated. use preparer.quote(col, column.name) or combine with format_table() - """ + """Prepare a quoted column name.""" if name is None: name = column.name if not getattr(column, 'is_literal', False): if use_table: - return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.quote(column, name) + return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.quote(name, column.quote) else: - return self.quote(column, name) + return self.quote(name, column.quote) else: # literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted if use_table: @@ -1079,7 +1074,7 @@ class IdentifierPreparer(object): # a longer sequence. if not self.omit_schema and use_schema and getattr(table, 'schema', None): - return (self.quote_identifier(table.schema), + return (self.quote(table.schema, table.quote_schema), self.format_table(table, use_schema=False)) else: return (self.format_table(table, use_schema=False), ) |