summaryrefslogtreecommitdiff
path: root/testtools/matchers/_datastructures.py
blob: 7c0f9bd78f937b6165e6f317647a395949eed5ea (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
# Copyright (c) 2009-2015 testtools developers. See LICENSE for details.

__all__ = [
    'ContainsAll',
    'MatchesListwise',
    'MatchesSetwise',
    'MatchesStructure',
    ]

"""Matchers that operate with knowledge of Python data structures."""

from ..helpers import map_values
from ._higherorder import (
    Annotate,
    MatchesAll,
    MismatchesAll,
    )
from ._impl import Mismatch


def ContainsAll(items):
    """Make a matcher that checks whether a list of things is contained
    in another thing.

    The matcher effectively checks that the provided sequence is a subset of
    the matchee.
    """
    from ._basic import Contains
    return MatchesAll(*map(Contains, items), first_only=False)


class MatchesListwise:
    """Matches if each matcher matches the corresponding value.

    More easily explained by example than in words:

    >>> from ._basic import Equals
    >>> MatchesListwise([Equals(1)]).match([1])
    >>> MatchesListwise([Equals(1), Equals(2)]).match([1, 2])
    >>> print (MatchesListwise([Equals(1), Equals(2)]).match([2, 1]).describe())
    Differences: [
    2 != 1
    1 != 2
    ]
    >>> matcher = MatchesListwise([Equals(1), Equals(2)], first_only=True)
    >>> print (matcher.match([3, 4]).describe())
    3 != 1
    """

    def __init__(self, matchers, first_only=False):
        """Construct a MatchesListwise matcher.

        :param matchers: A list of matcher that the matched values must match.
        :param first_only: If True, then only report the first mismatch,
            otherwise report all of them. Defaults to False.
        """
        self.matchers = matchers
        self.first_only = first_only

    def match(self, values):
        from ._basic import HasLength
        mismatches = []
        length_mismatch = Annotate(
            "Length mismatch", HasLength(len(self.matchers))).match(values)
        if length_mismatch:
            mismatches.append(length_mismatch)
        for matcher, value in zip(self.matchers, values):
            mismatch = matcher.match(value)
            if mismatch:
                if self.first_only:
                    return mismatch
                mismatches.append(mismatch)
        if mismatches:
            return MismatchesAll(mismatches)


class MatchesStructure:
    """Matcher that matches an object structurally.

    'Structurally' here means that attributes of the object being matched are
    compared against given matchers.

    `fromExample` allows the creation of a matcher from a prototype object and
    then modified versions can be created with `update`.

    `byEquality` creates a matcher in much the same way as the constructor,
    except that the matcher for each of the attributes is assumed to be
    `Equals`.

    `byMatcher` creates a similar matcher to `byEquality`, but you get to pick
    the matcher, rather than just using `Equals`.
    """

    def __init__(self, **kwargs):
        """Construct a `MatchesStructure`.

        :param kwargs: A mapping of attributes to matchers.
        """
        self.kws = kwargs

    @classmethod
    def byEquality(cls, **kwargs):
        """Matches an object where the attributes equal the keyword values.

        Similar to the constructor, except that the matcher is assumed to be
        Equals.
        """
        from ._basic import Equals
        return cls.byMatcher(Equals, **kwargs)

    @classmethod
    def byMatcher(cls, matcher, **kwargs):
        """Matches an object where the attributes match the keyword values.

        Similar to the constructor, except that the provided matcher is used
        to match all of the values.
        """
        return cls(**map_values(matcher, kwargs))

    @classmethod
    def fromExample(cls, example, *attributes):
        from ._basic import Equals
        kwargs = {}
        for attr in attributes:
            kwargs[attr] = Equals(getattr(example, attr))
        return cls(**kwargs)

    def update(self, **kws):
        new_kws = self.kws.copy()
        for attr, matcher in kws.items():
            if matcher is None:
                new_kws.pop(attr, None)
            else:
                new_kws[attr] = matcher
        return type(self)(**new_kws)

    def __str__(self):
        kws = []
        for attr, matcher in sorted(self.kws.items()):
            kws.append("{}={}".format(attr, matcher))
        return "{}({})".format(self.__class__.__name__, ', '.join(kws))

    def match(self, value):
        matchers = []
        values = []
        for attr, matcher in sorted(self.kws.items()):
            matchers.append(Annotate(attr, matcher))
            values.append(getattr(value, attr))
        return MatchesListwise(matchers).match(values)


class MatchesSetwise:
    """Matches if all the matchers match elements of the value being matched.

    That is, each element in the 'observed' set must match exactly one matcher
    from the set of matchers, with no matchers left over.

    The difference compared to `MatchesListwise` is that the order of the
    matchings does not matter.
    """

    def __init__(self, *matchers):
        self.matchers = matchers

    def match(self, observed):
        remaining_matchers = set(self.matchers)
        not_matched = []
        for value in observed:
            for matcher in remaining_matchers:
                if matcher.match(value) is None:
                    remaining_matchers.remove(matcher)
                    break
            else:
                not_matched.append(value)
        if not_matched or remaining_matchers:
            remaining_matchers = list(remaining_matchers)
            # There are various cases that all should be reported somewhat
            # differently.

            # There are two trivial cases:
            # 1) There are just some matchers left over.
            # 2) There are just some values left over.

            # Then there are three more interesting cases:
            # 3) There are the same number of matchers and values left over.
            # 4) There are more matchers left over than values.
            # 5) There are more values left over than matchers.

            if len(not_matched) == 0:
                if len(remaining_matchers) > 1:
                    msg = "There were {} matchers left over: ".format(
                        len(remaining_matchers))
                else:
                    msg = "There was 1 matcher left over: "
                msg += ', '.join(map(str, remaining_matchers))
                return Mismatch(msg)
            elif len(remaining_matchers) == 0:
                if len(not_matched) > 1:
                    return Mismatch(
                        "There were {} values left over: {}".format(
                            len(not_matched), not_matched))
                else:
                    return Mismatch(
                        "There was 1 value left over: {}".format(
                            not_matched))
            else:
                common_length = min(len(remaining_matchers), len(not_matched))
                if common_length == 0:
                    raise AssertionError("common_length can't be 0 here")
                if common_length > 1:
                    msg = "There were {} mismatches".format(common_length)
                else:
                    msg = "There was 1 mismatch"
                if len(remaining_matchers) > len(not_matched):
                    extra_matchers = remaining_matchers[common_length:]
                    msg += " and {} extra matcher".format(len(extra_matchers))
                    if len(extra_matchers) > 1:
                        msg += "s"
                    msg += ': ' + ', '.join(map(str, extra_matchers))
                elif len(not_matched) > len(remaining_matchers):
                    extra_values = not_matched[common_length:]
                    msg += " and {} extra value".format(len(extra_values))
                    if len(extra_values) > 1:
                        msg += "s"
                    msg += ': ' + str(extra_values)
                return Annotate(
                    msg, MatchesListwise(remaining_matchers[:common_length])
                    ).match(not_matched[:common_length])