[clang] [llvm] Adding splitdouble HLSL function (PR #109331)

Justin Bogner via cfe-commits cfe-commits at lists.llvm.org
Fri Oct 25 17:33:54 PDT 2024


================
@@ -95,6 +99,126 @@ static void initializeAlloca(CodeGenFunction &CGF, AllocaInst *AI, Value *Size,
   I->addAnnotationMetadata("auto-init");
 }
 
+static Value *handleHlslSplitdouble(const CallExpr *E, CodeGenFunction *CGF) {
+  Value *Op0 = CGF->EmitScalarExpr(E->getArg(0));
+  const auto *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1));
+  const auto *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2));
+
+  CallArgList Args;
+  LValue Op1TmpLValue =
+      CGF->EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
+  LValue Op2TmpLValue =
+      CGF->EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
+
+  if (CGF->getTarget().getCXXABI().areArgsDestroyedLeftToRightInCallee())
+    Args.reverseWritebacks();
+
+  Value *LowBits = nullptr;
+  Value *HighBits = nullptr;
+
+  if (CGF->CGM.getTarget().getTriple().isDXIL()) {
+
+    llvm::Type *RetElementTy = CGF->Int32Ty;
+    if (auto *Op0VecTy = E->getArg(0)->getType()->getAs<clang::VectorType>())
+      RetElementTy = llvm::VectorType::get(
+          CGF->Int32Ty, ElementCount::getFixed(Op0VecTy->getNumElements()));
+    auto *RetTy = llvm::StructType::get(RetElementTy, RetElementTy);
+
+    CallInst *CI = CGF->Builder.CreateIntrinsic(
+        RetTy, Intrinsic::dx_splitdouble, {Op0}, nullptr, "hlsl.splitdouble");
+
+    LowBits = CGF->Builder.CreateExtractValue(CI, 0);
+    HighBits = CGF->Builder.CreateExtractValue(CI, 1);
+
+  } else {
+    // For Non DXIL targets we generate the instructions.
+    // TODO: This code accounts for known limitations in
+    // SPIR-V and splitdouble. Such should be handled,
+    // in a later compilation stage. After
+    // https://github.com/llvm/llvm-project/issues/113597 is fixed, this shall
+    // be refactored.
+
+    // casts `<2 x double>` to `<4 x i32>`, then shuffles into high and low
+    // `<2 x i32>` vectors.
+    auto EmitDouble2Cast =
+        [](CodeGenFunction &CGF,
+           Value *DoubleVec2) -> std::pair<Value *, Value *> {
+      Value *BC = CGF.Builder.CreateBitCast(
+          DoubleVec2, FixedVectorType::get(CGF.Int32Ty, 4));
+      Value *LB = CGF.Builder.CreateShuffleVector(BC, {0, 2});
+      Value *HB = CGF.Builder.CreateShuffleVector(BC, {1, 3});
+      return std::make_pair(LB, HB);
+    };
+
+    if (!Op0->getType()->isVectorTy()) {
+      FixedVectorType *DestTy = FixedVectorType::get(CGF->Int32Ty, 2);
+      Value *Bitcast = CGF->Builder.CreateBitCast(Op0, DestTy);
+
+      LowBits = CGF->Builder.CreateExtractElement(Bitcast, (uint64_t)0);
+      HighBits = CGF->Builder.CreateExtractElement(Bitcast, 1);
+    } else {
+
+      const auto *TargTy = E->getArg(0)->getType()->getAs<clang::VectorType>();
+
+      int NumElements = TargTy->getNumElements();
+
+      FixedVectorType *UintVec2 = FixedVectorType::get(CGF->Int32Ty, 2);
+
+      switch (NumElements) {
+      case 1: {
+        auto *Bitcast = CGF->Builder.CreateBitCast(Op0, UintVec2);
+
+        LowBits = CGF->Builder.CreateExtractElement(Bitcast, (uint64_t)0);
+        HighBits = CGF->Builder.CreateExtractElement(Bitcast, 1);
+        break;
+      }
+      case 2: {
+        auto [LB, HB] = EmitDouble2Cast(*CGF, Op0);
+        LowBits = LB;
+        HighBits = HB;
+        break;
+      }
+
+      case 3: {
+        auto *Shuff = CGF->Builder.CreateShuffleVector(Op0, {0, 1});
+        auto [LB, HB] = EmitDouble2Cast(*CGF, Shuff);
+
+        auto *EV = CGF->Builder.CreateExtractElement(Op0, 2);
+        auto *ScalarBitcast = CGF->Builder.CreateBitCast(EV, UintVec2);
+
+        LowBits =
+            CGF->Builder.CreateShuffleVector(LB, ScalarBitcast, {0, 1, 2});
+        HighBits =
+            CGF->Builder.CreateShuffleVector(HB, ScalarBitcast, {0, 1, 3});
+        break;
+      }
+      case 4: {
+
+        auto *Shuff1 = CGF->Builder.CreateShuffleVector(Op0, {0, 1});
+        auto [LB1, HB1] = EmitDouble2Cast(*CGF, Shuff1);
+
+        auto *Shuff2 = CGF->Builder.CreateShuffleVector(Op0, {2, 3});
+        auto [LB2, HB2] = EmitDouble2Cast(*CGF, Shuff2);
+
+        LowBits = CGF->Builder.CreateShuffleVector(LB1, LB2, {0, 1, 2, 3});
+        HighBits = CGF->Builder.CreateShuffleVector(HB1, HB2, {0, 1, 3, 3});
+        break;
+      }
+      default: {
+        CGF->CGM.Error(E->getExprLoc(),
+                       "splitdouble doesn't support vectors larger than 4.");
+        return nullptr;
+      }
+      }
+    }
----------------
bogner wrote:

So IIUC from the comment, #113597, and some of the previous discussion on this PR, what we really want to do is something simple that just bitcasts the scalar or vector of doubles to a scalar or vector of uint32 and then shufflevector the even and odd parts out, like so:
```c++
    int NumElements = 1;
    if (const auto *VecTy = E->getArg(0)->getType()->getAs<clang::VectorType>())
      NumElements = VecTy->getNumElements();

    FixedVectorType *Uint32VecTy =
        FixedVectorType::get(CGF->Int32Ty, NumElements * 2);
    Value *Uint32Vec = CGF->Builder.CreateBitCast(Op0, Uint32VecTy);
    if (NumElements == 1) {
      LowBits = CGF->Builder.CreateExtractElement(Uint32Vec, (uint64_t)0);
      HighBits = CGF->Builder.CreateExtractElement(Uint32Vec, 1);
    } else {
      SmallVector<int> EvenMask, OddMask;
      for (int I = 0, E = NumElements; I != E; ++I) {
        EvenMask.push_back(I * 2);
        OddMask.push_back(I * 2 + 1);
      }
      LowBits = CGF->Builder.CreateShuffleVector(Uint32Vec, EvenMask);
      HighBits = CGF->Builder.CreateShuffleVector(Uint32Vec, OddMask);
    }
```

but because SPIR-V currently doesn't legalize the shufflevector operation, this kind of thing fails if we try to actually emit SPIR-V. Is that correct?

I'm not really convinced that it's worth putting this giant workaround in rather than just doing the simple thing directly with the understanding that only the scalar and `double2` cases will fully work for now. As far as the frontend parts go that would still be fully testable, and we should be able to have high confidence that when we do the work to improve vector legalization in SPIR-V it will "just work", rather than us having to come back and change everything again.

WDYT?

https://github.com/llvm/llvm-project/pull/109331


More information about the cfe-commits mailing list