[llvm] [SPIRV] Use `Op[S|U]Dot` when possible for integer dot product (PR #115095)

Finn Plummer via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 21 10:56:14 PST 2024


https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/115095

>From 087fd8212e866468fc1e5375fc02a69ab3cc28a9 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 5 Nov 2024 23:12:22 +0000
Subject: [PATCH 1/4] [SPIRV] Use `Op[S|U]Dot` when possible for integer dot
 product

- use the new OpSDot/OpUDot instructions when capabilites allow in
SPIRVInstructionSelector.cpp
- correct functionality of capability check onto input operand and not
return operand type in SPIRVModuleAnalysis.cpp

- add test cases to demonstrate use case in idot.ll
---
 .../Target/SPIRV/SPIRVInstructionSelector.cpp |  30 ++++-
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp |  18 ++-
 .../CodeGen/SPIRV/hlsl-intrinsics/idot.ll     | 124 +++++++++++++-----
 3 files changed, 128 insertions(+), 44 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 414583aea91e64..4cb6f8e0af59ce 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -179,7 +179,10 @@ class SPIRVInstructionSelector : public InstructionSelector {
                            MachineInstr &I, unsigned Opcode) const;
 
   bool selectIntegerDot(Register ResVReg, const SPIRVType *ResType,
-                        MachineInstr &I) const;
+                        MachineInstr &I, bool Signed) const;
+
+  bool selectIntegerDotExpansion(Register ResVReg, const SPIRVType *ResType,
+                                 MachineInstr &I) const;
 
   template <bool Signed>
   bool selectDot4AddPacked(Register ResVReg, const SPIRVType *ResType,
@@ -1681,9 +1684,27 @@ bool SPIRVInstructionSelector::selectFloatDot(Register ResVReg,
       .constrainAllUses(TII, TRI, RBI);
 }
 
+bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
+                                                const SPIRVType *ResType,
+                                                MachineInstr &I,
+                                                bool Signed) const {
+  assert(I.getNumOperands() == 4);
+  assert(I.getOperand(2).isReg());
+  assert(I.getOperand(3).isReg());
+  MachineBasicBlock &BB = *I.getParent();
+
+  auto DotOp = Signed ? SPIRV::OpSDot : SPIRV::OpUDot;
+  return BuildMI(BB, I, I.getDebugLoc(), TII.get(DotOp))
+              .addDef(ResVReg)
+              .addUse(GR.getSPIRVTypeID(ResType))
+              .addUse(I.getOperand(2).getReg())
+              .addUse(I.getOperand(3).getReg())
+              .constrainAllUses(TII, TRI, RBI);
+}
+
 // Since pre-1.6 SPIRV has no integer dot implementation,
 // expand by piecewise multiplying and adding the results
-bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
+bool SPIRVInstructionSelector::selectIntegerDotExpansion(Register ResVReg,
                                                 const SPIRVType *ResType,
                                                 MachineInstr &I) const {
   assert(I.getNumOperands() == 4);
@@ -2681,7 +2702,10 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     return selectFloatDot(ResVReg, ResType, I);
   case Intrinsic::spv_udot:
   case Intrinsic::spv_sdot:
-    return selectIntegerDot(ResVReg, ResType, I);
+    if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product) ||
+        STI.isAtLeastSPIRVVer(VersionTuple(1, 6)))
+      return selectIntegerDot(ResVReg, ResType, I, /*Signed=*/IID == Intrinsic::spv_sdot);
+    return selectIntegerDotExpansion(ResVReg, ResType, I);
   case Intrinsic::spv_dot4add_i8packed:
     if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product) ||
         STI.isAtLeastSPIRVVer(VersionTuple(1, 6)))
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index e8641b3a105dec..4f30d7d83b7a7e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1013,10 +1013,11 @@ static void AddDotProductRequirements(const MachineInstr &MI,
   Reqs.addCapability(SPIRV::Capability::DotProduct);
 
   const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
-  const MachineInstr *InstrPtr = &MI;
-  assert(MI.getOperand(1).isReg() && "Unexpected operand in dot");
+  assert(MI.getOperand(2).isReg() && "Unexpected operand in dot");
+  const MachineInstr *InputInstr = MRI.getVRegDef(MI.getOperand(2).getReg());
+  assert(InputInstr->getOperand(1).isReg() && "Unexpected operand in dot input");
 
-  Register TypeReg = InstrPtr->getOperand(1).getReg();
+  Register TypeReg = InputInstr->getOperand(1).getReg();
   SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
   if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
     assert(TypeDef->getOperand(1).getImm() == 32);
@@ -1024,10 +1025,13 @@ static void AddDotProductRequirements(const MachineInstr &MI,
   } else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) {
     SPIRVType *ScalarTypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
     assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt);
