diff options
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r-- | lib/sqlalchemy/ansisql.py | 749 |
1 files changed, 362 insertions, 387 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 9994d5288..22227d56a 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -10,9 +10,10 @@ Contains default implementations for the abstract objects in the sql module. """ -from sqlalchemy import schema, sql, engine, util, sql_util, exceptions +import string, re, sets, operator + +from sqlalchemy import schema, sql, engine, util, exceptions from sqlalchemy.engine import default -import string, re, sets, weakref, random ANSI_FUNCS = sets.ImmutableSet(['CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP', 'CURRENT_USER', 'LOCALTIME', 'LOCALTIMESTAMP', @@ -40,6 +41,41 @@ RESERVED_WORDS = util.Set(['all', 'analyse', 'analyze', 'and', 'any', 'array', LEGAL_CHARACTERS = util.Set(string.ascii_lowercase + string.ascii_uppercase + string.digits + '_$') ILLEGAL_INITIAL_CHARACTERS = util.Set(string.digits + '$') +BIND_PARAMS = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE) +BIND_PARAMS_ESC = re.compile(r'\x5c(:\w+)(?!:)', re.UNICODE) + +OPERATORS = { + operator.and_ : 'AND', + operator.or_ : 'OR', + operator.inv : 'NOT', + operator.add : '+', + operator.mul : '*', + operator.sub : '-', + operator.div : '/', + operator.mod : '%', + operator.truediv : '/', + operator.lt : '<', + operator.le : '<=', + operator.ne : '!=', + operator.gt : '>', + operator.ge : '>=', + operator.eq : '=', + sql.ColumnOperators.concat_op : '||', + sql.ColumnOperators.like_op : 'LIKE', + sql.ColumnOperators.notlike_op : 'NOT LIKE', + sql.ColumnOperators.ilike_op : 'ILIKE', + sql.ColumnOperators.notilike_op : 'NOT ILIKE', + sql.ColumnOperators.between_op : 'BETWEEN', + sql.ColumnOperators.in_op : 'IN', + sql.ColumnOperators.notin_op : 'NOT IN', + sql.ColumnOperators.comma_op : ', ', + sql.Operators.from_ : 'FROM', + sql.Operators.as_ : 'AS', + sql.Operators.exists : 'EXISTS', + sql.Operators.is_ : 'IS', + sql.Operators.isnot : 'IS NOT' +} + class ANSIDialect(default.DefaultDialect): def __init__(self, cache_identifiers=True, **kwargs): super(ANSIDialect,self).__init__(**kwargs) @@ -66,14 +102,16 @@ class ANSIDialect(default.DefaultDialect): """ return ANSIIdentifierPreparer(self) -class ANSICompiler(sql.Compiled): +class ANSICompiler(engine.Compiled, sql.ClauseVisitor): """Default implementation of Compiled. Compiles ClauseElements into ANSI-compliant SQL strings. """ - __traverse_options__ = {'column_collections':False} + __traverse_options__ = {'column_collections':False, 'entry':True} + operators = OPERATORS + def __init__(self, dialect, statement, parameters=None, **kwargs): """Construct a new ``ANSICompiler`` object. @@ -92,7 +130,7 @@ class ANSICompiler(sql.Compiled): correspond to the keys present in the parameters. """ - sql.Compiled.__init__(self, dialect, statement, parameters, **kwargs) + super(ANSICompiler, self).__init__(dialect, statement, parameters, **kwargs) # if we are insert/update. set to true when we visit an INSERT or UPDATE self.isinsert = self.isupdate = False @@ -104,21 +142,6 @@ class ANSICompiler(sql.Compiled): # actually present in the generated SQL self.bind_names = {} - # a dictionary which stores the string representation for every ClauseElement - # processed by this compiler. - self.strings = {} - - # a dictionary which stores the string representation for ClauseElements - # processed by this compiler, which are to be used in the FROM clause - # of a select. items are often placed in "froms" as well as "strings" - # and sometimes with different representations. - self.froms = {} - - # slightly hacky. maps FROM clauses to WHERE clauses, and used in select - # generation to modify the WHERE clause of the select. currently a hack - # used by the oracle module. - self.wheres = {} - # when the compiler visits a SELECT statement, the clause object is appended # to this stack. various visit operations will check this stack to determine # additional choices (TODO: it seems to be all typemap stuff. shouldnt this only @@ -137,12 +160,6 @@ class ANSICompiler(sql.Compiled): # for aliases self.generated_ids = {} - # True if this compiled represents an INSERT - self.isinsert = False - - # True if this compiled represents an UPDATE - self.isupdate = False - # default formatting style for bind parameters self.bindtemplate = ":%s" @@ -158,64 +175,76 @@ class ANSICompiler(sql.Compiled): # an ANSIIdentifierPreparer that formats the quoting of identifiers self.preparer = dialect.identifier_preparer - + + # a dictionary containing attributes about all select() + # elements located within the clause, regarding which are subqueries, which are + # selected from, and which elements should be correlated to an enclosing select. + # used mostly to determine the list of FROM elements for each select statement, as well + # as some dialect-specific rules regarding subqueries. + self.correlate_state = {} + # for UPDATE and INSERT statements, a set of columns whos values are being set # from a SQL expression (i.e., not one of the bind parameter values). if present, # default-value logic in the Dialect knows not to fire off column defaults # and also knows postfetching will be needed to get the values represented by these # parameters. self.inline_params = None - + def after_compile(self): # this re will search for params like :param # it has a negative lookbehind for an extra ':' so that it doesnt match # postgres '::text' tokens - match = re.compile(r'(?<!:):([\w_]+)', re.UNICODE) + text = self.string + if ':' not in text: + return + if self.paramstyle=='pyformat': - self.strings[self.statement] = match.sub(lambda m:'%(' + m.group(1) +')s', self.strings[self.statement]) + text = BIND_PARAMS.sub(lambda m:'%(' + m.group(1) +')s', text) elif self.positional: - params = match.finditer(self.strings[self.statement]) + params = BIND_PARAMS.finditer(text) for p in params: self.positiontup.append(p.group(1)) if self.paramstyle=='qmark': - self.strings[self.statement] = match.sub('?', self.strings[self.statement]) + text = BIND_PARAMS.sub('?', text) elif self.paramstyle=='format': - self.strings[self.statement] = match.sub('%s', self.strings[self.statement]) + text = BIND_PARAMS.sub('%s', text) elif self.paramstyle=='numeric': i = [0] def getnum(x): i[0] += 1 return str(i[0]) - self.strings[self.statement] = match.sub(getnum, self.strings[self.statement]) - - def get_from_text(self, obj): - return self.froms.get(obj, None) - - def get_str(self, obj): - return self.strings[obj] - + text = BIND_PARAMS.sub(getnum, text) + # un-escape any \:params + text = BIND_PARAMS_ESC.sub(lambda m: m.group(1), text) + self.string = text + + def compile(self): + self.string = self.process(self.statement) + self.after_compile() + + def process(self, obj, **kwargs): + return self.traverse_single(obj, **kwargs) + + def is_subquery(self, select): + return self.correlate_state[select].get('is_subquery', False) + def get_whereclause(self, obj): - return self.wheres.get(obj, None) + """given a FROM clause, return an additional WHERE condition that should be + applied to a SELECT. + + Currently used by Oracle to provide WHERE criterion for JOIN and OUTER JOIN + constructs in non-ansi mode. + """ + + return None def construct_params(self, params): - """Return a structure of bind parameters for this compiled object. - - This includes bind parameters that might be compiled in via - the `values` argument of an ``Insert`` or ``Update`` statement - object, and also the given `**params`. The keys inside of - `**params` can be any key that matches the - ``BindParameterClause`` objects compiled within this object. - - The output is dependent on the paramstyle of the DBAPI being - used; if a named style, the return result will be a dictionary - with keynames matching the compiled statement. If a - positional style, the output will be a list, with an iterator - that will return parameter values in an order corresponding to - the bind positions in the compiled statement. - - For an executemany style of call, this method should be called - for each element in the list of parameter groups that will - ultimately be executed. + """Return a sql.ClauseParameters object. + + Combines the given bind parameter dictionary (string keys to object values) + with the _BindParamClause objects stored within this Compiled object + to produce a ClauseParameters structure, representing the bind arguments + for a single statement execution, or one element of an executemany execution. """ if self.parameters is not None: @@ -225,7 +254,7 @@ class ANSICompiler(sql.Compiled): bindparams.update(params) d = sql.ClauseParameters(self.dialect, self.positiontup) for b in self.binds.values(): - name = self.bind_names.get(b, b.key) + name = self.bind_names[b] d.set_parameter(b, b.value, name) for key, value in bindparams.iteritems(): @@ -233,7 +262,7 @@ class ANSICompiler(sql.Compiled): b = self.binds[key] except KeyError: continue - name = self.bind_names.get(b, b.key) + name = self.bind_names[b] d.set_parameter(b, value, name) return d @@ -246,8 +275,8 @@ class ANSICompiler(sql.Compiled): return "" - def visit_grouping(self, grouping): - self.strings[grouping] = "(" + self.strings[grouping.elem] + ")" + def visit_grouping(self, grouping, **kwargs): + return "(" + self.process(grouping.elem) + ")" def visit_label(self, label): labelname = self._truncated_identifier("colident", label.name) @@ -256,9 +285,10 @@ class ANSICompiler(sql.Compiled): self.typemap.setdefault(labelname.lower(), label.obj.type) if isinstance(label.obj, sql._ColumnClause): self.column_labels[label.obj._label] = labelname - self.strings[label] = self.strings[label.obj] + " AS " + self.preparer.format_label(label, labelname) + self.column_labels[label.name] = labelname + return " ".join([self.process(label.obj), self.operator_string(sql.ColumnOperators.as_), self.preparer.format_label(label, labelname)]) - def visit_column(self, column): + def visit_column(self, column, **kwargs): # there is actually somewhat of a ruleset when you would *not* necessarily # want to truncate a column identifier, if its mapped to the name of a # physical column. but thats very hard to identify at this point, and @@ -269,107 +299,110 @@ class ANSICompiler(sql.Compiled): else: name = column.name + if len(self.select_stack): + # if we are within a visit to a Select, set up the "typemap" + # for this column which is used to translate result set values + self.typemap.setdefault(name.lower(), column.type) + self.column_labels.setdefault(column._label, name.lower()) + if column.table is None or not column.table.named_with_column(): - self.strings[column] = self.preparer.format_column(column, name=name) + return self.preparer.format_column(column, name=name) else: if column.table.oid_column is column: n = self.dialect.oid_column_name(column) if n is not None: - self.strings[column] = "%s.%s" % (self.preparer.format_table(column.table, use_schema=False), n) + return "%s.%s" % (self.preparer.format_table(column.table, use_schema=False, name=self._anonymize(column.table.name)), n) elif len(column.table.primary_key) != 0: pk = list(column.table.primary_key)[0] pkname = (pk.is_literal and name or self._truncated_identifier("colident", pk.name)) - self.strings[column] = self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname) + return self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname, table_name=self._anonymize(column.table.name)) else: - self.strings[column] = None + return None else: - self.strings[column] = self.preparer.format_column_with_table(column, column_name=name) + return self.preparer.format_column_with_table(column, column_name=name, table_name=self._anonymize(column.table.name)) - if len(self.select_stack): - # if we are within a visit to a Select, set up the "typemap" - # for this column which is used to translate result set values - self.typemap.setdefault(name.lower(), column.type) - self.column_labels.setdefault(column._label, name.lower()) - def visit_fromclause(self, fromclause): - self.froms[fromclause] = fromclause.name + def visit_fromclause(self, fromclause, **kwargs): + return fromclause.name - def visit_index(self, index): - self.strings[index] = index.name + def visit_index(self, index, **kwargs): + return index.name - def visit_typeclause(self, typeclause): - self.strings[typeclause] = typeclause.type.dialect_impl(self.dialect).get_col_spec() + def visit_typeclause(self, typeclause, **kwargs): + return typeclause.type.dialect_impl(self.dialect).get_col_spec() - def visit_textclause(self, textclause): - self.strings[textclause] = textclause.text - self.froms[textclause] = textclause.text + def visit_textclause(self, textclause, **kwargs): + for bind in textclause.bindparams.values(): + self.process(bind) if textclause.typemap is not None: self.typemap.update(textclause.typemap) + return textclause.text - def visit_null(self, null): - self.strings[null] = 'NULL' + def visit_null(self, null, **kwargs): + return 'NULL' - def visit_clauselist(self, list): - sep = list.operator - if sep == ',': - sep = ', ' - elif sep is None or sep == " ": + def visit_clauselist(self, clauselist, **kwargs): + sep = clauselist.operator + if sep is None: sep = " " + elif sep == sql.ColumnOperators.comma_op: + sep = ', ' else: - sep = " " + sep + " " - self.strings[list] = string.join([s for s in [self.get_str(c) for c in list.clauses] if s is not None], sep) + sep = " " + self.operator_string(clauselist.operator) + " " + return string.join([s for s in [self.process(c) for c in clauselist.clauses] if s is not None], sep) def apply_function_parens(self, func): return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0 - def visit_calculatedclause(self, clause): - self.strings[clause] = self.get_str(clause.clause_expr) + def visit_calculatedclause(self, clause, **kwargs): + return self.process(clause.clause_expr) - def visit_cast(self, cast): + def visit_cast(self, cast, **kwargs): if len(self.select_stack): # not sure if we want to set the typemap here... self.typemap.setdefault("CAST", cast.type) - self.strings[cast] = "CAST(%s AS %s)" % (self.strings[cast.clause],self.strings[cast.typeclause]) + return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause)) - def visit_function(self, func): + def visit_function(self, func, **kwargs): if len(self.select_stack): self.typemap.setdefault(func.name, func.type) if not self.apply_function_parens(func): - self.strings[func] = ".".join(func.packagenames + [func.name]) - self.froms[func] = self.strings[func] + return ".".join(func.packagenames + [func.name]) else: - self.strings[func] = ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.get_str(func.clause_expr) - self.froms[func] = self.strings[func] + return ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.process(func.clause_expr) - def visit_compound_select(self, cs): - text = string.join([self.get_str(c) for c in cs.selects], " " + cs.keyword + " ") - group_by = self.get_str(cs.group_by_clause) + def visit_compound_select(self, cs, asfrom=False, **kwargs): + text = string.join([self.process(c) for c in cs.selects], " " + cs.keyword + " ") + group_by = self.process(cs._group_by_clause) if group_by: text += " GROUP BY " + group_by text += self.order_by_clause(cs) - text += self.visit_select_postclauses(cs) - self.strings[cs] = text - self.froms[cs] = "(" + text + ")" + text += (cs._limit or cs._offset) and self.limit_clause(cs) or "" + + if asfrom: + return "(" + text + ")" + else: + return text - def visit_unary(self, unary): - s = self.get_str(unary.element) + def visit_unary(self, unary, **kwargs): + s = self.process(unary.element) if unary.operator: - s = unary.operator + " " + s + s = self.operator_string(unary.operator) + " " + s if unary.modifier: s = s + " " + unary.modifier - self.strings[unary] = s + return s - def visit_binary(self, binary): - result = self.get_str(binary.left) - if binary.operator is not None: - result += " " + self.binary_operator_string(binary) - result += " " + self.get_str(binary.right) - self.strings[binary] = result - - def binary_operator_string(self, binary): - return binary.operator + def visit_binary(self, binary, **kwargs): + op = self.operator_string(binary.operator) + if callable(op): + return op(self.process(binary.left), self.process(binary.right)) + else: + return self.process(binary.left) + " " + op + " " + self.process(binary.right) + + def operator_string(self, operator): + return self.operators.get(operator, str(operator)) - def visit_bindparam(self, bindparam): + def visit_bindparam(self, bindparam, **kwargs): # apply truncation to the ultimate generated name if bindparam.shortname != bindparam.key: @@ -378,7 +411,6 @@ class ANSICompiler(sql.Compiled): if bindparam.unique: count = 1 key = bindparam.key - # redefine the generated name of the bind param in the case # that we have multiple conflicting bind parameters. while self.binds.setdefault(key, bindparam) is not bindparam: @@ -386,164 +418,167 @@ class ANSICompiler(sql.Compiled): key = bindparam.key + tag count += 1 bindparam.key = key - self.strings[bindparam] = self.bindparam_string(self._truncate_bindparam(bindparam)) + return self.bindparam_string(self._truncate_bindparam(bindparam)) else: existing = self.binds.get(bindparam.key) if existing is not None and existing.unique: raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key) - self.strings[bindparam] = self.bindparam_string(self._truncate_bindparam(bindparam)) self.binds[bindparam.key] = bindparam + return self.bindparam_string(self._truncate_bindparam(bindparam)) def _truncate_bindparam(self, bindparam): if bindparam in self.bind_names: return self.bind_names[bindparam] bind_name = bindparam.key - if len(bind_name) > self.dialect.max_identifier_length(): - bind_name = self._truncated_identifier("bindparam", bind_name) - # add to bind_names for translation - self.bind_names[bindparam] = bind_name + bind_name = self._truncated_identifier("bindparam", bind_name) + # add to bind_names for translation + self.bind_names[bindparam] = bind_name + return bind_name def _truncated_identifier(self, ident_class, name): if (ident_class, name) in self.generated_ids: return self.generated_ids[(ident_class, name)] - if len(name) > self.dialect.max_identifier_length(): + + anonname = self._anonymize(name) + if len(anonname) > self.dialect.max_identifier_length(): counter = self.generated_ids.get(ident_class, 1) truncname = name[0:self.dialect.max_identifier_length() - 6] + "_" + hex(counter)[2:] self.generated_ids[ident_class] = counter + 1 else: - truncname = name + truncname = anonname self.generated_ids[(ident_class, name)] = truncname return truncname + + def _anonymize(self, name): + def anon(match): + (ident, derived) = match.group(1,2) + if ('anonymous', ident) in self.generated_ids: + return self.generated_ids[('anonymous', ident)] + else: + anonymous_counter = self.generated_ids.get('anonymous', 1) + newname = derived + "_" + str(anonymous_counter) + self.generated_ids['anonymous'] = anonymous_counter + 1 + self.generated_ids[('anonymous', ident)] = newname + return newname + return re.sub(r'{ANON (-?\d+) (.*)}', anon, name) def bindparam_string(self, name): return self.bindtemplate % name - def visit_alias(self, alias): - self.froms[alias] = self.get_from_text(alias.original) + " AS " + self.preparer.format_alias(alias) - self.strings[alias] = self.get_str(alias.original) + def visit_alias(self, alias, asfrom=False, **kwargs): + if asfrom: + return self.process(alias.original, asfrom=True, **kwargs) + " AS " + self.preparer.format_alias(alias, self._anonymize(alias.name)) + else: + return self.process(alias.original, **kwargs) - def visit_select(self, select): - # the actual list of columns to print in the SELECT column list. - inner_columns = util.OrderedDict() + def label_select_column(self, select, column): + """convert a column from a select's "columns" clause. + + given a select() and a column element from its inner_columns collection, return a + Label object if this column should be labeled in the columns clause. Otherwise, + return None and the column will be used as-is. + + The calling method will traverse the returned label to acquire its string + representation. + """ + + # SQLite doesnt like selecting from a subquery where the column + # names look like table.colname. so if column is in a "selected from" + # subquery, label it synoymously with its column name + if \ + self.correlate_state[select].get('is_selected_from', False) and \ + isinstance(column, sql._ColumnClause) and \ + not column.is_literal and \ + column.table is not None and \ + not isinstance(column.table, sql.Select): + return column.label(column.name) + else: + return None + + def visit_select(self, select, asfrom=False, **kwargs): + select._calculate_correlations(self.correlate_state) self.select_stack.append(select) - for c in select._raw_columns: - if hasattr(c, '_selectable'): - s = c._selectable() - else: - self.traverse(c) - inner_columns[self.get_str(c)] = c - continue - for co in s.columns: - if select.use_labels: - labelname = co._label - if labelname is not None: - l = co.label(labelname) - self.traverse(l) - inner_columns[labelname] = l - else: - self.traverse(co) - inner_columns[self.get_str(co)] = co - # TODO: figure this out, a ColumnClause with a select as a parent - # is different from any other kind of parent - elif select.is_selected_from and isinstance(co, sql._ColumnClause) and not co.is_literal and co.table is not None and not isinstance(co.table, sql.Select): - # SQLite doesnt like selecting from a subquery where the column - # names look like table.colname, so add a label synonomous with - # the column name - l = co.label(co.name) - self.traverse(l) - inner_columns[self.get_str(l.obj)] = l + + # the actual list of columns to print in the SELECT column list. + inner_columns = util.OrderedSet() + + froms = select._get_display_froms(self.correlate_state) + + for co in select.inner_columns: + if select.use_labels: + labelname = co._label + if labelname is not None: + l = co.label(labelname) + inner_columns.add(self.process(l)) else: self.traverse(co) - inner_columns[self.get_str(co)] = co + inner_columns.add(self.process(co)) + else: + l = self.label_select_column(select, co) + if l is not None: + inner_columns.add(self.process(l)) + else: + inner_columns.add(self.process(co)) + self.select_stack.pop(-1) - collist = string.join([self.get_str(v) for v in inner_columns.values()], ', ') + collist = string.join(inner_columns.difference(util.Set([None])), ', ') - text = "SELECT " - text += self.visit_select_precolumns(select) + text = " ".join(["SELECT"] + [self.process(x) for x in select._prefixes]) + " " + text += self.get_select_precolumns(select) text += collist - whereclause = select.whereclause - - froms = [] - for f in select.froms: - - if self.parameters is not None: - # TODO: whack this feature in 0.4 - # look at our own parameters, see if they - # are all present in the form of BindParamClauses. if - # not, then append to the above whereclause column conditions - # matching those keys - for c in f.columns: - if sql.is_column(c) and self.parameters.has_key(c.key) and not self.binds.has_key(c.key): - value = self.parameters[c.key] - else: - continue - clause = c==value - if whereclause is not None: - whereclause = self.traverse(sql.and_(clause, whereclause), stop_on=util.Set([whereclause])) - else: - whereclause = clause - self.traverse(whereclause) - - # special thingy used by oracle to redefine a join + whereclause = select._whereclause + + from_strings = [] + for f in froms: + from_strings.append(self.process(f, asfrom=True)) + w = self.get_whereclause(f) if w is not None: - # TODO: move this more into the oracle module if whereclause is not None: - whereclause = self.traverse(sql.and_(w, whereclause), stop_on=util.Set([whereclause, w])) + whereclause = sql.and_(w, whereclause) else: whereclause = w - t = self.get_from_text(f) - if t is not None: - froms.append(t) - if len(froms): text += " \nFROM " - text += string.join(froms, ', ') + text += string.join(from_strings, ', ') else: text += self.default_from() if whereclause is not None: - t = self.get_str(whereclause) + t = self.process(whereclause) if t: text += " \nWHERE " + t - group_by = self.get_str(select.group_by_clause) + group_by = self.process(select._group_by_clause) if group_by: text += " GROUP BY " + group_by - if select.having is not None: - t = self.get_str(select.having) + if select._having is not None: + t = self.process(select._having) if t: text += " \nHAVING " + t text += self.order_by_clause(select) - text += self.visit_select_postclauses(select) + text += (select._limit or select._offset) and self.limit_clause(select) or "" text += self.for_update_clause(select) - self.strings[select] = text - self.froms[select] = "(" + text + ")" + if asfrom: + return "(" + text + ")" + else: + return text - def visit_select_precolumns(self, select): + def get_select_precolumns(self, select): """Called when building a ``SELECT`` statement, position is just before column list.""" - - return select.distinct and "DISTINCT " or "" - - def visit_select_postclauses(self, select): - """Called when building a ``SELECT`` statement, position is after all other ``SELECT`` clauses. - - Most DB syntaxes put ``LIMIT``/``OFFSET`` here. - """ - - return (select.limit or select.offset) and self.limit_clause(select) or "" + return select._distinct and "DISTINCT " or "" def order_by_clause(self, select): - order_by = self.get_str(select.order_by_clause) + order_by = self.process(select._order_by_clause) if order_by: return " ORDER BY " + order_by else: @@ -557,175 +592,103 @@ class ANSICompiler(sql.Compiled): def limit_clause(self, select): text = "" - if select.limit is not None: - text += " \n LIMIT " + str(select.limit) - if select.offset is not None: - if select.limit is None: + if select._limit is not None: + text += " \n LIMIT " + str(select._limit) + if select._offset is not None: + if select._limit is None: text += " \n LIMIT -1" - text += " OFFSET " + str(select.offset) + text += " OFFSET " + str(select._offset) return text - def visit_table(self, table): - self.froms[table] = self.preparer.format_table(table) - self.strings[table] = "" - - def visit_join(self, join): - righttext = self.get_from_text(join.right) - if join.right._group_parenthesized(): - righttext = "(" + righttext + ")" - if join.isouter: - self.froms[join] = (self.get_from_text(join.left) + " LEFT OUTER JOIN " + righttext + - " ON " + self.get_str(join.onclause)) + def visit_table(self, table, asfrom=False, **kwargs): + if asfrom: + return self.preparer.format_table(table) else: - self.froms[join] = (self.get_from_text(join.left) + " JOIN " + righttext + - " ON " + self.get_str(join.onclause)) - self.strings[join] = self.froms[join] - - def visit_insert_column_default(self, column, default, parameters): - """Called when visiting an ``Insert`` statement. - - For each column in the table that contains a ``ColumnDefault`` - object, add a blank *placeholder* parameter so the ``Insert`` - gets compiled with this column's name in its column and - ``VALUES`` clauses. - """ - - parameters.setdefault(column.key, None) - - def visit_update_column_default(self, column, default, parameters): - """Called when visiting an ``Update`` statement. - - For each column in the table that contains a ``ColumnDefault`` - object as an onupdate, add a blank *placeholder* parameter so - the ``Update`` gets compiled with this column's name as one of - its ``SET`` clauses. - """ - - parameters.setdefault(column.key, None) - - def visit_insert_sequence(self, column, sequence, parameters): - """Called when visiting an ``Insert`` statement. - - This may be overridden compilers that support sequences to - place a blank *placeholder* parameter for each column in the - table that contains a Sequence object, so the Insert gets - compiled with this column's name in its column and ``VALUES`` - clauses. - """ - - pass - - def visit_insert_column(self, column, parameters): - """Called when visiting an ``Insert`` statement. - - This may be overridden by compilers who disallow NULL columns - being set in an ``Insert`` where there is a default value on - the column (i.e. postgres), to remove the column for which - there is a NULL insert from the parameter list. - """ + return "" - pass + def visit_join(self, join, asfrom=False, **kwargs): + return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \ + self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause)) + def uses_sequences_for_inserts(self): + return False + def visit_insert(self, insert_stmt): - # scan the table's columns for defaults that have to be pre-set for an INSERT - # add these columns to the parameter list via visit_insert_XXX methods - default_params = {} + + # search for columns who will be required to have an explicit bound value. + # for inserts, this includes Python-side defaults, columns with sequences for dialects + # that support sequences, and primary key columns for dialects that explicitly insert + # pre-generated primary key values + required_cols = util.Set() class DefaultVisitor(schema.SchemaVisitor): - def visit_column(s, c): - self.visit_insert_column(c, default_params) + def visit_column(s, cd): + if c.primary_key and self.uses_sequences_for_inserts(): + required_cols.add(c) def visit_column_default(s, cd): - self.visit_insert_column_default(c, cd, default_params) + required_cols.add(c) def visit_sequence(s, seq): - self.visit_insert_sequence(c, seq, default_params) + if self.uses_sequences_for_inserts(): + required_cols.add(c) vis = DefaultVisitor() for c in insert_stmt.table.c: if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)): vis.traverse(c) self.isinsert = True - colparams = self._get_colparams(insert_stmt, default_params) + colparams = self._get_colparams(insert_stmt, required_cols) - self.inline_params = util.Set() - def create_param(col, p): - if isinstance(p, sql._BindParamClause): - self.binds[p.key] = p - if p.shortname is not None: - self.binds[p.shortname] = p - return self.bindparam_string(self._truncate_bindparam(p)) - else: - self.inline_params.add(col) - self.traverse(p) - if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnElement): - return "(" + self.get_str(p) + ")" - else: - return self.get_str(p) - - text = ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" + - " VALUES (" + string.join([create_param(*c) for c in colparams], ', ') + ")") - - self.strings[insert_stmt] = text + return ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" + + " VALUES (" + string.join([c[1] for c in colparams], ', ') + ")") def visit_update(self, update_stmt): - # scan the table's columns for onupdates that have to be pre-set for an UPDATE - # add these columns to the parameter list via visit_update_XXX methods - default_params = {} + update_stmt._calculate_correlations(self.correlate_state) + + # search for columns who will be required to have an explicit bound value. + # for updates, this includes Python-side "onupdate" defaults. + required_cols = util.Set() class OnUpdateVisitor(schema.SchemaVisitor): def visit_column_onupdate(s, cd): - self.visit_update_column_default(c, cd, default_params) + required_cols.add(c) vis = OnUpdateVisitor() for c in update_stmt.table.c: if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)): vis.traverse(c) self.isupdate = True - colparams = self._get_colparams(update_stmt, default_params) - - self.inline_params = util.Set() - def create_param(col, p): - if isinstance(p, sql._BindParamClause): - self.binds[p.key] = p - self.binds[p.shortname] = p - return self.bindparam_string(self._truncate_bindparam(p)) - else: - self.traverse(p) - self.inline_params.add(col) - if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnElement): - return "(" + self.get_str(p) + ")" - else: - return self.get_str(p) - - text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), create_param(*c)) for c in colparams], ', ') - - if update_stmt.whereclause: - text += " WHERE " + self.get_str(update_stmt.whereclause) + colparams = self._get_colparams(update_stmt, required_cols) - self.strings[update_stmt] = text + text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams], ', ') + if update_stmt._whereclause: + text += " WHERE " + self.process(update_stmt._whereclause) - def _get_colparams(self, stmt, default_params): - """Organize ``UPDATE``/``INSERT`` ``SET``/``VALUES`` parameters into a list of tuples. - - Each tuple will contain the ``Column`` and a ``ClauseElement`` - representing the value to be set (usually a ``_BindParamClause``, - but could also be other SQL expressions.) - - The list of tuples will determine the columns that are - actually rendered into the ``SET``/``VALUES`` clause of the - rendered ``UPDATE``/``INSERT`` statement. It will also - determine how to generate the list/dictionary of bind - parameters at execution time (i.e. ``get_params()``). + return text - This list takes into account the `values` keyword specified - to the statement, the parameters sent to this Compiled - instance, and the default bind parameter values corresponding - to the dialect's behavior for otherwise unspecified primary - key columns. + def _get_colparams(self, stmt, required_cols): + """create a set of tuples representing column/string pairs for use + in an INSERT or UPDATE statement. + + This method may generate new bind params within this compiled + based on the given set of "required columns", which are required + to have a value set in the statement. """ + def create_bind_param(col, value): + bindparam = sql.bindparam(col.key, value, type_=col.type, unique=True) + self.binds[col.key] = bindparam + return self.bindparam_string(self._truncate_bindparam(bindparam)) + # no parameters in the statement, no parameters in the # compiled params - return binds for all columns if self.parameters is None and stmt.parameters is None: - return [(c, sql.bindparam(c.key, type=c.type)) for c in stmt.table.columns] + return [(c, create_bind_param(c, None)) for c in stmt.table.columns] + + def create_clause_param(col, value): + self.traverse(value) + self.inline_params.add(col) + return self.process(value) + + self.inline_params = util.Set() def to_col(key): if not isinstance(key, sql._ColumnClause): @@ -744,29 +707,43 @@ class ANSICompiler(sql.Compiled): for k, v in stmt.parameters.iteritems(): parameters.setdefault(to_col(k), v) - for k, v in default_params.iteritems(): - parameters.setdefault(to_col(k), v) + for col in required_cols: + parameters.setdefault(col, None) # create a list of column assignment clauses as tuples values = [] for c in stmt.table.columns: - if parameters.has_key(c): + if c in parameters: value = parameters[c] if sql._is_literal(value): - value = sql.bindparam(c.key, value, type=c.type, unique=True) + value = create_bind_param(c, value) + else: + value = create_clause_param(c, value) values.append((c, value)) + return values def visit_delete(self, delete_stmt): + delete_stmt._calculate_correlations(self.correlate_state) + text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table) - if delete_stmt.whereclause: - text += " WHERE " + self.get_str(delete_stmt.whereclause) + if delete_stmt._whereclause: + text += " WHERE " + self.process(delete_stmt._whereclause) - self.strings[delete_stmt] = text + return text + + def visit_savepoint(self, savepoint_stmt): + return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident) + def visit_rollback_to_savepoint(self, savepoint_stmt): + return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident) + + def visit_release_savepoint(self, savepoint_stmt): + return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident) + def __str__(self): - return self.get_str(self.statement) + return self.string class ANSISchemaBase(engine.SchemaIterator): def find_alterables(self, tables): @@ -795,7 +772,7 @@ class ANSISchemaGenerator(ANSISchemaBase): def visit_metadata(self, metadata): collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name, schema=t.schema))] for table in collection: - table.accept_visitor(self) + self.traverse_single(table) if self.dialect.supports_alter(): for alterable in self.find_alterables(collection): self.add_foreignkey(alterable) @@ -803,9 +780,7 @@ class ANSISchemaGenerator(ANSISchemaBase): def visit_table(self, table): for column in table.columns: if column.default is not None: - column.default.accept_visitor(self) - #if column.onupdate is not None: - # column.onupdate.accept_visitor(visitor) + self.traverse_single(column.default) self.append("\nCREATE TABLE " + self.preparer.format_table(table) + " (") @@ -820,20 +795,20 @@ class ANSISchemaGenerator(ANSISchemaBase): if column.primary_key: first_pk = True for constraint in column.constraints: - constraint.accept_visitor(self) + self.traverse_single(constraint) # On some DB order is significant: visit PK first, then the # other constraints (engine.ReflectionTest.testbasic failed on FB2) if len(table.primary_key): - table.primary_key.accept_visitor(self) + self.traverse_single(table.primary_key) for constraint in [c for c in table.constraints if c is not table.primary_key]: - constraint.accept_visitor(self) + self.traverse_single(constraint) self.append("\n)%s\n\n" % self.post_create_table(table)) self.execute() if hasattr(table, 'indexes'): for index in table.indexes: - index.accept_visitor(self) + self.traverse_single(index) def post_create_table(self, table): return '' @@ -870,7 +845,7 @@ class ANSISchemaGenerator(ANSISchemaBase): if constraint.name is not None: self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) self.append("PRIMARY KEY ") - self.append("(%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', '))) + self.append("(%s)" % ', '.join([self.preparer.format_column(c) for c in constraint])) def visit_foreign_key_constraint(self, constraint): if constraint.use_alter and self.dialect.supports_alter(): @@ -889,9 +864,9 @@ class ANSISchemaGenerator(ANSISchemaBase): self.append("CONSTRAINT %s " % preparer.format_constraint(constraint)) self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % ( - string.join([preparer.format_column(f.parent) for f in constraint.elements], ', '), + ', '.join([preparer.format_column(f.parent) for f in constraint.elements]), preparer.format_table(list(constraint.elements)[0].column.table), - string.join([preparer.format_column(f.column) for f in constraint.elements], ', ') + ', '.join([preparer.format_column(f.column) for f in constraint.elements]) )) if constraint.ondelete is not None: self.append(" ON DELETE %s" % constraint.ondelete) @@ -903,17 +878,17 @@ class ANSISchemaGenerator(ANSISchemaBase): if constraint.name is not None: self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) - self.append(" UNIQUE (%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', '))) + self.append(" UNIQUE (%s)" % (', '.join([self.preparer.format_column(c) for c in constraint]))) def visit_column(self, column): pass def visit_index(self, index): - preparer = self.preparer - self.append('CREATE ') + preparer = self.preparer + self.append("CREATE ") if index.unique: - self.append('UNIQUE ') - self.append('INDEX %s ON %s (%s)' \ + self.append("UNIQUE ") + self.append("INDEX %s ON %s (%s)" \ % (preparer.format_index(index), preparer.format_table(index.table), string.join([preparer.format_column(c) for c in index.columns], ', '))) @@ -933,7 +908,7 @@ class ANSISchemaDropper(ANSISchemaBase): for alterable in self.find_alterables(collection): self.drop_foreignkey(alterable) for table in collection: - table.accept_visitor(self) + self.traverse_single(table) def visit_index(self, index): self.append("\nDROP INDEX " + self.preparer.format_index(index)) @@ -948,7 +923,7 @@ class ANSISchemaDropper(ANSISchemaBase): def visit_table(self, table): for column in table.columns: if column.default is not None: - column.default.accept_visitor(self) + self.traverse_single(column.default) self.append("\nDROP TABLE " + self.preparer.format_table(table)) self.execute() @@ -1048,17 +1023,17 @@ class ANSIIdentifierPreparer(object): def should_quote(self, object): return object.quote or self._requires_quotes(object.name, object.case_sensitive) - def is_natural_case(self, object): - return object.quote or self._requires_quotes(object.name, object.case_sensitive) - def format_sequence(self, sequence): return self.__generic_obj_format(sequence, sequence.name) def format_label(self, label, name=None): return self.__generic_obj_format(label, name or label.name) - def format_alias(self, alias): - return self.__generic_obj_format(alias, alias.name) + def format_alias(self, alias, name=None): + return self.__generic_obj_format(alias, name or alias.name) + + def format_savepoint(self, savepoint): + return self.__generic_obj_format(savepoint, savepoint) def format_constraint(self, constraint): return self.__generic_obj_format(constraint, constraint.name) @@ -1076,25 +1051,25 @@ class ANSIIdentifierPreparer(object): result = self.__generic_obj_format(table, table.schema) + "." + result return result - def format_column(self, column, use_table=False, name=None): + def format_column(self, column, use_table=False, name=None, table_name=None): """Prepare a quoted column name.""" if name is None: name = column.name if not getattr(column, 'is_literal', False): if use_table: - return self.format_table(column.table, use_schema=False) + "." + self.__generic_obj_format(column, name) + return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.__generic_obj_format(column, name) else: return self.__generic_obj_format(column, name) else: # literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted if use_table: - return self.format_table(column.table, use_schema=False) + "." + name + return self.format_table(column.table, use_schema=False, name=table_name) + "." + name else: return name - def format_column_with_table(self, column, column_name=None): + def format_column_with_table(self, column, column_name=None, table_name=None): """Prepare a quoted column name with table name.""" - return self.format_column(column, use_table=True, name=column_name) + return self.format_column(column, use_table=True, name=column_name, table_name=table_name) dialect = ANSIDialect |