[clang] [llvm] Adding splitdouble HLSL function (PR #109331)

via cfe-commits cfe-commits at lists.llvm.org
Fri Oct 18 09:48:13 PDT 2024


================
@@ -18957,6 +18957,134 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         CGM.getHLSLRuntime().getRadiansIntrinsic(), ArrayRef<Value *>{Op0},
         nullptr, "hlsl.radians");
   }
+  case Builtin::BI__builtin_hlsl_splitdouble: {
+
+    assert((E->getArg(0)->getType()->hasFloatingRepresentation() &&
+            E->getArg(1)->getType()->hasUnsignedIntegerRepresentation() &&
+            E->getArg(2)->getType()->hasUnsignedIntegerRepresentation()) &&
+           "asuint operands types mismatch");
+    Value *Op0 = EmitScalarExpr(E->getArg(0));
+    const auto *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1));
+    const auto *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2));
+
+    CallArgList Args;
+    auto [Op1BaseLValue, Op1TmpLValue] =
+        EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
+    auto [Op2BaseLValue, Op2TmpLValue] =
+        EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
+
+    if (getTarget().getCXXABI().areArgsDestroyedLeftToRightInCallee())
+      Args.reverseWritebacks();
+
+    auto EmitVectorCode =
+        [](Value *Op, CGBuilderTy *Builder,
+           FixedVectorType *DestTy) -> std::pair<Value *, Value *> {
+      Value *bitcast = Builder->CreateBitCast(Op, DestTy);
+
+      SmallVector<int> LowbitsIndex;
+      SmallVector<int> HighbitsIndex;
+
+      for (unsigned int Idx = 0; Idx < DestTy->getNumElements(); Idx += 2) {
+        LowbitsIndex.push_back(Idx);
+        HighbitsIndex.push_back(Idx + 1);
+      }
+
+      Value *Arg0 = Builder->CreateShuffleVector(bitcast, LowbitsIndex);
+      Value *Arg1 = Builder->CreateShuffleVector(bitcast, HighbitsIndex);
+
+      return std::make_pair(Arg0, Arg1);
+    };
+
+    Value *LastInst = nullptr;
+
+    if (CGM.getTarget().getTriple().isDXIL()) {
+
+      llvm::Type *RetElementTy = Int32Ty;
+      if (auto *Op0VecTy = E->getArg(0)->getType()->getAs<VectorType>())
+        RetElementTy = llvm::VectorType::get(
+            Int32Ty, ElementCount::getFixed(Op0VecTy->getNumElements()));
+      auto *RetTy = llvm::StructType::get(RetElementTy, RetElementTy);
+
+      CallInst *CI = Builder.CreateIntrinsic(
+          RetTy, Intrinsic::dx_splitdouble, {Op0}, nullptr, "hlsl.splitdouble");
+
+      Value *Arg0 = Builder.CreateExtractValue(CI, 0);
+      Value *Arg1 = Builder.CreateExtractValue(CI, 1);
+
+      Builder.CreateStore(Arg0, Op1TmpLValue.getAddress());
+      LastInst = Builder.CreateStore(Arg1, Op2TmpLValue.getAddress());
+
+    } else {
+
+      assert(!CGM.getTarget().getTriple().isDXIL() &&
+             "For non-DXIL targets we generate the instructions");
+
+      if (!Op0->getType()->isVectorTy()) {
+        FixedVectorType *DestTy = FixedVectorType::get(Int32Ty, 2);
+        Value *Bitcast = Builder.CreateBitCast(Op0, DestTy);
+
+        Value *Arg0 = Builder.CreateExtractElement(Bitcast, 0.0);
+        Value *Arg1 = Builder.CreateExtractElement(Bitcast, 1.0);
+
+        Builder.CreateStore(Arg0, Op1TmpLValue.getAddress());
+        LastInst = Builder.CreateStore(Arg1, Op2TmpLValue.getAddress());
+      } else {
+
+        const auto *TargTy = E->getArg(0)->getType()->getAs<VectorType>();
+
+        int NumElements = TargTy->getNumElements();
+
+        FixedVectorType *DestTy = FixedVectorType::get(Int32Ty, 4);
+        if (NumElements == 2) {
+          std::pair<Value *, Value *> Vec2res =
+              EmitVectorCode(Op0, &Builder, DestTy);
+
+          Builder.CreateStore(Vec2res.first, Op1TmpLValue.getAddress());
+          LastInst =
+              Builder.CreateStore(Vec2res.second, Op2TmpLValue.getAddress());
+        } else {
+
+          SmallVector<std::pair<Value *, Value *>> EmitedValuePairs;
+
+          for (int It = 0; It < NumElements; It += 2) {
+            // Second element in the index mask might be useless if NumElements
----------------
joaosaffran wrote:

Sure, a little context first: i) the native SPIR-V vector size is limited to 4; ii) Shufflevector requires both vector operands to be the same size. 

So, when we have an input vector with odd size, like 3 for example, we will need to perform 2 `bitcast`. Here is some IR to illustrate that:
```
define <3 x i32> @f(<3 x double> noundef %D) local_unnamed_addr {
entry:
  %0 = shufflevector <3 x double> %D, <3 x double> poison, <2 x i32> <i32 0, i32 1>
  %1 = shufflevector <3 x double> %D, <3 x double> poison, <1 x i32> <i32 2>
  %2 = bitcast <2 x double> %0 to <4 x i32>
  %3 = bitcast <1 x double> %1 to <2 x i32>
  ...
}
```

The next step is to extract the correct lower bits and highbits, for that we can use 2 `shufflevector`, one for the highbits, another for the lowbits. However, `%0` and `%1` have different sizes, which is not allowed in `shufflevector` op. 
```
define <3 x i32> @f(<3 x double> noundef %D) local_unnamed_addr {
entry:
  %0 = shufflevector <3 x double> %D, <3 x double> poison, <2 x i32> <i32 0, i32 1>
  %1 = shufflevector <3 x double> %D, <3 x double> poison, <1 x i32> <i32 2>
  %2 = bitcast <2 x double> %0 to <4 x i32>
  %3 = bitcast <1 x double> %1 to <2 x i32>
  ; The operations bellowed are not allowed, because %2 and %3 have different sizes.                       
  %high = shufflevector <4 x i32> %2, <2 x i32> %3, <3 x i32> <i32 0, i32 2, i32 4>
  %low = shufflevector <4 x i32> %2, <2 x i32> %3, <3 x i32> <i32 1, i32 3, i32 5>
...
}
```

So the solution I came up with is to add a dummy value to `%1`, that way the following vectors `%2` and `%3` have the same size, which allows the usage of shuffle vector. But, in the end, when getting `%high` and `%low` we don't mask out the dummy value.

```
define <3 x i32> @f(<3 x double> noundef %D) local_unnamed_addr {
entry:
  %0 = shufflevector <3 x double> %D, <3 x double> poison, <2 x i32> <i32 0, i32 1>
  ; adding a dummy value in the mask bellow
  %1 = shufflevector <3 x double> %D, <3 x double> poison, <2 x i32> <i32 2, i32 0>
  %2 = bitcast <2 x double> %0 to <4 x i32>
  %3 = bitcast <2 x double> %1 to <4 x i32>
  ; The operations bellow only read indexes 0-5, ignoring the dummy value on indexes 6 and 7.                     
  %high = shufflevector <4 x i32> %2, <2 x i32> %3, <3 x i32> <i32 0, i32 2, i32 4>
  %low = shufflevector <4 x i32> %2, <2 x i32> %3, <3 x i32> <i32 1, i32 3, i32 5>
...
}
```

Hopefully this clarified the comment a little bit.

https://github.com/llvm/llvm-project/pull/109331


More information about the cfe-commits mailing list