summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py101
1 files changed, 48 insertions, 53 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 1fe9ef062..78bb4e31c 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -19,7 +19,7 @@ is otherwise internal to SQLAlchemy.
"""
import string, re, itertools
-from sqlalchemy import schema, engine, util, exceptions
+from sqlalchemy import schema, engine, util, exc
from sqlalchemy.sql import operators, functions
from sqlalchemy.sql import expression as sql
@@ -115,8 +115,6 @@ class DefaultCompiler(engine.Compiled):
paradigm as visitors.ClauseVisitor but implements its own traversal.
"""
- __traverse_options__ = {'column_collections':False, 'entry':True}
-
operators = OPERATORS
functions = FUNCTIONS
@@ -162,17 +160,12 @@ class DefaultCompiler(engine.Compiled):
# for aliases
self.generated_ids = {}
- # paramstyle from the dialect (comes from DB-API)
- self.paramstyle = self.dialect.paramstyle
-
# true if the paramstyle is positional
self.positional = self.dialect.positional
+ if self.positional:
+ self.positiontup = []
- self.bindtemplate = BIND_TEMPLATES[self.paramstyle]
-
- # a list of the compiled's bind parameter names, used to help
- # formulate a positional argument list
- self.positiontup = []
+ self.bindtemplate = BIND_TEMPLATES[self.dialect.paramstyle]
# an IdentifierPreparer that formats the quoting of identifiers
self.preparer = self.dialect.identifier_preparer
@@ -230,15 +223,18 @@ class DefaultCompiler(engine.Compiled):
return ""
def visit_grouping(self, grouping, **kwargs):
- return "(" + self.process(grouping.elem) + ")"
+ return "(" + self.process(grouping.element) + ")"
- def visit_label(self, label, result_map=None):
+ def visit_label(self, label, result_map=None, render_labels=False):
+ if not render_labels:
+ return self.process(label.element)
+
labelname = self._truncated_identifier("colident", label.name)
if result_map is not None:
- result_map[labelname.lower()] = (label.name, (label, label.obj, labelname), label.obj.type)
+ result_map[labelname.lower()] = (label.name, (label, label.element, labelname), label.element.type)
- return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)])
+ return " ".join([self.process(label.element), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)])
def visit_column(self, column, result_map=None, **kwargs):
@@ -261,16 +257,16 @@ class DefaultCompiler(engine.Compiled):
if getattr(column, "is_literal", False):
name = self.escape_literal_column(name)
else:
- name = self.preparer.quote(column, name)
+ name = self.preparer.quote(name, column.quote)
if column.table is None or not column.table.named_with_column:
return name
else:
if getattr(column.table, 'schema', None):
- schema_prefix = self.preparer.quote(column.table, column.table.schema) + '.'
+ schema_prefix = self.preparer.quote(column.table.schema, column.table.quote_schema) + '.'
else:
schema_prefix = ''
- return schema_prefix + self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + name
+ return schema_prefix + self.preparer.quote(ANONYMOUS_LABEL.sub(self._process_anon, column.table.name), column.table.quote) + "." + name
def escape_literal_column(self, text):
"""provide escaping for the literal_column() construct."""
@@ -387,7 +383,7 @@ class DefaultCompiler(engine.Compiled):
if name in self.binds:
existing = self.binds[name]
if existing is not bindparam and (existing.unique or bindparam.unique):
- raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key)
+ raise exc.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key)
self.binds[bindparam.key] = self.binds[name] = bindparam
return self.bindparam_string(name)
@@ -418,7 +414,7 @@ class DefaultCompiler(engine.Compiled):
return truncname
def _process_anon(self, match):
- (ident, derived) = match.group(1,2)
+ (ident, derived) = match.group(1, 2)
key = ('anonymous', ident)
if key in self.generated_ids:
@@ -436,8 +432,9 @@ class DefaultCompiler(engine.Compiled):
def bindparam_string(self, name):
if self.positional:
self.positiontup.append(name)
-
- return self.bindtemplate % {'name':name, 'position':len(self.positiontup)}
+ return self.bindtemplate % {'name':name, 'position':len(self.positiontup)}
+ else:
+ return self.bindtemplate % {'name':name}
def visit_alias(self, alias, asfrom=False, **kwargs):
if asfrom:
@@ -490,7 +487,7 @@ class DefaultCompiler(engine.Compiled):
froms = select._get_display_froms(existingfroms)
- correlate_froms = util.Set(itertools.chain(*([froms] + [f._get_from_objects() for f in froms])))
+ correlate_froms = util.Set(sql._from_objects(*froms))
# TODO: might want to propigate existing froms for select(select(select))
# where innermost select should correlate to outermost
@@ -504,6 +501,7 @@ class DefaultCompiler(engine.Compiled):
[c for c in [
self.process(
self.label_select_column(select, co, asfrom=asfrom),
+ render_labels=True,
**column_clause_args)
for co in select.inner_columns
]
@@ -580,9 +578,9 @@ class DefaultCompiler(engine.Compiled):
def visit_table(self, table, asfrom=False, **kwargs):
if asfrom:
if getattr(table, "schema", None):
- return self.preparer.quote(table, table.schema) + "." + self.preparer.quote(table, table.name)
+ return self.preparer.quote(table.schema, table.quote_schema) + "." + self.preparer.quote(table.name, table.quote)
else:
- return self.preparer.quote(table, table.name)
+ return self.preparer.quote(table.name, table.quote)
else:
return ""
@@ -603,7 +601,7 @@ class DefaultCompiler(engine.Compiled):
return (insert + " INTO %s (%s) VALUES (%s)" %
(preparer.format_table(insert_stmt.table),
- ', '.join([preparer.quote(c[0], c[0].name)
+ ', '.join([preparer.quote(c[0].name, c[0].quote)
for c in colparams]),
', '.join([c[1] for c in colparams])))
@@ -613,7 +611,7 @@ class DefaultCompiler(engine.Compiled):
self.isupdate = True
colparams = self._get_colparams(update_stmt)
- text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.quote(c[0], c[0].name), c[1]) for c in colparams], ', ')
+ text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.quote(c[0].name, c[0].quote), c[1]) for c in colparams], ', ')
if update_stmt._whereclause:
text += " WHERE " + self.process(update_stmt._whereclause)
@@ -837,7 +835,7 @@ class SchemaGenerator(DDLBase):
if constraint.name is not None:
self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint))
self.append("PRIMARY KEY ")
- self.append("(%s)" % ', '.join([self.preparer.quote(c, c.name) for c in constraint]))
+ self.append("(%s)" % ', '.join([self.preparer.quote(c.name, c.quote) for c in constraint]))
self.define_constraint_deferrability(constraint)
def visit_foreign_key_constraint(self, constraint):
@@ -858,9 +856,9 @@ class SchemaGenerator(DDLBase):
preparer.format_constraint(constraint))
table = list(constraint.elements)[0].column.table
self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % (
- ', '.join([preparer.quote(f.parent, f.parent.name) for f in constraint.elements]),
+ ', '.join([preparer.quote(f.parent.name, f.parent.quote) for f in constraint.elements]),
preparer.format_table(table),
- ', '.join([preparer.quote(f.column, f.column.name) for f in constraint.elements])
+ ', '.join([preparer.quote(f.column.name, f.column.quote) for f in constraint.elements])
))
if constraint.ondelete is not None:
self.append(" ON DELETE %s" % constraint.ondelete)
@@ -873,7 +871,7 @@ class SchemaGenerator(DDLBase):
if constraint.name is not None:
self.append("CONSTRAINT %s " %
self.preparer.format_constraint(constraint))
- self.append(" UNIQUE (%s)" % (', '.join([self.preparer.quote(c, c.name) for c in constraint])))
+ self.append(" UNIQUE (%s)" % (', '.join([self.preparer.quote(c.name, c.quote) for c in constraint])))
self.define_constraint_deferrability(constraint)
def define_constraint_deferrability(self, constraint):
@@ -896,7 +894,7 @@ class SchemaGenerator(DDLBase):
self.append("INDEX %s ON %s (%s)" \
% (preparer.format_index(index),
preparer.format_table(index.table),
- string.join([preparer.quote(c, c.name) for c in index.columns], ', ')))
+ string.join([preparer.quote(c.name, c.quote) for c in index.columns], ', ')))
self.execute()
@@ -1005,9 +1003,12 @@ class IdentifierPreparer(object):
or not self.legal_characters.match(unicode(value))
or (lc_value != value))
- def quote(self, obj, ident):
- if getattr(obj, 'quote', False):
+ def quote(self, ident, force):
+ if force:
return self.quote_identifier(ident)
+ elif force is False:
+ return ident
+
if ident in self.__strings:
return self.__strings[ident]
else:
@@ -1017,53 +1018,47 @@ class IdentifierPreparer(object):
self.__strings[ident] = ident
return self.__strings[ident]
- def should_quote(self, object):
- return object.quote or self._requires_quotes(object.name)
-
def format_sequence(self, sequence, use_schema=True):
- name = self.quote(sequence, sequence.name)
+ name = self.quote(sequence.name, sequence.quote)
if not self.omit_schema and use_schema and sequence.schema is not None:
- name = self.quote(sequence, sequence.schema) + "." + name
+ name = self.quote(sequence.schema, sequence.quote) + "." + name
return name
def format_label(self, label, name=None):
- return self.quote(label, name or label.name)
+ return self.quote(name or label.name, label.quote)
def format_alias(self, alias, name=None):
- return self.quote(alias, name or alias.name)
+ return self.quote(name or alias.name, alias.quote)
def format_savepoint(self, savepoint, name=None):
- return self.quote(savepoint, name or savepoint.ident)
+ return self.quote(name or savepoint.ident, savepoint.quote)
def format_constraint(self, constraint):
- return self.quote(constraint, constraint.name)
+ return self.quote(constraint.name, constraint.quote)
def format_index(self, index):
- return self.quote(index, index.name)
+ return self.quote(index.name, index.quote)
def format_table(self, table, use_schema=True, name=None):
"""Prepare a quoted table and schema name."""
if name is None:
name = table.name
- result = self.quote(table, name)
+ result = self.quote(name, table.quote)
if not self.omit_schema and use_schema and getattr(table, "schema", None):
- result = self.quote(table, table.schema) + "." + result
+ result = self.quote(table.schema, table.quote_schema) + "." + result
return result
def format_column(self, column, use_table=False, name=None, table_name=None):
- """Prepare a quoted column name.
-
- deprecated. use preparer.quote(col, column.name) or combine with format_table()
- """
+ """Prepare a quoted column name."""
if name is None:
name = column.name
if not getattr(column, 'is_literal', False):
if use_table:
- return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.quote(column, name)
+ return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.quote(name, column.quote)
else:
- return self.quote(column, name)
+ return self.quote(name, column.quote)
else:
# literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted
if use_table:
@@ -1079,7 +1074,7 @@ class IdentifierPreparer(object):
# a longer sequence.
if not self.omit_schema and use_schema and getattr(table, 'schema', None):
- return (self.quote_identifier(table.schema),
+ return (self.quote(table.schema, table.quote_schema),
self.format_table(table, use_schema=False))
else:
return (self.format_table(table, use_schema=False), )