summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2015-06-18 17:11:49 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2015-06-18 17:11:49 -0400
commit89a6fa3c1436a490d3c664ee5d37aa95e2d8a9a9 (patch)
tree20bc0285443babaf844b02a4625e7586d223df97
parent0c71f985e4885e4cd5c04eee46d73d7870bad6df (diff)
downloadalembic-89a6fa3c1436a490d3c664ee5d37aa95e2d8a9a9.tar.gz
- factor schema object creator functions into a separate object
-rw-r--r--alembic/operations/base.py194
-rw-r--r--alembic/operations/schemaobj.py137
-rw-r--r--tests/test_batch.py18
3 files changed, 185 insertions, 164 deletions
diff --git a/alembic/operations/base.py b/alembic/operations/base.py
index cda5aa2..59fa7af 100644
--- a/alembic/operations/base.py
+++ b/alembic/operations/base.py
@@ -1,11 +1,11 @@
from contextlib import contextmanager
-from sqlalchemy.types import NULLTYPE, Integer
+from sqlalchemy.types import NULLTYPE
from sqlalchemy import schema as sa_schema
from .. import util
from . import batch
-from ..util.compat import string_types
+from . import schemaobj
from ..ddl import impl
__all__ = ('Operations', 'BatchOperations')
@@ -56,6 +56,8 @@ class Operations(object):
else:
self.impl = impl
+ self.schema_obj = schemaobj.SchemaObjects(migration_context)
+
@classmethod
@contextmanager
def context(cls, migration_context):
@@ -65,131 +67,6 @@ class Operations(object):
yield op
_remove_proxy()
- def _primary_key_constraint(self, name, table_name, cols, schema=None):
- m = self._metadata()
- columns = [sa_schema.Column(n, NULLTYPE) for n in cols]
- t1 = sa_schema.Table(table_name, m,
- *columns,
- schema=schema)
- p = sa_schema.PrimaryKeyConstraint(*columns, name=name)
- t1.append_constraint(p)
- return p
-
- def _foreign_key_constraint(self, name, source, referent,
- local_cols, remote_cols,
- onupdate=None, ondelete=None,
- deferrable=None, source_schema=None,
- referent_schema=None, initially=None,
- match=None, **dialect_kw):
- m = self._metadata()
- if source == referent:
- t1_cols = local_cols + remote_cols
- else:
- t1_cols = local_cols
- sa_schema.Table(
- referent, m,
- *[sa_schema.Column(n, NULLTYPE) for n in remote_cols],
- schema=referent_schema)
-
- t1 = sa_schema.Table(
- source, m,
- *[sa_schema.Column(n, NULLTYPE) for n in t1_cols],
- schema=source_schema)
-
- tname = "%s.%s" % (referent_schema, referent) if referent_schema \
- else referent
-
- if util.sqla_08:
- # "match" kw unsupported in 0.7
- dialect_kw['match'] = match
-
- f = sa_schema.ForeignKeyConstraint(local_cols,
- ["%s.%s" % (tname, n)
- for n in remote_cols],
- name=name,
- onupdate=onupdate,
- ondelete=ondelete,
- deferrable=deferrable,
- initially=initially,
- **dialect_kw
- )
- t1.append_constraint(f)
-
- return f
-
- def _unique_constraint(self, name, source, local_cols, schema=None, **kw):
- t = sa_schema.Table(
- source, self._metadata(),
- *[sa_schema.Column(n, NULLTYPE) for n in local_cols],
- schema=schema)
- kw['name'] = name
- uq = sa_schema.UniqueConstraint(*[t.c[n] for n in local_cols], **kw)
- # TODO: need event tests to ensure the event
- # is fired off here
- t.append_constraint(uq)
- return uq
-
- def _check_constraint(self, name, source, condition, schema=None, **kw):
- t = sa_schema.Table(source, self._metadata(),
- sa_schema.Column('x', Integer), schema=schema)
- ck = sa_schema.CheckConstraint(condition, name=name, **kw)
- t.append_constraint(ck)
- return ck
-
- def _metadata(self):
- kw = {}
- if 'target_metadata' in self.migration_context.opts:
- mt = self.migration_context.opts['target_metadata']
- if hasattr(mt, 'naming_convention'):
- kw['naming_convention'] = mt.naming_convention
- return sa_schema.MetaData(**kw)
-
- def _table(self, name, *columns, **kw):
- m = self._metadata()
- t = sa_schema.Table(name, m, *columns, **kw)
- for f in t.foreign_keys:
- self._ensure_table_for_fk(m, f)
- return t
-
- def _column(self, name, type_, **kw):
- return sa_schema.Column(name, type_, **kw)
-
- def _index(self, name, tablename, columns, schema=None, **kw):
- t = sa_schema.Table(
- tablename or 'no_table', self._metadata(),
- schema=schema
- )
- idx = sa_schema.Index(
- name,
- *[impl._textual_index_column(t, n) for n in columns],
- **kw)
- return idx
-
- def _parse_table_key(self, table_key):
- if '.' in table_key:
- tokens = table_key.split('.')
- sname = ".".join(tokens[0:-1])
- tname = tokens[-1]
- else:
- tname = table_key
- sname = None
- return (sname, tname)
-
- def _ensure_table_for_fk(self, metadata, fk):
- """create a placeholder Table object for the referent of a
- ForeignKey.
-
- """
- if isinstance(fk._colspec, string_types):
- table_key, cname = fk._colspec.rsplit('.', 1)
- sname, tname = self._parse_table_key(table_key)
- if table_key not in metadata.tables:
- rel_t = sa_schema.Table(tname, metadata, schema=sname)
- else:
- rel_t = metadata.tables[table_key]
- if cname not in rel_t.c:
- rel_t.append_column(sa_schema.Column(cname, NULLTYPE))
-
@contextmanager
def batch_alter_table(
self, table_name, schema=None, recreate="auto", copy_from=None,
@@ -458,10 +335,11 @@ class Operations(object):
constraint._create_rule(compiler))
if existing_type and type_:
- t = self._table(table_name,
- sa_schema.Column(column_name, existing_type),
- schema=schema
- )
+ t = self.schema_obj.table(
+ table_name,
+ sa_schema.Column(column_name, existing_type),
+ schema=schema
+ )
for constraint in t.constraints:
if _count_constraint(constraint):
self.impl.drop_constraint(constraint)
@@ -480,10 +358,11 @@ class Operations(object):
)
if type_:
- t = self._table(table_name,
- sa_schema.Column(column_name, type_),
- schema=schema
- )
+ t = self.schema_obj.table(
+ table_name,
+ self.schema_obj.column(column_name, type_),
+ schema=schema
+ )
for constraint in t.constraints:
if _count_constraint(constraint):
self.impl.add_constraint(constraint)
@@ -589,7 +468,7 @@ class Operations(object):
"""
- t = self._table(table_name, column, schema=schema)
+ t = self.schema_obj.table(table_name, column, schema=schema)
self.impl.add_column(
table_name,
column,
@@ -647,7 +526,7 @@ class Operations(object):
self.impl.drop_column(
table_name,
- self._column(column_name, NULLTYPE),
+ self.schema_obj.column(column_name, NULLTYPE),
**kw
)
@@ -692,8 +571,8 @@ class Operations(object):
"""
self.impl.add_constraint(
- self._primary_key_constraint(name, table_name, cols,
- schema)
+ self.schema_obj.primary_key_constraint(
+ name, table_name, cols, schema)
)
def create_foreign_key(self, name, source, referent, local_cols,
@@ -747,14 +626,15 @@ class Operations(object):
"""
self.impl.add_constraint(
- self._foreign_key_constraint(name, source, referent,
- local_cols, remote_cols,
- onupdate=onupdate, ondelete=ondelete,
- deferrable=deferrable,
- source_schema=source_schema,
- referent_schema=referent_schema,
- initially=initially, match=match,
- **dialect_kw)
+ self.schema_obj.foreign_key_constraint(
+ name, source, referent,
+ local_cols, remote_cols,
+ onupdate=onupdate, ondelete=ondelete,
+ deferrable=deferrable,
+ source_schema=source_schema,
+ referent_schema=referent_schema,
+ initially=initially, match=match,
+ **dialect_kw)
)
def create_unique_constraint(self, name, source, local_cols,
@@ -802,8 +682,9 @@ class Operations(object):
"""
self.impl.add_constraint(
- self._unique_constraint(name, source, local_cols,
- schema=schema, **kw)
+ self.schema_obj.unique_constraint(
+ name, source, local_cols,
+ schema=schema, **kw)
)
def create_check_constraint(self, name, source, condition,
@@ -852,7 +733,7 @@ class Operations(object):
"""
self.impl.add_constraint(
- self._check_constraint(
+ self.schema_obj.check_constraint(
name, source, condition, schema=schema, **kw)
)
@@ -941,7 +822,7 @@ class Operations(object):
object is returned.
"""
- table = self._table(name, *columns, **kw)
+ table = self.schema_obj.table(name, *columns, **kw)
self.impl.create_table(table)
return table
@@ -968,7 +849,7 @@ class Operations(object):
"""
self.impl.drop_table(
- self._table(name, **kw)
+ self.schema_obj.table(name, **kw)
)
def create_index(self, name, table_name, columns, schema=None,
@@ -1024,8 +905,9 @@ class Operations(object):
"""
self.impl.create_index(
- self._index(name, table_name, columns, schema=schema,
- unique=unique, quote=quote, **kw)
+ self.schema_obj.index(
+ name, table_name, columns, schema=schema,
+ unique=unique, quote=quote, **kw)
)
@util._with_legacy_names([('tablename', 'table_name')])
@@ -1052,7 +934,7 @@ class Operations(object):
# need a dummy column name here since SQLAlchemy
# 0.7.6 and further raises on Index with no columns
self.impl.drop_index(
- self._index(name, table_name, ['x'], schema=schema)
+ self.schema_obj.index(name, table_name, ['x'], schema=schema)
)
@util._with_legacy_names([("type", "type_")])
@@ -1073,7 +955,7 @@ class Operations(object):
"""
- t = self._table(table_name, schema=schema)
+ t = self.schema_obj.table(table_name, schema=schema)
types = {
'foreignkey': lambda name: sa_schema.ForeignKeyConstraint(
[], [], name=name),
diff --git a/alembic/operations/schemaobj.py b/alembic/operations/schemaobj.py
new file mode 100644
index 0000000..b5a8e08
--- /dev/null
+++ b/alembic/operations/schemaobj.py
@@ -0,0 +1,137 @@
+from sqlalchemy import schema as sa_schema
+from sqlalchemy.types import NULLTYPE, Integer
+from ..util.compat import string_types
+from .. import util
+from ..ddl import impl
+
+
+class SchemaObjects(object):
+
+ def __init__(self, migration_context):
+ self.migration_context = migration_context
+
+ def primary_key_constraint(self, name, table_name, cols, schema=None):
+ m = self.metadata()
+ columns = [sa_schema.Column(n, NULLTYPE) for n in cols]
+ t1 = sa_schema.Table(table_name, m,
+ *columns,
+ schema=schema)
+ p = sa_schema.PrimaryKeyConstraint(*columns, name=name)
+ t1.append_constraint(p)
+ return p
+
+ def foreign_key_constraint(
+ self, name, source, referent,
+ local_cols, remote_cols,
+ onupdate=None, ondelete=None,
+ deferrable=None, source_schema=None,
+ referent_schema=None, initially=None,
+ match=None, **dialect_kw):
+ m = self.metadata()
+ if source == referent:
+ t1_cols = local_cols + remote_cols
+ else:
+ t1_cols = local_cols
+ sa_schema.Table(
+ referent, m,
+ *[sa_schema.Column(n, NULLTYPE) for n in remote_cols],
+ schema=referent_schema)
+
+ t1 = sa_schema.Table(
+ source, m,
+ *[sa_schema.Column(n, NULLTYPE) for n in t1_cols],
+ schema=source_schema)
+
+ tname = "%s.%s" % (referent_schema, referent) if referent_schema \
+ else referent
+
+ if util.sqla_08:
+ # "match" kw unsupported in 0.7
+ dialect_kw['match'] = match
+
+ f = sa_schema.ForeignKeyConstraint(local_cols,
+ ["%s.%s" % (tname, n)
+ for n in remote_cols],
+ name=name,
+ onupdate=onupdate,
+ ondelete=ondelete,
+ deferrable=deferrable,
+ initially=initially,
+ **dialect_kw
+ )
+ t1.append_constraint(f)
+
+ return f
+
+ def unique_constraint(self, name, source, local_cols, schema=None, **kw):
+ t = sa_schema.Table(
+ source, self.metadata(),
+ *[sa_schema.Column(n, NULLTYPE) for n in local_cols],
+ schema=schema)
+ kw['name'] = name
+ uq = sa_schema.UniqueConstraint(*[t.c[n] for n in local_cols], **kw)
+ # TODO: need event tests to ensure the event
+ # is fired off here
+ t.append_constraint(uq)
+ return uq
+
+ def check_constraint(self, name, source, condition, schema=None, **kw):
+ t = sa_schema.Table(source, self.metadata(),
+ sa_schema.Column('x', Integer), schema=schema)
+ ck = sa_schema.CheckConstraint(condition, name=name, **kw)
+ t.append_constraint(ck)
+ return ck
+
+ def metadata(self):
+ kw = {}
+ if 'target_metadata' in self.migration_context.opts:
+ mt = self.migration_context.opts['target_metadata']
+ if hasattr(mt, 'naming_convention'):
+ kw['naming_convention'] = mt.naming_convention
+ return sa_schema.MetaData(**kw)
+
+ def table(self, name, *columns, **kw):
+ m = self.metadata()
+ t = sa_schema.Table(name, m, *columns, **kw)
+ for f in t.foreign_keys:
+ self._ensure_table_for_fk(m, f)
+ return t
+
+ def column(self, name, type_, **kw):
+ return sa_schema.Column(name, type_, **kw)
+
+ def index(self, name, tablename, columns, schema=None, **kw):
+ t = sa_schema.Table(
+ tablename or 'no_table', self.metadata(),
+ schema=schema
+ )
+ idx = sa_schema.Index(
+ name,
+ *[impl._textual_index_column(t, n) for n in columns],
+ **kw)
+ return idx
+
+ def _parse_table_key(self, table_key):
+ if '.' in table_key:
+ tokens = table_key.split('.')
+ sname = ".".join(tokens[0:-1])
+ tname = tokens[-1]
+ else:
+ tname = table_key
+ sname = None
+ return (sname, tname)
+
+ def _ensure_table_for_fk(self, metadata, fk):
+ """create a placeholder Table object for the referent of a
+ ForeignKey.
+
+ """
+ if isinstance(fk._colspec, string_types):
+ table_key, cname = fk._colspec.rsplit('.', 1)
+ sname, tname = self._parse_table_key(table_key)
+ if table_key not in metadata.tables:
+ rel_t = sa_schema.Table(tname, metadata, schema=sname)
+ else:
+ rel_t = metadata.tables[table_key]
+ if cname not in rel_t.c:
+ rel_t.append_column(sa_schema.Column(cname, NULLTYPE))
diff --git a/tests/test_batch.py b/tests/test_batch.py
index c827ac4..4226c8e 100644
--- a/tests/test_batch.py
+++ b/tests/test_batch.py
@@ -328,7 +328,7 @@ class BatchApplyTest(TestBase):
impl = self._simple_fixture()
col = Column('g', Integer)
# operations.add_column produces a table
- t = self.op._table('tname', col) # noqa
+ t = self.op.schema_obj.table('tname', col) # noqa
impl.add_column('tname', col)
new_table = self._assert_impl(impl, colnames=['id', 'x', 'y', 'g'])
eq_(new_table.c.g.name, 'g')
@@ -418,7 +418,7 @@ class BatchApplyTest(TestBase):
def test_add_fk(self):
impl = self._simple_fixture()
impl.add_column('tname', Column('user_id', Integer))
- fk = self.op._foreign_key_constraint(
+ fk = self.op.schema_obj.foreign_key_constraint(
'fk1', 'tname', 'user',
['user_id'], ['id'])
impl.add_constraint(fk)
@@ -445,7 +445,7 @@ class BatchApplyTest(TestBase):
def test_add_uq(self):
impl = self._simple_fixture()
- uq = self.op._unique_constraint(
+ uq = self.op.schema_obj.unique_constraint(
'uq1', 'tname', ['y']
)
@@ -457,7 +457,7 @@ class BatchApplyTest(TestBase):
def test_drop_uq(self):
impl = self._uq_fixture()
- uq = self.op._unique_constraint(
+ uq = self.op.schema_obj.unique_constraint(
'uq1', 'tname', ['y']
)
impl.drop_constraint(uq)
@@ -467,7 +467,7 @@ class BatchApplyTest(TestBase):
def test_create_index(self):
impl = self._simple_fixture()
- ix = self.op._index('ix1', 'tname', ['y'])
+ ix = self.op.schema_obj.index('ix1', 'tname', ['y'])
impl.create_index(ix)
self._assert_impl(
@@ -477,7 +477,7 @@ class BatchApplyTest(TestBase):
def test_drop_index(self):
impl = self._ix_fixture()
- ix = self.op._index('ix1', 'tname', ['y'])
+ ix = self.op.schema_obj.index('ix1', 'tname', ['y'])
impl.drop_index(ix)
self._assert_impl(
impl, colnames=['id', 'x', 'y'],
@@ -501,8 +501,10 @@ class BatchAPITest(TestBase):
batch = op.batch_alter_table(
'tname', recreate='never', schema=schema).__enter__()
- with mock.patch("alembic.operations.base.sa_schema") as mock_schema:
- yield batch
+ mock_schema = mock.MagicMock()
+ with mock.patch("alembic.operations.schemaobj.sa_schema", mock_schema):
+ with mock.patch("alembic.operations.base.sa_schema", mock_schema):
+ yield batch
batch.impl.flush()
self.mock_schema = mock_schema