From 03c729c4791d8fd0feb115fa18b239d993771497 Mon Sep 17 00:00:00 2001 From: Erik Olof Gunnar Andersson Date: Sun, 9 Oct 2022 00:02:14 -0700 Subject: Cleaned up and optimized sqlalchemy base - Cleaned up sql syntax. - Fixed minor inconsistencies. - Improved test coverage. Change-Id: I5986dc9fdd1607119a71872637f836d211186c1e --- designate/sqlalchemy/base.py | 110 +++++++++++------------- designate/tests/test_sqlalchemy.py | 74 ++++++++++++---- designate/tests/test_storage/__init__.py | 30 +++---- designate/tests/test_storage/test_sqlalchemy.py | 42 ++++----- 4 files changed, 140 insertions(+), 116 deletions(-) diff --git a/designate/sqlalchemy/base.py b/designate/sqlalchemy/base.py index d72f36ca..d9a7eb66 100644 --- a/designate/sqlalchemy/base.py +++ b/designate/sqlalchemy/base.py @@ -78,7 +78,7 @@ class SQLAlchemy(object, metaclass=abc.ABCMeta): @property def session(self): # NOTE: This uses a thread local store, allowing each greenthread to - # have it's own session stored correctly. Without this, each + # have its own session stored correctly. Without this, each # greenthread may end up using a single global session, which # leads to bad things happening. @@ -101,51 +101,38 @@ class SQLAlchemy(object, metaclass=abc.ABCMeta): @staticmethod def _apply_criterion(table, query, criterion): - if criterion is not None: - for name, value in criterion.items(): - column = getattr(table.c, name) + if criterion is None: + return query + for name, value in criterion.items(): + column = getattr(table.c, name) + + if isinstance(value, str): # Wildcard value: '%' - if isinstance(value, str) and '%' in value: + if '%' in value: query = query.where(column.like(value)) - - elif (isinstance(value, str) and - value.startswith('!')): - queryval = value[1:] - query = query.where(column != queryval) - - elif (isinstance(value, str) and - value.startswith('<=')): - queryval = value[2:] - query = query.where(column <= queryval) - - elif (isinstance(value, str) and - value.startswith('<')): - queryval = value[1:] - query = query.where(column < queryval) - - elif (isinstance(value, str) and - value.startswith('>=')): - queryval = value[2:] - query = query.where(column >= queryval) - - elif (isinstance(value, str) and - value.startswith('>')): - queryval = value[1:] - query = query.where(column > queryval) - - elif (isinstance(value, str) and - value.startswith('BETWEEN')): - elements = [i.strip(" ") for i in - value.split(" ", 1)[1].strip(" ").split(",")] - query = query.where(between( - column, elements[0], elements[1])) - - elif isinstance(value, list): - query = query.where(column.in_(value)) - + elif value.startswith('!'): + query = query.where(column != value[1:]) + elif value.startswith('<='): + query = query.where(column <= value[2:]) + elif value.startswith('<'): + query = query.where(column < value[1:]) + elif value.startswith('>='): + query = query.where(column >= value[2:]) + elif value.startswith('>'): + query = query.where(column > value[1:]) + elif value.startswith('BETWEEN'): + elements = [i.strip(' ') for i in + value.split(' ', 1)[1].strip(' ').split(',')] + query = query.where( + between(column, elements[0], elements[1]) + ) else: query = query.where(column == value) + elif isinstance(value, list): + query = query.where(column.in_(value)) + else: + query = query.where(column == value) return query @@ -211,8 +198,7 @@ class SQLAlchemy(object, metaclass=abc.ABCMeta): try: resultproxy = self.session.execute(query, [dict(values)]) except oslo_db_exception.DBDuplicateEntry: - msg = "Duplicate %s" % obj.obj_name() - raise exc_dup(msg) + raise exc_dup("Duplicate %s" % obj.obj_name()) # Refetch the row, for generated columns etc query = select([table]).where( @@ -247,8 +233,7 @@ class SQLAlchemy(object, metaclass=abc.ABCMeta): results = resultproxy.fetchall() if len(results) != 1: - msg = "Could not find %s" % cls.obj_name() - raise exc_notfound(msg) + raise exc_notfound("Could not find %s" % cls.obj_name()) else: return _set_object_from_model(cls(), results[0]) else: @@ -300,13 +285,17 @@ class SQLAlchemy(object, metaclass=abc.ABCMeta): records_table, recordsets_table.c.id == records_table.c.recordset_id) - inner_q = select([recordsets_table.c.id, # 0 - RS ID - zones_table.c.name] # 1 - ZONE NAME - ).select_from(rzjoin).\ + inner_q = ( + select([recordsets_table.c.id, # 0 - RS ID + zones_table.c.name]). # 1 - ZONE NAME + select_from(rzjoin). where(zones_table.c.deleted == '0') + ) - count_q = select([func.count(distinct(recordsets_table.c.id))]).\ + count_q = ( + select([func.count(distinct(recordsets_table.c.id))]). select_from(rzjoin).where(zones_table.c.deleted == '0') + ) if index_hint: inner_q = inner_q.with_hint(recordsets_table, index_hint, @@ -507,7 +496,7 @@ class SQLAlchemy(object, metaclass=abc.ABCMeta): current_rrset.records.append(rrdata) else: - # We've already got an rrset, add the rdata + # We've already got a rrset, add the rdata if record[r_map['id']] is not None: rrdata = objects.Record() @@ -517,8 +506,8 @@ class SQLAlchemy(object, metaclass=abc.ABCMeta): current_rrset.records.append(rrdata) # If the last record examined was a new rrset, or there is only 1 rrset - if len(rrsets) == 0 or \ - (len(rrsets) != 0 and rrsets[-1] != current_rrset): + if (len(rrsets) == 0 or + (len(rrsets) != 0 and rrsets[-1] != current_rrset)): if current_rrset is not None: rrsets.append(current_rrset) @@ -539,9 +528,11 @@ class SQLAlchemy(object, metaclass=abc.ABCMeta): for skip_value in skip_values: values.pop(skip_value, None) - query = table.update()\ - .where(table.c.id == obj.id)\ - .values(**values) + query = ( + table.update(). + where(table.c.id == obj.id). + values(**values) + ) query = self._apply_tenant_criteria(context, table, query) query = self._apply_deleted_criteria(context, table, query) @@ -550,12 +541,10 @@ class SQLAlchemy(object, metaclass=abc.ABCMeta): try: resultproxy = self.session.execute(query) except oslo_db_exception.DBDuplicateEntry: - msg = "Duplicate %s" % obj.obj_name() - raise exc_dup(msg) + raise exc_dup("Duplicate %s" % obj.obj_name()) if resultproxy.rowcount != 1: - msg = "Could not find %s" % obj.obj_name() - raise exc_notfound(msg) + raise exc_notfound("Could not find %s" % obj.obj_name()) # Refetch the row, for generated columns etc query = select([table]).where(table.c.id == obj.id) @@ -598,8 +587,7 @@ class SQLAlchemy(object, metaclass=abc.ABCMeta): resultproxy = self.session.execute(query) if resultproxy.rowcount != 1: - msg = "Could not find %s" % obj.obj_name() - raise exc_notfound(msg) + raise exc_notfound("Could not find %s" % obj.obj_name()) # Refetch the row, for generated columns etc query = select([table]).where(table.c.id == obj.id) diff --git a/designate/tests/test_sqlalchemy.py b/designate/tests/test_sqlalchemy.py index 8481229d..b411d579 100644 --- a/designate/tests/test_sqlalchemy.py +++ b/designate/tests/test_sqlalchemy.py @@ -37,79 +37,79 @@ class SQLAlchemyTestCase(TestCase): self.query = mock.Mock() def test_wildcard(self): - criterion = {"a": "%foo%"} + criterion = {'a': '%foo%'} - op = dummy_table.c.a.like("%foo") + op = dummy_table.c.a.like('%foo') with mock.patch.object(dummy_table.c.a, 'operate') as func: func.return_value = op base.SQLAlchemy._apply_criterion( dummy_table, self.query, criterion) - func.assert_called_with(operators.like_op, "%foo%", escape=None) + func.assert_called_with(operators.like_op, '%foo%', escape=None) self.query.where.assert_called_with(op) def test_ne(self): - criterion = {"a": "!foo"} + criterion = {'a': '!foo'} - op = dummy_table.c.a != "foo" + op = dummy_table.c.a != 'foo' with mock.patch.object(dummy_table.c.a, 'operate') as func: func.return_value = op base.SQLAlchemy._apply_criterion( dummy_table, self.query, criterion) - func.assert_called_with(operator.ne, "foo") + func.assert_called_with(operator.ne, 'foo') self.query.where.assert_called_with(op) def test_le(self): - criterion = {"a": "<=foo"} + criterion = {'a': '<=foo'} - op = dummy_table.c.a <= "foo" + op = dummy_table.c.a <= 'foo' with mock.patch.object(dummy_table.c.a, 'operate') as func: func.return_value = op base.SQLAlchemy._apply_criterion( dummy_table, self.query, criterion) - func.assert_called_with(operator.le, "foo") + func.assert_called_with(operator.le, 'foo') self.query.where.assert_called_with(op) def test_lt(self): - criterion = {"a": "=foo"} + criterion = {'a': '>=foo'} - op = dummy_table.c.a >= "foo" + op = dummy_table.c.a >= 'foo' with mock.patch.object(dummy_table.c.a, 'operate') as func: func.return_value = op base.SQLAlchemy._apply_criterion( dummy_table, self.query, criterion) - func.assert_called_with(operator.ge, "foo") + func.assert_called_with(operator.ge, 'foo') self.query.where.assert_called_with(op) def test_gt(self): - criterion = {"a": ">foo"} + criterion = {'a': '>foo'} - op = dummy_table.c.a > "foo" + op = dummy_table.c.a > 'foo' with mock.patch.object(dummy_table.c.a, 'operate') as func: func.return_value = op base.SQLAlchemy._apply_criterion( dummy_table, self.query, criterion) - func.assert_called_with(operator.gt, "foo") + func.assert_called_with(operator.gt, 'foo') self.query.where.assert_called_with(op) def test_between(self): - criterion = {"a": "BETWEEN 1,3"} + criterion = {'a': 'BETWEEN 1,3'} op = dummy_table.c.a.between(1, 3) with mock.patch.object(dummy_table.c.a, 'operate') as func: @@ -120,3 +120,39 @@ class SQLAlchemyTestCase(TestCase): func.assert_called_with(operators.between_op, '1', '3', symmetric=False) self.query.where.assert_called_with(op) + + def test_regular_string(self): + criterion = {'a': 'foo'} + + op = dummy_table.c.a.like('foo') + with mock.patch.object(dummy_table.c.a, 'operate') as func: + func.return_value = op + + base.SQLAlchemy._apply_criterion( + dummy_table, self.query, criterion) + func.assert_called_with(operator.eq, 'foo') + self.query.where.assert_called_with(op) + + def test_list(self): + criterion = {'a': ['foo']} + + op = dummy_table.c.a.between(1, 3) + with mock.patch.object(dummy_table.c.a, 'operate') as func: + func.return_value = op + + base.SQLAlchemy._apply_criterion( + dummy_table, self.query, criterion) + func.assert_called_with(operators.in_op, ['foo']) + self.query.where.assert_called_with(op) + + def test_boolean(self): + criterion = {'a': True} + + op = dummy_table.c.a.like('foo') + with mock.patch.object(dummy_table.c.a, 'operate') as func: + func.return_value = op + + base.SQLAlchemy._apply_criterion( + dummy_table, self.query, criterion) + func.assert_called_with(operator.eq, True) + self.query.where.assert_called_with(op) diff --git a/designate/tests/test_storage/__init__.py b/designate/tests/test_storage/__init__.py index be4069e8..dc93e14a 100644 --- a/designate/tests/test_storage/__init__.py +++ b/designate/tests/test_storage/__init__.py @@ -1544,7 +1544,7 @@ class StorageTestCase(object): def test_create_tld(self): values = { 'name': 'com', - 'description': u'This is a comment.' + 'description': 'This is a comment.' } result = self.storage.create_tld( @@ -1869,8 +1869,8 @@ class StorageTestCase(object): def test_create_pool_with_all_relations(self): values = { - 'name': u'Pool', - 'description': u'Pool description', + 'name': 'Pool', + 'description': 'Pool description', 'attributes': [{'key': 'scope', 'value': 'public'}], 'ns_records': [{'priority': 1, 'hostname': 'ns1.example.org.'}], 'nameservers': [{'host': "192.0.2.1", 'port': 53}], @@ -2029,8 +2029,8 @@ class StorageTestCase(object): def test_update_pool_with_all_relations(self): values = { - 'name': u'Pool-A', - 'description': u'Pool-A description', + 'name': 'Pool-A', + 'description': 'Pool-A description', 'attributes': [{'key': 'scope', 'value': 'public'}], 'ns_records': [{'priority': 1, 'hostname': 'ns1.example.org.'}], 'nameservers': [{'host': "192.0.2.1", 'port': 53}], @@ -2054,8 +2054,8 @@ class StorageTestCase(object): # we trigger an update rather than a create. values = { 'id': created_pool_id, - 'name': u'Pool-B', - 'description': u'Pool-B description', + 'name': 'Pool-B', + 'description': 'Pool-B description', 'attributes': [{'key': 'scope', 'value': 'private'}], 'ns_records': [{'priority': 1, 'hostname': 'ns2.example.org.'}], 'nameservers': [{'host': "192.0.2.5", 'port': 53}], @@ -2534,7 +2534,7 @@ class StorageTestCase(object): pool = self.create_pool(fixture=0) # Create 10 PoolTargets - created = [self.create_pool_target(pool, description=u'Target %d' % i) + created = [self.create_pool_target(pool, description='Target %d' % i) for i in range(10)] # Ensure we can page through the results. @@ -2546,9 +2546,9 @@ class StorageTestCase(object): # Create two pool_targets pool_target_one = self.create_pool_target( - pool, fixture=0, description=u'One') + pool, fixture=0, description='One') pool_target_two = self.create_pool_target( - pool, fixture=1, description=u'Two') + pool, fixture=1, description='Two') # Verify pool_target_one criterion = dict(description=pool_target_one['description']) @@ -2588,9 +2588,9 @@ class StorageTestCase(object): # Create two pool_targets pool_target_one = self.create_pool_target( - pool, fixture=0, description=u'One') + pool, fixture=0, description='One') pool_target_two = self.create_pool_target( - pool, fixture=1, description=u'Two') + pool, fixture=1, description='Two') # Verify pool_target_one criterion = dict(description=pool_target_one['description']) @@ -2622,16 +2622,16 @@ class StorageTestCase(object): def test_update_pool_target(self): pool = self.create_pool(fixture=0) - pool_target = self.create_pool_target(pool, description=u'One') + pool_target = self.create_pool_target(pool, description='One') # Update the pool_target - pool_target.description = u'Two' + pool_target.description = 'Two' pool_target = self.storage.update_pool_target( self.admin_context, pool_target) # Verify the new values - self.assertEqual(u'Two', pool_target.description) + self.assertEqual('Two', pool_target.description) # Ensure the version column was incremented self.assertEqual(2, pool_target.version) diff --git a/designate/tests/test_storage/test_sqlalchemy.py b/designate/tests/test_storage/test_sqlalchemy.py index 3133bce8..c2aeb7d1 100644 --- a/designate/tests/test_storage/test_sqlalchemy.py +++ b/designate/tests/test_storage/test_sqlalchemy.py @@ -30,27 +30,27 @@ class SqlalchemyStorageTest(StorageTestCase, TestCase): def test_schema_table_names(self): table_names = [ - u'blacklists', - u'pool_also_notifies', - u'pool_attributes', - u'pool_nameservers', - u'pool_ns_records', - u'pool_target_masters', - u'pool_target_options', - u'pool_targets', - u'pools', - u'quotas', - u'records', - u'recordsets', - u'service_statuses', - u'tlds', - u'tsigkeys', - u'zone_attributes', - u'zone_masters', - u'zone_tasks', - u'zone_transfer_accepts', - u'zone_transfer_requests', - u'zones' + 'blacklists', + 'pool_also_notifies', + 'pool_attributes', + 'pool_nameservers', + 'pool_ns_records', + 'pool_target_masters', + 'pool_target_options', + 'pool_targets', + 'pools', + 'quotas', + 'records', + 'recordsets', + 'service_statuses', + 'tlds', + 'tsigkeys', + 'zone_attributes', + 'zone_masters', + 'zone_tasks', + 'zone_transfer_accepts', + 'zone_transfer_requests', + 'zones' ] inspector = self.storage.get_inspector() -- cgit v1.2.1