CUB
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Groups
block_radix_sort.cuh
Go to the documentation of this file.
1 /******************************************************************************
2  * Copyright (c) 2011, Duane Merrill. All rights reserved.
3  * Copyright (c) 2011-2014, NVIDIA CORPORATION. All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  * * Redistributions of source code must retain the above copyright
8  * notice, this list of conditions and the following disclaimer.
9  * * Redistributions in binary form must reproduce the above copyright
10  * notice, this list of conditions and the following disclaimer in the
11  * documentation and/or other materials provided with the distribution.
12  * * Neither the name of the NVIDIA CORPORATION nor the
13  * names of its contributors may be used to endorse or promote products
14  * derived from this software without specific prior written permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
20  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  *
27  ******************************************************************************/
28 
35 #pragma once
36 
37 #include "block_exchange.cuh"
38 #include "block_radix_rank.cuh"
39 #include "../util_ptx.cuh"
40 #include "../util_arch.cuh"
41 #include "../util_type.cuh"
42 #include "../util_namespace.cuh"
43 
45 CUB_NS_PREFIX
46 
48 namespace cub {
49 
119 template <
120  typename Key,
121  int BLOCK_DIM_X,
122  int ITEMS_PER_THREAD,
123  typename Value = NullType,
124  int RADIX_BITS = 4,
125  bool MEMOIZE_OUTER_SCAN = (CUB_PTX_ARCH >= 350) ? true : false,
126  BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS,
127  cudaSharedMemConfig SMEM_CONFIG = cudaSharedMemBankSizeFourByte,
128  int BLOCK_DIM_Y = 1,
129  int BLOCK_DIM_Z = 1,
130  int PTX_ARCH = CUB_PTX_ARCH>
132 {
133 private:
134 
135  /******************************************************************************
136  * Constants and type definitions
137  ******************************************************************************/
138 
139  enum
140  {
141  // The thread block size in threads
142  BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z,
143 
144  // Whether or not there are values to be trucked along with keys
145  KEYS_ONLY = Equals<Value, NullType>::VALUE,
146  };
147 
148  // Key traits and unsigned bits type
150  typedef typename KeyTraits::UnsignedBits UnsignedBits;
151 
153  typedef BlockRadixRank<
154  BLOCK_DIM_X,
155  RADIX_BITS,
156  false,
157  MEMOIZE_OUTER_SCAN,
158  INNER_SCAN_ALGORITHM,
159  SMEM_CONFIG,
160  BLOCK_DIM_Y,
161  BLOCK_DIM_Z,
162  PTX_ARCH>
163  AscendingBlockRadixRank;
164 
166  typedef BlockRadixRank<
167  BLOCK_DIM_X,
168  RADIX_BITS,
169  true,
170  MEMOIZE_OUTER_SCAN,
171  INNER_SCAN_ALGORITHM,
172  SMEM_CONFIG,
173  BLOCK_DIM_Y,
174  BLOCK_DIM_Z,
175  PTX_ARCH>
176  DescendingBlockRadixRank;
177 
180 
183 
185  struct _TempStorage
186  {
187  union
188  {
189  typename AscendingBlockRadixRank::TempStorage asending_ranking_storage;
190  typename DescendingBlockRadixRank::TempStorage descending_ranking_storage;
191  typename BlockExchangeKeys::TempStorage exchange_keys;
192  typename BlockExchangeValues::TempStorage exchange_values;
193  };
194  };
195 
196 
197  /******************************************************************************
198  * Thread fields
199  ******************************************************************************/
200 
202  _TempStorage &temp_storage;
203 
205  int linear_tid;
206 
207  /******************************************************************************
208  * Utility methods
209  ******************************************************************************/
210 
212  __device__ __forceinline__ _TempStorage& PrivateStorage()
213  {
214  __shared__ _TempStorage private_storage;
215  return private_storage;
216  }
217 
219  __device__ __forceinline__ void RankKeys(
220  UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD],
221  int (&ranks)[ITEMS_PER_THREAD],
222  int begin_bit,
223  int pass_bits,
224  Int2Type<false> is_descending)
225  {
226  AscendingBlockRadixRank(temp_storage.asending_ranking_storage).RankKeys(
227  unsigned_keys,
228  ranks,
229  begin_bit,
230  pass_bits);
231  }
232 
234  __device__ __forceinline__ void RankKeys(
235  UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD],
236  int (&ranks)[ITEMS_PER_THREAD],
237  int begin_bit,
238  int pass_bits,
239  Int2Type<true> is_descending)
240  {
241  DescendingBlockRadixRank(temp_storage.descending_ranking_storage).RankKeys(
242  unsigned_keys,
243  ranks,
244  begin_bit,
245  pass_bits);
246  }
247 
249  __device__ __forceinline__ void ExchangeValues(
250  Value (&values)[ITEMS_PER_THREAD],
251  int (&ranks)[ITEMS_PER_THREAD],
252  Int2Type<false> is_keys_only,
253  Int2Type<true> is_blocked)
254  {
255  __syncthreads();
256 
257  // Exchange values through shared memory in blocked arrangement
258  BlockExchangeValues(temp_storage.exchange_values).ScatterToBlocked(values, ranks);
259  }
260 
262  __device__ __forceinline__ void ExchangeValues(
263  Value (&values)[ITEMS_PER_THREAD],
264  int (&ranks)[ITEMS_PER_THREAD],
265  Int2Type<false> is_keys_only,
266  Int2Type<false> is_blocked)
267  {
268  __syncthreads();
269 
270  // Exchange values through shared memory in blocked arrangement
271  BlockExchangeValues(temp_storage.exchange_values).ScatterToStriped(values, ranks);
272  }
273 
275  template <int IS_BLOCKED>
276  __device__ __forceinline__ void ExchangeValues(
277  Value (&values)[ITEMS_PER_THREAD],
278  int (&ranks)[ITEMS_PER_THREAD],
279  Int2Type<true> is_keys_only,
280  Int2Type<IS_BLOCKED> is_blocked)
281  {}
282 
284  template <int DESCENDING, int KEYS_ONLY>
285  __device__ __forceinline__ void SortBlocked(
286  Key (&keys)[ITEMS_PER_THREAD],
287  Value (&values)[ITEMS_PER_THREAD],
288  int begin_bit,
289  int end_bit,
290  Int2Type<DESCENDING> is_descending,
291  Int2Type<KEYS_ONLY> is_keys_only)
292  {
293  UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD] =
294  reinterpret_cast<UnsignedBits (&)[ITEMS_PER_THREAD]>(keys);
295 
296  // Twiddle bits if necessary
297  #pragma unroll
298  for (int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++)
299  {
300  unsigned_keys[KEY] = KeyTraits::TwiddleIn(unsigned_keys[KEY]);
301  }
302 
303  // Radix sorting passes
304  while (true)
305  {
306  int pass_bits = CUB_MIN(RADIX_BITS, end_bit - begin_bit);
307 
308  // Rank the blocked keys
309  int ranks[ITEMS_PER_THREAD];
310  RankKeys(unsigned_keys, ranks, begin_bit, pass_bits, is_descending);
311  begin_bit += RADIX_BITS;
312 
313  __syncthreads();
314 
315  // Exchange keys through shared memory in blocked arrangement
316  BlockExchangeKeys(temp_storage.exchange_keys).ScatterToBlocked(keys, ranks);
317 
318  // Exchange values through shared memory in blocked arrangement
319  ExchangeValues(values, ranks, is_keys_only, Int2Type<true>());
320 
321  // Quit if done
322  if (begin_bit >= end_bit) break;
323 
324  __syncthreads();
325  }
326 
327  // Untwiddle bits if necessary
328  #pragma unroll
329  for (int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++)
330  {
331  unsigned_keys[KEY] = KeyTraits::TwiddleOut(unsigned_keys[KEY]);
332  }
333  }
334 
336  template <int DESCENDING, int KEYS_ONLY>
337  __device__ __forceinline__ void SortBlockedToStriped(
338  Key (&keys)[ITEMS_PER_THREAD],
339  Value (&values)[ITEMS_PER_THREAD],
340  int begin_bit,
341  int end_bit,
342  Int2Type<DESCENDING> is_descending,
343  Int2Type<KEYS_ONLY> is_keys_only)
344  {
345  UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD] =
346  reinterpret_cast<UnsignedBits (&)[ITEMS_PER_THREAD]>(keys);
347 
348  // Twiddle bits if necessary
349  #pragma unroll
350  for (int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++)
351  {
352  unsigned_keys[KEY] = KeyTraits::TwiddleIn(unsigned_keys[KEY]);
353  }
354 
355  // Radix sorting passes
356  while (true)
357  {
358  int pass_bits = CUB_MIN(RADIX_BITS, end_bit - begin_bit);
359 
360  // Rank the blocked keys
361  int ranks[ITEMS_PER_THREAD];
362  RankKeys(unsigned_keys, ranks, begin_bit, pass_bits, is_descending);
363  begin_bit += RADIX_BITS;
364 
365  __syncthreads();
366 
367  // Check if this is the last pass
368  if (begin_bit >= end_bit)
369  {
370  // Last pass exchanges keys through shared memory in striped arrangement
371  BlockExchangeKeys(temp_storage.exchange_keys).ScatterToStriped(keys, ranks);
372 
373  // Last pass exchanges through shared memory in striped arrangement
374  ExchangeValues(values, ranks, is_keys_only, Int2Type<false>());
375 
376  // Quit
377  break;
378  }
379 
380  // Exchange keys through shared memory in blocked arrangement
381  BlockExchangeKeys(temp_storage.exchange_keys).ScatterToBlocked(keys, ranks);
382 
383  // Exchange values through shared memory in blocked arrangement
384  ExchangeValues(values, ranks, is_keys_only, Int2Type<true>());
385 
386  __syncthreads();
387  }
388 
389  // Untwiddle bits if necessary
390  #pragma unroll
391  for (int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++)
392  {
393  unsigned_keys[KEY] = KeyTraits::TwiddleOut(unsigned_keys[KEY]);
394  }
395  }
396 
397 
398 
399 public:
400 
402  struct TempStorage : Uninitialized<_TempStorage> {};
403 
404 
405  /******************************************************************/
409 
413  __device__ __forceinline__ BlockRadixSort()
414  :
415  temp_storage(PrivateStorage()),
416  linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
417  {}
418 
419 
423  __device__ __forceinline__ BlockRadixSort(
424  TempStorage &temp_storage)
425  :
426  temp_storage(temp_storage.Alias()),
427  linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
428  {}
429 
430 
432  /******************************************************************/
436 
474  __device__ __forceinline__ void Sort(
475  Key (&keys)[ITEMS_PER_THREAD],
476  int begin_bit = 0,
477  int end_bit = sizeof(Key) * 8)
478  {
479  NullType values[ITEMS_PER_THREAD];
480 
481  SortBlocked(keys, values, begin_bit, end_bit, Int2Type<false>(), Int2Type<KEYS_ONLY>());
482  }
483 
484 
529  __device__ __forceinline__ void Sort(
530  Key (&keys)[ITEMS_PER_THREAD],
531  Value (&values)[ITEMS_PER_THREAD],
532  int begin_bit = 0,
533  int end_bit = sizeof(Key) * 8)
534  {
535  SortBlocked(keys, values, begin_bit, end_bit, Int2Type<false>(), Int2Type<KEYS_ONLY>());
536  }
537 
575  __device__ __forceinline__ void SortDescending(
576  Key (&keys)[ITEMS_PER_THREAD],
577  int begin_bit = 0,
578  int end_bit = sizeof(Key) * 8)
579  {
580  NullType values[ITEMS_PER_THREAD];
581 
582  SortBlocked(keys, values, begin_bit, end_bit, Int2Type<true>(), Int2Type<KEYS_ONLY>());
583  }
584 
585 
630  __device__ __forceinline__ void SortDescending(
631  Key (&keys)[ITEMS_PER_THREAD],
632  Value (&values)[ITEMS_PER_THREAD],
633  int begin_bit = 0,
634  int end_bit = sizeof(Key) * 8)
635  {
636  SortBlocked(keys, values, begin_bit, end_bit, Int2Type<true>(), Int2Type<KEYS_ONLY>());
637  }
638 
639 
641  /******************************************************************/
645 
646 
685  __device__ __forceinline__ void SortBlockedToStriped(
686  Key (&keys)[ITEMS_PER_THREAD],
687  int begin_bit = 0,
688  int end_bit = sizeof(Key) * 8)
689  {
690  NullType values[ITEMS_PER_THREAD];
691 
692  SortBlockedToStriped(keys, values, begin_bit, end_bit, Int2Type<false>(), Int2Type<KEYS_ONLY>());
693  }
694 
695 
740  __device__ __forceinline__ void SortBlockedToStriped(
741  Key (&keys)[ITEMS_PER_THREAD],
742  Value (&values)[ITEMS_PER_THREAD],
743  int begin_bit = 0,
744  int end_bit = sizeof(Key) * 8)
745  {
746  SortBlockedToStriped(keys, values, begin_bit, end_bit, Int2Type<false>(), Int2Type<KEYS_ONLY>());
747  }
748 
749 
788  __device__ __forceinline__ void SortDescendingBlockedToStriped(
789  Key (&keys)[ITEMS_PER_THREAD],
790  int begin_bit = 0,
791  int end_bit = sizeof(Key) * 8)
792  {
793  NullType values[ITEMS_PER_THREAD];
794 
795  SortBlockedToStriped(keys, values, begin_bit, end_bit, Int2Type<true>(), Int2Type<KEYS_ONLY>());
796  }
797 
798 
843  __device__ __forceinline__ void SortDescendingBlockedToStriped(
844  Key (&keys)[ITEMS_PER_THREAD],
845  Value (&values)[ITEMS_PER_THREAD],
846  int begin_bit = 0,
847  int end_bit = sizeof(Key) * 8)
848  {
849  SortBlockedToStriped(keys, values, begin_bit, end_bit, Int2Type<true>(), Int2Type<KEYS_ONLY>());
850  }
851 
852 
854 
855 };
856 
861 } // CUB namespace
862 CUB_NS_POSTFIX // Optional outer namespace(s)
863