diff options
author | Nikolas Klauser <n_klauser@apple.com> | 2023-05-05 09:16:05 -0700 |
---|---|---|
committer | Nikolas Klauser <nikolasklauser@berlin.de> | 2023-05-15 06:48:43 -0700 |
commit | cbd9e5454741ebe6b39521fe1a8ed4eed5c2c801 (patch) | |
tree | 7ba3167f596324562759feb3b372617cf4169384 /libcxx | |
parent | ae8cb6437294ca99ba203607c0dd522db4dbf6b6 (diff) | |
download | llvm-cbd9e5454741ebe6b39521fe1a8ed4eed5c2c801.tar.gz |
[libc++][PSTL] Implement std::transform
Reviewed By: ldionne, #libc
Spies: libcxx-commits
Differential Revision: https://reviews.llvm.org/D149615
Diffstat (limited to 'libcxx')
8 files changed, 290 insertions, 44 deletions
diff --git a/libcxx/include/CMakeLists.txt b/libcxx/include/CMakeLists.txt index 4dd363de4d17..f304b5dafef8 100644 --- a/libcxx/include/CMakeLists.txt +++ b/libcxx/include/CMakeLists.txt @@ -82,6 +82,7 @@ set(files __algorithm/pstl_find.h __algorithm/pstl_for_each.h __algorithm/pstl_frontend_dispatch.h + __algorithm/pstl_transform.h __algorithm/push_heap.h __algorithm/ranges_adjacent_find.h __algorithm/ranges_all_of.h diff --git a/libcxx/include/__algorithm/pstl_transform.h b/libcxx/include/__algorithm/pstl_transform.h new file mode 100644 index 000000000000..74a869583f51 --- /dev/null +++ b/libcxx/include/__algorithm/pstl_transform.h @@ -0,0 +1,129 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef _LIBCPP___ALGORITHM_PSTL_TRANSFORM_H +#define _LIBCPP___ALGORITHM_PSTL_TRANSFORM_H + +#include <__algorithm/transform.h> +#include <__config> +#include <__iterator/iterator_traits.h> +#include <__pstl/internal/parallel_backend.h> +#include <__pstl/internal/unseq_backend_simd.h> +#include <__type_traits/is_execution_policy.h> +#include <__utility/terminate_on_exception.h> + +#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) +# pragma GCC system_header +#endif + +#if !defined(_LIBCPP_HAS_NO_INCOMPLETE_PSTL) && _LIBCPP_STD_VER >= 17 + +_LIBCPP_BEGIN_NAMESPACE_STD + +template <class _ExecutionPolicy, + class _ForwardIterator, + class _ForwardOutIterator, + class _UnaryOperation, + enable_if_t<is_execution_policy_v<__remove_cvref_t<_ExecutionPolicy>>, int> = 0> +_LIBCPP_HIDE_FROM_ABI _ForwardOutIterator transform( + _ExecutionPolicy&& __policy, + _ForwardIterator __first, + _ForwardIterator __last, + _ForwardOutIterator __result, + _UnaryOperation __op) { + if constexpr (__is_parallel_execution_policy_v<_ExecutionPolicy> && + __is_cpp17_random_access_iterator<_ForwardIterator>::value && + __is_cpp17_random_access_iterator<_ForwardOutIterator>::value) { + std::__terminate_on_exception([&] { + __pstl::__par_backend::__parallel_for( + __pstl::__internal::__par_backend_tag{}, + __policy, + __first, + __last, + [&__policy, __op, __first, __result](_ForwardIterator __brick_first, _ForwardIterator __brick_last) { + return std::transform( + std::__remove_parallel_policy(__policy), + __brick_first, + __brick_last, + __result + (__brick_first - __first), + __op); + }); + }); + return __result + (__last - __first); + } else if constexpr (__is_unsequenced_execution_policy_v<_ExecutionPolicy> && + __is_cpp17_random_access_iterator<_ForwardIterator>::value && + __is_cpp17_random_access_iterator<_ForwardOutIterator>::value) { + return __pstl::__unseq_backend::__simd_walk_2( + __first, + __last - __first, + __result, + [&](__iter_reference<_ForwardIterator> __in_value, __iter_reference<_ForwardOutIterator> __out_value) { + __out_value = __op(__in_value); + }); + } else { + return std::transform(__first, __last, __result, __op); + } +} + +template <class _ExecutionPolicy, + class _ForwardIterator1, + class _ForwardIterator2, + class _ForwardOutIterator, + class _BinaryOperation, + enable_if_t<is_execution_policy_v<__remove_cvref_t<_ExecutionPolicy>>, int> = 0> +_LIBCPP_HIDE_FROM_ABI _ForwardOutIterator transform( + _ExecutionPolicy&& __policy, + _ForwardIterator1 __first1, + _ForwardIterator1 __last1, + _ForwardIterator2 __first2, + _ForwardOutIterator __result, + _BinaryOperation __op) { + if constexpr (__is_parallel_execution_policy_v<_ExecutionPolicy> && + __is_cpp17_random_access_iterator<_ForwardIterator1>::value && + __is_cpp17_random_access_iterator<_ForwardIterator2>::value && + __is_cpp17_random_access_iterator<_ForwardOutIterator>::value) { + std::__terminate_on_exception([&] { + __pstl::__par_backend::__parallel_for( + __pstl::__internal::__par_backend_tag{}, + __policy, + __first1, + __last1, + [&__policy, __op, __first1, __first2, __result]( + _ForwardIterator1 __brick_first, _ForwardIterator1 __brick_last) { + return std::transform( + std::__remove_parallel_policy(__policy), + __brick_first, + __brick_last, + __first2 + (__brick_first - __first1), + __result + (__brick_first - __first1), + __op); + }); + }); + return __result + (__last1 - __first1); + } else if constexpr (__is_unsequenced_execution_policy_v<_ExecutionPolicy> && + __is_cpp17_random_access_iterator<_ForwardIterator1>::value && + __is_cpp17_random_access_iterator<_ForwardIterator2>::value && + __is_cpp17_random_access_iterator<_ForwardOutIterator>::value) { + return __pstl::__unseq_backend::__simd_walk_3( + __first1, + __last1 - __first1, + __first2, + __result, + [&](__iter_reference<_ForwardIterator1> __in1, + __iter_reference<_ForwardIterator2> __in2, + __iter_reference<_ForwardOutIterator> __out) { __out = __op(__in1, __in2); }); + } else { + return std::transform(__first1, __last1, __first2, __result, __op); + } +} + +_LIBCPP_END_NAMESPACE_STD + +#endif // !defined(_LIBCPP_HAS_NO_INCOMPLETE_PSTL) && _LIBCPP_STD_VER >= 17 + +#endif // _LIBCPP___ALGORITHM_PSTL_TRANSFORM_H diff --git a/libcxx/include/__pstl/internal/glue_algorithm_defs.h b/libcxx/include/__pstl/internal/glue_algorithm_defs.h index 82bb3f508d5a..de4501e56b2c 100644 --- a/libcxx/include/__pstl/internal/glue_algorithm_defs.h +++ b/libcxx/include/__pstl/internal/glue_algorithm_defs.h @@ -134,29 +134,6 @@ template <class _ExecutionPolicy, class _ForwardIterator1, class _ForwardIterato __pstl::__internal::__enable_if_execution_policy<_ExecutionPolicy, _ForwardIterator2> swap_ranges( _ExecutionPolicy&& __exec, _ForwardIterator1 __first1, _ForwardIterator1 __last1, _ForwardIterator2 __first2); -// [alg.transform] - -template <class _ExecutionPolicy, class _ForwardIterator1, class _ForwardIterator2, class _UnaryOperation> -__pstl::__internal::__enable_if_execution_policy<_ExecutionPolicy, _ForwardIterator2> -transform(_ExecutionPolicy&& __exec, - _ForwardIterator1 __first, - _ForwardIterator1 __last, - _ForwardIterator2 __result, - _UnaryOperation __op); - -template <class _ExecutionPolicy, - class _ForwardIterator1, - class _ForwardIterator2, - class _ForwardIterator, - class _BinaryOperation> -__pstl::__internal::__enable_if_execution_policy<_ExecutionPolicy, _ForwardIterator> -transform(_ExecutionPolicy&& __exec, - _ForwardIterator1 __first1, - _ForwardIterator1 __last1, - _ForwardIterator2 __first2, - _ForwardIterator __result, - _BinaryOperation __op); - // [alg.replace] template <class _ExecutionPolicy, class _ForwardIterator, class _UnaryPredicate, class _Tp> diff --git a/libcxx/include/__pstl/internal/glue_algorithm_impl.h b/libcxx/include/__pstl/internal/glue_algorithm_impl.h index db62705233b9..bae5efa7d057 100644 --- a/libcxx/include/__pstl/internal/glue_algorithm_impl.h +++ b/libcxx/include/__pstl/internal/glue_algorithm_impl.h @@ -251,27 +251,6 @@ __pstl::__internal::__enable_if_execution_policy<_ExecutionPolicy, _ForwardItera // [alg.transform] -template <class _ExecutionPolicy, class _ForwardIterator1, class _ForwardIterator2, class _UnaryOperation> -__pstl::__internal::__enable_if_execution_policy<_ExecutionPolicy, _ForwardIterator2> -transform(_ExecutionPolicy&& __exec, - _ForwardIterator1 __first, - _ForwardIterator1 __last, - _ForwardIterator2 __result, - _UnaryOperation __op) { - typedef typename iterator_traits<_ForwardIterator1>::reference _InputType; - typedef typename iterator_traits<_ForwardIterator2>::reference _OutputType; - - auto __dispatch_tag = __pstl::__internal::__select_backend(__exec, __first, __result); - - return __pstl::__internal::__pattern_walk2( - __dispatch_tag, - std::forward<_ExecutionPolicy>(__exec), - __first, - __last, - __result, - [__op](_InputType __x, _OutputType __y) mutable { __y = __op(__x); }); -} - template <class _ExecutionPolicy, class _ForwardIterator1, class _ForwardIterator2, diff --git a/libcxx/include/algorithm b/libcxx/include/algorithm index 469bf1706628..18a89eb1a4dc 100644 --- a/libcxx/include/algorithm +++ b/libcxx/include/algorithm @@ -1792,6 +1792,7 @@ template <class BidirectionalIterator, class Compare> #include <__algorithm/pstl_fill.h> #include <__algorithm/pstl_find.h> #include <__algorithm/pstl_for_each.h> +#include <__algorithm/pstl_transform.h> #include <__algorithm/push_heap.h> #include <__algorithm/ranges_adjacent_find.h> #include <__algorithm/ranges_all_of.h> diff --git a/libcxx/test/std/algorithms/alg.modifying.operations/alg.transform/pstl.transform.binary.pass.cpp b/libcxx/test/std/algorithms/alg.modifying.operations/alg.transform/pstl.transform.binary.pass.cpp new file mode 100644 index 000000000000..1076a1548ee3 --- /dev/null +++ b/libcxx/test/std/algorithms/alg.modifying.operations/alg.transform/pstl.transform.binary.pass.cpp @@ -0,0 +1,85 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: c++03, c++11, c++14 + +// UNSUPPORTED: libcpp-has-no-incomplete-pstl + +// <algorithm> + +// template<class ExecutionPolicy, class ForwardIterator1, class ForwardIterator2, +// class ForwardIterator, class BinaryOperation> +// ForwardIterator +// transform(ExecutionPolicy&& exec, +// ForwardIterator1 first1, ForwardIterator1 last1, +// ForwardIterator2 first2, ForwardIterator result, +// BinaryOperation binary_op); + +#include <algorithm> +#include <vector> + +#include "test_macros.h" +#include "test_execution_policies.h" +#include "test_iterators.h" + +EXECUTION_POLICY_SFINAE_TEST(transform); + +static_assert(sfinae_test_transform<int, int*, int*, int*, int*, bool (*)(int)>); +static_assert(!sfinae_test_transform<std::execution::parallel_policy, int*, int*, int*, int*, int (*)(int, int)>); + +template <class Iter1, class Iter2, class Iter3> +struct Test { + template <class Policy> + void operator()(Policy&& policy) { + // simple test + for (const int size : {0, 1, 2, 100, 350}) { + std::vector<int> a(size); + std::vector<int> b(size); + for (int i = 0; i != size; ++i) { + a[i] = i + 1; + b[i] = i - 3; + } + + std::vector<int> out(std::size(a)); + decltype(auto) ret = std::transform( + policy, + Iter1(std::data(a)), + Iter1(std::data(a) + std::size(a)), + Iter2(std::data(b)), + Iter3(std::data(out)), + [](int i, int j) { return i + j + 3; }); + static_assert(std::is_same_v<decltype(ret), Iter3>); + assert(base(ret) == std::data(out) + std::size(out)); + for (int i = 0; i != size; ++i) { + assert(out[i] == i * 2 + 1); + } + } + } +}; + +template <class Iter3> +struct TestIterators2 { + template <class Iter2> + void operator()() { + types::for_each(types::forward_iterator_list<int*>{}, + TestIteratorWithPolicies<types::partial_instantiation<Test, Iter2, Iter3>::template apply>{}); + } +}; + +struct TestIterators1 { + template <class Iter3> + void operator()() { + types::for_each(types::forward_iterator_list<int*>{}, TestIterators2<Iter3>{}); + } +}; + +int main(int, char**) { + types::for_each(types::forward_iterator_list<int*>{}, TestIterators1{}); + + return 0; +} diff --git a/libcxx/test/std/algorithms/alg.modifying.operations/alg.transform/pstl.transform.unary.pass.cpp b/libcxx/test/std/algorithms/alg.modifying.operations/alg.transform/pstl.transform.unary.pass.cpp new file mode 100644 index 000000000000..31069de4e523 --- /dev/null +++ b/libcxx/test/std/algorithms/alg.modifying.operations/alg.transform/pstl.transform.unary.pass.cpp @@ -0,0 +1,67 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: c++03, c++11, c++14 + +// UNSUPPORTED: libcpp-has-no-incomplete-pstl + +// <algorithm> + +// template<class ExecutionPolicy, class ForwardIterator1, class ForwardIterator2, +// class UnaryOperation> +// ForwardIterator2 +// transform(ExecutionPolicy&& exec, +// ForwardIterator1 first1, ForwardIterator1 last1, +// ForwardIterator2 result, UnaryOperation op); + +#include <algorithm> +#include <vector> + +#include "test_macros.h" +#include "test_execution_policies.h" +#include "test_iterators.h" + +// We can't test the constraint on the execution policy, because that would conflict with the binary +// transform algorithm that doesn't take an execution policy, which is not constrained at all. + +template <class Iter1, class Iter2> +struct Test { + template <class Policy> + void operator()(Policy&& policy) { + // simple test + for (const int size : {0, 1, 2, 100, 350}) { + std::vector<int> a(size); + for (int i = 0; i != size; ++i) + a[i] = i + 1; + + std::vector<int> out(std::size(a)); + decltype(auto) ret = std::transform( + policy, Iter1(std::data(a)), Iter1(std::data(a) + std::size(a)), Iter2(std::data(out)), [](int i) { + return i + 3; + }); + static_assert(std::is_same_v<decltype(ret), Iter2>); + assert(base(ret) == std::data(out) + std::size(out)); + for (int i = 0; i != size; ++i) + assert(out[i] == i + 4); + } + } +}; + +struct TestIterators { + template <class Iter2> + void operator()() { + types::for_each(types::forward_iterator_list<int*>{}, + TestIteratorWithPolicies<types::partial_instantiation<Test, Iter2>::template apply>{}); + } +}; + +int main(int, char**) { + types::for_each(types::forward_iterator_list<int*>{}, TestIterators{}); + + return 0; +} diff --git a/libcxx/test/support/type_algorithms.h b/libcxx/test/support/type_algorithms.h index 95a282b7b0bc..ac3ee60b2ccf 100644 --- a/libcxx/test/support/type_algorithms.h +++ b/libcxx/test/support/type_algorithms.h @@ -52,6 +52,13 @@ TEST_CONSTEXPR_CXX14 void for_each(type_list<Types...>, Functor f) { swallow((f.template operator()<Types>(), 0)...); } + +template <template <class...> class T, class... Args> +struct partial_instantiation { + template <class Other> + using apply = T<Args..., Other>; +}; + // type categories defined in [basic.fundamental] plus extensions (without CV-qualifiers) using character_types = |