diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2016-01-08 22:11:09 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2016-01-08 22:12:25 -0500 |
commit | 89facbed8855d1443dbe37919ff0645aea640ed0 (patch) | |
tree | 33e7ab15470a5f3a76b748418e6be0c62aa1eaba /lib/sqlalchemy | |
parent | 777e25694f1567ff61655d86a91be6264186c13e (diff) | |
download | sqlalchemy-89facbed8855d1443dbe37919ff0645aea640ed0.tar.gz |
- Multi-tenancy schema translation for :class:`.Table` objects is added.
This supports the use case of an application that uses the same set of
:class:`.Table` objects in many schemas, such as schema-per-user.
A new execution option
:paramref:`.Connection.execution_options.schema_translate_map` is
added. fixes #2685
- latest tox doesn't like the {posargs} in the profile rerunner
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 18 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/base.py | 42 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/default.py | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/reflection.py | 3 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/strategies.py | 3 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 91 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/ddl.py | 21 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/assertsql.py | 9 |
9 files changed, 157 insertions, 44 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 3f9fcb27f..3b3d65155 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1481,8 +1481,11 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer): raise exc.CompileError("Postgresql ENUM type requires a name.") name = self.quote(type_.name) - if not self.omit_schema and use_schema and type_.schema is not None: - name = self.quote_schema(type_.schema) + "." + name + effective_schema = self._get_effective_schema(type_) + + if not self.omit_schema and use_schema and \ + effective_schema is not None: + name = self.quote_schema(effective_schema) + "." + name return name @@ -1575,10 +1578,15 @@ class PGExecutionContext(default.DefaultExecutionContext): name = "%s_%s_seq" % (tab, col) column._postgresql_seq_name = seq_name = name - sch = column.table.schema - if sch is not None: + if column.table is not None: + effective_schema = self.connection._get_effective_schema( + column.table) + else: + effective_schema = None + + if effective_schema is not None: exc = "select nextval('\"%s\".\"%s\"')" % \ - (sch, seq_name) + (effective_schema, seq_name) else: exc = "select nextval('\"%s\"')" % \ (seq_name, ) diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 31e253eed..88f53abcf 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -44,6 +44,8 @@ class Connection(Connectable): """ + _schema_translate_map = None + def __init__(self, engine, connection=None, close_with_result=False, _branch_from=None, _execution_options=None, _dispatch=None, @@ -140,6 +142,13 @@ class Connection(Connectable): c.__dict__ = self.__dict__.copy() return c + def _get_effective_schema(self, table): + effective_schema = table.schema + if self._schema_translate_map: + effective_schema = self._schema_translate_map.get( + effective_schema, effective_schema) + return effective_schema + def __enter__(self): return self @@ -277,6 +286,19 @@ class Connection(Connectable): of many DBAPIs. The flag is currently understood only by the psycopg2 dialect. + :param schema_translate_map: Available on: Connection, Engine. + A dictionary mapping schema names to schema names, that will be + applied to the :paramref:`.Table.schema` element of each + :class:`.Table` encountered when SQL or DDL expression elements + are compiled into strings; the resulting schema name will be + converted based on presence in the map of the original name. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`schema_translating` + """ c = self._clone() c._execution_options = c._execution_options.union(opt) @@ -959,7 +981,9 @@ class Connection(Connectable): dialect = self.dialect - compiled = ddl.compile(dialect=dialect) + compiled = ddl.compile( + dialect=dialect, + schema_translate_map=self._schema_translate_map) ret = self._execute_context( dialect, dialect.execution_ctx_cls._init_ddl, @@ -990,17 +1014,27 @@ class Connection(Connectable): dialect = self.dialect if 'compiled_cache' in self._execution_options: - key = dialect, elem, tuple(sorted(keys)), len(distilled_params) > 1 + key = ( + dialect, elem, tuple(sorted(keys)), + tuple( + (k, self._schema_translate_map[k]) + for k in sorted(self._schema_translate_map) + ) if self._schema_translate_map else None, + len(distilled_params) > 1 + ) compiled_sql = self._execution_options['compiled_cache'].get(key) if compiled_sql is None: compiled_sql = elem.compile( dialect=dialect, column_keys=keys, - inline=len(distilled_params) > 1) + inline=len(distilled_params) > 1, + schema_translate_map=self._schema_translate_map + ) self._execution_options['compiled_cache'][key] = compiled_sql else: compiled_sql = elem.compile( dialect=dialect, column_keys=keys, - inline=len(distilled_params) > 1) + inline=len(distilled_params) > 1, + schema_translate_map=self._schema_translate_map) ret = self._execute_context( dialect, diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 87278c2be..160fe545e 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -398,10 +398,18 @@ class DefaultDialect(interfaces.Dialect): if not branch: self._set_connection_isolation(connection, isolation_level) + if 'schema_translate_map' in opts: + @event.listens_for(engine, "engine_connect") + def set_schema_translate_map(connection, branch): + connection._schema_translate_map = opts['schema_translate_map'] + def set_connection_execution_options(self, connection, opts): if 'isolation_level' in opts: self._set_connection_isolation(connection, opts['isolation_level']) + if 'schema_translate_map' in opts: + connection._schema_translate_map = opts['schema_translate_map'] + def _set_connection_isolation(self, connection, level): if connection.in_transaction(): util.warn( diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index 59eed51ec..dca99e1ce 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -529,7 +529,8 @@ class Inspector(object): """ dialect = self.bind.dialect - schema = table.schema + schema = self.bind._get_effective_schema(table) + table_name = table.name # get table-level arguments that are specifically diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index 0d0414ed1..cb3e6fa8a 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -233,6 +233,9 @@ class MockEngineStrategy(EngineStrategy): dialect = property(attrgetter('_dialect')) name = property(lambda s: s._dialect.name) + def _get_effective_schema(self, table): + return table.schema + def contextual_connect(self, **kwargs): return self diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 98ab60aaa..4068d18be 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -167,6 +167,7 @@ class Compiled(object): _cached_metadata = None def __init__(self, dialect, statement, bind=None, + schema_translate_map=None, compile_kwargs=util.immutabledict()): """Construct a new :class:`.Compiled` object. @@ -177,15 +178,24 @@ class Compiled(object): :param bind: Optional Engine or Connection to compile this statement against. + :param schema_translate_map: dictionary of schema names to be + translated when forming the resultant SQL + + .. versionadded:: 1.1 + :param compile_kwargs: additional kwargs that will be passed to the initial call to :meth:`.Compiled.process`. - .. versionadded:: 0.8 """ self.dialect = dialect self.bind = bind + self.preparer = self.dialect.identifier_preparer + if schema_translate_map: + self.preparer = self.preparer._with_schema_translate( + schema_translate_map) + if statement is not None: self.statement = statement self.can_execute = statement.supports_execution @@ -385,8 +395,6 @@ class SQLCompiler(Compiled): self.ctes = None - # an IdentifierPreparer that formats the quoting of identifiers - self.preparer = dialect.identifier_preparer self.label_length = dialect.label_length \ or dialect.max_identifier_length @@ -653,8 +661,16 @@ class SQLCompiler(Compiled): if table is None or not include_table or not table.named_with_column: return name else: - if table.schema: - schema_prefix = self.preparer.quote_schema(table.schema) + '.' + + # inlining of preparer._get_effective_schema + effective_schema = table.schema + if self.preparer.schema_translate_map: + effective_schema = self.preparer.schema_translate_map.get( + effective_schema, effective_schema) + + if effective_schema: + schema_prefix = self.preparer.quote_schema( + effective_schema) + '.' else: schema_prefix = '' tablename = table.name @@ -1814,8 +1830,15 @@ class SQLCompiler(Compiled): def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, fromhints=None, use_schema=True, **kwargs): if asfrom or ashint: - if use_schema and getattr(table, "schema", None): - ret = self.preparer.quote_schema(table.schema) + \ + + # inlining of preparer._get_effective_schema + effective_schema = table.schema + if self.preparer.schema_translate_map: + effective_schema = self.preparer.schema_translate_map.get( + effective_schema, effective_schema) + + if use_schema and effective_schema: + ret = self.preparer.quote_schema(effective_schema) + \ "." + self.preparer.quote(table.name) else: ret = self.preparer.quote(table.name) @@ -2103,10 +2126,6 @@ class DDLCompiler(Compiled): def type_compiler(self): return self.dialect.type_compiler - @property - def preparer(self): - return self.dialect.identifier_preparer - def construct_params(self, params=None): return None @@ -2116,7 +2135,7 @@ class DDLCompiler(Compiled): if isinstance(ddl.target, schema.Table): context = context.copy() - preparer = self.dialect.identifier_preparer + preparer = self.preparer path = preparer.format_table_seq(ddl.target) if len(path) == 1: table, sch = path[0], '' @@ -2142,7 +2161,7 @@ class DDLCompiler(Compiled): def visit_create_table(self, create): table = create.element - preparer = self.dialect.identifier_preparer + preparer = self.preparer text = "\nCREATE " if table._prefixes: @@ -2269,9 +2288,12 @@ class DDLCompiler(Compiled): index, include_schema=True) def _prepared_index_name(self, index, include_schema=False): - if include_schema and index.table is not None and index.table.schema: - schema = index.table.schema - schema_name = self.preparer.quote_schema(schema) + if index.table is not None: + effective_schema = self.preparer._get_effective_schema(index.table) + else: + effective_schema = None + if include_schema and effective_schema: + schema_name = self.preparer.quote_schema(effective_schema) else: schema_name = None @@ -2399,7 +2421,7 @@ class DDLCompiler(Compiled): return text def visit_foreign_key_constraint(self, constraint): - preparer = self.dialect.identifier_preparer + preparer = self.preparer text = "" if constraint.name is not None: formatted_name = self.preparer.format_constraint(constraint) @@ -2626,6 +2648,8 @@ class IdentifierPreparer(object): illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS + schema_translate_map = util.immutabledict() + def __init__(self, dialect, initial_quote='"', final_quote=None, escape_quote='"', omit_schema=False): """Construct a new ``IdentifierPreparer`` object. @@ -2650,6 +2674,12 @@ class IdentifierPreparer(object): self.omit_schema = omit_schema self._strings = {} + def _with_schema_translate(self, schema_translate_map): + prep = self.__class__.__new__(self.__class__) + prep.__dict__.update(self.__dict__) + prep.schema_translate_map = schema_translate_map + return prep + def _escape_identifier(self, value): """Escape an identifier. @@ -2722,9 +2752,12 @@ class IdentifierPreparer(object): def format_sequence(self, sequence, use_schema=True): name = self.quote(sequence.name) + + effective_schema = self._get_effective_schema(sequence) + if (not self.omit_schema and use_schema and - sequence.schema is not None): - name = self.quote_schema(sequence.schema) + "." + name + effective_schema is not None): + name = self.quote_schema(effective_schema) + "." + name return name def format_label(self, label, name=None): @@ -2747,15 +2780,25 @@ class IdentifierPreparer(object): return None return self.quote(constraint.name) + def _get_effective_schema(self, table): + effective_schema = table.schema + if self.schema_translate_map: + effective_schema = self.schema_translate_map.get( + effective_schema, effective_schema) + return effective_schema + def format_table(self, table, use_schema=True, name=None): """Prepare a quoted table and schema name.""" if name is None: name = table.name result = self.quote(name) + + effective_schema = self._get_effective_schema(table) + if not self.omit_schema and use_schema \ - and getattr(table, "schema", None): - result = self.quote_schema(table.schema) + "." + result + and effective_schema: + result = self.quote_schema(effective_schema) + "." + result return result def format_schema(self, name, quote=None): @@ -2794,9 +2837,11 @@ class IdentifierPreparer(object): # ('database', 'owner', etc.) could override this and return # a longer sequence. + effective_schema = self._get_effective_schema(table) + if not self.omit_schema and use_schema and \ - getattr(table, 'schema', None): - return (self.quote_schema(table.schema), + effective_schema: + return (self.quote_schema(effective_schema), self.format_table(table, use_schema=False)) else: return (self.format_table(table, use_schema=False), ) diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 71018f132..7225da551 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -679,13 +679,16 @@ class SchemaGenerator(DDLBase): def _can_create_table(self, table): self.dialect.validate_identifier(table.name) - if table.schema: - self.dialect.validate_identifier(table.schema) + effective_schema = self.connection._get_effective_schema(table) + if effective_schema: + self.dialect.validate_identifier(effective_schema) return not self.checkfirst or \ not self.dialect.has_table(self.connection, - table.name, schema=table.schema) + table.name, schema=effective_schema) def _can_create_sequence(self, sequence): + effective_schema = self.connection._get_effective_schema(sequence) + return self.dialect.supports_sequences and \ ( (not self.dialect.sequences_optional or @@ -695,7 +698,7 @@ class SchemaGenerator(DDLBase): not self.dialect.has_sequence( self.connection, sequence.name, - schema=sequence.schema) + schema=effective_schema) ) ) @@ -882,12 +885,14 @@ class SchemaDropper(DDLBase): def _can_drop_table(self, table): self.dialect.validate_identifier(table.name) - if table.schema: - self.dialect.validate_identifier(table.schema) + effective_schema = self.connection._get_effective_schema(table) + if effective_schema: + self.dialect.validate_identifier(effective_schema) return not self.checkfirst or self.dialect.has_table( - self.connection, table.name, schema=table.schema) + self.connection, table.name, schema=effective_schema) def _can_drop_sequence(self, sequence): + effective_schema = self.connection._get_effective_schema(sequence) return self.dialect.supports_sequences and \ ((not self.dialect.sequences_optional or not sequence.optional) and @@ -895,7 +900,7 @@ class SchemaDropper(DDLBase): self.dialect.has_sequence( self.connection, sequence.name, - schema=sequence.schema)) + schema=effective_schema)) ) def visit_index(self, index): diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 63667654d..ad0aa4362 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -273,7 +273,8 @@ class AssertsCompiledSQL(object): check_prefetch=None, use_default_dialect=False, allow_dialect_select=False, - literal_binds=False): + literal_binds=False, + schema_translate_map=None): if use_default_dialect: dialect = default.DefaultDialect() elif allow_dialect_select: @@ -292,6 +293,9 @@ class AssertsCompiledSQL(object): kw = {} compile_kwargs = {} + if schema_translate_map: + kw['schema_translate_map'] = schema_translate_map + if params is not None: kw['column_keys'] = list(params) diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index 39d078985..904149c16 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -87,13 +87,18 @@ class CompiledSQL(SQLMatchRule): compare_dialect = self._compile_dialect(execute_observed) if isinstance(context.compiled.statement, _DDLCompiles): compiled = \ - context.compiled.statement.compile(dialect=compare_dialect) + context.compiled.statement.compile( + dialect=compare_dialect, + schema_translate_map=context. + compiled.preparer.schema_translate_map) else: compiled = ( context.compiled.statement.compile( dialect=compare_dialect, column_keys=context.compiled.column_keys, - inline=context.compiled.inline) + inline=context.compiled.inline, + schema_translate_map=context. + compiled.preparer.schema_translate_map) ) _received_statement = re.sub(r'[\n\t]', '', util.text_type(compiled)) parameters = execute_observed.parameters |