[clang] [llvm] [HLSL][SPIRV]Add SPIRV generation for HLSL dot (PR #104656)

Greg Roth via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 19 11:34:02 PDT 2024


================
@@ -68,28 +69,65 @@ static Value *expandAbs(CallInst *Orig) {
                                  "dx.max");
 }
 
-static Value *expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
+// Create DXIL dot intrinsics for floating point dot operations
+static Value *expandFloatDotIntrinsic(CallInst *Orig) {
+  Value *A = Orig->getOperand(0);
+  Value *B = Orig->getOperand(1);
+  Type *ATy = A->getType();
+  [[maybe_unused]] Type *BTy = B->getType();
+  assert(ATy->isVectorTy() && BTy->isVectorTy());
+
+  IRBuilder<> Builder(Orig);
+
+  auto *AVec = dyn_cast<FixedVectorType>(ATy);
+
+  assert(ATy->getScalarType()->isFloatingPointTy());
+
+  Intrinsic::ID DotIntrinsic = Intrinsic::dx_dot4;
+  switch (AVec->getNumElements()) {
+  case 2:
+    DotIntrinsic = Intrinsic::dx_dot2;
+    break;
+  case 3:
+    DotIntrinsic = Intrinsic::dx_dot3;
+    break;
+  case 4:
+    DotIntrinsic = Intrinsic::dx_dot4;
+    break;
+  default:
+    llvm_unreachable("dot product with vector outside 2-4 range");
+  }
+  return Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic,
+                                 ArrayRef<Value *>{A, B}, nullptr, "dot");
+}
+
+// Expand integer dot product to multiply and add ops
+static Value *expandIntegerDotIntrinsic(CallInst *Orig,
+                                        Intrinsic::ID DotIntrinsic) {
   assert(DotIntrinsic == Intrinsic::dx_sdot ||
          DotIntrinsic == Intrinsic::dx_udot);
-  Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
-                                   ? Intrinsic::dx_imad
-                                   : Intrinsic::dx_umad;
   Value *A = Orig->getOperand(0);
   Value *B = Orig->getOperand(1);
-  [[maybe_unused]] Type *ATy = A->getType();
+  Type *ATy = A->getType();
   [[maybe_unused]] Type *BTy = B->getType();
   assert(ATy->isVectorTy() && BTy->isVectorTy());
 
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
+
+  auto *AVec = dyn_cast<FixedVectorType>(ATy);
 
-  auto *AVec = dyn_cast<FixedVectorType>(A->getType());
+  assert(ATy->getScalarType()->isIntegerTy());
+
+  Value *Result;
+  Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
+                                   ? Intrinsic::dx_imad
+                                   : Intrinsic::dx_umad;
   Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0);
   Value *Elt1 = Builder.CreateExtractElement(B, (uint64_t)0);
-  Value *Result = Builder.CreateMul(Elt0, Elt1);
-  for (unsigned I = 1; I < AVec->getNumElements(); I++) {
-    Elt0 = Builder.CreateExtractElement(A, I);
-    Elt1 = Builder.CreateExtractElement(B, I);
+  Result = Builder.CreateMul(Elt0, Elt1);
+  for (unsigned i = 1; i < AVec->getNumElements(); i++) {
----------------
pow2clk wrote:

I try so hard to be agnostic about style guides, but stuff like this challenge my lack of creed. 😣

https://github.com/llvm/llvm-project/pull/104656


More information about the llvm-commits mailing list