[clang] [llvm] Adding splitdouble HLSL function (PR #109331)

via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 23 17:33:36 PDT 2024


================
@@ -95,6 +99,144 @@ static void initializeAlloca(CodeGenFunction &CGF, AllocaInst *AI, Value *Size,
   I->addAnnotationMetadata("auto-init");
 }
 
+static Value *handleHlslSplitdouble(const CallExpr *E, CodeGenFunction *CGF) {
+  Value *Op0 = CGF->EmitScalarExpr(E->getArg(0));
+  const auto *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1));
+  const auto *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2));
+
+  CallArgList Args;
+  LValue Op1TmpLValue =
+      CGF->EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
+  LValue Op2TmpLValue =
+      CGF->EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
+
+  if (CGF->getTarget().getCXXABI().areArgsDestroyedLeftToRightInCallee())
+    Args.reverseWritebacks();
+
+  auto EmitVectorCode =
+      [](Value *Op, CGBuilderTy *Builder,
+         FixedVectorType *DestTy) -> std::pair<Value *, Value *> {
+    Value *bitcast = Builder->CreateBitCast(Op, DestTy);
+
+    SmallVector<int> LowbitsIndex;
+    SmallVector<int> HighbitsIndex;
+
+    for (unsigned int Idx = 0; Idx < DestTy->getNumElements(); Idx += 2) {
+      LowbitsIndex.push_back(Idx);
+      HighbitsIndex.push_back(Idx + 1);
+    }
+
+    Value *Arg0 = Builder->CreateShuffleVector(bitcast, LowbitsIndex);
+    Value *Arg1 = Builder->CreateShuffleVector(bitcast, HighbitsIndex);
+
+    return std::make_pair(Arg0, Arg1);
+  };
+
+  Value *LowBits = nullptr;
+  Value *HighBits = nullptr;
+
+  if (CGF->CGM.getTarget().getTriple().isDXIL()) {
+
+    llvm::Type *RetElementTy = CGF->Int32Ty;
+    if (auto *Op0VecTy = E->getArg(0)->getType()->getAs<clang::VectorType>())
+      RetElementTy = llvm::VectorType::get(
+          CGF->Int32Ty, ElementCount::getFixed(Op0VecTy->getNumElements()));
+    auto *RetTy = llvm::StructType::get(RetElementTy, RetElementTy);
+
+    CallInst *CI = CGF->Builder.CreateIntrinsic(
+        RetTy, Intrinsic::dx_splitdouble, {Op0}, nullptr, "hlsl.splitdouble");
+
+    LowBits = CGF->Builder.CreateExtractValue(CI, 0);
+    HighBits = CGF->Builder.CreateExtractValue(CI, 1);
+
+  } else {
+    // For Non DXIL targets we generate the instructions.
+    // TODO: This code accounts for known limitations in
+    // SPIR-V and splitdouble. Such should be handled,
+    // in a later compilation stage. After [issue link here]
+    // is fixed, this shall be refactored.
+
+    if (!Op0->getType()->isVectorTy()) {
+      FixedVectorType *DestTy = FixedVectorType::get(CGF->Int32Ty, 2);
+      Value *Bitcast = CGF->Builder.CreateBitCast(Op0, DestTy);
+
+      LowBits = CGF->Builder.CreateExtractElement(Bitcast, 0.0);
+      HighBits = CGF->Builder.CreateExtractElement(Bitcast, 1.0);
+    } else {
+
+      const auto *TargTy = E->getArg(0)->getType()->getAs<clang::VectorType>();
+
+      int NumElements = TargTy->getNumElements();
+
+      FixedVectorType *DestTy = FixedVectorType::get(CGF->Int32Ty, 4);
+
+      if (NumElements == 1) {
+        FixedVectorType *DestTy = FixedVectorType::get(CGF->Int32Ty, 2);
+        auto *Bitcast = CGF->Builder.CreateBitCast(Op0, DestTy);
+
+        LowBits = CGF->Builder.CreateExtractElement(Bitcast, 0.0);
+        HighBits = CGF->Builder.CreateExtractElement(Bitcast, 1.0);
+      } else if (NumElements == 2) {
+        auto [LB, HB] = EmitVectorCode(Op0, &CGF->Builder, DestTy);
+        LowBits = LB;
+        HighBits = HB;
+      } else {
+
+        SmallVector<std::pair<Value *, Value *>> EmitedValuePairs;
+
+        int isOdd = NumElements % 2;
+        int NumEvenElements = NumElements - isOdd;
+
+        Value *FinalElementCast = nullptr;
+        for (int It = 0; It < NumEvenElements; It += 2) {
+          auto Shuff = CGF->Builder.CreateShuffleVector(Op0, {It, It + 1});
+          std::pair<Value *, Value *> ValuePair =
+              EmitVectorCode(Shuff, &CGF->Builder, DestTy);
+          EmitedValuePairs.push_back(ValuePair);
+        }
+
+        if (isOdd == 1) {
+          FixedVectorType *DestTy = FixedVectorType::get(CGF->Int32Ty, 2);
+          auto *EV = CGF->Builder.CreateExtractElement(Op0, NumElements - 1);
+          FinalElementCast = CGF->Builder.CreateBitCast(EV, DestTy);
+        }
+
+        SmallVector<int> Index = {0, 1};
+
+        auto lb = EmitedValuePairs[0].first;
+        auto hb = EmitedValuePairs[0].second;
+
+        int EvenSizedPairs = EmitedValuePairs.size() - isOdd;
+
+        for (int It = 1; It < EvenSizedPairs; It++) {
----------------
joaosaffran wrote:

Refactor it to only handle vectors of size up to 4.

https://github.com/llvm/llvm-project/pull/109331


More information about the llvm-commits mailing list