summaryrefslogtreecommitdiff
path: root/pstl/include/pstl/internal/omp/parallel_stable_sort.h
blob: 6f9dce528960c678c23006701a3313cfca606d34 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
// -*- C++ -*-
// -*-===----------------------------------------------------------------------===//
//
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
//
//===----------------------------------------------------------------------===//

#ifndef _PSTL_INTERNAL_OMP_PARALLEL_STABLE_SORT_H
#define _PSTL_INTERNAL_OMP_PARALLEL_STABLE_SORT_H

#include "util.h"
#include "parallel_merge.h"

namespace __pstl
{
namespace __omp_backend
{

namespace __sort_details
{
struct __move_value
{
    template <typename _Iterator, typename _OutputIterator>
    void
    operator()(_Iterator __x, _OutputIterator __z) const
    {
        *__z = std::move(*__x);
    }
};

template <typename _RandomAccessIterator, typename _OutputIterator>
_OutputIterator
__parallel_move_range(_RandomAccessIterator __first1, _RandomAccessIterator __last1, _OutputIterator __d_first)
{
    std::size_t __size = __last1 - __first1;

    // Perform serial moving of small chunks

    if (__size <= __default_chunk_size)
    {
        return std::move(__first1, __last1, __d_first);
    }

    // Perform parallel moving of larger chunks
    auto __policy = __pstl::__omp_backend::__chunk_partitioner(__first1, __last1);

    _PSTL_PRAGMA(omp taskloop)
    for (std::size_t __chunk = 0; __chunk < __policy.__n_chunks; ++__chunk)
    {
        __pstl::__omp_backend::__process_chunk(__policy, __first1, __chunk,
                                       [&](auto __chunk_first, auto __chunk_last)
                                       {
                                           auto __chunk_offset = __chunk_first - __first1;
                                           auto __output_it = __d_first + __chunk_offset;
                                           std::move(__chunk_first, __chunk_last, __output_it);
                                       });
    }

    return __d_first + __size;
}

struct __move_range
{
    template <typename _RandomAccessIterator, typename _OutputIterator>
    _OutputIterator
    operator()(_RandomAccessIterator __first1, _RandomAccessIterator __last1, _OutputIterator __d_first) const
    {
        return __pstl::__omp_backend::__sort_details::__parallel_move_range(__first1, __last1, __d_first);
    }
};
} // namespace __sort_details

template <typename _RandomAccessIterator, typename _Compare, typename _LeafSort>
void
__parallel_stable_sort_body(_RandomAccessIterator __xs, _RandomAccessIterator __xe, _Compare __comp,
                            _LeafSort __leaf_sort)
{
    using _ValueType = typename std::iterator_traits<_RandomAccessIterator>::value_type;
    using _VecType = typename std::vector<_ValueType>;
    using _OutputIterator = typename _VecType::iterator;
    using _MoveValue = typename __omp_backend::__sort_details::__move_value;
    using _MoveRange = __omp_backend::__sort_details::__move_range;

    if (__should_run_serial(__xs, __xe))
    {
        __leaf_sort(__xs, __xe, __comp);
    }
    else
    {
        std::size_t __size = __xe - __xs;
        auto __mid = __xs + (__size / 2);
        __pstl::__omp_backend::__parallel_invoke_body(
            [&]() { __parallel_stable_sort_body(__xs, __mid, __comp, __leaf_sort); },
            [&]() { __parallel_stable_sort_body(__mid, __xe, __comp, __leaf_sort); });

        // Perform a parallel merge of the sorted ranges into __output_data.
        _VecType __output_data(__size);
        _MoveValue __move_value;
        _MoveRange __move_range;
        __utils::__serial_move_merge __merge(__size);
        __pstl::__omp_backend::__parallel_merge_body(
            __mid - __xs, __xe - __mid, __xs, __mid, __mid, __xe, __output_data.begin(), __comp,
            [&__merge, &__move_value, &__move_range](_RandomAccessIterator __as, _RandomAccessIterator __ae,
                                                     _RandomAccessIterator __bs, _RandomAccessIterator __be,
                                                     _OutputIterator __cs, _Compare __comp)
            { __merge(__as, __ae, __bs, __be, __cs, __comp, __move_value, __move_value, __move_range, __move_range); });

        // Move the values from __output_data back in the original source range.
        __pstl::__omp_backend::__sort_details::__parallel_move_range(__output_data.begin(), __output_data.end(), __xs);
    }
}

template <class _ExecutionPolicy, typename _RandomAccessIterator, typename _Compare, typename _LeafSort>
void
__parallel_stable_sort(_ExecutionPolicy&& /*__exec*/, _RandomAccessIterator __xs, _RandomAccessIterator __xe,
                       _Compare __comp, _LeafSort __leaf_sort, std::size_t __nsort = 0)
{
    auto __count = static_cast<std::size_t>(__xe - __xs);
    if (__count <= __default_chunk_size || __nsort < __count)
    {
        __leaf_sort(__xs, __xe, __comp);
        return;
    }

    // TODO: the partial sort implementation should
    // be shared with the other backends.

    if (omp_in_parallel())
    {
        if (__count <= __nsort)
        {
            __pstl::__omp_backend::__parallel_stable_sort_body(__xs, __xe, __comp, __leaf_sort);
        }
        else
        {
            __pstl::__omp_backend::__parallel_stable_partial_sort(__xs, __xe, __comp, __leaf_sort, __nsort);
        }
    }
    else
    {
        _PSTL_PRAGMA(omp parallel)
        _PSTL_PRAGMA(omp single nowait)
        if (__count <= __nsort)
        {
            __pstl::__omp_backend::__parallel_stable_sort_body(__xs, __xe, __comp, __leaf_sort);
        }
        else
        {
            __pstl::__omp_backend::__parallel_stable_partial_sort(__xs, __xe, __comp, __leaf_sort, __nsort);
        }
    }
}

} // namespace __omp_backend
} // namespace __pstl
#endif // _PSTL_INTERNAL_OMP_PARALLEL_STABLE_SORT_H