summaryrefslogtreecommitdiff
path: root/alembic/operations/schemaobj.py
diff options
context:
space:
mode:
Diffstat (limited to 'alembic/operations/schemaobj.py')
-rw-r--r--alembic/operations/schemaobj.py157
1 files changed, 157 insertions, 0 deletions
diff --git a/alembic/operations/schemaobj.py b/alembic/operations/schemaobj.py
new file mode 100644
index 0000000..b590aca
--- /dev/null
+++ b/alembic/operations/schemaobj.py
@@ -0,0 +1,157 @@
+from sqlalchemy import schema as sa_schema
+from sqlalchemy.types import NULLTYPE, Integer
+from ..util.compat import string_types
+from .. import util
+
+
+class SchemaObjects(object):
+
+ def __init__(self, migration_context=None):
+ self.migration_context = migration_context
+
+ def primary_key_constraint(self, name, table_name, cols, schema=None):
+ m = self.metadata()
+ columns = [sa_schema.Column(n, NULLTYPE) for n in cols]
+ t1 = sa_schema.Table(table_name, m,
+ *columns,
+ schema=schema)
+ p = sa_schema.PrimaryKeyConstraint(*columns, name=name)
+ t1.append_constraint(p)
+ return p
+
+ def foreign_key_constraint(
+ self, name, source, referent,
+ local_cols, remote_cols,
+ onupdate=None, ondelete=None,
+ deferrable=None, source_schema=None,
+ referent_schema=None, initially=None,
+ match=None, **dialect_kw):
+ m = self.metadata()
+ if source == referent:
+ t1_cols = local_cols + remote_cols
+ else:
+ t1_cols = local_cols
+ sa_schema.Table(
+ referent, m,
+ *[sa_schema.Column(n, NULLTYPE) for n in remote_cols],
+ schema=referent_schema)
+
+ t1 = sa_schema.Table(
+ source, m,
+ *[sa_schema.Column(n, NULLTYPE) for n in t1_cols],
+ schema=source_schema)
+
+ tname = "%s.%s" % (referent_schema, referent) if referent_schema \
+ else referent
+
+ if util.sqla_08:
+ # "match" kw unsupported in 0.7
+ dialect_kw['match'] = match
+
+ f = sa_schema.ForeignKeyConstraint(local_cols,
+ ["%s.%s" % (tname, n)
+ for n in remote_cols],
+ name=name,
+ onupdate=onupdate,
+ ondelete=ondelete,
+ deferrable=deferrable,
+ initially=initially,
+ **dialect_kw
+ )
+ t1.append_constraint(f)
+
+ return f
+
+ def unique_constraint(self, name, source, local_cols, schema=None, **kw):
+ t = sa_schema.Table(
+ source, self.metadata(),
+ *[sa_schema.Column(n, NULLTYPE) for n in local_cols],
+ schema=schema)
+ kw['name'] = name
+ uq = sa_schema.UniqueConstraint(*[t.c[n] for n in local_cols], **kw)
+ # TODO: need event tests to ensure the event
+ # is fired off here
+ t.append_constraint(uq)
+ return uq
+
+ def check_constraint(self, name, source, condition, schema=None, **kw):
+ t = sa_schema.Table(source, self.metadata(),
+ sa_schema.Column('x', Integer), schema=schema)
+ ck = sa_schema.CheckConstraint(condition, name=name, **kw)
+ t.append_constraint(ck)
+ return ck
+
+ def generic_constraint(self, name, table_name, type_, schema=None, **kw):
+ t = self.table(table_name, schema=schema)
+ types = {
+ 'foreignkey': lambda name: sa_schema.ForeignKeyConstraint(
+ [], [], name=name),
+ 'primary': sa_schema.PrimaryKeyConstraint,
+ 'unique': sa_schema.UniqueConstraint,
+ 'check': lambda name: sa_schema.CheckConstraint("", name=name),
+ None: sa_schema.Constraint
+ }
+ try:
+ const = types[type_]
+ except KeyError:
+ raise TypeError("'type' can be one of %s" %
+ ", ".join(sorted(repr(x) for x in types)))
+ else:
+ const = const(name=name)
+ t.append_constraint(const)
+ return const
+
+ def metadata(self):
+ kw = {}
+ if self.migration_context is not None and \
+ 'target_metadata' in self.migration_context.opts:
+ mt = self.migration_context.opts['target_metadata']
+ if hasattr(mt, 'naming_convention'):
+ kw['naming_convention'] = mt.naming_convention
+ return sa_schema.MetaData(**kw)
+
+ def table(self, name, *columns, **kw):
+ m = self.metadata()
+ t = sa_schema.Table(name, m, *columns, **kw)
+ for f in t.foreign_keys:
+ self._ensure_table_for_fk(m, f)
+ return t
+
+ def column(self, name, type_, **kw):
+ return sa_schema.Column(name, type_, **kw)
+
+ def index(self, name, tablename, columns, schema=None, **kw):
+ t = sa_schema.Table(
+ tablename or 'no_table', self.metadata(),
+ schema=schema
+ )
+ idx = sa_schema.Index(
+ name,
+ *[util.sqla_compat._textual_index_column(t, n) for n in columns],
+ **kw)
+ return idx
+
+ def _parse_table_key(self, table_key):
+ if '.' in table_key:
+ tokens = table_key.split('.')
+ sname = ".".join(tokens[0:-1])
+ tname = tokens[-1]
+ else:
+ tname = table_key
+ sname = None
+ return (sname, tname)
+
+ def _ensure_table_for_fk(self, metadata, fk):
+ """create a placeholder Table object for the referent of a
+ ForeignKey.
+
+ """
+ if isinstance(fk._colspec, string_types):
+ table_key, cname = fk._colspec.rsplit('.', 1)
+ sname, tname = self._parse_table_key(table_key)
+ if table_key not in metadata.tables:
+ rel_t = sa_schema.Table(tname, metadata, schema=sname)
+ else:
+ rel_t = metadata.tables[table_key]
+ if cname not in rel_t.c:
+ rel_t.append_column(sa_schema.Column(cname, NULLTYPE))