summaryrefslogtreecommitdiff
path: root/tests/unit/scheduler/test_base_filter.py
blob: a47b0db43c89313f52eb11b595e12c9a73f08aaf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# Copyright (c) 2013 OpenStack Foundation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import mock
from oslotest import base as test_base
from oslotest import moxstubout

from openstack.common.scheduler import base_filter


class TestBaseFilter(test_base.BaseTestCase):

    def setUp(self):
        super(TestBaseFilter, self).setUp()
        self.mox = self.useFixture(moxstubout.MoxStubout()).mox
        self.filter = base_filter.BaseFilter()

    def test_filter_one_is_called(self):
        filters = [1, 2, 3, 4]
        filter_properties = {'x': 'y'}
        self.mox.StubOutWithMock(self.filter, '_filter_one')

        self.filter._filter_one(1, filter_properties).AndReturn(False)
        self.filter._filter_one(2, filter_properties).AndReturn(True)
        self.filter._filter_one(3, filter_properties).AndReturn(True)
        self.filter._filter_one(4, filter_properties).AndReturn(False)

        self.mox.ReplayAll()

        result = list(self.filter.filter_all(filters, filter_properties))
        self.assertEqual([2, 3], result)


class FakeExtension(object):

    def __init__(self, plugin):
        self.plugin = plugin


class BaseFakeFilter(base_filter.BaseFilter):
    pass


class FakeFilter1(BaseFakeFilter):
    """Derives from BaseFakeFilter and has a fake entry point defined.

    Entry point is returned by fake ExtensionManager.
    Should be included in the output of all_classes.
    """
    pass


class FakeFilter2(BaseFakeFilter):
    """Derives from BaseFakeFilter but has no entry point.

    Should be not included in all_classes.
    """
    pass


class FakeFilter3(base_filter.BaseFilter):
    """Does not derive from BaseFakeFilter.

    Should not be included.
    """
    pass


class FakeFilter4(BaseFakeFilter):
    """Derives from BaseFakeFilter and has an entry point.

    Should be included.
    """
    pass


class FakeFilter5(BaseFakeFilter):
    """Derives from BaseFakeFilter but has no entry point.

    Should not be included.
    """
    run_filter_once_per_request = True
    pass


class FakeExtensionManager(list):

    def __init__(self, namespace):
        classes = [FakeFilter1, FakeFilter3, FakeFilter4]
        exts = map(FakeExtension, classes)
        super(FakeExtensionManager, self).__init__(exts)
        self.namespace = namespace


class TestBaseFilterHandler(test_base.BaseTestCase):

    def setUp(self):
        super(TestBaseFilterHandler, self).setUp()
        self.stubs = self.useFixture(moxstubout.MoxStubout()).stubs
        self.stubs.Set(base_filter.base_handler.extension, 'ExtensionManager',
                       FakeExtensionManager)
        self.handler = base_filter.BaseFilterHandler(BaseFakeFilter,
                                                     'fake_filters')

    def test_get_all_classes(self):
        # In order for a FakeFilter to be returned by get_all_classes, it has
        # to comply with these rules:
        # * It must be derived from BaseFakeFilter
        #   AND
        # * It must have a python entrypoint assigned (returned by
        #   FakeExtensionManager)
        expected = [FakeFilter1, FakeFilter4]
        result = self.handler.get_all_classes()
        self.assertEqual(expected, result)

    def _get_filtered_objects(self, filter_classes, index=0):
        filter_objs_initial = [1, 2, 3, 4]
        filter_properties = {'x': 'y'}
        return self.handler.get_filtered_objects(filter_classes,
                                                 filter_objs_initial,
                                                 filter_properties,
                                                 index)

    @mock.patch.object(FakeFilter4, 'filter_all')
    @mock.patch.object(FakeFilter3, 'filter_all', return_value=None)
    def test_get_filtered_objects_return_none(self, fake3_filter_all,
                                              fake4_filter_all):
        filter_classes = [FakeFilter1, FakeFilter2, FakeFilter3, FakeFilter4]
        result = self._get_filtered_objects(filter_classes)
        self.assertIsNone(result)
        self.assertFalse(fake4_filter_all.called)

    def test_get_filtered_objects(self):
        filter_objs_expected = [1, 2, 3, 4]
        filter_classes = [FakeFilter1, FakeFilter2, FakeFilter3, FakeFilter4]
        result = self._get_filtered_objects(filter_classes)
        self.assertEqual(filter_objs_expected, result)

    def test_get_filtered_objects_with_filter_run_once(self):
        filter_objs_expected = [1, 2, 3, 4]
        filter_classes = [FakeFilter5]

        with mock.patch.object(FakeFilter5, 'filter_all',
                               return_value=filter_objs_expected
                               ) as fake5_filter_all:
            result = self._get_filtered_objects(filter_classes)
            self.assertEqual(filter_objs_expected, result)
            self.assertEqual(1, fake5_filter_all.call_count)

            result = self._get_filtered_objects(filter_classes, index=1)
            self.assertEqual(filter_objs_expected, result)
            self.assertEqual(1, fake5_filter_all.call_count)

            result = self._get_filtered_objects(filter_classes, index=2)
            self.assertEqual(filter_objs_expected, result)
            self.assertEqual(1, fake5_filter_all.call_count)