summaryrefslogtreecommitdiff
path: root/oslo_db
diff options
context:
space:
mode:
authorEli Qiao <taget@linux.vnet.ibm.com>2014-12-23 17:07:15 +0800
committerEli Qiao <taget@linux.vnet.ibm.com>2015-01-09 06:12:06 +0800
commitbce8ed304274f767b0a12eadfdf07e220495c160 (patch)
treed24d775703c8e258b1ecd9cdeea8d73fa04d9958 /oslo_db
parent98b434db7d1b52aee32782065df155988918cc3f (diff)
downloadoslo-db-bce8ed304274f767b0a12eadfdf07e220495c160.tar.gz
Make sure sort_key_attr is QueryableAttribute when query
When doing query.order_by, sort_key_attr is get from model class, we need to make sure sort_key_attr is really a QueryableAttribute type instance before we do the query or it will cause errors. This will prevent if there IS a function which name is same as sort_key. Closes-Bug: 1405069 Change-Id: I8a3eb08ab3469ec08e05bfce754b664943d65c83
Diffstat (limited to 'oslo_db')
-rw-r--r--oslo_db/sqlalchemy/utils.py6
-rw-r--r--oslo_db/tests/old_import_api/sqlalchemy/test_utils.py58
-rw-r--r--oslo_db/tests/sqlalchemy/test_utils.py58
3 files changed, 78 insertions, 44 deletions
diff --git a/oslo_db/sqlalchemy/utils.py b/oslo_db/sqlalchemy/utils.py
index 6a66bb9..919ac9e 100644
--- a/oslo_db/sqlalchemy/utils.py
+++ b/oslo_db/sqlalchemy/utils.py
@@ -32,6 +32,7 @@ from sqlalchemy.engine import url as sa_url
from sqlalchemy.ext.compiler import compiles
from sqlalchemy import func
from sqlalchemy import Index
+from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import MetaData
from sqlalchemy.sql.expression import literal_column
@@ -147,8 +148,9 @@ def paginate_query(query, model, limit, sort_keys, marker=None,
raise ValueError(_("Unknown sort direction, "
"must be 'desc' or 'asc'"))
try:
- sort_key_attr = getattr(model, current_sort_key)
- except AttributeError:
+ sort_key_attr = inspect(model).\
+ all_orm_descriptors[current_sort_key]
+ except KeyError:
raise exception.InvalidSortKey()
query = query.order_by(sort_dir_func(sort_key_attr))
diff --git a/oslo_db/tests/old_import_api/sqlalchemy/test_utils.py b/oslo_db/tests/old_import_api/sqlalchemy/test_utils.py
index 6e865c6..fea7fbd 100644
--- a/oslo_db/tests/old_import_api/sqlalchemy/test_utils.py
+++ b/oslo_db/tests/old_import_api/sqlalchemy/test_utils.py
@@ -43,6 +43,7 @@ from oslo_db.tests.old_import_api import utils as test_utils
SA_VERSION = tuple(map(int, sqlalchemy.__version__.split('.')))
+Base = declarative_base()
class TestSanitizeDbUrl(test_base.BaseTestCase):
@@ -86,6 +87,17 @@ class FakeModel(object):
return '<FakeModel: %s>' % self.values
+class FakeTable(Base):
+ __tablename__ = 'fake_table'
+
+ user_id = Column(String(50), primary_key=True)
+ project_id = Column(String(50))
+ snapshot_id = Column(String(50))
+
+ def foo(self):
+ pass
+
+
class TestPaginateQuery(test_base.BaseTestCase):
def setUp(self):
super(TestPaginateQuery, self).setUp()
@@ -94,23 +106,17 @@ class TestPaginateQuery(test_base.BaseTestCase):
self.query = self.mox.CreateMockAnything()
self.mox.StubOutWithMock(sqlalchemy, 'asc')
self.mox.StubOutWithMock(sqlalchemy, 'desc')
- self.marker = FakeModel({
- 'user_id': 'user',
- 'project_id': 'p',
- 'snapshot_id': 's',
- })
- self.model = FakeModel({
- 'user_id': 'user',
- 'project_id': 'project',
- 'snapshot_id': 'snapshot',
- })
+ self.marker = FakeTable(user_id='user',
+ project_id='p',
+ snapshot_id='s')
+ self.model = FakeTable
def test_paginate_query_no_pagination_no_sort_dirs(self):
- sqlalchemy.asc('user').AndReturn('asc_3')
+ sqlalchemy.asc(self.model.user_id).AndReturn('asc_3')
self.query.order_by('asc_3').AndReturn(self.query)
- sqlalchemy.asc('project').AndReturn('asc_2')
+ sqlalchemy.asc(self.model.project_id).AndReturn('asc_2')
self.query.order_by('asc_2').AndReturn(self.query)
- sqlalchemy.asc('snapshot').AndReturn('asc_1')
+ sqlalchemy.asc(self.model.snapshot_id).AndReturn('asc_1')
self.query.order_by('asc_1').AndReturn(self.query)
self.query.limit(5).AndReturn(self.query)
self.mox.ReplayAll()
@@ -118,9 +124,9 @@ class TestPaginateQuery(test_base.BaseTestCase):
['user_id', 'project_id', 'snapshot_id'])
def test_paginate_query_no_pagination(self):
- sqlalchemy.asc('user').AndReturn('asc')
+ sqlalchemy.asc(self.model.user_id).AndReturn('asc')
self.query.order_by('asc').AndReturn(self.query)
- sqlalchemy.desc('project').AndReturn('desc')
+ sqlalchemy.desc(self.model.project_id).AndReturn('desc')
self.query.order_by('desc').AndReturn(self.query)
self.query.limit(5).AndReturn(self.query)
self.mox.ReplayAll()
@@ -129,13 +135,23 @@ class TestPaginateQuery(test_base.BaseTestCase):
sort_dirs=['asc', 'desc'])
def test_paginate_query_attribute_error(self):
- sqlalchemy.asc('user').AndReturn('asc')
+ sqlalchemy.asc(self.model.user_id).AndReturn('asc')
self.query.order_by('asc').AndReturn(self.query)
self.mox.ReplayAll()
self.assertRaises(exception.InvalidSortKey,
utils.paginate_query, self.query,
self.model, 5, ['user_id', 'non-existent key'])
+ def test_paginate_query_attribute_error_invalid_sortkey(self):
+ self.assertRaises(exception.InvalidSortKey,
+ utils.paginate_query, self.query,
+ self.model, 5, ['bad_user_id'])
+
+ def test_paginate_query_attribute_error_invalid_sortkey_2(self):
+ self.assertRaises(exception.InvalidSortKey,
+ utils.paginate_query, self.query,
+ self.model, 5, ['foo'])
+
def test_paginate_query_assertion_error(self):
self.mox.ReplayAll()
self.assertRaises(AssertionError,
@@ -153,13 +169,13 @@ class TestPaginateQuery(test_base.BaseTestCase):
sort_dir=None, sort_dirs=['asc', 'desk'])
def test_paginate_query(self):
- sqlalchemy.asc('user').AndReturn('asc_1')
+ sqlalchemy.asc(self.model.user_id).AndReturn('asc_1')
self.query.order_by('asc_1').AndReturn(self.query)
- sqlalchemy.desc('project').AndReturn('desc_1')
+ sqlalchemy.desc(self.model.project_id).AndReturn('desc_1')
self.query.order_by('desc_1').AndReturn(self.query)
self.mox.StubOutWithMock(sqlalchemy.sql, 'and_')
- sqlalchemy.sql.and_(False).AndReturn('some_crit')
- sqlalchemy.sql.and_(True, False).AndReturn('another_crit')
+ sqlalchemy.sql.and_(mock.ANY).AndReturn('some_crit')
+ sqlalchemy.sql.and_(mock.ANY, mock.ANY).AndReturn('another_crit')
self.mox.StubOutWithMock(sqlalchemy.sql, 'or_')
sqlalchemy.sql.or_('some_crit', 'another_crit').AndReturn('some_f')
self.query.filter('some_f').AndReturn(self.query)
@@ -171,7 +187,7 @@ class TestPaginateQuery(test_base.BaseTestCase):
sort_dirs=['asc', 'desc'])
def test_paginate_query_value_error(self):
- sqlalchemy.asc('user').AndReturn('asc_1')
+ sqlalchemy.asc(self.model.user_id).AndReturn('asc_1')
self.query.order_by('asc_1').AndReturn(self.query)
self.mox.ReplayAll()
self.assertRaises(ValueError, utils.paginate_query,
diff --git a/oslo_db/tests/sqlalchemy/test_utils.py b/oslo_db/tests/sqlalchemy/test_utils.py
index 509cf48..06a1445 100644
--- a/oslo_db/tests/sqlalchemy/test_utils.py
+++ b/oslo_db/tests/sqlalchemy/test_utils.py
@@ -41,6 +41,7 @@ from oslo_db.sqlalchemy import utils
from oslo_db.tests import utils as test_utils
+Base = declarative_base()
SA_VERSION = tuple(map(int, sqlalchemy.__version__.split('.')))
@@ -64,6 +65,17 @@ class CustomType(UserDefinedType):
return "CustomType"
+class FakeTable(Base):
+ __tablename__ = 'fake_table'
+
+ user_id = Column(String(50), primary_key=True)
+ project_id = Column(String(50))
+ snapshot_id = Column(String(50))
+
+ def foo(self):
+ pass
+
+
class FakeModel(object):
def __init__(self, values):
self.values = values
@@ -93,23 +105,17 @@ class TestPaginateQuery(test_base.BaseTestCase):
self.query = self.mox.CreateMockAnything()
self.mox.StubOutWithMock(sqlalchemy, 'asc')
self.mox.StubOutWithMock(sqlalchemy, 'desc')
- self.marker = FakeModel({
- 'user_id': 'user',
- 'project_id': 'p',
- 'snapshot_id': 's',
- })
- self.model = FakeModel({
- 'user_id': 'user',
- 'project_id': 'project',
- 'snapshot_id': 'snapshot',
- })
+ self.marker = FakeTable(user_id='user',
+ project_id='p',
+ snapshot_id='s')
+ self.model = FakeTable
def test_paginate_query_no_pagination_no_sort_dirs(self):
- sqlalchemy.asc('user').AndReturn('asc_3')
+ sqlalchemy.asc(self.model.user_id).AndReturn('asc_3')
self.query.order_by('asc_3').AndReturn(self.query)
- sqlalchemy.asc('project').AndReturn('asc_2')
+ sqlalchemy.asc(self.model.project_id).AndReturn('asc_2')
self.query.order_by('asc_2').AndReturn(self.query)
- sqlalchemy.asc('snapshot').AndReturn('asc_1')
+ sqlalchemy.asc(self.model.snapshot_id).AndReturn('asc_1')
self.query.order_by('asc_1').AndReturn(self.query)
self.query.limit(5).AndReturn(self.query)
self.mox.ReplayAll()
@@ -117,9 +123,9 @@ class TestPaginateQuery(test_base.BaseTestCase):
['user_id', 'project_id', 'snapshot_id'])
def test_paginate_query_no_pagination(self):
- sqlalchemy.asc('user').AndReturn('asc')
+ sqlalchemy.asc(self.model.user_id).AndReturn('asc')
self.query.order_by('asc').AndReturn(self.query)
- sqlalchemy.desc('project').AndReturn('desc')
+ sqlalchemy.desc(self.model.project_id).AndReturn('desc')
self.query.order_by('desc').AndReturn(self.query)
self.query.limit(5).AndReturn(self.query)
self.mox.ReplayAll()
@@ -128,13 +134,23 @@ class TestPaginateQuery(test_base.BaseTestCase):
sort_dirs=['asc', 'desc'])
def test_paginate_query_attribute_error(self):
- sqlalchemy.asc('user').AndReturn('asc')
+ sqlalchemy.asc(self.model.user_id).AndReturn('asc')
self.query.order_by('asc').AndReturn(self.query)
self.mox.ReplayAll()
self.assertRaises(exception.InvalidSortKey,
utils.paginate_query, self.query,
self.model, 5, ['user_id', 'non-existent key'])
+ def test_paginate_query_attribute_error_invalid_sortkey(self):
+ self.assertRaises(exception.InvalidSortKey,
+ utils.paginate_query, self.query,
+ self.model, 5, ['bad_user_id'])
+
+ def test_paginate_query_attribute_error_invalid_sortkey_2(self):
+ self.assertRaises(exception.InvalidSortKey,
+ utils.paginate_query, self.query,
+ self.model, 5, ['foo'])
+
def test_paginate_query_assertion_error(self):
self.mox.ReplayAll()
self.assertRaises(AssertionError,
@@ -152,13 +168,13 @@ class TestPaginateQuery(test_base.BaseTestCase):
sort_dir=None, sort_dirs=['asc', 'desk'])
def test_paginate_query(self):
- sqlalchemy.asc('user').AndReturn('asc_1')
+ sqlalchemy.asc(self.model.user_id).AndReturn('asc_1')
self.query.order_by('asc_1').AndReturn(self.query)
- sqlalchemy.desc('project').AndReturn('desc_1')
+ sqlalchemy.desc(self.model.project_id).AndReturn('desc_1')
self.query.order_by('desc_1').AndReturn(self.query)
self.mox.StubOutWithMock(sqlalchemy.sql, 'and_')
- sqlalchemy.sql.and_(False).AndReturn('some_crit')
- sqlalchemy.sql.and_(True, False).AndReturn('another_crit')
+ sqlalchemy.sql.and_(mock.ANY).AndReturn('some_crit')
+ sqlalchemy.sql.and_(mock.ANY, mock.ANY).AndReturn('another_crit')
self.mox.StubOutWithMock(sqlalchemy.sql, 'or_')
sqlalchemy.sql.or_('some_crit', 'another_crit').AndReturn('some_f')
self.query.filter('some_f').AndReturn(self.query)
@@ -170,7 +186,7 @@ class TestPaginateQuery(test_base.BaseTestCase):
sort_dirs=['asc', 'desc'])
def test_paginate_query_value_error(self):
- sqlalchemy.asc('user').AndReturn('asc_1')
+ sqlalchemy.asc(self.model.user_id).AndReturn('asc_1')
self.query.order_by('asc_1').AndReturn(self.query)
self.mox.ReplayAll()
self.assertRaises(ValueError, utils.paginate_query,