diff options
author | Victor Stinner <vstinner@redhat.com> | 2015-05-14 00:50:34 +0200 |
---|---|---|
committer | Victor Stinner <vstinner@redhat.com> | 2015-05-19 14:54:20 -0700 |
commit | 60f80b5ded35613f5d3fda235e18b81ffd0f45d1 (patch) | |
tree | 6d8df960534ca0d5b29628a93e96a739f2b7df8d | |
parent | 1cbe72fbd87cb663bb38fe7f86a18d1e5d72df9b (diff) | |
download | oslo-db-60f80b5ded35613f5d3fda235e18b81ffd0f45d1.tar.gz |
Add a keys() method to SQLAlchemy ModelBase
With this additional method, it now possible to write directly
dict(obj), instead of dict(obj.iteritems()), to cast an object to a
dictionary.
Modify also ModelIterator: it doesn't inherit from ModelBase anymore.
Change-Id: I702be362a58155a28482e733e60539d36c039509
-rw-r--r-- | oslo_db/sqlalchemy/models.py | 8 | ||||
-rw-r--r-- | oslo_db/tests/sqlalchemy/test_models.py | 23 |
2 files changed, 28 insertions, 3 deletions
diff --git a/oslo_db/sqlalchemy/models.py b/oslo_db/sqlalchemy/models.py index 80fa01d..8edfe72 100644 --- a/oslo_db/sqlalchemy/models.py +++ b/oslo_db/sqlalchemy/models.py @@ -90,14 +90,18 @@ class ModelBase(six.Iterator): Includes attributes from joins. """ - local = dict(self) + local = dict((key, value) for key, value in self) joined = dict([(k, v) for k, v in six.iteritems(self.__dict__) if not k[0] == '_']) local.update(joined) return six.iteritems(local) + def keys(self): + """Make the model object behave like a dict.""" + return [key for key, value in self.iteritems()] + -class ModelIterator(ModelBase, six.Iterator): +class ModelIterator(six.Iterator): def __init__(self, model, columns): self.model = model diff --git a/oslo_db/tests/sqlalchemy/test_models.py b/oslo_db/tests/sqlalchemy/test_models.py index 4a45576..eac66f5 100644 --- a/oslo_db/tests/sqlalchemy/test_models.py +++ b/oslo_db/tests/sqlalchemy/test_models.py @@ -40,7 +40,8 @@ class ModelBaseTest(test_base.DbTestCase): 'get', 'update', 'save', - 'iteritems') + 'iteritems', + 'keys') for method in dict_methods: self.assertTrue(hasattr(models.ModelBase, method), "Method %s() is not found" % method) @@ -80,6 +81,18 @@ class ModelBaseTest(test_base.DbTestCase): self.ekm.update(h) self.assertEqual(dict(self.ekm.iteritems()), expected) + def test_modelbase_dict(self): + h = {'a': '1', 'b': '2'} + expected = { + 'id': None, + 'smth': None, + 'name': 'NAME', + 'a': '1', + 'b': '2', + } + self.ekm.update(h) + self.assertEqual(dict(self.ekm), expected) + def test_modelbase_iter(self): expected = { 'id': None, @@ -97,6 +110,14 @@ class ModelBaseTest(test_base.DbTestCase): self.assertEqual(len(expected), found_items) + def test_modelbase_keys(self): + self.assertEqual(set(self.ekm.keys()), + set(('id', 'smth', 'name'))) + + self.ekm.update({'a': '1', 'b': '2'}) + self.assertEqual(set(self.ekm.keys()), + set(('a', 'b', 'id', 'smth', 'name'))) + def test_modelbase_several_iters(self): mb = ExtraKeysModel() it1 = iter(mb) |