Skip to content

Commit 5f29bff

Browse files
authored
[SPIR-V] countbit on 16+64 bit types (#7997)
The SPIR-V instruction can work on non-32 bit components, but only with VK maintenance 9, which seems to be hidden behing a feature bit. Implementing countbit on such types by converting to a 32-bit integer first. Should be ok. Fixes #7494
1 parent 23a15c5 commit 5f29bff

File tree

3 files changed

+293
-12
lines changed

3 files changed

+293
-12
lines changed

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9763,7 +9763,10 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
97639763
retVal = processReverseBitsIntrinsic(callExpr, srcLoc);
97649764
break;
97659765
}
9766-
INTRINSIC_SPIRV_OP_CASE(countbits, BitCount, false);
9766+
case hlsl::IntrinsicOp::IOP_countbits: {
9767+
retVal = processCountBitsIntrinsic(callExpr, srcLoc);
9768+
break;
9769+
}
97679770
INTRINSIC_SPIRV_OP_CASE(fmod, FRem, true);
97689771
INTRINSIC_SPIRV_OP_CASE(fwidth, Fwidth, true);
97699772
INTRINSIC_SPIRV_OP_CASE(and, LogicalAnd, false);
@@ -9922,6 +9925,91 @@ SpirvInstruction *SpirvEmitter::processDerivativeIntrinsic(
99229925
return result;
99239926
}
99249927

