CUB
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Groups
warp_scan.cuh
Go to the documentation of this file.
1 /******************************************************************************
2  * Copyright (c) 2011, Duane Merrill. All rights reserved.
3  * Copyright (c) 2011-2014, NVIDIA CORPORATION. All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  * * Redistributions of source code must retain the above copyright
8  * notice, this list of conditions and the following disclaimer.
9  * * Redistributions in binary form must reproduce the above copyright
10  * notice, this list of conditions and the following disclaimer in the
11  * documentation and/or other materials provided with the distribution.
12  * * Neither the name of the NVIDIA CORPORATION nor the
13  * names of its contributors may be used to endorse or promote products
14  * derived from this software without specific prior written permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
20  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  *
27  ******************************************************************************/
28 
34 #pragma once
35 
36 #include "specializations/warp_scan_shfl.cuh"
37 #include "specializations/warp_scan_smem.cuh"
38 #include "../thread/thread_operators.cuh"
39 #include "../util_arch.cuh"
40 #include "../util_type.cuh"
41 #include "../util_namespace.cuh"
42 
44 CUB_NS_PREFIX
45 
47 namespace cub {
48 
141 template <
142  typename T,
143  int LOGICAL_WARP_THREADS = CUB_PTX_WARP_THREADS,
144  int PTX_ARCH = CUB_PTX_ARCH>
145 class WarpScan
146 {
147 private:
148 
149  /******************************************************************************
150  * Constants and type definitions
151  ******************************************************************************/
152 
153  enum
154  {
156  IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(PTX_ARCH)),
157 
159  IS_POW_OF_TWO = ((LOGICAL_WARP_THREADS & (LOGICAL_WARP_THREADS - 1)) == 0),
160 
162  IS_INTEGER = ((Traits<T>::CATEGORY == SIGNED_INTEGER) || (Traits<T>::CATEGORY == UNSIGNED_INTEGER))
163  };
164 
166  typedef typename If<(PTX_ARCH >= 300) && (IS_POW_OF_TWO),
167  WarpScanShfl<T, LOGICAL_WARP_THREADS, PTX_ARCH>,
168  WarpScanSmem<T, LOGICAL_WARP_THREADS, PTX_ARCH> >::Type InternalWarpScan;
169 
171  typedef typename InternalWarpScan::TempStorage _TempStorage;
172 
173 
174  /******************************************************************************
175  * Thread fields
176  ******************************************************************************/
177 
179  _TempStorage &temp_storage;
180  int lane_id;
181 
182 
183  /******************************************************************************
184  * Utility methods
185  ******************************************************************************/
186 
187 public:
188 
190  struct TempStorage : Uninitialized<_TempStorage> {};
191 
192 
193  /******************************************************************/
197 
201  __device__ __forceinline__ WarpScan(
202  TempStorage &temp_storage)
203  :
204  temp_storage(temp_storage.Alias()),
205  lane_id(IS_ARCH_WARP ?
206  LaneId() :
207  LaneId() % LOGICAL_WARP_THREADS)
208  {}
209 
210 
212  /******************************************************************/
216 
217 
251  __device__ __forceinline__ void InclusiveSum(
252  T input,
253  T &output)
254  {
255  InternalWarpScan(temp_storage).InclusiveScan(input, output, cub::Sum());
256  }
257 
258 
295  __device__ __forceinline__ void InclusiveSum(
296  T input,
297  T &output,
298  T &warp_aggregate)
299  {
300  InternalWarpScan(temp_storage).InclusiveScan(input, output, cub::Sum(), warp_aggregate);
301  }
302 
303 
379  template <typename WarpPrefixCallbackOp>
380  __device__ __forceinline__ void InclusiveSum(
381  T input,
382  T &output,
383  T &warp_aggregate,
384  WarpPrefixCallbackOp &warp_prefix_op)
385  {
386  // Compute inclusive warp scan
387  InclusiveSum(input, output, warp_aggregate);
388 
389  // Compute warp-wide prefix from aggregate, then broadcast to other lanes
390  T prefix;
391  prefix = warp_prefix_op(warp_aggregate);
392  prefix = InternalWarpScan(temp_storage).Broadcast(prefix, 0);
393 
394  // Update output
395  output = prefix + output;
396  }
397 
399 
400 private:
401 
403  __device__ __forceinline__ void Sum(T input, T &inclusive_output, T &exclusive_output, Int2Type<true> is_integer)
404  {
405  // Compute exclusive warp scan from inclusive warp scan
406  InclusiveSum(input, inclusive_output);
407  exclusive_output = inclusive_output - input;
408  }
409 
411  __device__ __forceinline__ void Sum(T input, T &inclusive_output, T &exclusive_output, Int2Type<false> is_integer)
412  {
413  // Delegate to regular scan for non-integer types (because we won't be able to use subtraction)
414  T identity = ZeroInitialize<T>();
415  InternalWarpScan(temp_storage).Scan(input, inclusive_output, exclusive_output, identity, cub::Sum());
416  }
417 
419  __device__ __forceinline__ void ExclusiveSum(T input, T &output, Int2Type<true> is_integer)
420  {
421  // Compute exclusive warp scan from inclusive warp scan
422  T inclusive;
423  InclusiveSum(input, inclusive);
424  output = inclusive - input;
425  }
426 
428  __device__ __forceinline__ void ExclusiveSum(T input, T &output, Int2Type<false> is_integer)
429  {
430  // Delegate to regular scan for non-integer types (because we won't be able to use subtraction)
431  T identity = ZeroInitialize<T>();
432  ExclusiveScan(input, output, identity, cub::Sum());
433  }
434 
436  __device__ __forceinline__ void ExclusiveSum(T input, T &output, T &warp_aggregate, Int2Type<true> is_integer)
437  {
438  // Compute exclusive warp scan from inclusive warp scan
439  T inclusive;
440  InclusiveSum(input, inclusive, warp_aggregate);
441  output = inclusive - input;
442  }
443 
445  __device__ __forceinline__ void ExclusiveSum(T input, T &output, T &warp_aggregate, Int2Type<false> is_integer)
446  {
447  // Delegate to regular scan for non-integer types (because we won't be able to use subtraction)
448  T identity = ZeroInitialize<T>();
449  ExclusiveScan(input, output, identity, cub::Sum(), warp_aggregate);
450  }
451 
453  template <typename WarpPrefixCallbackOp>
454  __device__ __forceinline__ void ExclusiveSum(T input, T &output, T &warp_aggregate, WarpPrefixCallbackOp &warp_prefix_op, Int2Type<true> is_integer)
455  {
456  // Compute exclusive warp scan from inclusive warp scan
457  T inclusive;
458  InclusiveSum(input, inclusive, warp_aggregate, warp_prefix_op);
459  output = inclusive - input;
460  }
461 
463  template <typename WarpPrefixCallbackOp>
464  __device__ __forceinline__ void ExclusiveSum(T input, T &output, T &warp_aggregate, WarpPrefixCallbackOp &warp_prefix_op, Int2Type<false> is_integer)
465  {
466  // Delegate to regular scan for non-integer types (because we won't be able to use subtraction)
467  T identity = ZeroInitialize<T>();
468  ExclusiveScan(input, output, identity, cub::Sum(), warp_aggregate, warp_prefix_op);
469  }
470 
471 public:
472 
473 
474  /******************************************************************/
478 
479 
519  __device__ __forceinline__ void ExclusiveSum(
520  T input,
521  T &output)
522  {
523  ExclusiveSum(input, output, Int2Type<IS_INTEGER>());
524  }
525 
526 
566  __device__ __forceinline__ void ExclusiveSum(
567  T input,
568  T &output,
569  T &warp_aggregate)
570  {
571  ExclusiveSum(input, output, warp_aggregate, Int2Type<IS_INTEGER>());
572  }
573 
574 
653  template <typename WarpPrefixCallbackOp>
654  __device__ __forceinline__ void ExclusiveSum(
655  T input,
656  T &output,
657  T &warp_aggregate,
658  WarpPrefixCallbackOp &warp_prefix_op)
659  {
660  ExclusiveSum(input, output, warp_aggregate, warp_prefix_op, Int2Type<IS_INTEGER>());
661  }
662 
663 
665  /******************************************************************/
669 
707  template <typename ScanOp>
708  __device__ __forceinline__ void InclusiveScan(
709  T input,
710  T &output,
711  ScanOp scan_op)
712  {
713  InternalWarpScan(temp_storage).InclusiveScan(input, output, scan_op);
714  }
715 
716 
758  template <typename ScanOp>
759  __device__ __forceinline__ void InclusiveScan(
760  T input,
761  T &output,
762  ScanOp scan_op,
763  T &warp_aggregate)
764  {
765  InternalWarpScan(temp_storage).InclusiveScan(input, output, scan_op, warp_aggregate);
766  }
767 
768 
846  template <
847  typename ScanOp,
848  typename WarpPrefixCallbackOp>
849  __device__ __forceinline__ void InclusiveScan(
850  T input,
851  T &output,
852  ScanOp scan_op,
853  T &warp_aggregate,
854  WarpPrefixCallbackOp &warp_prefix_op)
855  {
856  // Compute inclusive warp scan
857  InclusiveScan(input, output, scan_op, warp_aggregate);
858 
859  // Compute warp-wide prefix from aggregate, then broadcast to other lanes
860  T prefix;
861  prefix = warp_prefix_op(warp_aggregate);
862  prefix = InternalWarpScan(temp_storage).Broadcast(prefix, 0);
863 
864  // Update output
865  output = scan_op(prefix, output);
866  }
867 
868 
870  /******************************************************************/
874 
912  template <typename ScanOp>
913  __device__ __forceinline__ void ExclusiveScan(
914  T input,
915  T &output,
916  T identity,
917  ScanOp scan_op)
918  {
919  T inclusive_output;
920  InternalWarpScan(temp_storage).Scan(input, inclusive_output, output, identity, scan_op);
921  }
922 
923 
964  template <typename ScanOp>
965  __device__ __forceinline__ void ExclusiveScan(
966  T input,
967  T &output,
968  T identity,
969  ScanOp scan_op,
970  T &warp_aggregate)
971  {
972  InternalWarpScan(temp_storage).ExclusiveScan(input, output, identity, scan_op, warp_aggregate);
973  }
974 
975 
1053  template <
1054  typename ScanOp,
1055  typename WarpPrefixCallbackOp>
1056  __device__ __forceinline__ void ExclusiveScan(
1057  T input,
1058  T &output,
1059  T identity,
1060  ScanOp scan_op,
1061  T &warp_aggregate,
1062  WarpPrefixCallbackOp &warp_prefix_op)
1063  {
1064  // Exclusive warp scan
1065  ExclusiveScan(input, output, identity, scan_op, warp_aggregate);
1066 
1067  // Compute warp-wide prefix from aggregate, then broadcast to other lanes
1068  T prefix = warp_prefix_op(warp_aggregate);
1069  prefix = InternalWarpScan(temp_storage).Broadcast(prefix, 0);
1070 
1071  // Update output
1072  output = (lane_id == 0) ?
1073  prefix :
1074  scan_op(prefix, output);
1075  }
1076 
1077 
1079  /******************************************************************/
1083 
1084 
1123  template <typename ScanOp>
1124  __device__ __forceinline__ void ExclusiveScan(
1125  T input,
1126  T &output,
1127  ScanOp scan_op)
1128  {
1129  T inclusive_output;
1130  InternalWarpScan(temp_storage).Scan(input, inclusive_output, output, scan_op);
1131  }
1132 
1133 
1174  template <typename ScanOp>
1175  __device__ __forceinline__ void ExclusiveScan(
1176  T input,
1177  T &output,
1178  ScanOp scan_op,
1179  T &warp_aggregate)
1180  {
1181  InternalWarpScan(temp_storage).ExclusiveScan(input, output, scan_op, warp_aggregate);
1182  }
1183 
1184 
1262  template <
1263  typename ScanOp,
1264  typename WarpPrefixCallbackOp>
1265  __device__ __forceinline__ void ExclusiveScan(
1266  T input,
1267  T &output,
1268  ScanOp scan_op,
1269  T &warp_aggregate,
1270  WarpPrefixCallbackOp &warp_prefix_op)
1271  {
1272  // Exclusive warp scan
1273  ExclusiveScan(input, output, scan_op, warp_aggregate);
1274 
1275  // Compute warp-wide prefix from aggregate, then broadcast to other lanes
1276  T prefix = warp_prefix_op(warp_aggregate);
1277  prefix = InternalWarpScan(temp_storage).Broadcast(prefix, 0);
1278 
1279  // Update output with prefix
1280  output = (lane_id == 0) ?
1281  prefix :
1282  scan_op(prefix, output);
1283  }
1284 
1286  /******************************************************************/
1290 
1333  __device__ __forceinline__ void Sum(
1334  T input,
1335  T &inclusive_output,
1336  T &exclusive_output)
1337  {
1338  Sum(input, inclusive_output, exclusive_output, Int2Type<IS_INTEGER>());
1339  }
1340 
1341 
1382  template <typename ScanOp>
1383  __device__ __forceinline__ void Scan(
1384  T input,
1385  T &inclusive_output,
1386  T &exclusive_output,
1387  T identity,
1388  ScanOp scan_op)
1389  {
1390  InternalWarpScan(temp_storage).Scan(input, inclusive_output, exclusive_output, identity, scan_op);
1391  }
1392 
1393 
1434  template <typename ScanOp>
1435  __device__ __forceinline__ void Scan(
1436  T input,
1437  T &inclusive_output,
1438  T &exclusive_output,
1439  ScanOp scan_op)
1440  {
1441  InternalWarpScan(temp_storage).Scan(input, inclusive_output, exclusive_output, scan_op);
1442  }
1443 
1444 
1446 };
1447  // end group WarpModule
1449 
1450 } // CUB namespace
1451 CUB_NS_POSTFIX // Optional outer namespace(s)