36 #include "specializations/warp_scan_shfl.cuh"
37 #include "specializations/warp_scan_smem.cuh"
38 #include "../thread/thread_operators.cuh"
39 #include "../util_arch.cuh"
40 #include "../util_type.cuh"
41 #include "../util_namespace.cuh"
143 int LOGICAL_WARP_THREADS = CUB_PTX_WARP_THREADS,
159 IS_POW_OF_TWO = ((LOGICAL_WARP_THREADS & (LOGICAL_WARP_THREADS - 1)) == 0),
167 WarpScanShfl<T, LOGICAL_WARP_THREADS, PTX_ARCH>,
168 WarpScanSmem<T, LOGICAL_WARP_THREADS, PTX_ARCH> >::Type InternalWarpScan;
171 typedef typename InternalWarpScan::TempStorage _TempStorage;
179 _TempStorage &temp_storage;
204 temp_storage(temp_storage.Alias()),
205 lane_id(IS_ARCH_WARP ?
207 LaneId() % LOGICAL_WARP_THREADS)
255 InternalWarpScan(temp_storage).InclusiveScan(input, output,
cub::Sum());
300 InternalWarpScan(temp_storage).InclusiveScan(input, output,
cub::Sum(), warp_aggregate);
379 template <
typename WarpPrefixCallbackOp>
384 WarpPrefixCallbackOp &warp_prefix_op)
391 prefix = warp_prefix_op(warp_aggregate);
392 prefix = InternalWarpScan(temp_storage).Broadcast(prefix, 0);
395 output = prefix + output;
403 __device__ __forceinline__
void Sum(T input, T &inclusive_output, T &exclusive_output,
Int2Type<true> is_integer)
407 exclusive_output = inclusive_output - input;
411 __device__ __forceinline__
void Sum(T input, T &inclusive_output, T &exclusive_output, Int2Type<false> is_integer)
414 T identity = ZeroInitialize<T>();
415 InternalWarpScan(temp_storage).Scan(input, inclusive_output, exclusive_output, identity,
cub::Sum());
419 __device__ __forceinline__
void ExclusiveSum(T input, T &output, Int2Type<true> is_integer)
424 output = inclusive - input;
428 __device__ __forceinline__
void ExclusiveSum(T input, T &output, Int2Type<false> is_integer)
431 T identity = ZeroInitialize<T>();
436 __device__ __forceinline__
void ExclusiveSum(T input, T &output, T &warp_aggregate, Int2Type<true> is_integer)
441 output = inclusive - input;
445 __device__ __forceinline__
void ExclusiveSum(T input, T &output, T &warp_aggregate, Int2Type<false> is_integer)
448 T identity = ZeroInitialize<T>();
453 template <
typename WarpPrefixCallbackOp>
454 __device__ __forceinline__
void ExclusiveSum(T input, T &output, T &warp_aggregate, WarpPrefixCallbackOp &warp_prefix_op, Int2Type<true> is_integer)
458 InclusiveSum(input, inclusive, warp_aggregate, warp_prefix_op);
459 output = inclusive - input;
463 template <
typename WarpPrefixCallbackOp>
464 __device__ __forceinline__
void ExclusiveSum(T input, T &output, T &warp_aggregate, WarpPrefixCallbackOp &warp_prefix_op, Int2Type<false> is_integer)
467 T identity = ZeroInitialize<T>();
653 template <
typename WarpPrefixCallbackOp>
658 WarpPrefixCallbackOp &warp_prefix_op)
707 template <
typename ScanOp>
713 InternalWarpScan(temp_storage).InclusiveScan(input, output, scan_op);
758 template <
typename ScanOp>
765 InternalWarpScan(temp_storage).InclusiveScan(input, output, scan_op, warp_aggregate);
848 typename WarpPrefixCallbackOp>
854 WarpPrefixCallbackOp &warp_prefix_op)
861 prefix = warp_prefix_op(warp_aggregate);
862 prefix = InternalWarpScan(temp_storage).Broadcast(prefix, 0);
865 output = scan_op(prefix, output);
912 template <
typename ScanOp>
920 InternalWarpScan(temp_storage).Scan(input, inclusive_output, output, identity, scan_op);
964 template <
typename ScanOp>
972 InternalWarpScan(temp_storage).ExclusiveScan(input, output, identity, scan_op, warp_aggregate);
1055 typename WarpPrefixCallbackOp>
1062 WarpPrefixCallbackOp &warp_prefix_op)
1065 ExclusiveScan(input, output, identity, scan_op, warp_aggregate);
1068 T prefix = warp_prefix_op(warp_aggregate);
1069 prefix = InternalWarpScan(temp_storage).Broadcast(prefix, 0);
1072 output = (lane_id == 0) ?
1074 scan_op(prefix, output);
1123 template <
typename ScanOp>
1130 InternalWarpScan(temp_storage).Scan(input, inclusive_output, output, scan_op);
1174 template <
typename ScanOp>
1181 InternalWarpScan(temp_storage).ExclusiveScan(input, output, scan_op, warp_aggregate);
1264 typename WarpPrefixCallbackOp>
1270 WarpPrefixCallbackOp &warp_prefix_op)
1276 T prefix = warp_prefix_op(warp_aggregate);
1277 prefix = InternalWarpScan(temp_storage).Broadcast(prefix, 0);
1280 output = (lane_id == 0) ?
1282 scan_op(prefix, output);
1333 __device__ __forceinline__
void Sum(
1335 T &inclusive_output,
1336 T &exclusive_output)
1382 template <
typename ScanOp>
1383 __device__ __forceinline__
void Scan(
1385 T &inclusive_output,
1386 T &exclusive_output,
1390 InternalWarpScan(temp_storage).Scan(input, inclusive_output, exclusive_output, identity, scan_op);
1434 template <
typename ScanOp>
1435 __device__ __forceinline__
void Scan(
1437 T &inclusive_output,
1438 T &exclusive_output,
1441 InternalWarpScan(temp_storage).Scan(input, inclusive_output, exclusive_output, scan_op);