diff options
Diffstat (limited to 'oslo_db/sqlalchemy/migration_cli/ext_alembic.py')
-rw-r--r-- | oslo_db/sqlalchemy/migration_cli/ext_alembic.py | 33 |
1 files changed, 23 insertions, 10 deletions
diff --git a/oslo_db/sqlalchemy/migration_cli/ext_alembic.py b/oslo_db/sqlalchemy/migration_cli/ext_alembic.py index 243ae47..1dbf88f 100644 --- a/oslo_db/sqlalchemy/migration_cli/ext_alembic.py +++ b/oslo_db/sqlalchemy/migration_cli/ext_alembic.py @@ -17,7 +17,6 @@ from alembic import config as alembic_config import alembic.migration as alembic_migration from oslo_db.sqlalchemy.migration_cli import ext_base -from oslo_db.sqlalchemy import session as db_session class AlembicExtension(ext_base.MigrationExtensionBase): @@ -28,31 +27,41 @@ class AlembicExtension(ext_base.MigrationExtensionBase): def enabled(self): return os.path.exists(self.alembic_ini_path) - def __init__(self, migration_config): + def __init__(self, engine, migration_config): """Extension to provide alembic features. + :param engine: SQLAlchemy engine instance for a given database + :type engine: sqlalchemy.engine.Engine :param migration_config: Stores specific configuration for migrations :type migration_config: dict """ self.alembic_ini_path = migration_config.get('alembic_ini_path', '') self.config = alembic_config.Config(self.alembic_ini_path) + # TODO(viktors): Remove this, when we will use Alembic 0.7.5 or + # higher, because the ``attributes`` dictionary was + # added to Alembic in version 0.7.5. + if not hasattr(self.config, 'attributes'): + self.config.attributes = {} # option should be used if script is not in default directory repo_path = migration_config.get('alembic_repo_path') if repo_path: self.config.set_main_option('script_location', repo_path) - self.db_url = migration_config['db_url'] + self.engine = engine def upgrade(self, version): - return alembic.command.upgrade(self.config, version or 'head') + with self.engine.begin() as connection: + self.config.attributes['connection'] = connection + return alembic.command.upgrade(self.config, version or 'head') def downgrade(self, version): if isinstance(version, int) or version is None or version.isdigit(): version = 'base' - return alembic.command.downgrade(self.config, version) + with self.engine.begin() as connection: + self.config.attributes['connection'] = connection + return alembic.command.downgrade(self.config, version) def version(self): - engine = db_session.create_engine(self.db_url) - with engine.connect() as conn: + with self.engine.connect() as conn: context = alembic_migration.MigrationContext.configure(conn) return context.get_current_revision() @@ -65,8 +74,10 @@ class AlembicExtension(ext_base.MigrationExtensionBase): state :type autogenerate: bool """ - return alembic.command.revision(self.config, message=message, - autogenerate=autogenerate) + with self.engine.begin() as connection: + self.config.attributes['connection'] = connection + return alembic.command.revision(self.config, message=message, + autogenerate=autogenerate) def stamp(self, revision): """Stamps database with provided revision. @@ -75,4 +86,6 @@ class AlembicExtension(ext_base.MigrationExtensionBase): database with most recent revision :type revision: string """ - return alembic.command.stamp(self.config, revision=revision) + with self.engine.begin() as connection: + self.config.attributes['connection'] = connection + return alembic.command.stamp(self.config, revision=revision) |