9928+
SpirvInstruction *
9929+
SpirvEmitter::processCountBitsIntrinsic(const CallExpr *callExpr,
9930+
clang::SourceLocation srcLoc) {
9931+
const QualType argType = callExpr->getArg(0)->getType();
9932+
const uint32_t bitwidth = getElementSpirvBitwidth(
9933+
astContext, argType, spirvOptions.enable16BitTypes);
9934+
9935+
// The intrinsic should always return an uint or vector of uint.
9936+
QualType retType = {};
9937+
if (!isVectorType(callExpr->getCallReturnType(astContext), &retType))
9938+
retType = callExpr->getCallReturnType(astContext);
9939+
assert(retType == astContext.UnsignedIntTy);
9940+
9941+
// SPIRV only supports 32 bit integers for `OpBitCount` until maintenace9.
9942+
// We need to unfold and add extra instructions to support this on
9943+
// non-32bit integers.
9944+
if (bitwidth == 32) {
9945+
return processIntrinsicUsingSpirvInst(callExpr, spv::Op::OpBitCount,
9946+
/* actPerRowForMatrices= */ false);
9947+
} else if (bitwidth == 16) {
9948+
return generateCountBits16(callExpr, srcLoc);
9949+
} else if (bitwidth == 64) {
9950+
return generateCountBits64(callExpr, srcLoc);
9951+
}
9952+
emitError("countbits currently only supports 16, 32, and 64-bit "
9953+
"width components when targeting SPIR-V",
9954+
srcLoc);
9955+
return nullptr;
9956+
}
9957+
9958+
SpirvInstruction *
9959+
SpirvEmitter::generateCountBits16(const CallExpr *callExpr,
9960+
clang::SourceLocation srcLoc) {
9961+
const QualType argType = callExpr->getArg(0)->getType();
9962+
// Load the 16-bit value
9963+
auto *loadInst = doExpr(callExpr->getArg(0));
9964+
bool isVector = isVectorType(argType);
9965+
uint32_t count = isVector ? hlsl::GetHLSLVecSize(argType) : 1;
9966+
QualType uintType =
9967+
isVector ? astContext.getExtVectorType(astContext.UnsignedIntTy, count)
9968+
: astContext.UnsignedIntTy;
9969+
9970+
auto *extended =
9971+
spvBuilder.createUnaryOp(spv::Op::OpUConvert, uintType, loadInst, srcLoc);
9972+
return spvBuilder.createUnaryOp(spv::Op::OpBitCount, uintType, extended,
9973+
srcLoc);
9974+
}
9975+
9976+
SpirvInstruction *
9977+
SpirvEmitter::generateCountBits64(const CallExpr *callExpr,
9978+
clang::SourceLocation srcLoc) {
9979+
const QualType argType = callExpr->getArg(0)->getType();
9980+
// Load the 16-bit value
9981+
auto *loadInst = doExpr(callExpr->getArg(0));
9982+
bool isVector = isVectorType(argType);
9983+
uint32_t count = isVector ? hlsl::GetHLSLVecSize(argType) : 1;
9984+
QualType uintType =
9985+
isVector ? astContext.getExtVectorType(astContext.UnsignedIntTy, count)
9986+
: astContext.UnsignedIntTy;
9987+
9988+
auto *lhs =
9989+
spvBuilder.createUnaryOp(spv::Op::OpUConvert, uintType, loadInst, srcLoc);
9990+
auto *lhs_count =
9991+
spvBuilder.createUnaryOp(spv::Op::OpBitCount, uintType, lhs, srcLoc);
9992+
9993+
auto *shiftAmount =
9994+
spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 32));
9995+
if (isVector) {
9996+
SmallVector<SpirvConstant *, 4> Components;
9997+
for (unsigned I = 0; I < count; ++I)
9998+
Components.push_back(shiftAmount);
9999+
shiftAmount = spvBuilder.getConstantComposite(uintType, Components);
10000+
}
10001+
10002+
SpirvInstruction *rhs = spvBuilder.createBinaryOp(
10003+
spv::Op::OpShiftRightLogical, argType, loadInst, shiftAmount, srcLoc);
10004+
auto *rhs32 =
10005+
spvBuilder.createUnaryOp(spv::Op::OpUConvert, uintType, rhs, srcLoc);
10006+
auto *rhs_count =
10007+
spvBuilder.createUnaryOp(spv::Op::OpBitCount, uintType, rhs32, srcLoc);
10008+
10009+
return spvBuilder.createBinaryOp(spv::Op::OpIAdd, uintType, rhs_count,
10010+
lhs_count, srcLoc);
10011+
}
10012+
992510013
SpirvInstruction *
992610014
SpirvEmitter::processReverseBitsIntrinsic(const CallExpr *callExpr,
992710015
clang::SourceLocation srcLoc) {

tools/clang/lib/SPIRV/SpirvEmitter.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,13 @@ class SpirvEmitter : public ASTConsumer {
852852
SourceLocation loc,
853853
SourceRange range);
854854

855+
SpirvInstruction *processCountBitsIntrinsic(const CallExpr *callExpr,
856+
clang::SourceLocation srcLoc);
857+
SpirvInstruction *generateCountBits16(const CallExpr *callExpr,
858+
clang::SourceLocation srcLoc);
859+
SpirvInstruction *generateCountBits64(const CallExpr *callExpr,
860+
clang::SourceLocation srcLoc);
861+
855862
// Processes the `reversebits` intrinsic
856863
SpirvInstruction *processReverseBitsIntrinsic(const CallExpr *expr,
857864
clang::SourceLocation srcLoc);
Lines changed: 197 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,203 @@
1-
// RUN: %dxc -T vs_6_0 -E main -fcgl %s -spirv | FileCheck %s
1+
// RUN: %dxc -T vs_6_2 -E main -fcgl -enable-16bit-types %s -spirv | FileCheck %s
22

33
// According to HLSL reference:
44
// The 'countbits' function can only operate on scalar or vector of uints.
55

66
void main() {
7-
uint a;
8-
uint4 b;
9-
10-
// CHECK: [[a:%[0-9]+]] = OpLoad %uint %a
11-
// CHECK-NEXT: {{%[0-9]+}} = OpBitCount %uint [[a]]
12-
uint cb = countbits(a);
13-
14-
// CHECK: [[b:%[0-9]+]] = OpLoad %v4uint %b
15-
// CHECK-NEXT: {{%[0-9]+}} = OpBitCount %v4uint [[b]]
16-
uint4 cb4 = countbits(b);
7+
// CHECK: [[v4_32:%[0-9]+]] = OpConstantComposite %v4uint %uint_32 %uint_32 %uint_32 %uint_32
8+
9+
uint16_t u16;
10+
uint32_t u32;
11+
uint64_t u64;
12+
13+
// CHECK: [[tmp:%[0-9]+]] = OpLoad %ushort %u16
14+
// CHECK: [[ext:%[0-9]+]] = OpUConvert %uint [[tmp]]
15+
// CHECK: [[res:%[0-9]+]] = OpBitCount %uint [[ext]]
16+
// CHECK: [[cast:%[0-9]+]] = OpUConvert %ushort [[res]]
17+
// CHECK: OpStore %u16ru16 [[cast]]
18+
uint16_t u16ru16 = countbits(u16);
19+
20+
// CHECK: [[tmp:%[0-9]+]] = OpLoad %ushort %u16
21+
// CHECK: [[ext:%[0-9]+]] = OpUConvert %uint [[tmp]]
22+
// CHECK: [[res:%[0-9]+]] = OpBitCount %uint [[ext]]
23+
// CHECK: OpStore %u32ru16 [[res]]
24+
uint32_t u32ru16 = countbits(u16);
25+
26+
// CHECK: [[tmp:%[0-9]+]] = OpLoad %ushort %u16
27+
// CHECK: [[ext:%[0-9]+]] = OpUConvert %uint [[tmp]]
28+
// CHECK: [[res:%[0-9]+]] = OpBitCount %uint [[ext]]
29+
// CHECK: [[cast:%[0-9]+]] = OpUConvert %ulong [[res]]
30+
// CHECK: OpStore %u64ru16 [[cast]]
31+
uint64_t u64ru16 = countbits(u16);
32+
33+
// CHECK: [[ext:%[0-9]+]] = OpLoad %uint %u32
34+
// CHECK: [[res:%[0-9]+]] = OpBitCount %uint [[ext]]
35+
// CHECK: [[cast:%[0-9]+]] = OpUConvert %ushort [[res]]
36+
// CHECK: OpStore %u16ru32 [[cast]]
37+
uint16_t u16ru32 = countbits(u32);
38+
// CHECK: [[ext:%[0-9]+]] = OpLoad %uint %u32
39+
// CHECK: [[res:%[0-9]+]] = OpBitCount %uint [[ext]]
40+
// CHECK: OpStore %u32ru32 [[res]]
41+
uint32_t u32ru32 = countbits(u32);
42+
// CHECK: [[ext:%[0-9]+]] = OpLoad %uint %u32
43+
// CHECK: [[res:%[0-9]+]] = OpBitCount %uint [[ext]]
44+
// CHECK: [[cast:%[0-9]+]] = OpUConvert %ulong [[res]]
45+
// CHECK: OpStore %u64ru32 [[cast]]
46+
uint64_t u64ru32 = countbits(u32);
47+
48+
// CHECK: [[ld:%[0-9]+]] = OpLoad %ulong %u64
49+
// CHECK-DAG: [[lo:%[0-9]+]] = OpUConvert %uint [[ld]]
50+
// CHECK-DAG: [[sh:%[0-9]+]] = OpShiftRightLogical %ulong [[ld]] %uint_32
51+
// CHECK-DAG: [[hi:%[0-9]+]] = OpUConvert %uint [[sh]]
52+
// CHECK-DAG: [[ca:%[0-9]+]] = OpBitCount %uint [[lo]]
53+
// CHECK-DAG: [[cb:%[0-9]+]] = OpBitCount %uint [[hi]]
54+
// CHECK-DAG: [[re:%[0-9]+]] = OpIAdd %uint [[cb]] [[ca]]
55+
// CHECK-DAG: [[cast:%[0-9]+]] = OpUConvert %ushort [[re]]
56+
// CHECK-DAG: OpStore %u16ru64 [[cast]]
57+
uint16_t u16ru64 = countbits(u64);
58+
59+
// CHECK: [[ld:%[0-9]+]] = OpLoad %ulong %u64
60+
// CHECK-DAG: [[lo:%[0-9]+]] = OpUConvert %uint [[ld]]
61+
// CHECK-DAG: [[sh:%[0-9]+]] = OpShiftRightLogical %ulong [[ld]] %uint_32
62+
// CHECK-DAG: [[hi:%[0-9]+]] = OpUConvert %uint [[sh]]
63+
// CHECK-DAG: [[ca:%[0-9]+]] = OpBitCount %uint [[lo]]
64+
// CHECK-DAG: [[cb:%[0-9]+]] = OpBitCount %uint [[hi]]
65+
// CHECK-DAG: [[re:%[0-9]+]] = OpIAdd %uint [[cb]] [[ca]]
66+
// CHECK-DAG: OpStore %u32ru64 [[re]]
67+
uint32_t u32ru64 = countbits(u64);
68+
69+
// CHECK: [[ld:%[0-9]+]] = OpLoad %ulong %u64
70+
// CHECK-DAG: [[lo:%[0-9]+]] = OpUConvert %uint [[ld]]
71+
// CHECK-DAG: [[sh:%[0-9]+]] = OpShiftRightLogical %ulong [[ld]] %uint_32
72+
// CHECK-DAG: [[hi:%[0-9]+]] = OpUConvert %uint [[sh]]
73+
// CHECK-DAG: [[ca:%[0-9]+]] = OpBitCount %uint [[lo]]
74+
// CHECK-DAG: [[cb:%[0-9]+]] = OpBitCount %uint [[hi]]
75+
// CHECK-DAG: [[re:%[0-9]+]] = OpIAdd %uint [[cb]] [[ca]]
76+
// CHECK-DAG: [[cast:%[0-9]+]] = OpUConvert %ulong [[re]]
77+
// CHECK-DAG: OpStore %u64ru64 [[cast]]
78+
uint64_t u64ru64 = countbits(u64);
79+
80+
int16_t s16;
81+
int32_t s32;
82+
int64_t s64;
83+
84+
// CHECK: [[tmp:%[0-9]+]] = OpLoad %short %s16
85+
// CHECK: [[ext:%[0-9]+]] = OpUConvert %uint [[tmp]]
86+
// CHECK: [[res:%[0-9]+]] = OpBitCount %uint [[ext]]
87+
// CHECK: [[cast:%[0-9]+]] = OpUConvert %ushort [[res]]
88+
// CHECK: [[bc:%[0-9]+]] = OpBitcast %short [[cast]]
89+
// CHECK: OpStore %s16rs16 [[bc]]
90+
int16_t s16rs16 = countbits(s16);
91+
92+
// CHECK: [[tmp:%[0-9]+]] = OpLoad %short %s16
93+
// CHECK: [[ext:%[0-9]+]] = OpUConvert %uint [[tmp]]
94+
// CHECK: [[res:%[0-9]+]] = OpBitCount %uint [[ext]]
95+
// CHECK: [[bc:%[0-9]+]] = OpBitcast %int [[res]]
96+
// CHECK: OpStore %s32rs16 [[bc]]
97+
int32_t s32rs16 = countbits(s16);
98+
99+
// CHECK: [[tmp:%[0-9]+]] = OpLoad %short %s16
100+
// CHECK: [[ext:%[0-9]+]] = OpUConvert %uint [[tmp]]
101+
// CHECK: [[res:%[0-9]+]] = OpBitCount %uint [[ext]]
102+
// CHECK: [[cast:%[0-9]+]] = OpUConvert %ulong [[res]]
103+
// CHECK: [[bc:%[0-9]+]] = OpBitcast %long [[cast]]
104+
// CHECK: OpStore %s64rs16 [[bc]]
105+
int64_t s64rs16 = countbits(s16);
106+
107+
// CHECK: [[ext:%[0-9]+]] = OpLoad %int %s32
108+
// CHECK: [[res:%[0-9]+]] = OpBitCount %uint [[ext]]
109+
// CHECK: [[cast:%[0-9]+]] = OpUConvert %ushort [[res]]
110+
// CHECK: [[bc:%[0-9]+]] = OpBitcast %short [[cast]]
111+
// CHECK: OpStore %s16rs32 [[bc]]
112+
int16_t s16rs32 = countbits(s32);
113+
// CHECK: [[ext:%[0-9]+]] = OpLoad %int %s32
114+
// CHECK: [[res:%[0-9]+]] = OpBitCount %uint [[ext]]
115+
// CHECK: [[bc:%[0-9]+]] = OpBitcast %int [[res]]
116+
// CHECK: OpStore %s32rs32 [[bc]]
117+
int32_t s32rs32 = countbits(s32);
118+
// CHECK: [[ext:%[0-9]+]] = OpLoad %int %s32
119+
// CHECK: [[res:%[0-9]+]] = OpBitCount %uint [[ext]]
120+
// CHECK: [[cast:%[0-9]+]] = OpUConvert %ulong [[res]]
121+
// CHECK: [[bc:%[0-9]+]] = OpBitcast %long [[cast]]
122+
// CHECK: OpStore %s64rs32 [[bc]]
123+
int64_t s64rs32 = countbits(s32);
124+
125+
// CHECK: [[ld:%[0-9]+]] = OpLoad %long %s64
126+
// CHECK-DAG: [[lo:%[0-9]+]] = OpUConvert %uint [[ld]]
127+
// CHECK-DAG: [[sh:%[0-9]+]] = OpShiftRightLogical %long [[ld]] %uint_32
128+
// CHECK-DAG: [[hi:%[0-9]+]] = OpUConvert %uint [[sh]]
129+
// CHECK-DAG: [[ca:%[0-9]+]] = OpBitCount %uint [[lo]]
130+
// CHECK-DAG: [[cb:%[0-9]+]] = OpBitCount %uint [[hi]]
131+
// CHECK-DAG: [[re:%[0-9]+]] = OpIAdd %uint [[cb]] [[ca]]
132+
// CHECK-DAG: [[cast:%[0-9]+]] = OpUConvert %ushort [[re]]
133+
// CHECK-DAG: [[bc:%[0-9]+]] = OpBitcast %short [[cast]]
134+
// CHECK-DAG: OpStore %s16rs64 [[bc]]
135+
int16_t s16rs64 = countbits(s64);
136+
137+
// CHECK: [[ld:%[0-9]+]] = OpLoad %long %s64
138+
// CHECK-DAG: [[lo:%[0-9]+]] = OpUConvert %uint [[ld]]
139+
// CHECK-DAG: [[sh:%[0-9]+]] = OpShiftRightLogical %long [[ld]] %uint_32
140+
// CHECK-DAG: [[hi:%[0-9]+]] = OpUConvert %uint [[sh]]
141+
// CHECK-DAG: [[ca:%[0-9]+]] = OpBitCount %uint [[lo]]
142+
// CHECK-DAG: [[cb:%[0-9]+]] = OpBitCount %uint [[hi]]
143+
// CHECK-DAG: [[re:%[0-9]+]] = OpIAdd %uint [[cb]] [[ca]]
144+
// CHECK-DAG: [[bc:%[0-9]+]] = OpBitcast %int [[re]]
145+
// CHECK-DAG: OpStore %s32rs64 [[bc]]
146+
int32_t s32rs64 = countbits(s64);
147+
148+
// CHECK: [[ld:%[0-9]+]] = OpLoad %long %s64
149+
// CHECK-DAG: [[lo:%[0-9]+]] = OpUConvert %uint [[ld]]
150+
// CHECK-DAG: [[sh:%[0-9]+]] = OpShiftRightLogical %long [[ld]] %uint_32
151+
// CHECK-DAG: [[hi:%[0-9]+]] = OpUConvert %uint [[sh]]
152+
// CHECK-DAG: [[ca:%[0-9]+]] = OpBitCount %uint [[lo]]
153+
// CHECK-DAG: [[cb:%[0-9]+]] = OpBitCount %uint [[hi]]
154+
// CHECK-DAG: [[re:%[0-9]+]] = OpIAdd %uint [[cb]] [[ca]]
155+
// CHECK-DAG: [[cast:%[0-9]+]] = OpUConvert %ulong [[re]]
156+
// CHECK-DAG: [[bc:%[0-9]+]] = OpBitcast %long [[cast]]
157+
// CHECK-DAG: OpStore %s64rs64 [[bc]]
158+
int64_t s64rs64 = countbits(s64);
159+
160+
uint16_t4 vu16;
161+
uint32_t4 vu32;
162+
uint64_t4 vu64;
163+
164+
// CHECK: [[tmp:%[0-9]+]] = OpLoad %v4ushort %vu16
165+
// CHECK-DAG: [[ext:%[0-9]+]] = OpUConvert %v4uint [[tmp]]
166+
// CHECK-NEXT: {{%[0-9]+}} = OpBitCount %v4uint [[ext]]
167+
uint4 rvu16 = countbits(vu16);
168+
169+
// CHECK: [[tmp:%[0-9]+]] = OpLoad %v4uint %vu32
170+
// CHECK-NEXT: {{%[0-9]+}} = OpBitCount %v4uint [[tmp]]
171+
uint4 rvu32 = countbits(vu32);
172+
173+
// CHECK: [[ld:%[0-9]+]] = OpLoad %v4ulong %vu64
174+
// CHECK-DAG: [[lo:%[0-9]+]] = OpUConvert %v4uint [[ld]]
175+
// CHECK-DAG: [[sh:%[0-9]+]] = OpShiftRightLogical %v4ulong [[ld]] [[v4_32]]
176+
// CHECK-DAG: [[hi:%[0-9]+]] = OpUConvert %v4uint [[sh]]
177+
// CHECK-DAG: [[ca:%[0-9]+]] = OpBitCount %v4uint [[lo]]
178+
// CHECK-DAG: [[cb:%[0-9]+]] = OpBitCount %v4uint [[hi]]
179+
// CHECK-DAG: {{%[0-9]+}} = OpIAdd %v4uint [[cb]] [[ca]]
180+
uint4 rvu64 = countbits(vu64);
181+
182+
int16_t4 vs16;
183+
int32_t4 vs32;
184+
int64_t4 vs64;
185+
186+
// CHECK: [[tmp:%[0-9]+]] = OpLoad %v4short %vs16
187+
// CHECK-DAG: [[ext:%[0-9]+]] = OpUConvert %v4uint [[tmp]]
188+
// CHECK-NEXT: {{%[0-9]+}} = OpBitCount %v4uint [[ext]]
189+
uint4 rvs16 = countbits(vs16);
190+
191+
// CHECK: [[tmp:%[0-9]+]] = OpLoad %v4int %vs32
192+
// CHECK-NEXT: {{%[0-9]+}} = OpBitCount %v4uint [[tmp]]
193+
uint4 rvs32 = countbits(vs32);
194+
195+
// CHECK: [[ld:%[0-9]+]] = OpLoad %v4long %vs64
196+
// CHECK-DAG: [[lo:%[0-9]+]] = OpUConvert %v4uint [[ld]]
197+
// CHECK-DAG: [[sh:%[0-9]+]] = OpShiftRightLogical %v4long [[ld]] [[v4_32]]
198+
// CHECK-DAG: [[hi:%[0-9]+]] = OpUConvert %v4uint [[sh]]
199+
// CHECK-DAG: [[ca:%[0-9]+]] = OpBitCount %v4uint [[lo]]
200+
// CHECK-DAG: [[cb:%[0-9]+]] = OpBitCount %v4uint [[hi]]
201+
// CHECK-DAG: {{%[0-9]+}} = OpIAdd %v4uint [[cb]] [[ca]]
202+
uint4 rvs64 = countbits(vs64);
17203
}

0 commit comments

Comments
 (0)