diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 1114 |
1 files changed, 1114 insertions, 0 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py new file mode 100644 index 000000000..6053c72be --- /dev/null +++ b/lib/sqlalchemy/sql/compiler.py @@ -0,0 +1,1114 @@ +# compiler.py +# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""SQL expression compilation routines and DDL implementations.""" + +import string, re +from sqlalchemy import schema, engine, util, exceptions +from sqlalchemy.sql import operators, visitors +from sqlalchemy.sql import util as sql_util +from sqlalchemy.sql import expression as sql + +ANSI_FUNCS = util.Set([ + 'CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP', + 'CURRENT_USER', 'LOCALTIME', 'LOCALTIMESTAMP', + 'SESSION_USER', 'USER']) + +RESERVED_WORDS = util.Set([ + 'all', 'analyse', 'analyze', 'and', 'any', 'array', + 'as', 'asc', 'asymmetric', 'authorization', 'between', + 'binary', 'both', 'case', 'cast', 'check', 'collate', + 'column', 'constraint', 'create', 'cross', 'current_date', + 'current_role', 'current_time', 'current_timestamp', + 'current_user', 'default', 'deferrable', 'desc', + 'distinct', 'do', 'else', 'end', 'except', 'false', + 'for', 'foreign', 'freeze', 'from', 'full', 'grant', + 'group', 'having', 'ilike', 'in', 'initially', 'inner', + 'intersect', 'into', 'is', 'isnull', 'join', 'leading', + 'left', 'like', 'limit', 'localtime', 'localtimestamp', + 'natural', 'new', 'not', 'notnull', 'null', 'off', 'offset', + 'old', 'on', 'only', 'or', 'order', 'outer', 'overlaps', + 'placing', 'primary', 'references', 'right', 'select', + 'session_user', 'set', 'similar', 'some', 'symmetric', 'table', + 'then', 'to', 'trailing', 'true', 'union', 'unique', 'user', + 'using', 'verbose', 'when', 'where']) + +LEGAL_CHARACTERS = util.Set(string.ascii_lowercase + + string.ascii_uppercase + + string.digits + '_$') +ILLEGAL_INITIAL_CHARACTERS = util.Set(string.digits + '$') + +BIND_PARAMS = re.compile(r'(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])', re.UNICODE) +BIND_PARAMS_ESC = re.compile(r'\x5c(:[\w\$]+)(?![:\w\$])', re.UNICODE) +ANONYMOUS_LABEL = re.compile(r'{ANON (-?\d+) (.*)}') + +OPERATORS = { + operators.and_ : 'AND', + operators.or_ : 'OR', + operators.inv : 'NOT', + operators.add : '+', + operators.mul : '*', + operators.sub : '-', + operators.div : '/', + operators.mod : '%', + operators.truediv : '/', + operators.lt : '<', + operators.le : '<=', + operators.ne : '!=', + operators.gt : '>', + operators.ge : '>=', + operators.eq : '=', + operators.distinct_op : 'DISTINCT', + operators.concat_op : '||', + operators.like_op : 'LIKE', + operators.notlike_op : 'NOT LIKE', + operators.ilike_op : 'ILIKE', + operators.notilike_op : 'NOT ILIKE', + operators.between_op : 'BETWEEN', + operators.in_op : 'IN', + operators.notin_op : 'NOT IN', + operators.comma_op : ', ', + operators.desc_op : 'DESC', + operators.asc_op : 'ASC', + operators.from_ : 'FROM', + operators.as_ : 'AS', + operators.exists : 'EXISTS', + operators.is_ : 'IS', + operators.isnot : 'IS NOT' +} + +class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): + """Default implementation of Compiled. + + Compiles ClauseElements into SQL strings. + """ + + __traverse_options__ = {'column_collections':False, 'entry':True} + + operators = OPERATORS + + def __init__(self, dialect, statement, parameters=None, **kwargs): + """Construct a new ``DefaultCompiler`` object. + + dialect + Dialect to be used + + statement + ClauseElement to be compiled + + parameters + optional dictionary indicating a set of bind parameters + specified with this Compiled object. These parameters are + the *default* key/value pairs when the Compiled is executed, + and also may affect the actual compilation, as in the case + of an INSERT where the actual columns inserted will + correspond to the keys present in the parameters. + """ + + super(DefaultCompiler, self).__init__(dialect, statement, parameters, **kwargs) + + # if we are insert/update. set to true when we visit an INSERT or UPDATE + self.isinsert = self.isupdate = False + + # a dictionary of bind parameter keys to _BindParamClause instances. + self.binds = {} + + # a dictionary of _BindParamClause instances to "compiled" names that are + # actually present in the generated SQL + self.bind_names = {} + + # a stack. what recursive compiler doesn't have a stack ? :) + self.stack = [] + + # a dictionary of result-set column names (strings) to TypeEngine instances, + # which will be passed to a ResultProxy and used for resultset-level value conversion + self.typemap = {} + + # a dictionary of select columns labels mapped to their "generated" label + self.column_labels = {} + + # a dictionary of ClauseElement subclasses to counters, which are used to + # generate truncated identifier names or "anonymous" identifiers such as + # for aliases + self.generated_ids = {} + + # default formatting style for bind parameters + self.bindtemplate = ":%s" + + # paramstyle from the dialect (comes from DBAPI) + self.paramstyle = self.dialect.paramstyle + + # true if the paramstyle is positional + self.positional = self.dialect.positional + + # a list of the compiled's bind parameter names, used to help + # formulate a positional argument list + self.positiontup = [] + + # an IdentifierPreparer that formats the quoting of identifiers + self.preparer = self.dialect.identifier_preparer + + # for UPDATE and INSERT statements, a set of columns whos values are being set + # from a SQL expression (i.e., not one of the bind parameter values). if present, + # default-value logic in the Dialect knows not to fire off column defaults + # and also knows postfetching will be needed to get the values represented by these + # parameters. + self.inline_params = None + + def after_compile(self): + # this re will search for params like :param + # it has a negative lookbehind for an extra ':' so that it doesnt match + # postgres '::text' tokens + text = self.string + if ':' not in text: + return + + if self.paramstyle=='pyformat': + text = BIND_PARAMS.sub(lambda m:'%(' + m.group(1) +')s', text) + elif self.positional: + params = BIND_PARAMS.finditer(text) + for p in params: + self.positiontup.append(p.group(1)) + if self.paramstyle=='qmark': + text = BIND_PARAMS.sub('?', text) + elif self.paramstyle=='format': + text = BIND_PARAMS.sub('%s', text) + elif self.paramstyle=='numeric': + i = [0] + def getnum(x): + i[0] += 1 + return str(i[0]) + text = BIND_PARAMS.sub(getnum, text) + # un-escape any \:params + text = BIND_PARAMS_ESC.sub(lambda m: m.group(1), text) + self.string = text + + def compile(self): + self.string = self.process(self.statement) + self.after_compile() + + def process(self, obj, stack=None, **kwargs): + if stack: + self.stack.append(stack) + try: + return self.traverse_single(obj, **kwargs) + finally: + if stack: + self.stack.pop(-1) + + def is_subquery(self, select): + return self.stack and self.stack[-1].get('is_subquery') + + def get_whereclause(self, obj): + """given a FROM clause, return an additional WHERE condition that should be + applied to a SELECT. + + Currently used by Oracle to provide WHERE criterion for JOIN and OUTER JOIN + constructs in non-ansi mode. + """ + + return None + + def construct_params(self, params): + """Return a sql.util.ClauseParameters object. + + Combines the given bind parameter dictionary (string keys to object values) + with the _BindParamClause objects stored within this Compiled object + to produce a ClauseParameters structure, representing the bind arguments + for a single statement execution, or one element of an executemany execution. + """ + + d = sql_util.ClauseParameters(self.dialect, self.positiontup) + + pd = self.parameters or {} + pd.update(params) + + for key, bind in self.binds.iteritems(): + d.set_parameter(bind, pd.get(key, bind.value), self.bind_names[bind]) + + return d + + params = property(lambda self:self.construct_params({}), doc="""Return the `ClauseParameters` corresponding to this compiled object. + A shortcut for `construct_params()`.""") + + def default_from(self): + """Called when a SELECT statement has no froms, and no FROM clause is to be appended. + + Gives Oracle a chance to tack on a ``FROM DUAL`` to the string output. + """ + + return "" + + def visit_grouping(self, grouping, **kwargs): + return "(" + self.process(grouping.elem) + ")" + + def visit_label(self, label): + labelname = self._truncated_identifier("colident", label.name) + + if self.stack and self.stack[-1].get('select'): + self.typemap.setdefault(labelname.lower(), label.obj.type) + if isinstance(label.obj, sql._ColumnClause): + self.column_labels[label.obj._label] = labelname + self.column_labels[label.name] = labelname + return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)]) + + def visit_column(self, column, **kwargs): + # there is actually somewhat of a ruleset when you would *not* necessarily + # want to truncate a column identifier, if its mapped to the name of a + # physical column. but thats very hard to identify at this point, and + # the identifier length should be greater than the id lengths of any physical + # columns so should not matter. + if not column.is_literal: + name = self._truncated_identifier("colident", column.name) + else: + name = column.name + + if self.stack and self.stack[-1].get('select'): + # if we are within a visit to a Select, set up the "typemap" + # for this column which is used to translate result set values + self.typemap.setdefault(name.lower(), column.type) + self.column_labels.setdefault(column._label, name.lower()) + + if column.table is None or not column.table.named_with_column(): + return self.preparer.format_column(column, name=name) + else: + if column.table.oid_column is column: + n = self.dialect.oid_column_name(column) + if n is not None: + return "%s.%s" % (self.preparer.format_table(column.table, use_schema=False, name=self._anonymize(column.table.name)), n) + elif len(column.table.primary_key) != 0: + pk = list(column.table.primary_key)[0] + pkname = (pk.is_literal and name or self._truncated_identifier("colident", pk.name)) + return self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname, table_name=self._anonymize(column.table.name)) + else: + return None + else: + return self.preparer.format_column_with_table(column, column_name=name, table_name=self._anonymize(column.table.name)) + + + def visit_fromclause(self, fromclause, **kwargs): + return fromclause.name + + def visit_index(self, index, **kwargs): + return index.name + + def visit_typeclause(self, typeclause, **kwargs): + return typeclause.type.dialect_impl(self.dialect).get_col_spec() + + def visit_textclause(self, textclause, **kwargs): + for bind in textclause.bindparams.values(): + self.process(bind) + if textclause.typemap is not None: + self.typemap.update(textclause.typemap) + return textclause.text + + def visit_null(self, null, **kwargs): + return 'NULL' + + def visit_clauselist(self, clauselist, **kwargs): + sep = clauselist.operator + if sep is None: + sep = " " + elif sep == operators.comma_op: + sep = ', ' + else: + sep = " " + self.operator_string(clauselist.operator) + " " + return string.join([s for s in [self.process(c) for c in clauselist.clauses] if s is not None], sep) + + def apply_function_parens(self, func): + return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0 + + def visit_calculatedclause(self, clause, **kwargs): + return self.process(clause.clause_expr) + + def visit_cast(self, cast, **kwargs): + if self.stack and self.stack[-1].get('select'): + # not sure if we want to set the typemap here... + self.typemap.setdefault("CAST", cast.type) + return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause)) + + def visit_function(self, func, **kwargs): + if self.stack and self.stack[-1].get('select'): + self.typemap.setdefault(func.name, func.type) + if not self.apply_function_parens(func): + return ".".join(func.packagenames + [func.name]) + else: + return ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.process(func.clause_expr) + + def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs): + stack_entry = {'select':cs} + + if asfrom: + stack_entry['is_selected_from'] = stack_entry['is_subquery'] = True + elif self.stack and self.stack[-1].get('select'): + stack_entry['is_subquery'] = True + self.stack.append(stack_entry) + + text = string.join([self.process(c, asfrom=asfrom, parens=False) for c in cs.selects], " " + cs.keyword + " ") + group_by = self.process(cs._group_by_clause, asfrom=asfrom) + if group_by: + text += " GROUP BY " + group_by + + text += self.order_by_clause(cs) + text += (cs._limit or cs._offset) and self.limit_clause(cs) or "" + + self.stack.pop(-1) + + if asfrom and parens: + return "(" + text + ")" + else: + return text + + def visit_unary(self, unary, **kwargs): + s = self.process(unary.element) + if unary.operator: + s = self.operator_string(unary.operator) + " " + s + if unary.modifier: + s = s + " " + self.operator_string(unary.modifier) + return s + + def visit_binary(self, binary, **kwargs): + op = self.operator_string(binary.operator) + if callable(op): + return op(self.process(binary.left), self.process(binary.right)) + else: + return self.process(binary.left) + " " + op + " " + self.process(binary.right) + + def operator_string(self, operator): + return self.operators.get(operator, str(operator)) + + def visit_bindparam(self, bindparam, **kwargs): + # apply truncation to the ultimate generated name + + if bindparam.shortname != bindparam.key: + self.binds.setdefault(bindparam.shortname, bindparam) + + if bindparam.unique: + count = 1 + key = bindparam.key + # redefine the generated name of the bind param in the case + # that we have multiple conflicting bind parameters. + while self.binds.setdefault(key, bindparam) is not bindparam: + tag = "_%d" % count + key = bindparam.key + tag + count += 1 + bindparam.key = key + return self.bindparam_string(self._truncate_bindparam(bindparam)) + else: + existing = self.binds.get(bindparam.key) + if existing is not None and existing.unique: + raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key) + self.binds[bindparam.key] = bindparam + return self.bindparam_string(self._truncate_bindparam(bindparam)) + + def _truncate_bindparam(self, bindparam): + if bindparam in self.bind_names: + return self.bind_names[bindparam] + + bind_name = bindparam.key + bind_name = self._truncated_identifier("bindparam", bind_name) + # add to bind_names for translation + self.bind_names[bindparam] = bind_name + + return bind_name + + def _truncated_identifier(self, ident_class, name): + if (ident_class, name) in self.generated_ids: + return self.generated_ids[(ident_class, name)] + + anonname = ANONYMOUS_LABEL.sub(self._process_anon, name) + + if len(anonname) > self.dialect.max_identifier_length(): + counter = self.generated_ids.get(ident_class, 1) + truncname = name[0:self.dialect.max_identifier_length() - 6] + "_" + hex(counter)[2:] + self.generated_ids[ident_class] = counter + 1 + else: + truncname = anonname + self.generated_ids[(ident_class, name)] = truncname + return truncname + + def _process_anon(self, match): + (ident, derived) = match.group(1,2) + if ('anonymous', ident) in self.generated_ids: + return self.generated_ids[('anonymous', ident)] + else: + anonymous_counter = self.generated_ids.get('anonymous', 1) + newname = derived + "_" + str(anonymous_counter) + self.generated_ids['anonymous'] = anonymous_counter + 1 + self.generated_ids[('anonymous', ident)] = newname + return newname + + def _anonymize(self, name): + return ANONYMOUS_LABEL.sub(self._process_anon, name) + + def bindparam_string(self, name): + return self.bindtemplate % name + + def visit_alias(self, alias, asfrom=False, **kwargs): + if asfrom: + return self.process(alias.original, asfrom=True, **kwargs) + " AS " + self.preparer.format_alias(alias, self._anonymize(alias.name)) + else: + return self.process(alias.original, **kwargs) + + def label_select_column(self, select, column): + """convert a column from a select's "columns" clause. + + given a select() and a column element from its inner_columns collection, return a + Label object if this column should be labeled in the columns clause. Otherwise, + return None and the column will be used as-is. + + The calling method will traverse the returned label to acquire its string + representation. + """ + + # SQLite doesnt like selecting from a subquery where the column + # names look like table.colname. so if column is in a "selected from" + # subquery, label it synoymously with its column name + if \ + (self.stack and self.stack[-1].get('is_selected_from')) and \ + isinstance(column, sql._ColumnClause) and \ + not column.is_literal and \ + column.table is not None and \ + not isinstance(column.table, sql.Select): + return column.label(column.name) + else: + return None + + def visit_select(self, select, asfrom=False, parens=True, **kwargs): + + stack_entry = {'select':select} + + if asfrom: + stack_entry['is_selected_from'] = stack_entry['is_subquery'] = True + elif self.stack and self.stack[-1].get('select'): + stack_entry['is_subquery'] = True + + if self.stack and self.stack[-1].get('from'): + existingfroms = self.stack[-1]['from'] + else: + existingfroms = None + froms = select._get_display_froms(existingfroms) + + correlate_froms = util.Set() + for f in froms: + correlate_froms.add(f) + for f2 in f._get_from_objects(): + correlate_froms.add(f2) + + # TODO: might want to propigate existing froms for select(select(select)) + # where innermost select should correlate to outermost +# if existingfroms: +# correlate_froms = correlate_froms.union(existingfroms) + stack_entry['from'] = correlate_froms + self.stack.append(stack_entry) + + # the actual list of columns to print in the SELECT column list. + inner_columns = util.OrderedSet() + + for co in select.inner_columns: + if select.use_labels: + labelname = co._label + if labelname is not None: + l = co.label(labelname) + inner_columns.add(self.process(l)) + else: + self.traverse(co) + inner_columns.add(self.process(co)) + else: + l = self.label_select_column(select, co) + if l is not None: + inner_columns.add(self.process(l)) + else: + inner_columns.add(self.process(co)) + + collist = string.join(inner_columns.difference(util.Set([None])), ', ') + + text = " ".join(["SELECT"] + [self.process(x) for x in select._prefixes]) + " " + text += self.get_select_precolumns(select) + text += collist + + whereclause = select._whereclause + + from_strings = [] + for f in froms: + from_strings.append(self.process(f, asfrom=True)) + + w = self.get_whereclause(f) + if w is not None: + if whereclause is not None: + whereclause = sql.and_(w, whereclause) + else: + whereclause = w + + if froms: + text += " \nFROM " + text += string.join(from_strings, ', ') + else: + text += self.default_from() + + if whereclause is not None: + t = self.process(whereclause) + if t: + text += " \nWHERE " + t + + group_by = self.process(select._group_by_clause) + if group_by: + text += " GROUP BY " + group_by + + if select._having is not None: + t = self.process(select._having) + if t: + text += " \nHAVING " + t + + text += self.order_by_clause(select) + text += (select._limit or select._offset) and self.limit_clause(select) or "" + text += self.for_update_clause(select) + + self.stack.pop(-1) + + if asfrom and parens: + return "(" + text + ")" + else: + return text + + def get_select_precolumns(self, select): + """Called when building a ``SELECT`` statement, position is just before column list.""" + return select._distinct and "DISTINCT " or "" + + def order_by_clause(self, select): + order_by = self.process(select._order_by_clause) + if order_by: + return " ORDER BY " + order_by + else: + return "" + + def for_update_clause(self, select): + if select.for_update: + return " FOR UPDATE" + else: + return "" + + def limit_clause(self, select): + text = "" + if select._limit is not None: + text += " \n LIMIT " + str(select._limit) + if select._offset is not None: + if select._limit is None: + text += " \n LIMIT -1" + text += " OFFSET " + str(select._offset) + return text + + def visit_table(self, table, asfrom=False, **kwargs): + if asfrom: + return self.preparer.format_table(table) + else: + return "" + + def visit_join(self, join, asfrom=False, **kwargs): + return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \ + self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause)) + + def uses_sequences_for_inserts(self): + return False + + def visit_insert(self, insert_stmt): + + # search for columns who will be required to have an explicit bound value. + # for inserts, this includes Python-side defaults, columns with sequences for dialects + # that support sequences, and primary key columns for dialects that explicitly insert + # pre-generated primary key values + required_cols = util.Set() + class DefaultVisitor(schema.SchemaVisitor): + def visit_column(s, cd): + if c.primary_key and self.uses_sequences_for_inserts(): + required_cols.add(c) + def visit_column_default(s, cd): + required_cols.add(c) + def visit_sequence(s, seq): + if self.uses_sequences_for_inserts(): + required_cols.add(c) + vis = DefaultVisitor() + for c in insert_stmt.table.c: + if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)): + vis.traverse(c) + + self.isinsert = True + colparams = self._get_colparams(insert_stmt, required_cols) + + return ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" + + " VALUES (" + string.join([c[1] for c in colparams], ', ') + ")") + + def visit_update(self, update_stmt): + self.stack.append({'from':util.Set([update_stmt.table])}) + + # search for columns who will be required to have an explicit bound value. + # for updates, this includes Python-side "onupdate" defaults. + required_cols = util.Set() + class OnUpdateVisitor(schema.SchemaVisitor): + def visit_column_onupdate(s, cd): + required_cols.add(c) + vis = OnUpdateVisitor() + for c in update_stmt.table.c: + if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)): + vis.traverse(c) + + self.isupdate = True + colparams = self._get_colparams(update_stmt, required_cols) + + text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams], ', ') + + if update_stmt._whereclause: + text += " WHERE " + self.process(update_stmt._whereclause) + + self.stack.pop(-1) + + return text + + def _get_colparams(self, stmt, required_cols): + """create a set of tuples representing column/string pairs for use + in an INSERT or UPDATE statement. + + This method may generate new bind params within this compiled + based on the given set of "required columns", which are required + to have a value set in the statement. + """ + + def create_bind_param(col, value): + bindparam = sql.bindparam(col.key, value, type_=col.type, unique=True) + self.binds[col.key] = bindparam + return self.bindparam_string(self._truncate_bindparam(bindparam)) + + # no parameters in the statement, no parameters in the + # compiled params - return binds for all columns + if self.parameters is None and stmt.parameters is None: + return [(c, create_bind_param(c, None)) for c in stmt.table.columns] + + def create_clause_param(col, value): + self.traverse(value) + self.inline_params.add(col) + return self.process(value) + + self.inline_params = util.Set() + + def to_col(key): + if not isinstance(key, sql._ColumnClause): + return stmt.table.columns.get(unicode(key), key) + else: + return key + + # if we have statement parameters - set defaults in the + # compiled params + if self.parameters is None: + parameters = {} + else: + parameters = dict([(to_col(k), v) for k, v in self.parameters.iteritems()]) + + if stmt.parameters is not None: + for k, v in stmt.parameters.iteritems(): + parameters.setdefault(to_col(k), v) + + for col in required_cols: + parameters.setdefault(col, None) + + # create a list of column assignment clauses as tuples + values = [] + for c in stmt.table.columns: + if c in parameters: + value = parameters[c] + if sql._is_literal(value): + value = create_bind_param(c, value) + else: + value = create_clause_param(c, value) + values.append((c, value)) + + return values + + def visit_delete(self, delete_stmt): + self.stack.append({'from':util.Set([delete_stmt.table])}) + + text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table) + + if delete_stmt._whereclause: + text += " WHERE " + self.process(delete_stmt._whereclause) + + self.stack.pop(-1) + + return text + + def visit_savepoint(self, savepoint_stmt): + return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) + + def visit_rollback_to_savepoint(self, savepoint_stmt): + return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) + + def visit_release_savepoint(self, savepoint_stmt): + return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) + + def __str__(self): + return self.string + +class DDLBase(engine.SchemaIterator): + def find_alterables(self, tables): + alterables = [] + class FindAlterables(schema.SchemaVisitor): + def visit_foreign_key_constraint(self, constraint): + if constraint.use_alter and constraint.table in tables: + alterables.append(constraint) + findalterables = FindAlterables() + for table in tables: + for c in table.constraints: + findalterables.traverse(c) + return alterables + +class SchemaGenerator(DDLBase): + def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): + super(SchemaGenerator, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + self.tables = tables and util.Set(tables) or None + self.preparer = dialect.identifier_preparer + self.dialect = dialect + + def get_column_specification(self, column, first_pk=False): + raise NotImplementedError() + + def visit_metadata(self, metadata): + collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name, schema=t.schema))] + for table in collection: + self.traverse_single(table) + if self.dialect.supports_alter(): + for alterable in self.find_alterables(collection): + self.add_foreignkey(alterable) + + def visit_table(self, table): + for column in table.columns: + if column.default is not None: + self.traverse_single(column.default) + + self.append("\nCREATE TABLE " + self.preparer.format_table(table) + " (") + + separator = "\n" + + # if only one primary key, specify it along with the column + first_pk = False + for column in table.columns: + self.append(separator) + separator = ", \n" + self.append("\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk)) + if column.primary_key: + first_pk = True + for constraint in column.constraints: + self.traverse_single(constraint) + + # On some DB order is significant: visit PK first, then the + # other constraints (engine.ReflectionTest.testbasic failed on FB2) + if table.primary_key: + self.traverse_single(table.primary_key) + for constraint in [c for c in table.constraints if c is not table.primary_key]: + self.traverse_single(constraint) + + self.append("\n)%s\n\n" % self.post_create_table(table)) + self.execute() + if hasattr(table, 'indexes'): + for index in table.indexes: + self.traverse_single(index) + + def post_create_table(self, table): + return '' + + def get_column_default_string(self, column): + if isinstance(column.default, schema.PassiveDefault): + if isinstance(column.default.arg, basestring): + return "'%s'" % column.default.arg + else: + return unicode(self._compile(column.default.arg, None)) + else: + return None + + def _compile(self, tocompile, parameters): + """compile the given string/parameters using this SchemaGenerator's dialect.""" + compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters) + compiler.compile() + return compiler + + def visit_check_constraint(self, constraint): + self.append(", \n\t") + if constraint.name is not None: + self.append("CONSTRAINT %s " % + self.preparer.format_constraint(constraint)) + self.append(" CHECK (%s)" % constraint.sqltext) + + def visit_column_check_constraint(self, constraint): + self.append(" CHECK (%s)" % constraint.sqltext) + + def visit_primary_key_constraint(self, constraint): + if len(constraint) == 0: + return + self.append(", \n\t") + if constraint.name is not None: + self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) + self.append("PRIMARY KEY ") + self.append("(%s)" % ', '.join([self.preparer.format_column(c) for c in constraint])) + + def visit_foreign_key_constraint(self, constraint): + if constraint.use_alter and self.dialect.supports_alter(): + return + self.append(", \n\t ") + self.define_foreign_key(constraint) + + def add_foreignkey(self, constraint): + self.append("ALTER TABLE %s ADD " % self.preparer.format_table(constraint.table)) + self.define_foreign_key(constraint) + self.execute() + + def define_foreign_key(self, constraint): + preparer = self.preparer + if constraint.name is not None: + self.append("CONSTRAINT %s " % + preparer.format_constraint(constraint)) + self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % ( + ', '.join([preparer.format_column(f.parent) for f in constraint.elements]), + preparer.format_table(list(constraint.elements)[0].column.table), + ', '.join([preparer.format_column(f.column) for f in constraint.elements]) + )) + if constraint.ondelete is not None: + self.append(" ON DELETE %s" % constraint.ondelete) + if constraint.onupdate is not None: + self.append(" ON UPDATE %s" % constraint.onupdate) + + def visit_unique_constraint(self, constraint): + self.append(", \n\t") + if constraint.name is not None: + self.append("CONSTRAINT %s " % + self.preparer.format_constraint(constraint)) + self.append(" UNIQUE (%s)" % (', '.join([self.preparer.format_column(c) for c in constraint]))) + + def visit_column(self, column): + pass + + def visit_index(self, index): + preparer = self.preparer + self.append("CREATE ") + if index.unique: + self.append("UNIQUE ") + self.append("INDEX %s ON %s (%s)" \ + % (preparer.format_index(index), + preparer.format_table(index.table), + string.join([preparer.format_column(c) for c in index.columns], ', '))) + self.execute() + +class SchemaDropper(DDLBase): + def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): + super(SchemaDropper, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + self.tables = tables + self.preparer = dialect.identifier_preparer + self.dialect = dialect + + def visit_metadata(self, metadata): + collection = [t for t in metadata.table_iterator(reverse=True, tables=self.tables) if (not self.checkfirst or self.dialect.has_table(self.connection, t.name, schema=t.schema))] + if self.dialect.supports_alter(): + for alterable in self.find_alterables(collection): + self.drop_foreignkey(alterable) + for table in collection: + self.traverse_single(table) + + def visit_index(self, index): + self.append("\nDROP INDEX " + self.preparer.format_index(index)) + self.execute() + + def drop_foreignkey(self, constraint): + self.append("ALTER TABLE %s DROP CONSTRAINT %s" % ( + self.preparer.format_table(constraint.table), + self.preparer.format_constraint(constraint))) + self.execute() + + def visit_table(self, table): + for column in table.columns: + if column.default is not None: + self.traverse_single(column.default) + + self.append("\nDROP TABLE " + self.preparer.format_table(table)) + self.execute() + +class IdentifierPreparer(object): + """Handle quoting and case-folding of identifiers based on options.""" + + def __init__(self, dialect, initial_quote='"', final_quote=None, omit_schema=False): + """Construct a new ``IdentifierPreparer`` object. + + initial_quote + Character that begins a delimited identifier. + + final_quote + Character that ends a delimited identifier. Defaults to `initial_quote`. + + omit_schema + Prevent prepending schema name. Useful for databases that do + not support schemae. + """ + + self.dialect = dialect + self.initial_quote = initial_quote + self.final_quote = final_quote or self.initial_quote + self.omit_schema = omit_schema + self.__strings = {} + + def _escape_identifier(self, value): + """Escape an identifier. + + Subclasses should override this to provide database-dependent + escaping behavior. + """ + + return value.replace('"', '""') + + def _unescape_identifier(self, value): + """Canonicalize an escaped identifier. + + Subclasses should override this to provide database-dependent + unescaping behavior that reverses _escape_identifier. + """ + + return value.replace('""', '"') + + def quote_identifier(self, value): + """Quote an identifier. + + Subclasses should override this to provide database-dependent + quoting behavior. + """ + + return self.initial_quote + self._escape_identifier(value) + self.final_quote + + def _fold_identifier_case(self, value): + """Fold the case of an identifier. + + Subclasses should override this to provide database-dependent + case folding behavior. + """ + + return value + # ANSI SQL calls for the case of all unquoted identifiers to be folded to UPPER. + # some tests would need to be rewritten if this is done. + #return value.upper() + + def _reserved_words(self): + return RESERVED_WORDS + + def _legal_characters(self): + return LEGAL_CHARACTERS + + def _illegal_initial_characters(self): + return ILLEGAL_INITIAL_CHARACTERS + + def _requires_quotes(self, value): + """Return True if the given identifier requires quoting.""" + return \ + value in self._reserved_words() \ + or (value[0] in self._illegal_initial_characters()) \ + or bool(len([x for x in unicode(value) if x not in self._legal_characters()])) \ + or (value.lower() != value) + + def __generic_obj_format(self, obj, ident): + if getattr(obj, 'quote', False): + return self.quote_identifier(ident) + try: + return self.__strings[ident] + except KeyError: + if self._requires_quotes(ident): + self.__strings[ident] = self.quote_identifier(ident) + else: + self.__strings[ident] = ident + return self.__strings[ident] + + def should_quote(self, object): + return object.quote or self._requires_quotes(object.name) + + def format_sequence(self, sequence): + return self.__generic_obj_format(sequence, sequence.name) + + def format_label(self, label, name=None): + return self.__generic_obj_format(label, name or label.name) + + def format_alias(self, alias, name=None): + return self.__generic_obj_format(alias, name or alias.name) + + def format_savepoint(self, savepoint, name=None): + return self.__generic_obj_format(savepoint, name or savepoint.ident) + + def format_constraint(self, constraint): + return self.__generic_obj_format(constraint, constraint.name) + + def format_index(self, index): + return self.__generic_obj_format(index, index.name) + + def format_table(self, table, use_schema=True, name=None): + """Prepare a quoted table and schema name.""" + + if name is None: + name = table.name + result = self.__generic_obj_format(table, name) + if use_schema and getattr(table, "schema", None): + result = self.__generic_obj_format(table, table.schema) + "." + result + return result + + def format_column(self, column, use_table=False, name=None, table_name=None): + """Prepare a quoted column name.""" + if name is None: + name = column.name + if not getattr(column, 'is_literal', False): + if use_table: + return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.__generic_obj_format(column, name) + else: + return self.__generic_obj_format(column, name) + else: + # literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted + if use_table: + return self.format_table(column.table, use_schema=False, name=table_name) + "." + name + else: + return name + + def format_column_with_table(self, column, column_name=None, table_name=None): + """Prepare a quoted column name with table name.""" + + return self.format_column(column, use_table=True, name=column_name, table_name=table_name) + + + def format_table_seq(self, table, use_schema=True): + """Format table name and schema as a tuple.""" + + # Dialects with more levels in their fully qualified references + # ('database', 'owner', etc.) could override this and return + # a longer sequence. + + if use_schema and getattr(table, 'schema', None): + return (self.quote_identifier(table.schema), + self.format_table(table, use_schema=False)) + else: + return (self.format_table(table, use_schema=False), ) + + def unformat_identifiers(self, identifiers): + """Unpack 'schema.table.column'-like strings into components.""" + + try: + r = self._r_identifiers + except AttributeError: + initial, final, escaped_final = \ + [re.escape(s) for s in + (self.initial_quote, self.final_quote, + self._escape_identifier(self.final_quote))] + r = re.compile( + r'(?:' + r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s' + r'|([^\.]+))(?=\.|$))+' % + { 'initial': initial, + 'final': final, + 'escaped': escaped_final }) + self._r_identifiers = r + + return [self._unescape_identifier(i) + for i in [a or b for a, b in r.findall(identifiers)]] + |