diff options
-rw-r--r-- | oslo_db/sqlalchemy/utils.py | 37 | ||||
-rw-r--r-- | oslo_db/tests/sqlalchemy/test_utils.py | 35 |
2 files changed, 62 insertions, 10 deletions
diff --git a/oslo_db/sqlalchemy/utils.py b/oslo_db/sqlalchemy/utils.py index 7e18543..4f6a072 100644 --- a/oslo_db/sqlalchemy/utils.py +++ b/oslo_db/sqlalchemy/utils.py @@ -18,6 +18,7 @@ import collections import contextlib +import itertools import logging import re @@ -53,6 +54,10 @@ LOG = logging.getLogger(__name__) _DBURL_REGEX = re.compile(r"[^:]+://([^:]+):([^@]+)@.+") +_VALID_SORT_DIR = [ + "-".join(x) for x in itertools.product(["asc", "desc"], + ["nullsfirst", "nullslast"])] + def sanitize_db_url(url): match = _DBURL_REGEX.match(url) @@ -88,6 +93,8 @@ def paginate_query(query, model, limit, sort_keys, marker=None, :param marker: the last item of the previous page; we returns the next results after this value. :param sort_dir: direction in which results should be sorted (asc, desc) + suffix -nullsfirst, -nullslast can be added to defined + the ordering of null values :param sort_dirs: per-column array of sort_dirs, corresponding to sort_keys :rtype: sqlalchemy.orm.query.Query @@ -114,21 +121,31 @@ def paginate_query(query, model, limit, sort_keys, marker=None, # Add sorting for current_sort_key, current_sort_dir in zip(sort_keys, sort_dirs): try: + inspect(model).all_orm_descriptors[current_sort_key] + except KeyError: + raise exception.InvalidSortKey() + else: + sort_key_attr = getattr(model, current_sort_key) + + try: + main_sort_dir, __, null_sort_dir = current_sort_dir.partition("-") sort_dir_func = { 'asc': sqlalchemy.asc, 'desc': sqlalchemy.desc, - }[current_sort_dir] + }[main_sort_dir] + + null_order_by_stmt = { + "": None, + "nullsfirst": sort_key_attr.is_(None), + "nullslast": sort_key_attr.isnot(None), + }[null_sort_dir] except KeyError: raise ValueError(_("Unknown sort direction, " - "must be 'desc' or 'asc'")) - try: - inspect(model).\ - all_orm_descriptors[current_sort_key] - except KeyError: - raise exception.InvalidSortKey() - else: - sort_key_attr = getattr(model, current_sort_key) + "must be one of: %s") % + ", ".join(_VALID_SORT_DIR)) + if null_order_by_stmt is not None: + query = query.order_by(sqlalchemy.desc(null_order_by_stmt)) query = query.order_by(sort_dir_func(sort_key_attr)) # Add pagination @@ -147,7 +164,7 @@ def paginate_query(query, model, limit, sort_keys, marker=None, crit_attrs.append((model_attr == marker_values[j])) model_attr = getattr(model, sort_keys[i]) - if sort_dirs[i] == 'desc': + if sort_dirs[i].startswith('desc'): crit_attrs.append((model_attr < marker_values[i])) else: crit_attrs.append((model_attr > marker_values[i])) diff --git a/oslo_db/tests/sqlalchemy/test_utils.py b/oslo_db/tests/sqlalchemy/test_utils.py index 054f99d..b7a00fb 100644 --- a/oslo_db/tests/sqlalchemy/test_utils.py +++ b/oslo_db/tests/sqlalchemy/test_utils.py @@ -166,6 +166,11 @@ class TestPaginateQuery(test_base.BaseTestCase): utils.paginate_query, self.query, self.model, 5, ['foo']) + def test_paginate_query_attribute_error_invalid_sortkey_3(self): + self.assertRaises(exception.InvalidSortKey, + utils.paginate_query, self.query, + self.model, 5, ['asc-nullinvalid']) + def test_paginate_query_assertion_error(self): self.mox.ReplayAll() self.assertRaises(AssertionError, @@ -200,6 +205,36 @@ class TestPaginateQuery(test_base.BaseTestCase): marker=self.marker, sort_dirs=['asc', 'desc']) + def test_paginate_query_null(self): + self.mox.StubOutWithMock(self.model.user_id, 'isnot') + self.model.user_id.isnot(None).AndReturn('asc_null_1') + sqlalchemy.desc('asc_null_1').AndReturn('asc_null_2') + self.query.order_by('asc_null_2').AndReturn(self.query) + + sqlalchemy.asc(self.model.user_id).AndReturn('asc_1') + self.query.order_by('asc_1').AndReturn(self.query) + + self.mox.StubOutWithMock(self.model.project_id, 'is_') + self.model.project_id.is_(None).AndReturn('desc_null_1') + sqlalchemy.desc('desc_null_1').AndReturn('desc_null_2') + self.query.order_by('desc_null_2').AndReturn(self.query) + + sqlalchemy.desc(self.model.project_id).AndReturn('desc_1') + self.query.order_by('desc_1').AndReturn(self.query) + + self.mox.StubOutWithMock(sqlalchemy.sql, 'and_') + sqlalchemy.sql.and_(mock.ANY).AndReturn('some_crit') + sqlalchemy.sql.and_(mock.ANY, mock.ANY).AndReturn('another_crit') + self.mox.StubOutWithMock(sqlalchemy.sql, 'or_') + sqlalchemy.sql.or_('some_crit', 'another_crit').AndReturn('some_f') + self.query.filter('some_f').AndReturn(self.query) + self.query.limit(5).AndReturn(self.query) + self.mox.ReplayAll() + utils.paginate_query(self.query, self.model, 5, + ['user_id', 'project_id'], + marker=self.marker, + sort_dirs=['asc-nullslast', 'desc-nullsfirst']) + def test_paginate_query_value_error(self): sqlalchemy.asc(self.model.user_id).AndReturn('asc_1') self.query.order_by('asc_1').AndReturn(self.query) |