[clang] [llvm] Adding splitdouble HLSL function (PR #109331)

Tex Riddell via cfe-commits cfe-commits at lists.llvm.org
Fri Oct 18 17:55:23 PDT 2024


================
@@ -18952,6 +18955,142 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         CGM.getHLSLRuntime().getRadiansIntrinsic(), ArrayRef<Value *>{Op0},
         nullptr, "hlsl.radians");
   }
+  case Builtin::BI__builtin_hlsl_splitdouble: {
+
+    assert((E->getArg(0)->getType()->hasFloatingRepresentation() &&
+            E->getArg(1)->getType()->hasUnsignedIntegerRepresentation() &&
+            E->getArg(2)->getType()->hasUnsignedIntegerRepresentation()) &&
+           "asuint operands types mismatch");
+    Value *Op0 = EmitScalarExpr(E->getArg(0));
+    const auto *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1));
+    const auto *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2));
+
+    CallArgList Args;
+    LValue Op1TmpLValue = EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
+    LValue Op2TmpLValue = EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
+
+    if (getTarget().getCXXABI().areArgsDestroyedLeftToRightInCallee())
+      Args.reverseWritebacks();
+
+    auto EmitVectorCode =
+        [](Value *Op, CGBuilderTy *Builder,
+           FixedVectorType *DestTy) -> std::pair<Value *, Value *> {
+      Value *bitcast = Builder->CreateBitCast(Op, DestTy);
+
+      SmallVector<int> LowbitsIndex;
+      SmallVector<int> HighbitsIndex;
+
+      for (unsigned int Idx = 0; Idx < DestTy->getNumElements(); Idx += 2) {
+        LowbitsIndex.push_back(Idx);
+        HighbitsIndex.push_back(Idx + 1);
+      }
+
+      Value *Arg0 = Builder->CreateShuffleVector(bitcast, LowbitsIndex);
+      Value *Arg1 = Builder->CreateShuffleVector(bitcast, HighbitsIndex);
+
+      return std::make_pair(Arg0, Arg1);
+    };
+
+    Value *LastInst = nullptr;
+
+    if (CGM.getTarget().getTriple().isDXIL()) {
+
+      llvm::Type *RetElementTy = Int32Ty;
+      if (auto *Op0VecTy = E->getArg(0)->getType()->getAs<VectorType>())
+        RetElementTy = llvm::VectorType::get(
+            Int32Ty, ElementCount::getFixed(Op0VecTy->getNumElements()));
+      auto *RetTy = llvm::StructType::get(RetElementTy, RetElementTy);
+
+      CallInst *CI = Builder.CreateIntrinsic(
+          RetTy, Intrinsic::dx_splitdouble, {Op0}, nullptr, "hlsl.splitdouble");
+
+      Value *Arg0 = Builder.CreateExtractValue(CI, 0);
+      Value *Arg1 = Builder.CreateExtractValue(CI, 1);
+
+      Builder.CreateStore(Arg0, Op1TmpLValue.getAddress());
+      LastInst = Builder.CreateStore(Arg1, Op2TmpLValue.getAddress());
+
+    } else {
+
+      assert(!CGM.getTarget().getTriple().isDXIL() &&
+             "For non-DXIL targets we generate the instructions");
+
+      if (!Op0->getType()->isVectorTy()) {
+        FixedVectorType *DestTy = FixedVectorType::get(Int32Ty, 2);
+        Value *Bitcast = Builder.CreateBitCast(Op0, DestTy);
+
+        Value *Arg0 = Builder.CreateExtractElement(Bitcast, 0.0);
+        Value *Arg1 = Builder.CreateExtractElement(Bitcast, 1.0);
+
+        Builder.CreateStore(Arg0, Op1TmpLValue.getAddress());
+        LastInst = Builder.CreateStore(Arg1, Op2TmpLValue.getAddress());
+      } else {
+
+        const auto *TargTy = E->getArg(0)->getType()->getAs<VectorType>();
+
+        int NumElements = TargTy->getNumElements();
+
+        FixedVectorType *DestTy = FixedVectorType::get(Int32Ty, 4);
+        if (NumElements == 1) {
+          FixedVectorType *DestTy = FixedVectorType::get(Int32Ty, 2);
+          Value *Bitcast = Builder.CreateBitCast(Op0, DestTy);
+
+          Value *Arg0 = Builder.CreateExtractElement(Bitcast, 0.0);
+          Value *Arg1 = Builder.CreateExtractElement(Bitcast, 1.0);
+
+          Builder.CreateStore(Arg0, Op1TmpLValue.getAddress());
+          LastInst = Builder.CreateStore(Arg1, Op2TmpLValue.getAddress());
+        } else if (NumElements == 2) {
+          auto [LowBits, HighBits] = EmitVectorCode(Op0, &Builder, DestTy);
+
+          Builder.CreateStore(LowBits, Op1TmpLValue.getAddress());
+          LastInst = Builder.CreateStore(HighBits, Op2TmpLValue.getAddress());
+        } else {
+
+          SmallVector<std::pair<Value *, Value *>> EmitedValuePairs;
+
+          for (int It = 0; It < NumElements; It += 2) {
+            // Due to existing restrictions to SPIR-V and splitdouble,
+            // all shufflevector operations, should return vectors of
+            // the same size, up to 4. Such introduce and edge case
----------------
tex3d wrote:

I think this expansion (if we still have to do it here) would be better done a different way.

Instead of adding a dummy value to the original shuffle for the cast, add a shuffle to extend the casted result vector when needed, adding poison values instead of "dummy" values there.  This keeps the extra values localized, poison, and more easily eliminated, instead of passing through the bitcast.  This can be done as part of generating the final high/low shuffles when the vector sizes don't match at the end.  That also keeps the logic localized and makes it more obvious why it's needed.

https://github.com/llvm/llvm-project/pull/109331


More information about the cfe-commits mailing list