diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 212 |
1 files changed, 106 insertions, 106 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 07ef0f50a..39d320ede 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -152,15 +152,15 @@ class _CompileLabel(visitors.Visitable): __visit_name__ = 'label' __slots__ = 'element', 'name' - + def __init__(self, col, name): self.element = col self.name = name - + @property def type(self): return self.element.type - + @property def quote(self): return self.element.quote @@ -176,28 +176,28 @@ class SQLCompiler(engine.Compiled): extract_map = EXTRACT_MAP compound_keywords = COMPOUND_KEYWORDS - + # class-level defaults which can be set at the instance # level to define if this Compiled instance represents # INSERT/UPDATE/DELETE isdelete = isinsert = isupdate = False - + # holds the "returning" collection of columns if # the statement is CRUD and defines returning columns # either implicitly or explicitly returning = None - + # set to True classwide to generate RETURNING # clauses before the VALUES or WHERE clause (i.e. MSSQL) returning_precedes_values = False - + # SQL 92 doesn't allow bind parameters to be used # in the columns clause of a SELECT, nor does it allow # ambiguous expressions like "? = ?". A compiler # subclass can set this flag to False if the target # driver/DB enforces this ansi_bind_rules = False - + def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): """Construct a new ``DefaultCompiler`` object. @@ -256,7 +256,7 @@ class SQLCompiler(engine.Compiled): self.truncated_names = {} engine.Compiled.__init__(self, dialect, statement, **kwargs) - + @util.memoized_property def _bind_processors(self): @@ -267,14 +267,14 @@ class SQLCompiler(engine.Compiled): for bindparam in self.bind_names ) if value is not None ) - + def is_subquery(self): return len(self.stack) > 1 @property def sql_compiler(self): return self - + def construct_params(self, params=None, _group_number=None): """return a dictionary of bind parameter keys and values""" @@ -353,25 +353,25 @@ class SQLCompiler(engine.Compiled): return label.element._compiler_dispatch(self, within_columns_clause=False, **kw) - + def visit_column(self, column, result_map=None, **kwargs): name = column.name if name is None: raise exc.CompileError("Cannot compile Column object until " "it's 'name' is assigned.") - + is_literal = column.is_literal if not is_literal and isinstance(name, sql._generated_label): name = self._truncated_identifier("colident", name) if result_map is not None: result_map[name.lower()] = (name, (column, ), column.type) - + if is_literal: name = self.escape_literal_column(name) else: name = self.preparer.quote(name, column.quote) - + table = column.table if table is None or not table.named_with_column: return name @@ -385,7 +385,7 @@ class SQLCompiler(engine.Compiled): tablename = table.name if isinstance(tablename, sql._generated_label): tablename = self._truncated_identifier("alias", tablename) - + return schema_prefix + \ self.preparer.quote(tablename, table.quote) + \ "." + name @@ -407,7 +407,7 @@ class SQLCompiler(engine.Compiled): def post_process_text(self, text): return text - + def visit_textclause(self, textclause, **kwargs): if textclause.typemap is not None: for colname, type_ in textclause.typemap.iteritems(): @@ -486,14 +486,14 @@ class SQLCompiler(engine.Compiled): self.stack.append({'from':entry.get('from', None), 'iswrapper':True}) keyword = self.compound_keywords.get(cs.keyword) - + text = (" " + keyword + " ").join( (c._compiler_dispatch(self, asfrom=asfrom, parens=False, compound_index=i, **kwargs) for i, c in enumerate(cs.selects)) ) - + group_by = cs._group_by_clause._compiler_dispatch( self, asfrom=asfrom, **kwargs) if group_by: @@ -523,7 +523,7 @@ class SQLCompiler(engine.Compiled): isinstance(binary.left, sql._BindParamClause) and \ isinstance(binary.right, sql._BindParamClause): kw['literal_binds'] = True - + return self._operator_dispatch(binary.operator, binary, lambda opstr: binary.left._compiler_dispatch(self, **kw) + @@ -550,7 +550,7 @@ class SQLCompiler(engine.Compiled): + (escape and (' ESCAPE ' + self.render_literal_value(escape, None)) or '') - + def visit_ilike_op(self, binary, **kw): escape = binary.modifiers.get("escape", None) return 'lower(%s) LIKE lower(%s)' % ( @@ -559,7 +559,7 @@ class SQLCompiler(engine.Compiled): + (escape and (' ESCAPE ' + self.render_literal_value(escape, None)) or '') - + def visit_notilike_op(self, binary, **kw): escape = binary.modifiers.get("escape", None) return 'lower(%s) NOT LIKE lower(%s)' % ( @@ -568,7 +568,7 @@ class SQLCompiler(engine.Compiled): + (escape and (' ESCAPE ' + self.render_literal_value(escape, None)) or '') - + def _operator_dispatch(self, operator, element, fn, **kw): if util.callable(operator): disp = getattr(self, "visit_%s" % operator.__name__, None) @@ -578,7 +578,7 @@ class SQLCompiler(engine.Compiled): return fn(OPERATORS[operator]) else: return fn(" " + operator + " ") - + def visit_bindparam(self, bindparam, within_columns_clause=False, literal_binds=False, **kwargs): if literal_binds or \ @@ -589,7 +589,7 @@ class SQLCompiler(engine.Compiled): "renderable value not allowed here.") return self.render_literal_bindparam(bindparam, within_columns_clause=True, **kwargs) - + name = self._truncate_bindparam(bindparam) if name in self.binds: existing = self.binds[name] @@ -610,26 +610,26 @@ class SQLCompiler(engine.Compiled): "with insert() or update() (for example, 'b_%s')." % (bindparam.key, bindparam.key) ) - + self.binds[bindparam.key] = self.binds[name] = bindparam return self.bindparam_string(name) - + def render_literal_bindparam(self, bindparam, **kw): value = bindparam.value processor = bindparam.type._cached_bind_processor(self.dialect) if processor: value = processor(value) return self.render_literal_value(value, bindparam.type) - + def render_literal_value(self, value, type_): """Render the value of a bind parameter as a quoted literal. - + This is used for statement sections that do not accept bind paramters on the target driver/database. - + This should be implemented by subclasses using the quoting services of the DBAPI. - + """ if isinstance(value, basestring): value = value.replace("'", "''") @@ -643,7 +643,7 @@ class SQLCompiler(engine.Compiled): else: raise NotImplementedError( "Don't know how to literal-quote value %r" % value) - + def _truncate_bindparam(self, bindparam): if bindparam in self.bind_names: return self.bind_names[bindparam] @@ -672,10 +672,10 @@ class SQLCompiler(engine.Compiled): truncname = anonname self.truncated_names[(ident_class, name)] = truncname return truncname - + def _anonymize(self, name): return name % self.anon_map - + def _process_anon(self, key): (ident, derived) = key.split(' ', 1) anonymous_counter = self.anon_map.get(derived, 1) @@ -705,12 +705,12 @@ class SQLCompiler(engine.Compiled): asfrom=True, **kwargs) + \ " AS " + \ self.preparer.format_alias(alias, alias_name) - + if fromhints and alias in fromhints: hinttext = self.get_from_hint_text(alias, fromhints[alias]) if hinttext: ret += " " + hinttext - + return ret else: return alias.original._compiler_dispatch(self, **kwargs) @@ -742,16 +742,16 @@ class SQLCompiler(engine.Compiled): def get_select_hint_text(self, byfroms): return None - + def get_from_hint_text(self, table, text): return None - + def visit_select(self, select, asfrom=False, parens=True, iswrapper=False, fromhints=None, compound_index=1, **kwargs): entry = self.stack and self.stack[-1] or {} - + existingfroms = entry.get('from', None) froms = select._get_display_froms(existingfroms) @@ -782,7 +782,7 @@ class SQLCompiler(engine.Compiled): ] if c is not None ] - + text = "SELECT " # we're off to a good start ! if select._hints: @@ -798,7 +798,7 @@ class SQLCompiler(engine.Compiled): hint_text = self.get_select_hint_text(byfrom) if hint_text: text += hint_text + " " - + if select._prefixes: text += " ".join( x._compiler_dispatch(self, **kwargs) @@ -808,7 +808,7 @@ class SQLCompiler(engine.Compiled): if froms: text += " \nFROM " - + if select._hints: text += ', '.join([f._compiler_dispatch(self, asfrom=True, fromhints=byfrom, @@ -854,7 +854,7 @@ class SQLCompiler(engine.Compiled): def get_select_precolumns(self, select): """Called when building a ``SELECT`` statement, position is just before column list. - + """ return select._distinct and "DISTINCT " or "" @@ -924,15 +924,15 @@ class SQLCompiler(engine.Compiled): preparer = self.preparer supports_default_values = self.dialect.supports_default_values - + text = "INSERT" - + prefixes = [self.process(x) for x in insert_stmt._prefixes] if prefixes: text += " " + " ".join(prefixes) - + text += " INTO " + preparer.format_table(insert_stmt.table) - + if colparams or not supports_default_values: text += " (%s)" % ', '.join([preparer.format_column(c[0]) for c in colparams]) @@ -941,7 +941,7 @@ class SQLCompiler(engine.Compiled): self.returning = self.returning or insert_stmt._returning returning_clause = self.returning_clause( insert_stmt, self.returning) - + if self.returning_precedes_values: text += " " + returning_clause @@ -950,12 +950,12 @@ class SQLCompiler(engine.Compiled): else: text += " VALUES (%s)" % \ ', '.join([c[1] for c in colparams]) - + if self.returning and not self.returning_precedes_values: text += " " + returning_clause - + return text - + def visit_update(self, update_stmt): self.stack.append({'from': set([update_stmt.table])}) @@ -963,7 +963,7 @@ class SQLCompiler(engine.Compiled): colparams = self._get_colparams(update_stmt) text = "UPDATE " + self.preparer.format_table(update_stmt.table) - + text += ' SET ' + \ ', '.join( self.preparer.quote(c[0].name, c[0].quote) + @@ -976,14 +976,14 @@ class SQLCompiler(engine.Compiled): if self.returning_precedes_values: text += " " + self.returning_clause( update_stmt, update_stmt._returning) - + if update_stmt._whereclause is not None: text += " WHERE " + self.process(update_stmt._whereclause) if self.returning and not self.returning_precedes_values: text += " " + self.returning_clause( update_stmt, update_stmt._returning) - + self.stack.pop(-1) return text @@ -1001,10 +1001,10 @@ class SQLCompiler(engine.Compiled): "with insert() or update() (for example, 'b_%s')." % (col.key, col.key) ) - + self.binds[col.key] = bindparam return self.bindparam_string(self._truncate_bindparam(bindparam)) - + def _get_colparams(self, stmt): """create a set of tuples representing column/string pairs for use in an INSERT or UPDATE statement. @@ -1030,7 +1030,7 @@ class SQLCompiler(engine.Compiled): ] required = object() - + # if we have statement parameters - set defaults in the # compiled params if self.column_keys is None: @@ -1047,17 +1047,17 @@ class SQLCompiler(engine.Compiled): # create a list of column assignment clauses as tuples values = [] - + need_pks = self.isinsert and \ not self.inline and \ not stmt._returning - + implicit_returning = need_pks and \ self.dialect.implicit_returning and \ stmt.table.implicit_returning - + postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid - + # iterating through columns at the top to maintain ordering. # otherwise we might iterate through individual sets of # "defaults", "primary key cols", etc. @@ -1071,7 +1071,7 @@ class SQLCompiler(engine.Compiled): self.postfetch.append(c) value = self.process(value.self_group()) values.append((c, value)) - + elif self.isinsert: if c.primary_key and \ need_pks and \ @@ -1080,7 +1080,7 @@ class SQLCompiler(engine.Compiled): not postfetch_lastrowid or c is not stmt.table._autoincrement_column ): - + if implicit_returning: if c.default is not None: if c.default.is_sequence: @@ -1115,7 +1115,7 @@ class SQLCompiler(engine.Compiled): (c, self._create_crud_bind_param(c, None)) ) self.prefetch.append(c) - + elif c.default is not None: if c.default.is_sequence: proc = self.process(c.default) @@ -1127,7 +1127,7 @@ class SQLCompiler(engine.Compiled): values.append( (c, self.process(c.default.arg.self_group())) ) - + if not c.primary_key: # dont add primary key column to postfetch self.postfetch.append(c) @@ -1139,7 +1139,7 @@ class SQLCompiler(engine.Compiled): elif c.server_default is not None: if not c.primary_key: self.postfetch.append(c) - + elif self.isupdate: if c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: @@ -1167,14 +1167,14 @@ class SQLCompiler(engine.Compiled): if self.returning_precedes_values: text += " " + self.returning_clause( delete_stmt, delete_stmt._returning) - + if delete_stmt._whereclause is not None: text += " WHERE " + self.process(delete_stmt._whereclause) if self.returning and not self.returning_precedes_values: text += " " + self.returning_clause( delete_stmt, delete_stmt._returning) - + self.stack.pop(-1) return text @@ -1192,18 +1192,18 @@ class SQLCompiler(engine.Compiled): class DDLCompiler(engine.Compiled): - + @util.memoized_property def sql_compiler(self): return self.dialect.statement_compiler(self.dialect, None) - + @property def preparer(self): return self.dialect.identifier_preparer def construct_params(self, params=None): return None - + def visit_ddl(self, ddl, **kwargs): # table events can substitute table and schema name context = ddl.context @@ -1220,7 +1220,7 @@ class DDLCompiler(engine.Compiled): context.setdefault('table', table) context.setdefault('schema', sch) context.setdefault('fullname', preparer.format_table(ddl.target)) - + return self.sql_compiler.post_process_text(ddl.statement % context) def visit_create_table(self, create): @@ -1259,16 +1259,16 @@ class DDLCompiler(engine.Compiled): return text def create_table_constraints(self, table): - + # On some DB order is significant: visit PK first, then the # other constraints (engine.ReflectionTest.testbasic failed on FB2) constraints = [] if table.primary_key: constraints.append(table.primary_key) - + constraints.extend([c for c in table.constraints if c is not table.primary_key]) - + return ", \n\t".join(p for p in (self.process(constraint) for constraint in constraints @@ -1280,7 +1280,7 @@ class DDLCompiler(engine.Compiled): not getattr(constraint, 'use_alter', False) )) if p is not None ) - + def visit_drop_table(self, drop): return "\nDROP TABLE " + self.preparer.format_table(drop.element) @@ -1302,7 +1302,7 @@ class DDLCompiler(engine.Compiled): preparer = self.preparer text = "CREATE " if index.unique: - text += "UNIQUE " + text += "UNIQUE " text += "INDEX %s ON %s (%s)" \ % (preparer.quote(self._index_identifier(index.name), index.quote), @@ -1332,7 +1332,7 @@ class DDLCompiler(engine.Compiled): if create.element.start is not None: text += " START WITH %d" % create.element.start return text - + def visit_drop_sequence(self, drop): return "DROP SEQUENCE %s" % \ self.preparer.format_sequence(drop.element) @@ -1344,7 +1344,7 @@ class DDLCompiler(engine.Compiled): self.preparer.format_constraint(drop.element), drop.cascade and " CASCADE" or "" ) - + def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) + " " + \ self.dialect.type_compiler.process(column.type) @@ -1417,7 +1417,7 @@ class DDLCompiler(engine.Compiled): def define_constraint_remote_table(self, constraint, table, preparer): """Format the remote table clause of a CREATE CONSTRAINT clause.""" - + return preparer.format_table(table) def visit_unique_constraint(self, constraint): @@ -1438,7 +1438,7 @@ class DDLCompiler(engine.Compiled): if constraint.onupdate is not None: text += " ON UPDATE %s" % constraint.onupdate return text - + def define_constraint_deferrability(self, constraint): text = "" if constraint.deferrable is not None: @@ -1449,15 +1449,15 @@ class DDLCompiler(engine.Compiled): if constraint.initially is not None: text += " INITIALLY %s" % constraint.initially return text - - + + class GenericTypeCompiler(engine.TypeCompiler): def visit_CHAR(self, type_): return "CHAR" + (type_.length and "(%d)" % type_.length or "") def visit_NCHAR(self, type_): return "NCHAR" + (type_.length and "(%d)" % type_.length or "") - + def visit_FLOAT(self, type_): return "FLOAT" @@ -1474,7 +1474,7 @@ class GenericTypeCompiler(engine.TypeCompiler): def visit_DECIMAL(self, type_): return "DECIMAL" - + def visit_INTEGER(self, type_): return "INTEGER" @@ -1516,46 +1516,46 @@ class GenericTypeCompiler(engine.TypeCompiler): def visit_VARBINARY(self, type_): return "VARBINARY" + (type_.length and "(%d)" % type_.length or "") - + def visit_BOOLEAN(self, type_): return "BOOLEAN" - + def visit_TEXT(self, type_): return "TEXT" - + def visit_large_binary(self, type_): return self.visit_BLOB(type_) - + def visit_boolean(self, type_): return self.visit_BOOLEAN(type_) - + def visit_time(self, type_): return self.visit_TIME(type_) - + def visit_datetime(self, type_): return self.visit_DATETIME(type_) - + def visit_date(self, type_): return self.visit_DATE(type_) def visit_big_integer(self, type_): return self.visit_BIGINT(type_) - + def visit_small_integer(self, type_): return self.visit_SMALLINT(type_) - + def visit_integer(self, type_): return self.visit_INTEGER(type_) - + def visit_float(self, type_): return self.visit_FLOAT(type_) - + def visit_numeric(self, type_): return self.visit_NUMERIC(type_) - + def visit_string(self, type_): return self.visit_VARCHAR(type_) - + def visit_unicode(self, type_): return self.visit_VARCHAR(type_) @@ -1564,19 +1564,19 @@ class GenericTypeCompiler(engine.TypeCompiler): def visit_unicode_text(self, type_): return self.visit_TEXT(type_) - + def visit_enum(self, type_): return self.visit_VARCHAR(type_) - + def visit_null(self, type_): raise NotImplementedError("Can't generate DDL for the null type") - + def visit_type_decorator(self, type_): return self.process(type_.type_engine(self.dialect)) - + def visit_user_defined(self, type_): return type_.get_col_spec() - + class IdentifierPreparer(object): """Handle quoting and case-folding of identifiers based on options.""" @@ -1609,7 +1609,7 @@ class IdentifierPreparer(object): self.escape_to_quote = self.escape_quote * 2 self.omit_schema = omit_schema self._strings = {} - + def _escape_identifier(self, value): """Escape an identifier. @@ -1689,7 +1689,7 @@ class IdentifierPreparer(object): def format_constraint(self, constraint): return self.quote(constraint.name, constraint.quote) - + def format_table(self, table, use_schema=True, name=None): """Prepare a quoted table and schema name.""" @@ -1754,7 +1754,7 @@ class IdentifierPreparer(object): 'final': final, 'escaped': escaped_final }) return r - + def unformat_identifiers(self, identifiers): """Unpack 'schema.table.column'-like strings into components.""" |