From bce8ed304274f767b0a12eadfdf07e220495c160 Mon Sep 17 00:00:00 2001 From: Eli Qiao Date: Tue, 23 Dec 2014 17:07:15 +0800 Subject: Make sure sort_key_attr is QueryableAttribute when query When doing query.order_by, sort_key_attr is get from model class, we need to make sure sort_key_attr is really a QueryableAttribute type instance before we do the query or it will cause errors. This will prevent if there IS a function which name is same as sort_key. Closes-Bug: 1405069 Change-Id: I8a3eb08ab3469ec08e05bfce754b664943d65c83 --- oslo_db/sqlalchemy/utils.py | 6 ++- .../tests/old_import_api/sqlalchemy/test_utils.py | 58 ++++++++++++++-------- oslo_db/tests/sqlalchemy/test_utils.py | 58 ++++++++++++++-------- 3 files changed, 78 insertions(+), 44 deletions(-) diff --git a/oslo_db/sqlalchemy/utils.py b/oslo_db/sqlalchemy/utils.py index 6a66bb9..919ac9e 100644 --- a/oslo_db/sqlalchemy/utils.py +++ b/oslo_db/sqlalchemy/utils.py @@ -32,6 +32,7 @@ from sqlalchemy.engine import url as sa_url from sqlalchemy.ext.compiler import compiles from sqlalchemy import func from sqlalchemy import Index +from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import MetaData from sqlalchemy.sql.expression import literal_column @@ -147,8 +148,9 @@ def paginate_query(query, model, limit, sort_keys, marker=None, raise ValueError(_("Unknown sort direction, " "must be 'desc' or 'asc'")) try: - sort_key_attr = getattr(model, current_sort_key) - except AttributeError: + sort_key_attr = inspect(model).\ + all_orm_descriptors[current_sort_key] + except KeyError: raise exception.InvalidSortKey() query = query.order_by(sort_dir_func(sort_key_attr)) diff --git a/oslo_db/tests/old_import_api/sqlalchemy/test_utils.py b/oslo_db/tests/old_import_api/sqlalchemy/test_utils.py index 6e865c6..fea7fbd 100644 --- a/oslo_db/tests/old_import_api/sqlalchemy/test_utils.py +++ b/oslo_db/tests/old_import_api/sqlalchemy/test_utils.py @@ -43,6 +43,7 @@ from oslo_db.tests.old_import_api import utils as test_utils SA_VERSION = tuple(map(int, sqlalchemy.__version__.split('.'))) +Base = declarative_base() class TestSanitizeDbUrl(test_base.BaseTestCase): @@ -86,6 +87,17 @@ class FakeModel(object): return '' % self.values +class FakeTable(Base): + __tablename__ = 'fake_table' + + user_id = Column(String(50), primary_key=True) + project_id = Column(String(50)) + snapshot_id = Column(String(50)) + + def foo(self): + pass + + class TestPaginateQuery(test_base.BaseTestCase): def setUp(self): super(TestPaginateQuery, self).setUp() @@ -94,23 +106,17 @@ class TestPaginateQuery(test_base.BaseTestCase): self.query = self.mox.CreateMockAnything() self.mox.StubOutWithMock(sqlalchemy, 'asc') self.mox.StubOutWithMock(sqlalchemy, 'desc') - self.marker = FakeModel({ - 'user_id': 'user', - 'project_id': 'p', - 'snapshot_id': 's', - }) - self.model = FakeModel({ - 'user_id': 'user', - 'project_id': 'project', - 'snapshot_id': 'snapshot', - }) + self.marker = FakeTable(user_id='user', + project_id='p', + snapshot_id='s') + self.model = FakeTable def test_paginate_query_no_pagination_no_sort_dirs(self): - sqlalchemy.asc('user').AndReturn('asc_3') + sqlalchemy.asc(self.model.user_id).AndReturn('asc_3') self.query.order_by('asc_3').AndReturn(self.query) - sqlalchemy.asc('project').AndReturn('asc_2') + sqlalchemy.asc(self.model.project_id).AndReturn('asc_2') self.query.order_by('asc_2').AndReturn(self.query) - sqlalchemy.asc('snapshot').AndReturn('asc_1') + sqlalchemy.asc(self.model.snapshot_id).AndReturn('asc_1') self.query.order_by('asc_1').AndReturn(self.query) self.query.limit(5).AndReturn(self.query) self.mox.ReplayAll() @@ -118,9 +124,9 @@ class TestPaginateQuery(test_base.BaseTestCase): ['user_id', 'project_id', 'snapshot_id']) def test_paginate_query_no_pagination(self): - sqlalchemy.asc('user').AndReturn('asc') + sqlalchemy.asc(self.model.user_id).AndReturn('asc') self.query.order_by('asc').AndReturn(self.query) - sqlalchemy.desc('project').AndReturn('desc') + sqlalchemy.desc(self.model.project_id).AndReturn('desc') self.query.order_by('desc').AndReturn(self.query) self.query.limit(5).AndReturn(self.query) self.mox.ReplayAll() @@ -129,13 +135,23 @@ class TestPaginateQuery(test_base.BaseTestCase): sort_dirs=['asc', 'desc']) def test_paginate_query_attribute_error(self): - sqlalchemy.asc('user').AndReturn('asc') + sqlalchemy.asc(self.model.user_id).AndReturn('asc') self.query.order_by('asc').AndReturn(self.query) self.mox.ReplayAll() self.assertRaises(exception.InvalidSortKey, utils.paginate_query, self.query, self.model, 5, ['user_id', 'non-existent key']) + def test_paginate_query_attribute_error_invalid_sortkey(self): + self.assertRaises(exception.InvalidSortKey, + utils.paginate_query, self.query, + self.model, 5, ['bad_user_id']) + + def test_paginate_query_attribute_error_invalid_sortkey_2(self): + self.assertRaises(exception.InvalidSortKey, + utils.paginate_query, self.query, + self.model, 5, ['foo']) + def test_paginate_query_assertion_error(self): self.mox.ReplayAll() self.assertRaises(AssertionError, @@ -153,13 +169,13 @@ class TestPaginateQuery(test_base.BaseTestCase): sort_dir=None, sort_dirs=['asc', 'desk']) def test_paginate_query(self): - sqlalchemy.asc('user').AndReturn('asc_1') + sqlalchemy.asc(self.model.user_id).AndReturn('asc_1') self.query.order_by('asc_1').AndReturn(self.query) - sqlalchemy.desc('project').AndReturn('desc_1') + sqlalchemy.desc(self.model.project_id).AndReturn('desc_1') self.query.order_by('desc_1').AndReturn(self.query) self.mox.StubOutWithMock(sqlalchemy.sql, 'and_') - sqlalchemy.sql.and_(False).AndReturn('some_crit') - sqlalchemy.sql.and_(True, False).AndReturn('another_crit') + sqlalchemy.sql.and_(mock.ANY).AndReturn('some_crit') + sqlalchemy.sql.and_(mock.ANY, mock.ANY).AndReturn('another_crit') self.mox.StubOutWithMock(sqlalchemy.sql, 'or_') sqlalchemy.sql.or_('some_crit', 'another_crit').AndReturn('some_f') self.query.filter('some_f').AndReturn(self.query) @@ -171,7 +187,7 @@ class TestPaginateQuery(test_base.BaseTestCase): sort_dirs=['asc', 'desc']) def test_paginate_query_value_error(self): - sqlalchemy.asc('user').AndReturn('asc_1') + sqlalchemy.asc(self.model.user_id).AndReturn('asc_1') self.query.order_by('asc_1').AndReturn(self.query) self.mox.ReplayAll() self.assertRaises(ValueError, utils.paginate_query, diff --git a/oslo_db/tests/sqlalchemy/test_utils.py b/oslo_db/tests/sqlalchemy/test_utils.py index 509cf48..06a1445 100644 --- a/oslo_db/tests/sqlalchemy/test_utils.py +++ b/oslo_db/tests/sqlalchemy/test_utils.py @@ -41,6 +41,7 @@ from oslo_db.sqlalchemy import utils from oslo_db.tests import utils as test_utils +Base = declarative_base() SA_VERSION = tuple(map(int, sqlalchemy.__version__.split('.'))) @@ -64,6 +65,17 @@ class CustomType(UserDefinedType): return "CustomType" +class FakeTable(Base): + __tablename__ = 'fake_table' + + user_id = Column(String(50), primary_key=True) + project_id = Column(String(50)) + snapshot_id = Column(String(50)) + + def foo(self): + pass + + class FakeModel(object): def __init__(self, values): self.values = values @@ -93,23 +105,17 @@ class TestPaginateQuery(test_base.BaseTestCase): self.query = self.mox.CreateMockAnything() self.mox.StubOutWithMock(sqlalchemy, 'asc') self.mox.StubOutWithMock(sqlalchemy, 'desc') - self.marker = FakeModel({ - 'user_id': 'user', - 'project_id': 'p', - 'snapshot_id': 's', - }) - self.model = FakeModel({ - 'user_id': 'user', - 'project_id': 'project', - 'snapshot_id': 'snapshot', - }) + self.marker = FakeTable(user_id='user', + project_id='p', + snapshot_id='s') + self.model = FakeTable def test_paginate_query_no_pagination_no_sort_dirs(self): - sqlalchemy.asc('user').AndReturn('asc_3') + sqlalchemy.asc(self.model.user_id).AndReturn('asc_3') self.query.order_by('asc_3').AndReturn(self.query) - sqlalchemy.asc('project').AndReturn('asc_2') + sqlalchemy.asc(self.model.project_id).AndReturn('asc_2') self.query.order_by('asc_2').AndReturn(self.query) - sqlalchemy.asc('snapshot').AndReturn('asc_1') + sqlalchemy.asc(self.model.snapshot_id).AndReturn('asc_1') self.query.order_by('asc_1').AndReturn(self.query) self.query.limit(5).AndReturn(self.query) self.mox.ReplayAll() @@ -117,9 +123,9 @@ class TestPaginateQuery(test_base.BaseTestCase): ['user_id', 'project_id', 'snapshot_id']) def test_paginate_query_no_pagination(self): - sqlalchemy.asc('user').AndReturn('asc') + sqlalchemy.asc(self.model.user_id).AndReturn('asc') self.query.order_by('asc').AndReturn(self.query) - sqlalchemy.desc('project').AndReturn('desc') + sqlalchemy.desc(self.model.project_id).AndReturn('desc') self.query.order_by('desc').AndReturn(self.query) self.query.limit(5).AndReturn(self.query) self.mox.ReplayAll() @@ -128,13 +134,23 @@ class TestPaginateQuery(test_base.BaseTestCase): sort_dirs=['asc', 'desc']) def test_paginate_query_attribute_error(self): - sqlalchemy.asc('user').AndReturn('asc') + sqlalchemy.asc(self.model.user_id).AndReturn('asc') self.query.order_by('asc').AndReturn(self.query) self.mox.ReplayAll() self.assertRaises(exception.InvalidSortKey, utils.paginate_query, self.query, self.model, 5, ['user_id', 'non-existent key']) + def test_paginate_query_attribute_error_invalid_sortkey(self): + self.assertRaises(exception.InvalidSortKey, + utils.paginate_query, self.query, + self.model, 5, ['bad_user_id']) + + def test_paginate_query_attribute_error_invalid_sortkey_2(self): + self.assertRaises(exception.InvalidSortKey, + utils.paginate_query, self.query, + self.model, 5, ['foo']) + def test_paginate_query_assertion_error(self): self.mox.ReplayAll() self.assertRaises(AssertionError, @@ -152,13 +168,13 @@ class TestPaginateQuery(test_base.BaseTestCase): sort_dir=None, sort_dirs=['asc', 'desk']) def test_paginate_query(self): - sqlalchemy.asc('user').AndReturn('asc_1') + sqlalchemy.asc(self.model.user_id).AndReturn('asc_1') self.query.order_by('asc_1').AndReturn(self.query) - sqlalchemy.desc('project').AndReturn('desc_1') + sqlalchemy.desc(self.model.project_id).AndReturn('desc_1') self.query.order_by('desc_1').AndReturn(self.query) self.mox.StubOutWithMock(sqlalchemy.sql, 'and_') - sqlalchemy.sql.and_(False).AndReturn('some_crit') - sqlalchemy.sql.and_(True, False).AndReturn('another_crit') + sqlalchemy.sql.and_(mock.ANY).AndReturn('some_crit') + sqlalchemy.sql.and_(mock.ANY, mock.ANY).AndReturn('another_crit') self.mox.StubOutWithMock(sqlalchemy.sql, 'or_') sqlalchemy.sql.or_('some_crit', 'another_crit').AndReturn('some_f') self.query.filter('some_f').AndReturn(self.query) @@ -170,7 +186,7 @@ class TestPaginateQuery(test_base.BaseTestCase): sort_dirs=['asc', 'desc']) def test_paginate_query_value_error(self): - sqlalchemy.asc('user').AndReturn('asc_1') + sqlalchemy.asc(self.model.user_id).AndReturn('asc_1') self.query.order_by('asc_1').AndReturn(self.query) self.mox.ReplayAll() self.assertRaises(ValueError, utils.paginate_query, -- cgit v1.2.1