summaryrefslogtreecommitdiff
path: root/tests/test_version_table.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2014-09-13 16:53:57 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2014-09-13 16:53:57 -0400
commitfa080367259f49c93f5055b5b49d175649126d16 (patch)
tree07be66e413850cb0f546f624159a2529a7252ff0 /tests/test_version_table.py
parentf41a1cfcea4360e7e0d445094a7f108000a3c3f6 (diff)
downloadalembic-fa080367259f49c93f5055b5b49d175649126d16.tar.gz
- finish up
Diffstat (limited to 'tests/test_version_table.py')
-rw-r--r--tests/test_version_table.py48
1 files changed, 25 insertions, 23 deletions
diff --git a/tests/test_version_table.py b/tests/test_version_table.py
index 98dec50..28ac6ca 100644
--- a/tests/test_version_table.py
+++ b/tests/test_version_table.py
@@ -1,6 +1,8 @@
-import unittest
+from alembic.testing.fixtures import TestBase
-from sqlalchemy import Table, MetaData, Column, String, create_engine
+from alembic.testing import config, eq_, assert_raises
+
+from sqlalchemy import Table, MetaData, Column, String
from sqlalchemy.engine.reflection import Inspector
from alembic.util import CommandError
@@ -9,15 +11,11 @@ version_table = Table('version_table', MetaData(),
Column('version_num', String(32), nullable=False))
-class TestMigrationContext(unittest.TestCase):
- _bind = []
+class TestMigrationContext(TestBase):
- @property
- def bind(self):
- if not self._bind:
- engine = create_engine('sqlite:///', echo=True)
- self._bind.append(engine)
- return self._bind[0]
+ @classmethod
+ def setup_class(cls):
+ cls.bind = config.db
def setUp(self):
self.connection = self.bind.connect()
@@ -26,6 +24,7 @@ class TestMigrationContext(unittest.TestCase):
def tearDown(self):
version_table.drop(self.connection, checkfirst=True)
self.transaction.rollback()
+ self.connection.close()
def make_one(self, **kwargs):
from alembic.migration import MigrationContext
@@ -36,49 +35,52 @@ class TestMigrationContext(unittest.TestCase):
rows = result.fetchall()
if len(rows) == 0:
return None
- self.assertEqual(len(rows), 1)
+ eq_(len(rows), 1)
return rows[0]['version_num']
def test_config_default_version_table_name(self):
context = self.make_one(dialect_name='sqlite')
- self.assertEqual(context._version.name, 'alembic_version')
+ eq_(context._version.name, 'alembic_version')
def test_config_explicit_version_table_name(self):
context = self.make_one(dialect_name='sqlite',
opts={'version_table': 'explicit'})
- self.assertEqual(context._version.name, 'explicit')
+ eq_(context._version.name, 'explicit')
def test_config_explicit_version_table_schema(self):
context = self.make_one(dialect_name='sqlite',
opts={'version_table_schema': 'explicit'})
- self.assertEqual(context._version.schema, 'explicit')
+ eq_(context._version.schema, 'explicit')
def test_get_current_revision_creates_version_table(self):
context = self.make_one(connection=self.connection,
opts={'version_table': 'version_table'})
- self.assertEqual(context.get_current_revision(), None)
+ eq_(context.get_current_revision(), None)
insp = Inspector(self.connection)
- self.assertTrue('version_table' in insp.get_table_names())
+ assert ('version_table' in insp.get_table_names())
def test_get_current_revision(self):
context = self.make_one(connection=self.connection,
opts={'version_table': 'version_table'})
version_table.create(self.connection)
- self.assertEqual(context.get_current_revision(), None)
+ eq_(context.get_current_revision(), None)
self.connection.execute(
version_table.insert().values(version_num='revid'))
- self.assertEqual(context.get_current_revision(), 'revid')
+ eq_(context.get_current_revision(), 'revid')
def test_get_current_revision_error_if_starting_rev_given_online(self):
context = self.make_one(connection=self.connection,
opts={'starting_rev': 'boo'})
- self.assertRaises(CommandError, context.get_current_revision)
+ assert_raises(
+ CommandError,
+ context.get_current_revision
+ )
def test_get_current_revision_offline(self):
context = self.make_one(dialect_name='sqlite',
opts={'starting_rev': 'startrev',
'as_sql': True})
- self.assertEqual(context.get_current_revision(), 'startrev')
+ eq_(context.get_current_revision(), 'startrev')
def test__update_current_rev(self):
version_table.create(self.connection)
@@ -86,8 +88,8 @@ class TestMigrationContext(unittest.TestCase):
opts={'version_table': 'version_table'})
context._update_current_rev(None, 'a')
- self.assertEqual(self.get_revision(), 'a')
+ eq_(self.get_revision(), 'a')
context._update_current_rev('a', 'b')
- self.assertEqual(self.get_revision(), 'b')
+ eq_(self.get_revision(), 'b')
context._update_current_rev('b', None)
- self.assertEqual(self.get_revision(), None)
+ eq_(self.get_revision(), None)