aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDmitriy Sobolev <Dmitriy.Sobolev@intel.com>2021-02-20 19:30:57 +0300
committerMikhail Dvorskiy <mikhail.dvorskiy@intel.com>2021-02-24 11:36:24 +0300
commit03231bed853daf9317f6e0d9e2f9484c65476b96 (patch)
treeff86169af125a3f903ca80ff9f464ff420ca0cfe
parentRemove explicit default copy constructor in copy_constructible_value_holder (... (diff)
downloadllvm-project-03231bed853daf9317f6e0d9e2f9484c65476b96.tar.gz
llvm-project-03231bed853daf9317f6e0d9e2f9484c65476b96.tar.bz2
llvm-project-03231bed853daf9317f6e0d9e2f9484c65476b96.zip
Avoid divergence of work items in the same SIMD before calling collectives (#129)
-rw-r--r--include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort.h82
1 files changed, 39 insertions, 43 deletions
diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort.h
index 13d21dd380b4..dbb88446ee5a 100644
--- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort.h
+++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort.h
@@ -182,22 +182,6 @@ __convert_to_ordered(_T __value)
// radix sort: run-time device info functions
//------------------------------------------------------------------------
-// get item id in sub-group
-inline ::std::uint32_t
-__get_sg_item_idx(const sycl::nd_item<1>& __idx)
-{
- // technically sycl::id<1>::operator[int] returns a value that always fits in uint8_t (no overflow)
- // and since 64-bit arithmetic is more expensive, the return type is set to ::std::uint32_t
- return static_cast<::std::uint32_t>(__idx.get_sub_group().get_local_id()[0]);
-}
-
-// get number of items in sub-group
-inline ::std::uint32_t
-__get_sg_item_num(const sycl::nd_item<1>& __idx)
-{
- return __idx.get_sub_group().get_local_range()[0];
-}
-
// get rounded up result of (__number / __divisor)
template <typename _T1, typename _T2>
inline auto
@@ -275,6 +259,20 @@ __get_bucket_value(_T __value, ::std::uint32_t __radix_iter)
return (__value >> __bucket_offset) & __bucket_mask;
}
+template <typename _T, bool __is_comp_asc>
+inline __enable_if_t<__is_comp_asc, _T>
+__get_last_value()
+{
+ return ::std::numeric_limits<_T>::max();
+};
+
+template <typename _T, bool __is_comp_asc>
+inline __enable_if_t<!__is_comp_asc, _T>
+__get_last_value()
+{
+ return ::std::numeric_limits<_T>::min();
+};
+
//-----------------------------------------------------------------------
// radix sort: count kernel (per iteration)
//-----------------------------------------------------------------------
@@ -517,36 +515,34 @@ __radix_sort_reorder_submit(_ExecutionPolicy&& __exec, ::std::size_t __segments,
for (::std::size_t __block_idx = 0; __block_idx < __blocks_per_segment * __it_size; ++__block_idx)
{
const ::std::size_t __val_idx = __start_idx + __sg_size * __block_idx;
- // TODO: profile how it affects performance
- if (__val_idx < __inout_buf_size)
+
+ // get value, convert it to ordered (in terms of bitness)
+ // if the index is outside of the range, use fake value which will not affect other values
+ __ordered_t<_InputT> __batch_val = __val_idx < __inout_buf_size
+ ? __convert_to_ordered(__input_rng[__val_idx])
+ : __get_last_value<__ordered_t<_InputT>, __is_comp_asc>();
+
+ // get bit values in a certain bucket of a value
+ ::std::uint32_t __bucket_val =
+ __get_bucket_value<__radix_bits, __is_comp_asc>(__batch_val, __radix_iter);
+
+ _OffsetT __new_offset_idx = 0;
+ // TODO: most computation-heavy code segment - find a better optimized solution
+ for (::std::uint32_t __radix_state_idx = 0; __radix_state_idx < __radix_states; ++__radix_state_idx)
{
- // get value, convert it to ordered (in terms of bitness)
- __ordered_t<_InputT> __batch_val = __convert_to_ordered(__input_rng[__val_idx]);
- // get bit values in a certain bucket of a value
- ::std::uint32_t __bucket_val =
- __get_bucket_value<__radix_bits, __is_comp_asc>(__batch_val, __radix_iter);
-
- _OffsetT __new_offset_idx = 0;
- // TODO: most computation-heavy code segment - find a better optimized solution
- for (::std::uint32_t __radix_state_idx = 0; __radix_state_idx < __radix_states;
- ++__radix_state_idx)
- {
- ::std::uint32_t __is_current_bucket = __bucket_val == __radix_state_idx;
- ::std::uint32_t __sg_item_offset =
- sycl::ONEAPI::exclusive_scan(__self_item.get_sub_group(), __is_current_bucket,
- sycl::ONEAPI::plus<::std::uint32_t>());
-
- __new_offset_idx |=
- __is_current_bucket * (__offset_arr[__radix_state_idx] + __sg_item_offset);
- ::std::uint32_t __sg_total_offset =
- sycl::ONEAPI::reduce(__self_item.get_sub_group(), __is_current_bucket,
- sycl::ONEAPI::plus<::std::uint32_t>());
-
- __offset_arr[__radix_state_idx] = __offset_arr[__radix_state_idx] + __sg_total_offset;
- }
+ ::std::uint32_t __is_current_bucket = __bucket_val == __radix_state_idx;
+ ::std::uint32_t __sg_item_offset = sycl::ONEAPI::exclusive_scan(
+ __self_item.get_sub_group(), __is_current_bucket, sycl::ONEAPI::plus<::std::uint32_t>());
- __output_rng[__new_offset_idx] = __input_rng[__val_idx];
+ __new_offset_idx |= __is_current_bucket * (__offset_arr[__radix_state_idx] + __sg_item_offset);
+ ::std::uint32_t __sg_total_offset = sycl::ONEAPI::reduce(
+ __self_item.get_sub_group(), __is_current_bucket, sycl::ONEAPI::plus<::std::uint32_t>());
+
+ __offset_arr[__radix_state_idx] = __offset_arr[__radix_state_idx] + __sg_total_offset;
}
+
+ if (__val_idx < __inout_buf_size)
+ __output_rng[__new_offset_idx] = __input_rng[__val_idx];
}
});
});