summaryrefslogtreecommitdiff
path: root/migrate
diff options
context:
space:
mode:
authoriElectric <unknown>2010-06-11 02:04:20 +0200
committeriElectric <unknown>2010-06-11 02:04:20 +0200
commit59e83b76b94eb7522c0ea109cf8259a643a1ea2f (patch)
tree0b981511b165373975195326326ef6538f85b8ab /migrate
parent922bd7c08c3ab9ac4ebde792be600413abcc7c72 (diff)
parentfd8d3136832601a1835ff4e2eb1d2969ef9af165 (diff)
downloadsqlalchemy-migrate-59e83b76b94eb7522c0ea109cf8259a643a1ea2f.tar.gz
merge with https://robertanthonyfarrell-0point6-release-patches.googlecode.com/hg/
Diffstat (limited to 'migrate')
-rw-r--r--migrate/tests/__init__.py4
-rw-r--r--migrate/tests/changeset/__init__.py0
-rw-r--r--migrate/tests/changeset/test_changeset.py774
-rw-r--r--migrate/tests/changeset/test_constraint.py282
-rw-r--r--migrate/tests/fixture/__init__.py66
-rw-r--r--migrate/tests/fixture/base.py32
-rw-r--r--migrate/tests/fixture/database.py170
-rw-r--r--migrate/tests/fixture/models.py14
-rw-r--r--migrate/tests/fixture/pathed.py77
-rw-r--r--migrate/tests/fixture/shell.py32
-rw-r--r--migrate/tests/integrated/__init__.py0
-rw-r--r--migrate/tests/integrated/test_docs.py18
-rw-r--r--migrate/tests/versioning/__init__.py0
-rw-r--r--migrate/tests/versioning/test_api.py120
-rw-r--r--migrate/tests/versioning/test_cfgparse.py27
-rw-r--r--migrate/tests/versioning/test_database.py11
-rw-r--r--migrate/tests/versioning/test_genmodel.py13
-rw-r--r--migrate/tests/versioning/test_keyedinstance.py45
-rw-r--r--migrate/tests/versioning/test_pathed.py51
-rw-r--r--migrate/tests/versioning/test_repository.py198
-rw-r--r--migrate/tests/versioning/test_runchangeset.py52
-rw-r--r--migrate/tests/versioning/test_schema.py202
-rw-r--r--migrate/tests/versioning/test_schemadiff.py170
-rw-r--r--migrate/tests/versioning/test_script.py244
-rw-r--r--migrate/tests/versioning/test_shell.py547
-rw-r--r--migrate/tests/versioning/test_template.py70
-rw-r--r--migrate/tests/versioning/test_util.py90
-rw-r--r--migrate/tests/versioning/test_version.py165
-rw-r--r--migrate/versioning/migrate_repository.py13
29 files changed, 3482 insertions, 5 deletions
diff --git a/migrate/tests/__init__.py b/migrate/tests/__init__.py
new file mode 100644
index 0000000..f5c0c5b
--- /dev/null
+++ b/migrate/tests/__init__.py
@@ -0,0 +1,4 @@
+# make this package available during imports as long as we support <python2.5
+import sys
+import os
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
diff --git a/migrate/tests/changeset/__init__.py b/migrate/tests/changeset/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/migrate/tests/changeset/__init__.py
diff --git a/migrate/tests/changeset/test_changeset.py b/migrate/tests/changeset/test_changeset.py
new file mode 100644
index 0000000..beeb509
--- /dev/null
+++ b/migrate/tests/changeset/test_changeset.py
@@ -0,0 +1,774 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import sqlalchemy
+from sqlalchemy import *
+
+from migrate import changeset
+from migrate.changeset import *
+from migrate.changeset.schema import ColumnDelta
+from migrate.tests import fixture
+
+
+class TestAddDropColumn(fixture.DB):
+ """Test add/drop column through all possible interfaces
+ also test for constraints
+ """
+ level = fixture.DB.CONNECT
+ table_name = 'tmp_adddropcol'
+ table_int = 0
+
+ def _setup(self, url):
+ super(TestAddDropColumn, self)._setup(url)
+ self.meta = MetaData()
+ self.table = Table(self.table_name, self.meta,
+ Column('id', Integer, unique=True),
+ )
+ self.meta.bind = self.engine
+ if self.engine.has_table(self.table.name):
+ self.table.drop()
+ self.table.create()
+
+ def _teardown(self):
+ if self.engine.has_table(self.table.name):
+ self.table.drop()
+ self.meta.clear()
+ super(TestAddDropColumn,self)._teardown()
+
+ def run_(self, create_column_func, drop_column_func, *col_p, **col_k):
+ col_name = 'data'
+
+ def assert_numcols(num_of_expected_cols):
+ # number of cols should be correct in table object and in database
+ self.refresh_table(self.table_name)
+ result = len(self.table.c)
+
+ self.assertEquals(result, num_of_expected_cols),
+ if col_k.get('primary_key', None):
+ # new primary key: check its length too
+ result = len(self.table.primary_key)
+ self.assertEquals(result, num_of_expected_cols)
+
+ # we have 1 columns and there is no data column
+ assert_numcols(1)
+ self.assertTrue(getattr(self.table.c, 'data', None) is None)
+ if len(col_p) == 0:
+ col_p = [String(40)]
+ col = Column(col_name, *col_p, **col_k)
+ create_column_func(col)
+ assert_numcols(2)
+ # data column exists
+ self.assert_(self.table.c.data.type.length, 40)
+
+ col2 = self.table.c.data
+ drop_column_func(col2)
+ assert_numcols(1)
+
+ @fixture.usedb()
+ def test_undefined(self):
+ """Add/drop columns not yet defined in the table"""
+ def add_func(col):
+ return create_column(col, self.table)
+ def drop_func(col):
+ return drop_column(col, self.table)
+ return self.run_(add_func, drop_func)
+
+ @fixture.usedb()
+ def test_defined(self):
+ """Add/drop columns already defined in the table"""
+ def add_func(col):
+ self.meta.clear()
+ self.table = Table(self.table_name, self.meta,
+ Column('id', Integer, primary_key=True),
+ col,
+ )
+ return create_column(col)
+ def drop_func(col):
+ return drop_column(col)
+ return self.run_(add_func, drop_func)
+
+ @fixture.usedb()
+ def test_method_bound(self):
+ """Add/drop columns via column methods; columns bound to a table
+ ie. no table parameter passed to function
+ """
+ def add_func(col):
+ self.assert_(col.table is None, col.table)
+ self.table.append_column(col)
+ return col.create()
+ def drop_func(col):
+ #self.assert_(col.table is None,col.table)
+ #self.table.append_column(col)
+ return col.drop()
+ return self.run_(add_func, drop_func)
+
+ @fixture.usedb()
+ def test_method_notbound(self):
+ """Add/drop columns via column methods; columns not bound to a table"""
+ def add_func(col):
+ return col.create(self.table)
+ def drop_func(col):
+ return col.drop(self.table)
+ return self.run_(add_func, drop_func)
+
+ @fixture.usedb()
+ def test_tablemethod_obj(self):
+ """Add/drop columns via table methods; by column object"""
+ def add_func(col):
+ return self.table.create_column(col)
+ def drop_func(col):
+ return self.table.drop_column(col)
+ return self.run_(add_func, drop_func)
+
+ @fixture.usedb()
+ def test_tablemethod_name(self):
+ """Add/drop columns via table methods; by column name"""
+ def add_func(col):
+ # must be bound to table
+ self.table.append_column(col)
+ return self.table.create_column(col.name)
+ def drop_func(col):
+ # Not necessarily bound to table
+ return self.table.drop_column(col.name)
+ return self.run_(add_func, drop_func)
+
+ @fixture.usedb()
+ def test_byname(self):
+ """Add/drop columns via functions; by table object and column name"""
+ def add_func(col):
+ self.table.append_column(col)
+ return create_column(col.name, self.table)
+ def drop_func(col):
+ return drop_column(col.name, self.table)
+ return self.run_(add_func, drop_func)
+
+ @fixture.usedb()
+ def test_drop_column_not_in_table(self):
+ """Drop column by name"""
+ def add_func(col):
+ return self.table.create_column(col)
+ def drop_func(col):
+ self.table.c.remove(col)
+ return self.table.drop_column(col.name)
+ self.run_(add_func, drop_func)
+
+ @fixture.usedb()
+ def test_fk(self):
+ """Can create columns with foreign keys"""
+ # create FK's target
+ reftable = Table('tmp_ref', self.meta,
+ Column('id', Integer, primary_key=True),
+ )
+ if self.engine.has_table(reftable.name):
+ reftable.drop()
+ reftable.create()
+
+ # create column with fk
+ col = Column('data', Integer, ForeignKey(reftable.c.id))
+ if self.url.startswith('sqlite'):
+ self.assertRaises(changeset.exceptions.NotSupportedError,
+ col.create, self.table)
+ else:
+ col.create(self.table)
+
+ # check if constraint is added
+ for cons in self.table.constraints:
+ if isinstance(cons, sqlalchemy.schema.ForeignKeyConstraint):
+ break
+ else:
+ self.fail('No constraint found')
+
+ # TODO: test on db level if constraints work
+
+ self.assertEqual(reftable.c.id.name, col.foreign_keys[0].column.name)
+ col.drop(self.table)
+
+ if self.engine.has_table(reftable.name):
+ reftable.drop()
+
+ @fixture.usedb(not_supported='sqlite')
+ def test_pk(self):
+ """Can create columns with primary key"""
+ col = Column('data', Integer, nullable=False)
+ self.assertRaises(changeset.exceptions.InvalidConstraintError,
+ col.create, self.table, primary_key_name=True)
+ col.create(self.table, primary_key_name='data_pkey')
+
+ # check if constraint was added (cannot test on objects)
+ self.table.insert(values={'data': 4}).execute()
+ try:
+ self.table.insert(values={'data': 4}).execute()
+ except (sqlalchemy.exc.IntegrityError,
+ sqlalchemy.exc.ProgrammingError):
+ pass
+ else:
+ self.fail()
+
+ col.drop()
+
+ @fixture.usedb(not_supported='mysql')
+ def test_check(self):
+ """Can create columns with check constraint"""
+ col = Column('data',
+ Integer,
+ sqlalchemy.schema.CheckConstraint('data > 4'))
+ col.create(self.table)
+
+ # check if constraint was added (cannot test on objects)
+ self.table.insert(values={'data': 5}).execute()
+ try:
+ self.table.insert(values={'data': 3}).execute()
+ except (sqlalchemy.exc.IntegrityError,
+ sqlalchemy.exc.ProgrammingError):
+ pass
+ else:
+ self.fail()
+
+ col.drop()
+
+ @fixture.usedb(not_supported='sqlite')
+ def test_unique(self):
+ """Can create columns with unique constraint"""
+ self.assertRaises(changeset.exceptions.InvalidConstraintError,
+ Column('data', Integer, unique=True).create, self.table)
+ col = Column('data', Integer)
+ col.create(self.table, unique_name='data_unique')
+
+ # check if constraint was added (cannot test on objects)
+ self.table.insert(values={'data': 5}).execute()
+ try:
+ self.table.insert(values={'data': 5}).execute()
+ except (sqlalchemy.exc.IntegrityError,
+ sqlalchemy.exc.ProgrammingError):
+ pass
+ else:
+ self.fail()
+
+ col.drop(self.table)
+
+# TODO: remove already attached columns with indexes, uniques, pks, fks ..
+ @fixture.usedb()
+ def test_index(self):
+ """Can create columns with indexes"""
+ self.assertRaises(changeset.exceptions.InvalidConstraintError,
+ Column('data', Integer).create, self.table, index_name=True)
+ col = Column('data', Integer)
+ col.create(self.table, index_name='ix_data')
+
+ # check if index was added
+ self.table.insert(values={'data': 5}).execute()
+ try:
+ self.table.insert(values={'data': 5}).execute()
+ except (sqlalchemy.exc.IntegrityError,
+ sqlalchemy.exc.ProgrammingError):
+ pass
+ else:
+ self.fail()
+
+ Index('ix_data', col).drop(bind=self.engine)
+ col.drop()
+
+ @fixture.usedb()
+ def test_server_defaults(self):
+ """Can create columns with server_default values"""
+ col = Column('data', String(244), server_default='foobar')
+ col.create(self.table)
+
+ self.table.insert(values={'id': 10}).execute()
+ row = self._select_row()
+ self.assertEqual(u'foobar', row['data'])
+
+ col.drop()
+
+ @fixture.usedb()
+ def test_populate_default(self):
+ """Test populate_default=True"""
+ def default():
+ return 'foobar'
+ col = Column('data', String(244), default=default)
+ col.create(self.table, populate_default=True)
+
+ self.table.insert(values={'id': 10}).execute()
+ row = self._select_row()
+ self.assertEqual(u'foobar', row['data'])
+
+ col.drop()
+
+ # TODO: test sequence
+ # TODO: test quoting
+ # TODO: test non-autoname constraints
+
+
+class TestRename(fixture.DB):
+ """Tests for table and index rename methods"""
+ level = fixture.DB.CONNECT
+ meta = MetaData()
+
+ def _setup(self, url):
+ super(TestRename, self)._setup(url)
+ self.meta.bind = self.engine
+
+ @fixture.usedb(not_supported='firebird')
+ def test_rename_table(self):
+ """Tables can be renamed"""
+ c_name = 'col_1'
+ table_name1 = 'name_one'
+ table_name2 = 'name_two'
+ index_name1 = 'x' + table_name1
+ index_name2 = 'x' + table_name2
+
+ self.meta.clear()
+ self.column = Column(c_name, Integer)
+ self.table = Table(table_name1, self.meta, self.column)
+ self.index = Index(index_name1, self.column, unique=False)
+
+ if self.engine.has_table(self.table.name):
+ self.table.drop()
+ if self.engine.has_table(table_name2):
+ tmp = Table(table_name2, self.meta, autoload=True)
+ tmp.drop()
+ tmp.deregister()
+ del tmp
+ self.table.create()
+
+ def assert_table_name(expected, skip_object_check=False):
+ """Refresh a table via autoload
+ SA has changed some since this test was written; we now need to do
+ meta.clear() upon reloading a table - clear all rather than a
+ select few. So, this works only if we're working with one table at
+ a time (else, others will vanish too).
+ """
+ if not skip_object_check:
+ # Table object check
+ self.assertEquals(self.table.name,expected)
+ newname = self.table.name
+ else:
+ # we know the object's name isn't consistent: just assign it
+ newname = expected
+ # Table DB check
+ self.meta.clear()
+ self.table = Table(newname, self.meta, autoload=True)
+ self.assertEquals(self.table.name, expected)
+
+ def assert_index_name(expected, skip_object_check=False):
+ if not skip_object_check:
+ # Index object check
+ self.assertEquals(self.index.name, expected)
+ else:
+ # object is inconsistent
+ self.index.name = expected
+ # TODO: Index DB check
+
+ try:
+ # Table renames
+ assert_table_name(table_name1)
+ rename_table(self.table, table_name2)
+ assert_table_name(table_name2)
+ self.table.rename(table_name1)
+ assert_table_name(table_name1)
+
+ # test by just the string
+ rename_table(table_name1, table_name2, engine=self.engine)
+ assert_table_name(table_name2, True) # object not updated
+
+ # Index renames
+ if self.url.startswith('sqlite') or self.url.startswith('mysql'):
+ self.assertRaises(changeset.exceptions.NotSupportedError,
+ self.index.rename, index_name2)
+ else:
+ assert_index_name(index_name1)
+ rename_index(self.index, index_name2, engine=self.engine)
+ assert_index_name(index_name2)
+ self.index.rename(index_name1)
+ assert_index_name(index_name1)
+
+ # test by just the string
+ rename_index(index_name1, index_name2, engine=self.engine)
+ assert_index_name(index_name2, True)
+
+ finally:
+ if self.table.exists():
+ self.table.drop()
+
+
+class TestColumnChange(fixture.DB):
+ level = fixture.DB.CONNECT
+ table_name = 'tmp_colchange'
+
+ def _setup(self, url):
+ super(TestColumnChange, self)._setup(url)
+ self.meta = MetaData(self.engine)
+ self.table = Table(self.table_name, self.meta,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(40), server_default=DefaultClause("tluafed"),
+ nullable=True),
+ )
+ if self.table.exists():
+ self.table.drop()
+ try:
+ self.table.create()
+ except sqlalchemy.exceptions.SQLError, e:
+ # SQLite: database schema has changed
+ if not self.url.startswith('sqlite://'):
+ raise
+
+ def _teardown(self):
+ if self.table.exists():
+ try:
+ self.table.drop(self.engine)
+ except sqlalchemy.exceptions.SQLError,e:
+ # SQLite: database schema has changed
+ if not self.url.startswith('sqlite://'):
+ raise
+ super(TestColumnChange, self)._teardown()
+
+ @fixture.usedb()
+ def test_rename(self):
+ """Can rename a column"""
+ def num_rows(col, content):
+ return len(list(self.table.select(col == content).execute()))
+ # Table content should be preserved in changed columns
+ content = "fgsfds"
+ self.engine.execute(self.table.insert(), data=content, id=42)
+ self.assertEquals(num_rows(self.table.c.data, content), 1)
+
+ # ...as a function, given a column object and the new name
+ alter_column('data', name='data2', table=self.table)
+ self.refresh_table()
+ alter_column(self.table.c.data2, name='atad')
+ self.refresh_table(self.table.name)
+ self.assert_('data' not in self.table.c.keys())
+ self.assert_('atad' in self.table.c.keys())
+ self.assertEquals(num_rows(self.table.c.atad, content), 1)
+
+ # ...as a method, given a new name
+ self.table.c.atad.alter(name='data')
+ self.refresh_table(self.table.name)
+ self.assert_('atad' not in self.table.c.keys())
+ self.table.c.data # Should not raise exception
+ self.assertEquals(num_rows(self.table.c.data, content), 1)
+
+ # ...as a function, given a new object
+ col = Column('atad', String(40), server_default=self.table.c.data.server_default)
+ alter_column(self.table.c.data, col)
+ self.refresh_table(self.table.name)
+ self.assert_('data' not in self.table.c.keys())
+ self.table.c.atad # Should not raise exception
+ self.assertEquals(num_rows(self.table.c.atad, content), 1)
+
+ # ...as a method, given a new object
+ col = Column('data', String(40), server_default=self.table.c.atad.server_default)
+ self.table.c.atad.alter(col)
+ self.refresh_table(self.table.name)
+ self.assert_('atad' not in self.table.c.keys())
+ self.table.c.data # Should not raise exception
+ self.assertEquals(num_rows(self.table.c.data,content), 1)
+
+ @fixture.usedb()
+ def test_type(self):
+ """Can change a column's type"""
+ # Entire column definition given
+ self.table.c.data.alter(Column('data', String(42)))
+ self.refresh_table(self.table.name)
+ self.assert_(isinstance(self.table.c.data.type, String))
+ self.assertEquals(self.table.c.data.type.length, 42)
+
+ # Just the new type
+ self.table.c.data.alter(type=String(43))
+ self.refresh_table(self.table.name)
+ self.assert_(isinstance(self.table.c.data.type, String))
+ self.assertEquals(self.table.c.data.type.length, 43)
+
+ # Different type
+ self.assert_(isinstance(self.table.c.id.type, Integer))
+ self.assertEquals(self.table.c.id.nullable, False)
+
+ if not self.engine.name == 'firebird':
+ self.table.c.id.alter(type=String(20))
+ self.assertEquals(self.table.c.id.nullable, False)
+ self.refresh_table(self.table.name)
+ self.assert_(isinstance(self.table.c.id.type, String))
+
+ @fixture.usedb()
+ def test_default(self):
+ """Can change a column's server_default value (DefaultClauses only)
+ Only DefaultClauses are changed here: others are managed by the
+ application / by SA
+ """
+ self.assertEquals(self.table.c.data.server_default.arg, 'tluafed')
+
+ # Just the new default
+ default = 'my_default'
+ self.table.c.data.alter(server_default=DefaultClause(default))
+ self.refresh_table(self.table.name)
+ #self.assertEquals(self.table.c.data.server_default.arg,default)
+ # TextClause returned by autoload
+ self.assert_(default in str(self.table.c.data.server_default.arg))
+ self.engine.execute(self.table.insert(), id=12)
+ row = self._select_row()
+ self.assertEqual(row['data'], default)
+
+ # Column object
+ default = 'your_default'
+ self.table.c.data.alter(Column('data', String(40), server_default=DefaultClause(default)))
+ self.refresh_table(self.table.name)
+ self.assert_(default in str(self.table.c.data.server_default.arg))
+
+ # Drop/remove default
+ self.table.c.data.alter(server_default=None)
+ self.assertEqual(self.table.c.data.server_default, None)
+
+ self.refresh_table(self.table.name)
+ # server_default isn't necessarily None for Oracle
+ #self.assert_(self.table.c.data.server_default is None,self.table.c.data.server_default)
+ self.engine.execute(self.table.insert(), id=11)
+ if SQLA_06:
+ row = self.table.select(self.table.c.id == 11).execution_options(autocommit=True).execute().fetchone()
+ else:
+ row = self.table.select(self.table.c.id == 11, autocommit=True).execute().fetchone()
+ self.assert_(row['data'] is None, row['data'])
+
+ @fixture.usedb(not_supported='firebird')
+ def test_null(self):
+ """Can change a column's null constraint"""
+ self.assertEquals(self.table.c.data.nullable, True)
+
+ # Column object
+ self.table.c.data.alter(Column('data', String(40), nullable=False))
+ self.table.nullable = None
+ self.refresh_table(self.table.name)
+ self.assertEquals(self.table.c.data.nullable, False)
+
+ # Just the new status
+ self.table.c.data.alter(nullable=True)
+ self.refresh_table(self.table.name)
+ self.assertEquals(self.table.c.data.nullable, True)
+
+ @fixture.usedb()
+ def test_alter_metadata(self):
+ """Test if alter_metadata is respected"""
+
+ self.table.c.data.alter(Column('data', String(100)))
+
+ self.assert_(isinstance(self.table.c.data.type, String))
+ self.assertEqual(self.table.c.data.type.length, 100)
+
+ # nothing should change
+ self.table.c.data.alter(Column('data', String(200)), alter_metadata=False)
+ self.assert_(isinstance(self.table.c.data.type, String))
+ self.assertEqual(self.table.c.data.type.length, 100)
+
+ @fixture.usedb()
+ def test_alter_returns_delta(self):
+ """Test if alter constructs return delta"""
+
+ delta = self.table.c.data.alter(Column('data', String(100)))
+ self.assert_('type' in delta)
+
+ @fixture.usedb()
+ def test_alter_all(self):
+ """Tests all alter changes at one time"""
+ # test for each db separately
+ # since currently some dont support everything
+
+ # test pre settings
+ self.assertEqual(self.table.c.data.nullable, True)
+ self.assertEqual(self.table.c.data.server_default.arg, 'tluafed')
+ self.assertEqual(self.table.c.data.name, 'data')
+ self.assertTrue(isinstance(self.table.c.data.type, String))
+ self.assertTrue(self.table.c.data.type.length, 40)
+
+ kw = dict(nullable=False,
+ server_default='foobar',
+ name='data_new',
+ type=String(50),
+ alter_metadata=True)
+ if self.engine.name == 'firebird':
+ del kw['nullable']
+ self.table.c.data.alter(**kw)
+
+ # test altered objects
+ self.assertEqual(self.table.c.data.server_default.arg, 'foobar')
+ if not self.engine.name == 'firebird':
+ self.assertEqual(self.table.c.data.nullable, False)
+ self.assertEqual(self.table.c.data.name, 'data_new')
+ self.assertEqual(self.table.c.data.type.length, 50)
+
+ self.refresh_table(self.table.name)
+
+ # test post settings
+ if not self.engine.name == 'firebird':
+ self.assertEqual(self.table.c.data_new.nullable, False)
+ self.assertEqual(self.table.c.data_new.name, 'data_new')
+ self.assertTrue(isinstance(self.table.c.data_new.type, String))
+ self.assertTrue(self.table.c.data_new.type.length, 50)
+
+ # insert data and assert default
+ self.table.insert(values={'id': 10}).execute()
+ row = self._select_row()
+ self.assertEqual(u'foobar', row['data_new'])
+
+
+class TestColumnDelta(fixture.DB):
+ """Tests ColumnDelta class"""
+
+ level = fixture.DB.CONNECT
+ table_name = 'tmp_coldelta'
+ table_int = 0
+
+ def _setup(self, url):
+ super(TestColumnDelta, self)._setup(url)
+ self.meta = MetaData()
+ self.table = Table(self.table_name, self.meta,
+ Column('ids', String(10)),
+ )
+ self.meta.bind = self.engine
+ if self.engine.has_table(self.table.name):
+ self.table.drop()
+ self.table.create()
+
+ def _teardown(self):
+ if self.engine.has_table(self.table.name):
+ self.table.drop()
+ self.meta.clear()
+ super(TestColumnDelta,self)._teardown()
+
+ def mkcol(self, name='id', type=String, *p, **k):
+ return Column(name, type, *p, **k)
+
+ def verify(self, expected, original, *p, **k):
+ self.delta = ColumnDelta(original, *p, **k)
+ result = self.delta.keys()
+ result.sort()
+ self.assertEquals(expected, result)
+ return self.delta
+
+ def test_deltas_two_columns(self):
+ """Testing ColumnDelta with two columns"""
+ col_orig = self.mkcol(primary_key=True)
+ col_new = self.mkcol(name='ids', primary_key=True)
+ self.verify([], col_orig, col_orig)
+ self.verify(['name'], col_orig, col_orig, 'ids')
+ self.verify(['name'], col_orig, col_orig, name='ids')
+ self.verify(['name'], col_orig, col_new)
+ self.verify(['name', 'type'], col_orig, col_new, type=String)
+
+ # Type comparisons
+ self.verify([], self.mkcol(type=String), self.mkcol(type=String))
+ self.verify(['type'], self.mkcol(type=String), self.mkcol(type=Integer))
+ self.verify(['type'], self.mkcol(type=String), self.mkcol(type=String(42)))
+ self.verify([], self.mkcol(type=String(42)), self.mkcol(type=String(42)))
+ self.verify(['type'], self.mkcol(type=String(24)), self.mkcol(type=String(42)))
+ self.verify(['type'], self.mkcol(type=String(24)), self.mkcol(type=Text(24)))
+
+ # Other comparisons
+ self.verify(['primary_key'], self.mkcol(nullable=False), self.mkcol(primary_key=True))
+
+ # PK implies nullable=False
+ self.verify(['nullable', 'primary_key'], self.mkcol(nullable=True), self.mkcol(primary_key=True))
+ self.verify([], self.mkcol(primary_key=True), self.mkcol(primary_key=True))
+ self.verify(['nullable'], self.mkcol(nullable=True), self.mkcol(nullable=False))
+ self.verify([], self.mkcol(nullable=True), self.mkcol(nullable=True))
+ self.verify([], self.mkcol(server_default=None), self.mkcol(server_default=None))
+ self.verify([], self.mkcol(server_default='42'), self.mkcol(server_default='42'))
+
+ # test server default
+ delta = self.verify(['server_default'], self.mkcol(), self.mkcol('id', String, DefaultClause('foobar')))
+ self.assertEqual(delta['server_default'].arg, 'foobar')
+
+ self.verify([], self.mkcol(server_default='foobar'), self.mkcol('id', String, DefaultClause('foobar')))
+ self.verify(['type'], self.mkcol(server_default='foobar'), self.mkcol('id', Text, DefaultClause('foobar')))
+
+ # test alter_metadata
+ col = self.mkcol(server_default='foobar')
+ self.verify(['type'], col, self.mkcol('id', Text, DefaultClause('foobar')), alter_metadata=True)
+ self.assert_(isinstance(col.type, Text))
+
+ col = self.mkcol()
+ self.verify(['name', 'server_default', 'type'], col, self.mkcol('beep', Text, DefaultClause('foobar')), alter_metadata=True)
+ self.assert_(isinstance(col.type, Text))
+ self.assertEqual(col.name, 'beep')
+ self.assertEqual(col.server_default.arg, 'foobar')
+
+ col = self.mkcol()
+ self.verify(['name', 'server_default', 'type'], col, self.mkcol('beep', Text, DefaultClause('foobar')), alter_metadata=False)
+ self.assertFalse(isinstance(col.type, Text))
+ self.assertNotEqual(col.name, 'beep')
+ self.assertFalse(col.server_default)
+
+ @fixture.usedb()
+ def test_deltas_zero_columns(self):
+ """Testing ColumnDelta with zero columns"""
+
+ self.verify(['name'], 'ids', table=self.table, name='hey')
+
+ # test reflection
+ self.verify(['type'], 'ids', table=self.table.name, type=String(80), engine=self.engine)
+ self.verify(['type'], 'ids', table=self.table.name, type=String(80), metadata=self.meta)
+
+ # check if alter_metadata is respected
+ self.meta.clear()
+ delta = self.verify(['type'], 'ids', table=self.table.name, type=String(80), alter_metadata=True, metadata=self.meta)
+ self.assert_(self.table.name in self.meta)
+ self.assertEqual(delta.result_column.type.length, 80)
+ self.assertEqual(self.meta.tables.get(self.table.name).c.ids.type.length, 80)
+
+ self.meta.clear()
+ self.verify(['type'], 'ids', table=self.table.name, type=String(80), alter_metadata=False, engine=self.engine)
+ self.assert_(self.table.name not in self.meta)
+
+ self.meta.clear()
+ self.verify(['type'], 'ids', table=self.table.name, type=String(80), alter_metadata=False, metadata=self.meta)
+ self.assert_(self.table.name not in self.meta)
+
+ # test defaults
+ self.meta.clear()
+ self.verify(['server_default'], 'ids', table=self.table.name, server_default='foobar', alter_metadata=True, metadata=self.meta)
+ self.meta.tables.get(self.table.name).c.ids.server_default.arg == 'foobar'
+
+ # test missing parameters
+ self.assertRaises(ValueError, ColumnDelta, table=self.table.name)
+ self.assertRaises(ValueError, ColumnDelta, 'ids', table=self.table.name, alter_metadata=True)
+ self.assertRaises(ValueError, ColumnDelta, 'ids', table=self.table.name, alter_metadata=False)
+
+ def test_deltas_one_column(self):
+ """Testing ColumnDelta with one column"""
+ col_orig = self.mkcol(primary_key=True)
+
+ self.verify([], col_orig)
+ self.verify(['name'], col_orig, 'ids')
+ # Parameters are always executed, even if they're 'unchanged'
+ # (We can't assume given column is up-to-date)
+ self.verify(['name', 'primary_key', 'type'], col_orig, 'id', Integer, primary_key=True)
+ self.verify(['name', 'primary_key', 'type'], col_orig, name='id', type=Integer, primary_key=True)
+
+ # Change name, given an up-to-date definition and the current name
+ delta = self.verify(['name'], col_orig, name='blah')
+ self.assertEquals(delta.get('name'), 'blah')
+ self.assertEquals(delta.current_name, 'id')
+
+ # check if alter_metadata is respected
+ col_orig = self.mkcol(primary_key=True)
+ self.verify(['name', 'type'], col_orig, name='id12', type=Text, alter_metadata=True)
+ self.assert_(isinstance(col_orig.type, Text))
+ self.assertEqual(col_orig.name, 'id12')
+
+ col_orig = self.mkcol(primary_key=True)
+ self.verify(['name', 'type'], col_orig, name='id12', type=Text, alter_metadata=False)
+ self.assert_(isinstance(col_orig.type, String))
+ self.assertEqual(col_orig.name, 'id')
+
+ # test server default
+ col_orig = self.mkcol(primary_key=True)
+ delta = self.verify(['server_default'], col_orig, DefaultClause('foobar'))
+ self.assertEqual(delta['server_default'].arg, 'foobar')
+
+ delta = self.verify(['server_default'], col_orig, server_default=DefaultClause('foobar'))
+ self.assertEqual(delta['server_default'].arg, 'foobar')
+
+ # no change
+ col_orig = self.mkcol(server_default=DefaultClause('foobar'))
+ delta = self.verify(['type'], col_orig, DefaultClause('foobar'), type=PickleType)
+ self.assert_(isinstance(delta.result_column.type, PickleType))
+
+ # TODO: test server on update
+ # TODO: test bind metadata
diff --git a/migrate/tests/changeset/test_constraint.py b/migrate/tests/changeset/test_constraint.py
new file mode 100644
index 0000000..e5717f2
--- /dev/null
+++ b/migrate/tests/changeset/test_constraint.py
@@ -0,0 +1,282 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from sqlalchemy import *
+from sqlalchemy.util import *
+from sqlalchemy.exc import *
+
+from migrate.changeset import *
+from migrate.changeset.exceptions import *
+
+from migrate.tests import fixture
+
+
+class CommonTestConstraint(fixture.DB):
+ """helper functions to test constraints.
+
+ we just create a fresh new table and make sure everything is
+ as required.
+ """
+
+ def _setup(self, url):
+ super(CommonTestConstraint, self)._setup(url)
+ self._create_table()
+
+ def _teardown(self):
+ if hasattr(self, 'table') and self.engine.has_table(self.table.name):
+ self.table.drop()
+ super(CommonTestConstraint, self)._teardown()
+
+ def _create_table(self):
+ self._connect(self.url)
+ self.meta = MetaData(self.engine)
+ self.tablename = 'mytable'
+ self.table = Table(self.tablename, self.meta,
+ Column('id', Integer, nullable=False),
+ Column('fkey', Integer, nullable=False),
+ mysql_engine='InnoDB')
+ if self.engine.has_table(self.table.name):
+ self.table.drop()
+ self.table.create()
+
+ # make sure we start at zero
+ self.assertEquals(len(self.table.primary_key), 0)
+ self.assert_(isinstance(self.table.primary_key,
+ schema.PrimaryKeyConstraint), self.table.primary_key.__class__)
+
+
+class TestConstraint(CommonTestConstraint):
+ level = fixture.DB.CONNECT
+
+ def _define_pk(self, *cols):
+ # Add a pk by creating a PK constraint
+ if (self.engine.name in ('oracle', 'firebird')):
+ # Can't drop Oracle PKs without an explicit name
+ pk = PrimaryKeyConstraint(table=self.table, name='temp_pk_key', *cols)
+ else:
+ pk = PrimaryKeyConstraint(table=self.table, *cols)
+ self.compare_columns_equal(pk.columns, cols)
+ pk.create()
+ self.refresh_table()
+ if not self.url.startswith('sqlite'):
+ self.compare_columns_equal(self.table.primary_key, cols, ['type', 'autoincrement'])
+
+ # Drop the PK constraint
+ #if (self.engine.name in ('oracle', 'firebird')):
+ # # Apparently Oracle PK names aren't introspected
+ # pk.name = self.table.primary_key.name
+ pk.drop()
+ self.refresh_table()
+ self.assertEquals(len(self.table.primary_key), 0)
+ self.assert_(isinstance(self.table.primary_key, schema.PrimaryKeyConstraint))
+ return pk
+
+ @fixture.usedb(not_supported='sqlite')
+ def test_define_fk(self):
+ """FK constraints can be defined, created, and dropped"""
+ # FK target must be unique
+ pk = PrimaryKeyConstraint(self.table.c.id, table=self.table, name="pkid")
+ pk.create()
+
+ # Add a FK by creating a FK constraint
+ self.assertEquals(self.table.c.fkey.foreign_keys._list, [])
+ fk = ForeignKeyConstraint([self.table.c.fkey],
+ [self.table.c.id],
+ name="fk_id_fkey",
+ ondelete="CASCADE")
+ self.assert_(self.table.c.fkey.foreign_keys._list is not [])
+ for key in fk.columns:
+ self.assertEquals(key, self.table.c.fkey.name)
+ self.assertEquals([e.column for e in fk.elements], [self.table.c.id])
+ self.assertEquals(list(fk.referenced), [self.table.c.id])
+
+ if self.url.startswith('mysql'):
+ # MySQL FKs need an index
+ index = Index('index_name', self.table.c.fkey)
+ index.create()
+ fk.create()
+
+ # test for ondelete/onupdate
+ fkey = self.table.c.fkey.foreign_keys._list[0]
+ self.assertEquals(fkey.ondelete, "CASCADE")
+ # TODO: test on real db if it was set
+
+ self.refresh_table()
+ self.assert_(self.table.c.fkey.foreign_keys._list is not [])
+
+ fk.drop()
+ self.refresh_table()
+ self.assertEquals(self.table.c.fkey.foreign_keys._list, [])
+
+ @fixture.usedb()
+ def test_define_pk(self):
+ """PK constraints can be defined, created, and dropped"""
+ self._define_pk(self.table.c.fkey)
+
+ @fixture.usedb()
+ def test_define_pk_multi(self):
+ """Multicolumn PK constraints can be defined, created, and dropped"""
+ self._define_pk(self.table.c.id, self.table.c.fkey)
+
+ @fixture.usedb()
+ def test_drop_cascade(self):
+ """Drop constraint cascaded"""
+
+ pk = PrimaryKeyConstraint('fkey', table=self.table, name="id_pkey")
+ pk.create()
+ self.refresh_table()
+
+ # Drop the PK constraint forcing cascade
+ try:
+ pk.drop(cascade=True)
+ except NotSupportedError:
+ if self.engine.name == 'firebird':
+ pass
+
+ # TODO: add real assertion if it was added
+
+ @fixture.usedb(supported=['mysql'])
+ def test_fail_mysql_check_constraints(self):
+ """Check constraints raise NotSupported for mysql on drop"""
+ cons = CheckConstraint('id > 3', name="id_check", table=self.table)
+ cons.create()
+ self.refresh_table()
+
+ try:
+ cons.drop()
+ except NotSupportedError:
+ pass
+ else:
+ self.fail()
+
+ @fixture.usedb(not_supported=['sqlite', 'mysql'])
+ def test_named_check_constraints(self):
+ """Check constraints can be defined, created, and dropped"""
+ self.assertRaises(InvalidConstraintError, CheckConstraint, 'id > 3')
+ cons = CheckConstraint('id > 3', name="id_check", table=self.table)
+ cons.create()
+ self.refresh_table()
+
+ self.table.insert(values={'id': 4, 'fkey': 1}).execute()
+ try:
+ self.table.insert(values={'id': 1, 'fkey': 1}).execute()
+ except (IntegrityError, ProgrammingError):
+ pass
+ else:
+ self.fail()
+
+ # Remove the name, drop the constraint; it should succeed
+ cons.drop()
+ self.refresh_table()
+ self.table.insert(values={'id': 2, 'fkey': 2}).execute()
+ self.table.insert(values={'id': 1, 'fkey': 2}).execute()
+
+
+class TestAutoname(CommonTestConstraint):
+ """Every method tests for a type of constraint wether it can autoname
+ itself and if you can pass object instance and names to classes.
+ """
+ level = fixture.DB.CONNECT
+
+ @fixture.usedb(not_supported=['oracle', 'firebird'])
+ def test_autoname_pk(self):
+ """PrimaryKeyConstraints can guess their name if None is given"""
+ # Don't supply a name; it should create one
+ cons = PrimaryKeyConstraint(self.table.c.id)
+ cons.create()
+ self.refresh_table()
+ if not self.url.startswith('sqlite'):
+ # TODO: test for index for sqlite
+ self.compare_columns_equal(cons.columns, self.table.primary_key, ['autoincrement', 'type'])
+
+ # Remove the name, drop the constraint; it should succeed
+ cons.name = None
+ cons.drop()
+ self.refresh_table()
+ self.assertEquals(list(), list(self.table.primary_key))
+
+ # test string names
+ cons = PrimaryKeyConstraint('id', table=self.table)
+ cons.create()
+ self.refresh_table()
+ if not self.url.startswith('sqlite'):
+ # TODO: test for index for sqlite
+ if SQLA_06:
+ self.compare_columns_equal(cons.columns, self.table.primary_key)
+ else:
+ self.compare_columns_equal(cons.columns, self.table.primary_key, ['autoincrement'])
+ cons.name = None
+ cons.drop()
+
+ @fixture.usedb(not_supported=['oracle', 'sqlite', 'firebird'])
+ def test_autoname_fk(self):
+ """ForeignKeyConstraints can guess their name if None is given"""
+ cons = PrimaryKeyConstraint(self.table.c.id)
+ cons.create()
+
+ cons = ForeignKeyConstraint([self.table.c.fkey], [self.table.c.id])
+ cons.create()
+ self.refresh_table()
+ self.table.c.fkey.foreign_keys[0].column is self.table.c.id
+
+ # Remove the name, drop the constraint; it should succeed
+ cons.name = None
+ cons.drop()
+ self.refresh_table()
+ self.assertEquals(self.table.c.fkey.foreign_keys._list, list())
+
+ # test string names
+ cons = ForeignKeyConstraint(['fkey'], ['%s.id' % self.tablename], table=self.table)
+ cons.create()
+ self.refresh_table()
+ self.table.c.fkey.foreign_keys[0].column is self.table.c.id
+
+ # Remove the name, drop the constraint; it should succeed
+ cons.name = None
+ cons.drop()
+
+ @fixture.usedb(not_supported=['oracle', 'sqlite', 'mysql'])
+ def test_autoname_check(self):
+ """CheckConstraints can guess their name if None is given"""
+ cons = CheckConstraint('id > 3', columns=[self.table.c.id])
+ cons.create()
+ self.refresh_table()
+
+ if not self.engine.name == 'mysql':
+ self.table.insert(values={'id': 4, 'fkey': 1}).execute()
+ try:
+ self.table.insert(values={'id': 1, 'fkey': 2}).execute()
+ except (IntegrityError, ProgrammingError):
+ pass
+ else:
+ self.fail()
+
+ # Remove the name, drop the constraint; it should succeed
+ cons.name = None
+ cons.drop()
+ self.refresh_table()
+ self.table.insert(values={'id': 2, 'fkey': 2}).execute()
+ self.table.insert(values={'id': 1, 'fkey': 3}).execute()
+
+ @fixture.usedb(not_supported=['oracle', 'sqlite'])
+ def test_autoname_unique(self):
+ """UniqueConstraints can guess their name if None is given"""
+ cons = UniqueConstraint(self.table.c.fkey)
+ cons.create()
+ self.refresh_table()
+
+ self.table.insert(values={'fkey': 4, 'id': 1}).execute()
+ try:
+ self.table.insert(values={'fkey': 4, 'id': 2}).execute()
+ except (sqlalchemy.exc.IntegrityError,
+ sqlalchemy.exc.ProgrammingError):
+ pass
+ else:
+ self.fail()
+
+ # Remove the name, drop the constraint; it should succeed
+ cons.name = None
+ cons.drop()
+ self.refresh_table()
+ self.table.insert(values={'fkey': 4, 'id': 2}).execute()
+ self.table.insert(values={'fkey': 4, 'id': 1}).execute()
diff --git a/migrate/tests/fixture/__init__.py b/migrate/tests/fixture/__init__.py
new file mode 100644
index 0000000..2a40355
--- /dev/null
+++ b/migrate/tests/fixture/__init__.py
@@ -0,0 +1,66 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import unittest
+import sys
+
+
+# Append test method name,etc. to descriptions automatically.
+# Yes, this is ugly, but it's the simplest way...
+def getDescription(self, test):
+ ret = str(test)
+ if self.descriptions:
+ return test.shortDescription() or ret
+ return ret
+unittest._TextTestResult.getDescription = getDescription
+
+
+class Result(unittest._TextTestResult):
+ # test description may be changed as we go; store the description at
+ # exception-time and print later
+ def __init__(self,*p,**k):
+ super(Result,self).__init__(*p,**k)
+ self.desc=dict()
+ def _addError(self,test,err,errs):
+ test,err=errs.pop()
+ errdata=(test,err,self.getDescription(test))
+ errs.append(errdata)
+
+ def addFailure(self,test,err):
+ super(Result,self).addFailure(test,err)
+ self._addError(test,err,self.failures)
+ def addError(self,test,err):
+ super(Result,self).addError(test,err)
+ self._addError(test,err,self.errors)
+ def printErrorList(self, flavour, errors):
+ # Copied from unittest.py
+ #for test, err in errors:
+ for errdata in errors:
+ test,err,desc=errdata
+ self.stream.writeln(self.separator1)
+ #self.stream.writeln("%s: %s" % (flavour,self.getDescription(test)))
+ self.stream.writeln("%s: %s" % (flavour,desc or self.getDescription(test)))
+ self.stream.writeln(self.separator2)
+ self.stream.writeln("%s" % err)
+
+class Runner(unittest.TextTestRunner):
+ def _makeResult(self):
+ return Result(self.stream,self.descriptions,self.verbosity)
+
+def suite(imports):
+ return unittest.TestLoader().loadTestsFromNames(imports)
+
+def main(imports=None):
+ if imports:
+ global suite
+ suite = suite(imports)
+ defaultTest='fixture.suite'
+ else:
+ defaultTest=None
+ return unittest.TestProgram(defaultTest=defaultTest,\
+ testRunner=Runner(verbosity=1))
+
+from base import Base
+from migrate.tests.fixture.pathed import Pathed
+from shell import Shell
+from database import DB,usedb
diff --git a/migrate/tests/fixture/base.py b/migrate/tests/fixture/base.py
new file mode 100644
index 0000000..10e8f13
--- /dev/null
+++ b/migrate/tests/fixture/base.py
@@ -0,0 +1,32 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import re
+import unittest
+
+class Base(unittest.TestCase):
+
+ def setup_method(self,func=None):
+ self.setUp()
+
+ def teardown_method(self,func=None):
+ self.tearDown()
+
+ def assertEqualsIgnoreWhitespace(self, v1, v2):
+ """Compares two strings that should be\
+ identical except for whitespace
+ """
+ def strip_whitespace(s):
+ return re.sub(r'\s', '', s)
+
+ line1 = strip_whitespace(v1)
+ line2 = strip_whitespace(v2)
+
+ self.assertEquals(line1, line2, "%s != %s" % (v1, v2))
+
+ def ignoreErrors(self, func, *p,**k):
+ """Call a function, ignoring any exceptions"""
+ try:
+ func(*p,**k)
+ except:
+ pass
diff --git a/migrate/tests/fixture/database.py b/migrate/tests/fixture/database.py
new file mode 100644
index 0000000..7155b26
--- /dev/null
+++ b/migrate/tests/fixture/database.py
@@ -0,0 +1,170 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import os
+from decorator import decorator
+
+from sqlalchemy import create_engine, Table, MetaData
+from sqlalchemy.orm import create_session
+from sqlalchemy.pool import StaticPool
+
+from migrate.changeset import SQLA_06
+from migrate.changeset.schema import ColumnDelta
+from migrate.versioning.util import Memoize
+
+from migrate.tests.fixture.base import Base
+from migrate.tests.fixture.pathed import Pathed
+
+
+@Memoize
+def readurls():
+ """read URLs from config file return a list"""
+ # TODO: remove tmpfile since sqlite can store db in memory
+ filename = 'test_db.cfg'
+ ret = list()
+ tmpfile = Pathed.tmp()
+ fullpath = os.path.join(os.curdir, filename)
+
+ try:
+ fd = open(fullpath)
+ except IOError:
+ raise IOError("""You must specify the databases to use for testing!
+Copy %(filename)s.tmpl to %(filename)s and edit your database URLs.""" % locals())
+
+ for line in fd:
+ if line.startswith('#'):
+ continue
+ line = line.replace('__tmp__', tmpfile).strip()
+ ret.append(line)
+ fd.close()
+ return ret
+
+def is_supported(url, supported, not_supported):
+ db = url.split(':', 1)[0]
+
+ if supported is not None:
+ if isinstance(supported, basestring):
+ return supported == db
+ else:
+ return db in supported
+ elif not_supported is not None:
+ if isinstance(not_supported, basestring):
+ return not_supported != db
+ else:
+ return not (db in not_supported)
+ return True
+
+
+def usedb(supported=None, not_supported=None):
+ """Decorates tests to be run with a database connection
+ These tests are run once for each available database
+
+ @param supported: run tests for ONLY these databases
+ @param not_supported: run tests for all databases EXCEPT these
+
+ If both supported and not_supported are empty, all dbs are assumed
+ to be supported
+ """
+ if supported is not None and not_supported is not None:
+ raise AssertionError("Can't specify both supported and not_supported in fixture.db()")
+
+ urls = readurls()
+ my_urls = [url for url in urls if is_supported(url, supported, not_supported)]
+
+ @decorator
+ def dec(f, self, *a, **kw):
+ for url in my_urls:
+ try:
+ self._setup(url)
+ f(self, *a, **kw)
+ finally:
+ self._teardown()
+ return dec
+
+
+class DB(Base):
+ # Constants: connection level
+ NONE = 0 # No connection; just set self.url
+ CONNECT = 1 # Connect; no transaction
+ TXN = 2 # Everything in a transaction
+
+ level = TXN
+
+ def _engineInfo(self, url=None):
+ if url is None:
+ url = self.url
+ return url
+
+ def _setup(self, url):
+ self._connect(url)
+
+ def _teardown(self):
+ self._disconnect()
+
+ def _connect(self, url):
+ self.url = url
+ # TODO: seems like 0.5.x branch does not work with engine.dispose and staticpool
+ #self.engine = create_engine(url, echo=True, poolclass=StaticPool)
+ self.engine = create_engine(url, echo=True)
+ self.meta = MetaData(bind=self.engine)
+ if self.level < self.CONNECT:
+ return
+ #self.session = create_session(bind=self.engine)
+ if self.level < self.TXN:
+ return
+ #self.txn = self.session.begin()
+
+ def _disconnect(self):
+ if hasattr(self, 'txn'):
+ self.txn.rollback()
+ if hasattr(self, 'session'):
+ self.session.close()
+ #if hasattr(self,'conn'):
+ # self.conn.close()
+ self.engine.dispose()
+
+ def _supported(self, url):
+ db = url.split(':',1)[0]
+ func = getattr(self, self._TestCase__testMethodName)
+ if hasattr(func, 'supported'):
+ return db in func.supported
+ if hasattr(func, 'not_supported'):
+ return not (db in func.not_supported)
+ # Neither list assigned; assume all are supported
+ return True
+
+ def _not_supported(self, url):
+ return not self._supported(url)
+
+ def _select_row(self):
+ """Select rows, used in multiple tests"""
+ if SQLA_06:
+ row = self.table.select().execution_options(autocommit=True).execute().fetchone()
+ else:
+ row = self.table.select(autocommit=True).execute().fetchone()
+ return row
+
+ def refresh_table(self, name=None):
+ """Reload the table from the database
+ Assumes we're working with only a single table, self.table, and
+ metadata self.meta
+
+ Working w/ multiple tables is not possible, as tables can only be
+ reloaded with meta.clear()
+ """
+ if name is None:
+ name = self.table.name
+ self.meta.clear()
+ self.table = Table(name, self.meta, autoload=True)
+
+ def compare_columns_equal(self, columns1, columns2, ignore=None):
+ """Loop through all columns and compare them"""
+ for c1, c2 in zip(list(columns1), list(columns2)):
+ diffs = ColumnDelta(c1, c2).diffs
+ if ignore:
+ for key in ignore:
+ diffs.pop(key, None)
+ if diffs:
+ self.fail("Comparing %s to %s failed: %s" % (columns1, columns2, diffs))
+
+# TODO: document engine.dispose and write tests
diff --git a/migrate/tests/fixture/models.py b/migrate/tests/fixture/models.py
new file mode 100644
index 0000000..ee76429
--- /dev/null
+++ b/migrate/tests/fixture/models.py
@@ -0,0 +1,14 @@
+from sqlalchemy import *
+
+# test rundiffs in shell
+meta_old_rundiffs = MetaData()
+meta_rundiffs = MetaData()
+meta = MetaData()
+
+tmp_account_rundiffs = Table('tmp_account_rundiffs', meta_rundiffs,
+ Column('id', Integer, primary_key=True),
+ Column('login', Text()),
+ Column('passwd', Text()),
+)
+
+tmp_sql_table = Table('tmp_sql_table', meta, Column('id', Integer))
diff --git a/migrate/tests/fixture/pathed.py b/migrate/tests/fixture/pathed.py
new file mode 100644
index 0000000..b36aa08
--- /dev/null
+++ b/migrate/tests/fixture/pathed.py
@@ -0,0 +1,77 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import os
+import sys
+import shutil
+import tempfile
+
+from migrate.tests.fixture import base
+
+
+class Pathed(base.Base):
+ # Temporary files
+
+ _tmpdir = tempfile.mkdtemp()
+
+ def setUp(self):
+ super(Pathed, self).setUp()
+ self.temp_usable_dir = tempfile.mkdtemp()
+ sys.path.append(self.temp_usable_dir)
+
+ def tearDown(self):
+ super(Pathed, self).tearDown()
+ try:
+ sys.path.remove(self.temp_usable_dir)
+ except:
+ pass # w00t?
+ Pathed.purge(self.temp_usable_dir)
+
+ @classmethod
+ def _tmp(cls, prefix='', suffix=''):
+ """Generate a temporary file name that doesn't exist
+ All filenames are generated inside a temporary directory created by
+ tempfile.mkdtemp(); only the creating user has access to this directory.
+ It should be secure to return a nonexistant temp filename in this
+ directory, unless the user is messing with their own files.
+ """
+ file, ret = tempfile.mkstemp(suffix,prefix,cls._tmpdir)
+ os.close(file)
+ os.remove(ret)
+ return ret
+
+ @classmethod
+ def tmp(cls, *p, **k):
+ return cls._tmp(*p, **k)
+
+ @classmethod
+ def tmp_py(cls, *p, **k):
+ return cls._tmp(suffix='.py', *p, **k)
+
+ @classmethod
+ def tmp_sql(cls, *p, **k):
+ return cls._tmp(suffix='.sql', *p, **k)
+
+ @classmethod
+ def tmp_named(cls, name):
+ return os.path.join(cls._tmpdir, name)
+
+ @classmethod
+ def tmp_repos(cls, *p, **k):
+ return cls._tmp(*p, **k)
+
+ @classmethod
+ def purge(cls, path):
+ """Removes this path if it exists, in preparation for tests
+ Careful - all tests should take place in /tmp.
+ We don't want to accidentally wipe stuff out...
+ """
+ if os.path.exists(path):
+ if os.path.isdir(path):
+ shutil.rmtree(path)
+ else:
+ os.remove(path)
+ if path.endswith('.py'):
+ pyc = path + 'c'
+ if os.path.exists(pyc):
+ os.remove(pyc)
diff --git a/migrate/tests/fixture/shell.py b/migrate/tests/fixture/shell.py
new file mode 100644
index 0000000..92f9e33
--- /dev/null
+++ b/migrate/tests/fixture/shell.py
@@ -0,0 +1,32 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import os
+import sys
+
+from scripttest import TestFileEnvironment
+
+from migrate.tests.fixture.pathed import *
+
+
+class Shell(Pathed):
+ """Base class for command line tests"""
+
+ def setUp(self):
+ super(Shell, self).setUp()
+ self.env = TestFileEnvironment(
+ base_path=os.path.join(self.temp_usable_dir, 'env'),
+ script_path=[os.path.dirname(sys.executable)], # PATH to migrate development script folder
+ environ={'PYTHONPATH':
+ os.path.join(os.getcwd(), 'migrate', 'tests')},
+ )
+ self.env.run("virtualenv %s" % self.env.base_path)
+ self.env.run("%s/bin/python setup.py install" % (self.env.base_path,), cwd=os.getcwd())
+
+ def run_version(self, repos_path):
+ result = self.env.run('bin/migrate version %s' % repos_path)
+ return int(result.stdout.strip())
+
+ def run_db_version(self, url, repos_path):
+ result = self.env.run('bin/migrate db_version %s %s' % (url, repos_path))
+ return int(result.stdout.strip())
diff --git a/migrate/tests/integrated/__init__.py b/migrate/tests/integrated/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/migrate/tests/integrated/__init__.py
diff --git a/migrate/tests/integrated/test_docs.py b/migrate/tests/integrated/test_docs.py
new file mode 100644
index 0000000..6aed071
--- /dev/null
+++ b/migrate/tests/integrated/test_docs.py
@@ -0,0 +1,18 @@
+import doctest
+import os
+
+
+from migrate.tests import fixture
+
+# Collect tests for all handwritten docs: doc/*.rst
+
+dir = ('..','..','..','docs')
+absdir = (os.path.dirname(os.path.abspath(__file__)),)+dir
+dirpath = os.path.join(*absdir)
+files = [f for f in os.listdir(dirpath) if f.endswith('.rst')]
+paths = [os.path.join(*(dir+(f,))) for f in files]
+assert len(paths) > 0
+suite = doctest.DocFileSuite(*paths)
+
+def test_docs():
+ suite.debug()
diff --git a/migrate/tests/versioning/__init__.py b/migrate/tests/versioning/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/migrate/tests/versioning/__init__.py
diff --git a/migrate/tests/versioning/test_api.py b/migrate/tests/versioning/test_api.py
new file mode 100644
index 0000000..41152b0
--- /dev/null
+++ b/migrate/tests/versioning/test_api.py
@@ -0,0 +1,120 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+from migrate.versioning import api
+from migrate.versioning.exceptions import *
+
+from migrate.tests.fixture.pathed import *
+from migrate.tests.fixture import models
+from migrate.tests import fixture
+
+
+class TestAPI(Pathed):
+
+ def test_help(self):
+ self.assertTrue(isinstance(api.help('help'), basestring))
+ self.assertRaises(UsageError, api.help)
+ self.assertRaises(UsageError, api.help, 'foobar')
+ self.assert_(isinstance(api.help('create'), str))
+
+ # test that all commands return some text
+ for cmd in api.__all__:
+ content = api.help(cmd)
+ self.assertTrue(content)
+
+ def test_create(self):
+ tmprepo = self.tmp_repos()
+ api.create(tmprepo, 'temp')
+
+ # repository already exists
+ self.assertRaises(KnownError, api.create, tmprepo, 'temp')
+
+ def test_script(self):
+ repo = self.tmp_repos()
+ api.create(repo, 'temp')
+ api.script('first version', repo)
+
+ def test_script_sql(self):
+ repo = self.tmp_repos()
+ api.create(repo, 'temp')
+ api.script_sql('postgres', repo)
+
+ def test_version(self):
+ repo = self.tmp_repos()
+ api.create(repo, 'temp')
+ api.version(repo)
+
+ def test_source(self):
+ repo = self.tmp_repos()
+ api.create(repo, 'temp')
+ api.script('first version', repo)
+ api.script_sql('default', repo)
+
+ # no repository
+ self.assertRaises(UsageError, api.source, 1)
+
+ # stdout
+ out = api.source(1, dest=None, repository=repo)
+ self.assertTrue(out)
+
+ # file
+ out = api.source(1, dest=self.tmp_repos(), repository=repo)
+ self.assertFalse(out)
+
+ def test_manage(self):
+ output = api.manage(os.path.join(self.temp_usable_dir, 'manage.py'))
+
+
+class TestSchemaAPI(fixture.DB, Pathed):
+
+ def _setup(self, url):
+ super(TestSchemaAPI, self)._setup(url)
+ self.repo = self.tmp_repos()
+ api.create(self.repo, 'temp')
+ self.schema = api.version_control(url, self.repo)
+
+ def _teardown(self):
+ self.schema = api.drop_version_control(self.url, self.repo)
+ super(TestSchemaAPI, self)._teardown()
+
+ @fixture.usedb()
+ def test_workflow(self):
+ self.assertEqual(api.db_version(self.url, self.repo), 0)
+ api.script('First Version', self.repo)
+ self.assertEqual(api.db_version(self.url, self.repo), 0)
+ api.upgrade(self.url, self.repo, 1)
+ self.assertEqual(api.db_version(self.url, self.repo), 1)
+ api.downgrade(self.url, self.repo, 0)
+ self.assertEqual(api.db_version(self.url, self.repo), 0)
+ api.test(self.url, self.repo)
+ self.assertEqual(api.db_version(self.url, self.repo), 0)
+
+ # preview
+ # TODO: test output
+ out = api.upgrade(self.url, self.repo, preview_py=True)
+ out = api.upgrade(self.url, self.repo, preview_sql=True)
+
+ api.upgrade(self.url, self.repo, 1)
+ api.script_sql('default', self.repo)
+ self.assertRaises(UsageError, api.upgrade, self.url, self.repo, 2, preview_py=True)
+ out = api.upgrade(self.url, self.repo, 2, preview_sql=True)
+
+ # cant upgrade to version 1, already at version 1
+ self.assertEqual(api.db_version(self.url, self.repo), 1)
+ self.assertRaises(KnownError, api.upgrade, self.url, self.repo, 0)
+
+ @fixture.usedb()
+ def test_compare_model_to_db(self):
+ diff = api.compare_model_to_db(self.url, self.repo, models.meta)
+
+ @fixture.usedb()
+ def test_create_model(self):
+ model = api.create_model(self.url, self.repo)
+
+ @fixture.usedb()
+ def test_make_update_script_for_model(self):
+ model = api.make_update_script_for_model(self.url, self.repo, models.meta_old_rundiffs, models.meta_rundiffs)
+
+ @fixture.usedb()
+ def test_update_db_from_model(self):
+ model = api.update_db_from_model(self.url, self.repo, models.meta_rundiffs)
diff --git a/migrate/tests/versioning/test_cfgparse.py b/migrate/tests/versioning/test_cfgparse.py
new file mode 100644
index 0000000..27f52cd
--- /dev/null
+++ b/migrate/tests/versioning/test_cfgparse.py
@@ -0,0 +1,27 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+from migrate.versioning import cfgparse
+from migrate.versioning.repository import *
+from migrate.versioning.template import Template
+from migrate.tests import fixture
+
+
+class TestConfigParser(fixture.Base):
+
+ def test_to_dict(self):
+ """Correctly interpret config results as dictionaries"""
+ parser = cfgparse.Parser(dict(default_value=42))
+ self.assert_(len(parser.sections()) == 0)
+ parser.add_section('section')
+ parser.set('section','option','value')
+ self.assertEqual(parser.get('section', 'option'), 'value')
+ self.assertEqual(parser.to_dict()['section']['option'], 'value')
+
+ def test_table_config(self):
+ """We should be able to specify the table to be used with a repository"""
+ default_text = Repository.prepare_config(Template().get_repository(),
+ 'repository_name', {})
+ specified_text = Repository.prepare_config(Template().get_repository(),
+ 'repository_name', {'version_table': '_other_table'})
+ self.assertNotEquals(default_text, specified_text)
diff --git a/migrate/tests/versioning/test_database.py b/migrate/tests/versioning/test_database.py
new file mode 100644
index 0000000..525385f
--- /dev/null
+++ b/migrate/tests/versioning/test_database.py
@@ -0,0 +1,11 @@
+from sqlalchemy import select
+from migrate.tests import fixture
+
+class TestConnect(fixture.DB):
+ level=fixture.DB.TXN
+
+ @fixture.usedb()
+ def test_connect(self):
+ """Connect to the database successfully"""
+ # Connection is done in fixture.DB setup; make sure we can do stuff
+ select(['42'],bind=self.engine).execute()
diff --git a/migrate/tests/versioning/test_genmodel.py b/migrate/tests/versioning/test_genmodel.py
new file mode 100644
index 0000000..61610a1
--- /dev/null
+++ b/migrate/tests/versioning/test_genmodel.py
@@ -0,0 +1,13 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import os
+
+from migrate.versioning.genmodel import *
+from migrate.versioning.exceptions import *
+
+from migrate.tests import fixture
+
+
+class TestModelGenerator(fixture.Pathed, fixture.DB):
+ level = fixture.DB.TXN
diff --git a/migrate/tests/versioning/test_keyedinstance.py b/migrate/tests/versioning/test_keyedinstance.py
new file mode 100644
index 0000000..b2f87ac
--- /dev/null
+++ b/migrate/tests/versioning/test_keyedinstance.py
@@ -0,0 +1,45 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+from migrate.tests import fixture
+from migrate.versioning.util.keyedinstance import *
+
+class TestKeydInstance(fixture.Base):
+ def test_unique(self):
+ """UniqueInstance should produce unique object instances"""
+ class Uniq1(KeyedInstance):
+ @classmethod
+ def _key(cls,key):
+ return str(key)
+ def __init__(self,value):
+ self.value=value
+ class Uniq2(KeyedInstance):
+ @classmethod
+ def _key(cls,key):
+ return str(key)
+ def __init__(self,value):
+ self.value=value
+
+ a10 = Uniq1('a')
+
+ # Different key: different instance
+ b10 = Uniq1('b')
+ self.assert_(a10 is not b10)
+
+ # Different class: different instance
+ a20 = Uniq2('a')
+ self.assert_(a10 is not a20)
+
+ # Same key/class: same instance
+ a11 = Uniq1('a')
+ self.assert_(a10 is a11)
+
+ # __init__ is called
+ self.assertEquals(a10.value,'a')
+
+ # clear() causes us to forget all existing instances
+ Uniq1.clear()
+ a12 = Uniq1('a')
+ self.assert_(a10 is not a12)
+
+ self.assertRaises(NotImplementedError, KeyedInstance._key)
diff --git a/migrate/tests/versioning/test_pathed.py b/migrate/tests/versioning/test_pathed.py
new file mode 100644
index 0000000..7616e9d
--- /dev/null
+++ b/migrate/tests/versioning/test_pathed.py
@@ -0,0 +1,51 @@
+from migrate.tests import fixture
+from migrate.versioning.pathed import *
+
+class TestPathed(fixture.Base):
+ def test_parent_path(self):
+ """Default parent_path should behave correctly"""
+ filepath='/fgsfds/moot.py'
+ dirpath='/fgsfds/moot'
+ sdirpath='/fgsfds/moot/'
+
+ result='/fgsfds'
+ self.assert_(result==Pathed._parent_path(filepath))
+ self.assert_(result==Pathed._parent_path(dirpath))
+ self.assert_(result==Pathed._parent_path(sdirpath))
+
+ def test_new(self):
+ """Pathed(path) shouldn't create duplicate objects of the same path"""
+ path='/fgsfds'
+ class Test(Pathed):
+ attr=None
+ o1=Test(path)
+ o2=Test(path)
+ self.assert_(isinstance(o1,Test))
+ self.assert_(o1.path==path)
+ self.assert_(o1 is o2)
+ o1.attr='herring'
+ self.assert_(o2.attr=='herring')
+ o2.attr='shrubbery'
+ self.assert_(o1.attr=='shrubbery')
+
+ def test_parent(self):
+ """Parents should be fetched correctly"""
+ class Parent(Pathed):
+ parent=None
+ children=0
+ def _init_child(self,child,path):
+ """Keep a tally of children.
+ (A real class might do something more interesting here)
+ """
+ self.__class__.children+=1
+
+ class Child(Pathed):
+ parent=Parent
+
+ path='/fgsfds/moot.py'
+ parent_path='/fgsfds'
+ object=Child(path)
+ self.assert_(isinstance(object,Child))
+ self.assert_(isinstance(object.parent,Parent))
+ self.assert_(object.path==path)
+ self.assert_(object.parent.path==parent_path)
diff --git a/migrate/tests/versioning/test_repository.py b/migrate/tests/versioning/test_repository.py
new file mode 100644
index 0000000..a23cd45
--- /dev/null
+++ b/migrate/tests/versioning/test_repository.py
@@ -0,0 +1,198 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import os
+import shutil
+
+from migrate.versioning import exceptions
+from migrate.versioning.repository import *
+from migrate.versioning.script import *
+from nose.tools import raises
+
+from migrate.tests import fixture
+
+
+class TestRepository(fixture.Pathed):
+ def test_create(self):
+ """Repositories are created successfully"""
+ path = self.tmp_repos()
+ name = 'repository_name'
+
+ # Creating a repository that doesn't exist should succeed
+ repo = Repository.create(path, name)
+ config_path = repo.config.path
+ manage_path = os.path.join(repo.path, 'manage.py')
+ self.assert_(repo)
+
+ # Files should actually be created
+ self.assert_(os.path.exists(path))
+ self.assert_(os.path.exists(config_path))
+ self.assert_(os.path.exists(manage_path))
+
+ # Can't create it again: it already exists
+ self.assertRaises(exceptions.PathFoundError, Repository.create, path, name)
+ return path
+
+ def test_load(self):
+ """We should be able to load information about an existing repository"""
+ # Create a repository to load
+ path = self.test_create()
+ repos = Repository(path)
+
+ self.assert_(repos)
+ self.assert_(repos.config)
+ self.assert_(repos.config.get('db_settings', 'version_table'))
+
+ # version_table's default isn't none
+ self.assertNotEquals(repos.config.get('db_settings', 'version_table'), 'None')
+
+ def test_load_notfound(self):
+ """Nonexistant repositories shouldn't be loaded"""
+ path = self.tmp_repos()
+ self.assert_(not os.path.exists(path))
+ self.assertRaises(exceptions.InvalidRepositoryError, Repository, path)
+
+ def test_load_invalid(self):
+ """Invalid repos shouldn't be loaded"""
+ # Here, invalid=empty directory. There may be other conditions too,
+ # but we shouldn't need to test all of them
+ path = self.tmp_repos()
+ os.mkdir(path)
+ self.assertRaises(exceptions.InvalidRepositoryError, Repository, path)
+
+
+class TestVersionedRepository(fixture.Pathed):
+ """Tests on an existing repository with a single python script"""
+
+ def setUp(self):
+ super(TestVersionedRepository, self).setUp()
+ Repository.clear()
+ self.path_repos = self.tmp_repos()
+ Repository.create(self.path_repos, 'repository_name')
+
+ def test_version(self):
+ """We should correctly detect the version of a repository"""
+ repos = Repository(self.path_repos)
+
+ # Get latest version, or detect if a specified version exists
+ self.assertEquals(repos.latest, 0)
+ # repos.latest isn't an integer, but a VerNum
+ # (so we can't just assume the following tests are correct)
+ self.assert_(repos.latest >= 0)
+ self.assert_(repos.latest < 1)
+
+ # Create a script and test again
+ repos.create_script('')
+ self.assertEquals(repos.latest, 1)
+ self.assert_(repos.latest >= 0)
+ self.assert_(repos.latest >= 1)
+ self.assert_(repos.latest < 2)
+
+ # Create a new script and test again
+ repos.create_script('')
+ self.assertEquals(repos.latest, 2)
+ self.assert_(repos.latest >= 0)
+ self.assert_(repos.latest >= 1)
+ self.assert_(repos.latest >= 2)
+ self.assert_(repos.latest < 3)
+
+ def test_source(self):
+ """Get a script object by version number and view its source"""
+ # Load repository and commit script
+ repo = Repository(self.path_repos)
+ repo.create_script('')
+ repo.create_script_sql('postgres')
+
+ # Source is valid: script must have an upgrade function
+ # (not a very thorough test, but should be plenty)
+ source = repo.version(1).script().source()
+ self.assertTrue(source.find('def upgrade') >= 0)
+
+ source = repo.version(2).script('postgres', 'upgrade').source()
+ self.assertEqual(source.strip(), '')
+
+ def test_latestversion(self):
+ """Repository.version() (no params) returns the latest version"""
+ repos = Repository(self.path_repos)
+ repos.create_script('')
+ self.assert_(repos.version(repos.latest) is repos.version())
+ self.assert_(repos.version() is not None)
+
+ def test_changeset(self):
+ """Repositories can create changesets properly"""
+ # Create a nonzero-version repository of empty scripts
+ repos = Repository(self.path_repos)
+ for i in range(10):
+ repos.create_script('')
+
+ def check_changeset(params, length):
+ """Creates and verifies a changeset"""
+ changeset = repos.changeset('postgres', *params)
+ self.assertEquals(len(changeset), length)
+ self.assertTrue(isinstance(changeset, Changeset))
+ uniq = list()
+ # Changesets are iterable
+ for version, change in changeset:
+ self.assert_(isinstance(change, BaseScript))
+ # Changes aren't identical
+ self.assert_(id(change) not in uniq)
+ uniq.append(id(change))
+ return changeset
+
+ # Upgrade to a specified version...
+ cs = check_changeset((0, 10), 10)
+ self.assertEquals(cs.keys().pop(0),0 ) # 0 -> 1: index is starting version
+ self.assertEquals(cs.keys().pop(), 9) # 9 -> 10: index is starting version
+ self.assertEquals(cs.start, 0) # starting version
+ self.assertEquals(cs.end, 10) # ending version
+ check_changeset((0, 1), 1)
+ check_changeset((0, 5), 5)
+ check_changeset((0, 0), 0)
+ check_changeset((5, 5), 0)
+ check_changeset((10, 10), 0)
+ check_changeset((5, 10), 5)
+
+ # Can't request a changeset of higher version than this repository
+ self.assertRaises(Exception, repos.changeset, 'postgres', 5, 11)
+ self.assertRaises(Exception, repos.changeset, 'postgres', -1, 5)
+
+ # Upgrade to the latest version...
+ cs = check_changeset((0,), 10)
+ self.assertEquals(cs.keys().pop(0), 0)
+ self.assertEquals(cs.keys().pop(), 9)
+ self.assertEquals(cs.start, 0)
+ self.assertEquals(cs.end, 10)
+ check_changeset((1,), 9)
+ check_changeset((5,), 5)
+ check_changeset((9,), 1)
+ check_changeset((10,), 0)
+
+ # run changes
+ cs.run('postgres', 'upgrade')
+
+ # Can't request a changeset of higher/lower version than this repository
+ self.assertRaises(Exception, repos.changeset, 'postgres', 11)
+ self.assertRaises(Exception, repos.changeset, 'postgres', -1)
+
+ # Downgrade
+ cs = check_changeset((10, 0),10)
+ self.assertEquals(cs.keys().pop(0), 10) # 10 -> 9
+ self.assertEquals(cs.keys().pop(), 1) # 1 -> 0
+ self.assertEquals(cs.start, 10)
+ self.assertEquals(cs.end, 0)
+ check_changeset((10, 5), 5)
+ check_changeset((5, 0), 5)
+
+ def test_many_versions(self):
+ """Test what happens when lots of versions are created"""
+ repos = Repository(self.path_repos)
+ for i in range(1001):
+ repos.create_script('')
+
+ # since we normally create 3 digit ones, let's see if we blow up
+ self.assert_(os.path.exists('%s/versions/1000.py' % self.path_repos))
+ self.assert_(os.path.exists('%s/versions/1001.py' % self.path_repos))
+
+
+# TODO: test manage file
+# TODO: test changeset
diff --git a/migrate/tests/versioning/test_runchangeset.py b/migrate/tests/versioning/test_runchangeset.py
new file mode 100644
index 0000000..52b0215
--- /dev/null
+++ b/migrate/tests/versioning/test_runchangeset.py
@@ -0,0 +1,52 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import os,shutil
+
+from migrate.tests import fixture
+from migrate.versioning.schema import *
+from migrate.versioning import script
+
+
+class TestRunChangeset(fixture.Pathed,fixture.DB):
+ level=fixture.DB.CONNECT
+ def _setup(self, url):
+ super(TestRunChangeset, self)._setup(url)
+ Repository.clear()
+ self.path_repos=self.tmp_repos()
+ # Create repository, script
+ Repository.create(self.path_repos,'repository_name')
+
+ @fixture.usedb()
+ def test_changeset_run(self):
+ """Running a changeset against a repository gives expected results"""
+ repos=Repository(self.path_repos)
+ for i in range(10):
+ repos.create_script('')
+ try:
+ ControlledSchema(self.engine,repos).drop()
+ except:
+ pass
+ db=ControlledSchema.create(self.engine,repos)
+
+ # Scripts are empty; we'll check version # correctness.
+ # (Correct application of their content is checked elsewhere)
+ self.assertEquals(db.version,0)
+ db.upgrade(1)
+ self.assertEquals(db.version,1)
+ db.upgrade(5)
+ self.assertEquals(db.version,5)
+ db.upgrade(5)
+ self.assertEquals(db.version,5)
+ db.upgrade(None) # Latest is implied
+ self.assertEquals(db.version,10)
+ self.assertRaises(Exception,db.upgrade,11)
+ self.assertEquals(db.version,10)
+ db.upgrade(9)
+ self.assertEquals(db.version,9)
+ db.upgrade(0)
+ self.assertEquals(db.version,0)
+ self.assertRaises(Exception,db.upgrade,-1)
+ self.assertEquals(db.version,0)
+ #changeset = repos.changeset(self.url,0)
+ db.drop()
diff --git a/migrate/tests/versioning/test_schema.py b/migrate/tests/versioning/test_schema.py
new file mode 100644
index 0000000..d4b4861
--- /dev/null
+++ b/migrate/tests/versioning/test_schema.py
@@ -0,0 +1,202 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import os
+import shutil
+
+from migrate.versioning.schema import *
+from migrate.versioning import script, exceptions, schemadiff
+
+from sqlalchemy import *
+
+from migrate.tests import fixture
+
+
+class TestControlledSchema(fixture.Pathed, fixture.DB):
+ # Transactions break postgres in this test; we'll clean up after ourselves
+ level = fixture.DB.CONNECT
+
+ def setUp(self):
+ super(TestControlledSchema, self).setUp()
+ self.path_repos = self.temp_usable_dir + '/repo/'
+ self.repos = Repository.create(self.path_repos, 'repo_name')
+
+ def _setup(self, url):
+ self.setUp()
+ super(TestControlledSchema, self)._setup(url)
+ self.cleanup()
+
+ def _teardown(self):
+ super(TestControlledSchema, self)._teardown()
+ self.cleanup()
+ self.tearDown()
+
+ def cleanup(self):
+ # drop existing version table if necessary
+ try:
+ ControlledSchema(self.engine, self.repos).drop()
+ except:
+ # No table to drop; that's fine, be silent
+ pass
+
+ def tearDown(self):
+ self.cleanup()
+ super(TestControlledSchema, self).tearDown()
+
+ @fixture.usedb()
+ def test_version_control(self):
+ """Establish version control on a particular database"""
+ # Establish version control on this database
+ dbcontrol = ControlledSchema.create(self.engine, self.repos)
+
+ # Trying to create another DB this way fails: table exists
+ self.assertRaises(exceptions.DatabaseAlreadyControlledError,
+ ControlledSchema.create, self.engine, self.repos)
+
+ # We can load a controlled DB this way, too
+ dbcontrol0 = ControlledSchema(self.engine, self.repos)
+ self.assertEquals(dbcontrol, dbcontrol0)
+
+ # We can also use a repository path, instead of a repository
+ dbcontrol0 = ControlledSchema(self.engine, self.repos.path)
+ self.assertEquals(dbcontrol, dbcontrol0)
+
+ # We don't have to use the same connection
+ engine = create_engine(self.url)
+ dbcontrol0 = ControlledSchema(engine, self.repos.path)
+ self.assertEquals(dbcontrol, dbcontrol0)
+
+ # Clean up:
+ dbcontrol.drop()
+
+ # Attempting to drop vc from a db without it should fail
+ self.assertRaises(exceptions.DatabaseNotControlledError, dbcontrol.drop)
+
+ # No table defined should raise error
+ self.assertRaises(exceptions.DatabaseNotControlledError,
+ ControlledSchema, self.engine, self.repos)
+
+ @fixture.usedb()
+ def test_version_control_specified(self):
+ """Establish version control with a specified version"""
+ # Establish version control on this database
+ version = 0
+ dbcontrol = ControlledSchema.create(self.engine, self.repos, version)
+ self.assertEquals(dbcontrol.version, version)
+
+ # Correct when we load it, too
+ dbcontrol = ControlledSchema(self.engine, self.repos)
+ self.assertEquals(dbcontrol.version, version)
+
+ dbcontrol.drop()
+
+ # Now try it with a nonzero value
+ version = 10
+ for i in range(version):
+ self.repos.create_script('')
+ self.assertEquals(self.repos.latest, version)
+
+ # Test with some mid-range value
+ dbcontrol = ControlledSchema.create(self.engine,self.repos, 5)
+ self.assertEquals(dbcontrol.version, 5)
+ dbcontrol.drop()
+
+ # Test with max value
+ dbcontrol = ControlledSchema.create(self.engine, self.repos, version)
+ self.assertEquals(dbcontrol.version, version)
+ dbcontrol.drop()
+
+ @fixture.usedb()
+ def test_version_control_invalid(self):
+ """Try to establish version control with an invalid version"""
+ versions = ('Thirteen', '-1', -1, '' , 13)
+ # A fresh repository doesn't go up to version 13 yet
+ for version in versions:
+ #self.assertRaises(ControlledSchema.InvalidVersionError,
+ # Can't have custom errors with assertRaises...
+ try:
+ ControlledSchema.create(self.engine, self.repos, version)
+ self.assert_(False, repr(version))
+ except exceptions.InvalidVersionError:
+ pass
+
+ @fixture.usedb()
+ def test_changeset(self):
+ """Create changeset from controlled schema"""
+ dbschema = ControlledSchema.create(self.engine, self.repos)
+
+ # empty schema doesn't have changesets
+ cs = dbschema.changeset()
+ self.assertEqual(cs, {})
+
+ for i in range(5):
+ self.repos.create_script('')
+ self.assertEquals(self.repos.latest, 5)
+
+ cs = dbschema.changeset(5)
+ self.assertEqual(len(cs), 5)
+
+ # cleanup
+ dbschema.drop()
+
+ @fixture.usedb()
+ def test_upgrade_runchange(self):
+ dbschema = ControlledSchema.create(self.engine, self.repos)
+
+ for i in range(10):
+ self.repos.create_script('')
+
+ self.assertEquals(self.repos.latest, 10)
+
+ dbschema.upgrade(10)
+
+ self.assertRaises(ValueError, dbschema.upgrade, 'a')
+ self.assertRaises(exceptions.InvalidVersionError, dbschema.runchange, 20, '', 1)
+
+ # TODO: test for table version in db
+
+ # cleanup
+ dbschema.drop()
+
+ @fixture.usedb()
+ def test_create_model(self):
+ """Test workflow to generate create_model"""
+ model = ControlledSchema.create_model(self.engine, self.repos, declarative=False)
+ self.assertTrue(isinstance(model, basestring))
+
+ model = ControlledSchema.create_model(self.engine, self.repos.path, declarative=True)
+ self.assertTrue(isinstance(model, basestring))
+
+ @fixture.usedb()
+ def test_compare_model_to_db(self):
+ meta = self.construct_model()
+
+ diff = ControlledSchema.compare_model_to_db(self.engine, meta, self.repos)
+ self.assertTrue(isinstance(diff, schemadiff.SchemaDiff))
+
+ diff = ControlledSchema.compare_model_to_db(self.engine, meta, self.repos.path)
+ self.assertTrue(isinstance(diff, schemadiff.SchemaDiff))
+ meta.drop_all(self.engine)
+
+ @fixture.usedb()
+ def test_update_db_from_model(self):
+ dbschema = ControlledSchema.create(self.engine, self.repos)
+
+ meta = self.construct_model()
+
+ dbschema.update_db_from_model(meta)
+
+ # TODO: test for table version in db
+
+ # cleanup
+ dbschema.drop()
+ meta.drop_all(self.engine)
+
+ def construct_model(self):
+ meta = MetaData()
+
+ user = Table('temp_model_schema', meta, Column('id', Integer), Column('user', String(245)))
+
+ return meta
+
+ # TODO: test how are tables populated in db
diff --git a/migrate/tests/versioning/test_schemadiff.py b/migrate/tests/versioning/test_schemadiff.py
new file mode 100644
index 0000000..6df6463
--- /dev/null
+++ b/migrate/tests/versioning/test_schemadiff.py
@@ -0,0 +1,170 @@
+# -*- coding: utf-8 -*-
+
+import os
+
+import sqlalchemy
+from sqlalchemy import *
+from nose.tools import eq_
+
+from migrate.versioning import genmodel, schemadiff
+from migrate.changeset import schema
+
+from migrate.tests import fixture
+
+
+class TestSchemaDiff(fixture.DB):
+ table_name = 'tmp_schemadiff'
+ level = fixture.DB.CONNECT
+
+ def _setup(self, url):
+ super(TestSchemaDiff, self)._setup(url)
+ self.meta = MetaData(self.engine, reflect=True)
+ self.meta.drop_all() # in case junk tables are lying around in the test database
+ self.meta = MetaData(self.engine, reflect=True) # needed if we just deleted some tables
+ self.table = Table(self.table_name, self.meta,
+ Column('id',Integer(), primary_key=True),
+ Column('name', UnicodeText()),
+ Column('data', UnicodeText()),
+ )
+
+ def _teardown(self):
+ if self.table.exists():
+ self.meta = MetaData(self.engine, reflect=True)
+ self.meta.drop_all()
+ super(TestSchemaDiff, self)._teardown()
+
+ def _applyLatestModel(self):
+ diff = schemadiff.getDiffOfModelAgainstDatabase(self.meta, self.engine, excludeTables=['migrate_version'])
+ genmodel.ModelGenerator(diff).applyModel()
+
+ # TODO: support for diff/generator to extract differences between columns
+ #@fixture.usedb()
+ #def test_type_diff(self):
+ #"""Basic test for correct type diff"""
+ #self.table.create()
+ #self.meta = MetaData(self.engine)
+ #self.table = Table(self.table_name, self.meta,
+ #Column('id', Integer(), primary_key=True),
+ #Column('name', Unicode(45)),
+ #Column('data', Integer),
+ #)
+ #diff = schemadiff.getDiffOfModelAgainstDatabase\
+ #(self.meta, self.engine, excludeTables=['migrate_version'])
+
+ @fixture.usedb()
+ def test_rundiffs(self):
+
+ def assertDiff(isDiff, tablesMissingInDatabase, tablesMissingInModel, tablesWithDiff):
+ diff = schemadiff.getDiffOfModelAgainstDatabase(self.meta, self.engine, excludeTables=['migrate_version'])
+ eq_(bool(diff), isDiff)
+ eq_( ([t.name for t in diff.tablesMissingInDatabase], [t.name for t in diff.tablesMissingInModel], [t.name for t in diff.tablesWithDiff]),
+ (tablesMissingInDatabase, tablesMissingInModel, tablesWithDiff) )
+
+ # Model is defined but database is empty.
+ assertDiff(True, [self.table_name], [], [])
+
+ # Check Python upgrade and downgrade of database from updated model.
+ diff = schemadiff.getDiffOfModelAgainstDatabase(self.meta, self.engine, excludeTables=['migrate_version'])
+ decls, upgradeCommands, downgradeCommands = genmodel.ModelGenerator(diff).toUpgradeDowngradePython()
+ self.assertEqualsIgnoreWhitespace(decls, '''
+ from migrate.changeset import schema
+ meta = MetaData()
+ tmp_schemadiff = Table('tmp_schemadiff', meta,
+ Column('id', Integer(), primary_key=True, nullable=False),
+ Column('name', UnicodeText(length=None)),
+ Column('data', UnicodeText(length=None)),
+ )
+ ''')
+ self.assertEqualsIgnoreWhitespace(upgradeCommands,
+ '''meta.bind = migrate_engine
+ tmp_schemadiff.create()''')
+ self.assertEqualsIgnoreWhitespace(downgradeCommands,
+ '''meta.bind = migrate_engine
+ tmp_schemadiff.drop()''')
+
+ # Create table in database, now model should match database.
+ self._applyLatestModel()
+ assertDiff(False, [], [], [])
+
+ # Check Python code gen from database.
+ diff = schemadiff.getDiffOfModelAgainstDatabase(MetaData(), self.engine, excludeTables=['migrate_version'])
+ src = genmodel.ModelGenerator(diff).toPython()
+
+ exec src in locals()
+
+ c1 = Table('tmp_schemadiff', self.meta, autoload=True).c
+ c2 = tmp_schemadiff.c
+ self.compare_columns_equal(c1, c2, ['type'])
+ # TODO: get rid of ignoring type
+
+ if not self.engine.name == 'oracle':
+ # Add data, later we'll make sure it's still present.
+ result = self.engine.execute(self.table.insert(), id=1, name=u'mydata')
+ dataId = result.last_inserted_ids()[0]
+
+ # Modify table in model (by removing it and adding it back to model) -- drop column data and add column data2.
+ self.meta.remove(self.table)
+ self.table = Table(self.table_name,self.meta,
+ Column('id',Integer(),primary_key=True),
+ Column('name',UnicodeText(length=None)),
+ Column('data2',Integer(),nullable=True),
+ )
+ assertDiff(True, [], [], [self.table_name])
+
+ # Apply latest model changes and find no more diffs.
+ self._applyLatestModel()
+ assertDiff(False, [], [], [])
+
+ if not self.engine.name == 'oracle':
+ # Make sure data is still present.
+ result = self.engine.execute(self.table.select(self.table.c.id==dataId))
+ rows = result.fetchall()
+ eq_(len(rows), 1)
+ eq_(rows[0].name, 'mydata')
+
+ # Add data, later we'll make sure it's still present.
+ result = self.engine.execute(self.table.insert(), id=2, name=u'mydata2', data2=123)
+ dataId2 = result.last_inserted_ids()[0]
+
+ # Change column type in model.
+ self.meta.remove(self.table)
+ self.table = Table(self.table_name,self.meta,
+ Column('id',Integer(),primary_key=True),
+ Column('name',UnicodeText(length=None)),
+ Column('data2',String(255),nullable=True),
+ )
+ assertDiff(True, [], [], [self.table_name]) # TODO test type diff
+
+ # Apply latest model changes and find no more diffs.
+ self._applyLatestModel()
+ assertDiff(False, [], [], [])
+
+ if not self.engine.name == 'oracle':
+ # Make sure data is still present.
+ result = self.engine.execute(self.table.select(self.table.c.id==dataId2))
+ rows = result.fetchall()
+ self.assertEquals(len(rows), 1)
+ self.assertEquals(rows[0].name, 'mydata2')
+ self.assertEquals(rows[0].data2, '123')
+
+ # Delete data, since we're about to make a required column.
+ # Not even using sqlalchemy.PassiveDefault helps because we're doing explicit column select.
+ self.engine.execute(self.table.delete(), id=dataId)
+
+ if not self.engine.name == 'firebird':
+ # Change column nullable in model.
+ self.meta.remove(self.table)
+ self.table = Table(self.table_name,self.meta,
+ Column('id',Integer(),primary_key=True),
+ Column('name',UnicodeText(length=None)),
+ Column('data2',String(255),nullable=False),
+ )
+ assertDiff(True, [], [], [self.table_name]) # TODO test nullable diff
+
+ # Apply latest model changes and find no more diffs.
+ self._applyLatestModel()
+ assertDiff(False, [], [], [])
+
+ # Remove table from model.
+ self.meta.remove(self.table)
+ assertDiff(True, [], [self.table_name], [])
diff --git a/migrate/tests/versioning/test_script.py b/migrate/tests/versioning/test_script.py
new file mode 100644
index 0000000..51eeef8
--- /dev/null
+++ b/migrate/tests/versioning/test_script.py
@@ -0,0 +1,244 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import os
+import sys
+import shutil
+
+from migrate.versioning import exceptions, version, repository
+from migrate.versioning.script import *
+from migrate.versioning.util import *
+
+from migrate.tests import fixture
+from migrate.tests.fixture.models import tmp_sql_table
+
+
+class TestBaseScript(fixture.Pathed):
+
+ def test_all(self):
+ """Testing all basic BaseScript operations"""
+ # verify / source / run
+ src = self.tmp()
+ open(src, 'w').close()
+ bscript = BaseScript(src)
+ BaseScript.verify(src)
+ self.assertEqual(bscript.source(), '')
+ self.assertRaises(NotImplementedError, bscript.run, 'foobar')
+
+
+class TestPyScript(fixture.Pathed, fixture.DB):
+ cls = PythonScript
+ def test_create(self):
+ """We can create a migration script"""
+ path = self.tmp_py()
+ # Creating a file that doesn't exist should succeed
+ self.cls.create(path)
+ self.assert_(os.path.exists(path))
+ # Created file should be a valid script (If not, raises an error)
+ self.cls.verify(path)
+ # Can't create it again: it already exists
+ self.assertRaises(exceptions.PathFoundError,self.cls.create,path)
+
+ @fixture.usedb(supported='sqlite')
+ def test_run(self):
+ script_path = self.tmp_py()
+ pyscript = PythonScript.create(script_path)
+ pyscript.run(self.engine, 1)
+ pyscript.run(self.engine, -1)
+
+ self.assertRaises(exceptions.ScriptError, pyscript.run, self.engine, 0)
+ self.assertRaises(exceptions.ScriptError, pyscript._func, 'foobar')
+
+ # clean pyc file
+ os.remove(script_path + 'c')
+
+ # test deprecated upgrade/downgrade with no arguments
+ contents = open(script_path, 'r').read()
+ f = open(script_path, 'w')
+ f.write(contents.replace("upgrade(migrate_engine)", "upgrade()"))
+ f.close()
+
+ pyscript = PythonScript(script_path)
+ pyscript._module = None
+ try:
+ pyscript.run(self.engine, 1)
+ pyscript.run(self.engine, -1)
+ except exceptions.ScriptError:
+ pass
+ else:
+ self.fail()
+
+ def test_verify_notfound(self):
+ """Correctly verify a python migration script: nonexistant file"""
+ path = self.tmp_py()
+ self.assertFalse(os.path.exists(path))
+ # Fails on empty path
+ self.assertRaises(exceptions.InvalidScriptError,self.cls.verify,path)
+ self.assertRaises(exceptions.InvalidScriptError,self.cls,path)
+
+ def test_verify_invalidpy(self):
+ """Correctly verify a python migration script: invalid python file"""
+ path=self.tmp_py()
+ # Create empty file
+ f = open(path,'w')
+ f.write("def fail")
+ f.close()
+ self.assertRaises(Exception,self.cls.verify_module,path)
+ # script isn't verified on creation, but on module reference
+ py = self.cls(path)
+ self.assertRaises(Exception,(lambda x: x.module),py)
+
+ def test_verify_nofuncs(self):
+ """Correctly verify a python migration script: valid python file; no upgrade func"""
+ path = self.tmp_py()
+ # Create empty file
+ f = open(path, 'w')
+ f.write("def zergling():\n\tprint 'rush'")
+ f.close()
+ self.assertRaises(exceptions.InvalidScriptError, self.cls.verify_module, path)
+ # script isn't verified on creation, but on module reference
+ py = self.cls(path)
+ self.assertRaises(exceptions.InvalidScriptError,(lambda x: x.module),py)
+
+ @fixture.usedb(supported='sqlite')
+ def test_preview_sql(self):
+ """Preview SQL abstract from ORM layer (sqlite)"""
+ path = self.tmp_py()
+
+ f = open(path, 'w')
+ content = '''
+from migrate import *
+from sqlalchemy import *
+
+metadata = MetaData()
+
+UserGroup = Table('Link', metadata,
+ Column('link1ID', Integer),
+ Column('link2ID', Integer),
+ UniqueConstraint('link1ID', 'link2ID'))
+
+def upgrade(migrate_engine):
+ metadata.create_all(migrate_engine)
+ '''
+ f.write(content)
+ f.close()
+
+ pyscript = self.cls(path)
+ SQL = pyscript.preview_sql(self.url, 1)
+ self.assertEqualsIgnoreWhitespace("""
+ CREATE TABLE "Link"
+ ("link1ID" INTEGER,
+ "link2ID" INTEGER,
+ UNIQUE ("link1ID", "link2ID"))
+ """, SQL)
+ # TODO: test: No SQL should be executed!
+
+ def test_verify_success(self):
+ """Correctly verify a python migration script: success"""
+ path = self.tmp_py()
+ # Succeeds after creating
+ self.cls.create(path)
+ self.cls.verify(path)
+
+ # test for PythonScript.make_update_script_for_model
+
+ @fixture.usedb()
+ def test_make_update_script_for_model(self):
+ """Construct script source from differences of two models"""
+
+ self.setup_model_params()
+ self.write_file(self.first_model_path, self.base_source)
+ self.write_file(self.second_model_path, self.base_source + self.model_source)
+
+ source_script = self.pyscript.make_update_script_for_model(
+ engine=self.engine,
+ oldmodel=load_model('testmodel_first:meta'),
+ model=load_model('testmodel_second:meta'),
+ repository=self.repo_path,
+ )
+
+ self.assertTrue('User.create()' in source_script)
+ self.assertTrue('User.drop()' in source_script)
+
+ #@fixture.usedb()
+ #def test_make_update_script_for_model_equals(self):
+ # """Try to make update script from two identical models"""
+
+ # self.setup_model_params()
+ # self.write_file(self.first_model_path, self.base_source + self.model_source)
+ # self.write_file(self.second_model_path, self.base_source + self.model_source)
+
+ # source_script = self.pyscript.make_update_script_for_model(
+ # engine=self.engine,
+ # oldmodel=load_model('testmodel_first:meta'),
+ # model=load_model('testmodel_second:meta'),
+ # repository=self.repo_path,
+ # )
+
+ # self.assertFalse('User.create()' in source_script)
+ # self.assertFalse('User.drop()' in source_script)
+
+ def setup_model_params(self):
+ self.script_path = self.tmp_py()
+ self.repo_path = self.tmp()
+ self.first_model_path = os.path.join(self.temp_usable_dir, 'testmodel_first.py')
+ self.second_model_path = os.path.join(self.temp_usable_dir, 'testmodel_second.py')
+
+ self.base_source = """from sqlalchemy import *\nmeta = MetaData()\n"""
+ self.model_source = """
+User = Table('User', meta,
+ Column('id', Integer, primary_key=True),
+ Column('login', Unicode(40)),
+ Column('passwd', String(40)),
+)"""
+
+ self.repo = repository.Repository.create(self.repo_path, 'repo')
+ self.pyscript = PythonScript.create(self.script_path)
+
+ def write_file(self, path, contents):
+ f = open(path, 'w')
+ f.write(contents)
+ f.close()
+
+
+class TestSqlScript(fixture.Pathed, fixture.DB):
+
+ @fixture.usedb()
+ def test_error(self):
+ """Test if exception is raised on wrong script source"""
+ src = self.tmp()
+
+ f = open(src, 'w')
+ f.write("""foobar""")
+ f.close()
+
+ sqls = SqlScript(src)
+ self.assertRaises(Exception, sqls.run, self.engine)
+
+ @fixture.usedb()
+ def test_success(self):
+ """Test sucessful SQL execution"""
+ # cleanup and prepare python script
+ tmp_sql_table.metadata.drop_all(self.engine, checkfirst=True)
+ script_path = self.tmp_py()
+ pyscript = PythonScript.create(script_path)
+
+ # populate python script
+ contents = open(script_path, 'r').read()
+ contents = contents.replace("pass", "tmp_sql_table.create(migrate_engine)")
+ contents = 'from migrate.tests.fixture.models import tmp_sql_table\n' + contents
+ f = open(script_path, 'w')
+ f.write(contents)
+ f.close()
+
+ # write SQL script from python script preview
+ pyscript = PythonScript(script_path)
+ src = self.tmp()
+ f = open(src, 'w')
+ f.write(pyscript.preview_sql(self.url, 1))
+ f.close()
+
+ # run the change
+ sqls = SqlScript(src)
+ sqls.run(self.engine, executemany=False)
+ tmp_sql_table.metadata.drop_all(self.engine, checkfirst=True)
diff --git a/migrate/tests/versioning/test_shell.py b/migrate/tests/versioning/test_shell.py
new file mode 100644
index 0000000..1ce6dae
--- /dev/null
+++ b/migrate/tests/versioning/test_shell.py
@@ -0,0 +1,547 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import os
+import sys
+import tempfile
+try:
+ from runpy import run_module
+except ImportError:
+ pass #python2.4
+
+from sqlalchemy import MetaData, Table
+from nose.plugins.skip import SkipTest
+
+from migrate.versioning.repository import Repository
+from migrate.versioning import genmodel, shell, api
+from migrate.versioning.exceptions import *
+from migrate.tests.fixture import Shell, DB, usedb
+from migrate.tests.fixture import models
+
+
+class TestShellCommands(Shell):
+ """Tests migrate.py commands"""
+
+ def test_help(self):
+ """Displays default help dialog"""
+ self.assertEqual(self.env.run('bin/migrate -h').returncode, 0)
+ self.assertEqual(self.env.run('bin/migrate --help').returncode, 0)
+ self.assertEqual(self.env.run('bin/migrate help').returncode, 0)
+
+ def test_help_commands(self):
+ """Display help on a specific command"""
+ # we can only test that we get some output
+ for cmd in api.__all__:
+ result = self.env.run('bin/migrate help %s' % cmd)
+ self.assertTrue(isinstance(result.stdout, basestring))
+ self.assertTrue(result.stdout)
+ self.assertFalse(result.stderr)
+
+ def test_shutdown_logging(self):
+ """Try to shutdown logging output"""
+ repos = self.tmp_repos()
+ result = self.env.run('bin/migrate create %s repository_name' % repos)
+ result = self.env.run('bin/migrate version %s --disable_logging' % repos)
+ self.assertEqual(result.stdout, '')
+ result = self.env.run('bin/migrate version %s -q' % repos)
+ self.assertEqual(result.stdout, '')
+
+ # TODO: assert logging messages to 0
+ shell.main(['version', repos], logging=False)
+
+ def test_main_with_runpy(self):
+ if sys.version_info[:2] == (2, 4):
+ raise SkipTest("runpy is not part of python2.4")
+ try:
+ run_module('migrate.versioning.shell', run_name='__main__')
+ except:
+ pass
+
+ def test_main(self):
+ """Test main() function"""
+ # TODO: test output?
+ repos = self.tmp_repos()
+ shell.main(['help'])
+ shell.main(['help', 'create'])
+ shell.main(['create', 'repo_name', '--preview_sql'], repository=repos)
+ shell.main(['version', '--', '--repository=%s' % repos])
+ shell.main(['version', '-d', '--repository=%s' % repos, '--version=2'])
+ try:
+ shell.main(['foobar'])
+ except SystemExit, e:
+ pass
+ try:
+ shell.main(['create', 'f', 'o', 'o'])
+ except SystemExit, e:
+ pass
+ try:
+ shell.main(['create'])
+ except SystemExit, e:
+ pass
+ try:
+ shell.main(['create', 'repo_name'], repository=repos)
+ except SystemExit, e:
+ pass
+
+ def test_create(self):
+ """Repositories are created successfully"""
+ repos = self.tmp_repos()
+
+ # Creating a file that doesn't exist should succeed
+ result = self.env.run('bin/migrate create %s repository_name' % repos)
+
+ # Files should actually be created
+ self.assert_(os.path.exists(repos))
+
+ # The default table should not be None
+ repos_ = Repository(repos)
+ self.assertNotEquals(repos_.config.get('db_settings', 'version_table'), 'None')
+
+ # Can't create it again: it already exists
+ result = self.env.run('bin/migrate create %s repository_name' % repos,
+ expect_error=True)
+ self.assertEqual(result.returncode, 2)
+
+ def test_script(self):
+ """We can create a migration script via the command line"""
+ repos = self.tmp_repos()
+ result = self.env.run('bin/migrate create %s repository_name' % repos)
+
+ result = self.env.run('bin/migrate script --repository=%s Desc' % repos)
+ self.assert_(os.path.exists('%s/versions/001_Desc.py' % repos))
+
+ result = self.env.run('bin/migrate script More %s' % repos)
+ self.assert_(os.path.exists('%s/versions/002_More.py' % repos))
+
+ result = self.env.run('bin/migrate script "Some Random name" %s' % repos)
+ self.assert_(os.path.exists('%s/versions/003_Some_Random_name.py' % repos))
+
+ def test_script_sql(self):
+ """We can create a migration sql script via the command line"""
+ repos = self.tmp_repos()
+ result = self.env.run('bin/migrate create %s repository_name' % repos)
+
+ result = self.env.run('bin/migrate script_sql mydb %s' % repos)
+ self.assert_(os.path.exists('%s/versions/001_mydb_upgrade.sql' % repos))
+ self.assert_(os.path.exists('%s/versions/001_mydb_downgrade.sql' % repos))
+
+ # Test creating a second
+ result = self.env.run('bin/migrate script_sql postgres --repository=%s' % repos)
+ self.assert_(os.path.exists('%s/versions/002_postgres_upgrade.sql' % repos))
+ self.assert_(os.path.exists('%s/versions/002_postgres_downgrade.sql' % repos))
+
+ # TODO: test --previews
+
+ def test_manage(self):
+ """Create a project management script"""
+ script = self.tmp_py()
+ self.assert_(not os.path.exists(script))
+
+ # No attempt is made to verify correctness of the repository path here
+ result = self.env.run('bin/migrate manage %s --repository=/bla/' % script)
+ self.assert_(os.path.exists(script))
+
+
+class TestShellRepository(Shell):
+ """Shell commands on an existing repository/python script"""
+
+ def setUp(self):
+ """Create repository, python change script"""
+ super(TestShellRepository, self).setUp()
+ self.path_repos = self.tmp_repos()
+ result = self.env.run('bin/migrate create %s repository_name' % self.path_repos)
+
+ def test_version(self):
+ """Correctly detect repository version"""
+ # Version: 0 (no scripts yet); successful execution
+ result = self.env.run('bin/migrate version --repository=%s' % self.path_repos)
+ self.assertEqual(result.stdout.strip(), "0")
+
+ # Also works as a positional param
+ result = self.env.run('bin/migrate version %s' % self.path_repos)
+ self.assertEqual(result.stdout.strip(), "0")
+
+ # Create a script and version should increment
+ result = self.env.run('bin/migrate script Desc %s' % self.path_repos)
+ result = self.env.run('bin/migrate version %s' % self.path_repos)
+ self.assertEqual(result.stdout.strip(), "1")
+
+ def test_source(self):
+ """Correctly fetch a script's source"""
+ result = self.env.run('bin/migrate script Desc --repository=%s' % self.path_repos)
+
+ filename = '%s/versions/001_Desc.py' % self.path_repos
+ source = open(filename).read()
+ self.assert_(source.find('def upgrade') >= 0)
+
+ # Version is now 1
+ result = self.env.run('bin/migrate version %s' % self.path_repos)
+ self.assertEqual(result.stdout.strip(), "1")
+
+ # Output/verify the source of version 1
+ result = self.env.run('bin/migrate source 1 --repository=%s' % self.path_repos)
+ self.assertEqual(result.stdout.strip(), source.strip())
+
+ # We can also send the source to a file... test that too
+ result = self.env.run('bin/migrate source 1 %s --repository=%s' %
+ (filename, self.path_repos))
+ self.assert_(os.path.exists(filename))
+ fd = open(filename)
+ result = fd.read()
+ self.assert_(result.strip() == source.strip())
+
+
+class TestShellDatabase(Shell, DB):
+ """Commands associated with a particular database"""
+ # We'll need to clean up after ourself, since the shell creates its own txn;
+ # we need to connect to the DB to see if things worked
+
+ level = DB.CONNECT
+
+ @usedb()
+ def test_version_control(self):
+ """Ensure we can set version control on a database"""
+ path_repos = repos = self.tmp_repos()
+ url = self.url
+ result = self.env.run('bin/migrate create %s repository_name' % repos)
+
+ result = self.env.run('bin/migrate drop_version_control %(url)s %(repos)s'\
+ % locals(), expect_error=True)
+ self.assertEqual(result.returncode, 1)
+ result = self.env.run('bin/migrate version_control %(url)s %(repos)s' % locals())
+
+ # Clean up
+ result = self.env.run('bin/migrate drop_version_control %(url)s %(repos)s' % locals())
+ # Attempting to drop vc from a database without it should fail
+ result = self.env.run('bin/migrate drop_version_control %(url)s %(repos)s'\
+ % locals(), expect_error=True)
+ self.assertEqual(result.returncode, 1)
+
+ @usedb()
+ def test_wrapped_kwargs(self):
+ """Commands with default arguments set by manage.py"""
+ path_repos = repos = self.tmp_repos()
+ url = self.url
+ result = self.env.run('bin/migrate create --name=repository_name %s' % repos)
+ result = self.env.run('bin/migrate drop_version_control %(url)s %(repos)s' % locals(), expect_error=True)
+ self.assertEqual(result.returncode, 1)
+ result = self.env.run('bin/migrate version_control %(url)s %(repos)s' % locals())
+
+ result = self.env.run('bin/migrate drop_version_control %(url)s %(repos)s' % locals())
+
+ @usedb()
+ def test_version_control_specified(self):
+ """Ensure we can set version control to a particular version"""
+ path_repos = self.tmp_repos()
+ url = self.url
+ result = self.env.run('bin/migrate create --name=repository_name %s' % path_repos)
+ result = self.env.run('bin/migrate drop_version_control %(url)s %(path_repos)s' % locals(), expect_error=True)
+ self.assertEqual(result.returncode, 1)
+
+ # Fill the repository
+ path_script = self.tmp_py()
+ version = 2
+ for i in range(version):
+ result = self.env.run('bin/migrate script Desc --repository=%s' % path_repos)
+
+ # Repository version is correct
+ result = self.env.run('bin/migrate version %s' % path_repos)
+ self.assertEqual(result.stdout.strip(), str(version))
+
+ # Apply versioning to DB
+ result = self.env.run('bin/migrate version_control %(url)s %(path_repos)s %(version)s' % locals())
+
+ # Test db version number (should start at 2)
+ result = self.env.run('bin/migrate db_version %(url)s %(path_repos)s' % locals())
+ self.assertEqual(result.stdout.strip(), str(version))
+
+ # Clean up
+ result = self.env.run('bin/migrate drop_version_control %(url)s %(path_repos)s' % locals())
+
+ @usedb()
+ def test_upgrade(self):
+ """Can upgrade a versioned database"""
+ # Create a repository
+ repos_name = 'repos_name'
+ repos_path = self.tmp()
+ result = self.env.run('bin/migrate create %(repos_path)s %(repos_name)s' % locals())
+ self.assertEquals(self.run_version(repos_path), 0)
+
+ # Version the DB
+ result = self.env.run('bin/migrate drop_version_control %s %s' % (self.url, repos_path), expect_error=True)
+ result = self.env.run('bin/migrate version_control %s %s' % (self.url, repos_path))
+
+ # Upgrades with latest version == 0
+ self.assertEquals(self.run_db_version(self.url, repos_path), 0)
+ result = self.env.run('bin/migrate upgrade %s %s' % (self.url, repos_path))
+ self.assertEquals(self.run_db_version(self.url, repos_path), 0)
+ result = self.env.run('bin/migrate upgrade %s %s' % (self.url, repos_path))
+ self.assertEquals(self.run_db_version(self.url, repos_path), 0)
+ result = self.env.run('bin/migrate upgrade %s %s 1' % (self.url, repos_path), expect_error=True)
+ self.assertEquals(result.returncode, 1)
+ result = self.env.run('bin/migrate upgrade %s %s -1' % (self.url, repos_path), expect_error=True)
+ self.assertEquals(result.returncode, 2)
+
+ # Add a script to the repository; upgrade the db
+ result = self.env.run('bin/migrate script Desc --repository=%s' % (repos_path))
+ self.assertEquals(self.run_version(repos_path), 1)
+ self.assertEquals(self.run_db_version(self.url, repos_path), 0)
+
+ # Test preview
+ result = self.env.run('bin/migrate upgrade %s %s 0 --preview_sql' % (self.url, repos_path))
+ result = self.env.run('bin/migrate upgrade %s %s 0 --preview_py' % (self.url, repos_path))
+
+ result = self.env.run('bin/migrate upgrade %s %s' % (self.url, repos_path))
+ self.assertEquals(self.run_db_version(self.url, repos_path), 1)
+
+ # Downgrade must have a valid version specified
+ result = self.env.run('bin/migrate downgrade %s %s' % (self.url, repos_path), expect_error=True)
+ self.assertEquals(result.returncode, 2)
+ result = self.env.run('bin/migrate downgrade %s %s -1' % (self.url, repos_path), expect_error=True)
+ self.assertEquals(result.returncode, 2)
+ result = self.env.run('bin/migrate downgrade %s %s 2' % (self.url, repos_path), expect_error=True)
+ self.assertEquals(result.returncode, 2)
+ self.assertEquals(self.run_db_version(self.url, repos_path), 1)
+
+ result = self.env.run('bin/migrate downgrade %s %s 0' % (self.url, repos_path))
+ self.assertEquals(self.run_db_version(self.url, repos_path), 0)
+
+ result = self.env.run('bin/migrate downgrade %s %s 1' % (self.url, repos_path), expect_error=True)
+ self.assertEquals(result.returncode, 2)
+ self.assertEquals(self.run_db_version(self.url, repos_path), 0)
+
+ result = self.env.run('bin/migrate drop_version_control %s %s' % (self.url, repos_path))
+
+ def _run_test_sqlfile(self, upgrade_script, downgrade_script):
+ # TODO: add test script that checks if db really changed
+ repos_path = self.tmp()
+ repos_name = 'repos'
+
+ result = self.env.run('bin/migrate create %s %s' % (repos_path, repos_name))
+ result = self.env.run('bin/migrate drop_version_control %s %s' % (self.url, repos_path), expect_error=True)
+ result = self.env.run('bin/migrate version_control %s %s' % (self.url, repos_path))
+ self.assertEquals(self.run_version(repos_path), 0)
+ self.assertEquals(self.run_db_version(self.url, repos_path), 0)
+
+ beforeCount = len(os.listdir(os.path.join(repos_path, 'versions'))) # hmm, this number changes sometimes based on running from svn
+ result = self.env.run('bin/migrate script_sql %s --repository=%s' % ('postgres', repos_path))
+ self.assertEquals(self.run_version(repos_path), 1)
+ self.assertEquals(len(os.listdir(os.path.join(repos_path, 'versions'))), beforeCount + 2)
+
+ open('%s/versions/001_postgres_upgrade.sql' % repos_path, 'a').write(upgrade_script)
+ open('%s/versions/001_postgres_downgrade.sql' % repos_path, 'a').write(downgrade_script)
+
+ self.assertEquals(self.run_db_version(self.url, repos_path), 0)
+ self.assertRaises(Exception, self.engine.text('select * from t_table').execute)
+
+ result = self.env.run('bin/migrate upgrade %s %s' % (self.url, repos_path))
+ self.assertEquals(self.run_db_version(self.url, repos_path), 1)
+ self.engine.text('select * from t_table').execute()
+
+ result = self.env.run('bin/migrate downgrade %s %s 0' % (self.url, repos_path))
+ self.assertEquals(self.run_db_version(self.url, repos_path), 0)
+ self.assertRaises(Exception, self.engine.text('select * from t_table').execute)
+
+ # The tests below are written with some postgres syntax, but the stuff
+ # being tested (.sql files) ought to work with any db.
+ @usedb(supported='postgres')
+ def test_sqlfile(self):
+ upgrade_script = """
+ create table t_table (
+ id serial,
+ primary key(id)
+ );
+ """
+ downgrade_script = """
+ drop table t_table;
+ """
+ self.meta.drop_all()
+ self._run_test_sqlfile(upgrade_script, downgrade_script)
+
+ @usedb(supported='postgres')
+ def test_sqlfile_comment(self):
+ upgrade_script = """
+ -- Comments in SQL break postgres autocommit
+ create table t_table (
+ id serial,
+ primary key(id)
+ );
+ """
+ downgrade_script = """
+ -- Comments in SQL break postgres autocommit
+ drop table t_table;
+ """
+ self._run_test_sqlfile(upgrade_script, downgrade_script)
+
+ @usedb()
+ def test_command_test(self):
+ repos_name = 'repos_name'
+ repos_path = self.tmp()
+
+ result = self.env.run('bin/migrate create repository_name --repository=%s' % repos_path)
+ result = self.env.run('bin/migrate drop_version_control %s %s' % (self.url, repos_path), expect_error=True)
+ result = self.env.run('bin/migrate version_control %s %s' % (self.url, repos_path))
+ self.assertEquals(self.run_version(repos_path), 0)
+ self.assertEquals(self.run_db_version(self.url, repos_path), 0)
+
+ # Empty script should succeed
+ result = self.env.run('bin/migrate script Desc %s' % repos_path)
+ result = self.env.run('bin/migrate test %s %s' % (self.url, repos_path))
+ self.assertEquals(self.run_version(repos_path), 1)
+ self.assertEquals(self.run_db_version(self.url, repos_path), 0)
+
+ # Error script should fail
+ script_path = self.tmp_py()
+ script_text='''
+ from sqlalchemy import *
+ from migrate import *
+
+ def upgrade():
+ print 'fgsfds'
+ raise Exception()
+
+ def downgrade():
+ print 'sdfsgf'
+ raise Exception()
+ '''.replace("\n ", "\n")
+ file = open(script_path, 'w')
+ file.write(script_text)
+ file.close()
+
+ result = self.env.run('bin/migrate test %s %s bla' % (self.url, repos_path), expect_error=True)
+ self.assertEqual(result.returncode, 2)
+ self.assertEquals(self.run_version(repos_path), 1)
+ self.assertEquals(self.run_db_version(self.url, repos_path), 0)
+
+ # Nonempty script using migrate_engine should succeed
+ script_path = self.tmp_py()
+ script_text = '''
+ from sqlalchemy import *
+ from migrate import *
+
+ from migrate.changeset import schema
+
+ meta = MetaData(migrate_engine)
+ account = Table('account', meta,
+ Column('id', Integer, primary_key=True),
+ Column('login', Text),
+ Column('passwd', Text),
+ )
+ def upgrade():
+ # Upgrade operations go here. Don't create your own engine; use the engine
+ # named 'migrate_engine' imported from migrate.
+ meta.create_all()
+
+ def downgrade():
+ # Operations to reverse the above upgrade go here.
+ meta.drop_all()
+ '''.replace("\n ", "\n")
+ file = open(script_path, 'w')
+ file.write(script_text)
+ file.close()
+ result = self.env.run('bin/migrate test %s %s' % (self.url, repos_path))
+ self.assertEquals(self.run_version(repos_path), 1)
+ self.assertEquals(self.run_db_version(self.url, repos_path), 0)
+
+ @usedb()
+ def test_rundiffs_in_shell(self):
+ # This is a variant of the test_schemadiff tests but run through the shell level.
+ # These shell tests are hard to debug (since they keep forking processes), so they shouldn't replace the lower-level tests.
+ repos_name = 'repos_name'
+ repos_path = self.tmp()
+ script_path = self.tmp_py()
+ model_module = 'migrate.tests.fixture.models:meta_rundiffs'
+ old_model_module = 'migrate.tests.fixture.models:meta_old_rundiffs'
+
+ # Create empty repository.
+ self.meta = MetaData(self.engine, reflect=True)
+ self.meta.reflect()
+ self.meta.drop_all() # in case junk tables are lying around in the test database
+
+ result = self.env.run('bin/migrate create %s %s' % (repos_path, repos_name))
+ result = self.env.run('bin/migrate drop_version_control %s %s' % (self.url, repos_path), expect_error=True)
+ result = self.env.run('bin/migrate version_control %s %s' % (self.url, repos_path))
+ self.assertEquals(self.run_version(repos_path), 0)
+ self.assertEquals(self.run_db_version(self.url, repos_path), 0)
+
+ # Setup helper script.
+ result = self.env.run('bin/migrate manage %s --repository=%s --url=%s --model=%s'\
+ % (script_path, repos_path, self.url, model_module))
+ self.assert_(os.path.exists(script_path))
+
+ # Model is defined but database is empty.
+ result = self.env.run('bin/migrate compare_model_to_db %s %s --model=%s' \
+ % (self.url, repos_path, model_module))
+ self.assert_("tables missing in database: tmp_account_rundiffs" in result.stdout)
+
+ # Test Deprecation
+ result = self.env.run('bin/migrate compare_model_to_db %s %s --model=%s' \
+ % (self.url, repos_path, model_module.replace(":", ".")), expect_error=True)
+ self.assertEqual(result.returncode, 0)
+ self.assertTrue("DeprecationWarning" in result.stderr)
+ self.assert_("tables missing in database: tmp_account_rundiffs" in result.stdout)
+
+ # Update db to latest model.
+ result = self.env.run('bin/migrate update_db_from_model %s %s %s'\
+ % (self.url, repos_path, model_module))
+ self.assertEquals(self.run_version(repos_path), 0)
+ self.assertEquals(self.run_db_version(self.url, repos_path), 0) # version did not get bumped yet because new version not yet created
+
+ result = self.env.run('bin/migrate compare_model_to_db %s %s %s'\
+ % (self.url, repos_path, model_module))
+ self.assert_("No schema diffs" in result.stdout)
+
+ result = self.env.run('bin/migrate drop_version_control %s %s' % (self.url, repos_path), expect_error=True)
+ result = self.env.run('bin/migrate version_control %s %s' % (self.url, repos_path))
+
+ result = self.env.run('bin/migrate create_model %s %s' % (self.url, repos_path))
+ temp_dict = dict()
+ exec result.stdout in temp_dict
+
+ # TODO: compare whole table
+ self.compare_columns_equal(models.tmp_account_rundiffs.c, temp_dict['tmp_account_rundiffs'].c)
+ #self.assertTrue("""tmp_account_rundiffs = Table('tmp_account_rundiffs', meta,
+ #Column('id', Integer(), primary_key=True, nullable=False),
+ #Column('login', String(length=None, convert_unicode=False, assert_unicode=None)),
+ #Column('passwd', String(length=None, convert_unicode=False, assert_unicode=None))""" in result.stdout)
+
+ # We're happy with db changes, make first db upgrade script to go from version 0 -> 1.
+ result = self.env.run('bin/migrate make_update_script_for_model', expect_error=True)
+ self.assertTrue('Not enough arguments' in result.stderr)
+
+ result_script = self.env.run('bin/migrate make_update_script_for_model %s %s %s %s'\
+ % (self.url, repos_path, old_model_module, model_module))
+ self.assertEqualsIgnoreWhitespace(result_script.stdout,
+ '''from sqlalchemy import *
+ from migrate import *
+
+ from migrate.changeset import schema
+
+ meta = MetaData()
+ tmp_account_rundiffs = Table('tmp_account_rundiffs', meta,
+ Column('id', Integer(), primary_key=True, nullable=False),
+ Column('login', String(length=None, convert_unicode=False, assert_unicode=None)),
+ Column('passwd', String(length=None, convert_unicode=False, assert_unicode=None)),
+ )
+
+ def upgrade(migrate_engine):
+ # Upgrade operations go here. Don't create your own engine; bind migrate_engine
+ # to your metadata
+ meta.bind = migrate_engine
+ tmp_account_rundiffs.create()
+
+ def downgrade(migrate_engine):
+ # Operations to reverse the above upgrade go here.
+ meta.bind = migrate_engine
+ tmp_account_rundiffs.drop()''')
+
+ # Save the upgrade script.
+ result = self.env.run('bin/migrate script Desc %s' % repos_path)
+ upgrade_script_path = '%s/versions/001_Desc.py' % repos_path
+ open(upgrade_script_path, 'w').write(result_script.stdout)
+
+ result = self.env.run('bin/migrate compare_model_to_db %s %s %s'\
+ % (self.url, repos_path, model_module))
+ self.assert_("No schema diffs" in result.stdout)
+
+ self.meta.drop_all() # in case junk tables are lying around in the test database
diff --git a/migrate/tests/versioning/test_template.py b/migrate/tests/versioning/test_template.py
new file mode 100644
index 0000000..c6a45a6
--- /dev/null
+++ b/migrate/tests/versioning/test_template.py
@@ -0,0 +1,70 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+import os
+import shutil
+
+import migrate.versioning.templates
+from migrate.versioning.template import *
+from migrate.versioning import api
+
+from migrate.tests import fixture
+
+
+class TestTemplate(fixture.Pathed):
+ def test_templates(self):
+ """We can find the path to all repository templates"""
+ path = str(Template())
+ self.assert_(os.path.exists(path))
+
+ def test_repository(self):
+ """We can find the path to the default repository"""
+ path = Template().get_repository()
+ self.assert_(os.path.exists(path))
+
+ def test_script(self):
+ """We can find the path to the default migration script"""
+ path = Template().get_script()
+ self.assert_(os.path.exists(path))
+
+ def test_custom_templates_and_themes(self):
+ """Users can define their own templates with themes"""
+ new_templates_dir = os.path.join(self.temp_usable_dir, 'templates')
+ manage_tmpl_file = os.path.join(new_templates_dir, 'manage/custom.py_tmpl')
+ repository_tmpl_file = os.path.join(new_templates_dir, 'repository/custom/README')
+ script_tmpl_file = os.path.join(new_templates_dir, 'script/custom.py_tmpl')
+ sql_script_tmpl_file = os.path.join(new_templates_dir, 'sql_script/custom.py_tmpl')
+
+ MANAGE_CONTENTS = 'print "manage.py"'
+ README_CONTENTS = 'MIGRATE README!'
+ SCRIPT_FILE_CONTENTS = 'print "script.py"'
+ new_repo_dest = self.tmp_repos()
+ new_manage_dest = self.tmp_py()
+
+ # make new templates dir
+ shutil.copytree(migrate.versioning.templates.__path__[0], new_templates_dir)
+ shutil.copytree(os.path.join(new_templates_dir, 'repository/default'),
+ os.path.join(new_templates_dir, 'repository/custom'))
+
+ # edit templates
+ f = open(manage_tmpl_file, 'w').write(MANAGE_CONTENTS)
+ f = open(repository_tmpl_file, 'w').write(README_CONTENTS)
+ f = open(script_tmpl_file, 'w').write(SCRIPT_FILE_CONTENTS)
+ f = open(sql_script_tmpl_file, 'w').write(SCRIPT_FILE_CONTENTS)
+
+ # create repository, manage file and python script
+ kw = {}
+ kw['templates_path'] = new_templates_dir
+ kw['templates_theme'] = 'custom'
+ api.create(new_repo_dest, 'repo_name', **kw)
+ api.script('test', new_repo_dest, **kw)
+ api.script_sql('postgres', new_repo_dest, **kw)
+ api.manage(new_manage_dest, **kw)
+
+ # assert changes
+ self.assertEqual(open(new_manage_dest).read(), MANAGE_CONTENTS)
+ self.assertEqual(open(os.path.join(new_repo_dest, 'manage.py')).read(), MANAGE_CONTENTS)
+ self.assertEqual(open(os.path.join(new_repo_dest, 'README')).read(), README_CONTENTS)
+ self.assertEqual(open(os.path.join(new_repo_dest, 'versions/001_test.py')).read(), SCRIPT_FILE_CONTENTS)
+ self.assertEqual(open(os.path.join(new_repo_dest, 'versions/002_postgres_downgrade.sql')).read(), SCRIPT_FILE_CONTENTS)
+ self.assertEqual(open(os.path.join(new_repo_dest, 'versions/002_postgres_upgrade.sql')).read(), SCRIPT_FILE_CONTENTS)
diff --git a/migrate/tests/versioning/test_util.py b/migrate/tests/versioning/test_util.py
new file mode 100644
index 0000000..76bd1fd
--- /dev/null
+++ b/migrate/tests/versioning/test_util.py
@@ -0,0 +1,90 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import os
+
+from sqlalchemy import *
+
+from migrate.tests import fixture
+from migrate.versioning.util import *
+
+
+class TestUtil(fixture.Pathed):
+
+ def test_construct_engine(self):
+ """Construct engine the smart way"""
+ url = 'sqlite://'
+
+ engine = construct_engine(url)
+ self.assert_(engine.name == 'sqlite')
+
+ # keyword arg
+ engine = construct_engine(url, engine_arg_encoding=True)
+ self.assertTrue(engine.dialect.encoding)
+
+ # dict
+ engine = construct_engine(url, engine_dict={'encoding': True})
+ self.assertTrue(engine.dialect.encoding)
+
+ # engine parameter
+ engine_orig = create_engine('sqlite://')
+ engine = construct_engine(engine_orig)
+ self.assertEqual(engine, engine_orig)
+
+ # test precedance
+ engine = construct_engine(url, engine_dict={'encoding': False},
+ engine_arg_encoding=True)
+ self.assertTrue(engine.dialect.encoding)
+
+ # deprecated echo=True parameter
+ engine = construct_engine(url, echo='True')
+ self.assertTrue(engine.echo)
+
+ # unsupported argument
+ self.assertRaises(ValueError, construct_engine, 1)
+
+ def test_asbool(self):
+ """test asbool parsing"""
+ result = asbool(True)
+ self.assertEqual(result, True)
+
+ result = asbool(False)
+ self.assertEqual(result, False)
+
+ result = asbool('y')
+ self.assertEqual(result, True)
+
+ result = asbool('n')
+ self.assertEqual(result, False)
+
+ self.assertRaises(ValueError, asbool, 'test')
+ self.assertRaises(ValueError, asbool, object)
+
+
+ def test_load_model(self):
+ """load model from dotted name"""
+ model_path = os.path.join(self.temp_usable_dir, 'test_load_model.py')
+
+ f = open(model_path, 'w')
+ f.write("class FakeFloat(int): pass")
+ f.close()
+
+ FakeFloat = load_model('test_load_model.FakeFloat')
+ self.assert_(isinstance(FakeFloat(), int))
+
+ FakeFloat = load_model('test_load_model:FakeFloat')
+ self.assert_(isinstance(FakeFloat(), int))
+
+ FakeFloat = load_model(FakeFloat)
+ self.assert_(isinstance(FakeFloat(), int))
+
+ def test_guess_obj_type(self):
+ """guess object type from string"""
+ result = guess_obj_type('7')
+ self.assertEqual(result, 7)
+
+ result = guess_obj_type('y')
+ self.assertEqual(result, True)
+
+ result = guess_obj_type('test')
+ self.assertEqual(result, 'test')
diff --git a/migrate/tests/versioning/test_version.py b/migrate/tests/versioning/test_version.py
new file mode 100644
index 0000000..0be192f
--- /dev/null
+++ b/migrate/tests/versioning/test_version.py
@@ -0,0 +1,165 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from migrate.versioning.version import *
+from migrate.versioning.exceptions import *
+
+from migrate.tests import fixture
+
+
+class TestVerNum(fixture.Base):
+ def test_invalid(self):
+ """Disallow invalid version numbers"""
+ versions = ('-1', -1, 'Thirteen', '')
+ for version in versions:
+ self.assertRaises(ValueError, VerNum, version)
+
+ def test_str(self):
+ """Test str and repr version numbers"""
+ self.assertEqual(str(VerNum(2)), '2')
+ self.assertEqual(repr(VerNum(2)), '<VerNum(2)>')
+
+ def test_is(self):
+ """Two version with the same number should be equal"""
+ a = VerNum(1)
+ b = VerNum(1)
+ self.assert_(a is b)
+
+ self.assertEqual(VerNum(VerNum(2)), VerNum(2))
+
+ def test_add(self):
+ self.assertEqual(VerNum(1) + VerNum(1), VerNum(2))
+ self.assertEqual(VerNum(1) + 1, 2)
+ self.assertEqual(VerNum(1) + 1, '2')
+ self.assert_(isinstance(VerNum(1) + 1, VerNum))
+
+ def test_sub(self):
+ self.assertEqual(VerNum(1) - 1, 0)
+ self.assert_(isinstance(VerNum(1) - 1, VerNum))
+ self.assertRaises(ValueError, lambda: VerNum(0) - 1)
+
+ def test_eq(self):
+ """Two versions are equal"""
+ self.assertEqual(VerNum(1), VerNum('1'))
+ self.assertEqual(VerNum(1), 1)
+ self.assertEqual(VerNum(1), '1')
+ self.assertNotEqual(VerNum(1), 2)
+
+ def test_ne(self):
+ self.assert_(VerNum(1) != 2)
+ self.assertFalse(VerNum(1) != 1)
+
+ def test_lt(self):
+ self.assertFalse(VerNum(1) < 1)
+ self.assert_(VerNum(1) < 2)
+ self.assertFalse(VerNum(2) < 1)
+
+ def test_le(self):
+ self.assert_(VerNum(1) <= 1)
+ self.assert_(VerNum(1) <= 2)
+ self.assertFalse(VerNum(2) <= 1)
+
+ def test_gt(self):
+ self.assertFalse(VerNum(1) > 1)
+ self.assertFalse(VerNum(1) > 2)
+ self.assert_(VerNum(2) > 1)
+
+ def test_ge(self):
+ self.assert_(VerNum(1) >= 1)
+ self.assert_(VerNum(2) >= 1)
+ self.assertFalse(VerNum(1) >= 2)
+
+
+class TestVersion(fixture.Pathed):
+
+ def setUp(self):
+ super(TestVersion, self).setUp()
+
+ def test_str_to_filename(self):
+ self.assertEquals(str_to_filename(''), '')
+ self.assertEquals(str_to_filename('__'), '_')
+ self.assertEquals(str_to_filename('a'), 'a')
+ self.assertEquals(str_to_filename('Abc Def'), 'Abc_Def')
+ self.assertEquals(str_to_filename('Abc "D" Ef'), 'Abc_D_Ef')
+ self.assertEquals(str_to_filename("Abc's Stuff"), 'Abc_s_Stuff')
+ self.assertEquals(str_to_filename("a b"), 'a_b')
+
+ def test_collection(self):
+ """Let's see how we handle versions collection"""
+ coll = Collection(self.temp_usable_dir)
+ coll.create_new_python_version("foo bar")
+ coll.create_new_sql_version("postgres")
+ coll.create_new_sql_version("sqlite")
+ coll.create_new_python_version("")
+
+ self.assertEqual(coll.latest, 4)
+ self.assertEqual(len(coll.versions), 4)
+ self.assertEqual(coll.version(4), coll.version(coll.latest))
+
+ coll2 = Collection(self.temp_usable_dir)
+ self.assertEqual(coll.versions, coll2.versions)
+
+ Collection.clear()
+
+ def test_old_repository(self):
+ open(os.path.join(self.temp_usable_dir, '1'), 'w')
+ self.assertRaises(Exception, Collection, self.temp_usable_dir)
+
+ #TODO: def test_collection_unicode(self):
+ # pass
+
+ def test_create_new_python_version(self):
+ coll = Collection(self.temp_usable_dir)
+ coll.create_new_python_version("'")
+
+ ver = coll.version()
+ self.assert_(ver.script().source())
+
+ def test_create_new_sql_version(self):
+ coll = Collection(self.temp_usable_dir)
+ coll.create_new_sql_version("sqlite")
+
+ ver = coll.version()
+ ver_up = ver.script('sqlite', 'upgrade')
+ ver_down = ver.script('sqlite', 'downgrade')
+ ver_up.source()
+ ver_down.source()
+
+ def test_selection(self):
+ """Verify right sql script is selected"""
+
+ # Create empty directory.
+ path = self.tmp_repos()
+ os.mkdir(path)
+
+ # Create files -- files must be present or you'll get an exception later.
+ python_file = '001_initial_.py'
+ sqlite_upgrade_file = '001_sqlite_upgrade.sql'
+ default_upgrade_file = '001_default_upgrade.sql'
+ for file_ in [sqlite_upgrade_file, default_upgrade_file, python_file]:
+ filepath = '%s/%s' % (path, file_)
+ open(filepath, 'w').close()
+
+ ver = Version(1, path, [sqlite_upgrade_file])
+ self.assertEquals(os.path.basename(ver.script('sqlite', 'upgrade').path), sqlite_upgrade_file)
+
+ ver = Version(1, path, [default_upgrade_file])
+ self.assertEquals(os.path.basename(ver.script('default', 'upgrade').path), default_upgrade_file)
+
+ ver = Version(1, path, [sqlite_upgrade_file, default_upgrade_file])
+ self.assertEquals(os.path.basename(ver.script('sqlite', 'upgrade').path), sqlite_upgrade_file)
+
+ ver = Version(1, path, [sqlite_upgrade_file, default_upgrade_file, python_file])
+ self.assertEquals(os.path.basename(ver.script('postgres', 'upgrade').path), default_upgrade_file)
+
+ ver = Version(1, path, [sqlite_upgrade_file, python_file])
+ self.assertEquals(os.path.basename(ver.script('postgres', 'upgrade').path), python_file)
+
+ def test_bad_version(self):
+ ver = Version(1, self.temp_usable_dir, [])
+ self.assertRaises(ScriptError, ver.add_script, '123.sql')
+
+ pyscript = os.path.join(self.temp_usable_dir, 'bla.py')
+ open(pyscript, 'w')
+ ver.add_script(pyscript)
+ self.assertRaises(ScriptError, ver.add_script, 'bla.py')
diff --git a/migrate/versioning/migrate_repository.py b/migrate/versioning/migrate_repository.py
index 27c4d35..53833bb 100644
--- a/migrate/versioning/migrate_repository.py
+++ b/migrate/versioning/migrate_repository.py
@@ -6,6 +6,9 @@
import os
import sys
+import logging
+
+log = logging.getLogger(__name__)
def usage():
@@ -22,13 +25,13 @@ def usage():
def delete_file(filepath):
"""Deletes a file and prints a message."""
- print ' Deleting file: %s' % filepath
+ log.info('Deleting file: %s' % filepath)
os.remove(filepath)
def move_file(src, tgt):
"""Moves a file and prints a message."""
- print ' Moving file %s to %s' % (src, tgt)
+ log.info('Moving file %s to %s' % (src, tgt))
if os.path.exists(tgt):
raise Exception(
'Cannot move file %s because target %s already exists' % \
@@ -38,13 +41,13 @@ def move_file(src, tgt):
def delete_directory(dirpath):
"""Delete a directory and print a message."""
- print ' Deleting directory: %s' % dirpath
+ log.info('Deleting directory: %s' % dirpath)
os.rmdir(dirpath)
def migrate_repository(repos):
"""Does the actual migration to the new repository format."""
- print 'Migrating repository at: %s to new format' % repos
+ log.info('Migrating repository at: %s to new format' % repos)
versions = '%s/versions' % repos
dirs = os.listdir(versions)
# Only use int's in list.
@@ -52,7 +55,7 @@ def migrate_repository(repos):
numdirs.sort() # Sort list.
for dirname in numdirs:
origdir = '%s/%s' % (versions, dirname)
- print ' Working on directory: %s' % origdir
+ log.info('Working on directory: %s' % origdir)
files = os.listdir(origdir)
files.sort()
for filename in files: