summaryrefslogtreecommitdiff
path: root/pstl/test/std/numerics/numeric.ops/scan.pass.cpp
blob: b9eef530f5f59e683f079845651c6658d76c58d8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
// -*- C++ -*-
//===-- scan.pass.cpp -----------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

// UNSUPPORTED: c++03, c++11, c++14

#include "support/pstl_test_config.h"

#include <execution>
#include <numeric>

#include "support/utils.h"

using namespace TestUtils;

// We provide the no execution policy versions of the exclusive_scan and inclusive_scan due checking correctness result of the versions with execution policies.
//TODO: to add a macro for availability of ver implementations
template <class InputIterator, class OutputIterator, class T>
OutputIterator
exclusive_scan_serial(InputIterator first, InputIterator last, OutputIterator result, T init)
{
    for (; first != last; ++first, ++result)
    {
        *result = init;
        init = init + *first;
    }
    return result;
}

template <class InputIterator, class OutputIterator, class T, class BinaryOperation>
OutputIterator
exclusive_scan_serial(InputIterator first, InputIterator last, OutputIterator result, T init, BinaryOperation binary_op)
{
    for (; first != last; ++first, ++result)
    {
        *result = init;
        init = binary_op(init, *first);
    }
    return result;
}

// Note: N4582 is missing the ", class T".  Issue was reported 2016-Apr-11 to cxxeditor@gmail.com
template <class InputIterator, class OutputIterator, class BinaryOperation, class T>
OutputIterator
inclusive_scan_serial(InputIterator first, InputIterator last, OutputIterator result, BinaryOperation binary_op, T init)
{
    for (; first != last; ++first, ++result)
    {
        init = binary_op(init, *first);
        *result = init;
    }
    return result;
}

template <class InputIterator, class OutputIterator, class BinaryOperation>
OutputIterator
inclusive_scan_serial(InputIterator first, InputIterator last, OutputIterator result, BinaryOperation binary_op)
{
    if (first != last)
    {
        auto tmp = *first;
        *result = tmp;
        return inclusive_scan_serial(++first, last, ++result, binary_op, tmp);
    }
    else
    {
        return result;
    }
}

template <class InputIterator, class OutputIterator>
OutputIterator
inclusive_scan_serial(InputIterator first, InputIterator last, OutputIterator result)
{
    typedef typename std::iterator_traits<InputIterator>::value_type input_type;
    return inclusive_scan_serial(first, last, result, std::plus<input_type>());
}

// Most of the framework required for testing inclusive and exclusive scan is identical,
// so the tests for both are in this file.  Which is being tested is controlled by the global
// flag inclusive, which is set to each alternative by main().
static bool inclusive;

template <typename Iterator, typename Size, typename T>
void
check_and_reset(Iterator expected_first, Iterator out_first, Size n, T trash)
{
    EXPECT_EQ_N(expected_first, out_first, n,
                inclusive ? "wrong result from inclusive_scan" : "wrong result from exclusive_scan");
    std::fill_n(out_first, n, trash);
}

struct test_scan_with_plus
{
    template <typename Policy, typename Iterator1, typename Iterator2, typename Iterator3, typename Size, typename T>
    void
    operator()(Policy&& exec, Iterator1 in_first, Iterator1 in_last, Iterator2 out_first, Iterator2 out_last,
               Iterator3 expected_first, Iterator3, Size n, T init, T trash)
    {
        using namespace std;

        auto orr1 = inclusive ? inclusive_scan_serial(in_first, in_last, expected_first)
                              : exclusive_scan_serial(in_first, in_last, expected_first, init);
        (void)orr1;
        auto orr = inclusive ? inclusive_scan(exec, in_first, in_last, out_first)
                             : exclusive_scan(exec, in_first, in_last, out_first, init);
        EXPECT_TRUE(out_last == orr,
                    inclusive ? "inclusive_scan returned wrong iterator" : "exclusive_scan returned wrong iterator");

        check_and_reset(expected_first, out_first, n, trash);
        fill(out_first, out_last, trash);
    }
};

