summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py98
1 files changed, 43 insertions, 55 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 59eb3cdb3..59964178c 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -421,9 +421,9 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
anonname = ANONYMOUS_LABEL.sub(self._process_anon, name)
- if len(anonname) > self.dialect.max_identifier_length():
+ if len(anonname) > self.dialect.max_identifier_length:
counter = self.generated_ids.get(ident_class, 1)
- truncname = name[0:self.dialect.max_identifier_length() - 6] + "_" + hex(counter)[2:]
+ truncname = name[0:self.dialect.max_identifier_length - 6] + "_" + hex(counter)[2:]
self.generated_ids[ident_class] = counter + 1
else:
truncname = anonname
@@ -515,7 +515,6 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
l = co.label(labelname)
inner_columns.add(self.process(l))
else:
- self.traverse(co)
inner_columns.add(self.process(co))
else:
l = self.label_select_column(select, co)
@@ -620,20 +619,16 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
# for inserts, this includes Python-side defaults, columns with sequences for dialects
# that support sequences, and primary key columns for dialects that explicitly insert
# pre-generated primary key values
- required_cols = util.Set()
- class DefaultVisitor(schema.SchemaVisitor):
- def visit_column(s, cd):
- if c.primary_key and self.uses_sequences_for_inserts():
- required_cols.add(c)
- def visit_column_default(s, cd):
- required_cols.add(c)
- def visit_sequence(s, seq):
- if self.uses_sequences_for_inserts():
- required_cols.add(c)
- vis = DefaultVisitor()
- for c in insert_stmt.table.c:
- if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
- vis.traverse(c)
+ required_cols = [
+ c for c in insert_stmt.table.c
+ if \
+ isinstance(c, schema.SchemaItem) and \
+ (self.parameters is None or self.parameters.get(c.key, None) is None) and \
+ (
+ ((c.primary_key or isinstance(c.default, schema.Sequence)) and self.uses_sequences_for_inserts()) or
+ isinstance(c.default, schema.ColumnDefault)
+ )
+ ]
self.isinsert = True
colparams = self._get_colparams(insert_stmt, required_cols)
@@ -646,14 +641,12 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
# search for columns who will be required to have an explicit bound value.
# for updates, this includes Python-side "onupdate" defaults.
- required_cols = util.Set()
- class OnUpdateVisitor(schema.SchemaVisitor):
- def visit_column_onupdate(s, cd):
- required_cols.add(c)
- vis = OnUpdateVisitor()
- for c in update_stmt.table.c:
- if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
- vis.traverse(c)
+ required_cols = [c for c in update_stmt.table.c
+ if
+ isinstance(c, schema.SchemaItem) and \
+ (self.parameters is None or self.parameters.get(c.key, None) is None) and
+ isinstance(c.onupdate, schema.ColumnDefault)
+ ]
self.isupdate = True
colparams = self._get_colparams(update_stmt, required_cols)
@@ -681,11 +674,6 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
self.binds[col.key] = bindparam
return self.bindparam_string(self._truncate_bindparam(bindparam))
- def create_clause_param(col, value):
- self.traverse(value)
- self.inline_params.add(col)
- return self.process(value)
-
self.inline_params = util.Set()
def to_col(key):
@@ -704,25 +692,28 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
if self.parameters is None:
parameters = {}
else:
- parameters = dict([(to_col(k), v) for k, v in self.parameters.iteritems()])
+ parameters = dict([(getattr(k, 'key', k), v) for k, v in self.parameters.iteritems()])
if stmt.parameters is not None:
for k, v in stmt.parameters.iteritems():
- parameters.setdefault(to_col(k), v)
+ parameters.setdefault(getattr(k, 'key', k), v)
for col in required_cols:
- parameters.setdefault(col, None)
+ parameters.setdefault(col.key, None)
# create a list of column assignment clauses as tuples
values = []
for c in stmt.table.columns:
- if c in parameters:
- value = parameters[c]
- if sql._is_literal(value):
- value = create_bind_param(c, value)
- else:
- value = create_clause_param(c, value)
- values.append((c, value))
+ if c.key in parameters:
+ value = parameters[c.key]
+ else:
+ continue
+ if sql._is_literal(value):
+ value = create_bind_param(c, value)
+ else:
+ self.inline_params.add(c)
+ value = self.process(value)
+ values.append((c, value))
return values
@@ -778,7 +769,7 @@ class SchemaGenerator(DDLBase):
collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name, schema=t.schema))]
for table in collection:
self.traverse_single(table)
- if self.dialect.supports_alter():
+ if self.dialect.supports_alter:
for alterable in self.find_alterables(collection):
self.add_foreignkey(alterable)
@@ -853,7 +844,7 @@ class SchemaGenerator(DDLBase):
self.append("(%s)" % ', '.join([self.preparer.format_column(c) for c in constraint]))
def visit_foreign_key_constraint(self, constraint):
- if constraint.use_alter and self.dialect.supports_alter():
+ if constraint.use_alter and self.dialect.supports_alter:
return
self.append(", \n\t ")
self.define_foreign_key(constraint)
@@ -909,7 +900,7 @@ class SchemaDropper(DDLBase):
def visit_metadata(self, metadata):
collection = [t for t in metadata.table_iterator(reverse=True, tables=self.tables) if (not self.checkfirst or self.dialect.has_table(self.connection, t.name, schema=t.schema))]
- if self.dialect.supports_alter():
+ if self.dialect.supports_alter:
for alterable in self.find_alterables(collection):
self.drop_foreignkey(alterable)
for table in collection:
@@ -936,6 +927,12 @@ class SchemaDropper(DDLBase):
class IdentifierPreparer(object):
"""Handle quoting and case-folding of identifiers based on options."""
+ reserved_words = RESERVED_WORDS
+
+ legal_characters = LEGAL_CHARACTERS
+
+ illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
+
def __init__(self, dialect, initial_quote='"', final_quote=None, omit_schema=False):
"""Construct a new ``IdentifierPreparer`` object.
@@ -995,21 +992,12 @@ class IdentifierPreparer(object):
# some tests would need to be rewritten if this is done.
#return value.upper()
- def _reserved_words(self):
- return RESERVED_WORDS
-
- def _legal_characters(self):
- return LEGAL_CHARACTERS
-
- def _illegal_initial_characters(self):
- return ILLEGAL_INITIAL_CHARACTERS
-
def _requires_quotes(self, value):
"""Return True if the given identifier requires quoting."""
return \
- value in self._reserved_words() \
- or (value[0] in self._illegal_initial_characters()) \
- or bool(len([x for x in unicode(value) if x not in self._legal_characters()])) \
+ value in self.reserved_words \
+ or (value[0] in self.illegal_initial_characters) \
+ or bool(len([x for x in unicode(value) if x not in self.legal_characters])) \
or (value.lower() != value)
def __generic_obj_format(self, obj, ident):