summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2016-01-08 22:11:09 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2016-01-08 22:12:25 -0500
commit89facbed8855d1443dbe37919ff0645aea640ed0 (patch)
tree33e7ab15470a5f3a76b748418e6be0c62aa1eaba /lib/sqlalchemy
parent777e25694f1567ff61655d86a91be6264186c13e (diff)
downloadsqlalchemy-89facbed8855d1443dbe37919ff0645aea640ed0.tar.gz
- Multi-tenancy schema translation for :class:`.Table` objects is added.
This supports the use case of an application that uses the same set of :class:`.Table` objects in many schemas, such as schema-per-user. A new execution option :paramref:`.Connection.execution_options.schema_translate_map` is added. fixes #2685 - latest tox doesn't like the {posargs} in the profile rerunner
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py18
-rw-r--r--lib/sqlalchemy/engine/base.py42
-rw-r--r--lib/sqlalchemy/engine/default.py8
-rw-r--r--lib/sqlalchemy/engine/reflection.py3
-rw-r--r--lib/sqlalchemy/engine/strategies.py3
-rw-r--r--lib/sqlalchemy/sql/compiler.py91
-rw-r--r--lib/sqlalchemy/sql/ddl.py21
-rw-r--r--lib/sqlalchemy/testing/assertions.py6
-rw-r--r--lib/sqlalchemy/testing/assertsql.py9
9 files changed, 157 insertions, 44 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 3f9fcb27f..3b3d65155 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -1481,8 +1481,11 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer):
raise exc.CompileError("Postgresql ENUM type requires a name.")
name = self.quote(type_.name)
- if not self.omit_schema and use_schema and type_.schema is not None:
- name = self.quote_schema(type_.schema) + "." + name
+ effective_schema = self._get_effective_schema(type_)
+
+ if not self.omit_schema and use_schema and \
+ effective_schema is not None:
+ name = self.quote_schema(effective_schema) + "." + name
return name
@@ -1575,10 +1578,15 @@ class PGExecutionContext(default.DefaultExecutionContext):
name = "%s_%s_seq" % (tab, col)
column._postgresql_seq_name = seq_name = name
- sch = column.table.schema
- if sch is not None:
+ if column.table is not None:
+ effective_schema = self.connection._get_effective_schema(
+ column.table)
+ else:
+ effective_schema = None
+
+ if effective_schema is not None:
exc = "select nextval('\"%s\".\"%s\"')" % \
- (sch, seq_name)
+ (effective_schema, seq_name)
else:
exc = "select nextval('\"%s\"')" % \
(seq_name, )
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 31e253eed..88f53abcf 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -44,6 +44,8 @@ class Connection(Connectable):
"""
+ _schema_translate_map = None
+
def __init__(self, engine, connection=None, close_with_result=False,
_branch_from=None, _execution_options=None,
_dispatch=None,
@@ -140,6 +142,13 @@ class Connection(Connectable):
c.__dict__ = self.__dict__.copy()
return c
+ def _get_effective_schema(self, table):
+ effective_schema = table.schema
+ if self._schema_translate_map:
+ effective_schema = self._schema_translate_map.get(
+ effective_schema, effective_schema)
+ return effective_schema
+
def __enter__(self):
return self
@@ -277,6 +286,19 @@ class Connection(Connectable):
of many DBAPIs. The flag is currently understood only by the
psycopg2 dialect.
+ :param schema_translate_map: Available on: Connection, Engine.
+ A dictionary mapping schema names to schema names, that will be
+ applied to the :paramref:`.Table.schema` element of each
+ :class:`.Table` encountered when SQL or DDL expression elements
+ are compiled into strings; the resulting schema name will be
+ converted based on presence in the map of the original name.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`schema_translating`
+
"""
c = self._clone()
c._execution_options = c._execution_options.union(opt)
@@ -959,7 +981,9 @@ class Connection(Connectable):
dialect = self.dialect
- compiled = ddl.compile(dialect=dialect)
+ compiled = ddl.compile(
+ dialect=dialect,
+ schema_translate_map=self._schema_translate_map)
ret = self._execute_context(
dialect,
dialect.execution_ctx_cls._init_ddl,
@@ -990,17 +1014,27 @@ class Connection(Connectable):
dialect = self.dialect
if 'compiled_cache' in self._execution_options:
- key = dialect, elem, tuple(sorted(keys)), len(distilled_params) > 1
+ key = (
+ dialect, elem, tuple(sorted(keys)),
+ tuple(
+ (k, self._schema_translate_map[k])
+ for k in sorted(self._schema_translate_map)
+ ) if self._schema_translate_map else None,
+ len(distilled_params) > 1
+ )
compiled_sql = self._execution_options['compiled_cache'].get(key)
if compiled_sql is None:
compiled_sql = elem.compile(
dialect=dialect, column_keys=keys,
- inline=len(distilled_params) > 1)
+ inline=len(distilled_params) > 1,
+ schema_translate_map=self._schema_translate_map
+ )
self._execution_options['compiled_cache'][key] = compiled_sql
else:
compiled_sql = elem.compile(
dialect=dialect, column_keys=keys,
- inline=len(distilled_params) > 1)
+ inline=len(distilled_params) > 1,
+ schema_translate_map=self._schema_translate_map)
ret = self._execute_context(
dialect,
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 87278c2be..160fe545e 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -398,10 +398,18 @@ class DefaultDialect(interfaces.Dialect):
if not branch:
self._set_connection_isolation(connection, isolation_level)
+ if 'schema_translate_map' in opts:
+ @event.listens_for(engine, "engine_connect")
+ def set_schema_translate_map(connection, branch):
+ connection._schema_translate_map = opts['schema_translate_map']
+
def set_connection_execution_options(self, connection, opts):
if 'isolation_level' in opts:
self._set_connection_isolation(connection, opts['isolation_level'])
+ if 'schema_translate_map' in opts:
+ connection._schema_translate_map = opts['schema_translate_map']
+
def _set_connection_isolation(self, connection, level):
if connection.in_transaction():
util.warn(
diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py
index 59eed51ec..dca99e1ce 100644
--- a/lib/sqlalchemy/engine/reflection.py
+++ b/lib/sqlalchemy/engine/reflection.py
@@ -529,7 +529,8 @@ class Inspector(object):
"""
dialect = self.bind.dialect
- schema = table.schema
+ schema = self.bind._get_effective_schema(table)
+
table_name = table.name
# get table-level arguments that are specifically
diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py
index 0d0414ed1..cb3e6fa8a 100644
--- a/lib/sqlalchemy/engine/strategies.py
+++ b/lib/sqlalchemy/engine/strategies.py
@@ -233,6 +233,9 @@ class MockEngineStrategy(EngineStrategy):
dialect = property(attrgetter('_dialect'))
name = property(lambda s: s._dialect.name)
+ def _get_effective_schema(self, table):
+ return table.schema
+
def contextual_connect(self, **kwargs):
return self
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 98ab60aaa..4068d18be 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -167,6 +167,7 @@ class Compiled(object):
_cached_metadata = None
def __init__(self, dialect, statement, bind=None,
+ schema_translate_map=None,
compile_kwargs=util.immutabledict()):
"""Construct a new :class:`.Compiled` object.
@@ -177,15 +178,24 @@ class Compiled(object):
:param bind: Optional Engine or Connection to compile this
statement against.
+ :param schema_translate_map: dictionary of schema names to be
+ translated when forming the resultant SQL
+
+ .. versionadded:: 1.1
+
:param compile_kwargs: additional kwargs that will be
passed to the initial call to :meth:`.Compiled.process`.
- .. versionadded:: 0.8
"""
self.dialect = dialect
self.bind = bind
+ self.preparer = self.dialect.identifier_preparer
+ if schema_translate_map:
+ self.preparer = self.preparer._with_schema_translate(
+ schema_translate_map)
+
if statement is not None:
self.statement = statement
self.can_execute = statement.supports_execution
@@ -385,8 +395,6 @@ class SQLCompiler(Compiled):
self.ctes = None
- # an IdentifierPreparer that formats the quoting of identifiers
- self.preparer = dialect.identifier_preparer
self.label_length = dialect.label_length \
or dialect.max_identifier_length
@@ -653,8 +661,16 @@ class SQLCompiler(Compiled):
if table is None or not include_table or not table.named_with_column:
return name
else:
- if table.schema:
- schema_prefix = self.preparer.quote_schema(table.schema) + '.'
+
+ # inlining of preparer._get_effective_schema
+ effective_schema = table.schema
+ if self.preparer.schema_translate_map:
+ effective_schema = self.preparer.schema_translate_map.get(
+ effective_schema, effective_schema)
+
+ if effective_schema:
+ schema_prefix = self.preparer.quote_schema(
+ effective_schema) + '.'
else:
schema_prefix = ''
tablename = table.name
@@ -1814,8 +1830,15 @@ class SQLCompiler(Compiled):
def visit_table(self, table, asfrom=False, iscrud=False, ashint=False,
fromhints=None, use_schema=True, **kwargs):
if asfrom or ashint:
- if use_schema and getattr(table, "schema", None):
- ret = self.preparer.quote_schema(table.schema) + \
+
+ # inlining of preparer._get_effective_schema
+ effective_schema = table.schema
+ if self.preparer.schema_translate_map:
+ effective_schema = self.preparer.schema_translate_map.get(
+ effective_schema, effective_schema)
+
+ if use_schema and effective_schema:
+ ret = self.preparer.quote_schema(effective_schema) + \
"." + self.preparer.quote(table.name)
else:
ret = self.preparer.quote(table.name)
@@ -2103,10 +2126,6 @@ class DDLCompiler(Compiled):
def type_compiler(self):
return self.dialect.type_compiler
- @property
- def preparer(self):
- return self.dialect.identifier_preparer
-
def construct_params(self, params=None):
return None
@@ -2116,7 +2135,7 @@ class DDLCompiler(Compiled):
if isinstance(ddl.target, schema.Table):
context = context.copy()
- preparer = self.dialect.identifier_preparer
+ preparer = self.preparer
path = preparer.format_table_seq(ddl.target)
if len(path) == 1:
table, sch = path[0], ''
@@ -2142,7 +2161,7 @@ class DDLCompiler(Compiled):
def visit_create_table(self, create):
table = create.element
- preparer = self.dialect.identifier_preparer
+ preparer = self.preparer
text = "\nCREATE "
if table._prefixes:
@@ -2269,9 +2288,12 @@ class DDLCompiler(Compiled):
index, include_schema=True)
def _prepared_index_name(self, index, include_schema=False):
- if include_schema and index.table is not None and index.table.schema:
- schema = index.table.schema
- schema_name = self.preparer.quote_schema(schema)
+ if index.table is not None:
+ effective_schema = self.preparer._get_effective_schema(index.table)
+ else:
+ effective_schema = None
+ if include_schema and effective_schema:
+ schema_name = self.preparer.quote_schema(effective_schema)
else:
schema_name = None
@@ -2399,7 +2421,7 @@ class DDLCompiler(Compiled):
return text
def visit_foreign_key_constraint(self, constraint):
- preparer = self.dialect.identifier_preparer
+ preparer = self.preparer
text = ""
if constraint.name is not None:
formatted_name = self.preparer.format_constraint(constraint)
@@ -2626,6 +2648,8 @@ class IdentifierPreparer(object):
illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
+ schema_translate_map = util.immutabledict()
+
def __init__(self, dialect, initial_quote='"',
final_quote=None, escape_quote='"', omit_schema=False):
"""Construct a new ``IdentifierPreparer`` object.
@@ -2650,6 +2674,12 @@ class IdentifierPreparer(object):
self.omit_schema = omit_schema
self._strings = {}
+ def _with_schema_translate(self, schema_translate_map):
+ prep = self.__class__.__new__(self.__class__)
+ prep.__dict__.update(self.__dict__)
+ prep.schema_translate_map = schema_translate_map
+ return prep
+
def _escape_identifier(self, value):
"""Escape an identifier.
@@ -2722,9 +2752,12 @@ class IdentifierPreparer(object):
def format_sequence(self, sequence, use_schema=True):
name = self.quote(sequence.name)
+
+ effective_schema = self._get_effective_schema(sequence)
+
if (not self.omit_schema and use_schema and
- sequence.schema is not None):
- name = self.quote_schema(sequence.schema) + "." + name
+ effective_schema is not None):
+ name = self.quote_schema(effective_schema) + "." + name
return name
def format_label(self, label, name=None):
@@ -2747,15 +2780,25 @@ class IdentifierPreparer(object):
return None
return self.quote(constraint.name)
+ def _get_effective_schema(self, table):
+ effective_schema = table.schema
+ if self.schema_translate_map:
+ effective_schema = self.schema_translate_map.get(
+ effective_schema, effective_schema)
+ return effective_schema
+
def format_table(self, table, use_schema=True, name=None):
"""Prepare a quoted table and schema name."""
if name is None:
name = table.name
result = self.quote(name)
+
+ effective_schema = self._get_effective_schema(table)
+
if not self.omit_schema and use_schema \
- and getattr(table, "schema", None):
- result = self.quote_schema(table.schema) + "." + result
+ and effective_schema:
+ result = self.quote_schema(effective_schema) + "." + result
return result
def format_schema(self, name, quote=None):
@@ -2794,9 +2837,11 @@ class IdentifierPreparer(object):
# ('database', 'owner', etc.) could override this and return
# a longer sequence.
+ effective_schema = self._get_effective_schema(table)
+
if not self.omit_schema and use_schema and \
- getattr(table, 'schema', None):
- return (self.quote_schema(table.schema),
+ effective_schema:
+ return (self.quote_schema(effective_schema),
self.format_table(table, use_schema=False))
else:
return (self.format_table(table, use_schema=False), )
diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py
index 71018f132..7225da551 100644
--- a/lib/sqlalchemy/sql/ddl.py
+++ b/lib/sqlalchemy/sql/ddl.py
@@ -679,13 +679,16 @@ class SchemaGenerator(DDLBase):
def _can_create_table(self, table):
self.dialect.validate_identifier(table.name)
- if table.schema:
- self.dialect.validate_identifier(table.schema)
+ effective_schema = self.connection._get_effective_schema(table)
+ if effective_schema:
+ self.dialect.validate_identifier(effective_schema)
return not self.checkfirst or \
not self.dialect.has_table(self.connection,
- table.name, schema=table.schema)
+ table.name, schema=effective_schema)
def _can_create_sequence(self, sequence):
+ effective_schema = self.connection._get_effective_schema(sequence)
+
return self.dialect.supports_sequences and \
(
(not self.dialect.sequences_optional or
@@ -695,7 +698,7 @@ class SchemaGenerator(DDLBase):
not self.dialect.has_sequence(
self.connection,
sequence.name,
- schema=sequence.schema)
+ schema=effective_schema)
)
)
@@ -882,12 +885,14 @@ class SchemaDropper(DDLBase):
def _can_drop_table(self, table):
self.dialect.validate_identifier(table.name)
- if table.schema:
- self.dialect.validate_identifier(table.schema)
+ effective_schema = self.connection._get_effective_schema(table)
+ if effective_schema:
+ self.dialect.validate_identifier(effective_schema)
return not self.checkfirst or self.dialect.has_table(
- self.connection, table.name, schema=table.schema)
+ self.connection, table.name, schema=effective_schema)
def _can_drop_sequence(self, sequence):
+ effective_schema = self.connection._get_effective_schema(sequence)
return self.dialect.supports_sequences and \
((not self.dialect.sequences_optional or
not sequence.optional) and
@@ -895,7 +900,7 @@ class SchemaDropper(DDLBase):
self.dialect.has_sequence(
self.connection,
sequence.name,
- schema=sequence.schema))
+ schema=effective_schema))
)
def visit_index(self, index):
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index 63667654d..ad0aa4362 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -273,7 +273,8 @@ class AssertsCompiledSQL(object):
check_prefetch=None,
use_default_dialect=False,
allow_dialect_select=False,
- literal_binds=False):
+ literal_binds=False,
+ schema_translate_map=None):
if use_default_dialect:
dialect = default.DefaultDialect()
elif allow_dialect_select:
@@ -292,6 +293,9 @@ class AssertsCompiledSQL(object):
kw = {}
compile_kwargs = {}
+ if schema_translate_map:
+ kw['schema_translate_map'] = schema_translate_map
+
if params is not None:
kw['column_keys'] = list(params)
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index 39d078985..904149c16 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -87,13 +87,18 @@ class CompiledSQL(SQLMatchRule):
compare_dialect = self._compile_dialect(execute_observed)
if isinstance(context.compiled.statement, _DDLCompiles):
compiled = \
- context.compiled.statement.compile(dialect=compare_dialect)
+ context.compiled.statement.compile(
+ dialect=compare_dialect,
+ schema_translate_map=context.
+ compiled.preparer.schema_translate_map)
else:
compiled = (
context.compiled.statement.compile(
dialect=compare_dialect,
column_keys=context.compiled.column_keys,
- inline=context.compiled.inline)
+ inline=context.compiled.inline,
+ schema_translate_map=context.
+ compiled.preparer.schema_translate_map)
)
_received_statement = re.sub(r'[\n\t]', '', util.text_type(compiled))
parameters = execute_observed.parameters