summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVictor Stinner <vstinner@redhat.com>2015-05-14 00:50:34 +0200
committerVictor Stinner <vstinner@redhat.com>2015-05-19 14:54:20 -0700
commit60f80b5ded35613f5d3fda235e18b81ffd0f45d1 (patch)
tree6d8df960534ca0d5b29628a93e96a739f2b7df8d
parent1cbe72fbd87cb663bb38fe7f86a18d1e5d72df9b (diff)
downloadoslo-db-60f80b5ded35613f5d3fda235e18b81ffd0f45d1.tar.gz
Add a keys() method to SQLAlchemy ModelBase
With this additional method, it now possible to write directly dict(obj), instead of dict(obj.iteritems()), to cast an object to a dictionary. Modify also ModelIterator: it doesn't inherit from ModelBase anymore. Change-Id: I702be362a58155a28482e733e60539d36c039509
-rw-r--r--oslo_db/sqlalchemy/models.py8
-rw-r--r--oslo_db/tests/sqlalchemy/test_models.py23
2 files changed, 28 insertions, 3 deletions
diff --git a/oslo_db/sqlalchemy/models.py b/oslo_db/sqlalchemy/models.py
index 80fa01d..8edfe72 100644
--- a/oslo_db/sqlalchemy/models.py
+++ b/oslo_db/sqlalchemy/models.py
@@ -90,14 +90,18 @@ class ModelBase(six.Iterator):
Includes attributes from joins.
"""
- local = dict(self)
+ local = dict((key, value) for key, value in self)
joined = dict([(k, v) for k, v in six.iteritems(self.__dict__)
if not k[0] == '_'])
local.update(joined)
return six.iteritems(local)
+ def keys(self):
+ """Make the model object behave like a dict."""
+ return [key for key, value in self.iteritems()]
+
-class ModelIterator(ModelBase, six.Iterator):
+class ModelIterator(six.Iterator):
def __init__(self, model, columns):
self.model = model
diff --git a/oslo_db/tests/sqlalchemy/test_models.py b/oslo_db/tests/sqlalchemy/test_models.py
index 4a45576..eac66f5 100644
--- a/oslo_db/tests/sqlalchemy/test_models.py
+++ b/oslo_db/tests/sqlalchemy/test_models.py
@@ -40,7 +40,8 @@ class ModelBaseTest(test_base.DbTestCase):
'get',
'update',
'save',
- 'iteritems')
+ 'iteritems',
+ 'keys')
for method in dict_methods:
self.assertTrue(hasattr(models.ModelBase, method),
"Method %s() is not found" % method)
@@ -80,6 +81,18 @@ class ModelBaseTest(test_base.DbTestCase):
self.ekm.update(h)
self.assertEqual(dict(self.ekm.iteritems()), expected)
+ def test_modelbase_dict(self):
+ h = {'a': '1', 'b': '2'}
+ expected = {
+ 'id': None,
+ 'smth': None,
+ 'name': 'NAME',
+ 'a': '1',
+ 'b': '2',
+ }
+ self.ekm.update(h)
+ self.assertEqual(dict(self.ekm), expected)
+
def test_modelbase_iter(self):
expected = {
'id': None,
@@ -97,6 +110,14 @@ class ModelBaseTest(test_base.DbTestCase):
self.assertEqual(len(expected), found_items)
+ def test_modelbase_keys(self):
+ self.assertEqual(set(self.ekm.keys()),
+ set(('id', 'smth', 'name')))
+
+ self.ekm.update({'a': '1', 'b': '2'})
+ self.assertEqual(set(self.ekm.keys()),
+ set(('a', 'b', 'id', 'smth', 'name')))
+
def test_modelbase_several_iters(self):
mb = ExtraKeysModel()
it1 = iter(mb)