summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJenkins <jenkins@review.openstack.org>2014-02-20 13:48:16 +0000
committerGerrit Code Review <review@openstack.org>2014-02-20 13:48:16 +0000
commite6971a49ab07b904f5192e6725e16a931d6c5ae3 (patch)
tree2b247f745b4ee832ff975fa702aaf33a2f2af7e4
parenta5c865d0ae2a6981daa07e59bad25778526b26da (diff)
parent28fadaf6f6deeb3c8c1bff261fb00b2d92289acb (diff)
downloaddesignate-e6971a49ab07b904f5192e6725e16a931d6c5ae3.tar.gz
Merge "Adds support for paging in the storage layer"
-rw-r--r--designate/exceptions.py4
-rw-r--r--designate/openstack/common/db/__init__.py0
-rw-r--r--designate/openstack/common/db/api.py147
-rw-r--r--designate/openstack/common/db/exception.py56
-rw-r--r--designate/openstack/common/db/options.py177
-rw-r--r--designate/openstack/common/db/sqlalchemy/__init__.py0
-rw-r--r--designate/openstack/common/db/sqlalchemy/migration.py268
-rw-r--r--designate/openstack/common/db/sqlalchemy/models.py115
-rw-r--r--designate/openstack/common/db/sqlalchemy/provision.py187
-rw-r--r--designate/openstack/common/db/sqlalchemy/session.py809
-rw-r--r--designate/openstack/common/db/sqlalchemy/test_base.py149
-rw-r--r--designate/openstack/common/db/sqlalchemy/test_migrations.py269
-rw-r--r--designate/openstack/common/db/sqlalchemy/utils.py547
-rw-r--r--designate/storage/api.py48
-rw-r--r--designate/storage/base.py81
-rw-r--r--designate/storage/impl_sqlalchemy/__init__.py134
-rw-r--r--designate/tests/test_central/test_service.py4
-rw-r--r--designate/tests/test_storage/__init__.py112
-rw-r--r--designate/tests/test_storage/test_api.py72
-rw-r--r--openstack-common.conf2
-rw-r--r--requirements.txt1
21 files changed, 3064 insertions, 118 deletions
diff --git a/designate/exceptions.py b/designate/exceptions.py
index b018c5b6..7d800c13 100644
--- a/designate/exceptions.py
+++ b/designate/exceptions.py
@@ -87,6 +87,10 @@ class NetworkEndpointNotFound(BadRequest):
error_code = 403
+class MarkerNotFound(BadRequest):
+ error_type = 'marker_not_found'
+
+
class InvalidOperation(BadRequest):
error_code = 400
error_type = 'invalid_operation'
diff --git a/designate/openstack/common/db/__init__.py b/designate/openstack/common/db/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/designate/openstack/common/db/__init__.py
diff --git a/designate/openstack/common/db/api.py b/designate/openstack/common/db/api.py
new file mode 100644
index 00000000..e8ca92da
--- /dev/null
+++ b/designate/openstack/common/db/api.py
@@ -0,0 +1,147 @@
+# Copyright (c) 2013 Rackspace Hosting
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""Multiple DB API backend support.
+
+A DB backend module should implement a method named 'get_backend' which
+takes no arguments. The method can return any object that implements DB
+API methods.
+"""
+
+import functools
+import logging
+import time
+
+from designate.openstack.common.db import exception
+from designate.openstack.common.gettextutils import _ # noqa
+from designate.openstack.common import importutils
+
+
+LOG = logging.getLogger(__name__)
+
+
+def safe_for_db_retry(f):
+ """Enable db-retry for decorated function, if config option enabled."""
+ f.__dict__['enable_retry'] = True
+ return f
+
+
+class wrap_db_retry(object):
+ """Retry db.api methods, if DBConnectionError() raised
+
+ Retry decorated db.api methods. If we enabled `use_db_reconnect`
+ in config, this decorator will be applied to all db.api functions,
+ marked with @safe_for_db_retry decorator.
+ Decorator catchs DBConnectionError() and retries function in a
+ loop until it succeeds, or until maximum retries count will be reached.
+ """
+
+ def __init__(self, retry_interval, max_retries, inc_retry_interval,
+ max_retry_interval):
+ super(wrap_db_retry, self).__init__()
+
+ self.retry_interval = retry_interval
+ self.max_retries = max_retries
+ self.inc_retry_interval = inc_retry_interval
+ self.max_retry_interval = max_retry_interval
+
+ def __call__(self, f):
+ @functools.wraps(f)
+ def wrapper(*args, **kwargs):
+ next_interval = self.retry_interval
+ remaining = self.max_retries
+
+ while True:
+ try:
+ return f(*args, **kwargs)
+ except exception.DBConnectionError as e:
+ if remaining == 0:
+ LOG.exception(_('DB exceeded retry limit.'))
+ raise exception.DBError(e)
+ if remaining != -1:
+ remaining -= 1
+ LOG.exception(_('DB connection error.'))
+ # NOTE(vsergeyev): We are using patched time module, so
+ # this effectively yields the execution
+ # context to another green thread.
+ time.sleep(next_interval)
+ if self.inc_retry_interval:
+ next_interval = min(
+ next_interval * 2,
+ self.max_retry_interval
+ )
+ return wrapper
+
+
+class DBAPI(object):
+ def __init__(self, backend_name, backend_mapping=None, **kwargs):
+ """Initialize the choosen DB API backend.
+
+ :param backend_name: name of the backend to load
+ :type backend_name: str
+
+ :param backend_mapping: backend name -> module/class to load mapping
+ :type backend_mapping: dict
+
+ Keyword arguments:
+
+ :keyword use_db_reconnect: retry DB transactions on disconnect or not
+ :type use_db_reconnect: bool
+
+ :keyword retry_interval: seconds between transaction retries
+ :type retry_interval: int
+
+ :keyword inc_retry_interval: increase retry interval or not
+ :type inc_retry_interval: bool
+
+ :keyword max_retry_interval: max interval value between retries
+ :type max_retry_interval: int
+
+ :keyword max_retries: max number of retries before an error is raised
+ :type max_retries: int
+
+ """
+
+ if backend_mapping is None:
+ backend_mapping = {}
+
+ # Import the untranslated name if we don't have a
+ # mapping.
+ backend_path = backend_mapping.get(backend_name, backend_name)
+ backend_mod = importutils.import_module(backend_path)
+ self.__backend = backend_mod.get_backend()
+
+ self.use_db_reconnect = kwargs.get('use_db_reconnect', False)
+ self.retry_interval = kwargs.get('retry_interval', 1)
+ self.inc_retry_interval = kwargs.get('inc_retry_interval', True)
+ self.max_retry_interval = kwargs.get('max_retry_interval', 10)
+ self.max_retries = kwargs.get('max_retries', 20)
+
+ def __getattr__(self, key):
+ attr = getattr(self.__backend, key)
+
+ if not hasattr(attr, '__call__'):
+ return attr
+ # NOTE(vsergeyev): If `use_db_reconnect` option is set to True, retry
+ # DB API methods, decorated with @safe_for_db_retry
+ # on disconnect.
+ if self.use_db_reconnect and hasattr(attr, 'enable_retry'):
+ attr = wrap_db_retry(
+ retry_interval=self.retry_interval,
+ max_retries=self.max_retries,
+ inc_retry_interval=self.inc_retry_interval,
+ max_retry_interval=self.max_retry_interval)(attr)
+
+ return attr
diff --git a/designate/openstack/common/db/exception.py b/designate/openstack/common/db/exception.py
new file mode 100644
index 00000000..fd3ba246
--- /dev/null
+++ b/designate/openstack/common/db/exception.py
@@ -0,0 +1,56 @@
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""DB related custom exceptions."""
+
+import six
+
+from designate.openstack.common.gettextutils import _
+
+
+class DBError(Exception):
+ """Wraps an implementation specific exception."""
+ def __init__(self, inner_exception=None):
+ self.inner_exception = inner_exception
+ super(DBError, self).__init__(six.text_type(inner_exception))
+
+
+class DBDuplicateEntry(DBError):
+ """Wraps an implementation specific exception."""
+ def __init__(self, columns=[], inner_exception=None):
+ self.columns = columns
+ super(DBDuplicateEntry, self).__init__(inner_exception)
+
+
+class DBDeadlock(DBError):
+ def __init__(self, inner_exception=None):
+ super(DBDeadlock, self).__init__(inner_exception)
+
+
+class DBInvalidUnicodeParameter(Exception):
+ message = _("Invalid Parameter: "
+ "Unicode is not supported by the current database.")
+
+
+class DbMigrationError(DBError):
+ """Wraps migration specific exception."""
+ def __init__(self, message=None):
+ super(DbMigrationError, self).__init__(message)
+
+
+class DBConnectionError(DBError):
+ """Wraps connection specific exception."""
+ pass
diff --git a/designate/openstack/common/db/options.py b/designate/openstack/common/db/options.py
new file mode 100644
index 00000000..6e15a0f5
--- /dev/null
+++ b/designate/openstack/common/db/options.py
@@ -0,0 +1,177 @@
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import copy
+import os
+
+from oslo.config import cfg
+
+
+sqlite_db_opts = [
+ cfg.StrOpt('sqlite_db',
+ default='designate.sqlite',
+ help='The file name to use with SQLite'),
+ cfg.BoolOpt('sqlite_synchronous',
+ default=True,
+ help='If True, SQLite uses synchronous mode'),
+]
+
+database_opts = [
+ cfg.StrOpt('backend',
+ default='sqlalchemy',
+ deprecated_name='db_backend',
+ deprecated_group='DEFAULT',
+ help='The backend to use for db'),
+ cfg.StrOpt('connection',
+ default='sqlite:///' +
+ os.path.abspath(os.path.join(os.path.dirname(__file__),
+ '../', '$sqlite_db')),
+ help='The SQLAlchemy connection string used to connect to the '
+ 'database',
+ secret=True,
+ deprecated_opts=[cfg.DeprecatedOpt('sql_connection',
+ group='DEFAULT'),
+ cfg.DeprecatedOpt('sql_connection',
+ group='DATABASE'),
+ cfg.DeprecatedOpt('connection',
+ group='sql'), ]),
+ cfg.IntOpt('idle_timeout',
+ default=3600,
+ deprecated_opts=[cfg.DeprecatedOpt('sql_idle_timeout',
+ group='DEFAULT'),
+ cfg.DeprecatedOpt('sql_idle_timeout',
+ group='DATABASE'),
+ cfg.DeprecatedOpt('idle_timeout',
+ group='sql')],
+ help='Timeout before idle sql connections are reaped'),
+ cfg.IntOpt('min_pool_size',
+ default=1,
+ deprecated_opts=[cfg.DeprecatedOpt('sql_min_pool_size',
+ group='DEFAULT'),
+ cfg.DeprecatedOpt('sql_min_pool_size',
+ group='DATABASE')],
+ help='Minimum number of SQL connections to keep open in a '
+ 'pool'),
+ cfg.IntOpt('max_pool_size',
+ default=None,
+ deprecated_opts=[cfg.DeprecatedOpt('sql_max_pool_size',
+ group='DEFAULT'),
+ cfg.DeprecatedOpt('sql_max_pool_size',
+ group='DATABASE')],
+ help='Maximum number of SQL connections to keep open in a '
+ 'pool'),
+ cfg.IntOpt('max_retries',
+ default=10,
+ deprecated_opts=[cfg.DeprecatedOpt('sql_max_retries',
+ group='DEFAULT'),
+ cfg.DeprecatedOpt('sql_max_retries',
+ group='DATABASE')],
+ help='Maximum db connection retries during startup. '
+ '(setting -1 implies an infinite retry count)'),
+ cfg.IntOpt('retry_interval',
+ default=10,
+ deprecated_opts=[cfg.DeprecatedOpt('sql_retry_interval',
+ group='DEFAULT'),
+ cfg.DeprecatedOpt('reconnect_interval',
+ group='DATABASE')],
+ help='Interval between retries of opening a sql connection'),
+ cfg.IntOpt('max_overflow',
+ default=None,
+ deprecated_opts=[cfg.DeprecatedOpt('sql_max_overflow',
+ group='DEFAULT'),
+ cfg.DeprecatedOpt('sqlalchemy_max_overflow',
+ group='DATABASE')],
+ help='If set, use this value for max_overflow with sqlalchemy'),
+ cfg.IntOpt('connection_debug',
+ default=0,
+ deprecated_opts=[cfg.DeprecatedOpt('sql_connection_debug',
+ group='DEFAULT')],
+ help='Verbosity of SQL debugging information. 0=None, '
+ '100=Everything'),
+ cfg.BoolOpt('connection_trace',
+ default=False,
+ deprecated_opts=[cfg.DeprecatedOpt('sql_connection_trace',
+ group='DEFAULT')],
+ help='Add python stack traces to SQL as comment strings'),
+ cfg.IntOpt('pool_timeout',
+ default=None,
+ deprecated_opts=[cfg.DeprecatedOpt('sqlalchemy_pool_timeout',
+ group='DATABASE')],
+ help='If set, use this value for pool_timeout with sqlalchemy'),
+ cfg.BoolOpt('use_db_reconnect',
+ default=False,
+ help='Enable the experimental use of database reconnect '
+ 'on connection lost'),
+ cfg.IntOpt('db_retry_interval',
+ default=1,
+ help='seconds between db connection retries'),
+ cfg.BoolOpt('db_inc_retry_interval',
+ default=True,
+ help='Whether to increase interval between db connection '
+ 'retries, up to db_max_retry_interval'),
+ cfg.IntOpt('db_max_retry_interval',
+ default=10,
+ help='max seconds between db connection retries, if '
+ 'db_inc_retry_interval is enabled'),
+ cfg.IntOpt('db_max_retries',
+ default=20,
+ help='maximum db connection retries before error is raised. '
+ '(setting -1 implies an infinite retry count)'),
+]
+
+CONF = cfg.CONF
+CONF.register_opts(sqlite_db_opts)
+CONF.register_opts(database_opts, 'database')
+
+
+def set_defaults(sql_connection, sqlite_db, max_pool_size=None,
+ max_overflow=None, pool_timeout=None):
+ """Set defaults for configuration variables."""
+ cfg.set_defaults(database_opts,
+ connection=sql_connection)
+ cfg.set_defaults(sqlite_db_opts,
+ sqlite_db=sqlite_db)
+ # Update the QueuePool defaults
+ if max_pool_size is not None:
+ cfg.set_defaults(database_opts,
+ max_pool_size=max_pool_size)
+ if max_overflow is not None:
+ cfg.set_defaults(database_opts,
+ max_overflow=max_overflow)
+ if pool_timeout is not None:
+ cfg.set_defaults(database_opts,
+ pool_timeout=pool_timeout)
+
+
+_opts = [
+ (None, sqlite_db_opts),
+ ('database', database_opts),
+]
+
+
+def list_opts():
+ """Returns a list of oslo.config options available in the library.
+
+ The returned list includes all oslo.config options which may be registered
+ at runtime by the library.
+
+ Each element of the list is a tuple. The first element is the name of the
+ group under which the list of elements in the second element will be
+ registered. A group name of None corresponds to the [DEFAULT] group in
+ config files.
+
+ The purpose of this is to allow tools like the Oslo sample config file
+ generator to discover the options exposed to users by this library.
+
+ :returns: a list of (group_name, opts) tuples
+ """
+ return [(g, copy.deepcopy(o)) for g, o in _opts]
diff --git a/designate/openstack/common/db/sqlalchemy/__init__.py b/designate/openstack/common/db/sqlalchemy/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/designate/openstack/common/db/sqlalchemy/__init__.py
diff --git a/designate/openstack/common/db/sqlalchemy/migration.py b/designate/openstack/common/db/sqlalchemy/migration.py
new file mode 100644
index 00000000..c88fa614
--- /dev/null
+++ b/designate/openstack/common/db/sqlalchemy/migration.py
@@ -0,0 +1,268 @@
+# coding: utf-8
+#
+# Copyright (c) 2013 OpenStack Foundation
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+# Base on code in migrate/changeset/databases/sqlite.py which is under
+# the following license:
+#
+# The MIT License
+#
+# Copyright (c) 2009 Evan Rosson, Jan Dittberner, Domen Kožar
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+
+import os
+import re
+
+from migrate.changeset import ansisql
+from migrate.changeset.databases import sqlite
+from migrate import exceptions as versioning_exceptions
+from migrate.versioning import api as versioning_api
+from migrate.versioning.repository import Repository
+import sqlalchemy
+from sqlalchemy.schema import UniqueConstraint
+
+from designate.openstack.common.db import exception
+from designate.openstack.common.gettextutils import _
+
+
+def _get_unique_constraints(self, table):
+ """Retrieve information about existing unique constraints of the table
+
+ This feature is needed for _recreate_table() to work properly.
+ Unfortunately, it's not available in sqlalchemy 0.7.x/0.8.x.
+
+ """
+
+ data = table.metadata.bind.execute(
+ """SELECT sql
+ FROM sqlite_master
+ WHERE
+ type='table' AND
+ name=:table_name""",
+ table_name=table.name
+ ).fetchone()[0]
+
+ UNIQUE_PATTERN = "CONSTRAINT (\w+) UNIQUE \(([^\)]+)\)"
+ return [
+ UniqueConstraint(
+ *[getattr(table.columns, c.strip(' "')) for c in cols.split(",")],
+ name=name
+ )
+ for name, cols in re.findall(UNIQUE_PATTERN, data)
+ ]
+
+
+def _recreate_table(self, table, column=None, delta=None, omit_uniques=None):
+ """Recreate the table properly
+
+ Unlike the corresponding original method of sqlalchemy-migrate this one
+ doesn't drop existing unique constraints when creating a new one.
+
+ """
+
+ table_name = self.preparer.format_table(table)
+
+ # we remove all indexes so as not to have
+ # problems during copy and re-create
+ for index in table.indexes:
+ index.drop()
+
+ # reflect existing unique constraints
+ for uc in self._get_unique_constraints(table):
+ table.append_constraint(uc)
+ # omit given unique constraints when creating a new table if required
+ table.constraints = set([
+ cons for cons in table.constraints
+ if omit_uniques is None or cons.name not in omit_uniques
+ ])
+
+ self.append('ALTER TABLE %s RENAME TO migration_tmp' % table_name)
+ self.execute()
+
+ insertion_string = self._modify_table(table, column, delta)
+
+ table.create(bind=self.connection)
+ self.append(insertion_string % {'table_name': table_name})
+ self.execute()
+ self.append('DROP TABLE migration_tmp')
+ self.execute()
+
+
+def _visit_migrate_unique_constraint(self, *p, **k):
+ """Drop the given unique constraint
+
+ The corresponding original method of sqlalchemy-migrate just
+ raises NotImplemented error
+
+ """
+
+ self.recreate_table(p[0].table, omit_uniques=[p[0].name])
+
+
+def patch_migrate():
+ """A workaround for SQLite's inability to alter things
+
+ SQLite abilities to alter tables are very limited (please read
+ http://www.sqlite.org/lang_altertable.html for more details).
+ E. g. one can't drop a column or a constraint in SQLite. The
+ workaround for this is to recreate the original table omitting
+ the corresponding constraint (or column).
+
+ sqlalchemy-migrate library has recreate_table() method that
+ implements this workaround, but it does it wrong:
+
+ - information about unique constraints of a table
+ is not retrieved. So if you have a table with one
+ unique constraint and a migration adding another one
+ you will end up with a table that has only the
+ latter unique constraint, and the former will be lost
+
+ - dropping of unique constraints is not supported at all
+
+ The proper way to fix this is to provide a pull-request to
+ sqlalchemy-migrate, but the project seems to be dead. So we
+ can go on with monkey-patching of the lib at least for now.
+
+ """
+
+ # this patch is needed to ensure that recreate_table() doesn't drop
+ # existing unique constraints of the table when creating a new one
+ helper_cls = sqlite.SQLiteHelper
+ helper_cls.recreate_table = _recreate_table
+ helper_cls._get_unique_constraints = _get_unique_constraints
+
+ # this patch is needed to be able to drop existing unique constraints
+ constraint_cls = sqlite.SQLiteConstraintDropper
+ constraint_cls.visit_migrate_unique_constraint = \
+ _visit_migrate_unique_constraint
+ constraint_cls.__bases__ = (ansisql.ANSIColumnDropper,
+ sqlite.SQLiteConstraintGenerator)
+
+
+def db_sync(engine, abs_path, version=None, init_version=0):
+ """Upgrade or downgrade a database.
+
+ Function runs the upgrade() or downgrade() functions in change scripts.
+
+ :param engine: SQLAlchemy engine instance for a given database
+ :param abs_path: Absolute path to migrate repository.
+ :param version: Database will upgrade/downgrade until this version.
+ If None - database will update to the latest
+ available version.
+ :param init_version: Initial database version
+ """
+ if version is not None:
+ try:
+ version = int(version)
+ except ValueError:
+ raise exception.DbMigrationError(
+ message=_("version should be an integer"))
+
+ current_version = db_version(engine, abs_path, init_version)
+ repository = _find_migrate_repo(abs_path)
+ _db_schema_sanity_check(engine)
+ if version is None or version > current_version:
+ return versioning_api.upgrade(engine, repository, version)
+ else:
+ return versioning_api.downgrade(engine, repository,
+ version)
+
+
+def _db_schema_sanity_check(engine):
+ """Ensure all database tables were created with required parameters.
+
+ :param engine: SQLAlchemy engine instance for a given database
+
+ """
+
+ if engine.name == 'mysql':
+ onlyutf8_sql = ('SELECT TABLE_NAME,TABLE_COLLATION '
+ 'from information_schema.TABLES '
+ 'where TABLE_SCHEMA=%s and '
+ 'TABLE_COLLATION NOT LIKE "%%utf8%%"')
+
+ table_names = [res[0] for res in engine.execute(onlyutf8_sql,
+ engine.url.database)]
+ if len(table_names) > 0:
+ raise ValueError(_('Tables "%s" have non utf8 collation, '
+ 'please make sure all tables are CHARSET=utf8'
+ ) % ','.join(table_names))
+
+
+def db_version(engine, abs_path, init_version):
+ """Show the current version of the repository.
+
+ :param engine: SQLAlchemy engine instance for a given database
+ :param abs_path: Absolute path to migrate repository
+ :param version: Initial database version
+ """
+ repository = _find_migrate_repo(abs_path)
+ try:
+ return versioning_api.db_version(engine, repository)
+ except versioning_exceptions.DatabaseNotControlledError:
+ meta = sqlalchemy.MetaData()
+ meta.reflect(bind=engine)
+ tables = meta.tables
+ if len(tables) == 0 or 'alembic_version' in tables:
+ db_version_control(abs_path, init_version)
+ return versioning_api.db_version(engine, repository)
+ else:
+ raise exception.DbMigrationError(
+ message=_(
+ "The database is not under version control, but has "
+ "tables. Please stamp the current version of the schema "
+ "manually."))
+
+
+def db_version_control(engine, abs_path, version=None):
+ """Mark a database as under this repository's version control.
+
+ Once a database is under version control, schema changes should
+ only be done via change scripts in this repository.
+
+ :param engine: SQLAlchemy engine instance for a given database
+ :param abs_path: Absolute path to migrate repository
+ :param version: Initial database version
+ """
+ repository = _find_migrate_repo(abs_path)
+ versioning_api.version_control(engine, repository, version)
+ return version
+
+
+def _find_migrate_repo(abs_path):
+ """Get the project's change script repository
+
+ :param abs_path: Absolute path to migrate repository
+ """
+ if not os.path.exists(abs_path):
+ raise exception.DbMigrationError("Path %s not found" % abs_path)
+ return Repository(abs_path)
diff --git a/designate/openstack/common/db/sqlalchemy/models.py b/designate/openstack/common/db/sqlalchemy/models.py
new file mode 100644
index 00000000..01296da4
--- /dev/null
+++ b/designate/openstack/common/db/sqlalchemy/models.py
@@ -0,0 +1,115 @@
+# Copyright (c) 2011 X.commerce, a business unit of eBay Inc.
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# Copyright 2011 Piston Cloud Computing, Inc.
+# Copyright 2012 Cloudscaling Group, Inc.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""
+SQLAlchemy models.
+"""
+
+import six
+
+from sqlalchemy import Column, Integer
+from sqlalchemy import DateTime
+from sqlalchemy.orm import object_mapper
+
+from designate.openstack.common import timeutils
+
+
+class ModelBase(object):
+ """Base class for models."""
+ __table_initialized__ = False
+
+ def save(self, session):
+ """Save this object."""
+
+ # NOTE(boris-42): This part of code should be look like:
+ # session.add(self)
+ # session.flush()
+ # But there is a bug in sqlalchemy and eventlet that
+ # raises NoneType exception if there is no running
+ # transaction and rollback is called. As long as
+ # sqlalchemy has this bug we have to create transaction
+ # explicitly.
+ with session.begin(subtransactions=True):
+ session.add(self)
+ session.flush()
+
+ def __setitem__(self, key, value):
+ setattr(self, key, value)
+
+ def __getitem__(self, key):
+ return getattr(self, key)
+
+ def get(self, key, default=None):
+ return getattr(self, key, default)
+
+ @property
+ def _extra_keys(self):
+ """Specifies custom fields
+
+ Subclasses can override this property to return a list
+ of custom fields that should be included in their dict
+ representation.
+
+ For reference check tests/db/sqlalchemy/test_models.py
+ """
+ return []
+
+ def __iter__(self):
+ columns = dict(object_mapper(self).columns).keys()
+ # NOTE(russellb): Allow models to specify other keys that can be looked
+ # up, beyond the actual db columns. An example would be the 'name'
+ # property for an Instance.
+ columns.extend(self._extra_keys)
+ self._i = iter(columns)
+ return self
+
+ def next(self):
+ n = six.advance_iterator(self._i)
+ return n, getattr(self, n)
+
+ def update(self, values):
+ """Make the model object behave like a dict."""
+ for k, v in six.iteritems(values):
+ setattr(self, k, v)
+
+ def iteritems(self):
+ """Make the model object behave like a dict.
+
+ Includes attributes from joins.
+ """
+ local = dict(self)
+ joined = dict([(k, v) for k, v in six.iteritems(self.__dict__)
+ if not k[0] == '_'])
+ local.update(joined)
+ return six.iteritems(local)
+
+
+class TimestampMixin(object):
+ created_at = Column(DateTime, default=lambda: timeutils.utcnow())
+ updated_at = Column(DateTime, onupdate=lambda: timeutils.utcnow())
+
+
+class SoftDeleteMixin(object):
+ deleted_at = Column(DateTime)
+ deleted = Column(Integer, default=0)
+
+ def soft_delete(self, session):
+ """Mark this object as deleted."""
+ self.deleted = self.id
+ self.deleted_at = timeutils.utcnow()
+ self.save(session=session)
diff --git a/designate/openstack/common/db/sqlalchemy/provision.py b/designate/openstack/common/db/sqlalchemy/provision.py
new file mode 100644
index 00000000..47582ec8
--- /dev/null
+++ b/designate/openstack/common/db/sqlalchemy/provision.py
@@ -0,0 +1,187 @@
+# Copyright 2013 Mirantis.inc
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""Provision test environment for specific DB backends"""
+
+import argparse
+import os
+import random
+import string
+
+from six import moves
+import sqlalchemy
+
+from designate.openstack.common.db import exception as exc
+
+
+SQL_CONNECTION = os.getenv('OS_TEST_DBAPI_ADMIN_CONNECTION', 'sqlite://')
+
+
+def _gen_credentials(*names):
+ """Generate credentials."""
+ auth_dict = {}
+ for name in names:
+ val = ''.join(random.choice(string.ascii_lowercase)
+ for i in moves.range(10))
+ auth_dict[name] = val
+ return auth_dict
+
+
+def _get_engine(uri=SQL_CONNECTION):
+ """Engine creation
+
+ By default the uri is SQL_CONNECTION which is admin credentials.
+ Call the function without arguments to get admin connection. Admin
+ connection required to create temporary user and database for each
+ particular test. Otherwise use existing connection to recreate connection
+ to the temporary database.
+ """
+ return sqlalchemy.create_engine(uri, poolclass=sqlalchemy.pool.NullPool)
+
+
+def _execute_sql(engine, sql, driver):
+ """Initialize connection, execute sql query and close it."""
+ try:
+ with engine.connect() as conn:
+ if driver == 'postgresql':
+ conn.connection.set_isolation_level(0)
+ for s in sql:
+ conn.execute(s)
+ except sqlalchemy.exc.OperationalError:
+ msg = ('%s does not match database admin '
+ 'credentials or database does not exist.')
+ raise exc.DBConnectionError(msg % SQL_CONNECTION)
+
+
+def create_database(engine):
+ """Provide temporary user and database for each particular test."""
+ driver = engine.name
+
+ auth = _gen_credentials('database', 'user', 'passwd')
+
+ sqls = {
+ 'mysql': [
+ "drop database if exists %(database)s;",
+ "grant all on %(database)s.* to '%(user)s'@'localhost'"
+ " identified by '%(passwd)s';",
+ "create database %(database)s;",
+ ],
+ 'postgresql': [
+ "drop database if exists %(database)s;",
+ "drop user if exists %(user)s;",
+ "create user %(user)s with password '%(passwd)s';",
+ "create database %(database)s owner %(user)s;",
+ ]
+ }
+
+ if driver == 'sqlite':
+ return 'sqlite:////tmp/%s' % auth['database']
+
+ try:
+ sql_rows = sqls[driver]
+ except KeyError:
+ raise ValueError('Unsupported RDBMS %s' % driver)
+ sql_query = map(lambda x: x % auth, sql_rows)
+
+ _execute_sql(engine, sql_query, driver)
+
+ params = auth.copy()
+ params['backend'] = driver
+ return "%(backend)s://%(user)s:%(passwd)s@localhost/%(database)s" % params
+
+
+def drop_database(engine, current_uri):
+ """Drop temporary database and user after each particular test."""
+ engine = _get_engine(current_uri)
+ admin_engine = _get_engine()
+ driver = engine.name
+ auth = {'database': engine.url.database, 'user': engine.url.username}
+
+ if driver == 'sqlite':
+ try:
+ os.remove(auth['database'])
+ except OSError:
+ pass
+ return
+
+ sqls = {
+ 'mysql': [
+ "drop database if exists %(database)s;",
+ "drop user '%(user)s'@'localhost';",
+ ],
+ 'postgresql': [
+ "drop database if exists %(database)s;",
+ "drop user if exists %(user)s;",
+ ]
+ }
+
+ try:
+ sql_rows = sqls[driver]
+ except KeyError:
+ raise ValueError('Unsupported RDBMS %s' % driver)
+ sql_query = map(lambda x: x % auth, sql_rows)
+
+ _execute_sql(admin_engine, sql_query, driver)
+
+
+def main():
+ """Controller to handle commands
+
+ ::create: Create test user and database with random names.
+ ::drop: Drop user and database created by previous command.
+ """
+ parser = argparse.ArgumentParser(
+ description='Controller to handle database creation and dropping'
+ ' commands.',
+ epilog='Under normal circumstances is not used directly.'
+ ' Used in .testr.conf to automate test database creation'
+ ' and dropping processes.')
+ subparsers = parser.add_subparsers(
+ help='Subcommands to manipulate temporary test databases.')
+
+ create = subparsers.add_parser(
+ 'create',
+ help='Create temporary test '
+ 'databases and users.')
+ create.set_defaults(which='create')
+ create.add_argument(
+ 'instances_count',
+ type=int,
+ help='Number of databases to create.')
+
+ drop = subparsers.add_parser(
+ 'drop',
+ help='Drop temporary test databases and users.')
+ drop.set_defaults(which='drop')
+ drop.add_argument(
+ 'instances',
+ nargs='+',
+ help='List of databases uri to be dropped.')
+
+ args = parser.parse_args()
+
+ engine = _get_engine()
+ which = args.which
+
+ if which == "create":
+ for i in range(int(args.instances_count)):
+ print(create_database(engine))
+ elif which == "drop":
+ for db in args.instances:
+ drop_database(engine, db)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/designate/openstack/common/db/sqlalchemy/session.py b/designate/openstack/common/db/sqlalchemy/session.py
new file mode 100644
index 00000000..faaf7eb4
--- /dev/null
+++ b/designate/openstack/common/db/sqlalchemy/session.py
@@ -0,0 +1,809 @@
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""Session Handling for SQLAlchemy backend.
+
+Recommended ways to use sessions within this framework:
+
+* Don't use them explicitly; this is like running with ``AUTOCOMMIT=1``.
+ `model_query()` will implicitly use a session when called without one
+ supplied. This is the ideal situation because it will allow queries
+ to be automatically retried if the database connection is interrupted.
+
+ .. note:: Automatic retry will be enabled in a future patch.
+
+ It is generally fine to issue several queries in a row like this. Even though
+ they may be run in separate transactions and/or separate sessions, each one
+ will see the data from the prior calls. If needed, undo- or rollback-like
+ functionality should be handled at a logical level. For an example, look at
+ the code around quotas and `reservation_rollback()`.
+
+ Examples:
+
+ .. code:: python
+
+ def get_foo(context, foo):
+ return (model_query(context, models.Foo).
+ filter_by(foo=foo).
+ first())
+
+ def update_foo(context, id, newfoo):
+ (model_query(context, models.Foo).
+ filter_by(id=id).
+ update({'foo': newfoo}))
+
+ def create_foo(context, values):
+ foo_ref = models.Foo()
+ foo_ref.update(values)
+ foo_ref.save()
+ return foo_ref
+
+
+* Within the scope of a single method, keep all the reads and writes within
+ the context managed by a single session. In this way, the session's
+ `__exit__` handler will take care of calling `flush()` and `commit()` for
+ you. If using this approach, you should not explicitly call `flush()` or
+ `commit()`. Any error within the context of the session will cause the
+ session to emit a `ROLLBACK`. Database errors like `IntegrityError` will be
+ raised in `session`'s `__exit__` handler, and any try/except within the
+ context managed by `session` will not be triggered. And catching other
+ non-database errors in the session will not trigger the ROLLBACK, so
+ exception handlers should always be outside the session, unless the
+ developer wants to do a partial commit on purpose. If the connection is
+ dropped before this is possible, the database will implicitly roll back the
+ transaction.
+
+ .. note:: Statements in the session scope will not be automatically retried.
+
+ If you create models within the session, they need to be added, but you
+ do not need to call `model.save()`:
+
+ .. code:: python
+
+ def create_many_foo(context, foos):
+ session = sessionmaker()
+ with session.begin():
+ for foo in foos:
+ foo_ref = models.Foo()
+ foo_ref.update(foo)
+ session.add(foo_ref)
+
+ def update_bar(context, foo_id, newbar):
+ session = sessionmaker()
+ with session.begin():
+ foo_ref = (model_query(context, models.Foo, session).
+ filter_by(id=foo_id).
+ first())
+ (model_query(context, models.Bar, session).
+ filter_by(id=foo_ref['bar_id']).
+ update({'bar': newbar}))
+
+ .. note:: `update_bar` is a trivially simple example of using
+ ``with session.begin``. Whereas `create_many_foo` is a good example of
+ when a transaction is needed, it is always best to use as few queries as
+ possible.
+
+ The two queries in `update_bar` can be better expressed using a single query
+ which avoids the need for an explicit transaction. It can be expressed like
+ so:
+
+ .. code:: python
+
+ def update_bar(context, foo_id, newbar):
+ subq = (model_query(context, models.Foo.id).
+ filter_by(id=foo_id).
+ limit(1).
+ subquery())
+ (model_query(context, models.Bar).
+ filter_by(id=subq.as_scalar()).
+ update({'bar': newbar}))
+
+ For reference, this emits approximately the following SQL statement:
+
+ .. code:: sql
+
+ UPDATE bar SET bar = ${newbar}
+ WHERE id=(SELECT bar_id FROM foo WHERE id = ${foo_id} LIMIT 1);
+
+ .. note:: `create_duplicate_foo` is a trivially simple example of catching an
+ exception while using ``with session.begin``. Here create two duplicate
+ instances with same primary key, must catch the exception out of context
+ managed by a single session:
+
+ .. code:: python
+
+ def create_duplicate_foo(context):
+ foo1 = models.Foo()
+ foo2 = models.Foo()
+ foo1.id = foo2.id = 1
+ session = sessionmaker()
+ try:
+ with session.begin():
+ session.add(foo1)
+ session.add(foo2)
+ except exception.DBDuplicateEntry as e:
+ handle_error(e)
+
+* Passing an active session between methods. Sessions should only be passed
+ to private methods. The private method must use a subtransaction; otherwise
+ SQLAlchemy will throw an error when you call `session.begin()` on an existing
+ transaction. Public methods should not accept a session parameter and should
+ not be involved in sessions within the caller's scope.
+
+ Note that this incurs more overhead in SQLAlchemy than the above means
+ due to nesting transactions, and it is not possible to implicitly retry
+ failed database operations when using this approach.
+
+ This also makes code somewhat more difficult to read and debug, because a
+ single database transaction spans more than one method. Error handling
+ becomes less clear in this situation. When this is needed for code clarity,
+ it should be clearly documented.
+
+ .. code:: python
+
+ def myfunc(foo):
+ session = sessionmaker()
+ with session.begin():
+ # do some database things
+ bar = _private_func(foo, session)
+ return bar
+
+ def _private_func(foo, session=None):
+ if not session:
+ session = sessionmaker()
+ with session.begin(subtransaction=True):
+ # do some other database things
+ return bar
+
+
+There are some things which it is best to avoid:
+
+* Don't keep a transaction open any longer than necessary.
+
+ This means that your ``with session.begin()`` block should be as short
+ as possible, while still containing all the related calls for that
+ transaction.
+
+* Avoid ``with_lockmode('UPDATE')`` when possible.
+
+ In MySQL/InnoDB, when a ``SELECT ... FOR UPDATE`` query does not match
+ any rows, it will take a gap-lock. This is a form of write-lock on the
+ "gap" where no rows exist, and prevents any other writes to that space.
+ This can effectively prevent any INSERT into a table by locking the gap
+ at the end of the index. Similar problems will occur if the SELECT FOR UPDATE
+ has an overly broad WHERE clause, or doesn't properly use an index.
+
+ One idea proposed at ODS Fall '12 was to use a normal SELECT to test the
+ number of rows matching a query, and if only one row is returned,
+ then issue the SELECT FOR UPDATE.
+
+ The better long-term solution is to use
+ ``INSERT .. ON DUPLICATE KEY UPDATE``.
+ However, this can not be done until the "deleted" columns are removed and
+ proper UNIQUE constraints are added to the tables.
+
+
+Enabling soft deletes:
+
+* To use/enable soft-deletes, the `SoftDeleteMixin` must be added
+ to your model class. For example:
+
+ .. code:: python
+
+ class NovaBase(models.SoftDeleteMixin, models.ModelBase):
+ pass
+
+
+Efficient use of soft deletes:
+
+* There are two possible ways to mark a record as deleted:
+ `model.soft_delete()` and `query.soft_delete()`.
+
+ The `model.soft_delete()` method works with a single already-fetched entry.
+ `query.soft_delete()` makes only one db request for all entries that
+ correspond to the query.
+
+* In almost all cases you should use `query.soft_delete()`. Some examples:
+
+ .. code:: python
+
+ def soft_delete_bar():
+ count = model_query(BarModel).find(some_condition).soft_delete()
+ if count == 0:
+ raise Exception("0 entries were soft deleted")
+
+ def complex_soft_delete_with_synchronization_bar(session=None):
+ if session is None:
+ session = sessionmaker()
+ with session.begin(subtransactions=True):
+ count = (model_query(BarModel).
+ find(some_condition).
+ soft_delete(synchronize_session=True))
+ # Here synchronize_session is required, because we
+ # don't know what is going on in outer session.
+ if count == 0:
+ raise Exception("0 entries were soft deleted")
+
+* There is only one situation where `model.soft_delete()` is appropriate: when
+ you fetch a single record, work with it, and mark it as deleted in the same
+ transaction.
+
+ .. code:: python
+
+ def soft_delete_bar_model():
+ session = sessionmaker()
+ with session.begin():
+ bar_ref = model_query(BarModel).find(some_condition).first()
+ # Work with bar_ref
+ bar_ref.soft_delete(session=session)
+
+ However, if you need to work with all entries that correspond to query and
+ then soft delete them you should use the `query.soft_delete()` method:
+
+ .. code:: python
+
+ def soft_delete_multi_models():
+ session = sessionmaker()
+ with session.begin():
+ query = (model_query(BarModel, session=session).
+ find(some_condition))
+ model_refs = query.all()
+ # Work with model_refs
+ query.soft_delete(synchronize_session=False)
+ # synchronize_session=False should be set if there is no outer
+ # session and these entries are not used after this.
+
+ When working with many rows, it is very important to use query.soft_delete,
+ which issues a single query. Using `model.soft_delete()`, as in the following
+ example, is very inefficient.
+
+ .. code:: python
+
+ for bar_ref in bar_refs:
+ bar_ref.soft_delete(session=session)
+ # This will produce count(bar_refs) db requests.
+
+"""
+
+import functools
+import logging
+import re
+import time
+
+import six
+from sqlalchemy import exc as sqla_exc
+from sqlalchemy.interfaces import PoolListener
+import sqlalchemy.orm
+from sqlalchemy.pool import NullPool, StaticPool
+from sqlalchemy.sql.expression import literal_column
+
+from designate.openstack.common.db import exception
+from designate.openstack.common.gettextutils import _
+from designate.openstack.common import timeutils
+
+
+LOG = logging.getLogger(__name__)
+
+
+class SqliteForeignKeysListener(PoolListener):
+ """Ensures that the foreign key constraints are enforced in SQLite.
+
+ The foreign key constraints are disabled by default in SQLite,
+ so the foreign key constraints will be enabled here for every
+ database connection
+ """
+ def connect(self, dbapi_con, con_record):
+ dbapi_con.execute('pragma foreign_keys=ON')
+
+
+# note(boris-42): In current versions of DB backends unique constraint
+# violation messages follow the structure:
+#
+# sqlite:
+# 1 column - (IntegrityError) column c1 is not unique
+# N columns - (IntegrityError) column c1, c2, ..., N are not unique
+#
+# sqlite since 3.7.16:
+# 1 column - (IntegrityError) UNIQUE constraint failed: tbl.k1
+#
+# N columns - (IntegrityError) UNIQUE constraint failed: tbl.k1, tbl.k2
+#
+# postgres:
+# 1 column - (IntegrityError) duplicate key value violates unique
+# constraint "users_c1_key"
+# N columns - (IntegrityError) duplicate key value violates unique
+# constraint "name_of_our_constraint"
+#
+# mysql:
+# 1 column - (IntegrityError) (1062, "Duplicate entry 'value_of_c1' for key
+# 'c1'")
+# N columns - (IntegrityError) (1062, "Duplicate entry 'values joined
+# with -' for key 'name_of_our_constraint'")
+_DUP_KEY_RE_DB = {
+ "sqlite": (re.compile(r"^.*columns?([^)]+)(is|are)\s+not\s+unique$"),
+ re.compile(r"^.*UNIQUE\s+constraint\s+failed:\s+(.+)$")),
+ "postgresql": (re.compile(r"^.*duplicate\s+key.*\"([^\"]+)\"\s*\n.*$"),),
+ "mysql": (re.compile(r"^.*\(1062,.*'([^\']+)'\"\)$"),)
+}
+
+
+def _raise_if_duplicate_entry_error(integrity_error, engine_name):
+ """Raise exception if two entries are duplicated.
+
+ In this function will be raised DBDuplicateEntry exception if integrity
+ error wrap unique constraint violation.
+ """
+
+ def get_columns_from_uniq_cons_or_name(columns):
+ # note(vsergeyev): UniqueConstraint name convention: "uniq_t0c10c2"
+ # where `t` it is table name and columns `c1`, `c2`
+ # are in UniqueConstraint.
+ uniqbase = "uniq_"
+ if not columns.startswith(uniqbase):
+ if engine_name == "postgresql":
+ return [columns[columns.index("_") + 1:columns.rindex("_")]]
+ return [columns]
+ return columns[len(uniqbase):].split("0")[1:]
+
+ if engine_name not in ["mysql", "sqlite", "postgresql"]:
+ return
+
+ # FIXME(johannes): The usage of the .message attribute has been
+ # deprecated since Python 2.6. However, the exceptions raised by
+ # SQLAlchemy can differ when using unicode() and accessing .message.
+ # An audit across all three supported engines will be necessary to
+ # ensure there are no regressions.
+ for pattern in _DUP_KEY_RE_DB[engine_name]:
+ match = pattern.match(integrity_error.message)
+ if match:
+ break
+ else:
+ return
+
+ columns = match.group(1)
+
+ if engine_name == "sqlite":
+ columns = [c.split('.')[-1] for c in columns.strip().split(", ")]
+ else:
+ columns = get_columns_from_uniq_cons_or_name(columns)
+ raise exception.DBDuplicateEntry(columns, integrity_error)
+
+
+# NOTE(comstud): In current versions of DB backends, Deadlock violation
+# messages follow the structure:
+#
+# mysql:
+# (OperationalError) (1213, 'Deadlock found when trying to get lock; try '
+# 'restarting transaction') <query_str> <query_args>
+_DEADLOCK_RE_DB = {
+ "mysql": re.compile(r"^.*\(1213, 'Deadlock.*")
+}
+
+
+def _raise_if_deadlock_error(operational_error, engine_name):
+ """Raise exception on deadlock condition.
+
+ Raise DBDeadlock exception if OperationalError contains a Deadlock
+ condition.
+ """
+ re = _DEADLOCK_RE_DB.get(engine_name)
+ if re is None:
+ return
+ # FIXME(johannes): The usage of the .message attribute has been
+ # deprecated since Python 2.6. However, the exceptions raised by
+ # SQLAlchemy can differ when using unicode() and accessing .message.
+ # An audit across all three supported engines will be necessary to
+ # ensure there are no regressions.
+ m = re.match(operational_error.message)
+ if not m:
+ return
+ raise exception.DBDeadlock(operational_error)
+
+
+def _wrap_db_error(f):
+ #TODO(rpodolyaka): in a subsequent commit make this a class decorator to
+ # ensure it can only applied to Session subclasses instances (as we use
+ # Session instance bind attribute below)
+
+ @functools.wraps(f)
+ def _wrap(self, *args, **kwargs):
+ try:
+ return f(self, *args, **kwargs)
+ except UnicodeEncodeError:
+ raise exception.DBInvalidUnicodeParameter()
+ except sqla_exc.OperationalError as e:
+ _raise_if_db_connection_lost(e, self.bind)
+ _raise_if_deadlock_error(e, self.bind.dialect.name)
+ # NOTE(comstud): A lot of code is checking for OperationalError
+ # so let's not wrap it for now.
+ raise
+ # note(boris-42): We should catch unique constraint violation and
+ # wrap it by our own DBDuplicateEntry exception. Unique constraint
+ # violation is wrapped by IntegrityError.
+ except sqla_exc.IntegrityError as e:
+ # note(boris-42): SqlAlchemy doesn't unify errors from different
+ # DBs so we must do this. Also in some tables (for example
+ # instance_types) there are more than one unique constraint. This
+ # means we should get names of columns, which values violate
+ # unique constraint, from error message.
+ _raise_if_duplicate_entry_error(e, self.bind.dialect.name)
+ raise exception.DBError(e)
+ except Exception as e:
+ LOG.exception(_('DB exception wrapped.'))
+ raise exception.DBError(e)
+ return _wrap
+
+
+def _synchronous_switch_listener(dbapi_conn, connection_rec):
+ """Switch sqlite connections to non-synchronous mode."""
+ dbapi_conn.execute("PRAGMA synchronous = OFF")
+
+
+def _add_regexp_listener(dbapi_con, con_record):
+ """Add REGEXP function to sqlite connections."""
+
+ def regexp(expr, item):
+ reg = re.compile(expr)
+ return reg.search(six.text_type(item)) is not None
+ dbapi_con.create_function('regexp', 2, regexp)
+
+
+def _thread_yield(dbapi_con, con_record):
+ """Ensure other greenthreads get a chance to be executed.
+
+ If we use eventlet.monkey_patch(), eventlet.greenthread.sleep(0) will
+ execute instead of time.sleep(0).
+ Force a context switch. With common database backends (eg MySQLdb and
+ sqlite), there is no implicit yield caused by network I/O since they are
+ implemented by C libraries that eventlet cannot monkey patch.
+ """
+ time.sleep(0)
+
+
+def _ping_listener(engine, dbapi_conn, connection_rec, connection_proxy):
+ """Ensures that MySQL and DB2 connections are alive.
+
+ Borrowed from:
+ http://groups.google.com/group/sqlalchemy/msg/a4ce563d802c929f
+ """
+ cursor = dbapi_conn.cursor()
+ try:
+ ping_sql = 'select 1'
+ if engine.name == 'ibm_db_sa':
+ # DB2 requires a table expression
+ ping_sql = 'select 1 from (values (1)) AS t1'
+ cursor.execute(ping_sql)
+ except Exception as ex:
+ if engine.dialect.is_disconnect(ex, dbapi_conn, cursor):
+ msg = _('Database server has gone away: %s') % ex
+ LOG.warning(msg)
+ raise sqla_exc.DisconnectionError(msg)
+ else:
+ raise
+
+
+def _set_mode_traditional(dbapi_con, connection_rec, connection_proxy):
+ """Set engine mode to 'traditional'.
+
+ Required to prevent silent truncates at insert or update operations
+ under MySQL. By default MySQL truncates inserted string if it longer
+ than a declared field just with warning. That is fraught with data
+ corruption.
+ """
+ dbapi_con.cursor().execute("SET SESSION sql_mode = TRADITIONAL;")
+
+
+def _is_db_connection_error(args):
+ """Return True if error in connecting to db."""
+ # NOTE(adam_g): This is currently MySQL specific and needs to be extended
+ # to support Postgres and others.
+ # For the db2, the error code is -30081 since the db2 is still not ready
+ conn_err_codes = ('2002', '2003', '2006', '2013', '-30081')
+ for err_code in conn_err_codes:
+ if args.find(err_code) != -1:
+ return True
+ return False
+
+
+def _raise_if_db_connection_lost(error, engine):
+ # NOTE(vsergeyev): Function is_disconnect(e, connection, cursor)
+ # requires connection and cursor in incoming parameters,
+ # but we have no possibility to create connection if DB
+ # is not available, so in such case reconnect fails.
+ # But is_disconnect() ignores these parameters, so it
+ # makes sense to pass to function None as placeholder
+ # instead of connection and cursor.
+ if engine.dialect.is_disconnect(error, None, None):
+ raise exception.DBConnectionError(error)
+
+
+def create_engine(sql_connection, sqlite_fk=False,
+ mysql_traditional_mode=False, idle_timeout=3600,
+ connection_debug=0, max_pool_size=None, max_overflow=None,
+ pool_timeout=None, sqlite_synchronous=True,
+ connection_trace=False, max_retries=10, retry_interval=10):
+ """Return a new SQLAlchemy engine."""
+
+ connection_dict = sqlalchemy.engine.url.make_url(sql_connection)
+
+ engine_args = {
+ "pool_recycle": idle_timeout,
+ "echo": False,
+ 'convert_unicode': True,
+ }
+
+ # Map our SQL debug level to SQLAlchemy's options
+ if connection_debug >= 100:
+ engine_args['echo'] = 'debug'
+ elif connection_debug >= 50:
+ engine_args['echo'] = True
+
+ if "sqlite" in connection_dict.drivername:
+ if sqlite_fk:
+ engine_args["listeners"] = [SqliteForeignKeysListener()]
+ engine_args["poolclass"] = NullPool
+
+ if sql_connection == "sqlite://":
+ engine_args["poolclass"] = StaticPool
+ engine_args["connect_args"] = {'check_same_thread': False}
+ else:
+ if max_pool_size is not None:
+ engine_args['pool_size'] = max_pool_size
+ if max_overflow is not None:
+ engine_args['max_overflow'] = max_overflow
+ if pool_timeout is not None:
+ engine_args['pool_timeout'] = pool_timeout
+
+ engine = sqlalchemy.create_engine(sql_connection, **engine_args)
+
+ sqlalchemy.event.listen(engine, 'checkin', _thread_yield)
+
+ if engine.name in ['mysql', 'ibm_db_sa']:
+ callback = functools.partial(_ping_listener, engine)
+ sqlalchemy.event.listen(engine, 'checkout', callback)
+ if engine.name == 'mysql':
+ if mysql_traditional_mode:
+ sqlalchemy.event.listen(engine, 'checkout',
+ _set_mode_traditional)
+ else:
+ LOG.warning(_("This application has not enabled MySQL "
+ "traditional mode, which means silent "
+ "data corruption may occur. "
+ "Please encourage the application "
+ "developers to enable this mode."))
+ elif 'sqlite' in connection_dict.drivername:
+ if not sqlite_synchronous:
+ sqlalchemy.event.listen(engine, 'connect',
+ _synchronous_switch_listener)
+ sqlalchemy.event.listen(engine, 'connect', _add_regexp_listener)
+
+ if connection_trace and engine.dialect.dbapi.__name__ == 'MySQLdb':
+ _patch_mysqldb_with_stacktrace_comments()
+
+ try:
+ engine.connect()
+ except sqla_exc.OperationalError as e:
+ if not _is_db_connection_error(e.args[0]):
+ raise
+
+ remaining = max_retries
+ if remaining == -1:
+ remaining = 'infinite'
+ while True:
+ msg = _('SQL connection failed. %s attempts left.')
+ LOG.warning(msg % remaining)
+ if remaining != 'infinite':
+ remaining -= 1
+ time.sleep(retry_interval)
+ try:
+ engine.connect()
+ break
+ except sqla_exc.OperationalError as e:
+ if (remaining != 'infinite' and remaining == 0) or \
+ not _is_db_connection_error(e.args[0]):
+ raise
+ return engine
+
+
+class Query(sqlalchemy.orm.query.Query):
+ """Subclass of sqlalchemy.query with soft_delete() method."""
+ def soft_delete(self, synchronize_session='evaluate'):
+ return self.update({'deleted': literal_column('id'),
+ 'updated_at': literal_column('updated_at'),
+ 'deleted_at': timeutils.utcnow()},
+ synchronize_session=synchronize_session)
+
+
+class Session(sqlalchemy.orm.session.Session):
+ """Custom Session class to avoid SqlAlchemy Session monkey patching."""
+ @_wrap_db_error
+ def query(self, *args, **kwargs):
+ return super(Session, self).query(*args, **kwargs)
+
+ @_wrap_db_error
+ def flush(self, *args, **kwargs):
+ return super(Session, self).flush(*args, **kwargs)
+
+ @_wrap_db_error
+ def execute(self, *args, **kwargs):
+ return super(Session, self).execute(*args, **kwargs)
+
+
+def get_maker(engine, autocommit=True, expire_on_commit=False):
+ """Return a SQLAlchemy sessionmaker using the given engine."""
+ return sqlalchemy.orm.sessionmaker(bind=engine,
+ class_=Session,
+ autocommit=autocommit,
+ expire_on_commit=expire_on_commit,
+ query_cls=Query)
+
+
+def _patch_mysqldb_with_stacktrace_comments():
+ """Adds current stack trace as a comment in queries.
+
+ Patches MySQLdb.cursors.BaseCursor._do_query.
+ """
+ import MySQLdb.cursors
+ import traceback
+
+ old_mysql_do_query = MySQLdb.cursors.BaseCursor._do_query
+
+ def _do_query(self, q):
+ stack = ''
+ for filename, line, method, function in traceback.extract_stack():
+ # exclude various common things from trace
+ if filename.endswith('session.py') and method == '_do_query':
+ continue
+ if filename.endswith('api.py') and method == 'wrapper':
+ continue
+ if filename.endswith('utils.py') and method == '_inner':
+ continue
+ if filename.endswith('exception.py') and method == '_wrap':
+ continue
+ # db/api is just a wrapper around db/sqlalchemy/api
+ if filename.endswith('db/api.py'):
+ continue
+ # only trace inside designate
+ index = filename.rfind('designate')
+ if index == -1:
+ continue
+ stack += "File:%s:%s Method:%s() Line:%s | " \
+ % (filename[index:], line, method, function)
+
+ # strip trailing " | " from stack
+ if stack:
+ stack = stack[:-3]
+ qq = "%s /* %s */" % (q, stack)
+ else:
+ qq = q
+ old_mysql_do_query(self, qq)
+
+ setattr(MySQLdb.cursors.BaseCursor, '_do_query', _do_query)
+
+
+class EngineFacade(object):
+ """A helper class for removing of global engine instances from designate.db.
+
+ As a library, designate.db can't decide where to store/when to create engine
+ and sessionmaker instances, so this must be left for a target application.
+
+ On the other hand, in order to simplify the adoption of designate.db changes,
+ we'll provide a helper class, which creates engine and sessionmaker
+ on its instantiation and provides get_engine()/get_session() methods
+ that are compatible with corresponding utility functions that currently
+ exist in target projects, e.g. in Nova.
+
+ engine/sessionmaker instances will still be global (and they are meant to
+ be global), but they will be stored in the app context, rather that in the
+ designate.db context.
+
+ Note: using of this helper is completely optional and you are encouraged to
+ integrate engine/sessionmaker instances into your apps any way you like
+ (e.g. one might want to bind a session to a request context). Two important
+ things to remember:
+ 1. An Engine instance is effectively a pool of DB connections, so it's
+ meant to be shared (and it's thread-safe).
+ 2. A Session instance is not meant to be shared and represents a DB
+ transactional context (i.e. it's not thread-safe). sessionmaker is
+ a factory of sessions.
+
+ """
+
+ def __init__(self, sql_connection,
+ sqlite_fk=False, mysql_traditional_mode=False,
+ autocommit=True, expire_on_commit=False, **kwargs):
+ """Initialize engine and sessionmaker instances.
+
+ :param sqlite_fk: enable foreign keys in SQLite
+ :type sqlite_fk: bool
+
+ :param mysql_traditional_mode: enable traditional mode in MySQL
+ :type mysql_traditional_mode: bool
+
+ :param autocommit: use autocommit mode for created Session instances
+ :type autocommit: bool
+
+ :param expire_on_commit: expire session objects on commit
+ :type expire_on_commit: bool
+
+ Keyword arguments:
+
+ :keyword idle_timeout: timeout before idle sql connections are reaped
+ (defaults to 3600)
+ :keyword connection_debug: verbosity of SQL debugging information.
+ 0=None, 100=Everything (defaults to 0)
+ :keyword max_pool_size: maximum number of SQL connections to keep open
+ in a pool (defaults to SQLAlchemy settings)
+ :keyword max_overflow: if set, use this value for max_overflow with
+ sqlalchemy (defaults to SQLAlchemy settings)
+ :keyword pool_timeout: if set, use this value for pool_timeout with
+ sqlalchemy (defaults to SQLAlchemy settings)
+ :keyword sqlite_synchronous: if True, SQLite uses synchronous mode
+ (defaults to True)
+ :keyword connection_trace: add python stack traces to SQL as comment
+ strings (defaults to False)
+ :keyword max_retries: maximum db connection retries during startup.
+ (setting -1 implies an infinite retry count)
+ (defaults to 10)
+ :keyword retry_interval: interval between retries of opening a sql
+ connection (defaults to 10)
+
+ """
+
+ super(EngineFacade, self).__init__()
+
+ self._engine = create_engine(
+ sql_connection=sql_connection,
+ sqlite_fk=sqlite_fk,
+ mysql_traditional_mode=mysql_traditional_mode,
+ idle_timeout=kwargs.get('idle_timeout', 3600),
+ connection_debug=kwargs.get('connection_debug', 0),
+ max_pool_size=kwargs.get('max_pool_size', None),
+ max_overflow=kwargs.get('max_overflow', None),
+ pool_timeout=kwargs.get('pool_timeout', None),
+ sqlite_synchronous=kwargs.get('sqlite_synchronous', True),
+ connection_trace=kwargs.get('connection_trace', False),
+ max_retries=kwargs.get('max_retries', 10),
+ retry_interval=kwargs.get('retry_interval', 10))
+ self._session_maker = get_maker(
+ engine=self._engine,
+ autocommit=autocommit,
+ expire_on_commit=expire_on_commit)
+
+ def get_engine(self):
+ """Get the engine instance (note, that it's shared)."""
+
+ return self._engine
+
+ def get_session(self, **kwargs):
+ """Get a Session instance.
+
+ If passed, keyword arguments values override the ones used when the
+ sessionmaker instance was created.
+
+ :keyword autocommit: use autocommit mode for created Session instances
+ :type autocommit: bool
+
+ :keyword expire_on_commit: expire session objects on commit
+ :type expire_on_commit: bool
+
+ """
+
+ for arg in kwargs:
+ if arg not in ('autocommit', 'expire_on_commit'):
+ del kwargs[arg]
+
+ return self._session_maker(**kwargs)
diff --git a/designate/openstack/common/db/sqlalchemy/test_base.py b/designate/openstack/common/db/sqlalchemy/test_base.py
new file mode 100644
index 00000000..a97875d0
--- /dev/null
+++ b/designate/openstack/common/db/sqlalchemy/test_base.py
@@ -0,0 +1,149 @@
+# Copyright (c) 2013 OpenStack Foundation
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import abc
+import functools
+import os
+
+import fixtures
+import six
+
+from designate.openstack.common.db.sqlalchemy import session
+from designate.openstack.common.db.sqlalchemy import utils
+from designate.openstack.common import test
+
+
+class DbFixture(fixtures.Fixture):
+ """Basic database fixture.
+
+ Allows to run tests on various db backends, such as SQLite, MySQL and
+ PostgreSQL. By default use sqlite backend. To override default backend
+ uri set env variable OS_TEST_DBAPI_CONNECTION with database admin
+ credentials for specific backend.
+ """
+
+ def _get_uri(self):
+ return os.getenv('OS_TEST_DBAPI_CONNECTION', 'sqlite://')
+
+ def __init__(self, test):
+ super(DbFixture, self).__init__()
+
+ self.test = test
+
+ def setUp(self):
+ super(DbFixture, self).setUp()
+
+ self.test.engine = session.create_engine(self._get_uri())
+ self.test.sessionmaker = session.get_maker(self.test.engine)
+ self.addCleanup(self.test.engine.dispose)
+
+
+class DbTestCase(test.BaseTestCase):
+ """Base class for testing of DB code.
+
+ Using `DbFixture`. Intended to be the main database test case to use all
+ the tests on a given backend with user defined uri. Backend specific
+ tests should be decorated with `backend_specific` decorator.
+ """
+
+ FIXTURE = DbFixture
+
+ def setUp(self):
+ super(DbTestCase, self).setUp()
+ self.useFixture(self.FIXTURE(self))
+
+
+ALLOWED_DIALECTS = ['sqlite', 'mysql', 'postgresql']
+
+
+def backend_specific(*dialects):
+ """Decorator to skip backend specific tests on inappropriate engines.
+
+ ::dialects: list of dialects names under which the test will be launched.
+ """
+ def wrap(f):
+ @functools.wraps(f)
+ def ins_wrap(self):
+ if not set(dialects).issubset(ALLOWED_DIALECTS):
+ raise ValueError(
+ "Please use allowed dialects: %s" % ALLOWED_DIALECTS)
+ if self.engine.name not in dialects:
+ msg = ('The test "%s" can be run '
+ 'only on %s. Current engine is %s.')
+ args = (f.__name__, ' '.join(dialects), self.engine.name)
+ self.skip(msg % args)
+ else:
+ return f(self)
+ return ins_wrap
+ return wrap
+
+
+@six.add_metaclass(abc.ABCMeta)
+class OpportunisticFixture(DbFixture):
+ """Base fixture to use default CI databases.
+
+ The databases exist in OpenStack CI infrastructure. But for the
+ correct functioning in local environment the databases must be
+ created manually.
+ """
+
+ DRIVER = abc.abstractproperty(lambda: None)
+ DBNAME = PASSWORD = USERNAME = 'openstack_citest'
+
+ def _get_uri(self):
+ return utils.get_connect_string(backend=self.DRIVER,
+ user=self.USERNAME,
+ passwd=self.PASSWORD,
+ database=self.DBNAME)
+
+
+@six.add_metaclass(abc.ABCMeta)
+class OpportunisticTestCase(DbTestCase):
+ """Base test case to use default CI databases.
+
+ The subclasses of the test case are running only when openstack_citest
+ database is available otherwise a tests will be skipped.
+ """
+
+ FIXTURE = abc.abstractproperty(lambda: None)
+
+ def setUp(self):
+ credentials = {
+ 'backend': self.FIXTURE.DRIVER,
+ 'user': self.FIXTURE.USERNAME,
+ 'passwd': self.FIXTURE.PASSWORD,
+ 'database': self.FIXTURE.DBNAME}
+
+ if self.FIXTURE.DRIVER and not utils.is_backend_avail(**credentials):
+ msg = '%s backend is not available.' % self.FIXTURE.DRIVER
+ return self.skip(msg)
+
+ super(OpportunisticTestCase, self).setUp()
+
+
+class MySQLOpportunisticFixture(OpportunisticFixture):
+ DRIVER = 'mysql'
+
+
+class PostgreSQLOpportunisticFixture(OpportunisticFixture):
+ DRIVER = 'postgresql'
+
+
+class MySQLOpportunisticTestCase(OpportunisticTestCase):
+ FIXTURE = MySQLOpportunisticFixture
+
+
+class PostgreSQLOpportunisticTestCase(OpportunisticTestCase):
+ FIXTURE = PostgreSQLOpportunisticFixture
diff --git a/designate/openstack/common/db/sqlalchemy/test_migrations.py b/designate/openstack/common/db/sqlalchemy/test_migrations.py
new file mode 100644
index 00000000..5e2e278f
--- /dev/null
+++ b/designate/openstack/common/db/sqlalchemy/test_migrations.py
@@ -0,0 +1,269 @@
+# Copyright 2010-2011 OpenStack Foundation
+# Copyright 2012-2013 IBM Corp.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import functools
+import logging
+import os
+import subprocess
+
+import lockfile
+from six import moves
+from six.moves.urllib import parse
+import sqlalchemy
+import sqlalchemy.exc
+
+from designate.openstack.common.db.sqlalchemy import utils
+from designate.openstack.common.gettextutils import _
+from designate.openstack.common import test
+
+LOG = logging.getLogger(__name__)
+
+
+def _have_mysql(user, passwd, database):
+ present = os.environ.get('TEST_MYSQL_PRESENT')
+ if present is None:
+ return utils.is_backend_avail(backend='mysql',
+ user=user,
+ passwd=passwd,
+ database=database)
+ return present.lower() in ('', 'true')
+
+
+def _have_postgresql(user, passwd, database):
+ present = os.environ.get('TEST_POSTGRESQL_PRESENT')
+ if present is None:
+ return utils.is_backend_avail(backend='postgres',
+ user=user,
+ passwd=passwd,
+ database=database)
+ return present.lower() in ('', 'true')
+
+
+def _set_db_lock(lock_path=None, lock_prefix=None):
+ def decorator(f):
+ @functools.wraps(f)
+ def wrapper(*args, **kwargs):
+ try:
+ path = lock_path or os.environ.get("DESIGNATE_LOCK_PATH")
+ lock = lockfile.FileLock(os.path.join(path, lock_prefix))
+ with lock:
+ LOG.debug(_('Got lock "%s"') % f.__name__)
+ return f(*args, **kwargs)
+ finally:
+ LOG.debug(_('Lock released "%s"') % f.__name__)
+ return wrapper
+ return decorator
+
+
+class BaseMigrationTestCase(test.BaseTestCase):
+ """Base class fort testing of migration utils."""
+
+ def __init__(self, *args, **kwargs):
+ super(BaseMigrationTestCase, self).__init__(*args, **kwargs)
+
+ self.DEFAULT_CONFIG_FILE = os.path.join(os.path.dirname(__file__),
+ 'test_migrations.conf')
+ # Test machines can set the TEST_MIGRATIONS_CONF variable
+ # to override the location of the config file for migration testing
+ self.CONFIG_FILE_PATH = os.environ.get('TEST_MIGRATIONS_CONF',
+ self.DEFAULT_CONFIG_FILE)
+ self.test_databases = {}
+ self.migration_api = None
+
+ def setUp(self):
+ super(BaseMigrationTestCase, self).setUp()
+
+ # Load test databases from the config file. Only do this
+ # once. No need to re-run this on each test...
+ LOG.debug('config_path is %s' % self.CONFIG_FILE_PATH)
+ if os.path.exists(self.CONFIG_FILE_PATH):
+ cp = moves.configparser.RawConfigParser()
+ try:
+ cp.read(self.CONFIG_FILE_PATH)
+ defaults = cp.defaults()
+ for key, value in defaults.items():
+ self.test_databases[key] = value
+ except moves.configparser.ParsingError as e:
+ self.fail("Failed to read test_migrations.conf config "
+ "file. Got error: %s" % e)
+ else:
+ self.fail("Failed to find test_migrations.conf config "
+ "file.")
+
+ self.engines = {}
+ for key, value in self.test_databases.items():
+ self.engines[key] = sqlalchemy.create_engine(value)
+
+ # We start each test case with a completely blank slate.
+ self._reset_databases()
+
+ def tearDown(self):
+ # We destroy the test data store between each test case,
+ # and recreate it, which ensures that we have no side-effects
+ # from the tests
+ self._reset_databases()
+ super(BaseMigrationTestCase, self).tearDown()
+
+ def execute_cmd(self, cmd=None):
+ process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT)
+ output = process.communicate()[0]
+ LOG.debug(output)
+ self.assertEqual(0, process.returncode,
+ "Failed to run: %s\n%s" % (cmd, output))
+
+ def _reset_pg(self, conn_pieces):
+ (user,
+ password,
+ database,
+ host) = utils.get_db_connection_info(conn_pieces)
+ os.environ['PGPASSWORD'] = password
+ os.environ['PGUSER'] = user
+ # note(boris-42): We must create and drop database, we can't
+ # drop database which we have connected to, so for such
+ # operations there is a special database template1.
+ sqlcmd = ("psql -w -U %(user)s -h %(host)s -c"
+ " '%(sql)s' -d template1")
+
+ sql = ("drop database if exists %s;") % database
+ droptable = sqlcmd % {'user': user, 'host': host, 'sql': sql}
+ self.execute_cmd(droptable)
+
+ sql = ("create database %s;") % database
+ createtable = sqlcmd % {'user': user, 'host': host, 'sql': sql}
+ self.execute_cmd(createtable)
+
+ os.unsetenv('PGPASSWORD')
+ os.unsetenv('PGUSER')
+
+ @_set_db_lock(lock_prefix='migration_tests-')
+ def _reset_databases(self):
+ for key, engine in self.engines.items():
+ conn_string = self.test_databases[key]
+ conn_pieces = parse.urlparse(conn_string)
+ engine.dispose()
+ if conn_string.startswith('sqlite'):
+ # We can just delete the SQLite database, which is
+ # the easiest and cleanest solution
+ db_path = conn_pieces.path.strip('/')
+ if os.path.exists(db_path):
+ os.unlink(db_path)
+ # No need to recreate the SQLite DB. SQLite will
+ # create it for us if it's not there...
+ elif conn_string.startswith('mysql'):
+ # We can execute the MySQL client to destroy and re-create
+ # the MYSQL database, which is easier and less error-prone
+ # than using SQLAlchemy to do this via MetaData...trust me.
+ (user, password, database, host) = \
+ utils.get_db_connection_info(conn_pieces)
+ sql = ("drop database if exists %(db)s; "
+ "create database %(db)s;") % {'db': database}
+ cmd = ("mysql -u \"%(user)s\" -p\"%(password)s\" -h %(host)s "
+ "-e \"%(sql)s\"") % {'user': user, 'password': password,
+ 'host': host, 'sql': sql}
+ self.execute_cmd(cmd)
+ elif conn_string.startswith('postgresql'):
+ self._reset_pg(conn_pieces)
+
+
+class WalkVersionsMixin(object):
+ def _walk_versions(self, engine=None, snake_walk=False, downgrade=True):
+ # Determine latest version script from the repo, then
+ # upgrade from 1 through to the latest, with no data
+ # in the databases. This just checks that the schema itself
+ # upgrades successfully.
+
+ # Place the database under version control
+ self.migration_api.version_control(engine, self.REPOSITORY,
+ self.INIT_VERSION)
+ self.assertEqual(self.INIT_VERSION,
+ self.migration_api.db_version(engine,
+ self.REPOSITORY))
+
+ LOG.debug('latest version is %s' % self.REPOSITORY.latest)
+ versions = range(self.INIT_VERSION + 1, self.REPOSITORY.latest + 1)
+
+ for version in versions:
+ # upgrade -> downgrade -> upgrade
+ self._migrate_up(engine, version, with_data=True)
+ if snake_walk:
+ downgraded = self._migrate_down(
+ engine, version - 1, with_data=True)
+ if downgraded:
+ self._migrate_up(engine, version)
+
+ if downgrade:
+ # Now walk it back down to 0 from the latest, testing
+ # the downgrade paths.
+ for version in reversed(versions):
+ # downgrade -> upgrade -> downgrade
+ downgraded = self._migrate_down(engine, version - 1)
+
+ if snake_walk and downgraded:
+ self._migrate_up(engine, version)
+ self._migrate_down(engine, version - 1)
+
+ def _migrate_down(self, engine, version, with_data=False):
+ try:
+ self.migration_api.downgrade(engine, self.REPOSITORY, version)
+ except NotImplementedError:
+ # NOTE(sirp): some migrations, namely release-level
+ # migrations, don't support a downgrade.
+ return False
+
+ self.assertEqual(
+ version, self.migration_api.db_version(engine, self.REPOSITORY))
+
+ # NOTE(sirp): `version` is what we're downgrading to (i.e. the 'target'
+ # version). So if we have any downgrade checks, they need to be run for
+ # the previous (higher numbered) migration.
+ if with_data:
+ post_downgrade = getattr(
+ self, "_post_downgrade_%03d" % (version + 1), None)
+ if post_downgrade:
+ post_downgrade(engine)
+
+ return True
+
+ def _migrate_up(self, engine, version, with_data=False):
+ """migrate up to a new version of the db.
+
+ We allow for data insertion and post checks at every
+ migration version with special _pre_upgrade_### and
+ _check_### functions in the main test.
+ """
+ # NOTE(sdague): try block is here because it's impossible to debug
+ # where a failed data migration happens otherwise
+ try:
+ if with_data:
+ data = None
+ pre_upgrade = getattr(
+ self, "_pre_upgrade_%03d" % version, None)
+ if pre_upgrade:
+ data = pre_upgrade(engine)
+
+ self.migration_api.upgrade(engine, self.REPOSITORY, version)
+ self.assertEqual(version,
+ self.migration_api.db_version(engine,
+ self.REPOSITORY))
+ if with_data:
+ check = getattr(self, "_check_%03d" % version, None)
+ if check:
+ check(engine, data)
+ except Exception:
+ LOG.error("Failed to migrate to version %s on engine %s" %
+ (version, engine))
+ raise
diff --git a/designate/openstack/common/db/sqlalchemy/utils.py b/designate/openstack/common/db/sqlalchemy/utils.py
new file mode 100644
index 00000000..563d5bab
--- /dev/null
+++ b/designate/openstack/common/db/sqlalchemy/utils.py
@@ -0,0 +1,547 @@
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# Copyright 2010-2011 OpenStack Foundation.
+# Copyright 2012 Justin Santa Barbara
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import logging
+import re
+
+from migrate.changeset import UniqueConstraint
+import sqlalchemy
+from sqlalchemy import Boolean
+from sqlalchemy import CheckConstraint
+from sqlalchemy import Column
+from sqlalchemy.engine import reflection
+from sqlalchemy.ext.compiler import compiles
+from sqlalchemy import func
+from sqlalchemy import Index
+from sqlalchemy import Integer
+from sqlalchemy import MetaData
+from sqlalchemy.sql.expression import literal_column
+from sqlalchemy.sql.expression import UpdateBase
+from sqlalchemy.sql import select
+from sqlalchemy import String
+from sqlalchemy import Table
+from sqlalchemy.types import NullType
+
+from designate.openstack.common.gettextutils import _
+from designate.openstack.common import timeutils
+
+
+LOG = logging.getLogger(__name__)
+
+_DBURL_REGEX = re.compile(r"[^:]+://([^:]+):([^@]+)@.+")
+
+
+def sanitize_db_url(url):
+ match = _DBURL_REGEX.match(url)
+ if match:
+ return '%s****:****%s' % (url[:match.start(1)], url[match.end(2):])
+ return url
+
+
+class InvalidSortKey(Exception):
+ message = _("Sort key supplied was not valid.")
+
+
+# copy from glance/db/sqlalchemy/api.py
+def paginate_query(query, model, limit, sort_keys, marker=None,
+ sort_dir=None, sort_dirs=None):
+ """Returns a query with sorting / pagination criteria added.
+
+ Pagination works by requiring a unique sort_key, specified by sort_keys.
+ (If sort_keys is not unique, then we risk looping through values.)
+ We use the last row in the previous page as the 'marker' for pagination.
+ So we must return values that follow the passed marker in the order.
+ With a single-valued sort_key, this would be easy: sort_key > X.
+ With a compound-values sort_key, (k1, k2, k3) we must do this to repeat
+ the lexicographical ordering:
+ (k1 > X1) or (k1 == X1 && k2 > X2) or (k1 == X1 && k2 == X2 && k3 > X3)
+
+ We also have to cope with different sort_directions.
+
+ Typically, the id of the last row is used as the client-facing pagination
+ marker, then the actual marker object must be fetched from the db and
+ passed in to us as marker.
+
+ :param query: the query object to which we should add paging/sorting
+ :param model: the ORM model class
+ :param limit: maximum number of items to return
+ :param sort_keys: array of attributes by which results should be sorted
+ :param marker: the last item of the previous page; we returns the next
+ results after this value.
+ :param sort_dir: direction in which results should be sorted (asc, desc)
+ :param sort_dirs: per-column array of sort_dirs, corresponding to sort_keys
+
+ :rtype: sqlalchemy.orm.query.Query
+ :return: The query with sorting/pagination added.
+ """
+
+ if 'id' not in sort_keys:
+ # TODO(justinsb): If this ever gives a false-positive, check
+ # the actual primary key, rather than assuming its id
+ LOG.warning(_('Id not in sort_keys; is sort_keys unique?'))
+
+ assert(not (sort_dir and sort_dirs))
+
+ # Default the sort direction to ascending
+ if sort_dirs is None and sort_dir is None:
+ sort_dir = 'asc'
+
+ # Ensure a per-column sort direction
+ if sort_dirs is None:
+ sort_dirs = [sort_dir for _sort_key in sort_keys]
+
+ assert(len(sort_dirs) == len(sort_keys))
+
+ # Add sorting
+ for current_sort_key, current_sort_dir in zip(sort_keys, sort_dirs):
+ try:
+ sort_dir_func = {
+ 'asc': sqlalchemy.asc,
+ 'desc': sqlalchemy.desc,
+ }[current_sort_dir]
+ except KeyError:
+ raise ValueError(_("Unknown sort direction, "
+ "must be 'desc' or 'asc'"))
+ try:
+ sort_key_attr = getattr(model, current_sort_key)
+ except AttributeError:
+ raise InvalidSortKey()
+ query = query.order_by(sort_dir_func(sort_key_attr))
+
+ # Add pagination
+ if marker is not None:
+ marker_values = []
+ for sort_key in sort_keys:
+ v = getattr(marker, sort_key)
+ marker_values.append(v)
+
+ # Build up an array of sort criteria as in the docstring
+ criteria_list = []
+ for i in range(len(sort_keys)):
+ crit_attrs = []
+ for j in range(i):
+ model_attr = getattr(model, sort_keys[j])
+ crit_attrs.append((model_attr == marker_values[j]))
+
+ model_attr = getattr(model, sort_keys[i])
+ if sort_dirs[i] == 'desc':
+ crit_attrs.append((model_attr < marker_values[i]))
+ else:
+ crit_attrs.append((model_attr > marker_values[i]))
+
+ criteria = sqlalchemy.sql.and_(*crit_attrs)
+ criteria_list.append(criteria)
+
+ f = sqlalchemy.sql.or_(*criteria_list)
+ query = query.filter(f)
+
+ if limit is not None:
+ query = query.limit(limit)
+
+ return query
+
+
+def get_table(engine, name):
+ """Returns an sqlalchemy table dynamically from db.
+
+ Needed because the models don't work for us in migrations
+ as models will be far out of sync with the current data.
+ """
+ metadata = MetaData()
+ metadata.bind = engine
+ return Table(name, metadata, autoload=True)
+
+
+class InsertFromSelect(UpdateBase):
+ """Form the base for `INSERT INTO table (SELECT ... )` statement."""
+ def __init__(self, table, select):
+ self.table = table
+ self.select = select
+
+
+@compiles(InsertFromSelect)
+def visit_insert_from_select(element, compiler, **kw):
+ """Form the `INSERT INTO table (SELECT ... )` statement."""
+ return "INSERT INTO %s %s" % (
+ compiler.process(element.table, asfrom=True),
+ compiler.process(element.select))
+
+
+class ColumnError(Exception):
+ """Error raised when no column or an invalid column is found."""
+
+
+def _get_not_supported_column(col_name_col_instance, column_name):
+ try:
+ column = col_name_col_instance[column_name]
+ except KeyError:
+ msg = _("Please specify column %s in col_name_col_instance "
+ "param. It is required because column has unsupported "
+ "type by sqlite).")
+ raise ColumnError(msg % column_name)
+
+ if not isinstance(column, Column):
+ msg = _("col_name_col_instance param has wrong type of "
+ "column instance for column %s It should be instance "
+ "of sqlalchemy.Column.")
+ raise ColumnError(msg % column_name)
+ return column
+
+
+def drop_unique_constraint(migrate_engine, table_name, uc_name, *columns,
+ **col_name_col_instance):
+ """Drop unique constraint from table.
+
+ This method drops UC from table and works for mysql, postgresql and sqlite.
+ In mysql and postgresql we are able to use "alter table" construction.
+ Sqlalchemy doesn't support some sqlite column types and replaces their
+ type with NullType in metadata. We process these columns and replace
+ NullType with the correct column type.
+
+ :param migrate_engine: sqlalchemy engine
+ :param table_name: name of table that contains uniq constraint.
+ :param uc_name: name of uniq constraint that will be dropped.
+ :param columns: columns that are in uniq constraint.
+ :param col_name_col_instance: contains pair column_name=column_instance.
+ column_instance is instance of Column. These params
+ are required only for columns that have unsupported
+ types by sqlite. For example BigInteger.
+ """
+
+ meta = MetaData()
+ meta.bind = migrate_engine
+ t = Table(table_name, meta, autoload=True)
+
+ if migrate_engine.name == "sqlite":
+ override_cols = [
+ _get_not_supported_column(col_name_col_instance, col.name)
+ for col in t.columns
+ if isinstance(col.type, NullType)
+ ]
+ for col in override_cols:
+ t.columns.replace(col)
+
+ uc = UniqueConstraint(*columns, table=t, name=uc_name)
+ uc.drop()
+
+
+def drop_old_duplicate_entries_from_table(migrate_engine, table_name,
+ use_soft_delete, *uc_column_names):
+ """Drop all old rows having the same values for columns in uc_columns.
+
+ This method drop (or mark ad `deleted` if use_soft_delete is True) old
+ duplicate rows form table with name `table_name`.
+
+ :param migrate_engine: Sqlalchemy engine
+ :param table_name: Table with duplicates
+ :param use_soft_delete: If True - values will be marked as `deleted`,
+ if False - values will be removed from table
+ :param uc_column_names: Unique constraint columns
+ """
+ meta = MetaData()
+ meta.bind = migrate_engine
+
+ table = Table(table_name, meta, autoload=True)
+ columns_for_group_by = [table.c[name] for name in uc_column_names]
+
+ columns_for_select = [func.max(table.c.id)]
+ columns_for_select.extend(columns_for_group_by)
+
+ duplicated_rows_select = select(columns_for_select,
+ group_by=columns_for_group_by,
+ having=func.count(table.c.id) > 1)
+
+ for row in migrate_engine.execute(duplicated_rows_select):
+ # NOTE(boris-42): Do not remove row that has the biggest ID.
+ delete_condition = table.c.id != row[0]
+ is_none = None # workaround for pyflakes
+ delete_condition &= table.c.deleted_at == is_none
+ for name in uc_column_names:
+ delete_condition &= table.c[name] == row[name]
+
+ rows_to_delete_select = select([table.c.id]).where(delete_condition)
+ for row in migrate_engine.execute(rows_to_delete_select).fetchall():
+ LOG.info(_("Deleting duplicated row with id: %(id)s from table: "
+ "%(table)s") % dict(id=row[0], table=table_name))
+
+ if use_soft_delete:
+ delete_statement = table.update().\
+ where(delete_condition).\
+ values({
+ 'deleted': literal_column('id'),
+ 'updated_at': literal_column('updated_at'),
+ 'deleted_at': timeutils.utcnow()
+ })
+ else:
+ delete_statement = table.delete().where(delete_condition)
+ migrate_engine.execute(delete_statement)
+
+
+def _get_default_deleted_value(table):
+ if isinstance(table.c.id.type, Integer):
+ return 0
+ if isinstance(table.c.id.type, String):
+ return ""
+ raise ColumnError(_("Unsupported id columns type"))
+
+
+def _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes):
+ table = get_table(migrate_engine, table_name)
+
+ insp = reflection.Inspector.from_engine(migrate_engine)
+ real_indexes = insp.get_indexes(table_name)
+ existing_index_names = dict(
+ [(index['name'], index['column_names']) for index in real_indexes])
+
+ # NOTE(boris-42): Restore indexes on `deleted` column
+ for index in indexes:
+ if 'deleted' not in index['column_names']:
+ continue
+ name = index['name']
+ if name in existing_index_names:
+ column_names = [table.c[c] for c in existing_index_names[name]]
+ old_index = Index(name, *column_names, unique=index["unique"])
+ old_index.drop(migrate_engine)
+
+ column_names = [table.c[c] for c in index['column_names']]
+ new_index = Index(index["name"], *column_names, unique=index["unique"])
+ new_index.create(migrate_engine)
+
+
+def change_deleted_column_type_to_boolean(migrate_engine, table_name,
+ **col_name_col_instance):
+ if migrate_engine.name == "sqlite":
+ return _change_deleted_column_type_to_boolean_sqlite(
+ migrate_engine, table_name, **col_name_col_instance)
+ insp = reflection.Inspector.from_engine(migrate_engine)
+ indexes = insp.get_indexes(table_name)
+
+ table = get_table(migrate_engine, table_name)
+
+ old_deleted = Column('old_deleted', Boolean, default=False)
+ old_deleted.create(table, populate_default=False)
+
+ table.update().\
+ where(table.c.deleted == table.c.id).\
+ values(old_deleted=True).\
+ execute()
+
+ table.c.deleted.drop()
+ table.c.old_deleted.alter(name="deleted")
+
+ _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes)
+
+
+def _change_deleted_column_type_to_boolean_sqlite(migrate_engine, table_name,
+ **col_name_col_instance):
+ insp = reflection.Inspector.from_engine(migrate_engine)
+ table = get_table(migrate_engine, table_name)
+
+ columns = []
+ for column in table.columns:
+ column_copy = None
+ if column.name != "deleted":
+ if isinstance(column.type, NullType):
+ column_copy = _get_not_supported_column(col_name_col_instance,
+ column.name)
+ else:
+ column_copy = column.copy()
+ else:
+ column_copy = Column('deleted', Boolean, default=0)
+ columns.append(column_copy)
+
+ constraints = [constraint.copy() for constraint in table.constraints]
+
+ meta = table.metadata
+ new_table = Table(table_name + "__tmp__", meta,
+ *(columns + constraints))
+ new_table.create()
+
+ indexes = []
+ for index in insp.get_indexes(table_name):
+ column_names = [new_table.c[c] for c in index['column_names']]
+ indexes.append(Index(index["name"], *column_names,
+ unique=index["unique"]))
+
+ c_select = []
+ for c in table.c:
+ if c.name != "deleted":
+ c_select.append(c)
+ else:
+ c_select.append(table.c.deleted == table.c.id)
+
+ ins = InsertFromSelect(new_table, select(c_select))
+ migrate_engine.execute(ins)
+
+ table.drop()
+ [index.create(migrate_engine) for index in indexes]
+
+ new_table.rename(table_name)
+ new_table.update().\
+ where(new_table.c.deleted == new_table.c.id).\
+ values(deleted=True).\
+ execute()
+
+
+def change_deleted_column_type_to_id_type(migrate_engine, table_name,
+ **col_name_col_instance):
+ if migrate_engine.name == "sqlite":
+ return _change_deleted_column_type_to_id_type_sqlite(
+ migrate_engine, table_name, **col_name_col_instance)
+ insp = reflection.Inspector.from_engine(migrate_engine)
+ indexes = insp.get_indexes(table_name)
+
+ table = get_table(migrate_engine, table_name)
+
+ new_deleted = Column('new_deleted', table.c.id.type,
+ default=_get_default_deleted_value(table))
+ new_deleted.create(table, populate_default=True)
+
+ deleted = True # workaround for pyflakes
+ table.update().\
+ where(table.c.deleted == deleted).\
+ values(new_deleted=table.c.id).\
+ execute()
+ table.c.deleted.drop()
+ table.c.new_deleted.alter(name="deleted")
+
+ _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes)
+
+
+def _change_deleted_column_type_to_id_type_sqlite(migrate_engine, table_name,
+ **col_name_col_instance):
+ # NOTE(boris-42): sqlaclhemy-migrate can't drop column with check
+ # constraints in sqlite DB and our `deleted` column has
+ # 2 check constraints. So there is only one way to remove
+ # these constraints:
+ # 1) Create new table with the same columns, constraints
+ # and indexes. (except deleted column).
+ # 2) Copy all data from old to new table.
+ # 3) Drop old table.
+ # 4) Rename new table to old table name.
+ insp = reflection.Inspector.from_engine(migrate_engine)
+ meta = MetaData(bind=migrate_engine)
+ table = Table(table_name, meta, autoload=True)
+ default_deleted_value = _get_default_deleted_value(table)
+
+ columns = []
+ for column in table.columns:
+ column_copy = None
+ if column.name != "deleted":
+ if isinstance(column.type, NullType):
+ column_copy = _get_not_supported_column(col_name_col_instance,
+ column.name)
+ else:
+ column_copy = column.copy()
+ else:
+ column_copy = Column('deleted', table.c.id.type,
+ default=default_deleted_value)
+ columns.append(column_copy)
+
+ def is_deleted_column_constraint(constraint):
+ # NOTE(boris-42): There is no other way to check is CheckConstraint
+ # associated with deleted column.
+ if not isinstance(constraint, CheckConstraint):
+ return False
+ sqltext = str(constraint.sqltext)
+ return (sqltext.endswith("deleted in (0, 1)") or
+ sqltext.endswith("deleted IN (:deleted_1, :deleted_2)"))
+
+ constraints = []
+ for constraint in table.constraints:
+ if not is_deleted_column_constraint(constraint):
+ constraints.append(constraint.copy())
+
+ new_table = Table(table_name + "__tmp__", meta,
+ *(columns + constraints))
+ new_table.create()
+
+ indexes = []
+ for index in insp.get_indexes(table_name):
+ column_names = [new_table.c[c] for c in index['column_names']]
+ indexes.append(Index(index["name"], *column_names,
+ unique=index["unique"]))
+
+ ins = InsertFromSelect(new_table, table.select())
+ migrate_engine.execute(ins)
+
+ table.drop()
+ [index.create(migrate_engine) for index in indexes]
+
+ new_table.rename(table_name)
+ deleted = True # workaround for pyflakes
+ new_table.update().\
+ where(new_table.c.deleted == deleted).\
+ values(deleted=new_table.c.id).\
+ execute()
+
+ # NOTE(boris-42): Fix value of deleted column: False -> "" or 0.
+ deleted = False # workaround for pyflakes
+ new_table.update().\
+ where(new_table.c.deleted == deleted).\
+ values(deleted=default_deleted_value).\
+ execute()
+
+
+def get_connect_string(backend, database, user=None, passwd=None):
+ """Get database connection
+
+ Try to get a connection with a very specific set of values, if we get
+ these then we'll run the tests, otherwise they are skipped
+ """
+ args = {'backend': backend,
+ 'user': user,
+ 'passwd': passwd,
+ 'database': database}
+ if backend == 'sqlite':
+ template = '%(backend)s:///%(database)s'
+ else:
+ template = "%(backend)s://%(user)s:%(passwd)s@localhost/%(database)s"
+ return template % args
+
+
+def is_backend_avail(backend, database, user=None, passwd=None):
+ try:
+ connect_uri = get_connect_string(backend=backend,
+ database=database,
+ user=user,
+ passwd=passwd)
+ engine = sqlalchemy.create_engine(connect_uri)
+ connection = engine.connect()
+ except Exception:
+ # intentionally catch all to handle exceptions even if we don't
+ # have any backend code loaded.
+ return False
+ else:
+ connection.close()
+ engine.dispose()
+ return True
+
+
+def get_db_connection_info(conn_pieces):
+ database = conn_pieces.path.strip('/')
+ loc_pieces = conn_pieces.netloc.split('@')
+ host = loc_pieces[1]
+
+ auth_pieces = loc_pieces[0].split(':')
+ user = auth_pieces[0]
+ password = ""
+ if len(auth_pieces) > 1:
+ password = auth_pieces[1].strip()
+
+ return (user, password, database, host)
diff --git a/designate/storage/api.py b/designate/storage/api.py
index 083c6c5d..3d82eac8 100644
--- a/designate/storage/api.py
+++ b/designate/storage/api.py
@@ -55,14 +55,16 @@ class StorageAPI(object):
"""
return self.storage.get_quota(context, quota_id)
- def find_quotas(self, context, criterion=None):
+ def find_quotas(self, context, criterion=None, marker=None, limit=None,
+ sort_key=None, sort_dir=None):
"""
Find Quotas
:param context: RPC Context.
:param criterion: Criteria to filter by.
"""
- return self.storage.find_quotas(context, criterion)
+ return self.storage.find_quotas(
+ context, criterion, marker, limit, sort_key, sort_dir)
def find_quota(self, context, criterion):
"""
@@ -140,14 +142,16 @@ class StorageAPI(object):
"""
return self.storage.get_server(context, server_id)
- def find_servers(self, context, criterion=None):
+ def find_servers(self, context, criterion=None, marker=None, limit=None,
+ sort_key=None, sort_dir=None):
"""
Find Servers
:param context: RPC Context.
:param criterion: Criteria to filter by.
"""
- return self.storage.find_servers(context, criterion)
+ return self.storage.find_servers(
+ context, criterion, marker, limit, sort_key, sort_dir)
def find_server(self, context, criterion):
"""
@@ -225,14 +229,16 @@ class StorageAPI(object):
"""
return self.storage.get_tld(context, tld_id)
- def find_tlds(self, context, criterion=None):
+ def find_tlds(self, context, criterion=None, marker=None, limit=None,
+ sort_key=None, sort_dir=None):
"""
Find TLDs
:param context: RPC Context.
:param criterion: Criteria to filter by.
"""
- return self.storage.find_tlds(context, criterion)
+ return self.storage.find_tlds(
+ context, criterion, marker, limit, sort_key, sort_dir)
def find_tld(self, context, criterion):
"""
@@ -309,14 +315,16 @@ class StorageAPI(object):
"""
return self.storage.get_tsigkey(context, tsigkey_id)
- def find_tsigkeys(self, context, criterion=None):
+ def find_tsigkeys(self, context, criterion=None, marker=None, limit=None,
+ sort_key=None, sort_dir=None):
"""
Find Tsigkey
:param context: RPC Context.
:param criterion: Criteria to filter by.
"""
- return self.storage.find_tsigkeys(context, criterion)
+ return self.storage.find_tsigkeys(
+ context, criterion, marker, limit, sort_key, sort_dir)
def find_tsigkey(self, context, criterion):
"""
@@ -419,14 +427,16 @@ class StorageAPI(object):
"""
return self.storage.get_domain(context, domain_id)
- def find_domains(self, context, criterion=None):
+ def find_domains(self, context, criterion=None, marker=None, limit=None,
+ sort_key=None, sort_dir=None):
"""
Find Domains
:param context: RPC Context.
:param criterion: Criteria to filter by.
"""
- return self.storage.find_domains(context, criterion)
+ return self.storage.find_domains(
+ context, criterion, marker, limit, sort_key, sort_dir)
def find_domain(self, context, criterion):
"""
@@ -515,14 +525,16 @@ class StorageAPI(object):
"""
return self.storage.get_recordset(context, recordset_id)
- def find_recordsets(self, context, criterion=None):
+ def find_recordsets(self, context, criterion=None, marker=None, limit=None,
+ sort_key=None, sort_dir=None):
"""
Find RecordSets.
:param context: RPC Context.
:param criterion: Criteria to filter by.
"""
- return self.storage.find_recordsets(context, criterion)
+ return self.storage.find_recordsets(
+ context, criterion, marker, limit, sort_key, sort_dir)
def find_recordset(self, context, criterion=None):
"""
@@ -612,14 +624,16 @@ class StorageAPI(object):
"""
return self.storage.get_record(context, record_id)
- def find_records(self, context, criterion=None):
+ def find_records(self, context, criterion=None, marker=None, limit=None,
+ sort_key=None, sort_dir=None):
"""
Find Records.
:param context: RPC Context.
:param criterion: Criteria to filter by.
"""
- return self.storage.find_records(context, criterion)
+ return self.storage.find_records(
+ context, criterion, marker, limit, sort_key, sort_dir)
def find_record(self, context, criterion=None):
"""
@@ -705,14 +719,16 @@ class StorageAPI(object):
"""
return self.storage.get_blacklist(context, blacklist_id)
- def find_blacklists(self, context, criterion=None):
+ def find_blacklists(self, context, criterion=None, marker=None, limit=None,
+ sort_key=None, sort_dir=None):
"""
Find all Blacklisted Domains
:param context: RPC Context.
:param criterion: Criteria to filter by.
"""
- return self.storage.find_blacklists(context, criterion)
+ return self.storage.find_blacklists(
+ context, criterion, marker, limit, sort_key, sort_dir)
def find_blacklist(self, context, criterion):
"""
diff --git a/designate/storage/base.py b/designate/storage/base.py
index 11f650d8..dcb1ddbc 100644
--- a/designate/storage/base.py
+++ b/designate/storage/base.py
@@ -18,6 +18,7 @@ from designate.plugin import DriverPlugin
class Storage(DriverPlugin):
+
""" Base class for storage plugins """
__metaclass__ = abc.ABCMeta
__plugin_ns__ = 'designate.storage'
@@ -42,12 +43,19 @@ class Storage(DriverPlugin):
"""
@abc.abstractmethod
- def find_quotas(self, context, criterion):
+ def find_quotas(self, context, criterion=None, marker=None,
+ limit=None, sort_key=None, sort_dir=None):
"""
Find Quotas
:param context: RPC Context.
:param criterion: Criteria to filter by.
+ :param marker: Resource ID from which after the requested page will
+ start after
+ :param limit: Integer limit of objects of the page size after the
+ marker
+ :param sort_key: Key from which to sort after.
+ :param sort_dir: Direction to sort after using sort_key.
"""
@abc.abstractmethod
@@ -88,12 +96,19 @@ class Storage(DriverPlugin):
"""
@abc.abstractmethod
- def find_servers(self, context, criterion=None):
+ def find_servers(self, context, criterion=None, marker=None,
+ limit=None, sort_key=None, sort_dir=None):
"""
Find Servers.
:param context: RPC Context.
:param criterion: Criteria to filter by.
+ :param marker: Resource ID from which after the requested page will
+ start after
+ :param limit: Integer limit of objects of the page size after the
+ marker
+ :param sort_key: Key from which to sort after.
+ :param sort_dir: Direction to sort after using sort_key.
"""
@abc.abstractmethod
@@ -143,21 +158,34 @@ class Storage(DriverPlugin):
"""
@abc.abstractmethod
- def find_tlds(self, context, criterion=None):
+ def find_tlds(self, context, criterion=None, marker=None,
+ limit=None, sort_key=None, sort_dir=None):
"""
Find TLDs
:param context: RPC Context.
:param criterion: Criteria to filter by.
+ :param marker: Resource ID from which after the requested page will
+ start after
+ :param limit: Integer limit of objects of the page size after the
+ marker
+ :param sort_key: Key from which to sort after.
+ :param sort_dir: Direction to sort after using sort_key.
"""
@abc.abstractmethod
- def find_tld(self, context, criterion):
+ def find_tld(self, context, criterion=None):
"""
Find a single TLD.
:param context: RPC Context.
:param criterion: Criteria to filter by.
+ :param marker: Resource ID from which after the requested page will
+ start after
+ :param limit: Integer limit of objects of the page size after the
+ marker
+ :param sort_key: Key from which to sort after.
+ :param sort_dir: Direction to sort after using sort_key.
"""
@abc.abstractmethod
@@ -188,12 +216,19 @@ class Storage(DriverPlugin):
"""
@abc.abstractmethod
- def find_tsigkeys(self, context, criterion=None):
+ def find_tsigkeys(self, context, criterion=None,
+ marker=None, limit=None, sort_key=None, sort_dir=None):
"""
Find TSIG Keys.
:param context: RPC Context.
:param criterion: Criteria to filter by.
+ :param marker: Resource ID from which after the requested page will
+ start after
+ :param limit: Integer limit of objects of the page size after the
+ marker
+ :param sort_key: Key from which to sort after.
+ :param sort_dir: Direction to sort after using sort_key.
"""
@abc.abstractmethod
@@ -268,12 +303,19 @@ class Storage(DriverPlugin):
"""
@abc.abstractmethod
- def find_domains(self, context, criterion=None):
+ def find_domains(self, context, criterion=None, marker=None,
+ limit=None, sort_key=None, sort_dir=None):
"""
Find Domains
:param context: RPC Context.
:param criterion: Criteria to filter by.
+ :param marker: Resource ID from which after the requested page will
+ start after
+ :param limit: Integer limit of objects of the page size after the
+ marker
+ :param sort_key: Key from which to sort after.
+ :param sort_dir: Direction to sort after using sort_key.
"""
@abc.abstractmethod
@@ -333,13 +375,20 @@ class Storage(DriverPlugin):
"""
@abc.abstractmethod
- def find_recordsets(self, context, domain_id, criterion=None):
+ def find_recordsets(self, context, criterion=None,
+ marker=None, limit=None, sort_key=None, sort_dir=None):
"""
Find RecordSets.
:param context: RPC Context.
:param domain_id: Domain ID where the recordsets reside.
:param criterion: Criteria to filter by.
+ :param marker: Resource ID from which after the requested page will
+ start after
+ :param limit: Integer limit of objects of the page size after the
+ marker
+ :param sort_key: Key from which to sort after.
+ :param sort_dir: Direction to sort after using sort_key.
"""
@abc.abstractmethod
@@ -399,12 +448,19 @@ class Storage(DriverPlugin):
"""
@abc.abstractmethod
- def find_records(self, context, criterion=None):
+ def find_records(self, context, criterion=None, marker=None,
+ limit=None, sort_key=None, sort_dir=None):
"""
Find Records.
:param context: RPC Context.
:param criterion: Criteria to filter by.
+ :param marker: Resource ID from which after the requested page will
+ start after
+ :param limit: Integer limit of objects of the page size after the
+ marker
+ :param sort_key: Key from which to sort after.
+ :param sort_dir: Direction to sort after using sort_key.
"""
@abc.abstractmethod
@@ -462,12 +518,19 @@ class Storage(DriverPlugin):
"""
@abc.abstractmethod
- def find_blacklists(self, context, criterion):
+ def find_blacklists(self, context, criterion=None, marker=None,
+ limit=None, sort_key=None, sort_dir=None):
"""
Find Blacklists
:param context: RPC Context.
:param criterion: Criteria to filter by.
+ :param marker: Resource ID from which after the requested page will
+ start after
+ :param limit: Integer limit of objects of the page size after the
+ marker
+ :param sort_key: Key from which to sort after.
+ :param sort_dir: Direction to sort after using sort_key.
"""
@abc.abstractmethod
diff --git a/designate/storage/impl_sqlalchemy/__init__.py b/designate/storage/impl_sqlalchemy/__init__.py
index d49f6112..a8a86e3c 100644
--- a/designate/storage/impl_sqlalchemy/__init__.py
+++ b/designate/storage/impl_sqlalchemy/__init__.py
@@ -18,6 +18,7 @@ from sqlalchemy.orm import exc
from sqlalchemy import distinct, func
from oslo.config import cfg
from designate.openstack.common import log as logging
+from designate.openstack.common.db.sqlalchemy.utils import paginate_query
from designate import exceptions
from designate.storage import base
from designate.storage.impl_sqlalchemy import models
@@ -26,6 +27,7 @@ from designate.sqlalchemy.session import get_session
from designate.sqlalchemy.session import get_engine
from designate.sqlalchemy.session import SQLOPTS
+
LOG = logging.getLogger(__name__)
cfg.CONF.register_group(cfg.OptGroup(
@@ -92,7 +94,8 @@ class SQLAlchemyStorage(base.Storage):
return query
- def _find(self, model, context, criterion, one=False):
+ def _find(self, model, context, criterion, one=False,
+ marker=None, limit=None, sort_key=None, sort_dir=None):
"""
Base "finder" method
@@ -112,7 +115,23 @@ class SQLAlchemyStorage(base.Storage):
except (exc.NoResultFound, exc.MultipleResultsFound):
raise exceptions.NotFound()
else:
+ # If marker is not none and basestring we query it.
# Othwewise, return all matching records
+ if marker is not None:
+ try:
+ marker = self._find(model, context, {'id': marker},
+ one=True)
+ except exceptions.NotFound:
+ raise exceptions.MarkerNotFound(
+ 'Marker %s could not be found' % marker)
+ sort_key = sort_key or 'created_at'
+ sort_dir = sort_dir or 'asc'
+
+ query = paginate_query(
+ query, model, limit,
+ [sort_key, 'id', 'created_at'], marker=marker,
+ sort_dir=sort_dir)
+
return query.all()
## CRUD for our resources (quota, server, tsigkey, tenant, domain & record)
@@ -126,9 +145,12 @@ class SQLAlchemyStorage(base.Storage):
##
# Quota Methods
- def _find_quotas(self, context, criterion, one=False):
+ def _find_quotas(self, context, criterion, one=False,
+ marker=None, limit=None, sort_key=None, sort_dir=None):
try:
- return self._find(models.Quota, context, criterion, one)
+ return self._find(models.Quota, context, criterion, one=one,
+ marker=marker, limit=limit, sort_key=sort_key,
+ sort_dir=sort_dir)
except exceptions.NotFound:
raise exceptions.QuotaNotFound()
@@ -149,8 +171,11 @@ class SQLAlchemyStorage(base.Storage):
return dict(quota)
- def find_quotas(self, context, criterion=None):
- quotas = self._find_quotas(context, criterion)
+ def find_quotas(self, context, criterion=None,
+ marker=None, limit=None, sort_key=None, sort_dir=None):
+ quotas = self._find_quotas(context, criterion, marker=marker,
+ limit=limit, sort_key=sort_key,
+ sort_dir=sort_dir)
return [dict(q) for q in quotas]
@@ -177,9 +202,12 @@ class SQLAlchemyStorage(base.Storage):
quota.delete(self.session)
# Server Methods
- def _find_servers(self, context, criterion, one=False):
+ def _find_servers(self, context, criterion, one=False,
+ marker=None, limit=None, sort_key=None, sort_dir=None):
try:
- return self._find(models.Server, context, criterion, one)
+ return self._find(models.Server, context, criterion, one,
+ marker=marker, limit=limit, sort_key=sort_key,
+ sort_dir=sort_dir)
except exceptions.NotFound:
raise exceptions.ServerNotFound()
@@ -195,9 +223,11 @@ class SQLAlchemyStorage(base.Storage):
return dict(server)
- def find_servers(self, context, criterion=None):
- servers = self._find_servers(context, criterion)
-
+ def find_servers(self, context, criterion=None,
+ marker=None, limit=None, sort_key=None, sort_dir=None):
+ servers = self._find_servers(context, criterion, marker=marker,
+ limit=limit, sort_key=sort_key,
+ sort_dir=sort_dir)
return [dict(s) for s in servers]
def get_server(self, context, server_id):
@@ -222,9 +252,12 @@ class SQLAlchemyStorage(base.Storage):
server.delete(self.session)
# TLD Methods
- def _find_tlds(self, context, criterion, one=False):
+ def _find_tlds(self, context, criterion, one=False,
+ marker=None, limit=None, sort_key=None, sort_dir=None):
try:
- return self._find(models.Tld, context, criterion, one)
+ return self._find(models.Tld, context, criterion, one=one,
+ marker=marker, limit=limit, sort_key=sort_key,
+ sort_dir=sort_dir)
except exceptions.NotFound:
raise exceptions.TLDNotFound()
@@ -239,8 +272,10 @@ class SQLAlchemyStorage(base.Storage):
return dict(tld)
- def find_tlds(self, context, criterion=None):
- tlds = self._find_tlds(context, criterion)
+ def find_tlds(self, context, criterion=None,
+ marker=None, limit=None, sort_key=None, sort_dir=None):
+ tlds = self._find_tlds(context, criterion, marker=marker, limit=limit,
+ sort_key=sort_key, sort_dir=sort_dir)
return [dict(s) for s in tlds]
def find_tld(self, context, criterion=None):
@@ -267,9 +302,12 @@ class SQLAlchemyStorage(base.Storage):
tld.delete(self.session)
# TSIG Key Methods
- def _find_tsigkeys(self, context, criterion, one=False):
+ def _find_tsigkeys(self, context, criterion, one=False,
+ marker=None, limit=None, sort_key=None, sort_dir=None):
try:
- return self._find(models.TsigKey, context, criterion, one)
+ return self._find(models.TsigKey, context, criterion, one=one,
+ marker=marker, limit=limit, sort_key=sort_key,
+ sort_dir=sort_dir)
except exceptions.NotFound:
raise exceptions.TsigKeyNotFound()
@@ -285,8 +323,11 @@ class SQLAlchemyStorage(base.Storage):
return dict(tsigkey)
- def find_tsigkeys(self, context, criterion=None):
- tsigkeys = self._find_tsigkeys(context, criterion)
+ def find_tsigkeys(self, context, criterion=None,
+ marker=None, limit=None, sort_key=None, sort_dir=None):
+ tsigkeys = self._find_tsigkeys(context, criterion, marker=marker,
+ limit=limit, sort_key=sort_key,
+ sort_dir=sort_dir)
return [dict(t) for t in tsigkeys]
@@ -352,9 +393,12 @@ class SQLAlchemyStorage(base.Storage):
##
## Domain Methods
##
- def _find_domains(self, context, criterion, one=False):
+ def _find_domains(self, context, criterion, one=False,
+ marker=None, limit=None, sort_key=None, sort_dir=None):
try:
- return self._find(models.Domain, context, criterion, one)
+ return self._find(models.Domain, context, criterion, one=one,
+ marker=marker, limit=limit, sort_key=sort_key,
+ sort_dir=sort_dir)
except exceptions.NotFound:
raise exceptions.DomainNotFound()
@@ -375,8 +419,11 @@ class SQLAlchemyStorage(base.Storage):
return dict(domain)
- def find_domains(self, context, criterion=None):
- domains = self._find_domains(context, criterion)
+ def find_domains(self, context, criterion=None,
+ marker=None, limit=None, sort_key=None, sort_dir=None):
+ domains = self._find_domains(context, criterion, marker=marker,
+ limit=limit, sort_key=sort_key,
+ sort_dir=sort_dir)
return [dict(d) for d in domains]
@@ -412,9 +459,13 @@ class SQLAlchemyStorage(base.Storage):
return query.count()
# RecordSet Methods
- def _find_recordsets(self, context, criterion, one=False):
+ def _find_recordsets(self, context, criterion, one=False,
+ marker=None, limit=None, sort_key=None,
+ sort_dir=None):
try:
- return self._find(models.RecordSet, context, criterion, one)
+ return self._find(models.RecordSet, context, criterion, one=one,
+ marker=marker, limit=limit, sort_key=sort_key,
+ sort_dir=sort_dir)
except exceptions.NotFound:
raise exceptions.RecordSetNotFound()
@@ -441,8 +492,11 @@ class SQLAlchemyStorage(base.Storage):
return dict(recordset)
- def find_recordsets(self, context, criterion=None):
- recordsets = self._find_recordsets(context, criterion)
+ def find_recordsets(self, context, criterion=None,
+ marker=None, limit=None, sort_key=None, sort_dir=None):
+ recordsets = self._find_recordsets(
+ context, criterion, marker=marker, limit=limit, sort_key=sort_key,
+ sort_dir=sort_dir)
return [dict(r) for r in recordsets]
@@ -479,9 +533,12 @@ class SQLAlchemyStorage(base.Storage):
return query.count()
# Record Methods
- def _find_records(self, context, criterion, one=False):
+ def _find_records(self, context, criterion, one=False,
+ marker=None, limit=None, sort_key=None, sort_dir=None):
try:
- return self._find(models.Record, context, criterion, one)
+ return self._find(models.Record, context, criterion, one=one,
+ marker=marker, limit=limit, sort_key=sort_key,
+ sort_dir=sort_dir)
except exceptions.NotFound:
raise exceptions.RecordNotFound()
@@ -506,8 +563,11 @@ class SQLAlchemyStorage(base.Storage):
return dict(record)
- def find_records(self, context, criterion=None):
- records = self._find_records(context, criterion)
+ def find_records(self, context, criterion=None,
+ marker=None, limit=None, sort_key=None, sort_dir=None):
+ records = self._find_records(
+ context, criterion, marker=marker, limit=limit, sort_key=sort_key,
+ sort_dir=sort_dir)
return [dict(r) for r in records]
@@ -549,9 +609,12 @@ class SQLAlchemyStorage(base.Storage):
#
# Blacklist Methods
#
- def _find_blacklist(self, context, criterion, one=False):
+ def _find_blacklist(self, context, criterion, one=False,
+ marker=None, limit=None, sort_key=None, sort_dir=None):
try:
- return self._find(models.Blacklists, context, criterion, one)
+ return self._find(models.Blacklists, context, criterion, one=one,
+ marker=marker, limit=limit, sort_key=sort_key,
+ sort_dir=sort_dir)
except exceptions.NotFound:
raise exceptions.BlacklistNotFound()
@@ -567,8 +630,11 @@ class SQLAlchemyStorage(base.Storage):
return dict(blacklist)
- def find_blacklists(self, context, criterion=None):
- blacklists = self._find_blacklist(context, criterion)
+ def find_blacklists(self, context, criterion=None,
+ marker=None, limit=None, sort_key=None, sort_dir=None):
+ blacklists = self._find_blacklist(
+ context, criterion, marker=marker, limit=limit, sort_key=sort_key,
+ sort_dir=sort_dir)
return [dict(b) for b in blacklists]
diff --git a/designate/tests/test_central/test_service.py b/designate/tests/test_central/test_service.py
index cff71192..41f4ec07 100644
--- a/designate/tests/test_central/test_service.py
+++ b/designate/tests/test_central/test_service.py
@@ -968,8 +968,8 @@ class CentralServiceTest(CentralTestCase):
self.admin_context, criterion)
self.assertEqual(len(recordsets), 2)
- self.assertEqual(recordsets[0]['name'], 'mail.%s' % domain['name'])
- self.assertEqual(recordsets[1]['name'], 'www.%s' % domain['name'])
+ self.assertEqual(recordsets[0]['name'], 'www.%s' % domain['name'])
+ self.assertEqual(recordsets[1]['name'], 'mail.%s' % domain['name'])
def test_find_recordset(self):
domain = self.create_domain()
diff --git a/designate/tests/test_storage/__init__.py b/designate/tests/test_storage/__init__.py
index 2668f6e4..076e49fc 100644
--- a/designate/tests/test_storage/__init__.py
+++ b/designate/tests/test_storage/__init__.py
@@ -14,6 +14,7 @@
# License for the specific language governing permissions and limitations
# under the License.
import testtools
+import uuid
from designate.openstack.common import log as logging
from designate import exceptions
@@ -77,6 +78,26 @@ class StorageTestCase(object):
return fixture, self.storage.create_record(
context, domain['id'], recordset['id'], fixture)
+ def _ensure_paging(self, data, method):
+ """
+ Given an array of created items we iterate through them making sure
+ they match up to things returned by paged results.
+ """
+ found = method(self.admin_context, limit=5)
+ x = 0
+ for i in xrange(0, len(data)):
+ self.assertEqual(data[i]['id'], found[x]['id'])
+ x += 1
+ if x == len(found):
+ x = 0
+ found = method(
+ self.admin_context, limit=5, marker=found[-1:][0]['id'])
+
+ def test_paging_marker_not_found(self):
+ with testtools.ExpectedException(exceptions.MarkerNotFound):
+ self.storage.find_servers(
+ self.admin_context, marker=str(uuid.uuid4()), limit=5)
+
# Quota Tests
def test_create_quota(self):
values = self.get_quota_fixture()
@@ -270,20 +291,20 @@ class StorageTestCase(object):
self.assertEqual(actual, [])
# Create a single server
- _, server_one = self.create_server()
+ _, server = self.create_server()
actual = self.storage.find_servers(self.admin_context)
self.assertEqual(len(actual), 1)
+ self.assertEqual(str(actual[0]['name']), str(server['name']))
- self.assertEqual(str(actual[0]['name']), str(server_one['name']))
-
- # Create a second server
- _, server_two = self.create_server(fixture=1)
-
- actual = self.storage.find_servers(self.admin_context)
- self.assertEqual(len(actual), 2)
+ # Order of found items later will be reverse of the order they are
+ # created
+ created = [self.create_server(
+ values={'name': 'ns%s.example.org.' % i})[1]
+ for i in xrange(10, 20)]
+ created.insert(0, server)
- self.assertEqual(str(actual[1]['name']), str(server_two['name']))
+ self._ensure_paging(created, self.storage.find_servers)
def test_find_servers_criterion(self):
_, server_one = self.create_server(0)
@@ -388,24 +409,22 @@ class StorageTestCase(object):
self.assertEqual(actual, [])
# Create a single tsigkey
- _, tsigkey_one = self.create_tsigkey()
+ _, tsig = self.create_tsigkey()
actual = self.storage.find_tsigkeys(self.admin_context)
self.assertEqual(len(actual), 1)
- self.assertEqual(actual[0]['name'], tsigkey_one['name'])
- self.assertEqual(actual[0]['algorithm'], tsigkey_one['algorithm'])
- self.assertEqual(actual[0]['secret'], tsigkey_one['secret'])
+ self.assertEqual(actual[0]['name'], tsig['name'])
+ self.assertEqual(actual[0]['algorithm'], tsig['algorithm'])
+ self.assertEqual(actual[0]['secret'], tsig['secret'])
- # Create a second tsigkey
- _, tsigkey_two = self.create_tsigkey(fixture=1)
+ # Order of found items later will be reverse of the order they are
+ # created
+ created = [self.create_tsigkey(values={'name': 'tsig%s.' % i})[1]
+ for i in xrange(10, 20)]
+ created.insert(0, tsig)
- actual = self.storage.find_tsigkeys(self.admin_context)
- self.assertEqual(len(actual), 2)
-
- self.assertEqual(actual[1]['name'], tsigkey_two['name'])
- self.assertEqual(actual[1]['algorithm'], tsigkey_two['algorithm'])
- self.assertEqual(actual[1]['secret'], tsigkey_two['secret'])
+ self._ensure_paging(created, self.storage.find_tsigkeys)
def test_find_tsigkeys_criterion(self):
_, tsigkey_one = self.create_tsigkey(fixture=0)
@@ -582,19 +601,21 @@ class StorageTestCase(object):
self.assertEqual(actual, [])
# Create a single domain
- fixture_one, domain_one = self.create_domain()
+ fixture_one, domain = self.create_domain()
actual = self.storage.find_domains(self.admin_context)
self.assertEqual(len(actual), 1)
- self.assertEqual(actual[0]['name'], domain_one['name'])
- self.assertEqual(actual[0]['email'], domain_one['email'])
+ self.assertEqual(actual[0]['name'], domain['name'])
+ self.assertEqual(actual[0]['email'], domain['email'])
- # Create a second domain
- self.create_domain(fixture=1)
+ # Order of found items later will be reverse of the order they are
+ # created
+ created = [self.create_domain(values={'name': 'x%s.org.' % i})[1]
+ for i in xrange(10, 20)]
+ created.insert(0, domain)
- actual = self.storage.find_domains(self.admin_context)
- self.assertEqual(len(actual), 2)
+ self._ensure_paging(created, self.storage.find_domains)
def test_find_domains_criterion(self):
_, domain_one = self.create_domain(0)
@@ -812,14 +833,14 @@ class StorageTestCase(object):
self.assertEqual(actual[0]['name'], recordset_one['name'])
self.assertEqual(actual[0]['type'], recordset_one['type'])
- # Create a second recordset
- _, recordset_two = self.create_recordset(domain, fixture=1)
-
- actual = self.storage.find_recordsets(self.admin_context, criterion)
- self.assertEqual(len(actual), 2)
+ # Order of found items later will be reverse of the order they are
+ # created
+ created = [self.create_recordset(
+ domain, values={'name': 'test%s' % i + '.%s'})[1]
+ for i in xrange(10, 20)]
+ created.insert(0, recordset_one)
- self.assertEqual(actual[1]['name'], recordset_two['name'])
- self.assertEqual(actual[1]['type'], recordset_two['type'])
+ self._ensure_paging(created, self.storage.find_recordsets)
def test_find_recordsets_criterion(self):
_, domain = self.create_domain()
@@ -1016,22 +1037,23 @@ class StorageTestCase(object):
self.assertEqual(actual, [])
# Create a single record
- _, record_one = self.create_record(domain, recordset, fixture=0)
+ _, record = self.create_record(domain, recordset, fixture=0)
actual = self.storage.find_records(self.admin_context, criterion)
self.assertEqual(len(actual), 1)
- self.assertEqual(actual[0]['data'], record_one['data'])
- self.assertIn('status', record_one)
-
- # Create a second record
- _, record_two = self.create_record(domain, recordset, fixture=1)
+ self.assertEqual(actual[0]['data'], record['data'])
+ self.assertIn('status', record)
- actual = self.storage.find_records(self.admin_context, criterion)
- self.assertEqual(len(actual), 2)
+ # Order of found items later will be reverse of the order they are
+ # created
+ created = [self.create_record(
+ domain, recordset,
+ values={'data': '192.0.0.%s' % i})[1]
+ for i in xrange(10, 20)]
+ created.insert(0, record)
- self.assertEqual(actual[1]['data'], record_two['data'])
- self.assertIn('status', record_two)
+ self._ensure_paging(created, self.storage.find_records)
def test_find_records_criterion(self):
_, domain = self.create_domain()
diff --git a/designate/tests/test_storage/test_api.py b/designate/tests/test_storage/test_api.py
index 9062852e..03b55e8b 100644
--- a/designate/tests/test_storage/test_api.py
+++ b/designate/tests/test_storage/test_api.py
@@ -90,12 +90,20 @@ class StorageAPITest(TestCase):
def test_find_quotas(self):
context = mock.sentinel.context
criterion = mock.sentinel.criterion
+ marker = mock.sentinel.marker
+ limit = mock.sentinel.limit
+ sort_key = mock.sentinel.sort_key
+ sort_dir = mock.sentinel.sort_dir
quota = mock.sentinel.quota
self._set_side_effect('find_quotas', [[quota]])
- result = self.storage_api.find_quotas(context, criterion)
- self._assert_called_with('find_quotas', context, criterion)
+ result = self.storage_api.find_quotas(
+ context, criterion,
+ marker, limit, sort_key, sort_dir)
+ self._assert_called_with(
+ 'find_quotas', context, criterion,
+ marker, limit, sort_key, sort_dir)
self.assertEqual([quota], result)
def test_find_quota(self):
@@ -198,12 +206,21 @@ class StorageAPITest(TestCase):
def test_find_servers(self):
context = mock.sentinel.context
criterion = mock.sentinel.criterion
+ marker = mock.sentinel.marker
+ limit = mock.sentinel.limit
+ sort_key = mock.sentinel.sort_key
+ sort_dir = mock.sentinel.sort_dir
+
server = mock.sentinel.server
self._set_side_effect('find_servers', [[server]])
- result = self.storage_api.find_servers(context, criterion)
- self._assert_called_with('find_servers', context, criterion)
+ result = self.storage_api.find_servers(
+ context, criterion,
+ marker, limit, sort_key, sort_dir)
+ self._assert_called_with(
+ 'find_servers', context, criterion,
+ marker, limit, sort_key, sort_dir)
self.assertEqual([server], result)
def test_find_server(self):
@@ -306,12 +323,19 @@ class StorageAPITest(TestCase):
def test_find_tsigkeys(self):
context = mock.sentinel.context
criterion = mock.sentinel.criterion
+ marker = mock.sentinel.marker
+ limit = mock.sentinel.limit
+ sort_key = mock.sentinel.sort_key
+ sort_dir = mock.sentinel.sort_dir
tsigkey = mock.sentinel.tsigkey
self._set_side_effect('find_tsigkeys', [[tsigkey]])
- result = self.storage_api.find_tsigkeys(context, criterion)
- self._assert_called_with('find_tsigkeys', context, criterion)
+ result = self.storage_api.find_tsigkeys(
+ context, criterion, marker, limit, sort_key, sort_dir)
+ self._assert_called_with(
+ 'find_tsigkeys', context, criterion,
+ marker, limit, sort_key, sort_dir)
self.assertEqual([tsigkey], result)
def test_find_tsigkey(self):
@@ -444,12 +468,20 @@ class StorageAPITest(TestCase):
def test_find_domains(self):
context = mock.sentinel.context
criterion = mock.sentinel.criterion
+ marker = mock.sentinel.marker
+ limit = mock.sentinel.limit
+ sort_key = mock.sentinel.sort_key
+ sort_dir = mock.sentinel.sort_dir
domain = mock.sentinel.domain
self._set_side_effect('find_domains', [[domain]])
- result = self.storage_api.find_domains(context, criterion)
- self._assert_called_with('find_domains', context, criterion)
+ result = self.storage_api.find_domains(
+ context, criterion,
+ marker, limit, sort_key, sort_dir)
+ self._assert_called_with(
+ 'find_domains', context, criterion,
+ marker, limit, sort_key, sort_dir)
self.assertEqual([domain], result)
def test_find_domain(self):
@@ -552,12 +584,20 @@ class StorageAPITest(TestCase):
def test_find_recordsets(self):
context = mock.sentinel.context
criterion = mock.sentinel.criterion
+ marker = mock.sentinel.marker
+ limit = mock.sentinel.limit
+ sort_key = mock.sentinel.sort_key
+ sort_dir = mock.sentinel.sort_dir
recordset = mock.sentinel.recordset
self._set_side_effect('find_recordsets', [[recordset]])
- result = self.storage_api.find_recordsets(context, criterion)
- self._assert_called_with('find_recordsets', context, criterion)
+ result = self.storage_api.find_recordsets(
+ context, criterion,
+ marker, limit, sort_key, sort_dir)
+ self._assert_called_with(
+ 'find_recordsets', context, criterion,
+ marker, limit, sort_key, sort_dir)
self.assertEqual([recordset], result)
def test_find_recordset(self):
@@ -660,12 +700,20 @@ class StorageAPITest(TestCase):
def test_find_records(self):
context = mock.sentinel.context
criterion = mock.sentinel.criterion
+ marker = mock.sentinel.marker
+ limit = mock.sentinel.limit
+ sort_key = mock.sentinel.sort_key
+ sort_dir = mock.sentinel.sort_dir
record = mock.sentinel.record
self._set_side_effect('find_records', [[record]])
- result = self.storage_api.find_records(context, criterion)
- self._assert_called_with('find_records', context, criterion)
+ result = self.storage_api.find_records(
+ context, criterion,
+ marker, limit, sort_key, sort_dir)
+ self._assert_called_with(
+ 'find_records', context, criterion,
+ marker, limit, sort_key, sort_dir)
self.assertEqual([record], result)
diff --git a/openstack-common.conf b/openstack-common.conf
index 818fa05f..9edf292b 100644
--- a/openstack-common.conf
+++ b/openstack-common.conf
@@ -2,6 +2,8 @@
# The list of modules to copy from oslo-incubator.git
module=context
+module=db
+module=db.sqlalchemy
module=excutils
module=fixture
module=gettextutils
diff --git a/requirements.txt b/requirements.txt
index 883137c1..5177ac2e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,6 +6,7 @@ Flask>=0.10,<1.0
iso8601>=0.1.8
jsonschema>=2.0.0,<3.0.0
kombu>=2.4.8
+lockfile>=0.8
netaddr>=0.7.6
oslo.config>=1.2.0
oslo.rootwrap