diff options
author | Carlos Rivas <carlos@twobitcoder.com> | 2016-01-26 13:45:31 -0800 |
---|---|---|
committer | Carlos Rivas <carlos@twobitcoder.com> | 2016-01-26 13:45:31 -0800 |
commit | c6d630ca819239bf1b18bd6e51f265fb1be951c9 (patch) | |
tree | e30838e4e462d7994cc69d0c281a2d4a88b89edf /lib/sqlalchemy/sql/compiler.py | |
parent | 28365040ace29c9ceea28946ed19f07c3a4fcefc (diff) | |
parent | 8163de4cc9e01460d3476b9fb3ed14a5b3e70bae (diff) | |
download | sqlalchemy-c6d630ca819239bf1b18bd6e51f265fb1be951c9.tar.gz |
Merged zzzeek/sqlalchemy into master
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 190 |
1 files changed, 139 insertions, 51 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 6766c99b7..492999d16 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -167,25 +167,39 @@ class Compiled(object): _cached_metadata = None def __init__(self, dialect, statement, bind=None, + schema_translate_map=None, compile_kwargs=util.immutabledict()): - """Construct a new ``Compiled`` object. + """Construct a new :class:`.Compiled` object. - :param dialect: ``Dialect`` to compile against. + :param dialect: :class:`.Dialect` to compile against. - :param statement: ``ClauseElement`` to be compiled. + :param statement: :class:`.ClauseElement` to be compiled. :param bind: Optional Engine or Connection to compile this statement against. + :param schema_translate_map: dictionary of schema names to be + translated when forming the resultant SQL + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`schema_translating` + :param compile_kwargs: additional kwargs that will be passed to the initial call to :meth:`.Compiled.process`. - .. versionadded:: 0.8 """ self.dialect = dialect self.bind = bind + self.preparer = self.dialect.identifier_preparer + if schema_translate_map: + self.preparer = self.preparer._with_schema_translate( + schema_translate_map) + if statement is not None: self.statement = statement self.can_execute = statement.supports_execution @@ -286,12 +300,11 @@ class _CompileLabel(visitors.Visitable): def self_group(self, **kw): return self -class SQLCompiler(Compiled): - """Default implementation of Compiled. +class SQLCompiler(Compiled): + """Default implementation of :class:`.Compiled`. - Compiles ClauseElements into SQL strings. Uses a similar visit - paradigm as visitors.ClauseVisitor but implements its own traversal. + Compiles :class:`.ClauseElement` objects into SQL strings. """ @@ -305,6 +318,8 @@ class SQLCompiler(Compiled): INSERT/UPDATE/DELETE """ + isplaintext = False + returning = None """holds the "returning" collection of columns if the statement is CRUD and defines returning columns @@ -330,19 +345,34 @@ class SQLCompiler(Compiled): driver/DB enforces this """ + _textual_ordered_columns = False + """tell the result object that the column names as rendered are important, + but they are also "ordered" vs. what is in the compiled object here. + """ + + _ordered_columns = True + """ + if False, means we can't be sure the list of entries + in _result_columns is actually the rendered order. Usually + True unless using an unordered TextAsFrom. + """ + def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): - """Construct a new ``DefaultCompiler`` object. + """Construct a new :class:`.SQLCompiler` object. - dialect - Dialect to be used + :param dialect: :class:`.Dialect` to be used - statement - ClauseElement to be compiled + :param statement: :class:`.ClauseElement` to be compiled - column_keys - a list of column names to be compiled into an INSERT or UPDATE - statement. + :param column_keys: a list of column names to be compiled into an + INSERT or UPDATE statement. + + :param inline: whether to generate INSERT statements as "inline", e.g. + not formatted to return any generated defaults + + :param kwargs: additional keyword arguments to be consumed by the + superclass. """ self.column_keys = column_keys @@ -368,11 +398,6 @@ class SQLCompiler(Compiled): # column targeting self._result_columns = [] - # if False, means we can't be sure the list of entries - # in _result_columns is actually the rendered order. This - # gets flipped when we use TextAsFrom, for example. - self._ordered_columns = True - # true if the paramstyle is positional self.positional = dialect.positional if self.positional: @@ -381,8 +406,6 @@ class SQLCompiler(Compiled): self.ctes = None - # an IdentifierPreparer that formats the quoting of identifiers - self.preparer = dialect.identifier_preparer self.label_length = dialect.label_length \ or dialect.max_identifier_length @@ -649,8 +672,11 @@ class SQLCompiler(Compiled): if table is None or not include_table or not table.named_with_column: return name else: - if table.schema: - schema_prefix = self.preparer.quote_schema(table.schema) + '.' + effective_schema = self.preparer.schema_for_object(table) + + if effective_schema: + schema_prefix = self.preparer.quote_schema( + effective_schema) + '.' else: schema_prefix = '' tablename = table.name @@ -688,6 +714,9 @@ class SQLCompiler(Compiled): else: return self.bindparam_string(name, **kw) + if not self.stack: + self.isplaintext = True + # un-escape any \:params return BIND_PARAMS_ESC.sub( lambda m: m.group(1), @@ -711,7 +740,8 @@ class SQLCompiler(Compiled): ) or entry.get('need_result_map_for_nested', False) if populate_result_map: - self._ordered_columns = False + self._ordered_columns = \ + self._textual_ordered_columns = taf.positional for c in taf.column_args: self.process(c, within_columns_clause=True, add_to_result_map=self._add_to_result_map) @@ -873,22 +903,28 @@ class SQLCompiler(Compiled): else: return text + def _get_operator_dispatch(self, operator_, qualifier1, qualifier2): + attrname = "visit_%s_%s%s" % ( + operator_.__name__, qualifier1, + "_" + qualifier2 if qualifier2 else "") + return getattr(self, attrname, None) + def visit_unary(self, unary, **kw): if unary.operator: if unary.modifier: raise exc.CompileError( "Unary expression does not support operator " "and modifier simultaneously") - disp = getattr(self, "visit_%s_unary_operator" % - unary.operator.__name__, None) + disp = self._get_operator_dispatch( + unary.operator, "unary", "operator") if disp: return disp(unary, unary.operator, **kw) else: return self._generate_generic_unary_operator( unary, OPERATORS[unary.operator], **kw) elif unary.modifier: - disp = getattr(self, "visit_%s_unary_modifier" % - unary.modifier.__name__, None) + disp = self._get_operator_dispatch( + unary.modifier, "unary", "modifier") if disp: return disp(unary, unary.modifier, **kw) else: @@ -922,7 +958,7 @@ class SQLCompiler(Compiled): kw['literal_binds'] = True operator_ = override_operator or binary.operator - disp = getattr(self, "visit_%s_binary" % operator_.__name__, None) + disp = self._get_operator_dispatch(operator_, "binary", None) if disp: return disp(binary, operator_, **kw) else: @@ -1298,7 +1334,7 @@ class SQLCompiler(Compiled): add_to_result_map = lambda keyname, name, objects, type_: \ self._add_to_result_map( keyname, name, - objects + (column,), type_) + (column,) + objects, type_) else: col_expr = column if populate_result_map: @@ -1386,7 +1422,7 @@ class SQLCompiler(Compiled): """Rewrite any "a JOIN (b JOIN c)" expression as "a JOIN (select * from b JOIN c) AS anon", to support databases that can't parse a parenthesized join correctly - (i.e. sqlite the main one). + (i.e. sqlite < 3.7.16). """ cloned = {} @@ -1801,8 +1837,10 @@ class SQLCompiler(Compiled): def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, fromhints=None, use_schema=True, **kwargs): if asfrom or ashint: - if use_schema and getattr(table, "schema", None): - ret = self.preparer.quote_schema(table.schema) + \ + effective_schema = self.preparer.schema_for_object(table) + + if use_schema and effective_schema: + ret = self.preparer.quote_schema(effective_schema) + \ "." + self.preparer.quote(table.name) else: ret = self.preparer.quote(table.name) @@ -2080,6 +2118,30 @@ class SQLCompiler(Compiled): self.preparer.format_savepoint(savepoint_stmt) +class StrSQLCompiler(SQLCompiler): + """"a compiler subclass with a few non-standard SQL features allowed. + + Used for stringification of SQL statements when a real dialect is not + available. + + """ + + def visit_getitem_binary(self, binary, operator, **kw): + return "%s[%s]" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw) + ) + + def returning_clause(self, stmt, returning_cols): + + columns = [ + self._label_select_column(None, c, True, False, {}) + for c in elements._select_iterables(returning_cols) + ] + + return 'RETURNING ' + ', '.join(columns) + + class DDLCompiler(Compiled): @util.memoized_property @@ -2090,10 +2152,6 @@ class DDLCompiler(Compiled): def type_compiler(self): return self.dialect.type_compiler - @property - def preparer(self): - return self.dialect.identifier_preparer - def construct_params(self, params=None): return None @@ -2103,7 +2161,7 @@ class DDLCompiler(Compiled): if isinstance(ddl.target, schema.Table): context = context.copy() - preparer = self.dialect.identifier_preparer + preparer = self.preparer path = preparer.format_table_seq(ddl.target) if len(path) == 1: table, sch = path[0], '' @@ -2129,7 +2187,7 @@ class DDLCompiler(Compiled): def visit_create_table(self, create): table = create.element - preparer = self.dialect.identifier_preparer + preparer = self.preparer text = "\nCREATE " if table._prefixes: @@ -2256,9 +2314,12 @@ class DDLCompiler(Compiled): index, include_schema=True) def _prepared_index_name(self, index, include_schema=False): - if include_schema and index.table is not None and index.table.schema: - schema = index.table.schema - schema_name = self.preparer.quote_schema(schema) + if index.table is not None: + effective_schema = self.preparer.schema_for_object(index.table) + else: + effective_schema = None + if include_schema and effective_schema: + schema_name = self.preparer.quote_schema(effective_schema) else: schema_name = None @@ -2386,7 +2447,7 @@ class DDLCompiler(Compiled): return text def visit_foreign_key_constraint(self, constraint): - preparer = self.dialect.identifier_preparer + preparer = self.preparer text = "" if constraint.name is not None: formatted_name = self.preparer.format_constraint(constraint) @@ -2603,6 +2664,17 @@ class GenericTypeCompiler(TypeCompiler): return type_.get_col_spec(**kw) +class StrSQLTypeCompiler(GenericTypeCompiler): + def __getattr__(self, key): + if key.startswith("visit_"): + return self._visit_unknown + else: + raise AttributeError(key) + + def _visit_unknown(self, type_, **kw): + return "%s" % type_.__class__.__name__ + + class IdentifierPreparer(object): """Handle quoting and case-folding of identifiers based on options.""" @@ -2613,6 +2685,8 @@ class IdentifierPreparer(object): illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS + schema_for_object = schema._schema_getter(None) + def __init__(self, dialect, initial_quote='"', final_quote=None, escape_quote='"', omit_schema=False): """Construct a new ``IdentifierPreparer`` object. @@ -2637,6 +2711,12 @@ class IdentifierPreparer(object): self.omit_schema = omit_schema self._strings = {} + def _with_schema_translate(self, schema_translate_map): + prep = self.__class__.__new__(self.__class__) + prep.__dict__.update(self.__dict__) + prep.schema_for_object = schema._schema_getter(schema_translate_map) + return prep + def _escape_identifier(self, value): """Escape an identifier. @@ -2709,9 +2789,12 @@ class IdentifierPreparer(object): def format_sequence(self, sequence, use_schema=True): name = self.quote(sequence.name) + + effective_schema = self.schema_for_object(sequence) + if (not self.omit_schema and use_schema and - sequence.schema is not None): - name = self.quote_schema(sequence.schema) + "." + name + effective_schema is not None): + name = self.quote_schema(effective_schema) + "." + name return name def format_label(self, label, name=None): @@ -2740,9 +2823,12 @@ class IdentifierPreparer(object): if name is None: name = table.name result = self.quote(name) + + effective_schema = self.schema_for_object(table) + if not self.omit_schema and use_schema \ - and getattr(table, "schema", None): - result = self.quote_schema(table.schema) + "." + result + and effective_schema: + result = self.quote_schema(effective_schema) + "." + result return result def format_schema(self, name, quote=None): @@ -2781,9 +2867,11 @@ class IdentifierPreparer(object): # ('database', 'owner', etc.) could override this and return # a longer sequence. + effective_schema = self.schema_for_object(table) + if not self.omit_schema and use_schema and \ - getattr(table, 'schema', None): - return (self.quote_schema(table.schema), + effective_schema: + return (self.quote_schema(effective_schema), self.format_table(table, use_schema=False)) else: return (self.format_table(table, use_schema=False), ) |