Skip to content

Commit 23a15c5

Browse files
authored
[SPIR-V] Support asdouble() for uint3 argument type. (#7965)
Fixes #7699. Previously, only `asdouble(uint, uint)` and `asdouble(uint2, uint2)` were supported. The fix is to manually extract each component and compose a vector of doubles. Verified [asdouble.32.test](https://github.com/llvm/offload-test-suite/blob/main/test/Feature/HLSLLib/asdouble.32.test) is passing.
1 parent 2b90cd3 commit 23a15c5

File tree

2 files changed

+50
-12
lines changed

2 files changed

+50
-12
lines changed

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12386,7 +12386,9 @@ SpirvEmitter::processIntrinsicAsType(const CallExpr *callExpr) {
1238612386

1238712387
// Method 4: double asdouble(uint lowbits, uint highbits)
1238812388
// Method 5: double2 asdouble(uint2 lowbits, uint2 highbits)
12389-
// Method 6:
12389+
// Method 6: double3 asdouble(uint3 lowbits, uint3 highbits)
12390+
// Method 7: double4 asdouble(uint4 lowbits, uint4 highbits)
12391+
// Method 8:
1239012392
// void asuint(
1239112393
// in double value,
1239212394
// out uint lowbits,
@@ -12435,26 +12437,33 @@ SpirvEmitter::processIntrinsicAsType(const CallExpr *callExpr) {
1243512437
auto *highbits = doExpr(callExpr->getArg(1));
1243612438
const auto uintType = astContext.UnsignedIntTy;
1243712439
const auto doubleType = astContext.DoubleTy;
12440+
uint32_t vecSize;
1243812441
// Handling Method 4
12439-
if (argType->isUnsignedIntegerType()) {
12442+
if (!isVectorType(argType, nullptr, &vecSize)) {
1244012443
const auto uintVec2Type = astContext.getExtVectorType(uintType, 2);
1244112444
auto *operand = spvBuilder.createCompositeConstruct(
1244212445
uintVec2Type, {lowbits, highbits}, loc, range);
1244312446
return spvBuilder.createUnaryOp(spv::Op::OpBitcast, doubleType, operand,
1244412447
loc, range);
1244512448
}
12446-
// Handling Method 5
12449+
// Handling Method 5, 6, 7
1244712450
else {
12448-
const auto uintVec4Type = astContext.getExtVectorType(uintType, 4);
12449-
const auto doubleVec2Type = astContext.getExtVectorType(doubleType, 2);
12450-
auto *operand = spvBuilder.createVectorShuffle(
12451-
uintVec4Type, lowbits, highbits, {0, 2, 1, 3}, loc, range);
12452-
return spvBuilder.createUnaryOp(spv::Op::OpBitcast, doubleVec2Type,
12453-
operand, loc, range);
12451+
std::vector<SpirvInstruction *> doubles = {};
12452+
const auto uintVec2Type = astContext.getExtVectorType(uintType, 2);
12453+
// For each pair, convert them to double.
12454+
for (uint32_t i = 0; i < vecSize; ++i) {
12455+
auto *operand = spvBuilder.createVectorShuffle(
12456+
uintVec2Type, lowbits, highbits, {i, vecSize + i}, loc, range);
12457+
SpirvInstruction *doubleElem = spvBuilder.createUnaryOp(
12458+
spv::Op::OpBitcast, doubleType, operand, loc, range);
12459+
doubles.push_back(doubleElem);
12460+
}
12461+
return spvBuilder.createCompositeConstruct(returnType, doubles, loc,
12462+
range);
1245412463
}
1245512464
}
1245612465
case 3: {
12457-
// Handling Method 6.
12466+
// Handling Method 8.
1245812467
const Expr *arg1 = callExpr->getArg(1);
1245912468
const Expr *arg2 = callExpr->getArg(2);
1246012469

tools/clang/test/CodeGenSPIRV/intrinsics.asdouble.hlsl

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,37 @@ void main() {
1919

2020
// CHECK: [[low2:%[0-9]+]] = OpLoad %v2uint %low2
2121
// CHECK-NEXT: [[high2:%[0-9]+]] = OpLoad %v2uint %high2
22-
// CHECK-NEXT: [[arg3:%[0-9]+]] = OpVectorShuffle %v4uint [[low2]] [[high2]] 0 2 1 3
23-
// CHECK-NEXT: {{%[0-9]+}} = OpBitcast %v2double [[arg3]]
22+
// CHECK-NEXT: [[elem1:%[0-9]+]] = OpVectorShuffle %v2uint [[low2]] [[high2]] 0 2
23+
// CHECK-NEXT: [[double1:%[0-9]+]] = OpBitcast %double [[elem1]]
24+
// CHECK-NEXT: [[elem2:%[0-9]+]] = OpVectorShuffle %v2uint [[low2]] [[high2]] 1 3
25+
// CHECK-NEXT: [[double2:%[0-9]+]] = OpBitcast %double [[elem2]]
26+
// CHECK-NEXT: {{%[0-9]+}} = OpCompositeConstruct %v2double [[double1]] [[double2]]
2427
uint2 low2, high2;
2528
double2 c = asdouble(low2, high2);
29+
30+
// CHECK: [[low3:%[0-9]+]] = OpLoad %v3uint %low3
31+
// CHECK-NEXT: [[high3:%[0-9]+]] = OpLoad %v3uint %high3
32+
// CHECK-NEXT: [[elem1:%[0-9]+]] = OpVectorShuffle %v2uint [[low3]] [[high3]] 0 3
33+
// CHECK-NEXT: [[double1:%[0-9]+]] = OpBitcast %double [[elem1]]
34+
// CHECK-NEXT: [[elem2:%[0-9]+]] = OpVectorShuffle %v2uint [[low3]] [[high3]] 1 4
35+
// CHECK-NEXT: [[double2:%[0-9]+]] = OpBitcast %double [[elem2]]
36+
// CHECK-NEXT: [[elem3:%[0-9]+]] = OpVectorShuffle %v2uint [[low3]] [[high3]] 2 5
37+
// CHECK-NEXT: [[double3:%[0-9]+]] = OpBitcast %double [[elem3]]
38+
// CHECK-NEXT: {{%[0-9]+}} = OpCompositeConstruct %v3double [[double1]] [[double2]] [[double3]]
39+
uint3 low3, high3;
40+
double3 d = asdouble(low3, high3);
41+
42+
// CHECK: [[low4:%[0-9]+]] = OpLoad %v4uint %low4
43+
// CHECK-NEXT: [[high4:%[0-9]+]] = OpLoad %v4uint %high4
44+
// CHECK-NEXT: [[elem1:%[0-9]+]] = OpVectorShuffle %v2uint [[low4]] [[high4]] 0 4
45+
// CHECK-NEXT: [[double1:%[0-9]+]] = OpBitcast %double [[elem1]]
46+
// CHECK-NEXT: [[elem2:%[0-9]+]] = OpVectorShuffle %v2uint [[low4]] [[high4]] 1 5
47+
// CHECK-NEXT: [[double2:%[0-9]+]] = OpBitcast %double [[elem2]]
48+
// CHECK-NEXT: [[elem3:%[0-9]+]] = OpVectorShuffle %v2uint [[low4]] [[high4]] 2 6
49+
// CHECK-NEXT: [[double3:%[0-9]+]] = OpBitcast %double [[elem3]]
50+
// CHECK-NEXT: [[elem4:%[0-9]+]] = OpVectorShuffle %v2uint [[low4]] [[high4]] 3 7
51+
// CHECK-NEXT: [[double4:%[0-9]+]] = OpBitcast %double [[elem4]]
52+
// CHECK-NEXT: {{%[0-9]+}} = OpCompositeConstruct %v4double [[double1]] [[double2]] [[double3]] [[double4]]
53+
uint4 low4, high4;
54+
double4 e = asdouble(low4, high4);
2655
}

0 commit comments

Comments
 (0)