diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 140 |
1 files changed, 115 insertions, 25 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 1f183b5c1..14f4bda8c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -26,6 +26,7 @@ To generate user-defined SQL strings, see import collections import contextlib import itertools +import operator import re from . import base @@ -39,6 +40,7 @@ from . import schema from . import selectable from . import sqltypes from .base import NO_ARG +from .elements import quoted_name from .. import exc from .. import util @@ -369,6 +371,8 @@ class Compiled(object): _cached_metadata = None + schema_translate_map = None + execution_options = util.immutabledict() """ Execution options propagated from the statement. In some cases, @@ -381,6 +385,7 @@ class Compiled(object): statement, bind=None, schema_translate_map=None, + render_schema_translate=False, compile_kwargs=util.immutabledict(), ): """Construct a new :class:`.Compiled` object. @@ -411,6 +416,7 @@ class Compiled(object): self.bind = bind self.preparer = self.dialect.identifier_preparer if schema_translate_map: + self.schema_translate_map = schema_translate_map self.preparer = self.preparer._with_schema_translate( schema_translate_map ) @@ -422,6 +428,11 @@ class Compiled(object): self.execution_options = statement._execution_options self.string = self.process(self.statement, **compile_kwargs) + if render_schema_translate: + self.string = self.preparer._render_schema_translates( + self.string, schema_translate_map + ) + @util.deprecated( "0.7", "The :meth:`.Compiled.compile` method is deprecated and will be " @@ -2281,6 +2292,46 @@ class SQLCompiler(Compiled): return text + def visit_values(self, element, asfrom=False, from_linter=None, **kw): + v = "VALUES %s" % ", ".join( + self.process(elem, literal_binds=element.literal_binds) + for elem in element._data + ) + + if isinstance(element.name, elements._truncated_label): + name = self._truncated_identifier("values", element.name) + else: + name = element.name + + if element._is_lateral: + lateral = "LATERAL " + else: + lateral = "" + + if asfrom: + if from_linter: + from_linter.froms[element] = ( + name if name is not None else "(unnamed VALUES element)" + ) + + if name: + v = "%s(%s)%s (%s)" % ( + lateral, + v, + self.get_render_as_alias_suffix(self.preparer.quote(name)), + ( + ", ".join( + c._compiler_dispatch( + self, include_table=False, **kw + ) + for c in element.columns + ) + ), + ) + else: + v = "%s(%s)" % (lateral, v) + return v + def get_render_as_alias_suffix(self, alias_name_text): return " AS " + alias_name_text @@ -3365,18 +3416,18 @@ class DDLCompiler(Compiled): return self.sql_compiler.post_process_text(ddl.statement % context) - def visit_create_schema(self, create): + def visit_create_schema(self, create, **kw): schema = self.preparer.format_schema(create.element) return "CREATE SCHEMA " + schema - def visit_drop_schema(self, drop): + def visit_drop_schema(self, drop, **kw): schema = self.preparer.format_schema(drop.element) text = "DROP SCHEMA " + schema if drop.cascade: text += " CASCADE" return text - def visit_create_table(self, create): + def visit_create_table(self, create, **kw): table = create.element preparer = self.preparer @@ -3426,7 +3477,7 @@ class DDLCompiler(Compiled): text += "\n)%s\n\n" % self.post_create_table(table) return text - def visit_create_column(self, create, first_pk=False): + def visit_create_column(self, create, first_pk=False, **kw): column = create.element if column.system: @@ -3442,7 +3493,7 @@ class DDLCompiler(Compiled): return text def create_table_constraints( - self, table, _include_foreign_key_constraints=None + self, table, _include_foreign_key_constraints=None, **kw ): # On some DB order is significant: visit PK first, then the @@ -3482,10 +3533,10 @@ class DDLCompiler(Compiled): if p is not None ) - def visit_drop_table(self, drop): + def visit_drop_table(self, drop, **kw): return "\nDROP TABLE " + self.preparer.format_table(drop.element) - def visit_drop_view(self, drop): + def visit_drop_view(self, drop, **kw): return "\nDROP VIEW " + self.preparer.format_table(drop.element) def _verify_index_table(self, index): @@ -3495,7 +3546,7 @@ class DDLCompiler(Compiled): ) def visit_create_index( - self, create, include_schema=False, include_table_schema=True + self, create, include_schema=False, include_table_schema=True, **kw ): index = create.element self._verify_index_table(index) @@ -3521,7 +3572,7 @@ class DDLCompiler(Compiled): ) return text - def visit_drop_index(self, drop): + def visit_drop_index(self, drop, **kw): index = drop.element if index.name is None: @@ -3548,13 +3599,13 @@ class DDLCompiler(Compiled): index_name = schema_name + "." + index_name return index_name - def visit_add_constraint(self, create): + def visit_add_constraint(self, create, **kw): return "ALTER TABLE %s ADD %s" % ( self.preparer.format_table(create.element.table), self.process(create.element), ) - def visit_set_table_comment(self, create): + def visit_set_table_comment(self, create, **kw): return "COMMENT ON TABLE %s IS %s" % ( self.preparer.format_table(create.element), self.sql_compiler.render_literal_value( @@ -3562,12 +3613,12 @@ class DDLCompiler(Compiled): ), ) - def visit_drop_table_comment(self, drop): + def visit_drop_table_comment(self, drop, **kw): return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table( drop.element ) - def visit_set_column_comment(self, create): + def visit_set_column_comment(self, create, **kw): return "COMMENT ON COLUMN %s IS %s" % ( self.preparer.format_column( create.element, use_table=True, use_schema=True @@ -3577,12 +3628,12 @@ class DDLCompiler(Compiled): ), ) - def visit_drop_column_comment(self, drop): + def visit_drop_column_comment(self, drop, **kw): return "COMMENT ON COLUMN %s IS NULL" % self.preparer.format_column( drop.element, use_table=True ) - def visit_create_sequence(self, create): + def visit_create_sequence(self, create, **kw): text = "CREATE SEQUENCE %s" % self.preparer.format_sequence( create.element ) @@ -3606,10 +3657,10 @@ class DDLCompiler(Compiled): text += " CYCLE" return text - def visit_drop_sequence(self, drop): + def visit_drop_sequence(self, drop, **kw): return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element) - def visit_drop_constraint(self, drop): + def visit_drop_constraint(self, drop, **kw): constraint = drop.element if constraint.name is not None: formatted_name = self.preparer.format_constraint(constraint) @@ -3671,7 +3722,7 @@ class DDLCompiler(Compiled): else: return self.visit_check_constraint(constraint) - def visit_check_constraint(self, constraint): + def visit_check_constraint(self, constraint, **kw): text = "" if constraint.name is not None: formatted_name = self.preparer.format_constraint(constraint) @@ -3683,7 +3734,7 @@ class DDLCompiler(Compiled): text += self.define_constraint_deferrability(constraint) return text - def visit_column_check_constraint(self, constraint): + def visit_column_check_constraint(self, constraint, **kw): text = "" if constraint.name is not None: formatted_name = self.preparer.format_constraint(constraint) @@ -3695,7 +3746,7 @@ class DDLCompiler(Compiled): text += self.define_constraint_deferrability(constraint) return text - def visit_primary_key_constraint(self, constraint): + def visit_primary_key_constraint(self, constraint, **kw): if len(constraint) == 0: return "" text = "" @@ -3715,7 +3766,7 @@ class DDLCompiler(Compiled): text += self.define_constraint_deferrability(constraint) return text - def visit_foreign_key_constraint(self, constraint): + def visit_foreign_key_constraint(self, constraint, **kw): preparer = self.preparer text = "" if constraint.name is not None: @@ -3744,7 +3795,7 @@ class DDLCompiler(Compiled): return preparer.format_table(table) - def visit_unique_constraint(self, constraint): + def visit_unique_constraint(self, constraint, **kw): if len(constraint) == 0: return "" text = "" @@ -3789,7 +3840,7 @@ class DDLCompiler(Compiled): text += " MATCH %s" % constraint.match return text - def visit_computed_column(self, generated): + def visit_computed_column(self, generated, **kw): text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process( generated.sqltext, include_table=False, literal_binds=True ) @@ -3975,7 +4026,16 @@ class IdentifierPreparer(object): illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS - schema_for_object = schema._schema_getter(None) + schema_for_object = operator.attrgetter("schema") + """Return the .schema attribute for an object. + + For the default IdentifierPreparer, the schema for an object is always + the value of the ".schema" attribute. if the preparer is replaced + with one that has a non-empty schema_translate_map, the value of the + ".schema" attribute is rendered a symbol that will be converted to a + real schema name from the mapping post-compile. + + """ def __init__( self, @@ -4016,9 +4076,39 @@ class IdentifierPreparer(object): def _with_schema_translate(self, schema_translate_map): prep = self.__class__.__new__(self.__class__) prep.__dict__.update(self.__dict__) - prep.schema_for_object = schema._schema_getter(schema_translate_map) + + def symbol_getter(obj): + name = obj.schema + if name in schema_translate_map and obj._use_schema_map: + return quoted_name( + "[SCHEMA_%s]" % (name or "_none"), quote=False + ) + else: + return obj.schema + + prep.schema_for_object = symbol_getter return prep + def _render_schema_translates(self, statement, schema_translate_map): + d = schema_translate_map + if None in d: + d["_none"] = d[None] + + def replace(m): + name = m.group(2) + effective_schema = d[name] + if not effective_schema: + effective_schema = self.dialect.default_schema_name + if not effective_schema: + # TODO: no coverage here + raise exc.CompileError( + "Dialect has no default schema name; can't " + "use None as dynamic schema target." + ) + return self.quote(effective_schema) + + return re.sub(r"(\[SCHEMA_([\w\d_]+)\])", replace, statement) + def _escape_identifier(self, value): """Escape an identifier. |