38 #include "block_radix_rank.cuh"
39 #include "../util_ptx.cuh"
40 #include "../util_arch.cuh"
41 #include "../util_type.cuh"
42 #include "../util_namespace.cuh"
122 int ITEMS_PER_THREAD,
123 typename Value = NullType,
125 bool MEMOIZE_OUTER_SCAN = (
CUB_PTX_ARCH >= 350) ?
true :
false,
127 cudaSharedMemConfig SMEM_CONFIG = cudaSharedMemBankSizeFourByte,
142 BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z,
150 typedef typename KeyTraits::UnsignedBits UnsignedBits;
153 typedef BlockRadixRank<
158 INNER_SCAN_ALGORITHM,
163 AscendingBlockRadixRank;
166 typedef BlockRadixRank<
171 INNER_SCAN_ALGORITHM,
176 DescendingBlockRadixRank;
189 typename AscendingBlockRadixRank::TempStorage asending_ranking_storage;
190 typename DescendingBlockRadixRank::TempStorage descending_ranking_storage;
202 _TempStorage &temp_storage;
212 __device__ __forceinline__ _TempStorage& PrivateStorage()
214 __shared__ _TempStorage private_storage;
215 return private_storage;
219 __device__ __forceinline__
void RankKeys(
220 UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD],
221 int (&ranks)[ITEMS_PER_THREAD],
226 AscendingBlockRadixRank(temp_storage.asending_ranking_storage).RankKeys(
234 __device__ __forceinline__
void RankKeys(
235 UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD],
236 int (&ranks)[ITEMS_PER_THREAD],
241 DescendingBlockRadixRank(temp_storage.descending_ranking_storage).RankKeys(
249 __device__ __forceinline__
void ExchangeValues(
250 Value (&values)[ITEMS_PER_THREAD],
251 int (&ranks)[ITEMS_PER_THREAD],
258 BlockExchangeValues(temp_storage.exchange_values).ScatterToBlocked(values, ranks);
262 __device__ __forceinline__
void ExchangeValues(
263 Value (&values)[ITEMS_PER_THREAD],
264 int (&ranks)[ITEMS_PER_THREAD],
271 BlockExchangeValues(temp_storage.exchange_values).ScatterToStriped(values, ranks);
275 template <
int IS_BLOCKED>
276 __device__ __forceinline__
void ExchangeValues(
277 Value (&values)[ITEMS_PER_THREAD],
278 int (&ranks)[ITEMS_PER_THREAD],
284 template <
int DESCENDING,
int KEYS_ONLY>
285 __device__ __forceinline__
void SortBlocked(
286 Key (&keys)[ITEMS_PER_THREAD],
287 Value (&values)[ITEMS_PER_THREAD],
293 UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD] =
294 reinterpret_cast<UnsignedBits (&)[ITEMS_PER_THREAD]
>(keys);
298 for (
int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++)
300 unsigned_keys[KEY] = KeyTraits::TwiddleIn(unsigned_keys[KEY]);
306 int pass_bits = CUB_MIN(RADIX_BITS, end_bit - begin_bit);
309 int ranks[ITEMS_PER_THREAD];
310 RankKeys(unsigned_keys, ranks, begin_bit, pass_bits, is_descending);
311 begin_bit += RADIX_BITS;
316 BlockExchangeKeys(temp_storage.exchange_keys).ScatterToBlocked(keys, ranks);
322 if (begin_bit >= end_bit)
break;
329 for (
int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++)
331 unsigned_keys[KEY] = KeyTraits::TwiddleOut(unsigned_keys[KEY]);
336 template <
int DESCENDING,
int KEYS_ONLY>
337 __device__ __forceinline__
void SortBlockedToStriped(
338 Key (&keys)[ITEMS_PER_THREAD],
339 Value (&values)[ITEMS_PER_THREAD],
345 UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD] =
346 reinterpret_cast<UnsignedBits (&)[ITEMS_PER_THREAD]
>(keys);
350 for (
int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++)
352 unsigned_keys[KEY] = KeyTraits::TwiddleIn(unsigned_keys[KEY]);
358 int pass_bits = CUB_MIN(RADIX_BITS, end_bit - begin_bit);
361 int ranks[ITEMS_PER_THREAD];
362 RankKeys(unsigned_keys, ranks, begin_bit, pass_bits, is_descending);
363 begin_bit += RADIX_BITS;
368 if (begin_bit >= end_bit)
371 BlockExchangeKeys(temp_storage.exchange_keys).ScatterToStriped(keys, ranks);
381 BlockExchangeKeys(temp_storage.exchange_keys).ScatterToBlocked(keys, ranks);
391 for (
int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++)
393 unsigned_keys[KEY] = KeyTraits::TwiddleOut(unsigned_keys[KEY]);
415 temp_storage(PrivateStorage()),
416 linear_tid(
RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
426 temp_storage(temp_storage.Alias()),
427 linear_tid(
RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
474 __device__ __forceinline__
void Sort(
475 Key (&keys)[ITEMS_PER_THREAD],
477 int end_bit =
sizeof(Key) * 8)
529 __device__ __forceinline__
void Sort(
530 Key (&keys)[ITEMS_PER_THREAD],
531 Value (&values)[ITEMS_PER_THREAD],
533 int end_bit =
sizeof(Key) * 8)
576 Key (&keys)[ITEMS_PER_THREAD],
578 int end_bit =
sizeof(Key) * 8)
631 Key (&keys)[ITEMS_PER_THREAD],
632 Value (&values)[ITEMS_PER_THREAD],
634 int end_bit =
sizeof(Key) * 8)
686 Key (&keys)[ITEMS_PER_THREAD],
688 int end_bit =
sizeof(Key) * 8)
741 Key (&keys)[ITEMS_PER_THREAD],
742 Value (&values)[ITEMS_PER_THREAD],
744 int end_bit =
sizeof(Key) * 8)
789 Key (&keys)[ITEMS_PER_THREAD],
791 int end_bit =
sizeof(Key) * 8)
844 Key (&keys)[ITEMS_PER_THREAD],
845 Value (&values)[ITEMS_PER_THREAD],
847 int end_bit =
sizeof(Key) * 8)