template <typename T, typename Convert>
void
test_with_plus(T init, T trash, Convert convert)
{
    for (size_t n = 0; n <= 100000; n = n <= 16 ? n + 1 : size_t(3.1415 * n))
    {
        Sequence<T> in(n, convert);
        Sequence<T> expected(in);
        Sequence<T> out(n, [&](int32_t) { return trash; });

        invoke_on_all_policies(test_scan_with_plus(), in.begin(), in.end(), out.begin(), out.end(), expected.begin(),
                               expected.end(), in.size(), init, trash);
        invoke_on_all_policies(test_scan_with_plus(), in.cbegin(), in.cend(), out.begin(), out.end(), expected.begin(),
                               expected.end(), in.size(), init, trash);
    }
}
struct test_scan_with_binary_op
{
    template <typename Policy, typename Iterator1, typename Iterator2, typename Iterator3, typename Size, typename T,
              typename BinaryOp>
    typename std::enable_if<!TestUtils::isReverse<Iterator1>::value, void>::type
    operator()(Policy&& exec, Iterator1 in_first, Iterator1 in_last, Iterator2 out_first, Iterator2 out_last,
               Iterator3 expected_first, Iterator3, Size n, T init, BinaryOp binary_op, T trash)
    {
        using namespace std;

        auto orr1 = inclusive ? inclusive_scan_serial(in_first, in_last, expected_first, binary_op, init)
                              : exclusive_scan_serial(in_first, in_last, expected_first, init, binary_op);
        (void)orr1;
        auto orr = inclusive ? inclusive_scan(exec, in_first, in_last, out_first, binary_op, init)
                             : exclusive_scan(exec, in_first, in_last, out_first, init, binary_op);

        EXPECT_TRUE(out_last == orr, "scan returned wrong iterator");
        check_and_reset(expected_first, out_first, n, trash);
    }

    template <typename Policy, typename Iterator1, typename Iterator2, typename Iterator3, typename Size, typename T,
              typename BinaryOp>
    typename std::enable_if<TestUtils::isReverse<Iterator1>::value, void>::type
    operator()(Policy&&, Iterator1, Iterator1, Iterator2, Iterator2, Iterator3, Iterator3, Size, T, BinaryOp, T)
    {
    }
};

template <typename In, typename Out, typename BinaryOp>
void
test_matrix(Out init, BinaryOp binary_op, Out trash)
{
    for (size_t n = 0; n <= 100000; n = n <= 16 ? n + 1 : size_t(3.1415 * n))
    {
        Sequence<In> in(n, [](size_t k) { return In(k, k + 1); });

        Sequence<Out> out(n, [&](size_t) { return trash; });
        Sequence<Out> expected(n, [&](size_t) { return trash; });

        invoke_on_all_policies(test_scan_with_binary_op(), in.begin(), in.end(), out.begin(), out.end(),
                               expected.begin(), expected.end(), in.size(), init, binary_op, trash);
        invoke_on_all_policies(test_scan_with_binary_op(), in.cbegin(), in.cend(), out.begin(), out.end(),
                               expected.begin(), expected.end(), in.size(), init, binary_op, trash);
    }
}

int
main()
{
    for (int32_t mode = 0; mode < 2; ++mode)
    {
        inclusive = mode != 0;
#if !_PSTL_ICC_19_TEST_SIMD_UDS_WINDOWS_RELEASE_BROKEN
        // Test with highly restricted type and associative but not commutative operation
        test_matrix<Matrix2x2<int32_t>, Matrix2x2<int32_t>>(Matrix2x2<int32_t>(), multiply_matrix<int32_t>,
                                                            Matrix2x2<int32_t>(-666, 666));
#endif

        // Since the implict "+" forms of the scan delegate to the generic forms,
        // there's little point in using a highly restricted type, so just use double.
        test_with_plus<float64_t>(inclusive ? 0.0 : -1.0, -666.0,
                                  [](uint32_t k) { return float64_t((k % 991 + 1) ^ (k % 997 + 2)); });
    }
    std::cout << done() << std::endl;
    return 0;
}