diff options
author | Jason Kirtland <jek@discorporate.us> | 2008-02-14 20:02:10 +0000 |
---|---|---|
committer | Jason Kirtland <jek@discorporate.us> | 2008-02-14 20:02:10 +0000 |
commit | 71e745e96b8c5be990b3dc949cb99310dd055609 (patch) | |
tree | 00c748e65e7e85e0231a1c7c504dec6cfcab8e87 /lib/sqlalchemy/sql/compiler.py | |
parent | 8dd5eb402ef65194af4c54a6fd33a181b7d5eaf0 (diff) | |
download | sqlalchemy-71e745e96b8c5be990b3dc949cb99310dd055609.tar.gz |
- Fixed a couple pyflakes, cleaned up imports & whitespace
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 172 |
1 files changed, 84 insertions, 88 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 3f32778d6..8d8cfa38f 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -7,21 +7,20 @@ """Base SQL and DDL compiler implementations. Provides the [sqlalchemy.sql.compiler#DefaultCompiler] class, which is -responsible for generating all SQL query strings, as well as +responsible for generating all SQL query strings, as well as [sqlalchemy.sql.compiler#SchemaGenerator] and [sqlalchemy.sql.compiler#SchemaDropper] which issue CREATE and DROP DDL for tables, sequences, and indexes. The elements in this module are used by public-facing constructs like [sqlalchemy.sql.expression#ClauseElement] and [sqlalchemy.engine#Engine]. While dialect authors will want to be familiar with this module for the purpose of -creating database-specific compilers and schema generators, the module +creating database-specific compilers and schema generators, the module is otherwise internal to SQLAlchemy. """ import string, re from sqlalchemy import schema, engine, util, exceptions -from sqlalchemy.sql import operators, visitors, functions -from sqlalchemy.sql import util as sql_util +from sqlalchemy.sql import operators, functions from sqlalchemy.sql import expression as sql RESERVED_WORDS = util.Set([ @@ -57,7 +56,7 @@ BIND_TEMPLATES = { 'numeric':"%(position)s", 'named':":%(name)s" } - + OPERATORS = { operators.and_ : 'AND', @@ -96,14 +95,14 @@ OPERATORS = { FUNCTIONS = { functions.coalesce : 'coalesce%(expr)s', - functions.current_date: 'CURRENT_DATE', - functions.current_time: 'CURRENT_TIME', + functions.current_date: 'CURRENT_DATE', + functions.current_time: 'CURRENT_TIME', functions.current_timestamp: 'CURRENT_TIMESTAMP', - functions.current_user: 'CURRENT_USER', - functions.localtime: 'LOCALTIME', + functions.current_user: 'CURRENT_USER', + functions.localtime: 'LOCALTIME', functions.localtimestamp: 'LOCALTIMESTAMP', functions.sysdate: 'sysdate', - functions.session_user :'SESSION_USER', + functions.session_user :'SESSION_USER', functions.user: 'USER' } @@ -118,7 +117,7 @@ class DefaultCompiler(engine.Compiled): operators = OPERATORS functions = FUNCTIONS - + def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): """Construct a new ``DefaultCompiler`` object. @@ -132,35 +131,35 @@ class DefaultCompiler(engine.Compiled): a list of column names to be compiled into an INSERT or UPDATE statement. """ - + super(DefaultCompiler, self).__init__(dialect, statement, column_keys, **kwargs) # if we are insert/update/delete. set to true when we visit an INSERT, UPDATE or DELETE self.isdelete = self.isinsert = self.isupdate = False - + # compile INSERT/UPDATE defaults/sequences inlined (no pre-execute) self.inline = inline or getattr(statement, 'inline', False) - + # a dictionary of bind parameter keys to _BindParamClause instances. self.binds = {} - + # a dictionary of _BindParamClause instances to "compiled" names that are # actually present in the generated SQL self.bind_names = {} # a stack. what recursive compiler doesn't have a stack ? :) self.stack = [] - + # relates label names in the final SQL to # a tuple of local column/label name, ColumnElement object (if any) and TypeEngine. # ResultProxy uses this for type processing and column targeting self.result_map = {} - + # a dictionary of ClauseElement subclasses to counters, which are used to # generate truncated identifier names or "anonymous" identifiers such as # for aliases self.generated_ids = {} - + # paramstyle from the dialect (comes from DB-API) self.paramstyle = self.dialect.paramstyle @@ -168,17 +167,17 @@ class DefaultCompiler(engine.Compiled): self.positional = self.dialect.positional self.bindtemplate = BIND_TEMPLATES[self.paramstyle] - + # a list of the compiled's bind parameter names, used to help # formulate a positional argument list self.positiontup = [] # an IdentifierPreparer that formats the quoting of identifiers self.preparer = self.dialect.identifier_preparer - + def compile(self): self.string = self.process(self.statement) - + def process(self, obj, stack=None, **kwargs): if stack: self.stack.append(stack) @@ -189,23 +188,23 @@ class DefaultCompiler(engine.Compiled): finally: if stack: self.stack.pop(-1) - + def is_subquery(self, select): return self.stack and self.stack[-1].get('is_subquery') - + def get_whereclause(self, obj): - """given a FROM clause, return an additional WHERE condition that should be - applied to a SELECT. - + """given a FROM clause, return an additional WHERE condition that should be + applied to a SELECT. + Currently used by Oracle to provide WHERE criterion for JOIN and OUTER JOIN constructs in non-ansi mode. """ - + return None def construct_params(self, params=None): """return a dictionary of bind parameter keys and values""" - + if params: pd = {} for bindparam, name in self.bind_names.iteritems(): @@ -218,9 +217,9 @@ class DefaultCompiler(engine.Compiled): return pd else: return dict([(self.bind_names[bindparam], bindparam.value) for bindparam in self.bind_names]) - + params = property(construct_params) - + def default_from(self): """Called when a SELECT statement has no froms, and no FROM clause is to be appended. @@ -228,22 +227,22 @@ class DefaultCompiler(engine.Compiled): """ return "" - + def visit_grouping(self, grouping, **kwargs): return "(" + self.process(grouping.elem) + ")" - + def visit_label(self, label, result_map=None): labelname = self._truncated_identifier("colident", label.name) - + if result_map is not None: result_map[labelname.lower()] = (label.name, (label, label.obj, labelname), label.obj.type) - + return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)]) - + def visit_column(self, column, result_map=None, use_schema=False, **kwargs): # there is actually somewhat of a ruleset when you would *not* necessarily - # want to truncate a column identifier, if its mapped to the name of a - # physical column. but thats very hard to identify at this point, and + # want to truncate a column identifier, if its mapped to the name of a + # physical column. but thats very hard to identify at this point, and # the identifier length should be greater than the id lengths of any physical # columns so should not matter. @@ -259,7 +258,7 @@ class DefaultCompiler(engine.Compiled): if result_map is not None: result_map[name.lower()] = (name, (column, ), column.type) - + if column._is_oid: n = self.dialect.oid_column_name(column) if n is not None: @@ -288,7 +287,7 @@ class DefaultCompiler(engine.Compiled): # TODO: some dialects might need different behavior here return text.replace('%', '%%') - + def visit_fromclause(self, fromclause, **kwargs): return fromclause.name @@ -302,7 +301,7 @@ class DefaultCompiler(engine.Compiled): if textclause.typemap is not None: for colname, type_ in textclause.typemap.iteritems(): self.result_map[colname.lower()] = (colname, None, type_) - + def do_bindparam(m): name = m.group(1) if name in textclause.bindparams: @@ -311,7 +310,7 @@ class DefaultCompiler(engine.Compiled): return self.bindparam_string(name) # un-escape any \:params - return BIND_PARAMS_ESC.sub(lambda m: m.group(1), + return BIND_PARAMS_ESC.sub(lambda m: m.group(1), BIND_PARAMS.sub(do_bindparam, textclause.text) ) @@ -339,37 +338,37 @@ class DefaultCompiler(engine.Compiled): result_map[func.name.lower()] = (func.name, None, func.type) name = self.function_string(func) - + if callable(name): return name(*[self.process(x) for x in func.clause_expr]) else: return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func)} - + def function_argspec(self, func): return self.process(func.clause_expr) - + def function_string(self, func): return self.functions.get(func.__class__, func.name + "%(expr)s") def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs): stack_entry = {'select':cs} - + if asfrom: stack_entry['is_subquery'] = True elif self.stack and self.stack[-1].get('select'): stack_entry['is_subquery'] = True self.stack.append(stack_entry) - + text = string.join([self.process(c, asfrom=asfrom, parens=False) for c in cs.selects], " " + cs.keyword + " ") group_by = self.process(cs._group_by_clause, asfrom=asfrom) if group_by: text += " GROUP BY " + group_by - text += self.order_by_clause(cs) + text += self.order_by_clause(cs) text += (cs._limit or cs._offset) and self.limit_clause(cs) or "" - + self.stack.pop(-1) - + if asfrom and parens: return "(" + text + ")" else: @@ -382,19 +381,17 @@ class DefaultCompiler(engine.Compiled): if unary.modifier: s = s + " " + self.operator_string(unary.modifier) return s - + def visit_binary(self, binary, **kwargs): op = self.operator_string(binary.operator) if callable(op): return op(self.process(binary.left), self.process(binary.right)) else: return self.process(binary.left) + " " + op + " " + self.process(binary.right) - - return ret - + def operator_string(self, operator): return self.operators.get(operator, str(operator)) - + def visit_bindparam(self, bindparam, **kwargs): name = self._truncate_bindparam(bindparam) if name in self.binds: @@ -403,22 +400,22 @@ class DefaultCompiler(engine.Compiled): raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key) self.binds[bindparam.key] = self.binds[name] = bindparam return self.bindparam_string(name) - + def _truncate_bindparam(self, bindparam): if bindparam in self.bind_names: return self.bind_names[bindparam] - + bind_name = bindparam.key bind_name = self._truncated_identifier("bindparam", bind_name) # add to bind_names for translation self.bind_names[bindparam] = bind_name - + return bind_name - + def _truncated_identifier(self, ident_class, name): if (ident_class, name) in self.generated_ids: return self.generated_ids[(ident_class, name)] - + anonname = ANONYMOUS_LABEL.sub(self._process_anon, name) if len(anonname) > self.dialect.max_identifier_length: @@ -441,14 +438,14 @@ class DefaultCompiler(engine.Compiled): self.generated_ids[('anon_counter', derived)] = anonymous_counter + 1 self.generated_ids[key] = newname return newname - + def _anonymize(self, name): return ANONYMOUS_LABEL.sub(self._process_anon, name) - + def bindparam_string(self, name): if self.positional: self.positiontup.append(name) - + return self.bindtemplate % {'name':name, 'position':len(self.positiontup)} def visit_alias(self, alias, asfrom=False, **kwargs): @@ -459,13 +456,13 @@ class DefaultCompiler(engine.Compiled): def label_select_column(self, select, column, asfrom): """label columns present in a select().""" - + if isinstance(column, sql._Label): return column - + if select.use_labels and getattr(column, '_label', None): return column.label(column._label) - + if \ asfrom and \ isinstance(column, sql._ColumnClause) and \ @@ -494,12 +491,12 @@ class DefaultCompiler(engine.Compiled): stack_entry['iswrapper'] = True else: column_clause_args = {'result_map':self.result_map} - + if self.stack and 'from' in self.stack[-1]: existingfroms = self.stack[-1]['from'] else: existingfroms = None - + froms = select._get_display_froms(existingfroms) correlate_froms = util.Set() @@ -510,17 +507,17 @@ class DefaultCompiler(engine.Compiled): # TODO: might want to propigate existing froms for select(select(select)) # where innermost select should correlate to outermost # if existingfroms: -# correlate_froms = correlate_froms.union(existingfroms) +# correlate_froms = correlate_froms.union(existingfroms) stack_entry['from'] = correlate_froms self.stack.append(stack_entry) # the actual list of columns to print in the SELECT column list. inner_columns = util.OrderedSet() - + for co in select.inner_columns: l = self.label_select_column(select, co, asfrom=asfrom) inner_columns.add(self.process(l, **column_clause_args)) - + collist = string.join(inner_columns.difference(util.Set([None])), ', ') text = " ".join(["SELECT"] + [self.process(x) for x in select._prefixes]) + " " @@ -539,7 +536,7 @@ class DefaultCompiler(engine.Compiled): whereclause = sql.and_(w, whereclause) else: whereclause = w - + if froms: text += " \nFROM " text += string.join(from_strings, ', ') @@ -559,7 +556,7 @@ class DefaultCompiler(engine.Compiled): t = self.process(select._having) if t: text += " \nHAVING " + t - + text += self.order_by_clause(select) text += (select._limit or select._offset) and self.limit_clause(select) or "" text += self.for_update_clause(select) @@ -625,10 +622,10 @@ class DefaultCompiler(engine.Compiled): ', '.join([preparer.quote(c[0], c[0].name) for c in colparams]), ', '.join([c[1] for c in colparams]))) - + def visit_update(self, update_stmt): self.stack.append({'from':util.Set([update_stmt.table])}) - + self.isupdate = True colparams = self._get_colparams(update_stmt) @@ -636,15 +633,15 @@ class DefaultCompiler(engine.Compiled): if update_stmt._whereclause: text += " WHERE " + self.process(update_stmt._whereclause) - + self.stack.pop(-1) - + return text def _get_colparams(self, stmt): - """create a set of tuples representing column/string pairs for use + """create a set of tuples representing column/string pairs for use in an INSERT or UPDATE statement. - + """ def create_bind_param(col, value): @@ -654,7 +651,7 @@ class DefaultCompiler(engine.Compiled): self.postfetch = [] self.prefetch = [] - + # no parameters in the statement, no parameters in the # compiled params - return binds for all columns if self.column_keys is None and stmt.parameters is None: @@ -688,7 +685,7 @@ class DefaultCompiler(engine.Compiled): if (((isinstance(c.default, schema.Sequence) and not c.default.optional) or not self.dialect.supports_pk_autoincrement) or - (c.default is not None and + (c.default is not None and not isinstance(c.default, schema.Sequence))): values.append((c, create_bind_param(c, None))) self.prefetch.append(c) @@ -732,18 +729,18 @@ class DefaultCompiler(engine.Compiled): text += " WHERE " + self.process(delete_stmt._whereclause) self.stack.pop(-1) - + return text - + def visit_savepoint(self, savepoint_stmt): return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) def visit_rollback_to_savepoint(self, savepoint_stmt): return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) - + def visit_release_savepoint(self, savepoint_stmt): return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) - + def __str__(self): return self.string or '' @@ -1072,10 +1069,10 @@ class IdentifierPreparer(object): def format_column(self, column, use_table=False, name=None, table_name=None): """Prepare a quoted column name. - + deprecated. use preparer.quote(col, column.name) or combine with format_table() """ - + if name is None: name = column.name if not getattr(column, 'is_literal', False): @@ -1121,7 +1118,6 @@ class IdentifierPreparer(object): 'final': final, 'escaped': escaped_final }) self._r_identifiers = r - + return [self._unescape_identifier(i) for i in [a or b for a, b in r.findall(identifiers)]] - |