diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2014-09-13 16:53:57 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2014-09-13 16:53:57 -0400 |
commit | fa080367259f49c93f5055b5b49d175649126d16 (patch) | |
tree | 07be66e413850cb0f546f624159a2529a7252ff0 /tests/test_version_table.py | |
parent | f41a1cfcea4360e7e0d445094a7f108000a3c3f6 (diff) | |
download | alembic-fa080367259f49c93f5055b5b49d175649126d16.tar.gz |
- finish up
Diffstat (limited to 'tests/test_version_table.py')
-rw-r--r-- | tests/test_version_table.py | 48 |
1 files changed, 25 insertions, 23 deletions
diff --git a/tests/test_version_table.py b/tests/test_version_table.py index 98dec50..28ac6ca 100644 --- a/tests/test_version_table.py +++ b/tests/test_version_table.py @@ -1,6 +1,8 @@ -import unittest +from alembic.testing.fixtures import TestBase -from sqlalchemy import Table, MetaData, Column, String, create_engine +from alembic.testing import config, eq_, assert_raises + +from sqlalchemy import Table, MetaData, Column, String from sqlalchemy.engine.reflection import Inspector from alembic.util import CommandError @@ -9,15 +11,11 @@ version_table = Table('version_table', MetaData(), Column('version_num', String(32), nullable=False)) -class TestMigrationContext(unittest.TestCase): - _bind = [] +class TestMigrationContext(TestBase): - @property - def bind(self): - if not self._bind: - engine = create_engine('sqlite:///', echo=True) - self._bind.append(engine) - return self._bind[0] + @classmethod + def setup_class(cls): + cls.bind = config.db def setUp(self): self.connection = self.bind.connect() @@ -26,6 +24,7 @@ class TestMigrationContext(unittest.TestCase): def tearDown(self): version_table.drop(self.connection, checkfirst=True) self.transaction.rollback() + self.connection.close() def make_one(self, **kwargs): from alembic.migration import MigrationContext @@ -36,49 +35,52 @@ class TestMigrationContext(unittest.TestCase): rows = result.fetchall() if len(rows) == 0: return None - self.assertEqual(len(rows), 1) + eq_(len(rows), 1) return rows[0]['version_num'] def test_config_default_version_table_name(self): context = self.make_one(dialect_name='sqlite') - self.assertEqual(context._version.name, 'alembic_version') + eq_(context._version.name, 'alembic_version') def test_config_explicit_version_table_name(self): context = self.make_one(dialect_name='sqlite', opts={'version_table': 'explicit'}) - self.assertEqual(context._version.name, 'explicit') + eq_(context._version.name, 'explicit') def test_config_explicit_version_table_schema(self): context = self.make_one(dialect_name='sqlite', opts={'version_table_schema': 'explicit'}) - self.assertEqual(context._version.schema, 'explicit') + eq_(context._version.schema, 'explicit') def test_get_current_revision_creates_version_table(self): context = self.make_one(connection=self.connection, opts={'version_table': 'version_table'}) - self.assertEqual(context.get_current_revision(), None) + eq_(context.get_current_revision(), None) insp = Inspector(self.connection) - self.assertTrue('version_table' in insp.get_table_names()) + assert ('version_table' in insp.get_table_names()) def test_get_current_revision(self): context = self.make_one(connection=self.connection, opts={'version_table': 'version_table'}) version_table.create(self.connection) - self.assertEqual(context.get_current_revision(), None) + eq_(context.get_current_revision(), None) self.connection.execute( version_table.insert().values(version_num='revid')) - self.assertEqual(context.get_current_revision(), 'revid') + eq_(context.get_current_revision(), 'revid') def test_get_current_revision_error_if_starting_rev_given_online(self): context = self.make_one(connection=self.connection, opts={'starting_rev': 'boo'}) - self.assertRaises(CommandError, context.get_current_revision) + assert_raises( + CommandError, + context.get_current_revision + ) def test_get_current_revision_offline(self): context = self.make_one(dialect_name='sqlite', opts={'starting_rev': 'startrev', 'as_sql': True}) - self.assertEqual(context.get_current_revision(), 'startrev') + eq_(context.get_current_revision(), 'startrev') def test__update_current_rev(self): version_table.create(self.connection) @@ -86,8 +88,8 @@ class TestMigrationContext(unittest.TestCase): opts={'version_table': 'version_table'}) context._update_current_rev(None, 'a') - self.assertEqual(self.get_revision(), 'a') + eq_(self.get_revision(), 'a') context._update_current_rev('a', 'b') - self.assertEqual(self.get_revision(), 'b') + eq_(self.get_revision(), 'b') context._update_current_rev('b', None) - self.assertEqual(self.get_revision(), None) + eq_(self.get_revision(), None) |