summaryrefslogtreecommitdiff
path: root/keystone
diff options
context:
space:
mode:
authorStephen Finucane <stephenfin@redhat.com>2022-01-19 14:38:10 +0000
committerStephen Finucane <stephenfin@redhat.com>2022-01-24 16:03:44 +0000
commit0b906c65294ac9f0723ab9f598f7f9f4c660d667 (patch)
tree74f983fe6b9fd23c8bf17b1971e8caa96add9bea /keystone
parentaebd037f501ca3aa9e927f304382c231dacb2938 (diff)
downloadkeystone-0b906c65294ac9f0723ab9f598f7f9f4c660d667.tar.gz
sql: Vendor 'oslo_db.sqlalchemy.migration'
This is deprecated and will be removed in a future release of oslo.db. Even without that stick to prod us, we're going to need to use some of the sqlalchemy-migrate APIs and it's simpler to talk to this for everything rather than using oslo.db for some stuff and sqlalchemy-migrate for the remainder. Change-Id: Ib25c75a99794a04b6549e6b5184a2029955befc1 Signed-off-by: Stephen Finucane <stephenfin@redhat.com>
Diffstat (limited to 'keystone')
-rw-r--r--keystone/cmd/cli.py13
-rw-r--r--keystone/common/sql/upgrades.py145
-rw-r--r--keystone/common/utils.py7
-rw-r--r--keystone/tests/unit/common/sql/test_upgrades.py252
-rw-r--r--keystone/tests/unit/test_cli.py12
5 files changed, 386 insertions, 43 deletions
diff --git a/keystone/cmd/cli.py b/keystone/cmd/cli.py
index 3bf3e3ebc..fae129a96 100644
--- a/keystone/cmd/cli.py
+++ b/keystone/cmd/cli.py
@@ -18,9 +18,8 @@ import os
import sys
import uuid
-import migrate
from oslo_config import cfg
-from oslo_db.sqlalchemy import migration
+from oslo_db import exception as db_exception
from oslo_log import log
from oslo_serialization import jsonutils
import pbr.version
@@ -279,7 +278,7 @@ class DbSync(BaseApp):
status = 0
try:
expand_version = upgrades.get_db_version(repo='expand_repo')
- except migration.exception.DBMigrationError:
+ except db_exception.DBMigrationError:
LOG.info(
'Your database is not currently under version '
'control or the database is already controlled. Your '
@@ -290,17 +289,15 @@ class DbSync(BaseApp):
try:
migrate_version = upgrades.get_db_version(
repo='data_migration_repo')
- except migration.exception.DBMigrationError:
+ except db_exception.DBMigrationError:
migrate_version = 0
try:
contract_version = upgrades.get_db_version(repo='contract_repo')
- except migration.exception.DBMigrationError:
+ except db_exception.DBMigrationError:
contract_version = 0
- repo = migrate.versioning.repository.Repository(
- upgrades.find_repo('expand_repo'))
- migration_script_version = int(max(repo.versions.versions))
+ migration_script_version = upgrades.LATEST_VERSION
if (
contract_version > migrate_version or
diff --git a/keystone/common/sql/upgrades.py b/keystone/common/sql/upgrades.py
index 4e959df3f..00d21aed7 100644
--- a/keystone/common/sql/upgrades.py
+++ b/keystone/common/sql/upgrades.py
@@ -16,15 +16,17 @@
import os
-import migrate
-from migrate.versioning import api as versioning_api
+from migrate import exceptions as migrate_exceptions
+from migrate.versioning import api as migrate_api
+from migrate.versioning import repository as migrate_repository
from oslo_db import exception as db_exception
-from oslo_db.sqlalchemy import migration
+import sqlalchemy as sa
from keystone.common import sql
from keystone import exception
INITIAL_VERSION = 72
+LATEST_VERSION = 79
EXPAND_REPO = 'expand_repo'
DATA_MIGRATION_REPO = 'data_migration_repo'
CONTRACT_REPO = 'contract_repo'
@@ -34,9 +36,9 @@ class Repository(object):
def __init__(self, engine, repo_name):
self.repo_name = repo_name
- self.repo_path = find_repo(self.repo_name)
+ self.repo_path = _get_migrate_repo_path(self.repo_name)
self.min_version = INITIAL_VERSION
- self.schema_ = versioning_api.ControlledSchema.create(
+ self.schema_ = migrate_api.ControlledSchema.create(
engine, self.repo_path, self.min_version)
self.max_version = self.schema_.repository.version().version
@@ -44,7 +46,7 @@ class Repository(object):
version = version or self.max_version
err = ''
upgrade = True
- version = versioning_api._migrate_version(
+ version = migrate_api._migrate_version(
self.schema_, version, upgrade, err)
validate_upgrade_order(self.repo_name, target_repo_version=version)
if not current_schema:
@@ -62,13 +64,13 @@ class Repository(object):
@property
def version(self):
with sql.session_for_read() as session:
- return migration.db_version(
- session.get_bind(), self.repo_path, self.min_version)
+ return _migrate_db_version(
+ session.get_bind(), self.repo_path, self.min_version,
+ )
-def find_repo(repo_name):
- """Return the absolute path to the named repository."""
- path = os.path.abspath(
+def _get_migrate_repo_path(repo_name):
+ abs_path = os.path.abspath(
os.path.join(
os.path.dirname(sql.__file__),
'legacy_migrations',
@@ -76,21 +78,114 @@ def find_repo(repo_name):
)
)
- if not os.path.isdir(path):
- raise exception.MigrationNotProvided(sql.__name__, path)
+ if not os.path.isdir(abs_path):
+ raise exception.MigrationNotProvided(sql.__name__, abs_path)
- return path
+ return abs_path
+
+
+def _find_migrate_repo(abs_path):
+ """Get the project's change script repository
+
+ :param abs_path: Absolute path to migrate repository
+ """
+ if not os.path.exists(abs_path):
+ raise db_exception.DBMigrationError("Path %s not found" % abs_path)
+ return migrate_repository.Repository(abs_path)
+
+
+def _migrate_db_version_control(engine, abs_path, version=None):
+ """Mark a database as under this repository's version control.
+
+ Once a database is under version control, schema changes should
+ only be done via change scripts in this repository.
+
+ :param engine: SQLAlchemy engine instance for a given database
+ :param abs_path: Absolute path to migrate repository
+ :param version: Initial database version
+ """
+ repository = _find_migrate_repo(abs_path)
+
+ try:
+ migrate_api.version_control(engine, repository, version)
+ except migrate_exceptions.InvalidVersionError as ex:
+ raise db_exception.DBMigrationError("Invalid version : %s" % ex)
+ except migrate_exceptions.DatabaseAlreadyControlledError:
+ raise db_exception.DBMigrationError("Database is already controlled.")
+
+ return version
+
+
+def _migrate_db_version(engine, abs_path, init_version):
+ """Show the current version of the repository.
+
+ :param engine: SQLAlchemy engine instance for a given database
+ :param abs_path: Absolute path to migrate repository
+ :param init_version: Initial database version
+ """
+ repository = _find_migrate_repo(abs_path)
+ try:
+ return migrate_api.db_version(engine, repository)
+ except migrate_exceptions.DatabaseNotControlledError:
+ pass
+
+ meta = sa.MetaData()
+ meta.reflect(bind=engine)
+ tables = meta.tables
+ if (
+ len(tables) == 0 or
+ 'alembic_version' in tables or
+ 'migrate_version' in tables
+ ):
+ _migrate_db_version_control(engine, abs_path, version=init_version)
+ return migrate_api.db_version(engine, repository)
+
+ msg = _(
+ "The database is not under version control, but has tables. "
+ "Please stamp the current version of the schema manually."
+ )
+ raise db_exception.DBMigrationError(msg)
+
+
+def _migrate_db_sync(engine, abs_path, version=None, init_version=0):
+ """Upgrade or downgrade a database.
+
+ Function runs the upgrade() or downgrade() functions in change scripts.
+
+ :param engine: SQLAlchemy engine instance for a given database
+ :param abs_path: Absolute path to migrate repository.
+ :param version: Database will upgrade/downgrade until this version.
+ If None - database will update to the latest available version.
+ :param init_version: Initial database version
+ """
+
+ if version is not None:
+ try:
+ version = int(version)
+ except ValueError:
+ msg = _("version should be an integer")
+ raise db_exception.DBMigrationError(msg)
+
+ current_version = _migrate_db_version(engine, abs_path, init_version)
+ repository = _find_migrate_repo(abs_path)
+
+ if version is None or version > current_version:
+ try:
+ return migrate_api.upgrade(engine, repository, version)
+ except Exception as ex:
+ raise db_exception.DBMigrationError(ex)
+ else:
+ return migrate_api.downgrade(engine, repository, version)
def _sync_repo(repo_name):
- abs_path = find_repo(repo_name)
+ abs_path = _get_migrate_repo_path(repo_name)
with sql.session_for_write() as session:
engine = session.get_bind()
- migration.db_sync(
- engine,
- abs_path,
+ _migrate_db_sync(
+ engine=engine,
+ abs_path=abs_path,
init_version=INITIAL_VERSION,
- sanity_check=False,
)
@@ -115,9 +210,13 @@ def offline_sync_database_to_version(version=None):
def get_db_version(repo=EXPAND_REPO):
+ abs_path = _get_migrate_repo_path(repo)
with sql.session_for_read() as session:
- repo = find_repo(repo)
- return migration.db_version(session.get_bind(), repo, INITIAL_VERSION)
+ return _migrate_db_version(
+ session.get_bind(),
+ abs_path,
+ INITIAL_VERSION,
+ )
def validate_upgrade_order(repo_name, target_repo_version=None):
@@ -145,8 +244,8 @@ def validate_upgrade_order(repo_name, target_repo_version=None):
# find the latest version that the current command will upgrade to if there
# wasn't a version specified for upgrade.
if not target_repo_version:
- abs_path = find_repo(repo_name)
- repo = migrate.versioning.repository.Repository(abs_path)
+ abs_path = _get_migrate_repo_path(repo_name)
+ repo = _find_migrate_repo(abs_path)
target_repo_version = int(repo.latest)
# get current version of the command that runs before the current command.
diff --git a/keystone/common/utils.py b/keystone/common/utils.py
index 7c3e7ae1a..70d277e52 100644
--- a/keystone/common/utils.py
+++ b/keystone/common/utils.py
@@ -17,6 +17,7 @@
# under the License.
import collections.abc
+import contextlib
import grp
import hashlib
import itertools
@@ -489,3 +490,9 @@ def create_directory(directory, keystone_user_id=None, keystone_group_id=None):
'Unable to change the ownership of key repository without '
'a keystone user ID and keystone group ID both being '
'provided: %s', directory)
+
+
+@contextlib.contextmanager
+def nested_contexts(*contexts):
+ with contextlib.ExitStack() as stack:
+ yield [stack.enter_context(c) for c in contexts]
diff --git a/keystone/tests/unit/common/sql/test_upgrades.py b/keystone/tests/unit/common/sql/test_upgrades.py
new file mode 100644
index 000000000..c6c4a2e56
--- /dev/null
+++ b/keystone/tests/unit/common/sql/test_upgrades.py
@@ -0,0 +1,252 @@
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import os
+import tempfile
+from unittest import mock
+
+from migrate import exceptions as migrate_exception
+from migrate.versioning import api as migrate_api
+from migrate.versioning import repository as migrate_repository
+from oslo_db import exception as db_exception
+from oslo_db.sqlalchemy import enginefacade
+from oslo_db.sqlalchemy import test_fixtures as db_fixtures
+from oslotest import base as test_base
+import sqlalchemy
+
+from keystone.common.sql import upgrades
+from keystone.common import utils
+
+
+class TestMigrationCommon(
+ db_fixtures.OpportunisticDBTestMixin, test_base.BaseTestCase,
+):
+
+ def setUp(self):
+ super().setUp()
+
+ self.engine = enginefacade.writer.get_engine()
+
+ self.path = tempfile.mkdtemp('test_migration')
+ self.path1 = tempfile.mkdtemp('test_migration')
+ self.return_value = '/home/openstack/migrations'
+ self.return_value1 = '/home/extension/migrations'
+ self.init_version = 1
+ self.test_version = 123
+
+ self.patcher_repo = mock.patch.object(migrate_repository, 'Repository')
+ self.repository = self.patcher_repo.start()
+ self.repository.side_effect = [self.return_value, self.return_value1]
+
+ self.mock_api_db = mock.patch.object(migrate_api, 'db_version')
+ self.mock_api_db_version = self.mock_api_db.start()
+ self.mock_api_db_version.return_value = self.test_version
+
+ def tearDown(self):
+ os.rmdir(self.path)
+ self.mock_api_db.stop()
+ self.patcher_repo.stop()
+ super().tearDown()
+
+ def test_find_migrate_repo_path_not_found(self):
+ self.assertRaises(
+ db_exception.DBMigrationError,
+ upgrades._find_migrate_repo,
+ "/foo/bar/",
+ )
+
+ def test_find_migrate_repo_called_once(self):
+ my_repository = upgrades._find_migrate_repo(self.path)
+ self.repository.assert_called_once_with(self.path)
+ self.assertEqual(self.return_value, my_repository)
+
+ def test_find_migrate_repo_called_few_times(self):
+ repo1 = upgrades._find_migrate_repo(self.path)
+ repo2 = upgrades._find_migrate_repo(self.path1)
+ self.assertNotEqual(repo1, repo2)
+
+ def test_db_version_control(self):
+ with utils.nested_contexts(
+ mock.patch.object(upgrades, '_find_migrate_repo'),
+ mock.patch.object(migrate_api, 'version_control'),
+ ) as (mock_find_repo, mock_version_control):
+ mock_find_repo.return_value = self.return_value
+
+ version = upgrades._migrate_db_version_control(
+ self.engine, self.path, self.test_version)
+
+ self.assertEqual(self.test_version, version)
+ mock_version_control.assert_called_once_with(
+ self.engine, self.return_value, self.test_version)
+
+ @mock.patch.object(upgrades, '_find_migrate_repo')
+ @mock.patch.object(migrate_api, 'version_control')
+ def test_db_version_control_version_less_than_actual_version(
+ self, mock_version_control, mock_find_repo,
+ ):
+ mock_find_repo.return_value = self.return_value
+ mock_version_control.side_effect = \
+ migrate_exception.DatabaseAlreadyControlledError
+ self.assertRaises(
+ db_exception.DBMigrationError,
+ upgrades._migrate_db_version_control, self.engine,
+ self.path, self.test_version - 1)
+
+ @mock.patch.object(upgrades, '_find_migrate_repo')
+ @mock.patch.object(migrate_api, 'version_control')
+ def test_db_version_control_version_greater_than_actual_version(
+ self, mock_version_control, mock_find_repo,
+ ):
+ mock_find_repo.return_value = self.return_value
+ mock_version_control.side_effect = \
+ migrate_exception.InvalidVersionError
+ self.assertRaises(
+ db_exception.DBMigrationError,
+ upgrades._migrate_db_version_control, self.engine,
+ self.path, self.test_version + 1)
+
+ def test_db_version_return(self):
+ ret_val = upgrades._migrate_db_version(
+ self.engine, self.path, self.init_version)
+ self.assertEqual(self.test_version, ret_val)
+
+ def test_db_version_raise_not_controlled_error_first(self):
+ with mock.patch.object(
+ upgrades, '_migrate_db_version_control',
+ ) as mock_ver:
+ self.mock_api_db_version.side_effect = [
+ migrate_exception.DatabaseNotControlledError('oups'),
+ self.test_version]
+
+ ret_val = upgrades._migrate_db_version(
+ self.engine, self.path, self.init_version)
+ self.assertEqual(self.test_version, ret_val)
+ mock_ver.assert_called_once_with(
+ self.engine, self.path, version=self.init_version)
+
+ def test_db_version_raise_not_controlled_error_tables(self):
+ with mock.patch.object(sqlalchemy, 'MetaData') as mock_meta:
+ self.mock_api_db_version.side_effect = \
+ migrate_exception.DatabaseNotControlledError('oups')
+ my_meta = mock.MagicMock()
+ my_meta.tables = {'a': 1, 'b': 2}
+ mock_meta.return_value = my_meta
+
+ self.assertRaises(
+ db_exception.DBMigrationError, upgrades._migrate_db_version,
+ self.engine, self.path, self.init_version)
+
+ @mock.patch.object(migrate_api, 'version_control')
+ def test_db_version_raise_not_controlled_error_no_tables(self, mock_vc):
+ with mock.patch.object(sqlalchemy, 'MetaData') as mock_meta:
+ self.mock_api_db_version.side_effect = (
+ migrate_exception.DatabaseNotControlledError('oups'),
+ self.init_version)
+ my_meta = mock.MagicMock()
+ my_meta.tables = {}
+ mock_meta.return_value = my_meta
+
+ upgrades._migrate_db_version(
+ self.engine, self.path, self.init_version)
+
+ mock_vc.assert_called_once_with(
+ self.engine, self.return_value1, self.init_version)
+
+ @mock.patch.object(migrate_api, 'version_control')
+ def test_db_version_raise_not_controlled_alembic_tables(self, mock_vc):
+ # When there are tables but the alembic control table
+ # (alembic_version) is present, attempt to version the db.
+ # This simulates the case where there is are multiple repos (different
+ # abs_paths) and a different path has been versioned already.
+ with mock.patch.object(sqlalchemy, 'MetaData') as mock_meta:
+ self.mock_api_db_version.side_effect = [
+ migrate_exception.DatabaseNotControlledError('oups'), None]
+ my_meta = mock.MagicMock()
+ my_meta.tables = {'alembic_version': 1, 'b': 2}
+ mock_meta.return_value = my_meta
+
+ upgrades._migrate_db_version(
+ self.engine, self.path, self.init_version)
+
+ mock_vc.assert_called_once_with(
+ self.engine, self.return_value1, self.init_version)
+
+ @mock.patch.object(migrate_api, 'version_control')
+ def test_db_version_raise_not_controlled_migrate_tables(self, mock_vc):
+ # When there are tables but the sqlalchemy-migrate control table
+ # (migrate_version) is present, attempt to version the db.
+ # This simulates the case where there is are multiple repos (different
+ # abs_paths) and a different path has been versioned already.
+ with mock.patch.object(sqlalchemy, 'MetaData') as mock_meta:
+ self.mock_api_db_version.side_effect = [
+ migrate_exception.DatabaseNotControlledError('oups'), None]
+ my_meta = mock.MagicMock()
+ my_meta.tables = {'migrate_version': 1, 'b': 2}
+ mock_meta.return_value = my_meta
+
+ upgrades._migrate_db_version(
+ self.engine, self.path, self.init_version)
+
+ mock_vc.assert_called_once_with(
+ self.engine, self.return_value1, self.init_version)
+
+ def test_db_sync_wrong_version(self):
+ self.assertRaises(
+ db_exception.DBMigrationError,
+ upgrades._migrate_db_sync, self.engine, self.path, 'foo')
+
+ @mock.patch.object(migrate_api, 'upgrade')
+ def test_db_sync_script_not_present(self, upgrade):
+ # For non existent upgrades script file sqlalchemy-migrate will raise
+ # VersionNotFoundError which will be wrapped in DBMigrationError.
+ upgrade.side_effect = migrate_exception.VersionNotFoundError
+ self.assertRaises(
+ db_exception.DBMigrationError,
+ upgrades._migrate_db_sync, self.engine, self.path,
+ self.test_version + 1)
+
+ @mock.patch.object(migrate_api, 'upgrade')
+ def test_db_sync_known_error_raised(self, upgrade):
+ upgrade.side_effect = migrate_exception.KnownError
+ self.assertRaises(
+ db_exception.DBMigrationError,
+ upgrades._migrate_db_sync, self.engine, self.path,
+ self.test_version + 1)
+
+ def test_db_sync_upgrade(self):
+ init_ver = 55
+ with utils.nested_contexts(
+ mock.patch.object(upgrades, '_find_migrate_repo'),
+ mock.patch.object(migrate_api, 'upgrade')
+ ) as (mock_find_repo, mock_upgrade):
+ mock_find_repo.return_value = self.return_value
+ self.mock_api_db_version.return_value = self.test_version - 1
+
+ upgrades._migrate_db_sync(
+ self.engine, self.path, self.test_version, init_ver)
+
+ mock_upgrade.assert_called_once_with(
+ self.engine, self.return_value, self.test_version)
+
+ def test_db_sync_downgrade(self):
+ with utils.nested_contexts(
+ mock.patch.object(upgrades, '_find_migrate_repo'),
+ mock.patch.object(migrate_api, 'downgrade')
+ ) as (mock_find_repo, mock_downgrade):
+ mock_find_repo.return_value = self.return_value
+ self.mock_api_db_version.return_value = self.test_version + 1
+
+ upgrades._migrate_db_sync(
+ self.engine, self.path, self.test_version)
+
+ mock_downgrade.assert_called_once_with(
+ self.engine, self.return_value, self.test_version)
diff --git a/keystone/tests/unit/test_cli.py b/keystone/tests/unit/test_cli.py
index e2c42ec8a..c94d8c196 100644
--- a/keystone/tests/unit/test_cli.py
+++ b/keystone/tests/unit/test_cli.py
@@ -25,7 +25,6 @@ import fixtures
import freezegun
import http.client
import oslo_config.fixture
-from oslo_db.sqlalchemy import migration
from oslo_log import log
from oslo_serialization import jsonutils
from oslo_upgradecheck import upgradecheck
@@ -805,17 +804,6 @@ class CliDBSyncTestCase(unit.BaseTestCase):
cli.DbSync.main()
self._assert_correct_call(upgrades.contract_schema)
- @mock.patch('keystone.cmd.cli.upgrades.get_db_version')
- def test_db_sync_check_when_database_is_empty(self, mocked_get_db_version):
- e = migration.exception.DBMigrationError("Invalid version")
- mocked_get_db_version.side_effect = e
- checker = cli.DbSync()
-
- log_info = self.useFixture(fixtures.FakeLogger(level=log.INFO))
- status = checker.check_db_sync_status()
- self.assertIn("not currently under version control", log_info.output)
- self.assertEqual(status, 2)
-
class TestMappingPopulate(unit.SQLDriverOverrides, unit.TestCase):