# compiler.py # Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php """SQL expression compilation routines and DDL implementations.""" import string, re from sqlalchemy import schema, engine, util, exceptions from sqlalchemy.sql import operators, visitors from sqlalchemy.sql import util as sql_util from sqlalchemy.sql import expression as sql ANSI_FUNCS = util.Set([ 'CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP', 'CURRENT_USER', 'LOCALTIME', 'LOCALTIMESTAMP', 'SESSION_USER', 'USER']) RESERVED_WORDS = util.Set([ 'all', 'analyse', 'analyze', 'and', 'any', 'array', 'as', 'asc', 'asymmetric', 'authorization', 'between', 'binary', 'both', 'case', 'cast', 'check', 'collate', 'column', 'constraint', 'create', 'cross', 'current_date', 'current_role', 'current_time', 'current_timestamp', 'current_user', 'default', 'deferrable', 'desc', 'distinct', 'do', 'else', 'end', 'except', 'false', 'for', 'foreign', 'freeze', 'from', 'full', 'grant', 'group', 'having', 'ilike', 'in', 'initially', 'inner', 'intersect', 'into', 'is', 'isnull', 'join', 'leading', 'left', 'like', 'limit', 'localtime', 'localtimestamp', 'natural', 'new', 'not', 'notnull', 'null', 'off', 'offset', 'old', 'on', 'only', 'or', 'order', 'outer', 'overlaps', 'placing', 'primary', 'references', 'right', 'select', 'session_user', 'set', 'similar', 'some', 'symmetric', 'table', 'then', 'to', 'trailing', 'true', 'union', 'unique', 'user', 'using', 'verbose', 'when', 'where']) LEGAL_CHARACTERS = util.Set(string.ascii_lowercase + string.ascii_uppercase + string.digits + '_$') ILLEGAL_INITIAL_CHARACTERS = util.Set(string.digits + '$') BIND_PARAMS = re.compile(r'(?', operators.ge : '>=', operators.eq : '=', operators.distinct_op : 'DISTINCT', operators.concat_op : '||', operators.like_op : 'LIKE', operators.notlike_op : 'NOT LIKE', operators.ilike_op : 'ILIKE', operators.notilike_op : 'NOT ILIKE', operators.between_op : 'BETWEEN', operators.in_op : 'IN', operators.notin_op : 'NOT IN', operators.comma_op : ', ', operators.desc_op : 'DESC', operators.asc_op : 'ASC', operators.from_ : 'FROM', operators.as_ : 'AS', operators.exists : 'EXISTS', operators.is_ : 'IS', operators.isnot : 'IS NOT' } class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): """Default implementation of Compiled. Compiles ClauseElements into SQL strings. """ __traverse_options__ = {'column_collections':False, 'entry':True} operators = OPERATORS def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): """Construct a new ``DefaultCompiler`` object. dialect Dialect to be used statement ClauseElement to be compiled column_keys a list of column names to be compiled into an INSERT or UPDATE statement. """ super(DefaultCompiler, self).__init__(dialect, statement, column_keys, **kwargs) # if we are insert/update. set to true when we visit an INSERT or UPDATE self.isinsert = self.isupdate = False # compile INSERT/UPDATE defaults/sequences inlined (no pre-execute) self.inline = inline or getattr(statement, 'inline', False) # a dictionary of bind parameter keys to _BindParamClause instances. self.binds = {} # a dictionary of _BindParamClause instances to "compiled" names that are # actually present in the generated SQL self.bind_names = {} # a stack. what recursive compiler doesn't have a stack ? :) self.stack = [] # a dictionary of result-set column names (strings) to TypeEngine instances, # which will be passed to a ResultProxy and used for resultset-level value conversion self.typemap = {} # a dictionary of select columns labels mapped to their "generated" label self.column_labels = {} # a dictionary of ClauseElement subclasses to counters, which are used to # generate truncated identifier names or "anonymous" identifiers such as # for aliases self.generated_ids = {} # default formatting style for bind parameters self.bindtemplate = ":%s" # paramstyle from the dialect (comes from DB-API) self.paramstyle = self.dialect.paramstyle # true if the paramstyle is positional self.positional = self.dialect.positional # a list of the compiled's bind parameter names, used to help # formulate a positional argument list self.positiontup = [] # an IdentifierPreparer that formats the quoting of identifiers self.preparer = self.dialect.identifier_preparer def after_compile(self): # this re will search for params like :param # it has a negative lookbehind for an extra ':' so that it doesnt match # postgres '::text' tokens text = self.string if ':' not in text: return if self.paramstyle=='pyformat': text = BIND_PARAMS.sub(lambda m:'%(' + m.group(1) +')s', text) elif self.positional: params = BIND_PARAMS.finditer(text) for p in params: self.positiontup.append(p.group(1)) if self.paramstyle=='qmark': text = BIND_PARAMS.sub('?', text) elif self.paramstyle=='format': text = BIND_PARAMS.sub('%s', text) elif self.paramstyle=='numeric': i = [0] def getnum(x): i[0] += 1 return str(i[0]) text = BIND_PARAMS.sub(getnum, text) # un-escape any \:params text = BIND_PARAMS_ESC.sub(lambda m: m.group(1), text) self.string = text def compile(self): self.string = self.process(self.statement) self.after_compile() def process(self, obj, stack=None, **kwargs): if stack: self.stack.append(stack) try: return self.traverse_single(obj, **kwargs) finally: if stack: self.stack.pop(-1) def is_subquery(self, select): return self.stack and self.stack[-1].get('is_subquery') def get_whereclause(self, obj): """given a FROM clause, return an additional WHERE condition that should be applied to a SELECT. Currently used by Oracle to provide WHERE criterion for JOIN and OUTER JOIN constructs in non-ansi mode. """ return None def construct_params(self, params=None): """Return a sql.util.ClauseParameters object. Combines the given bind parameter dictionary (string keys to object values) with the _BindParamClause objects stored within this Compiled object to produce a ClauseParameters structure, representing the bind arguments for a single statement execution, or one element of an executemany execution. """ d = sql_util.ClauseParameters(self.dialect, self.positiontup) pd = params or {} bind_names = self.bind_names for key, bind in self.binds.iteritems(): # the following is an inlined ClauseParameters.set_parameter() name = bind_names[bind] d._binds[name] = [bind, name, pd.get(key, bind.value)] return d params = property(lambda self:self.construct_params(), doc="""Return the `ClauseParameters` corresponding to this compiled object. A shortcut for `construct_params()`.""") def default_from(self): """Called when a SELECT statement has no froms, and no FROM clause is to be appended. Gives Oracle a chance to tack on a ``FROM DUAL`` to the string output. """ return "" def visit_grouping(self, grouping, **kwargs): return "(" + self.process(grouping.elem) + ")" def visit_label(self, label): labelname = self._truncated_identifier("colident", label.name) if self.stack and self.stack[-1].get('select'): self.typemap.setdefault(labelname.lower(), label.obj.type) if isinstance(label.obj, sql._ColumnClause): self.column_labels[label.obj._label] = labelname self.column_labels[label.name] = labelname return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)]) def visit_column(self, column, **kwargs): # there is actually somewhat of a ruleset when you would *not* necessarily # want to truncate a column identifier, if its mapped to the name of a # physical column. but thats very hard to identify at this point, and # the identifier length should be greater than the id lengths of any physical # columns so should not matter. if not column.is_literal: name = self._truncated_identifier("colident", column.name) else: name = column.name if self.stack and self.stack[-1].get('select'): # if we are within a visit to a Select, set up the "typemap" # for this column which is used to translate result set values self.typemap.setdefault(name.lower(), column.type) self.column_labels.setdefault(column._label, name.lower()) if column.table is None or not column.table.named_with_column(): return self.preparer.format_column(column, name=name) else: if column.table.oid_column is column: n = self.dialect.oid_column_name(column) if n is not None: return "%s.%s" % (self.preparer.format_table(column.table, use_schema=False, name=ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)), n) elif len(column.table.primary_key) != 0: pk = list(column.table.primary_key)[0] pkname = (pk.is_literal and name or self._truncated_identifier("colident", pk.name)) return self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname, table_name=ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) else: return None else: return self.preparer.format_column_with_table(column, column_name=name, table_name=ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) def visit_fromclause(self, fromclause, **kwargs): return fromclause.name def visit_index(self, index, **kwargs): return index.name def visit_typeclause(self, typeclause, **kwargs): return typeclause.type.dialect_impl(self.dialect).get_col_spec() def visit_textclause(self, textclause, **kwargs): for bind in textclause.bindparams.values(): self.process(bind) if textclause.typemap is not None: self.typemap.update(textclause.typemap) return textclause.text def visit_null(self, null, **kwargs): return 'NULL' def visit_clauselist(self, clauselist, **kwargs): sep = clauselist.operator if sep is None: sep = " " elif sep == operators.comma_op: sep = ', ' else: sep = " " + self.operator_string(clauselist.operator) + " " return string.join([s for s in [self.process(c) for c in clauselist.clauses] if s is not None], sep) def apply_function_parens(self, func): return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0 def visit_calculatedclause(self, clause, **kwargs): return self.process(clause.clause_expr) def visit_cast(self, cast, **kwargs): if self.stack and self.stack[-1].get('select'): # not sure if we want to set the typemap here... self.typemap.setdefault("CAST", cast.type) return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause)) def visit_function(self, func, **kwargs): if self.stack and self.stack[-1].get('select'): self.typemap.setdefault(func.name, func.type) if not self.apply_function_parens(func): return ".".join(func.packagenames + [func.name]) else: return ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.process(func.clause_expr) def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs): stack_entry = {'select':cs} if asfrom: stack_entry['is_selected_from'] = stack_entry['is_subquery'] = True elif self.stack and self.stack[-1].get('select'): stack_entry['is_subquery'] = True self.stack.append(stack_entry) text = string.join([self.process(c, asfrom=asfrom, parens=False) for c in cs.selects], " " + cs.keyword + " ") group_by = self.process(cs._group_by_clause, asfrom=asfrom) if group_by: text += " GROUP BY " + group_by text += self.order_by_clause(cs) text += (cs._limit or cs._offset) and self.limit_clause(cs) or "" self.stack.pop(-1) if asfrom and parens: return "(" + text + ")" else: return text def visit_unary(self, unary, **kwargs): s = self.process(unary.element) if unary.operator: s = self.operator_string(unary.operator) + " " + s if unary.modifier: s = s + " " + self.operator_string(unary.modifier) return s def visit_binary(self, binary, **kwargs): op = self.operator_string(binary.operator) if callable(op): return op(self.process(binary.left), self.process(binary.right)) else: return self.process(binary.left) + " " + op + " " + self.process(binary.right) def operator_string(self, operator): return self.operators.get(operator, str(operator)) def visit_bindparam(self, bindparam, **kwargs): # apply truncation to the ultimate generated name if bindparam.shortname != bindparam.key: self.binds.setdefault(bindparam.shortname, bindparam) if bindparam.unique: count = 1 key = bindparam.key # redefine the generated name of the bind param in the case # that we have multiple conflicting bind parameters. while self.binds.setdefault(key, bindparam) is not bindparam: tag = "_%d" % count key = bindparam.key + tag count += 1 bindparam.key = key return self.bindparam_string(self._truncate_bindparam(bindparam)) else: existing = self.binds.get(bindparam.key) if existing is not None and existing.unique: raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key) self.binds[bindparam.key] = bindparam return self.bindparam_string(self._truncate_bindparam(bindparam)) def _truncate_bindparam(self, bindparam): if bindparam in self.bind_names: return self.bind_names[bindparam] bind_name = bindparam.key bind_name = self._truncated_identifier("bindparam", bind_name) # add to bind_names for translation self.bind_names[bindparam] = bind_name return bind_name def _truncated_identifier(self, ident_class, name): if (ident_class, name) in self.generated_ids: return self.generated_ids[(ident_class, name)] anonname = ANONYMOUS_LABEL.sub(self._process_anon, name) if len(anonname) > self.dialect.max_identifier_length: counter = self.generated_ids.get(ident_class, 1) truncname = name[0:self.dialect.max_identifier_length - 6] + "_" + hex(counter)[2:] self.generated_ids[ident_class] = counter + 1 else: truncname = anonname self.generated_ids[(ident_class, name)] = truncname return truncname def _process_anon(self, match): (ident, derived) = match.group(1,2) if ('anonymous', ident) in self.generated_ids: return self.generated_ids[('anonymous', ident)] else: anonymous_counter = self.generated_ids.get('anonymous', 1) newname = derived + "_" + str(anonymous_counter) self.generated_ids['anonymous'] = anonymous_counter + 1 self.generated_ids[('anonymous', ident)] = newname return newname def _anonymize(self, name): return ANONYMOUS_LABEL.sub(self._process_anon, name) def bindparam_string(self, name): return self.bindtemplate % name def visit_alias(self, alias, asfrom=False, **kwargs): if asfrom: return self.process(alias.original, asfrom=True, **kwargs) + " AS " + self.preparer.format_alias(alias, self._anonymize(alias.name)) else: return self.process(alias.original, **kwargs) def label_select_column(self, select, column): """convert a column from a select's "columns" clause. given a select() and a column element from its inner_columns collection, return a Label object if this column should be labeled in the columns clause. Otherwise, return None and the column will be used as-is. The calling method will traverse the returned label to acquire its string representation. """ # SQLite doesnt like selecting from a subquery where the column # names look like table.colname. so if column is in a "selected from" # subquery, label it synoymously with its column name if \ (self.stack and self.stack[-1].get('is_selected_from')) and \ isinstance(column, sql._ColumnClause) and \ not column.is_literal and \ column.table is not None and \ not isinstance(column.table, sql.Select): return column.label(column.name) else: return None def visit_select(self, select, asfrom=False, parens=True, **kwargs): stack_entry = {'select':select} if asfrom: stack_entry['is_selected_from'] = stack_entry['is_subquery'] = True elif self.stack and self.stack[-1].get('select'): stack_entry['is_subquery'] = True if self.stack and self.stack[-1].get('from'): existingfroms = self.stack[-1]['from'] else: existingfroms = None froms = select._get_display_froms(existingfroms) correlate_froms = util.Set() for f in froms: correlate_froms.add(f) for f2 in f._get_from_objects(): correlate_froms.add(f2) # TODO: might want to propigate existing froms for select(select(select)) # where innermost select should correlate to outermost # if existingfroms: # correlate_froms = correlate_froms.union(existingfroms) stack_entry['from'] = correlate_froms self.stack.append(stack_entry) # the actual list of columns to print in the SELECT column list. inner_columns = util.OrderedSet() for co in select.inner_columns: if select.use_labels: labelname = co._label if labelname is not None: l = co.label(labelname) inner_columns.add(self.process(l)) else: inner_columns.add(self.process(co)) else: l = self.label_select_column(select, co) if l is not None: inner_columns.add(self.process(l)) else: inner_columns.add(self.process(co)) collist = string.join(inner_columns.difference(util.Set([None])), ', ') text = " ".join(["SELECT"] + [self.process(x) for x in select._prefixes]) + " " text += self.get_select_precolumns(select) text += collist whereclause = select._whereclause from_strings = [] for f in froms: from_strings.append(self.process(f, asfrom=True)) w = self.get_whereclause(f) if w is not None: if whereclause is not None: whereclause = sql.and_(w, whereclause) else: whereclause = w if froms: text += " \nFROM " text += string.join(from_strings, ', ') else: text += self.default_from() if whereclause is not None: t = self.process(whereclause) if t: text += " \nWHERE " + t group_by = self.process(select._group_by_clause) if group_by: text += " GROUP BY " + group_by if select._having is not None: t = self.process(select._having) if t: text += " \nHAVING " + t text += self.order_by_clause(select) text += (select._limit or select._offset) and self.limit_clause(select) or "" text += self.for_update_clause(select) self.stack.pop(-1) if asfrom and parens: return "(" + text + ")" else: return text def get_select_precolumns(self, select): """Called when building a ``SELECT`` statement, position is just before column list.""" return select._distinct and "DISTINCT " or "" def order_by_clause(self, select): order_by = self.process(select._order_by_clause) if order_by: return " ORDER BY " + order_by else: return "" def for_update_clause(self, select): if select.for_update: return " FOR UPDATE" else: return "" def limit_clause(self, select): text = "" if select._limit is not None: text += " \n LIMIT " + str(select._limit) if select._offset is not None: if select._limit is None: text += " \n LIMIT -1" text += " OFFSET " + str(select._offset) return text def visit_table(self, table, asfrom=False, **kwargs): if asfrom: return self.preparer.format_table(table) else: return "" def visit_join(self, join, asfrom=False, **kwargs): return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \ self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause)) def visit_sequence(self, seq): return None def visit_insert(self, insert_stmt): self.isinsert = True colparams = self._get_colparams(insert_stmt) return ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" + " VALUES (" + string.join([c[1] for c in colparams], ', ') + ")") def visit_update(self, update_stmt): self.stack.append({'from':util.Set([update_stmt.table])}) self.isupdate = True colparams = self._get_colparams(update_stmt) text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams], ', ') if update_stmt._whereclause: text += " WHERE " + self.process(update_stmt._whereclause) self.stack.pop(-1) return text def _get_colparams(self, stmt): """create a set of tuples representing column/string pairs for use in an INSERT or UPDATE statement. """ def create_bind_param(col, value): bindparam = sql.bindparam(col.key, value, type_=col.type, unique=True) self.binds[col.key] = bindparam return self.bindparam_string(self._truncate_bindparam(bindparam)) self.postfetch = util.Set() self.prefetch = util.Set() # no parameters in the statement, no parameters in the # compiled params - return binds for all columns if self.column_keys is None and stmt.parameters is None: return [(c, create_bind_param(c, None)) for c in stmt.table.columns] # if we have statement parameters - set defaults in the # compiled params if self.column_keys is None: parameters = {} else: parameters = dict([(getattr(key, 'key', key), None) for key in self.column_keys]) if stmt.parameters is not None: for k, v in stmt.parameters.iteritems(): parameters.setdefault(getattr(k, 'key', k), v) # create a list of column assignment clauses as tuples values = [] for c in stmt.table.columns: if c.key in parameters: value = parameters[c.key] if sql._is_literal(value): value = create_bind_param(c, value) else: self.postfetch.add(c) value = self.process(value.self_group()) values.append((c, value)) elif isinstance(c, schema.Column): if self.isinsert: if c.primary_key and self.dialect.preexecute_sequences and not self.inline: values.append((c, create_bind_param(c, None))) self.prefetch.add(c) elif isinstance(c.default, schema.ColumnDefault): if isinstance(c.default.arg, sql.ClauseElement): values.append((c, self.process(c.default.arg))) self.postfetch.add(c) else: values.append((c, create_bind_param(c, None))) self.prefetch.add(c) elif isinstance(c.default, schema.PassiveDefault): self.postfetch.add(c) elif isinstance(c.default, schema.Sequence): proc = self.process(c.default) if proc is not None: values.append((c, proc)) self.postfetch.add(c) elif self.isupdate: if isinstance(c.onupdate, schema.ColumnDefault): if isinstance(c.onupdate.arg, sql.ClauseElement): values.append((c, self.process(c.onupdate.arg))) self.postfetch.add(c) else: values.append((c, create_bind_param(c, None))) self.prefetch.add(c) elif isinstance(c.onupdate, schema.PassiveDefault): self.postfetch.add(c) return values def visit_delete(self, delete_stmt): self.stack.append({'from':util.Set([delete_stmt.table])}) text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table) if delete_stmt._whereclause: text += " WHERE " + self.process(delete_stmt._whereclause) self.stack.pop(-1) return text def visit_savepoint(self, savepoint_stmt): return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) def visit_rollback_to_savepoint(self, savepoint_stmt): return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) def visit_release_savepoint(self, savepoint_stmt): return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) def __str__(self): return self.string class DDLBase(engine.SchemaIterator): def find_alterables(self, tables): alterables = [] class FindAlterables(schema.SchemaVisitor): def visit_foreign_key_constraint(self, constraint): if constraint.use_alter and constraint.table in tables: alterables.append(constraint) findalterables = FindAlterables() for table in tables: for c in table.constraints: findalterables.traverse(c) return alterables class SchemaGenerator(DDLBase): def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): super(SchemaGenerator, self).__init__(connection, **kwargs) self.checkfirst = checkfirst self.tables = tables and util.Set(tables) or None self.preparer = dialect.identifier_preparer self.dialect = dialect def get_column_specification(self, column, first_pk=False): raise NotImplementedError() def visit_metadata(self, metadata): collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name, schema=t.schema))] for table in collection: self.traverse_single(table) if self.dialect.supports_alter: for alterable in self.find_alterables(collection): self.add_foreignkey(alterable) def visit_table(self, table): for column in table.columns: if column.default is not None: self.traverse_single(column.default) self.append("\nCREATE TABLE " + self.preparer.format_table(table) + " (") separator = "\n" # if only one primary key, specify it along with the column first_pk = False for column in table.columns: self.append(separator) separator = ", \n" self.append("\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk)) if column.primary_key: first_pk = True for constraint in column.constraints: self.traverse_single(constraint) # On some DB order is significant: visit PK first, then the # other constraints (engine.ReflectionTest.testbasic failed on FB2) if table.primary_key: self.traverse_single(table.primary_key) for constraint in [c for c in table.constraints if c is not table.primary_key]: self.traverse_single(constraint) self.append("\n)%s\n\n" % self.post_create_table(table)) self.execute() if hasattr(table, 'indexes'): for index in table.indexes: self.traverse_single(index) def post_create_table(self, table): return '' def get_column_default_string(self, column): if isinstance(column.default, schema.PassiveDefault): if isinstance(column.default.arg, basestring): return "'%s'" % column.default.arg else: return unicode(self._compile(column.default.arg, None)) else: return None def _compile(self, tocompile, parameters): """compile the given string/parameters using this SchemaGenerator's dialect.""" compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters) compiler.compile() return compiler def visit_check_constraint(self, constraint): self.append(", \n\t") if constraint.name is not None: self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) self.append(" CHECK (%s)" % constraint.sqltext) def visit_column_check_constraint(self, constraint): self.append(" CHECK (%s)" % constraint.sqltext) def visit_primary_key_constraint(self, constraint): if len(constraint) == 0: return self.append(", \n\t") if constraint.name is not None: self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) self.append("PRIMARY KEY ") self.append("(%s)" % ', '.join([self.preparer.format_column(c) for c in constraint])) def visit_foreign_key_constraint(self, constraint): if constraint.use_alter and self.dialect.supports_alter: return self.append(", \n\t ") self.define_foreign_key(constraint) def add_foreignkey(self, constraint): self.append("ALTER TABLE %s ADD " % self.preparer.format_table(constraint.table)) self.define_foreign_key(constraint) self.execute() def define_foreign_key(self, constraint): preparer = self.preparer if constraint.name is not None: self.append("CONSTRAINT %s " % preparer.format_constraint(constraint)) self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % ( ', '.join([preparer.format_column(f.parent) for f in constraint.elements]), preparer.format_table(list(constraint.elements)[0].column.table), ', '.join([preparer.format_column(f.column) for f in constraint.elements]) )) if constraint.ondelete is not None: self.append(" ON DELETE %s" % constraint.ondelete) if constraint.onupdate is not None: self.append(" ON UPDATE %s" % constraint.onupdate) def visit_unique_constraint(self, constraint): self.append(", \n\t") if constraint.name is not None: self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) self.append(" UNIQUE (%s)" % (', '.join([self.preparer.format_column(c) for c in constraint]))) def visit_column(self, column): pass def visit_index(self, index): preparer = self.preparer self.append("CREATE ") if index.unique: self.append("UNIQUE ") self.append("INDEX %s ON %s (%s)" \ % (preparer.format_index(index), preparer.format_table(index.table), string.join([preparer.format_column(c) for c in index.columns], ', '))) self.execute() class SchemaDropper(DDLBase): def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): super(SchemaDropper, self).__init__(connection, **kwargs) self.checkfirst = checkfirst self.tables = tables self.preparer = dialect.identifier_preparer self.dialect = dialect def visit_metadata(self, metadata): collection = [t for t in metadata.table_iterator(reverse=True, tables=self.tables) if (not self.checkfirst or self.dialect.has_table(self.connection, t.name, schema=t.schema))] if self.dialect.supports_alter: for alterable in self.find_alterables(collection): self.drop_foreignkey(alterable) for table in collection: self.traverse_single(table) def visit_index(self, index): self.append("\nDROP INDEX " + self.preparer.format_index(index)) self.execute() def drop_foreignkey(self, constraint): self.append("ALTER TABLE %s DROP CONSTRAINT %s" % ( self.preparer.format_table(constraint.table), self.preparer.format_constraint(constraint))) self.execute() def visit_table(self, table): for column in table.columns: if column.default is not None: self.traverse_single(column.default) self.append("\nDROP TABLE " + self.preparer.format_table(table)) self.execute() class IdentifierPreparer(object): """Handle quoting and case-folding of identifiers based on options.""" reserved_words = RESERVED_WORDS legal_characters = LEGAL_CHARACTERS illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS def __init__(self, dialect, initial_quote='"', final_quote=None, omit_schema=False): """Construct a new ``IdentifierPreparer`` object. initial_quote Character that begins a delimited identifier. final_quote Character that ends a delimited identifier. Defaults to `initial_quote`. omit_schema Prevent prepending schema name. Useful for databases that do not support schemae. """ self.dialect = dialect self.initial_quote = initial_quote self.final_quote = final_quote or self.initial_quote self.omit_schema = omit_schema self.__strings = {} def _escape_identifier(self, value): """Escape an identifier. Subclasses should override this to provide database-dependent escaping behavior. """ return value.replace('"', '""') def _unescape_identifier(self, value): """Canonicalize an escaped identifier. Subclasses should override this to provide database-dependent unescaping behavior that reverses _escape_identifier. """ return value.replace('""', '"') def quote_identifier(self, value): """Quote an identifier. Subclasses should override this to provide database-dependent quoting behavior. """ return self.initial_quote + self._escape_identifier(value) + self.final_quote def _fold_identifier_case(self, value): """Fold the case of an identifier. Subclasses should override this to provide database-dependent case folding behavior. """ return value # ANSI SQL calls for the case of all unquoted identifiers to be folded to UPPER. # some tests would need to be rewritten if this is done. #return value.upper() def _requires_quotes(self, value): """Return True if the given identifier requires quoting.""" return \ value in self.reserved_words \ or (value[0] in self.illegal_initial_characters) \ or bool(len([x for x in unicode(value) if x not in self.legal_characters])) \ or (value.lower() != value) def __generic_obj_format(self, obj, ident): if getattr(obj, 'quote', False): return self.quote_identifier(ident) try: return self.__strings[ident] except KeyError: if self._requires_quotes(ident): self.__strings[ident] = self.quote_identifier(ident) else: self.__strings[ident] = ident return self.__strings[ident] def should_quote(self, object): return object.quote or self._requires_quotes(object.name) def format_sequence(self, sequence): return self.__generic_obj_format(sequence, sequence.name) def format_label(self, label, name=None): return self.__generic_obj_format(label, name or label.name) def format_alias(self, alias, name=None): return self.__generic_obj_format(alias, name or alias.name) def format_savepoint(self, savepoint, name=None): return self.__generic_obj_format(savepoint, name or savepoint.ident) def format_constraint(self, constraint): return self.__generic_obj_format(constraint, constraint.name) def format_index(self, index): return self.__generic_obj_format(index, index.name) def format_table(self, table, use_schema=True, name=None): """Prepare a quoted table and schema name.""" if name is None: name = table.name result = self.__generic_obj_format(table, name) if use_schema and getattr(table, "schema", None): result = self.__generic_obj_format(table, table.schema) + "." + result return result def format_column(self, column, use_table=False, name=None, table_name=None): """Prepare a quoted column name.""" if name is None: name = column.name if not getattr(column, 'is_literal', False): if use_table: return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.__generic_obj_format(column, name) else: return self.__generic_obj_format(column, name) else: # literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted if use_table: return self.format_table(column.table, use_schema=False, name=table_name) + "." + name else: return name def format_column_with_table(self, column, column_name=None, table_name=None): """Prepare a quoted column name with table name.""" return self.format_column(column, use_table=True, name=column_name, table_name=table_name) def format_table_seq(self, table, use_schema=True): """Format table name and schema as a tuple.""" # Dialects with more levels in their fully qualified references # ('database', 'owner', etc.) could override this and return # a longer sequence. if use_schema and getattr(table, 'schema', None): return (self.quote_identifier(table.schema), self.format_table(table, use_schema=False)) else: return (self.format_table(table, use_schema=False), ) def unformat_identifiers(self, identifiers): """Unpack 'schema.table.column'-like strings into components.""" try: r = self._r_identifiers except AttributeError: initial, final, escaped_final = \ [re.escape(s) for s in (self.initial_quote, self.final_quote, self._escape_identifier(self.final_quote))] r = re.compile( r'(?:' r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s' r'|([^\.]+))(?=\.|$))+' % { 'initial': initial, 'final': final, 'escaped': escaped_final }) self._r_identifiers = r return [self._unescape_identifier(i) for i in [a or b for a, b in r.findall(identifiers)]]