39 #include "util_namespace.cuh"
58 #ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
63 #if defined(_WIN64) || defined(__LP64__)
64 #define __CUB_LP64__ 1
66 #define _CUB_ASM_PTR_ "l"
67 #define _CUB_ASM_PTR_SIZE_ "u64"
69 #define __CUB_LP64__ 0
71 #define _CUB_ASM_PTR_ "r"
72 #define _CUB_ASM_PTR_SIZE_ "u32"
75 #endif // DOXYGEN_SHOULD_SKIP_THIS
85 __device__ __forceinline__
unsigned int SHR_ADD(
91 #if CUB_PTX_ARCH >= 200
92 asm(
"vshr.u32.u32.u32.clamp.add %0, %1, %2, %3;" :
93 "=r"(ret) :
"r"(x),
"r"(shift),
"r"(addend));
95 ret = (x >> shift) + addend;
104 __device__ __forceinline__
unsigned int SHL_ADD(
110 #if CUB_PTX_ARCH >= 200
111 asm(
"vshl.u32.u32.u32.clamp.add %0, %1, %2, %3;" :
112 "=r"(ret) :
"r"(x),
"r"(shift),
"r"(addend));
114 ret = (x << shift) + addend;
119 #ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
124 template <
typename Un
signedBits,
int BYTE_LEN>
125 __device__ __forceinline__
unsigned int BFE(
127 unsigned int bit_start,
128 unsigned int num_bits,
129 Int2Type<BYTE_LEN> byte_len)
132 #if CUB_PTX_ARCH >= 200
133 asm(
"bfe.u32 %0, %1, %2, %3;" :
"=r"(bits) :
"r"((
unsigned int) source),
"r"(bit_start),
"r"(num_bits));
135 const unsigned int MASK = (1 << num_bits) - 1;
136 bits = (source >> bit_start) & MASK;
145 template <
typename Un
signedBits>
146 __device__ __forceinline__
unsigned int BFE(
148 unsigned int bit_start,
149 unsigned int num_bits,
150 Int2Type<8> byte_len)
152 const unsigned long long MASK = (1ull << num_bits) - 1;
153 return (source >> bit_start) & MASK;
156 #endif // DOXYGEN_SHOULD_SKIP_THIS
161 template <
typename Un
signedBits>
162 __device__ __forceinline__
unsigned int BFE(
164 unsigned int bit_start,
165 unsigned int num_bits)
167 return BFE(source, bit_start, num_bits,
Int2Type<
sizeof(UnsignedBits)>());
174 __device__ __forceinline__
void BFI(
178 unsigned int bit_start,
179 unsigned int num_bits)
181 #if CUB_PTX_ARCH >= 200
182 asm(
"bfi.b32 %0, %1, %2, %3, %4;" :
183 "=r"(ret) :
"r"(y),
"r"(x),
"r"(bit_start),
"r"(num_bits));
186 unsigned int MASK_X = ((1 << num_bits) - 1) << bit_start;
187 unsigned int MASK_Y = ~MASK_X;
188 ret = (y & MASK_Y) | (x & MASK_X);
196 __device__ __forceinline__
unsigned int IADD3(
unsigned int x,
unsigned int y,
unsigned int z)
198 #if CUB_PTX_ARCH >= 200
199 asm(
"vadd.u32.u32.u32.add %0, %1, %2, %3;" :
"=r"(x) :
"r"(x),
"r"(y),
"r"(z));
233 __device__ __forceinline__
int PRMT(
unsigned int a,
unsigned int b,
unsigned int index)
236 asm(
"prmt.b32 %0, %1, %2, %3;" :
"=r"(ret) :
"r"(a),
"r"(b),
"r"(index));
240 #ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
245 __device__ __forceinline__
void BAR(
int count)
247 asm volatile(
"bar.sync 1, %0;" : :
"r"(count));
254 __device__ __forceinline__
float FMUL_RZ(
float a,
float b)
257 asm(
"mul.rz.f32 %0, %1, %2;" :
"=f"(d) :
"f"(a),
"f"(b));
265 __device__ __forceinline__
float FFMA_RZ(
float a,
float b,
float c)
268 asm(
"fma.rz.f32 %0, %1, %2, %3;" :
"=f"(d) :
"f"(a),
"f"(b),
"f"(c));
272 #endif // DOXYGEN_SHOULD_SKIP_THIS
285 __device__ __forceinline__
int RowMajorTid(
int block_dim_x,
int block_dim_y,
int block_dim_z)
287 return ((block_dim_z == 1) ? 0 : (threadIdx.z * block_dim_x * block_dim_y)) +
288 ((block_dim_y == 1) ? 0 : (threadIdx.y * block_dim_x)) +
296 __device__ __forceinline__
unsigned int LaneId()
299 asm(
"mov.u32 %0, %laneid;" :
"=r"(ret) );
307 __device__ __forceinline__
unsigned int WarpId()
310 asm(
"mov.u32 %0, %warpid;" :
"=r"(ret) );
320 asm(
"mov.u32 %0, %lanemask_lt;" :
"=r"(ret) );
330 asm(
"mov.u32 %0, %lanemask_le;" :
"=r"(ret) );
340 asm(
"mov.u32 %0, %lanemask_gt;" :
"=r"(ret) );
350 asm(
"mov.u32 %0, %lanemask_ge;" :
"=r"(ret) );
387 template <
typename T>
397 typedef typename UnitWord<T>::ShuffleWord ShuffleWord;
399 const int WORDS = (
sizeof(T) +
sizeof(ShuffleWord) - 1) /
sizeof(ShuffleWord);
401 ShuffleWord *output_alias =
reinterpret_cast<ShuffleWord *
>(&output);
402 ShuffleWord *input_alias =
reinterpret_cast<ShuffleWord *
>(&input);
405 for (
int WORD = 0; WORD < WORDS; ++WORD)
407 unsigned int shuffle_word = input_alias[WORD];
409 " shfl.up.b32 %0, %1, %2, %3;"
410 :
"=r"(shuffle_word) :
"r"(shuffle_word),
"r"(src_offset),
"r"(SHFL_C));
411 output_alias[WORD] = (ShuffleWord) shuffle_word;
446 template <
typename T>
453 SHFL_C = CUB_PTX_WARP_THREADS - 1,
456 typedef typename UnitWord<T>::ShuffleWord ShuffleWord;
458 const int WORDS = (
sizeof(T) +
sizeof(ShuffleWord) - 1) /
sizeof(ShuffleWord);
460 ShuffleWord *output_alias =
reinterpret_cast<ShuffleWord *
>(&output);
461 ShuffleWord *input_alias =
reinterpret_cast<ShuffleWord *
>(&input);
464 for (
int WORD = 0; WORD < WORDS; ++WORD)
466 unsigned int shuffle_word = input_alias[WORD];
468 " shfl.down.b32 %0, %1, %2, %3;"
469 :
"=r"(shuffle_word) :
"r"(shuffle_word),
"r"(src_offset),
"r"(SHFL_C));
470 output_alias[WORD] = (ShuffleWord) shuffle_word;
476 #ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
485 template <
typename T>
489 int logical_warp_threads)
491 typedef typename UnitWord<T>::ShuffleWord ShuffleWord;
493 const int WORDS = (
sizeof(T) +
sizeof(ShuffleWord) - 1) /
sizeof(ShuffleWord);
495 ShuffleWord *output_alias =
reinterpret_cast<ShuffleWord *
>(&output);
496 ShuffleWord *input_alias =
reinterpret_cast<ShuffleWord *
>(&input);
499 for (
int WORD = 0; WORD < WORDS; ++WORD)
501 unsigned int shuffle_word = input_alias[WORD];
502 asm(
"shfl.idx.b32 %0, %1, %2, %3;"
503 :
"=r"(shuffle_word) :
"r"(shuffle_word),
"r"(src_lane),
"r"(logical_warp_threads - 1));
504 output_alias[WORD] = (ShuffleWord) shuffle_word;
510 #endif // DOXYGEN_SHOULD_SKIP_THIS
541 template <
typename T>
557 __device__ __forceinline__
int WarpAll(
int cond)
559 #if CUB_PTX_ARCH < 120
561 __shared__
volatile int warp_signals[CUB_PTX_MAX_SM_THREADS / CUB_PTX_WARP_THREADS];
564 warp_signals[
WarpId()] = 1;
567 warp_signals[
WarpId()] = 0;
569 return warp_signals[
WarpId()];
583 __device__ __forceinline__
int WarpAny(
int cond)
585 #if CUB_PTX_ARCH < 120
587 __shared__
volatile int warp_signals[CUB_PTX_MAX_SM_THREADS / CUB_PTX_WARP_THREADS];
590 warp_signals[
WarpId()] = 0;
593 warp_signals[
WarpId()] = 1;
595 return warp_signals[
WarpId()];