-    auto Capability = ScalarTypeDef->getOperand(1).getImm() == 8
-                          ? SPIRV::Capability::DotProductInput4x8Bit
-                          : SPIRV::Capability::DotProductInputAll;
-    Reqs.addCapability(Capability);
+    if (ScalarTypeDef->getOperand(1).getImm() == 8) {
+      assert(TypeDef->getOperand(2).getImm() == 4
+             && "Dot operand of 8-bit integer type requires 4 components");
+      Reqs.addCapability(SPIRV::Capability::DotProductInput4x8Bit);
+    } else {
+      Reqs.addCapability(SPIRV::Capability::DotProductInputAll);
+    }
   }
 }
 
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll
index 22b6ed6bdfcbc5..b952cfe24a77db 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll
@@ -1,8 +1,20 @@
-; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
-; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+; RUN: llc -O0 -mtriple=spirv1.5-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-EXP
+; RUN: llc -O0 -mtriple=spirv1.6-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
+; RUN: llc -O0 -mtriple=spirv-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-EXT
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.5-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.6-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - -filetype=obj | spirv-val %}
 
 ; Make sure dxil operation function calls for dot are generated for int/uint vectors.
 
+; CHECK-DAG: OpCapability Int8
+; CHECK-DOT-DAG: OpCapability DotProduct
+; CHECK-DOT-DAG: OpCapability DotProductInputAll
+; CHECK-DOT-DAG: OpCapability DotProductInput4x8Bit
+; CHECK-EXT-DAG: OpExtension "SPV_KHR_integer_dot_product"
+
+; CHECK-DAG: %[[#int_8:]] = OpTypeInt 8
+; CHECK-DAG: %[[#vec4_int_8:]] = OpTypeVector %[[#int_8]] 4
 ; CHECK-DAG: %[[#int_16:]] = OpTypeInt 16
 ; CHECK-DAG: %[[#vec2_int_16:]] = OpTypeVector %[[#int_16]] 2
 ; CHECK-DAG: %[[#vec3_int_16:]] = OpTypeVector %[[#int_16]] 3
@@ -11,14 +23,32 @@
 ; CHECK-DAG: %[[#int_64:]] = OpTypeInt 64
 ; CHECK-DAG: %[[#vec2_int_64:]] = OpTypeVector %[[#int_64]] 2
 
+define noundef i8 @dot_int8_t4(<4 x i8> noundef %a, <4 x i8> noundef %b) {
+entry:
+; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_8]]
+; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_8]]
+
+; CHECK-DOT: %[[#dot:]] = OpSDot %[[#int_8]] %[[#arg0]] %[[#arg1]]
+
+; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_8]] %[[#arg0]] %[[#arg1]]
+; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 0
+; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 1
+; CHECK-EXP: %[[#sum:]] = OpIAdd %[[#int_8]] %[[#elt0]] %[[#elt1]]
+  %dot = call i8 @llvm.spv.sdot.v4i8(<4 x i8> %a, <4 x i8> %b)
+  ret i8 %dot
+}
+
 define noundef i16 @dot_int16_t2(<2 x i16> noundef %a, <2 x i16> noundef %b) {
 entry:
 ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_16]]
 ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_16]]
-; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec2_int_16]] %[[#arg0]] %[[#arg1]]
-; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
-; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
-; CHECK: %[[#sum:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]
+
+; CHECK-DOT: %[[#dot:]] = OpSDot %[[#int_16]] %[[#arg0]] %[[#arg1]]
+
+; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec2_int_16]] %[[#arg0]] %[[#arg1]]
+; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
+; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
+; CHECK-EXP: %[[#sum:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]
   %dot = call i16 @llvm.spv.sdot.v3i16(<2 x i16> %a, <2 x i16> %b)
   ret i16 %dot
 }
@@ -27,28 +57,49 @@ define noundef i32 @dot_int4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
 entry:
 ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_32]]
 ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_32]]
-; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
-; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
-; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
-; CHECK: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
-; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
-; CHECK: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
-; CHECK: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
-; CHECK: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]
+
+; CHECK-DOT: %[[#dot:]] = OpSDot %[[#int_32]] %[[#arg0]] %[[#arg1]]
+
+; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
+; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
+; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
+; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
+; CHECK-EXP: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
+; CHECK-EXP: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
+; CHECK-EXP: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
+; CHECK-EXP: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]
   %dot = call i32 @llvm.spv.sdot.v4i32(<4 x i32> %a, <4 x i32> %b)
   ret i32 %dot
 }
 
+define noundef i8 @dot_uint8_t4(<4 x i8> noundef %a, <4 x i8> noundef %b) {
+entry:
+; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_8]]
+; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_8]]
+
+; CHECK-DOT: %[[#dot:]] = OpUDot %[[#int_8]] %[[#arg0]] %[[#arg1]]
+
+; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_8]] %[[#arg0]] %[[#arg1]]
+; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 0
+; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 1
+; CHECK-EXP: %[[#sum:]] = OpIAdd %[[#int_8]] %[[#elt0]] %[[#elt1]]
+  %dot = call i8 @llvm.spv.udot.v4i8(<4 x i8> %a, <4 x i8> %b)
+  ret i8 %dot
+}
+
 define noundef i16 @dot_uint16_t3(<3 x i16> noundef %a, <3 x i16> noundef %b) {
 entry:
 ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec3_int_16]]
 ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec3_int_16]]
-; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec3_int_16]] %[[#arg0]] %[[#arg1]]
-; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
-; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
-; CHECK: %[[#sum0:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]
-; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 2
-; CHECK: %[[#sum1:]] = OpIAdd %[[#int_16]] %[[#sum0]] %[[#elt2]]
+
+; CHECK-DOT: %[[#dot:]] = OpUDot %[[#int_16]] %[[#arg0]] %[[#arg1]]
+
+; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec3_int_16]] %[[#arg0]] %[[#arg1]]
+; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
+; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
+; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]
+; CHECK-EXP: %[[#elt2:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 2
+; CHECK-EXP: %[[#sum1:]] = OpIAdd %[[#int_16]] %[[#sum0]] %[[#elt2]]
   %dot = call i16 @llvm.spv.udot.v3i16(<3 x i16> %a, <3 x i16> %b)
   ret i16 %dot
 }
@@ -57,32 +108,37 @@ define noundef i32 @dot_uint4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
 entry:
 ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_32]]
 ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_32]]
-; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
-; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
-; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
-; CHECK: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
-; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
-; CHECK: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
-; CHECK: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
-; CHECK: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]
+
+; CHECK-DOT: %[[#dot:]] = OpUDot %[[#int_32]] %[[#arg0]] %[[#arg1]]
+
+; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
+; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
+; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
+; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
+; CHECK-EXP: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
+; CHECK-EXP: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
+; CHECK-EXP: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
+; CHECK-EXP: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]
   %dot = call i32 @llvm.spv.udot.v4i32(<4 x i32> %a, <4 x i32> %b)
   ret i32 %dot
 }
 
 define noundef i64 @dot_uint64_t4(<2 x i64> noundef %a, <2 x i64> noundef %b) {
 entry:
-; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_64]]
-; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_64]]
-; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec2_int_64]] %[[#arg0]] %[[#arg1]]
-; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 0
-; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 1
-; CHECK: %[[#sum0:]] = OpIAdd %[[#int_64]] %[[#elt0]] %[[#elt1]]
+; CHECK-EXP: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_64]]
+; CHECK-EXP: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_64]]
+; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec2_int_64]] %[[#arg0]] %[[#arg1]]
+; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 0
+; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 1
+; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_64]] %[[#elt0]] %[[#elt1]]
   %dot = call i64 @llvm.spv.udot.v2i64(<2 x i64> %a, <2 x i64> %b)
   ret i64 %dot
 }
 
+declare i8 @llvm.spv.sdot.v4i8(<4 x i8>, <4 x i8>)
 declare i16 @llvm.spv.sdot.v2i16(<2 x i16>, <2 x i16>)
 declare i32 @llvm.spv.sdot.v4i32(<4 x i32>, <4 x i32>)
+declare i8 @llvm.spv.udot.v4i8(<4 x i8>, <4 x i8>)
 declare i16 @llvm.spv.udot.v3i32(<3 x i16>, <3 x i16>)
 declare i32 @llvm.spv.udot.v4i32(<4 x i32>, <4 x i32>)
 declare i64 @llvm.spv.udot.v2i64(<2 x i64>, <2 x i64>)

>From ca9b5a0884fb3e9ac695ec894158981adab43aee Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Wed, 6 Nov 2024 00:03:57 +0000
Subject: [PATCH 2/4] clang format

---
 .../Target/SPIRV/SPIRVInstructionSelector.cpp  | 18 +++++++++---------
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp  |  7 ++++---
 2 files changed, 13 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 4cb6f8e0af59ce..7fd7ad5a347114 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1695,18 +1695,17 @@ bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
 
   auto DotOp = Signed ? SPIRV::OpSDot : SPIRV::OpUDot;
   return BuildMI(BB, I, I.getDebugLoc(), TII.get(DotOp))
-              .addDef(ResVReg)
-              .addUse(GR.getSPIRVTypeID(ResType))
-              .addUse(I.getOperand(2).getReg())
-              .addUse(I.getOperand(3).getReg())
-              .constrainAllUses(TII, TRI, RBI);
+      .addDef(ResVReg)
+      .addUse(GR.getSPIRVTypeID(ResType))
+      .addUse(I.getOperand(2).getReg())
+      .addUse(I.getOperand(3).getReg())
+      .constrainAllUses(TII, TRI, RBI);
 }
 
 // Since pre-1.6 SPIRV has no integer dot implementation,
 // expand by piecewise multiplying and adding the results
-bool SPIRVInstructionSelector::selectIntegerDotExpansion(Register ResVReg,
-                                                const SPIRVType *ResType,
-                                                MachineInstr &I) const {
+bool SPIRVInstructionSelector::selectIntegerDotExpansion(
+    Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
   assert(I.getNumOperands() == 4);
   assert(I.getOperand(2).isReg());
   assert(I.getOperand(3).isReg());
@@ -2704,7 +2703,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
   case Intrinsic::spv_sdot:
     if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product) ||
         STI.isAtLeastSPIRVVer(VersionTuple(1, 6)))
-      return selectIntegerDot(ResVReg, ResType, I, /*Signed=*/IID == Intrinsic::spv_sdot);
+      return selectIntegerDot(ResVReg, ResType, I,
+                              /*Signed=*/IID == Intrinsic::spv_sdot);
     return selectIntegerDotExpansion(ResVReg, ResType, I);
   case Intrinsic::spv_dot4add_i8packed:
     if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product) ||
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 4f30d7d83b7a7e..63f2ee966030d3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1015,7 +1015,8 @@ static void AddDotProductRequirements(const MachineInstr &MI,
   const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
   assert(MI.getOperand(2).isReg() && "Unexpected operand in dot");
   const MachineInstr *InputInstr = MRI.getVRegDef(MI.getOperand(2).getReg());
-  assert(InputInstr->getOperand(1).isReg() && "Unexpected operand in dot input");
+  assert(InputInstr->getOperand(1).isReg() &&
+         "Unexpected operand in dot input");
 
   Register TypeReg = InputInstr->getOperand(1).getReg();
   SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
@@ -1026,8 +1027,8 @@ static void AddDotProductRequirements(const MachineInstr &MI,
     SPIRVType *ScalarTypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
     assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt);
     if (ScalarTypeDef->getOperand(1).getImm() == 8) {
-      assert(TypeDef->getOperand(2).getImm() == 4
-             && "Dot operand of 8-bit integer type requires 4 components");
+      assert(TypeDef->getOperand(2).getImm() == 4 &&
+             "Dot operand of 8-bit integer type requires 4 components");
       Reqs.addCapability(SPIRV::Capability::DotProductInput4x8Bit);
     } else {
       Reqs.addCapability(SPIRV::Capability::DotProductInputAll);

>From e52cf425c98be7c511f242244f39c3e2224d6126 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Thu, 7 Nov 2024 23:05:48 +0000
Subject: [PATCH 3/4] self review:

- fix spirv version
- add missing check
---
 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll | 11 +++++++----
 1 file changed, 7 insertions(+), 4 deletions(-)

diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll
index b952cfe24a77db..8acad352cdc29a 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll
@@ -1,9 +1,9 @@
 ; RUN: llc -O0 -mtriple=spirv1.5-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-EXP
 ; RUN: llc -O0 -mtriple=spirv1.6-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
-; RUN: llc -O0 -mtriple=spirv-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-EXT
+; RUN: llc -O0 -mtriple=spirv1.5-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-EXT
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.5-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.6-unknown-unknown %s -o - -filetype=obj | spirv-val %}
-; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - -filetype=obj | spirv-val %}
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.5-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - -filetype=obj | spirv-val %}
 
 ; Make sure dxil operation function calls for dot are generated for int/uint vectors.
 
@@ -125,8 +125,11 @@ entry:
 
 define noundef i64 @dot_uint64_t4(<2 x i64> noundef %a, <2 x i64> noundef %b) {
 entry:
-; CHECK-EXP: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_64]]
-; CHECK-EXP: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_64]]
+; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_64]]
+; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_64]]
+
+; CHECK-DOT: %[[#dot:]] = OpUDot %[[#int_64]] %[[#arg0]] %[[#arg1]]
+
 ; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec2_int_64]] %[[#arg0]] %[[#arg1]]
 ; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 0
 ; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 1

>From 0558a836b0d438967bb36df1424d41ef00a917a9 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Wed, 20 Nov 2024 19:45:10 +0000
Subject: [PATCH 4/4] self-review:

- add clarifying comment/code
---
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 63f2ee966030d3..765ff69ed38713 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1014,12 +1014,13 @@ static void AddDotProductRequirements(const MachineInstr &MI,
 
   const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
   assert(MI.getOperand(2).isReg() && "Unexpected operand in dot");
-  const MachineInstr *InputInstr = MRI.getVRegDef(MI.getOperand(2).getReg());
-  assert(InputInstr->getOperand(1).isReg() &&
-         "Unexpected operand in dot input");
+  // We do not consider what the previous instruction is. This is just used
+  // to get the input register and to check the type.
+  const MachineInstr *Input = MRI.getVRegDef(MI.getOperand(2).getReg());
+  assert(Input->getOperand(1).isReg() && "Unexpected operand in dot input");
+  Register InputReg = Input->getOperand(1).getReg();
 
-  Register TypeReg = InputInstr->getOperand(1).getReg();
-  SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
+  SPIRVType *TypeDef = MRI.getVRegDef(InputReg);
   if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
     assert(TypeDef->getOperand(1).getImm() == 32);
     Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitPacked);



More information about the llvm-commits mailing list