summaryrefslogtreecommitdiff
path: root/oslo_db/sqlalchemy/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'oslo_db/sqlalchemy/utils.py')
-rw-r--r--oslo_db/sqlalchemy/utils.py1012
1 files changed, 1012 insertions, 0 deletions
diff --git a/oslo_db/sqlalchemy/utils.py b/oslo_db/sqlalchemy/utils.py
new file mode 100644
index 0000000..6a66bb9
--- /dev/null
+++ b/oslo_db/sqlalchemy/utils.py
@@ -0,0 +1,1012 @@
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# Copyright 2010-2011 OpenStack Foundation.
+# Copyright 2012 Justin Santa Barbara
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import collections
+import logging
+import re
+
+from oslo.utils import timeutils
+import six
+import sqlalchemy
+from sqlalchemy import Boolean
+from sqlalchemy import CheckConstraint
+from sqlalchemy import Column
+from sqlalchemy.engine import Connectable
+from sqlalchemy.engine import reflection
+from sqlalchemy.engine import url as sa_url
+from sqlalchemy.ext.compiler import compiles
+from sqlalchemy import func
+from sqlalchemy import Index
+from sqlalchemy import Integer
+from sqlalchemy import MetaData
+from sqlalchemy.sql.expression import literal_column
+from sqlalchemy.sql.expression import UpdateBase
+from sqlalchemy.sql import text
+from sqlalchemy import String
+from sqlalchemy import Table
+from sqlalchemy.types import NullType
+
+from oslo_db import exception
+from oslo_db._i18n import _, _LI, _LW
+from oslo_db.sqlalchemy import models
+
+# NOTE(ochuprykov): Add references for backwards compatibility
+InvalidSortKey = exception.InvalidSortKey
+ColumnError = exception.ColumnError
+
+LOG = logging.getLogger(__name__)
+
+_DBURL_REGEX = re.compile(r"[^:]+://([^:]+):([^@]+)@.+")
+
+
+def get_callable_name(function):
+ # TODO(harlowja): Replace this once
+ # it is possible to use https://review.openstack.org/#/c/122495/ which is
+ # a more complete and expansive module that does a similar thing...
+ try:
+ method_self = six.get_method_self(function)
+ except AttributeError:
+ method_self = None
+ if method_self is not None:
+ if isinstance(method_self, six.class_types):
+ im_class = method_self
+ else:
+ im_class = type(method_self)
+ try:
+ parts = (im_class.__module__, function.__qualname__)
+ except AttributeError:
+ parts = (im_class.__module__, im_class.__name__, function.__name__)
+ else:
+ try:
+ parts = (function.__module__, function.__qualname__)
+ except AttributeError:
+ parts = (function.__module__, function.__name__)
+ return '.'.join(parts)
+
+
+def sanitize_db_url(url):
+ match = _DBURL_REGEX.match(url)
+ if match:
+ return '%s****:****%s' % (url[:match.start(1)], url[match.end(2):])
+ return url
+
+
+# copy from glance/db/sqlalchemy/api.py
+def paginate_query(query, model, limit, sort_keys, marker=None,
+ sort_dir=None, sort_dirs=None):
+ """Returns a query with sorting / pagination criteria added.
+
+ Pagination works by requiring a unique sort_key, specified by sort_keys.
+ (If sort_keys is not unique, then we risk looping through values.)
+ We use the last row in the previous page as the 'marker' for pagination.
+ So we must return values that follow the passed marker in the order.
+ With a single-valued sort_key, this would be easy: sort_key > X.
+ With a compound-values sort_key, (k1, k2, k3) we must do this to repeat
+ the lexicographical ordering:
+ (k1 > X1) or (k1 == X1 && k2 > X2) or (k1 == X1 && k2 == X2 && k3 > X3)
+
+ We also have to cope with different sort_directions.
+
+ Typically, the id of the last row is used as the client-facing pagination
+ marker, then the actual marker object must be fetched from the db and
+ passed in to us as marker.
+
+ :param query: the query object to which we should add paging/sorting
+ :param model: the ORM model class
+ :param limit: maximum number of items to return
+ :param sort_keys: array of attributes by which results should be sorted
+ :param marker: the last item of the previous page; we returns the next
+ results after this value.
+ :param sort_dir: direction in which results should be sorted (asc, desc)
+ :param sort_dirs: per-column array of sort_dirs, corresponding to sort_keys
+
+ :rtype: sqlalchemy.orm.query.Query
+ :return: The query with sorting/pagination added.
+ """
+
+ if 'id' not in sort_keys:
+ # TODO(justinsb): If this ever gives a false-positive, check
+ # the actual primary key, rather than assuming its id
+ LOG.warning(_LW('Id not in sort_keys; is sort_keys unique?'))
+
+ assert(not (sort_dir and sort_dirs))
+
+ # Default the sort direction to ascending
+ if sort_dirs is None and sort_dir is None:
+ sort_dir = 'asc'
+
+ # Ensure a per-column sort direction
+ if sort_dirs is None:
+ sort_dirs = [sort_dir for _sort_key in sort_keys]
+
+ assert(len(sort_dirs) == len(sort_keys))
+
+ # Add sorting
+ for current_sort_key, current_sort_dir in zip(sort_keys, sort_dirs):
+ try:
+ sort_dir_func = {
+ 'asc': sqlalchemy.asc,
+ 'desc': sqlalchemy.desc,
+ }[current_sort_dir]
+ except KeyError:
+ raise ValueError(_("Unknown sort direction, "
+ "must be 'desc' or 'asc'"))
+ try:
+ sort_key_attr = getattr(model, current_sort_key)
+ except AttributeError:
+ raise exception.InvalidSortKey()
+ query = query.order_by(sort_dir_func(sort_key_attr))
+
+ # Add pagination
+ if marker is not None:
+ marker_values = []
+ for sort_key in sort_keys:
+ v = getattr(marker, sort_key)
+ marker_values.append(v)
+
+ # Build up an array of sort criteria as in the docstring
+ criteria_list = []
+ for i in range(len(sort_keys)):
+ crit_attrs = []
+ for j in range(i):
+ model_attr = getattr(model, sort_keys[j])
+ crit_attrs.append((model_attr == marker_values[j]))
+
+ model_attr = getattr(model, sort_keys[i])
+ if sort_dirs[i] == 'desc':
+ crit_attrs.append((model_attr < marker_values[i]))
+ else:
+ crit_attrs.append((model_attr > marker_values[i]))
+
+ criteria = sqlalchemy.sql.and_(*crit_attrs)
+ criteria_list.append(criteria)
+
+ f = sqlalchemy.sql.or_(*criteria_list)
+ query = query.filter(f)
+
+ if limit is not None:
+ query = query.limit(limit)
+
+ return query
+
+
+def _read_deleted_filter(query, db_model, deleted):
+ if 'deleted' not in db_model.__table__.columns:
+ raise ValueError(_("There is no `deleted` column in `%s` table. "
+ "Project doesn't use soft-deleted feature.")
+ % db_model.__name__)
+
+ default_deleted_value = db_model.__table__.c.deleted.default.arg
+ if deleted:
+ query = query.filter(db_model.deleted != default_deleted_value)
+ else:
+ query = query.filter(db_model.deleted == default_deleted_value)
+ return query
+
+
+def _project_filter(query, db_model, project_id):
+ if 'project_id' not in db_model.__table__.columns:
+ raise ValueError(_("There is no `project_id` column in `%s` table.")
+ % db_model.__name__)
+
+ if isinstance(project_id, (list, tuple, set)):
+ query = query.filter(db_model.project_id.in_(project_id))
+ else:
+ query = query.filter(db_model.project_id == project_id)
+
+ return query
+
+
+def model_query(model, session, args=None, **kwargs):
+ """Query helper for db.sqlalchemy api methods.
+
+ This accounts for `deleted` and `project_id` fields.
+
+ :param model: Model to query. Must be a subclass of ModelBase.
+ :type model: models.ModelBase
+
+ :param session: The session to use.
+ :type session: sqlalchemy.orm.session.Session
+
+ :param args: Arguments to query. If None - model is used.
+ :type args: tuple
+
+ Keyword arguments:
+
+ :keyword project_id: If present, allows filtering by project_id(s).
+ Can be either a project_id value, or an iterable of
+ project_id values, or None. If an iterable is passed,
+ only rows whose project_id column value is on the
+ `project_id` list will be returned. If None is passed,
+ only rows which are not bound to any project, will be
+ returned.
+ :type project_id: iterable,
+ model.__table__.columns.project_id.type,
+ None type
+
+ :keyword deleted: If present, allows filtering by deleted field.
+ If True is passed, only deleted entries will be
+ returned, if False - only existing entries.
+ :type deleted: bool
+
+
+ Usage:
+
+ .. code-block:: python
+
+ from oslo_db.sqlalchemy import utils
+
+
+ def get_instance_by_uuid(uuid):
+ session = get_session()
+ with session.begin()
+ return (utils.model_query(models.Instance, session=session)
+ .filter(models.Instance.uuid == uuid)
+ .first())
+
+ def get_nodes_stat():
+ data = (Node.id, Node.cpu, Node.ram, Node.hdd)
+
+ session = get_session()
+ with session.begin()
+ return utils.model_query(Node, session=session, args=data).all()
+
+ Also you can create your own helper, based on ``utils.model_query()``.
+ For example, it can be useful if you plan to use ``project_id`` and
+ ``deleted`` parameters from project's ``context``
+
+ .. code-block:: python
+
+ from oslo_db.sqlalchemy import utils
+
+
+ def _model_query(context, model, session=None, args=None,
+ project_id=None, project_only=False,
+ read_deleted=None):
+
+ # We suppose, that functions ``_get_project_id()`` and
+ # ``_get_deleted()`` should handle passed parameters and
+ # context object (for example, decide, if we need to restrict a user
+ # to query his own entries by project_id or only allow admin to read
+ # deleted entries). For return values, we expect to get
+ # ``project_id`` and ``deleted``, which are suitable for the
+ # ``model_query()`` signature.
+ kwargs = {}
+ if project_id is not None:
+ kwargs['project_id'] = _get_project_id(context, project_id,
+ project_only)
+ if read_deleted is not None:
+ kwargs['deleted'] = _get_deleted_dict(context, read_deleted)
+ session = session or get_session()
+
+ with session.begin():
+ return utils.model_query(model, session=session,
+ args=args, **kwargs)
+
+ def get_instance_by_uuid(context, uuid):
+ return (_model_query(context, models.Instance, read_deleted='yes')
+ .filter(models.Instance.uuid == uuid)
+ .first())
+
+ def get_nodes_data(context, project_id, project_only='allow_none'):
+ data = (Node.id, Node.cpu, Node.ram, Node.hdd)
+
+ return (_model_query(context, Node, args=data, project_id=project_id,
+ project_only=project_only)
+ .all())
+
+ """
+
+ if not issubclass(model, models.ModelBase):
+ raise TypeError(_("model should be a subclass of ModelBase"))
+
+ query = session.query(model) if not args else session.query(*args)
+ if 'deleted' in kwargs:
+ query = _read_deleted_filter(query, model, kwargs['deleted'])
+ if 'project_id' in kwargs:
+ query = _project_filter(query, model, kwargs['project_id'])
+
+ return query
+
+
+def get_table(engine, name):
+ """Returns an sqlalchemy table dynamically from db.
+
+ Needed because the models don't work for us in migrations
+ as models will be far out of sync with the current data.
+
+ .. warning::
+
+ Do not use this method when creating ForeignKeys in database migrations
+ because sqlalchemy needs the same MetaData object to hold information
+ about the parent table and the reference table in the ForeignKey. This
+ method uses a unique MetaData object per table object so it won't work
+ with ForeignKey creation.
+ """
+ metadata = MetaData()
+ metadata.bind = engine
+ return Table(name, metadata, autoload=True)
+
+
+class InsertFromSelect(UpdateBase):
+ """Form the base for `INSERT INTO table (SELECT ... )` statement."""
+ def __init__(self, table, select):
+ self.table = table
+ self.select = select
+
+
+@compiles(InsertFromSelect)
+def visit_insert_from_select(element, compiler, **kw):
+ """Form the `INSERT INTO table (SELECT ... )` statement."""
+ return "INSERT INTO %s %s" % (
+ compiler.process(element.table, asfrom=True),
+ compiler.process(element.select))
+
+
+def _get_not_supported_column(col_name_col_instance, column_name):
+ try:
+ column = col_name_col_instance[column_name]
+ except KeyError:
+ msg = _("Please specify column %s in col_name_col_instance "
+ "param. It is required because column has unsupported "
+ "type by SQLite.")
+ raise exception.ColumnError(msg % column_name)
+
+ if not isinstance(column, Column):
+ msg = _("col_name_col_instance param has wrong type of "
+ "column instance for column %s It should be instance "
+ "of sqlalchemy.Column.")
+ raise exception.ColumnError(msg % column_name)
+ return column
+
+
+def drop_old_duplicate_entries_from_table(migrate_engine, table_name,
+ use_soft_delete, *uc_column_names):
+ """Drop all old rows having the same values for columns in uc_columns.
+
+ This method drop (or mark ad `deleted` if use_soft_delete is True) old
+ duplicate rows form table with name `table_name`.
+
+ :param migrate_engine: Sqlalchemy engine
+ :param table_name: Table with duplicates
+ :param use_soft_delete: If True - values will be marked as `deleted`,
+ if False - values will be removed from table
+ :param uc_column_names: Unique constraint columns
+ """
+ meta = MetaData()
+ meta.bind = migrate_engine
+
+ table = Table(table_name, meta, autoload=True)
+ columns_for_group_by = [table.c[name] for name in uc_column_names]
+
+ columns_for_select = [func.max(table.c.id)]
+ columns_for_select.extend(columns_for_group_by)
+
+ duplicated_rows_select = sqlalchemy.sql.select(
+ columns_for_select, group_by=columns_for_group_by,
+ having=func.count(table.c.id) > 1)
+
+ for row in migrate_engine.execute(duplicated_rows_select).fetchall():
+ # NOTE(boris-42): Do not remove row that has the biggest ID.
+ delete_condition = table.c.id != row[0]
+ is_none = None # workaround for pyflakes
+ delete_condition &= table.c.deleted_at == is_none
+ for name in uc_column_names:
+ delete_condition &= table.c[name] == row[name]
+
+ rows_to_delete_select = sqlalchemy.sql.select(
+ [table.c.id]).where(delete_condition)
+ for row in migrate_engine.execute(rows_to_delete_select).fetchall():
+ LOG.info(_LI("Deleting duplicated row with id: %(id)s from table: "
+ "%(table)s"), dict(id=row[0], table=table_name))
+
+ if use_soft_delete:
+ delete_statement = table.update().\
+ where(delete_condition).\
+ values({
+ 'deleted': literal_column('id'),
+ 'updated_at': literal_column('updated_at'),
+ 'deleted_at': timeutils.utcnow()
+ })
+ else:
+ delete_statement = table.delete().where(delete_condition)
+ migrate_engine.execute(delete_statement)
+
+
+def _get_default_deleted_value(table):
+ if isinstance(table.c.id.type, Integer):
+ return 0
+ if isinstance(table.c.id.type, String):
+ return ""
+ raise exception.ColumnError(_("Unsupported id columns type"))
+
+
+def _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes):
+ table = get_table(migrate_engine, table_name)
+
+ insp = reflection.Inspector.from_engine(migrate_engine)
+ real_indexes = insp.get_indexes(table_name)
+ existing_index_names = dict(
+ [(index['name'], index['column_names']) for index in real_indexes])
+
+ # NOTE(boris-42): Restore indexes on `deleted` column
+ for index in indexes:
+ if 'deleted' not in index['column_names']:
+ continue
+ name = index['name']
+ if name in existing_index_names:
+ column_names = [table.c[c] for c in existing_index_names[name]]
+ old_index = Index(name, *column_names, unique=index["unique"])
+ old_index.drop(migrate_engine)
+
+ column_names = [table.c[c] for c in index['column_names']]
+ new_index = Index(index["name"], *column_names, unique=index["unique"])
+ new_index.create(migrate_engine)
+
+
+def change_deleted_column_type_to_boolean(migrate_engine, table_name,
+ **col_name_col_instance):
+ if migrate_engine.name == "sqlite":
+ return _change_deleted_column_type_to_boolean_sqlite(
+ migrate_engine, table_name, **col_name_col_instance)
+ insp = reflection.Inspector.from_engine(migrate_engine)
+ indexes = insp.get_indexes(table_name)
+
+ table = get_table(migrate_engine, table_name)
+
+ old_deleted = Column('old_deleted', Boolean, default=False)
+ old_deleted.create(table, populate_default=False)
+
+ table.update().\
+ where(table.c.deleted == table.c.id).\
+ values(old_deleted=True).\
+ execute()
+
+ table.c.deleted.drop()
+ table.c.old_deleted.alter(name="deleted")
+
+ _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes)
+
+
+def _change_deleted_column_type_to_boolean_sqlite(migrate_engine, table_name,
+ **col_name_col_instance):
+ insp = reflection.Inspector.from_engine(migrate_engine)
+ table = get_table(migrate_engine, table_name)
+
+ columns = []
+ for column in table.columns:
+ column_copy = None
+ if column.name != "deleted":
+ if isinstance(column.type, NullType):
+ column_copy = _get_not_supported_column(col_name_col_instance,
+ column.name)
+ else:
+ column_copy = column.copy()
+ else:
+ column_copy = Column('deleted', Boolean, default=0)
+ columns.append(column_copy)
+
+ constraints = [constraint.copy() for constraint in table.constraints]
+
+ meta = table.metadata
+ new_table = Table(table_name + "__tmp__", meta,
+ *(columns + constraints))
+ new_table.create()
+
+ indexes = []
+ for index in insp.get_indexes(table_name):
+ column_names = [new_table.c[c] for c in index['column_names']]
+ indexes.append(Index(index["name"], *column_names,
+ unique=index["unique"]))
+
+ c_select = []
+ for c in table.c:
+ if c.name != "deleted":
+ c_select.append(c)
+ else:
+ c_select.append(table.c.deleted == table.c.id)
+
+ ins = InsertFromSelect(new_table, sqlalchemy.sql.select(c_select))
+ migrate_engine.execute(ins)
+
+ table.drop()
+ for index in indexes:
+ index.create(migrate_engine)
+
+ new_table.rename(table_name)
+ new_table.update().\
+ where(new_table.c.deleted == new_table.c.id).\
+ values(deleted=True).\
+ execute()
+
+
+def change_deleted_column_type_to_id_type(migrate_engine, table_name,
+ **col_name_col_instance):
+ if migrate_engine.name == "sqlite":
+ return _change_deleted_column_type_to_id_type_sqlite(
+ migrate_engine, table_name, **col_name_col_instance)
+ insp = reflection.Inspector.from_engine(migrate_engine)
+ indexes = insp.get_indexes(table_name)
+
+ table = get_table(migrate_engine, table_name)
+
+ new_deleted = Column('new_deleted', table.c.id.type,
+ default=_get_default_deleted_value(table))
+ new_deleted.create(table, populate_default=True)
+
+ deleted = True # workaround for pyflakes
+ table.update().\
+ where(table.c.deleted == deleted).\
+ values(new_deleted=table.c.id).\
+ execute()
+ table.c.deleted.drop()
+ table.c.new_deleted.alter(name="deleted")
+
+ _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes)
+
+
+def _change_deleted_column_type_to_id_type_sqlite(migrate_engine, table_name,
+ **col_name_col_instance):
+ # NOTE(boris-42): sqlaclhemy-migrate can't drop column with check
+ # constraints in sqlite DB and our `deleted` column has
+ # 2 check constraints. So there is only one way to remove
+ # these constraints:
+ # 1) Create new table with the same columns, constraints
+ # and indexes. (except deleted column).
+ # 2) Copy all data from old to new table.
+ # 3) Drop old table.
+ # 4) Rename new table to old table name.
+ insp = reflection.Inspector.from_engine(migrate_engine)
+ meta = MetaData(bind=migrate_engine)
+ table = Table(table_name, meta, autoload=True)
+ default_deleted_value = _get_default_deleted_value(table)
+
+ columns = []
+ for column in table.columns:
+ column_copy = None
+ if column.name != "deleted":
+ if isinstance(column.type, NullType):
+ column_copy = _get_not_supported_column(col_name_col_instance,
+ column.name)
+ else:
+ column_copy = column.copy()
+ else:
+ column_copy = Column('deleted', table.c.id.type,
+ default=default_deleted_value)
+ columns.append(column_copy)
+
+ def is_deleted_column_constraint(constraint):
+ # NOTE(boris-42): There is no other way to check is CheckConstraint
+ # associated with deleted column.
+ if not isinstance(constraint, CheckConstraint):
+ return False
+ sqltext = str(constraint.sqltext)
+ # NOTE(I159): in order to omit the CHECK constraint corresponding
+ # to `deleted` column we have to test these patterns which may
+ # vary depending on the SQLAlchemy version used.
+ constraint_markers = (
+ "deleted in (0, 1)",
+ "deleted IN (:deleted_1, :deleted_2)",
+ "deleted IN (:param_1, :param_2)"
+ )
+ return any(sqltext.endswith(marker) for marker in constraint_markers)
+
+ constraints = []
+ for constraint in table.constraints:
+ if not is_deleted_column_constraint(constraint):
+ constraints.append(constraint.copy())
+
+ new_table = Table(table_name + "__tmp__", meta,
+ *(columns + constraints))
+ new_table.create()
+
+ indexes = []
+ for index in insp.get_indexes(table_name):
+ column_names = [new_table.c[c] for c in index['column_names']]
+ indexes.append(Index(index["name"], *column_names,
+ unique=index["unique"]))
+
+ ins = InsertFromSelect(new_table, table.select())
+ migrate_engine.execute(ins)
+
+ table.drop()
+ for index in indexes:
+ index.create(migrate_engine)
+
+ new_table.rename(table_name)
+ deleted = True # workaround for pyflakes
+ new_table.update().\
+ where(new_table.c.deleted == deleted).\
+ values(deleted=new_table.c.id).\
+ execute()
+
+ # NOTE(boris-42): Fix value of deleted column: False -> "" or 0.
+ deleted = False # workaround for pyflakes
+ new_table.update().\
+ where(new_table.c.deleted == deleted).\
+ values(deleted=default_deleted_value).\
+ execute()
+
+
+def get_connect_string(backend, database, user=None, passwd=None,
+ host='localhost'):
+ """Get database connection
+
+ Try to get a connection with a very specific set of values, if we get
+ these then we'll run the tests, otherwise they are skipped
+
+ DEPRECATED: this function is deprecated and will be removed from oslo.db
+ in a few releases. Please use the provisioning system for dealing
+ with URLs and database provisioning.
+
+ """
+ args = {'backend': backend,
+ 'user': user,
+ 'passwd': passwd,
+ 'host': host,
+ 'database': database}
+ if backend == 'sqlite':
+ template = '%(backend)s:///%(database)s'
+ else:
+ template = "%(backend)s://%(user)s:%(passwd)s@%(host)s/%(database)s"
+ return template % args
+
+
+def is_backend_avail(backend, database, user=None, passwd=None):
+ """Return True if the given backend is available.
+
+
+ DEPRECATED: this function is deprecated and will be removed from oslo.db
+ in a few releases. Please use the provisioning system to access
+ databases based on backend availability.
+
+ """
+ from oslo_db.sqlalchemy import provision
+
+ connect_uri = get_connect_string(backend=backend,
+ database=database,
+ user=user,
+ passwd=passwd)
+ try:
+ eng = provision.Backend._ensure_backend_available(connect_uri)
+ eng.dispose()
+ except exception.BackendNotAvailable:
+ return False
+ else:
+ return True
+
+
+def get_db_connection_info(conn_pieces):
+ database = conn_pieces.path.strip('/')
+ loc_pieces = conn_pieces.netloc.split('@')
+ host = loc_pieces[1]
+
+ auth_pieces = loc_pieces[0].split(':')
+ user = auth_pieces[0]
+ password = ""
+ if len(auth_pieces) > 1:
+ password = auth_pieces[1].strip()
+
+ return (user, password, database, host)
+
+
+def index_exists(migrate_engine, table_name, index_name):
+ """Check if given index exists.
+
+ :param migrate_engine: sqlalchemy engine
+ :param table_name: name of the table
+ :param index_name: name of the index
+ """
+ inspector = reflection.Inspector.from_engine(migrate_engine)
+ indexes = inspector.get_indexes(table_name)
+ index_names = [index['name'] for index in indexes]
+ return index_name in index_names
+
+
+def add_index(migrate_engine, table_name, index_name, idx_columns):
+ """Create an index for given columns.
+
+ :param migrate_engine: sqlalchemy engine
+ :param table_name: name of the table
+ :param index_name: name of the index
+ :param idx_columns: tuple with names of columns that will be indexed
+ """
+ table = get_table(migrate_engine, table_name)
+ if not index_exists(migrate_engine, table_name, index_name):
+ index = Index(
+ index_name, *[getattr(table.c, col) for col in idx_columns]
+ )
+ index.create()
+ else:
+ raise ValueError("Index '%s' already exists!" % index_name)
+
+
+def drop_index(migrate_engine, table_name, index_name):
+ """Drop index with given name.
+
+ :param migrate_engine: sqlalchemy engine
+ :param table_name: name of the table
+ :param index_name: name of the index
+ """
+ table = get_table(migrate_engine, table_name)
+ for index in table.indexes:
+ if index.name == index_name:
+ index.drop()
+ break
+ else:
+ raise ValueError("Index '%s' not found!" % index_name)
+
+
+def change_index_columns(migrate_engine, table_name, index_name, new_columns):
+ """Change set of columns that are indexed by given index.
+
+ :param migrate_engine: sqlalchemy engine
+ :param table_name: name of the table
+ :param index_name: name of the index
+ :param new_columns: tuple with names of columns that will be indexed
+ """
+ drop_index(migrate_engine, table_name, index_name)
+ add_index(migrate_engine, table_name, index_name, new_columns)
+
+
+def column_exists(engine, table_name, column):
+ """Check if table has given column.
+
+ :param engine: sqlalchemy engine
+ :param table_name: name of the table
+ :param column: name of the colmn
+ """
+ t = get_table(engine, table_name)
+ return column in t.c
+
+
+class DialectFunctionDispatcher(object):
+ @classmethod
+ def dispatch_for_dialect(cls, expr, multiple=False):
+ """Provide dialect-specific functionality within distinct functions.
+
+ e.g.::
+
+ @dispatch_for_dialect("*")
+ def set_special_option(engine):
+ pass
+
+ @set_special_option.dispatch_for("sqlite")
+ def set_sqlite_special_option(engine):
+ return engine.execute("sqlite thing")
+
+ @set_special_option.dispatch_for("mysql+mysqldb")
+ def set_mysqldb_special_option(engine):
+ return engine.execute("mysqldb thing")
+
+ After the above registration, the ``set_special_option()`` function
+ is now a dispatcher, given a SQLAlchemy ``Engine``, ``Connection``,
+ URL string, or ``sqlalchemy.engine.URL`` object::
+
+ eng = create_engine('...')
+ result = set_special_option(eng)
+
+ The filter system supports two modes, "multiple" and "single".
+ The default is "single", and requires that one and only one function
+ match for a given backend. In this mode, the function may also
+ have a return value, which will be returned by the top level
+ call.
+
+ "multiple" mode, on the other hand, does not support return
+ arguments, but allows for any number of matching functions, where
+ each function will be called::
+
+ # the initial call sets this up as a "multiple" dispatcher
+ @dispatch_for_dialect("*", multiple=True)
+ def set_options(engine):
+ # set options that apply to *all* engines
+
+ @set_options.dispatch_for("postgresql")
+ def set_postgresql_options(engine):
+ # set options that apply to all Postgresql engines
+
+ @set_options.dispatch_for("postgresql+psycopg2")
+ def set_postgresql_psycopg2_options(engine):
+ # set options that apply only to "postgresql+psycopg2"
+
+ @set_options.dispatch_for("*+pyodbc")
+ def set_pyodbc_options(engine):
+ # set options that apply to all pyodbc backends
+
+ Note that in both modes, any number of additional arguments can be
+ accepted by member functions. For example, to populate a dictionary of
+ options, it may be passed in::
+
+ @dispatch_for_dialect("*", multiple=True)
+ def set_engine_options(url, opts):
+ pass
+
+ @set_engine_options.dispatch_for("mysql+mysqldb")
+ def _mysql_set_default_charset_to_utf8(url, opts):
+ opts.setdefault('charset', 'utf-8')
+
+ @set_engine_options.dispatch_for("sqlite")
+ def _set_sqlite_in_memory_check_same_thread(url, opts):
+ if url.database in (None, 'memory'):
+ opts['check_same_thread'] = False
+
+ opts = {}
+ set_engine_options(url, opts)
+
+ The driver specifiers are of the form:
+ ``<database | *>[+<driver | *>]``. That is, database name or "*",
+ followed by an optional ``+`` sign with driver or "*". Omitting
+ the driver name implies all drivers for that database.
+
+ """
+ if multiple:
+ cls = DialectMultiFunctionDispatcher
+ else:
+ cls = DialectSingleFunctionDispatcher
+ return cls().dispatch_for(expr)
+
+ _db_plus_driver_reg = re.compile(r'([^+]+?)(?:\+(.+))?$')
+
+ def dispatch_for(self, expr):
+ def decorate(fn):
+ dbname, driver = self._parse_dispatch(expr)
+ if fn is self:
+ fn = fn._last
+ self._last = fn
+ self._register(expr, dbname, driver, fn)
+ return self
+ return decorate
+
+ def _parse_dispatch(self, text):
+ m = self._db_plus_driver_reg.match(text)
+ if not m:
+ raise ValueError("Couldn't parse database[+driver]: %r" % text)
+ return m.group(1) or '*', m.group(2) or '*'
+
+ def __call__(self, *arg, **kw):
+ target = arg[0]
+ return self._dispatch_on(
+ self._url_from_target(target), target, arg, kw)
+
+ def _url_from_target(self, target):
+ if isinstance(target, Connectable):
+ return target.engine.url
+ elif isinstance(target, six.string_types):
+ if "://" not in target:
+ target_url = sa_url.make_url("%s://" % target)
+ else:
+ target_url = sa_url.make_url(target)
+ return target_url
+ elif isinstance(target, sa_url.URL):
+ return target
+ else:
+ raise ValueError("Invalid target type: %r" % target)
+
+ def dispatch_on_drivername(self, drivername):
+ """Return a sub-dispatcher for the given drivername.
+
+ This provides a means of calling a different function, such as the
+ "*" function, for a given target object that normally refers
+ to a sub-function.
+
+ """
+ dbname, driver = self._db_plus_driver_reg.match(drivername).group(1, 2)
+
+ def go(*arg, **kw):
+ return self._dispatch_on_db_driver(dbname, "*", arg, kw)
+
+ return go
+
+ def _dispatch_on(self, url, target, arg, kw):
+ dbname, driver = self._db_plus_driver_reg.match(
+ url.drivername).group(1, 2)
+ if not driver:
+ driver = url.get_dialect().driver
+
+ return self._dispatch_on_db_driver(dbname, driver, arg, kw)
+
+ def _invoke_fn(self, fn, arg, kw):
+ return fn(*arg, **kw)
+
+
+class DialectSingleFunctionDispatcher(DialectFunctionDispatcher):
+ def __init__(self):
+ self.reg = collections.defaultdict(dict)
+
+ def _register(self, expr, dbname, driver, fn):
+ fn_dict = self.reg[dbname]
+ if driver in fn_dict:
+ raise TypeError("Multiple functions for expression %r" % expr)
+ fn_dict[driver] = fn
+
+ def _matches(self, dbname, driver):
+ for db in (dbname, '*'):
+ subdict = self.reg[db]
+ for drv in (driver, '*'):
+ if drv in subdict:
+ return subdict[drv]
+ else:
+ raise ValueError(
+ "No default function found for driver: %r" %
+ ("%s+%s" % (dbname, driver)))
+
+ def _dispatch_on_db_driver(self, dbname, driver, arg, kw):
+ fn = self._matches(dbname, driver)
+ return self._invoke_fn(fn, arg, kw)
+
+
+class DialectMultiFunctionDispatcher(DialectFunctionDispatcher):
+ def __init__(self):
+ self.reg = collections.defaultdict(
+ lambda: collections.defaultdict(list))
+
+ def _register(self, expr, dbname, driver, fn):
+ self.reg[dbname][driver].append(fn)
+
+ def _matches(self, dbname, driver):
+ if driver != '*':
+ drivers = (driver, '*')
+ else:
+ drivers = ('*', )
+
+ for db in (dbname, '*'):
+ subdict = self.reg[db]
+ for drv in drivers:
+ for fn in subdict[drv]:
+ yield fn
+
+ def _dispatch_on_db_driver(self, dbname, driver, arg, kw):
+ for fn in self._matches(dbname, driver):
+ if self._invoke_fn(fn, arg, kw) is not None:
+ raise TypeError(
+ "Return value not allowed for "
+ "multiple filtered function")
+
+dispatch_for_dialect = DialectFunctionDispatcher.dispatch_for_dialect
+
+
+def get_non_innodb_tables(connectable, skip_tables=('migrate_version',
+ 'alembic_version')):
+ """Get a list of tables which don't use InnoDB storage engine.
+
+ :param connectable: a SQLAlchemy Engine or a Connection instance
+ :param skip_tables: a list of tables which might have a different
+ storage engine
+ """
+
+ query_str = """
+ SELECT table_name
+ FROM information_schema.tables
+ WHERE table_schema = :database AND
+ engine != 'InnoDB'
+ """
+
+ params = {}
+ if skip_tables:
+ params = dict(
+ ('skip_%s' % i, table_name)
+ for i, table_name in enumerate(skip_tables)
+ )
+
+ placeholders = ', '.join(':' + p for p in params)
+ query_str += ' AND table_name NOT IN (%s)' % placeholders
+
+ params['database'] = connectable.engine.url.database
+ query = text(query_str)
+ noninnodb = connectable.execute(query, **params)
+ return [i[0] for i in noninnodb]