summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ansisql.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r--lib/sqlalchemy/ansisql.py749
1 files changed, 362 insertions, 387 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py
index 9994d5288..22227d56a 100644
--- a/lib/sqlalchemy/ansisql.py
+++ b/lib/sqlalchemy/ansisql.py
@@ -10,9 +10,10 @@ Contains default implementations for the abstract objects in the sql
module.
"""
-from sqlalchemy import schema, sql, engine, util, sql_util, exceptions
+import string, re, sets, operator
+
+from sqlalchemy import schema, sql, engine, util, exceptions
from sqlalchemy.engine import default
-import string, re, sets, weakref, random
ANSI_FUNCS = sets.ImmutableSet(['CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP',
'CURRENT_USER', 'LOCALTIME', 'LOCALTIMESTAMP',
@@ -40,6 +41,41 @@ RESERVED_WORDS = util.Set(['all', 'analyse', 'analyze', 'and', 'any', 'array',
LEGAL_CHARACTERS = util.Set(string.ascii_lowercase + string.ascii_uppercase + string.digits + '_$')
ILLEGAL_INITIAL_CHARACTERS = util.Set(string.digits + '$')
+BIND_PARAMS = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
+BIND_PARAMS_ESC = re.compile(r'\x5c(:\w+)(?!:)', re.UNICODE)
+
+OPERATORS = {
+ operator.and_ : 'AND',
+ operator.or_ : 'OR',
+ operator.inv : 'NOT',
+ operator.add : '+',
+ operator.mul : '*',
+ operator.sub : '-',
+ operator.div : '/',
+ operator.mod : '%',
+ operator.truediv : '/',
+ operator.lt : '<',
+ operator.le : '<=',
+ operator.ne : '!=',
+ operator.gt : '>',
+ operator.ge : '>=',
+ operator.eq : '=',
+ sql.ColumnOperators.concat_op : '||',
+ sql.ColumnOperators.like_op : 'LIKE',
+ sql.ColumnOperators.notlike_op : 'NOT LIKE',
+ sql.ColumnOperators.ilike_op : 'ILIKE',
+ sql.ColumnOperators.notilike_op : 'NOT ILIKE',
+ sql.ColumnOperators.between_op : 'BETWEEN',
+ sql.ColumnOperators.in_op : 'IN',
+ sql.ColumnOperators.notin_op : 'NOT IN',
+ sql.ColumnOperators.comma_op : ', ',
+ sql.Operators.from_ : 'FROM',
+ sql.Operators.as_ : 'AS',
+ sql.Operators.exists : 'EXISTS',
+ sql.Operators.is_ : 'IS',
+ sql.Operators.isnot : 'IS NOT'
+}
+
class ANSIDialect(default.DefaultDialect):
def __init__(self, cache_identifiers=True, **kwargs):
super(ANSIDialect,self).__init__(**kwargs)
@@ -66,14 +102,16 @@ class ANSIDialect(default.DefaultDialect):
"""
return ANSIIdentifierPreparer(self)
-class ANSICompiler(sql.Compiled):
+class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
"""Default implementation of Compiled.
Compiles ClauseElements into ANSI-compliant SQL strings.
"""
- __traverse_options__ = {'column_collections':False}
+ __traverse_options__ = {'column_collections':False, 'entry':True}
+ operators = OPERATORS
+
def __init__(self, dialect, statement, parameters=None, **kwargs):
"""Construct a new ``ANSICompiler`` object.
@@ -92,7 +130,7 @@ class ANSICompiler(sql.Compiled):
correspond to the keys present in the parameters.
"""
- sql.Compiled.__init__(self, dialect, statement, parameters, **kwargs)
+ super(ANSICompiler, self).__init__(dialect, statement, parameters, **kwargs)
# if we are insert/update. set to true when we visit an INSERT or UPDATE
self.isinsert = self.isupdate = False
@@ -104,21 +142,6 @@ class ANSICompiler(sql.Compiled):
# actually present in the generated SQL
self.bind_names = {}
- # a dictionary which stores the string representation for every ClauseElement
- # processed by this compiler.
- self.strings = {}
-
- # a dictionary which stores the string representation for ClauseElements
- # processed by this compiler, which are to be used in the FROM clause
- # of a select. items are often placed in "froms" as well as "strings"
- # and sometimes with different representations.
- self.froms = {}
-
- # slightly hacky. maps FROM clauses to WHERE clauses, and used in select
- # generation to modify the WHERE clause of the select. currently a hack
- # used by the oracle module.
- self.wheres = {}
-
# when the compiler visits a SELECT statement, the clause object is appended
# to this stack. various visit operations will check this stack to determine
# additional choices (TODO: it seems to be all typemap stuff. shouldnt this only
@@ -137,12 +160,6 @@ class ANSICompiler(sql.Compiled):
# for aliases
self.generated_ids = {}
- # True if this compiled represents an INSERT
- self.isinsert = False
-
- # True if this compiled represents an UPDATE
- self.isupdate = False
-
# default formatting style for bind parameters
self.bindtemplate = ":%s"
@@ -158,64 +175,76 @@ class ANSICompiler(sql.Compiled):
# an ANSIIdentifierPreparer that formats the quoting of identifiers
self.preparer = dialect.identifier_preparer
-
+
+ # a dictionary containing attributes about all select()
+ # elements located within the clause, regarding which are subqueries, which are
+ # selected from, and which elements should be correlated to an enclosing select.
+ # used mostly to determine the list of FROM elements for each select statement, as well
+ # as some dialect-specific rules regarding subqueries.
+ self.correlate_state = {}
+
# for UPDATE and INSERT statements, a set of columns whos values are being set
# from a SQL expression (i.e., not one of the bind parameter values). if present,
# default-value logic in the Dialect knows not to fire off column defaults
# and also knows postfetching will be needed to get the values represented by these
# parameters.
self.inline_params = None
-
+
def after_compile(self):
# this re will search for params like :param
# it has a negative lookbehind for an extra ':' so that it doesnt match
# postgres '::text' tokens
- match = re.compile(r'(?<!:):([\w_]+)', re.UNICODE)
+ text = self.string
+ if ':' not in text:
+ return
+
if self.paramstyle=='pyformat':
- self.strings[self.statement] = match.sub(lambda m:'%(' + m.group(1) +')s', self.strings[self.statement])
+ text = BIND_PARAMS.sub(lambda m:'%(' + m.group(1) +')s', text)
elif self.positional:
- params = match.finditer(self.strings[self.statement])
+ params = BIND_PARAMS.finditer(text)
for p in params:
self.positiontup.append(p.group(1))
if self.paramstyle=='qmark':
- self.strings[self.statement] = match.sub('?', self.strings[self.statement])
+ text = BIND_PARAMS.sub('?', text)
elif self.paramstyle=='format':
- self.strings[self.statement] = match.sub('%s', self.strings[self.statement])
+ text = BIND_PARAMS.sub('%s', text)
elif self.paramstyle=='numeric':
i = [0]
def getnum(x):
i[0] += 1
return str(i[0])
- self.strings[self.statement] = match.sub(getnum, self.strings[self.statement])
-
- def get_from_text(self, obj):
- return self.froms.get(obj, None)
-
- def get_str(self, obj):
- return self.strings[obj]
-
+ text = BIND_PARAMS.sub(getnum, text)
+ # un-escape any \:params
+ text = BIND_PARAMS_ESC.sub(lambda m: m.group(1), text)
+ self.string = text
+
+ def compile(self):
+ self.string = self.process(self.statement)
+ self.after_compile()
+
+ def process(self, obj, **kwargs):
+ return self.traverse_single(obj, **kwargs)
+
+ def is_subquery(self, select):
+ return self.correlate_state[select].get('is_subquery', False)
+
def get_whereclause(self, obj):
- return self.wheres.get(obj, None)
+ """given a FROM clause, return an additional WHERE condition that should be
+ applied to a SELECT.
+
+ Currently used by Oracle to provide WHERE criterion for JOIN and OUTER JOIN
+ constructs in non-ansi mode.
+ """
+
+ return None
def construct_params(self, params):
- """Return a structure of bind parameters for this compiled object.
-
- This includes bind parameters that might be compiled in via
- the `values` argument of an ``Insert`` or ``Update`` statement
- object, and also the given `**params`. The keys inside of
- `**params` can be any key that matches the
- ``BindParameterClause`` objects compiled within this object.
-
- The output is dependent on the paramstyle of the DBAPI being
- used; if a named style, the return result will be a dictionary
- with keynames matching the compiled statement. If a
- positional style, the output will be a list, with an iterator
- that will return parameter values in an order corresponding to
- the bind positions in the compiled statement.
-
- For an executemany style of call, this method should be called
- for each element in the list of parameter groups that will
- ultimately be executed.
+ """Return a sql.ClauseParameters object.
+
+ Combines the given bind parameter dictionary (string keys to object values)
+ with the _BindParamClause objects stored within this Compiled object
+ to produce a ClauseParameters structure, representing the bind arguments
+ for a single statement execution, or one element of an executemany execution.
"""
if self.parameters is not None:
@@ -225,7 +254,7 @@ class ANSICompiler(sql.Compiled):
bindparams.update(params)
d = sql.ClauseParameters(self.dialect, self.positiontup)
for b in self.binds.values():
- name = self.bind_names.get(b, b.key)
+ name = self.bind_names[b]
d.set_parameter(b, b.value, name)
for key, value in bindparams.iteritems():
@@ -233,7 +262,7 @@ class ANSICompiler(sql.Compiled):
b = self.binds[key]
except KeyError:
continue
- name = self.bind_names.get(b, b.key)
+ name = self.bind_names[b]
d.set_parameter(b, value, name)
return d
@@ -246,8 +275,8 @@ class ANSICompiler(sql.Compiled):
return ""
- def visit_grouping(self, grouping):
- self.strings[grouping] = "(" + self.strings[grouping.elem] + ")"
+ def visit_grouping(self, grouping, **kwargs):
+ return "(" + self.process(grouping.elem) + ")"
def visit_label(self, label):
labelname = self._truncated_identifier("colident", label.name)
@@ -256,9 +285,10 @@ class ANSICompiler(sql.Compiled):
self.typemap.setdefault(labelname.lower(), label.obj.type)
if isinstance(label.obj, sql._ColumnClause):
self.column_labels[label.obj._label] = labelname
- self.strings[label] = self.strings[label.obj] + " AS " + self.preparer.format_label(label, labelname)
+ self.column_labels[label.name] = labelname
+ return " ".join([self.process(label.obj), self.operator_string(sql.ColumnOperators.as_), self.preparer.format_label(label, labelname)])
- def visit_column(self, column):
+ def visit_column(self, column, **kwargs):
# there is actually somewhat of a ruleset when you would *not* necessarily
# want to truncate a column identifier, if its mapped to the name of a
# physical column. but thats very hard to identify at this point, and
@@ -269,107 +299,110 @@ class ANSICompiler(sql.Compiled):
else:
name = column.name
+ if len(self.select_stack):
+ # if we are within a visit to a Select, set up the "typemap"
+ # for this column which is used to translate result set values
+ self.typemap.setdefault(name.lower(), column.type)
+ self.column_labels.setdefault(column._label, name.lower())
+
if column.table is None or not column.table.named_with_column():
- self.strings[column] = self.preparer.format_column(column, name=name)
+ return self.preparer.format_column(column, name=name)
else:
if column.table.oid_column is column:
n = self.dialect.oid_column_name(column)
if n is not None:
- self.strings[column] = "%s.%s" % (self.preparer.format_table(column.table, use_schema=False), n)
+ return "%s.%s" % (self.preparer.format_table(column.table, use_schema=False, name=self._anonymize(column.table.name)), n)
elif len(column.table.primary_key) != 0:
pk = list(column.table.primary_key)[0]
pkname = (pk.is_literal and name or self._truncated_identifier("colident", pk.name))
- self.strings[column] = self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname)
+ return self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname, table_name=self._anonymize(column.table.name))
else:
- self.strings[column] = None
+ return None
else:
- self.strings[column] = self.preparer.format_column_with_table(column, column_name=name)
+ return self.preparer.format_column_with_table(column, column_name=name, table_name=self._anonymize(column.table.name))
- if len(self.select_stack):
- # if we are within a visit to a Select, set up the "typemap"
- # for this column which is used to translate result set values
- self.typemap.setdefault(name.lower(), column.type)
- self.column_labels.setdefault(column._label, name.lower())
- def visit_fromclause(self, fromclause):
- self.froms[fromclause] = fromclause.name
+ def visit_fromclause(self, fromclause, **kwargs):
+ return fromclause.name
- def visit_index(self, index):
- self.strings[index] = index.name
+ def visit_index(self, index, **kwargs):
+ return index.name
- def visit_typeclause(self, typeclause):
- self.strings[typeclause] = typeclause.type.dialect_impl(self.dialect).get_col_spec()
+ def visit_typeclause(self, typeclause, **kwargs):
+ return typeclause.type.dialect_impl(self.dialect).get_col_spec()
- def visit_textclause(self, textclause):
- self.strings[textclause] = textclause.text
- self.froms[textclause] = textclause.text
+ def visit_textclause(self, textclause, **kwargs):
+ for bind in textclause.bindparams.values():
+ self.process(bind)
if textclause.typemap is not None:
self.typemap.update(textclause.typemap)
+ return textclause.text
- def visit_null(self, null):
- self.strings[null] = 'NULL'
+ def visit_null(self, null, **kwargs):
+ return 'NULL'
- def visit_clauselist(self, list):
- sep = list.operator
- if sep == ',':
- sep = ', '
- elif sep is None or sep == " ":
+ def visit_clauselist(self, clauselist, **kwargs):
+ sep = clauselist.operator
+ if sep is None:
sep = " "
+ elif sep == sql.ColumnOperators.comma_op:
+ sep = ', '
else:
- sep = " " + sep + " "
- self.strings[list] = string.join([s for s in [self.get_str(c) for c in list.clauses] if s is not None], sep)
+ sep = " " + self.operator_string(clauselist.operator) + " "
+ return string.join([s for s in [self.process(c) for c in clauselist.clauses] if s is not None], sep)
def apply_function_parens(self, func):
return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0
- def visit_calculatedclause(self, clause):
- self.strings[clause] = self.get_str(clause.clause_expr)
+ def visit_calculatedclause(self, clause, **kwargs):
+ return self.process(clause.clause_expr)
- def visit_cast(self, cast):
+ def visit_cast(self, cast, **kwargs):
if len(self.select_stack):
# not sure if we want to set the typemap here...
self.typemap.setdefault("CAST", cast.type)
- self.strings[cast] = "CAST(%s AS %s)" % (self.strings[cast.clause],self.strings[cast.typeclause])
+ return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause))
- def visit_function(self, func):
+ def visit_function(self, func, **kwargs):
if len(self.select_stack):
self.typemap.setdefault(func.name, func.type)
if not self.apply_function_parens(func):
- self.strings[func] = ".".join(func.packagenames + [func.name])
- self.froms[func] = self.strings[func]
+ return ".".join(func.packagenames + [func.name])
else:
- self.strings[func] = ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.get_str(func.clause_expr)
- self.froms[func] = self.strings[func]
+ return ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.process(func.clause_expr)
- def visit_compound_select(self, cs):
- text = string.join([self.get_str(c) for c in cs.selects], " " + cs.keyword + " ")
- group_by = self.get_str(cs.group_by_clause)
+ def visit_compound_select(self, cs, asfrom=False, **kwargs):
+ text = string.join([self.process(c) for c in cs.selects], " " + cs.keyword + " ")
+ group_by = self.process(cs._group_by_clause)
if group_by:
text += " GROUP BY " + group_by
text += self.order_by_clause(cs)
- text += self.visit_select_postclauses(cs)
- self.strings[cs] = text
- self.froms[cs] = "(" + text + ")"
+ text += (cs._limit or cs._offset) and self.limit_clause(cs) or ""
+
+ if asfrom:
+ return "(" + text + ")"
+ else:
+ return text
- def visit_unary(self, unary):
- s = self.get_str(unary.element)
+ def visit_unary(self, unary, **kwargs):
+ s = self.process(unary.element)
if unary.operator:
- s = unary.operator + " " + s
+ s = self.operator_string(unary.operator) + " " + s
if unary.modifier:
s = s + " " + unary.modifier
- self.strings[unary] = s
+ return s
- def visit_binary(self, binary):
- result = self.get_str(binary.left)
- if binary.operator is not None:
- result += " " + self.binary_operator_string(binary)
- result += " " + self.get_str(binary.right)
- self.strings[binary] = result
-
- def binary_operator_string(self, binary):
- return binary.operator
+ def visit_binary(self, binary, **kwargs):
+ op = self.operator_string(binary.operator)
+ if callable(op):
+ return op(self.process(binary.left), self.process(binary.right))
+ else:
+ return self.process(binary.left) + " " + op + " " + self.process(binary.right)
+
+ def operator_string(self, operator):
+ return self.operators.get(operator, str(operator))
- def visit_bindparam(self, bindparam):
+ def visit_bindparam(self, bindparam, **kwargs):
# apply truncation to the ultimate generated name
if bindparam.shortname != bindparam.key:
@@ -378,7 +411,6 @@ class ANSICompiler(sql.Compiled):
if bindparam.unique:
count = 1
key = bindparam.key
-
# redefine the generated name of the bind param in the case
# that we have multiple conflicting bind parameters.
while self.binds.setdefault(key, bindparam) is not bindparam:
@@ -386,164 +418,167 @@ class ANSICompiler(sql.Compiled):
key = bindparam.key + tag
count += 1
bindparam.key = key
- self.strings[bindparam] = self.bindparam_string(self._truncate_bindparam(bindparam))
+ return self.bindparam_string(self._truncate_bindparam(bindparam))
else:
existing = self.binds.get(bindparam.key)
if existing is not None and existing.unique:
raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key)
- self.strings[bindparam] = self.bindparam_string(self._truncate_bindparam(bindparam))
self.binds[bindparam.key] = bindparam
+ return self.bindparam_string(self._truncate_bindparam(bindparam))
def _truncate_bindparam(self, bindparam):
if bindparam in self.bind_names:
return self.bind_names[bindparam]
bind_name = bindparam.key
- if len(bind_name) > self.dialect.max_identifier_length():
- bind_name = self._truncated_identifier("bindparam", bind_name)
- # add to bind_names for translation
- self.bind_names[bindparam] = bind_name
+ bind_name = self._truncated_identifier("bindparam", bind_name)
+ # add to bind_names for translation
+ self.bind_names[bindparam] = bind_name
+
return bind_name
def _truncated_identifier(self, ident_class, name):
if (ident_class, name) in self.generated_ids:
return self.generated_ids[(ident_class, name)]
- if len(name) > self.dialect.max_identifier_length():
+
+ anonname = self._anonymize(name)
+ if len(anonname) > self.dialect.max_identifier_length():
counter = self.generated_ids.get(ident_class, 1)
truncname = name[0:self.dialect.max_identifier_length() - 6] + "_" + hex(counter)[2:]
self.generated_ids[ident_class] = counter + 1
else:
- truncname = name
+ truncname = anonname
self.generated_ids[(ident_class, name)] = truncname
return truncname
+
+ def _anonymize(self, name):
+ def anon(match):
+ (ident, derived) = match.group(1,2)
+ if ('anonymous', ident) in self.generated_ids:
+ return self.generated_ids[('anonymous', ident)]
+ else:
+ anonymous_counter = self.generated_ids.get('anonymous', 1)
+ newname = derived + "_" + str(anonymous_counter)
+ self.generated_ids['anonymous'] = anonymous_counter + 1
+ self.generated_ids[('anonymous', ident)] = newname
+ return newname
+ return re.sub(r'{ANON (-?\d+) (.*)}', anon, name)
def bindparam_string(self, name):
return self.bindtemplate % name
- def visit_alias(self, alias):
- self.froms[alias] = self.get_from_text(alias.original) + " AS " + self.preparer.format_alias(alias)
- self.strings[alias] = self.get_str(alias.original)
+ def visit_alias(self, alias, asfrom=False, **kwargs):
+ if asfrom:
+ return self.process(alias.original, asfrom=True, **kwargs) + " AS " + self.preparer.format_alias(alias, self._anonymize(alias.name))
+ else:
+ return self.process(alias.original, **kwargs)
- def visit_select(self, select):
- # the actual list of columns to print in the SELECT column list.
- inner_columns = util.OrderedDict()
+ def label_select_column(self, select, column):
+ """convert a column from a select's "columns" clause.
+
+ given a select() and a column element from its inner_columns collection, return a
+ Label object if this column should be labeled in the columns clause. Otherwise,
+ return None and the column will be used as-is.
+
+ The calling method will traverse the returned label to acquire its string
+ representation.
+ """
+
+ # SQLite doesnt like selecting from a subquery where the column
+ # names look like table.colname. so if column is in a "selected from"
+ # subquery, label it synoymously with its column name
+ if \
+ self.correlate_state[select].get('is_selected_from', False) and \
+ isinstance(column, sql._ColumnClause) and \
+ not column.is_literal and \
+ column.table is not None and \
+ not isinstance(column.table, sql.Select):
+ return column.label(column.name)
+ else:
+ return None
+
+ def visit_select(self, select, asfrom=False, **kwargs):
+ select._calculate_correlations(self.correlate_state)
self.select_stack.append(select)
- for c in select._raw_columns:
- if hasattr(c, '_selectable'):
- s = c._selectable()
- else:
- self.traverse(c)
- inner_columns[self.get_str(c)] = c
- continue
- for co in s.columns:
- if select.use_labels:
- labelname = co._label
- if labelname is not None:
- l = co.label(labelname)
- self.traverse(l)
- inner_columns[labelname] = l
- else:
- self.traverse(co)
- inner_columns[self.get_str(co)] = co
- # TODO: figure this out, a ColumnClause with a select as a parent
- # is different from any other kind of parent
- elif select.is_selected_from and isinstance(co, sql._ColumnClause) and not co.is_literal and co.table is not None and not isinstance(co.table, sql.Select):
- # SQLite doesnt like selecting from a subquery where the column
- # names look like table.colname, so add a label synonomous with
- # the column name
- l = co.label(co.name)
- self.traverse(l)
- inner_columns[self.get_str(l.obj)] = l
+
+ # the actual list of columns to print in the SELECT column list.
+ inner_columns = util.OrderedSet()
+
+ froms = select._get_display_froms(self.correlate_state)
+
+ for co in select.inner_columns:
+ if select.use_labels:
+ labelname = co._label
+ if labelname is not None:
+ l = co.label(labelname)
+ inner_columns.add(self.process(l))
else:
self.traverse(co)
- inner_columns[self.get_str(co)] = co
+ inner_columns.add(self.process(co))
+ else:
+ l = self.label_select_column(select, co)
+ if l is not None:
+ inner_columns.add(self.process(l))
+ else:
+ inner_columns.add(self.process(co))
+
self.select_stack.pop(-1)
- collist = string.join([self.get_str(v) for v in inner_columns.values()], ', ')
+ collist = string.join(inner_columns.difference(util.Set([None])), ', ')
- text = "SELECT "
- text += self.visit_select_precolumns(select)
+ text = " ".join(["SELECT"] + [self.process(x) for x in select._prefixes]) + " "
+ text += self.get_select_precolumns(select)
text += collist
- whereclause = select.whereclause
-
- froms = []
- for f in select.froms:
-
- if self.parameters is not None:
- # TODO: whack this feature in 0.4
- # look at our own parameters, see if they
- # are all present in the form of BindParamClauses. if
- # not, then append to the above whereclause column conditions
- # matching those keys
- for c in f.columns:
- if sql.is_column(c) and self.parameters.has_key(c.key) and not self.binds.has_key(c.key):
- value = self.parameters[c.key]
- else:
- continue
- clause = c==value
- if whereclause is not None:
- whereclause = self.traverse(sql.and_(clause, whereclause), stop_on=util.Set([whereclause]))
- else:
- whereclause = clause
- self.traverse(whereclause)
-
- # special thingy used by oracle to redefine a join
+ whereclause = select._whereclause
+
+ from_strings = []
+ for f in froms:
+ from_strings.append(self.process(f, asfrom=True))
+
w = self.get_whereclause(f)
if w is not None:
- # TODO: move this more into the oracle module
if whereclause is not None:
- whereclause = self.traverse(sql.and_(w, whereclause), stop_on=util.Set([whereclause, w]))
+ whereclause = sql.and_(w, whereclause)
else:
whereclause = w
- t = self.get_from_text(f)
- if t is not None:
- froms.append(t)
-
if len(froms):
text += " \nFROM "
- text += string.join(froms, ', ')
+ text += string.join(from_strings, ', ')
else:
text += self.default_from()
if whereclause is not None:
- t = self.get_str(whereclause)
+ t = self.process(whereclause)
if t:
text += " \nWHERE " + t
- group_by = self.get_str(select.group_by_clause)
+ group_by = self.process(select._group_by_clause)
if group_by:
text += " GROUP BY " + group_by
- if select.having is not None:
- t = self.get_str(select.having)
+ if select._having is not None:
+ t = self.process(select._having)
if t:
text += " \nHAVING " + t
text += self.order_by_clause(select)
- text += self.visit_select_postclauses(select)
+ text += (select._limit or select._offset) and self.limit_clause(select) or ""
text += self.for_update_clause(select)
- self.strings[select] = text
- self.froms[select] = "(" + text + ")"
+ if asfrom:
+ return "(" + text + ")"
+ else:
+ return text
- def visit_select_precolumns(self, select):
+ def get_select_precolumns(self, select):
"""Called when building a ``SELECT`` statement, position is just before column list."""
-
- return select.distinct and "DISTINCT " or ""
-
- def visit_select_postclauses(self, select):
- """Called when building a ``SELECT`` statement, position is after all other ``SELECT`` clauses.
-
- Most DB syntaxes put ``LIMIT``/``OFFSET`` here.
- """
-
- return (select.limit or select.offset) and self.limit_clause(select) or ""
+ return select._distinct and "DISTINCT " or ""
def order_by_clause(self, select):
- order_by = self.get_str(select.order_by_clause)
+ order_by = self.process(select._order_by_clause)
if order_by:
return " ORDER BY " + order_by
else:
@@ -557,175 +592,103 @@ class ANSICompiler(sql.Compiled):
def limit_clause(self, select):
text = ""
- if select.limit is not None:
- text += " \n LIMIT " + str(select.limit)
- if select.offset is not None:
- if select.limit is None:
+ if select._limit is not None:
+ text += " \n LIMIT " + str(select._limit)
+ if select._offset is not None:
+ if select._limit is None:
text += " \n LIMIT -1"
- text += " OFFSET " + str(select.offset)
+ text += " OFFSET " + str(select._offset)
return text
- def visit_table(self, table):
- self.froms[table] = self.preparer.format_table(table)
- self.strings[table] = ""
-
- def visit_join(self, join):
- righttext = self.get_from_text(join.right)
- if join.right._group_parenthesized():
- righttext = "(" + righttext + ")"
- if join.isouter:
- self.froms[join] = (self.get_from_text(join.left) + " LEFT OUTER JOIN " + righttext +
- " ON " + self.get_str(join.onclause))
+ def visit_table(self, table, asfrom=False, **kwargs):
+ if asfrom:
+ return self.preparer.format_table(table)
else:
- self.froms[join] = (self.get_from_text(join.left) + " JOIN " + righttext +
- " ON " + self.get_str(join.onclause))
- self.strings[join] = self.froms[join]
-
- def visit_insert_column_default(self, column, default, parameters):
- """Called when visiting an ``Insert`` statement.
-
- For each column in the table that contains a ``ColumnDefault``
- object, add a blank *placeholder* parameter so the ``Insert``
- gets compiled with this column's name in its column and
- ``VALUES`` clauses.
- """
-
- parameters.setdefault(column.key, None)
-
- def visit_update_column_default(self, column, default, parameters):
- """Called when visiting an ``Update`` statement.
-
- For each column in the table that contains a ``ColumnDefault``
- object as an onupdate, add a blank *placeholder* parameter so
- the ``Update`` gets compiled with this column's name as one of
- its ``SET`` clauses.
- """
-
- parameters.setdefault(column.key, None)
-
- def visit_insert_sequence(self, column, sequence, parameters):
- """Called when visiting an ``Insert`` statement.
-
- This may be overridden compilers that support sequences to
- place a blank *placeholder* parameter for each column in the
- table that contains a Sequence object, so the Insert gets
- compiled with this column's name in its column and ``VALUES``
- clauses.
- """
-
- pass
-
- def visit_insert_column(self, column, parameters):
- """Called when visiting an ``Insert`` statement.
-
- This may be overridden by compilers who disallow NULL columns
- being set in an ``Insert`` where there is a default value on
- the column (i.e. postgres), to remove the column for which
- there is a NULL insert from the parameter list.
- """
+ return ""
- pass
+ def visit_join(self, join, asfrom=False, **kwargs):
+ return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \
+ self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause))
+ def uses_sequences_for_inserts(self):
+ return False
+
def visit_insert(self, insert_stmt):
- # scan the table's columns for defaults that have to be pre-set for an INSERT
- # add these columns to the parameter list via visit_insert_XXX methods
- default_params = {}
+
+ # search for columns who will be required to have an explicit bound value.
+ # for inserts, this includes Python-side defaults, columns with sequences for dialects
+ # that support sequences, and primary key columns for dialects that explicitly insert
+ # pre-generated primary key values
+ required_cols = util.Set()
class DefaultVisitor(schema.SchemaVisitor):
- def visit_column(s, c):
- self.visit_insert_column(c, default_params)
+ def visit_column(s, cd):
+ if c.primary_key and self.uses_sequences_for_inserts():
+ required_cols.add(c)
def visit_column_default(s, cd):
- self.visit_insert_column_default(c, cd, default_params)
+ required_cols.add(c)
def visit_sequence(s, seq):
- self.visit_insert_sequence(c, seq, default_params)
+ if self.uses_sequences_for_inserts():
+ required_cols.add(c)
vis = DefaultVisitor()
for c in insert_stmt.table.c:
if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
vis.traverse(c)
self.isinsert = True
- colparams = self._get_colparams(insert_stmt, default_params)
+ colparams = self._get_colparams(insert_stmt, required_cols)
- self.inline_params = util.Set()
- def create_param(col, p):
- if isinstance(p, sql._BindParamClause):
- self.binds[p.key] = p
- if p.shortname is not None:
- self.binds[p.shortname] = p
- return self.bindparam_string(self._truncate_bindparam(p))
- else:
- self.inline_params.add(col)
- self.traverse(p)
- if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnElement):
- return "(" + self.get_str(p) + ")"
- else:
- return self.get_str(p)
-
- text = ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" +
- " VALUES (" + string.join([create_param(*c) for c in colparams], ', ') + ")")
-
- self.strings[insert_stmt] = text
+ return ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" +
+ " VALUES (" + string.join([c[1] for c in colparams], ', ') + ")")
def visit_update(self, update_stmt):
- # scan the table's columns for onupdates that have to be pre-set for an UPDATE
- # add these columns to the parameter list via visit_update_XXX methods
- default_params = {}
+ update_stmt._calculate_correlations(self.correlate_state)
+
+ # search for columns who will be required to have an explicit bound value.
+ # for updates, this includes Python-side "onupdate" defaults.
+ required_cols = util.Set()
class OnUpdateVisitor(schema.SchemaVisitor):
def visit_column_onupdate(s, cd):
- self.visit_update_column_default(c, cd, default_params)
+ required_cols.add(c)
vis = OnUpdateVisitor()
for c in update_stmt.table.c:
if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
vis.traverse(c)
self.isupdate = True
- colparams = self._get_colparams(update_stmt, default_params)
-
- self.inline_params = util.Set()
- def create_param(col, p):
- if isinstance(p, sql._BindParamClause):
- self.binds[p.key] = p
- self.binds[p.shortname] = p
- return self.bindparam_string(self._truncate_bindparam(p))
- else:
- self.traverse(p)
- self.inline_params.add(col)
- if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnElement):
- return "(" + self.get_str(p) + ")"
- else:
- return self.get_str(p)
-
- text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), create_param(*c)) for c in colparams], ', ')
-
- if update_stmt.whereclause:
- text += " WHERE " + self.get_str(update_stmt.whereclause)
+ colparams = self._get_colparams(update_stmt, required_cols)
- self.strings[update_stmt] = text
+ text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams], ', ')
+ if update_stmt._whereclause:
+ text += " WHERE " + self.process(update_stmt._whereclause)
- def _get_colparams(self, stmt, default_params):
- """Organize ``UPDATE``/``INSERT`` ``SET``/``VALUES`` parameters into a list of tuples.
-
- Each tuple will contain the ``Column`` and a ``ClauseElement``
- representing the value to be set (usually a ``_BindParamClause``,
- but could also be other SQL expressions.)
-
- The list of tuples will determine the columns that are
- actually rendered into the ``SET``/``VALUES`` clause of the
- rendered ``UPDATE``/``INSERT`` statement. It will also
- determine how to generate the list/dictionary of bind
- parameters at execution time (i.e. ``get_params()``).
+ return text
- This list takes into account the `values` keyword specified
- to the statement, the parameters sent to this Compiled
- instance, and the default bind parameter values corresponding
- to the dialect's behavior for otherwise unspecified primary
- key columns.
+ def _get_colparams(self, stmt, required_cols):
+ """create a set of tuples representing column/string pairs for use
+ in an INSERT or UPDATE statement.
+
+ This method may generate new bind params within this compiled
+ based on the given set of "required columns", which are required
+ to have a value set in the statement.
"""
+ def create_bind_param(col, value):
+ bindparam = sql.bindparam(col.key, value, type_=col.type, unique=True)
+ self.binds[col.key] = bindparam
+ return self.bindparam_string(self._truncate_bindparam(bindparam))
+
# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
if self.parameters is None and stmt.parameters is None:
- return [(c, sql.bindparam(c.key, type=c.type)) for c in stmt.table.columns]
+ return [(c, create_bind_param(c, None)) for c in stmt.table.columns]
+
+ def create_clause_param(col, value):
+ self.traverse(value)
+ self.inline_params.add(col)
+ return self.process(value)
+
+ self.inline_params = util.Set()
def to_col(key):
if not isinstance(key, sql._ColumnClause):
@@ -744,29 +707,43 @@ class ANSICompiler(sql.Compiled):
for k, v in stmt.parameters.iteritems():
parameters.setdefault(to_col(k), v)
- for k, v in default_params.iteritems():
- parameters.setdefault(to_col(k), v)
+ for col in required_cols:
+ parameters.setdefault(col, None)
# create a list of column assignment clauses as tuples
values = []
for c in stmt.table.columns:
- if parameters.has_key(c):
+ if c in parameters:
value = parameters[c]
if sql._is_literal(value):
- value = sql.bindparam(c.key, value, type=c.type, unique=True)
+ value = create_bind_param(c, value)
+ else:
+ value = create_clause_param(c, value)
values.append((c, value))
+
return values
def visit_delete(self, delete_stmt):
+ delete_stmt._calculate_correlations(self.correlate_state)
+
text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table)
- if delete_stmt.whereclause:
- text += " WHERE " + self.get_str(delete_stmt.whereclause)
+ if delete_stmt._whereclause:
+ text += " WHERE " + self.process(delete_stmt._whereclause)
- self.strings[delete_stmt] = text
+ return text
+
+ def visit_savepoint(self, savepoint_stmt):
+ return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
+ def visit_rollback_to_savepoint(self, savepoint_stmt):
+ return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
+
+ def visit_release_savepoint(self, savepoint_stmt):
+ return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
+
def __str__(self):
- return self.get_str(self.statement)
+ return self.string
class ANSISchemaBase(engine.SchemaIterator):
def find_alterables(self, tables):
@@ -795,7 +772,7 @@ class ANSISchemaGenerator(ANSISchemaBase):
def visit_metadata(self, metadata):
collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name, schema=t.schema))]
for table in collection:
- table.accept_visitor(self)
+ self.traverse_single(table)
if self.dialect.supports_alter():
for alterable in self.find_alterables(collection):
self.add_foreignkey(alterable)
@@ -803,9 +780,7 @@ class ANSISchemaGenerator(ANSISchemaBase):
def visit_table(self, table):
for column in table.columns:
if column.default is not None:
- column.default.accept_visitor(self)
- #if column.onupdate is not None:
- # column.onupdate.accept_visitor(visitor)
+ self.traverse_single(column.default)
self.append("\nCREATE TABLE " + self.preparer.format_table(table) + " (")
@@ -820,20 +795,20 @@ class ANSISchemaGenerator(ANSISchemaBase):
if column.primary_key:
first_pk = True
for constraint in column.constraints:
- constraint.accept_visitor(self)
+ self.traverse_single(constraint)
# On some DB order is significant: visit PK first, then the
# other constraints (engine.ReflectionTest.testbasic failed on FB2)
if len(table.primary_key):
- table.primary_key.accept_visitor(self)
+ self.traverse_single(table.primary_key)
for constraint in [c for c in table.constraints if c is not table.primary_key]:
- constraint.accept_visitor(self)
+ self.traverse_single(constraint)
self.append("\n)%s\n\n" % self.post_create_table(table))
self.execute()
if hasattr(table, 'indexes'):
for index in table.indexes:
- index.accept_visitor(self)
+ self.traverse_single(index)
def post_create_table(self, table):
return ''
@@ -870,7 +845,7 @@ class ANSISchemaGenerator(ANSISchemaBase):
if constraint.name is not None:
self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint))
self.append("PRIMARY KEY ")
- self.append("(%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', ')))
+ self.append("(%s)" % ', '.join([self.preparer.format_column(c) for c in constraint]))
def visit_foreign_key_constraint(self, constraint):
if constraint.use_alter and self.dialect.supports_alter():
@@ -889,9 +864,9 @@ class ANSISchemaGenerator(ANSISchemaBase):
self.append("CONSTRAINT %s " %
preparer.format_constraint(constraint))
self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % (
- string.join([preparer.format_column(f.parent) for f in constraint.elements], ', '),
+ ', '.join([preparer.format_column(f.parent) for f in constraint.elements]),
preparer.format_table(list(constraint.elements)[0].column.table),
- string.join([preparer.format_column(f.column) for f in constraint.elements], ', ')
+ ', '.join([preparer.format_column(f.column) for f in constraint.elements])
))
if constraint.ondelete is not None:
self.append(" ON DELETE %s" % constraint.ondelete)
@@ -903,17 +878,17 @@ class ANSISchemaGenerator(ANSISchemaBase):
if constraint.name is not None:
self.append("CONSTRAINT %s " %
self.preparer.format_constraint(constraint))
- self.append(" UNIQUE (%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', ')))
+ self.append(" UNIQUE (%s)" % (', '.join([self.preparer.format_column(c) for c in constraint])))
def visit_column(self, column):
pass
def visit_index(self, index):
- preparer = self.preparer
- self.append('CREATE ')
+ preparer = self.preparer
+ self.append("CREATE ")
if index.unique:
- self.append('UNIQUE ')
- self.append('INDEX %s ON %s (%s)' \
+ self.append("UNIQUE ")
+ self.append("INDEX %s ON %s (%s)" \
% (preparer.format_index(index),
preparer.format_table(index.table),
string.join([preparer.format_column(c) for c in index.columns], ', ')))
@@ -933,7 +908,7 @@ class ANSISchemaDropper(ANSISchemaBase):
for alterable in self.find_alterables(collection):
self.drop_foreignkey(alterable)
for table in collection:
- table.accept_visitor(self)
+ self.traverse_single(table)
def visit_index(self, index):
self.append("\nDROP INDEX " + self.preparer.format_index(index))
@@ -948,7 +923,7 @@ class ANSISchemaDropper(ANSISchemaBase):
def visit_table(self, table):
for column in table.columns:
if column.default is not None:
- column.default.accept_visitor(self)
+ self.traverse_single(column.default)
self.append("\nDROP TABLE " + self.preparer.format_table(table))
self.execute()
@@ -1048,17 +1023,17 @@ class ANSIIdentifierPreparer(object):
def should_quote(self, object):
return object.quote or self._requires_quotes(object.name, object.case_sensitive)
- def is_natural_case(self, object):
- return object.quote or self._requires_quotes(object.name, object.case_sensitive)
-
def format_sequence(self, sequence):
return self.__generic_obj_format(sequence, sequence.name)
def format_label(self, label, name=None):
return self.__generic_obj_format(label, name or label.name)
- def format_alias(self, alias):
- return self.__generic_obj_format(alias, alias.name)
+ def format_alias(self, alias, name=None):
+ return self.__generic_obj_format(alias, name or alias.name)
+
+ def format_savepoint(self, savepoint):
+ return self.__generic_obj_format(savepoint, savepoint)
def format_constraint(self, constraint):
return self.__generic_obj_format(constraint, constraint.name)
@@ -1076,25 +1051,25 @@ class ANSIIdentifierPreparer(object):
result = self.__generic_obj_format(table, table.schema) + "." + result
return result
- def format_column(self, column, use_table=False, name=None):
+ def format_column(self, column, use_table=False, name=None, table_name=None):
"""Prepare a quoted column name."""
if name is None:
name = column.name
if not getattr(column, 'is_literal', False):
if use_table:
- return self.format_table(column.table, use_schema=False) + "." + self.__generic_obj_format(column, name)
+ return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.__generic_obj_format(column, name)
else:
return self.__generic_obj_format(column, name)
else:
# literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted
if use_table:
- return self.format_table(column.table, use_schema=False) + "." + name
+ return self.format_table(column.table, use_schema=False, name=table_name) + "." + name
else:
return name
- def format_column_with_table(self, column, column_name=None):
+ def format_column_with_table(self, column, column_name=None, table_name=None):
"""Prepare a quoted column name with table name."""
- return self.format_column(column, use_table=True, name=column_name)
+ return self.format_column(column, use_table=True, name=column_name, table_name=table_name)
dialect = ANSIDialect