[clang] [llvm] [HLSL][SPIRV][DXIL] Implement `dot4add_i8packed` intrinsic (PR #113623)

Finn Plummer via cfe-commits cfe-commits at lists.llvm.org
Mon Nov 4 20:05:10 PST 2024


https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/113623

>From 81dfa26a941f7a0926a3126fe3ebbb4d2a67cec1 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Wed, 23 Oct 2024 22:59:15 +0000
Subject: [PATCH 01/13] [HLSL][SPIRV][DXIL] Implement `dot4add_i8packed`
 intrinsic

    - create a clang built-in in Builtins.td
    - link dot4add_i8packed in hlsl_intrinsics.h
    - add lowering to spirv backend through expansion of operation as
      OPSDot is missing up to SPIRV 1.6 in SPIRVInstructionSelector.cpp
    - add dot4add_i8packed intrinsic to IntrinsicsDirectX.td and mapping
      to DXIL.td op Dot4AddI8Packed

    - add tests for HLSL intrinsic lowering to dx/spv intrinsic in
      dot4add_i8packed.hlsl
    - add tests for sema checks in dot4add_i8packed-errors.hlsl
    - add test of spir-v lowering in SPIRV/dot4add_i8packed.ll
    - add test to dxil lowering in DirectX/dot4add_i8packed.ll
---
 clang/include/clang/Basic/Builtins.td         |  6 ++
 clang/lib/CodeGen/CGBuiltin.cpp               | 12 ++-
 clang/lib/CodeGen/CGHLSLRuntime.h             |  1 +
 clang/lib/Headers/hlsl/hlsl_intrinsics.h      | 10 +++
 .../builtins/dot4add_i8packed.hlsl            | 17 ++++
 .../BuiltIns/dot4add_i8packed-errors.hlsl     | 28 +++++++
 llvm/include/llvm/IR/IntrinsicsDirectX.td     |  1 +
 llvm/include/llvm/IR/IntrinsicsSPIRV.td       |  1 +
 llvm/lib/Target/DirectX/DXIL.td               | 10 +++
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 84 +++++++++++++++++++
 llvm/test/CodeGen/DirectX/dot4add_i8packed.ll | 10 +++
 .../SPIRV/hlsl-intrinsics/dot4add_i8packed.ll | 48 +++++++++++
 12 files changed, 227 insertions(+), 1 deletion(-)
 create mode 100644 clang/test/CodeGenHLSL/builtins/dot4add_i8packed.hlsl
 create mode 100644 clang/test/SemaHLSL/BuiltIns/dot4add_i8packed-errors.hlsl
 create mode 100644 llvm/test/CodeGen/DirectX/dot4add_i8packed.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll

diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 90475a361bb8f8..eb6b07e8858602 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4792,6 +4792,12 @@ def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLDot4AddI8Packed : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_dot4add_i8packed"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "int(unsigned int, unsigned int, int)";
+}
+
 def HLSLFrac : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_elementwise_frac"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 28f28c70b5ae52..13ed0f99da9815 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18722,7 +18722,17 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         /*ReturnType=*/T0->getScalarType(),
         getDotProductIntrinsic(CGM.getHLSLRuntime(), VecTy0->getElementType()),
         ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.dot");
-  } break;
+  }
+  case Builtin::BI__builtin_hlsl_dot4add_i8packed: {
+    Value *A = EmitScalarExpr(E->getArg(0));
+    Value *B = EmitScalarExpr(E->getArg(1));
+    Value *C = EmitScalarExpr(E->getArg(2));
+
+    Intrinsic::ID ID = CGM.getHLSLRuntime().getDot4AddI8PackedIntrinsic();
+    return Builder.CreateIntrinsic(
+        /*ReturnType=*/C->getType(), ID, ArrayRef<Value *>{A, B, C}, nullptr,
+        "hlsl.dot4add.i8packed");
+  }
   case Builtin::BI__builtin_hlsl_lerp: {
     Value *X = EmitScalarExpr(E->getArg(0));
     Value *Y = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index ff7df41b5c62e7..8b1141375106cc 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -89,6 +89,7 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(FDot, fdot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddI8Packed, dot4add_i8packed)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
 
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 30dce60b3ff702..d10bfcbeed97ea 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -894,6 +894,16 @@ uint64_t dot(uint64_t3, uint64_t3);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
 uint64_t dot(uint64_t4, uint64_t4);
 
+//===----------------------------------------------------------------------===//
+// dot4add builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn int dot4add_i8packed(uint A, uint B, int C)
+
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.4)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot4add_i8packed)
+int dot4add_i8packed(unsigned int, unsigned int, int);
+
 //===----------------------------------------------------------------------===//
 // exp builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/test/CodeGenHLSL/builtins/dot4add_i8packed.hlsl b/clang/test/CodeGenHLSL/builtins/dot4add_i8packed.hlsl
new file mode 100644
index 00000000000000..ea1a33d6267d2f
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/dot4add_i8packed.hlsl
@@ -0,0 +1,17 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN:   dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s -DTARGET=dx
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN:   spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s -DTARGET=spv
+
+// Test basic lowering to runtime function call.
+
+// CHECK-LABEL: test
+int test(uint a, uint b, int c) {
+  // CHECK:  %[[RET:.*]] = call [[TY:i32]] @llvm.[[TARGET]].dot4add.i8packed([[TY]] %[[#]], [[TY]] %[[#]], [[TY]] %[[#]])
+  // CHECK:  ret [[TY]] %[[RET]]
+  return dot4add_i8packed(a, b, c);
+}
+
+// CHECK: declare [[TY]] @llvm.[[TARGET]].dot4add.i8packed([[TY]], [[TY]], [[TY]])
diff --git a/clang/test/SemaHLSL/BuiltIns/dot4add_i8packed-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/dot4add_i8packed-errors.hlsl
new file mode 100644
index 00000000000000..ac0b430bfaf945
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/dot4add_i8packed-errors.hlsl
@@ -0,0 +1,28 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+int test_too_few_arg0() {
+  return __builtin_hlsl_dot4add_i8packed();
+  // expected-error at -1 {{too few arguments to function call, expected 3, have 0}}
+}
+
+int test_too_few_arg1(int p0) {
+  return __builtin_hlsl_dot4add_i8packed(p0);
+  // expected-error at -1 {{too few arguments to function call, expected 3, have 1}}
+}
+
+int test_too_few_arg2(int p0) {
+  return __builtin_hlsl_dot4add_i8packed(p0, p0);
+  // expected-error at -1 {{too few arguments to function call, expected 3, have 2}}
+}
+
+int test_too_many_arg(int p0) {
+  return __builtin_hlsl_dot4add_i8packed(p0, p0, p0, p0);
+  // expected-error at -1 {{too many arguments to function call, expected 3, have 4}}
+}
+
+struct S { float f; };
+
+int test_expr_struct_type_check(S p0, int p1) {
+  return __builtin_hlsl_dot4add_i8packed(p0, p1, p1);
+  // expected-error at -1 {{no viable conversion from 'S' to 'unsigned int'}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index e30d37f69f781e..8cd5ff9006c1b7 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -69,6 +69,7 @@ def int_dx_udot :
     DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
     [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
     [IntrNoMem, Commutative] >;
+  def int_dx_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
 
 def int_dx_frac  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_degrees : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 6df2eb156a0774..ebea18cd932617 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -83,6 +83,7 @@ let TargetPrefix = "spv" in {
     DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
     [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
     [IntrNoMem, Commutative] >;
+  def int_spv_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
   def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
   def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
   def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 147b32b1ca9903..f0f40de5009f52 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -779,6 +779,16 @@ def FlattenedThreadIdInGroup :  DXILOp<96, flattenedThreadIdInGroup> {
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
 
+def Dot4AddI8Packed : DXILOp<163, dot4AddPacked> {
+  let Doc = "signed dot product of 4 x i8 vectors packed into i32, with "
+            "accumulate to i32";
+  let LLVMIntrinsic = int_dx_dot4add_i8packed;
+  let arguments = [Int32Ty, Int32Ty, Int32Ty];
+  let result = Int32Ty;
+  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let stages = [Stages<DXIL1_0, [all_stages]>];
+}
+
 def AnnotateHandle : DXILOp<217, annotateHandle> {
   let Doc = "annotate handle with resource properties";
   let arguments = [HandleTy, ResPropsTy];
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index d9377fe4b91a1a..d297b2fa07209f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -164,6 +164,10 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectIntegerDot(Register ResVReg, const SPIRVType *ResType,
                         MachineInstr &I) const;
 
+  template <bool Signed>
+  bool selectDot4AddPacked(Register ResVReg, const SPIRVType *ResType,
+                           MachineInstr &I) const;
+
   void renderImm32(MachineInstrBuilder &MIB, const MachineInstr &I,
                    int OpIdx) const;
   void renderFImm64(MachineInstrBuilder &MIB, const MachineInstr &I,
@@ -1694,6 +1698,84 @@ bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
   return Result;
 }
 
+// Since pre-1.6 SPIRV has no DotProductInput4x8BitPacked implementation,
+// extract the elements of the packed inputs, multiply them and add the result
+// to the accumulator.
+template <bool Signed>
+bool SPIRVInstructionSelector::selectDot4AddPacked(Register ResVReg,
+                                                   const SPIRVType *ResType,
+                                                   MachineInstr &I) const {
+  assert(I.getNumOperands() == 5);
+  assert(I.getOperand(2).isReg());
+  assert(I.getOperand(3).isReg());
+  assert(I.getOperand(4).isReg());
+  MachineBasicBlock &BB = *I.getParent();
+
+  bool Result = false;
+
+  // Acc = C
+  Register Acc = I.getOperand(4).getReg();
+  SPIRVType *EltType = GR.getOrCreateSPIRVIntegerType(8, I, TII);
+  auto ExtractOp =
+      Signed ? SPIRV::OpBitFieldSExtract : SPIRV::OpBitFieldUExtract;
+
+  // Extract the i8 element, multiply and add it to the accumulator
+  for (unsigned i = 0; i < 4; i++) {
+    // A[i]
+    Register AElt = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+    Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
+                  .addDef(AElt)
+                  .addUse(GR.getSPIRVTypeID(ResType))
+                  .addUse(I.getOperand(2).getReg())
+                  .addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII))
+                  .addImm(8)
+                  .constrainAllUses(TII, TRI, RBI);
+
+    // B[i]
+    Register BElt = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+    Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
+                  .addDef(BElt)
+                  .addUse(GR.getSPIRVTypeID(ResType))
+                  .addUse(I.getOperand(3).getReg())
+                  .addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII))
+                  .addImm(8)
+                  .constrainAllUses(TII, TRI, RBI);
+
+    // A[i] * B[i]
+    Register Mul = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+    Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIMulS))
+                  .addDef(Mul)
+                  .addUse(GR.getSPIRVTypeID(ResType))
+                  .addUse(AElt)
+                  .addUse(BElt)
+                  .constrainAllUses(TII, TRI, RBI);
+
+    // Discard 24 highest-bits so that stored i32 register is i8 equivalent
+    Register MaskMul = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+    Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
+                  .addDef(MaskMul)
+                  .addUse(GR.getSPIRVTypeID(ResType))
+                  .addUse(Mul)
+                  .addUse(GR.getOrCreateConstInt(0, I, EltType, TII))
+                  .addImm(8)
+                  .constrainAllUses(TII, TRI, RBI);
+
+    // Acc = Acc + A[i] * B[i]
+    Register Sum =
+        i < 3 ? MRI->createVirtualRegister(&SPIRV::IDRegClass) : ResVReg;
+    Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS))
+                  .addDef(Sum)
+                  .addUse(GR.getSPIRVTypeID(ResType))
+                  .addUse(Acc)
+                  .addUse(MaskMul)
+                  .constrainAllUses(TII, TRI, RBI);
+
+    Acc = Sum;
+  }
+
+  return Result;
+}
+
 /// Transform saturate(x) to clamp(x, 0.0f, 1.0f) as SPIRV
 /// does not have a saturate builtin.
 bool SPIRVInstructionSelector::selectSaturate(Register ResVReg,
@@ -2527,6 +2609,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
   case Intrinsic::spv_udot:
   case Intrinsic::spv_sdot:
     return selectIntegerDot(ResVReg, ResType, I);
+  case Intrinsic::spv_dot4add_i8packed:
+    return selectDot4AddPacked<true>(ResVReg, ResType, I);
   case Intrinsic::spv_all:
     return selectAll(ResVReg, ResType, I);
   case Intrinsic::spv_any:
diff --git a/llvm/test/CodeGen/DirectX/dot4add_i8packed.ll b/llvm/test/CodeGen/DirectX/dot4add_i8packed.ll
new file mode 100644
index 00000000000000..7df0520505cea6
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/dot4add_i8packed.ll
@@ -0,0 +1,10 @@
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
+
+define void @main(i32 %a, i32 %b, i32 %c) {
+entry:
+; CHECK: call i32 @dx.op.dot4AddPacked(i32 163, i32 %a, i32 %b, i32 %c)
+  %0 = call i32 @llvm.dx.dot4add.i8packed(i32 %a, i32 %b, i32 %c)
+  ret void
+}
+
+declare i32 @llvm.dx.dot4add.i8packed(i32, i32, i32)
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll
new file mode 100644
index 00000000000000..35e2a731071103
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll
@@ -0,0 +1,48 @@
+; RUN: llc -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#int_32:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#int_8:]] = OpTypeInt 8 0
+; CHECK-DAG: %[[#zero:]] = OpConstantNull %[[#int_8]]
+; CHECK-DAG: %[[#eight:]] = OpConstant %[[#int_8]] 8
+; CHECK-DAG: %[[#sixteen:]] = OpConstant %[[#int_8]] 16
+; CHECK-DAG: %[[#twentyfour:]] = OpConstant %[[#int_8]] 24
+; CHECK-LABEL: Begin function test_dot
+define noundef i32 @test_dot(i32 noundef %a, i32 noundef %b, i32 noundef %c) {
+entry:
+; CHECK: %[[#A:]] = OpFunctionParameter %[[#int_32]]
+; CHECK: %[[#B:]] = OpFunctionParameter %[[#int_32]]
+; CHECK: %[[#C:]] = OpFunctionParameter %[[#int_32]]
+
+; First element of the packed vector
+; CHECK: %[[#A0:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#zero]] 8
+; CHECK: %[[#B0:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#zero]] 8
+; CHECK: %[[#MUL0:]] = OpIMul %[[#int_32]] %[[#A0]] %[[#B0]]
+; CHECK: %[[#MASK0:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL0]] %[[#zero]] 8
+; CHECK: %[[#ACC0:]] = OpIAdd %[[#int_32]] %[[#C]] %[[#MASK0]]
+
+; Second element of the packed vector
+; CHECK: %[[#A1:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#eight]] 8
+; CHECK: %[[#B1:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#eight]] 8
+; CHECK: %[[#MUL1:]] = OpIMul %[[#int_32]] %[[#A1]] %[[#B1]]
+; CHECK: %[[#MASK1:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL1]] %[[#zero]] 8
+; CHECK: %[[#ACC1:]] = OpIAdd %[[#int_32]] %[[#ACC0]] %[[#MASK1]]
+
+; Third element of the packed vector
+; CHECK: %[[#A2:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#sixteen]] 8
+; CHECK: %[[#B2:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#sixteen]] 8
+; CHECK: %[[#MUL2:]] = OpIMul %[[#int_32]] %[[#A2]] %[[#B2]]
+; CHECK: %[[#MASK2:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL2]] %[[#zero]] 8
+; CHECK: %[[#ACC2:]] = OpIAdd %[[#int_32]] %[[#ACC1]] %[[#MASK2]]
+
+; Fourth element of the packed vector
+; CHECK: %[[#A3:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#twentyfour]] 8
+; CHECK: %[[#B3:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#twentyfour]] 8
+; CHECK: %[[#MUL3:]] = OpIMul %[[#int_32]] %[[#A3]] %[[#B3]]
+; CHECK: %[[#MASK3:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL3]] %[[#zero]] 8
+; CHECK: %[[#ACC3:]] = OpIAdd %[[#int_32]] %[[#ACC2]] %[[#MASK3]]
+
+; CHECK: OpReturnValue %[[#ACC3]]
+  %spv.dot = call i32 @llvm.spv.dot4add.i8packed(i32 %a, i32 %b, i32 %c)
+  ret i32 %spv.dot
+}

>From bb6602e5840318f9f1dcbde5962d051aa8a9b4dd Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 29 Oct 2024 17:25:56 +0000
Subject: [PATCH 02/13] add SPV_KHR_integer_dot_product flag

we were unable to add this capability to SPIRV target information as a
command line option of llc
---
 llvm/docs/SPIRVUsage.rst                   | 2 ++
 llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp | 2 ++
 2 files changed, 4 insertions(+)

diff --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst
index 47cb52b1cc684d..0fcaa366c8a3e0 100644
--- a/llvm/docs/SPIRVUsage.rst
+++ b/llvm/docs/SPIRVUsage.rst
@@ -179,6 +179,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
      - Provides additional information to a compiler, similar to the llvm.assume and llvm.expect intrinsics.
    * - ``SPV_KHR_float_controls``
      - Provides new execution modes to control floating-point computations by overriding an implementation’s default behavior for rounding modes, denormals, signed zero, and infinities.
+   * - ``SPV_KHR_integer_dot_product``
+     - Adds instructions for dot product operations on integer vectors with optional accumulation. Integer vectors includes 4-component vector of 8 bit integers and 4-component vectors of 8 bit integers packed into 32-bit integers.
    * - ``SPV_KHR_linkonce_odr``
      - Allows to use the LinkOnceODR linkage type that lets a function or global variable to be merged with other functions or global variables of the same name when linkage occurs.
    * - ``SPV_KHR_no_integer_wrap_decoration``
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index dbfc133864bba4..e58366017966a7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -56,6 +56,8 @@ static const std::map<std::string, SPIRV::Extension::Extension>
          SPIRV::Extension::Extension::SPV_KHR_expect_assume},
         {"SPV_KHR_bit_instructions",
          SPIRV::Extension::Extension::SPV_KHR_bit_instructions},
+        {"SPV_KHR_integer_dot_product",
+         SPIRV::Extension::Extension::SPV_KHR_integer_dot_product},
         {"SPV_KHR_linkonce_odr",
          SPIRV::Extension::Extension::SPV_KHR_linkonce_odr},
         {"SPV_INTEL_inline_assembly",

>From 65512b6d650455b1eca97ee8fc9b6db7a1691048 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 29 Oct 2024 17:29:22 +0000
Subject: [PATCH 03/13] review comments:

- use OpSDot for the lowering when the target has capabilities to do
so

- fix HLSL_AVAILABILITY
---
 clang/lib/Headers/hlsl/hlsl_intrinsics.h      |  2 +-
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 40 ++++++++++--
 .../SPIRV/hlsl-intrinsics/dot4add_i8packed.ll | 64 +++++++++++--------
 3 files changed, 74 insertions(+), 32 deletions(-)

diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index d10bfcbeed97ea..90638b0276f14e 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -900,7 +900,7 @@ uint64_t dot(uint64_t4, uint64_t4);
 
 /// \fn int dot4add_i8packed(uint A, uint B, int C)
 
-_HLSL_16BIT_AVAILABILITY(shadermodel, 6.4)
+_HLSL_AVAILABILITY(shadermodel, 6.4)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot4add_i8packed)
 int dot4add_i8packed(unsigned int, unsigned int, int);
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index d297b2fa07209f..6b95d9a27f7cae 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -167,6 +167,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
   template <bool Signed>
   bool selectDot4AddPacked(Register ResVReg, const SPIRVType *ResType,
                            MachineInstr &I) const;
+  template <bool Signed>
+  bool selectDot4AddPackedExpansion(Register ResVReg, const SPIRVType *ResType,
+                                    MachineInstr &I) const;
 
   void renderImm32(MachineInstrBuilder &MIB, const MachineInstr &I,
                    int OpIdx) const;
@@ -1698,9 +1701,6 @@ bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
   return Result;
 }
 
-// Since pre-1.6 SPIRV has no DotProductInput4x8BitPacked implementation,
-// extract the elements of the packed inputs, multiply them and add the result
-// to the accumulator.
 template <bool Signed>
 bool SPIRVInstructionSelector::selectDot4AddPacked(Register ResVReg,
                                                    const SPIRVType *ResType,
@@ -1711,6 +1711,35 @@ bool SPIRVInstructionSelector::selectDot4AddPacked(Register ResVReg,
   assert(I.getOperand(4).isReg());
   MachineBasicBlock &BB = *I.getParent();
 
+  Register Dot = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+  bool Result = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpDot))
+                    .addDef(Dot)
+                    .addUse(GR.getSPIRVTypeID(ResType))
+                    .addUse(I.getOperand(2).getReg())
+                    .addUse(I.getOperand(3).getReg())
+                    .constrainAllUses(TII, TRI, RBI);
+
+  Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS))
+                .addDef(ResVReg)
+                .addUse(GR.getSPIRVTypeID(ResType))
+                .addUse(Dot)
+                .addUse(I.getOperand(4).getReg())
+                .constrainAllUses(TII, TRI, RBI);
+
+  return Result;
+}
+// Since pre-1.6 SPIRV has no DotProductInput4x8BitPacked implementation,
+// extract the elements of the packed inputs, multiply them and add the result
+// to the accumulator.
+template <bool Signed>
+bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
+    Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
+  assert(I.getNumOperands() == 5);
+  assert(I.getOperand(2).isReg());
+  assert(I.getOperand(3).isReg());
+  assert(I.getOperand(4).isReg());
+  MachineBasicBlock &BB = *I.getParent();
+
   bool Result = false;
 
   // Acc = C
@@ -2610,7 +2639,10 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
   case Intrinsic::spv_sdot:
     return selectIntegerDot(ResVReg, ResType, I);
   case Intrinsic::spv_dot4add_i8packed:
-    return selectDot4AddPacked<true>(ResVReg, ResType, I);
+    if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product) ||
+        STI.isAtLeastSPIRVVer(VersionTuple(1, 6)))
+      return selectDot4AddPacked<true>(ResVReg, ResType, I);
+    return selectDot4AddPackedExpansion<true>(ResVReg, ResType, I);
   case Intrinsic::spv_all:
     return selectAll(ResVReg, ResType, I);
   case Intrinsic::spv_any:
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll
index 35e2a731071103..1ae85e489d1a35 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll
@@ -1,12 +1,15 @@
-; RUN: llc -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - | FileCheck %s
+; RUN: llc -O0 -mtriple=spirv32v1.6-vulkan-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
+; RUN: llc -O0 -mtriple=spirv32-vulkan-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
+; RUN: llc -O0 -mtriple=spirv32-vulkan-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-EXP
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
 
 ; CHECK-DAG: %[[#int_32:]] = OpTypeInt 32 0
-; CHECK-DAG: %[[#int_8:]] = OpTypeInt 8 0
-; CHECK-DAG: %[[#zero:]] = OpConstantNull %[[#int_8]]
-; CHECK-DAG: %[[#eight:]] = OpConstant %[[#int_8]] 8
-; CHECK-DAG: %[[#sixteen:]] = OpConstant %[[#int_8]] 16
-; CHECK-DAG: %[[#twentyfour:]] = OpConstant %[[#int_8]] 24
+; CHECK-EXP-DAG: %[[#int_8:]] = OpTypeInt 8 0
+; CHECK-EXP-DAG: %[[#zero:]] = OpConstantNull %[[#int_8]]
+; CHECK-EXP-DAG: %[[#eight:]] = OpConstant %[[#int_8]] 8
+; CHECK-EXP-DAG: %[[#sixteen:]] = OpConstant %[[#int_8]] 16
+; CHECK-EXP-DAG: %[[#twentyfour:]] = OpConstant %[[#int_8]] 24
+
 ; CHECK-LABEL: Begin function test_dot
 define noundef i32 @test_dot(i32 noundef %a, i32 noundef %b, i32 noundef %c) {
 entry:
@@ -14,35 +17,42 @@ entry:
 ; CHECK: %[[#B:]] = OpFunctionParameter %[[#int_32]]
 ; CHECK: %[[#C:]] = OpFunctionParameter %[[#int_32]]
 
+; Test that we use the dot product op when capabilities allow
+
+; CHECK-DOT: %[[#DOT:]] = OpDot %[[#int_32]] %[[#A]] %[[#B]]
+; CHECK-DOT: %[[#RES:]] = OpIAdd %[[#int_32]] %[[#DOT]] %[[#C]]
+
+; Test expansion is used when spirv dot product capabilities aren't available:
+
 ; First element of the packed vector
-; CHECK: %[[#A0:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#zero]] 8
-; CHECK: %[[#B0:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#zero]] 8
-; CHECK: %[[#MUL0:]] = OpIMul %[[#int_32]] %[[#A0]] %[[#B0]]
-; CHECK: %[[#MASK0:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL0]] %[[#zero]] 8
-; CHECK: %[[#ACC0:]] = OpIAdd %[[#int_32]] %[[#C]] %[[#MASK0]]
+; CHECK-EXP: %[[#A0:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#zero]] 8
+; CHECK-EXP: %[[#B0:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#zero]] 8
+; CHECK-EXP: %[[#MUL0:]] = OpIMul %[[#int_32]] %[[#A0]] %[[#B0]]
+; CHECK-EXP: %[[#MASK0:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL0]] %[[#zero]] 8
+; CHECK-EXP: %[[#ACC0:]] = OpIAdd %[[#int_32]] %[[#C]] %[[#MASK0]]
 
 ; Second element of the packed vector
-; CHECK: %[[#A1:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#eight]] 8
-; CHECK: %[[#B1:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#eight]] 8
-; CHECK: %[[#MUL1:]] = OpIMul %[[#int_32]] %[[#A1]] %[[#B1]]
-; CHECK: %[[#MASK1:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL1]] %[[#zero]] 8
-; CHECK: %[[#ACC1:]] = OpIAdd %[[#int_32]] %[[#ACC0]] %[[#MASK1]]
+; CHECK-EXP: %[[#A1:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#eight]] 8
+; CHECK-EXP: %[[#B1:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#eight]] 8
+; CHECK-EXP: %[[#MUL1:]] = OpIMul %[[#int_32]] %[[#A1]] %[[#B1]]
+; CHECK-EXP: %[[#MASK1:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL1]] %[[#zero]] 8
+; CHECK-EXP: %[[#ACC1:]] = OpIAdd %[[#int_32]] %[[#ACC0]] %[[#MASK1]]
 
 ; Third element of the packed vector
-; CHECK: %[[#A2:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#sixteen]] 8
-; CHECK: %[[#B2:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#sixteen]] 8
-; CHECK: %[[#MUL2:]] = OpIMul %[[#int_32]] %[[#A2]] %[[#B2]]
-; CHECK: %[[#MASK2:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL2]] %[[#zero]] 8
-; CHECK: %[[#ACC2:]] = OpIAdd %[[#int_32]] %[[#ACC1]] %[[#MASK2]]
+; CHECK-EXP: %[[#A2:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#sixteen]] 8
+; CHECK-EXP: %[[#B2:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#sixteen]] 8
+; CHECK-EXP: %[[#MUL2:]] = OpIMul %[[#int_32]] %[[#A2]] %[[#B2]]
+; CHECK-EXP: %[[#MASK2:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL2]] %[[#zero]] 8
+; CHECK-EXP: %[[#ACC2:]] = OpIAdd %[[#int_32]] %[[#ACC1]] %[[#MASK2]]
 
 ; Fourth element of the packed vector
-; CHECK: %[[#A3:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#twentyfour]] 8
-; CHECK: %[[#B3:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#twentyfour]] 8
-; CHECK: %[[#MUL3:]] = OpIMul %[[#int_32]] %[[#A3]] %[[#B3]]
-; CHECK: %[[#MASK3:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL3]] %[[#zero]] 8
-; CHECK: %[[#ACC3:]] = OpIAdd %[[#int_32]] %[[#ACC2]] %[[#MASK3]]
+; CHECK-EXP: %[[#A3:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#twentyfour]] 8
+; CHECK-EXP: %[[#B3:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#twentyfour]] 8
+; CHECK-EXP: %[[#MUL3:]] = OpIMul %[[#int_32]] %[[#A3]] %[[#B3]]
+; CHECK-EXP: %[[#MASK3:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL3]] %[[#zero]] 8
+; CHECK-EXP: %[[#RES:]] = OpIAdd %[[#int_32]] %[[#ACC2]] %[[#MASK3]]
 
-; CHECK: OpReturnValue %[[#ACC3]]
+; CHECK: OpReturnValue %[[#RES]]
   %spv.dot = call i32 @llvm.spv.dot4add.i8packed(i32 %a, i32 %b, i32 %c)
   ret i32 %spv.dot
 }

>From ad25728469d7beae48841ed642f35ff6d663c6aa Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 29 Oct 2024 17:56:51 +0000
Subject: [PATCH 04/13] review comments

- dont use immediate constant
---
 .../Target/SPIRV/SPIRVInstructionSelector.cpp |  6 ++---
 .../SPIRV/hlsl-intrinsics/dot4add_i8packed.ll | 27 ++++++++++---------
 2 files changed, 17 insertions(+), 16 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 6b95d9a27f7cae..c15a87838b8888 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1757,7 +1757,7 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
                   .addUse(GR.getSPIRVTypeID(ResType))
                   .addUse(I.getOperand(2).getReg())
                   .addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII))
-                  .addImm(8)
+                  .addUse(GR.getOrCreateConstInt(8, I, EltType, TII))
                   .constrainAllUses(TII, TRI, RBI);
 
     // B[i]
@@ -1767,7 +1767,7 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
                   .addUse(GR.getSPIRVTypeID(ResType))
                   .addUse(I.getOperand(3).getReg())
                   .addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII))
-                  .addImm(8)
+                  .addUse(GR.getOrCreateConstInt(8, I, EltType, TII))
                   .constrainAllUses(TII, TRI, RBI);
 
     // A[i] * B[i]
@@ -1786,7 +1786,7 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
                   .addUse(GR.getSPIRVTypeID(ResType))
                   .addUse(Mul)
                   .addUse(GR.getOrCreateConstInt(0, I, EltType, TII))
-                  .addImm(8)
+                  .addUse(GR.getOrCreateConstInt(8, I, EltType, TII))
                   .constrainAllUses(TII, TRI, RBI);
 
     // Acc = Acc + A[i] * B[i]
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll
index 1ae85e489d1a35..57a94f7ecad423 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll
@@ -25,34 +25,35 @@ entry:
 ; Test expansion is used when spirv dot product capabilities aren't available:
 
 ; First element of the packed vector
-; CHECK-EXP: %[[#A0:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#zero]] 8
-; CHECK-EXP: %[[#B0:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#zero]] 8
+; CHECK-EXP: %[[#A0:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#zero]] %[[#eight]]
+; CHECK-EXP: %[[#B0:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#zero]] %[[#eight]]
 ; CHECK-EXP: %[[#MUL0:]] = OpIMul %[[#int_32]] %[[#A0]] %[[#B0]]
-; CHECK-EXP: %[[#MASK0:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL0]] %[[#zero]] 8
+; CHECK-EXP: %[[#MASK0:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL0]] %[[#zero]] %[[#eight]]
 ; CHECK-EXP: %[[#ACC0:]] = OpIAdd %[[#int_32]] %[[#C]] %[[#MASK0]]
 
 ; Second element of the packed vector
-; CHECK-EXP: %[[#A1:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#eight]] 8
-; CHECK-EXP: %[[#B1:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#eight]] 8
+; CHECK-EXP: %[[#A1:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#eight]] %[[#eight]]
+; CHECK-EXP: %[[#B1:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#eight]] %[[#eight]]
 ; CHECK-EXP: %[[#MUL1:]] = OpIMul %[[#int_32]] %[[#A1]] %[[#B1]]
-; CHECK-EXP: %[[#MASK1:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL1]] %[[#zero]] 8
+; CHECK-EXP: %[[#MASK1:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL1]] %[[#zero]] %[[#eight]]
 ; CHECK-EXP: %[[#ACC1:]] = OpIAdd %[[#int_32]] %[[#ACC0]] %[[#MASK1]]
 
 ; Third element of the packed vector
-; CHECK-EXP: %[[#A2:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#sixteen]] 8
-; CHECK-EXP: %[[#B2:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#sixteen]] 8
+; CHECK-EXP: %[[#A2:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#sixteen]] %[[#eight]]
+; CHECK-EXP: %[[#B2:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#sixteen]] %[[#eight]]
 ; CHECK-EXP: %[[#MUL2:]] = OpIMul %[[#int_32]] %[[#A2]] %[[#B2]]
-; CHECK-EXP: %[[#MASK2:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL2]] %[[#zero]] 8
+; CHECK-EXP: %[[#MASK2:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL2]] %[[#zero]] %[[#eight]]
 ; CHECK-EXP: %[[#ACC2:]] = OpIAdd %[[#int_32]] %[[#ACC1]] %[[#MASK2]]
 
 ; Fourth element of the packed vector
-; CHECK-EXP: %[[#A3:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#twentyfour]] 8
-; CHECK-EXP: %[[#B3:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#twentyfour]] 8
+; CHECK-EXP: %[[#A3:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#twentyfour]] %[[#eight]]
+; CHECK-EXP: %[[#B3:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#twentyfour]] %[[#eight]]
 ; CHECK-EXP: %[[#MUL3:]] = OpIMul %[[#int_32]] %[[#A3]] %[[#B3]]
-; CHECK-EXP: %[[#MASK3:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL3]] %[[#zero]] 8
-; CHECK-EXP: %[[#RES:]] = OpIAdd %[[#int_32]] %[[#ACC2]] %[[#MASK3]]
+; CHECK-EXP: %[[#MASK3:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL3]] %[[#zero]] %[[#eight]]
 
+; CHECK-EXP: %[[#RES:]] = OpIAdd %[[#int_32]] %[[#ACC2]] %[[#MASK3]]
 ; CHECK: OpReturnValue %[[#RES]]
   %spv.dot = call i32 @llvm.spv.dot4add.i8packed(i32 %a, i32 %b, i32 %c)
+
   ret i32 %spv.dot
 }

>From 4ee05a31a5434098c8be2900c85648f32507ba9d Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Wed, 30 Oct 2024 20:45:37 +0000
Subject: [PATCH 05/13] define and use Op[S|U]Dot

---
 llvm/lib/Target/SPIRV/SPIRVInstrInfo.td            | 3 +++
 llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 4 +++-
 2 files changed, 6 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index ee6b70a16417f4..130b557f9e8484 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -524,6 +524,9 @@ defm OpISubBorrow: BinOpTypedGen<"OpISubBorrow", 150, subc, 0, 1>;
 def OpUMulExtended: BinOp<"OpUMulExtended", 151>;
 def OpSMulExtended: BinOp<"OpSMulExtended", 152>;
 
+def OpSDot: BinOp<"OpSDot", 4450>;
+def OpUDot: BinOp<"OpUDot", 4451>;
+
 // 3.42.14 Bit Instructions
 
 defm OpShiftRightLogical: BinOpTypedGen<"OpShiftRightLogical", 194, srl, 0, 1>;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index c15a87838b8888..752f35c02b9229 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1711,8 +1711,9 @@ bool SPIRVInstructionSelector::selectDot4AddPacked(Register ResVReg,
   assert(I.getOperand(4).isReg());
   MachineBasicBlock &BB = *I.getParent();
 
+  auto DotOp = Signed ? SPIRV::OpSDot : SPIRV::OpUDot;
   Register Dot = MRI->createVirtualRegister(&SPIRV::IDRegClass);
-  bool Result = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpDot))
+  bool Result = BuildMI(BB, I, I.getDebugLoc(), TII.get(DotOp))
                     .addDef(Dot)
                     .addUse(GR.getSPIRVTypeID(ResType))
                     .addUse(I.getOperand(2).getReg())
@@ -1728,6 +1729,7 @@ bool SPIRVInstructionSelector::selectDot4AddPacked(Register ResVReg,
 
   return Result;
 }
+
 // Since pre-1.6 SPIRV has no DotProductInput4x8BitPacked implementation,
 // extract the elements of the packed inputs, multiply them and add the result
 // to the accumulator.

>From a041d44c7e947d3687c5f5e0e356fc3c32ed9cfd Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Wed, 30 Oct 2024 20:50:22 +0000
Subject: [PATCH 06/13] add integer_dot_product capabilities

- define the 4 capabilities and add them to OpenCL init when SPIRV
version is 1.6 or greater
  - require these capabilities during analysis of OpSDot or OpUDot
instructions
  - verify in test case that the capability/extensions are correctly
emitted
---
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp   | 16 ++++++++++++++++
 llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td  |  4 ++++
 .../SPIRV/hlsl-intrinsics/dot4add_i8packed.ll   | 17 +++++++++++------
 3 files changed, 31 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index db5463f5c7abb0..1f9644586ed57f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -672,6 +672,11 @@ void RequirementHandler::initAvailableCapabilitiesForOpenCL(
                       Capability::SignedZeroInfNanPreserve,
                       Capability::RoundingModeRTE,
                       Capability::RoundingModeRTZ});
+  if (ST.isAtLeastSPIRVVer(VersionTuple(1, 6)))
+    addAvailableCaps({Capability::DotProductKHR,
+                      Capability::DotProductInputAllKHR,
+                      Capability::DotProductInput4x8BitKHR,
+                      Capability::DotProductInput4x8BitPackedKHR});
   // TODO: verify if this needs some checks.
   addAvailableCaps({Capability::Float16, Capability::Float64});
 
@@ -1218,6 +1223,17 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::SplitBarrierINTEL);
     }
     break;
+  case SPIRV::OpSDot:
+  case SPIRV::OpUDot: {
+    if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product)) {
+      Reqs.addExtension(SPIRV::Extension::SPV_KHR_integer_dot_product);
+    }
+    Reqs.addCapability(SPIRV::Capability::DotProductKHR);
+    Reqs.addCapability(SPIRV::Capability::DotProductInputAllKHR);
+    Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitKHR);
+    Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitPackedKHR);
+    break;
+  }
   default:
     break;
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 13ad1eb8e8b337..888525f2ea85a4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -465,6 +465,10 @@ defm ExpectAssumeKHR : CapabilityOperand<5629, 0, 0, [SPV_KHR_expect_assume], []
 defm FunctionPointersINTEL : CapabilityOperand<5603, 0, 0, [SPV_INTEL_function_pointers], []>;
 defm IndirectReferencesINTEL : CapabilityOperand<5604, 0, 0, [SPV_INTEL_function_pointers], []>;
 defm AsmINTEL : CapabilityOperand<5606, 0, 0, [SPV_INTEL_inline_assembly], []>;
+defm DotProductInputAllKHR : CapabilityOperand<6016, 0, 0, [SPV_KHR_integer_dot_product], []>;
+defm DotProductInput4x8BitKHR : CapabilityOperand<6017, 0, 0, [SPV_KHR_integer_dot_product], [Int8]>;
+defm DotProductInput4x8BitPackedKHR : CapabilityOperand<6018, 0, 0, [SPV_KHR_integer_dot_product], []>;
+defm DotProductKHR : CapabilityOperand<6019, 0, 0, [SPV_KHR_integer_dot_product], []>;
 defm GroupNonUniformRotateKHR : CapabilityOperand<6026, 0, 0, [SPV_KHR_subgroup_rotate], [GroupNonUniform]>;
 defm AtomicFloat32AddEXT : CapabilityOperand<6033, 0, 0, [SPV_EXT_shader_atomic_float_add], []>;
 defm AtomicFloat64AddEXT : CapabilityOperand<6034, 0, 0, [SPV_EXT_shader_atomic_float_add], []>;
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll
index 57a94f7ecad423..b9bbb7bfc0de61 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll
@@ -1,9 +1,14 @@
-; RUN: llc -O0 -mtriple=spirv32v1.6-vulkan-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
-; RUN: llc -O0 -mtriple=spirv32-vulkan-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
-; RUN: llc -O0 -mtriple=spirv32-vulkan-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-EXP
-; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-EXP
+; RUN: llc -O0 -mtriple=spirv32v1.6-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-EXT
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32v1.6-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - -filetype=obj | spirv-val %}
 
-; CHECK-DAG: %[[#int_32:]] = OpTypeInt 32 0
+; CHECK-DOT: OpCapability DotProductInput4x8BitPackedKHR
+; CHECK-EXT: OpExtension "SPV_KHR_integer_dot_product"
+
+; CHECK: %[[#int_32:]] = OpTypeInt 32 0
 ; CHECK-EXP-DAG: %[[#int_8:]] = OpTypeInt 8 0
 ; CHECK-EXP-DAG: %[[#zero:]] = OpConstantNull %[[#int_8]]
 ; CHECK-EXP-DAG: %[[#eight:]] = OpConstant %[[#int_8]] 8
@@ -19,7 +24,7 @@ entry:
 
 ; Test that we use the dot product op when capabilities allow
 
-; CHECK-DOT: %[[#DOT:]] = OpDot %[[#int_32]] %[[#A]] %[[#B]]
+; CHECK-DOT: %[[#DOT:]] = OpSDot %[[#int_32]] %[[#A]] %[[#B]]
 ; CHECK-DOT: %[[#RES:]] = OpIAdd %[[#int_32]] %[[#DOT]] %[[#C]]
 
 ; Test expansion is used when spirv dot product capabilities aren't available:

>From 6431ad0ae4688697c7dff3e1b7ae5d7fca1ec809 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Fri, 1 Nov 2024 18:07:51 +0000
Subject: [PATCH 07/13] move out common capabilities

---
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 29 ++++++++++---------
 1 file changed, 15 insertions(+), 14 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 1f9644586ed57f..33c25c1804dee5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -626,6 +626,16 @@ void SPIRV::RequirementHandler::removeCapabilityIf(
 namespace llvm {
 namespace SPIRV {
 void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {
+  // Provided by both all supported Vulkan versions and OpenCl.
+  addAvailableCaps({Capability::Shader, Capability::Linkage, Capability::Int8,
+                    Capability::Int16});
+
+  if (ST.isAtLeastSPIRVVer(VersionTuple(1, 6)))
+    addAvailableCaps({Capability::DotProductKHR,
+                      Capability::DotProductInputAllKHR,
+                      Capability::DotProductInput4x8BitKHR,
+                      Capability::DotProductInput4x8BitPackedKHR});
+
   if (ST.isOpenCLEnv()) {
     initAvailableCapabilitiesForOpenCL(ST);
     return;
@@ -643,10 +653,8 @@ void RequirementHandler::initAvailableCapabilitiesForOpenCL(
     const SPIRVSubtarget &ST) {
   // Add the min requirements for different OpenCL and SPIR-V versions.
   addAvailableCaps({Capability::Addresses, Capability::Float16Buffer,
-                    Capability::Int16, Capability::Int8, Capability::Kernel,
-                    Capability::Linkage, Capability::Vector16,
-                    Capability::Groups, Capability::GenericPointer,
-                    Capability::Shader});
+                    Capability::Kernel, Capability::Vector16,
+                    Capability::Groups, Capability::GenericPointer});
   if (ST.hasOpenCLFullProfile())
     addAvailableCaps({Capability::Int64, Capability::Int64Atomics});
   if (ST.hasOpenCLImageSupport()) {
@@ -672,11 +680,6 @@ void RequirementHandler::initAvailableCapabilitiesForOpenCL(
                       Capability::SignedZeroInfNanPreserve,
                       Capability::RoundingModeRTE,
                       Capability::RoundingModeRTZ});
-  if (ST.isAtLeastSPIRVVer(VersionTuple(1, 6)))
-    addAvailableCaps({Capability::DotProductKHR,
-                      Capability::DotProductInputAllKHR,
-                      Capability::DotProductInput4x8BitKHR,
-                      Capability::DotProductInput4x8BitPackedKHR});
   // TODO: verify if this needs some checks.
   addAvailableCaps({Capability::Float16, Capability::Float64});
 
@@ -692,13 +695,11 @@ void RequirementHandler::initAvailableCapabilitiesForOpenCL(
 
 void RequirementHandler::initAvailableCapabilitiesForVulkan(
     const SPIRVSubtarget &ST) {
-  addAvailableCaps({Capability::Shader, Capability::Linkage});
 
   // Provided by all supported Vulkan versions.
-  addAvailableCaps({Capability::Int16, Capability::Int64, Capability::Float16,
-                    Capability::Float64, Capability::GroupNonUniform,
-                    Capability::Image1D, Capability::SampledBuffer,
-                    Capability::ImageBuffer});
+  addAvailableCaps({Capability::Int64, Capability::Float16, Capability::Float64,
+                    Capability::GroupNonUniform, Capability::Image1D,
+                    Capability::SampledBuffer, Capability::ImageBuffer});
 }
 
 } // namespace SPIRV

>From e9d6309e255f7a13bae40fb429d4e0266f103dff Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Fri, 1 Nov 2024 18:55:54 +0000
Subject: [PATCH 08/13] add extension and capabilities to vulkan env

---
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 14 +++++++-------
 llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp      |  2 +-
 2 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 33c25c1804dee5..d9170b552644b5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -636,6 +636,13 @@ void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {
                       Capability::DotProductInput4x8BitKHR,
                       Capability::DotProductInput4x8BitPackedKHR});
 
+  // Add capabilities enabled by extensions.
+  for (auto Extension : ST.getAllAvailableExtensions()) {
+    CapabilityList EnabledCapabilities =
+        getCapabilitiesEnabledByExtension(Extension);
+    addAvailableCaps(EnabledCapabilities);
+  }
+
   if (ST.isOpenCLEnv()) {
     initAvailableCapabilitiesForOpenCL(ST);
     return;
@@ -683,13 +690,6 @@ void RequirementHandler::initAvailableCapabilitiesForOpenCL(
   // TODO: verify if this needs some checks.
   addAvailableCaps({Capability::Float16, Capability::Float64});
 
-  // Add capabilities enabled by extensions.
-  for (auto Extension : ST.getAllAvailableExtensions()) {
-    CapabilityList EnabledCapabilities =
-        getCapabilitiesEnabledByExtension(Extension);
-    addAvailableCaps(EnabledCapabilities);
-  }
-
   // TODO: add OpenCL extensions.
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
index d31673bff5947c..fc3f5527da3867 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
@@ -129,7 +129,7 @@ bool SPIRVSubtarget::canDirectlyComparePointers() const {
 
 void SPIRVSubtarget::initAvailableExtensions() {
   AvailableExtensions.clear();
-  if (!isOpenCLEnv())
+  if (!(isOpenCLEnv() || isVulkanEnv()))
     return;
 
   AvailableExtensions.insert(Extensions.begin(), Extensions.end());

>From 580742e2a5835926d848b241f1c7949ce5f8246a Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Fri, 1 Nov 2024 18:56:39 +0000
Subject: [PATCH 09/13] only add required capabilities

---
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 25 +++++++++++++++----
 .../SPIRV/hlsl-intrinsics/dot4add_i8packed.ll | 13 +++++-----
 2 files changed, 27 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index d9170b552644b5..283b1707817bfd 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1226,13 +1226,28 @@ void addInstrRequirements(const MachineInstr &MI,
     break;
   case SPIRV::OpSDot:
   case SPIRV::OpUDot: {
-    if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product)) {
+    if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product))
       Reqs.addExtension(SPIRV::Extension::SPV_KHR_integer_dot_product);
-    }
     Reqs.addCapability(SPIRV::Capability::DotProductKHR);
-    Reqs.addCapability(SPIRV::Capability::DotProductInputAllKHR);
-    Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitKHR);
-    Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitPackedKHR);
+
+    const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+    const MachineInstr *InstrPtr = &MI;
+    assert(MI.getOperand(1).isReg() && "Unexpected operand in dot");
+
+    Register TypeReg = InstrPtr->getOperand(1).getReg();
+    SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
+    if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
+      assert(TypeDef->getOperand(1).getImm() == 32);
+      Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitPackedKHR);
+    } else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) {
+      SPIRVType *ScalarTypeDef =
+          MRI.getVRegDef(TypeDef->getOperand(1).getReg());
+      assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt);
+      auto Capability = ScalarTypeDef->getOperand(1).getImm() == 8
+                            ? SPIRV::Capability::DotProductInput4x8BitKHR
+                            : SPIRV::Capability::DotProductInputAllKHR;
+      Reqs.addCapability(Capability);
+    }
     break;
   }
   default:
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll
index b9bbb7bfc0de61..2a03d430e2c02f 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll
@@ -1,10 +1,11 @@
-; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-EXP
-; RUN: llc -O0 -mtriple=spirv32v1.6-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
-; RUN: llc -O0 -mtriple=spirv32-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-EXT
-; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
-; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32v1.6-unknown-unknown %s -o - -filetype=obj | spirv-val %}
-; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - -filetype=obj | spirv-val %}
+; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-EXP
+; RUN: llc -O0 -mtriple=spirv1.6-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
+; RUN: llc -O0 -mtriple=spirv-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-EXT
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.6-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - -filetype=obj | spirv-val %}
 
+; CHECK-DOT: OpCapability DotProductKHR
 ; CHECK-DOT: OpCapability DotProductInput4x8BitPackedKHR
 ; CHECK-EXT: OpExtension "SPV_KHR_integer_dot_product"
 

>From 33f284488835b583748b8b7499dc7cb15b2f32d7 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Sat, 2 Nov 2024 01:40:01 +0000
Subject: [PATCH 10/13] remove unneeded flags

---
 clang/test/CodeGenHLSL/builtins/dot4add_i8packed.hlsl | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/clang/test/CodeGenHLSL/builtins/dot4add_i8packed.hlsl b/clang/test/CodeGenHLSL/builtins/dot4add_i8packed.hlsl
index ea1a33d6267d2f..7cc31512844d16 100644
--- a/clang/test/CodeGenHLSL/builtins/dot4add_i8packed.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot4add_i8packed.hlsl
@@ -1,7 +1,7 @@
-// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN: %clang_cc1 -finclude-default-header -triple \
 // RUN:   dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
 // RUN:   FileCheck %s -DTARGET=dx
-// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN: %clang_cc1 -finclude-default-header -triple \
 // RUN:   spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
 // RUN:   FileCheck %s -DTARGET=spv
 

>From e1bfa192079fb28aa3bf937cfd23a35962eb6e76 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 5 Nov 2024 01:12:52 +0000
Subject: [PATCH 11/13] move dot product requirements out of switch

---
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 52 ++++++++++---------
 1 file changed, 28 insertions(+), 24 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 283b1707817bfd..84dccb7b8efc46 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -854,6 +854,32 @@ static void AddAtomicFloatRequirements(const MachineInstr &MI,
   }
 }
 
+static void AddDotProductRequirements(const MachineInstr &MI,
+                                      SPIRV::RequirementHandler &Reqs,
+                                      const SPIRVSubtarget &ST) {
+  if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product))
+    Reqs.addExtension(SPIRV::Extension::SPV_KHR_integer_dot_product);
+  Reqs.addCapability(SPIRV::Capability::DotProductKHR);
+
+  const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+  const MachineInstr *InstrPtr = &MI;
+  assert(MI.getOperand(1).isReg() && "Unexpected operand in dot");
+
+  Register TypeReg = InstrPtr->getOperand(1).getReg();
+  SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
+  if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
+    assert(TypeDef->getOperand(1).getImm() == 32);
+    Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitPackedKHR);
+  } else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) {
+    SPIRVType *ScalarTypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
+    assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt);
+    auto Capability = ScalarTypeDef->getOperand(1).getImm() == 8
+                          ? SPIRV::Capability::DotProductInput4x8BitKHR
+                          : SPIRV::Capability::DotProductInputAllKHR;
+    Reqs.addCapability(Capability);
+  }
+}
+
 void addInstrRequirements(const MachineInstr &MI,
                           SPIRV::RequirementHandler &Reqs,
                           const SPIRVSubtarget &ST) {
@@ -1225,31 +1251,9 @@ void addInstrRequirements(const MachineInstr &MI,
     }
     break;
   case SPIRV::OpSDot:
-  case SPIRV::OpUDot: {
-    if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product))
-      Reqs.addExtension(SPIRV::Extension::SPV_KHR_integer_dot_product);
-    Reqs.addCapability(SPIRV::Capability::DotProductKHR);
-
-    const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
-    const MachineInstr *InstrPtr = &MI;
-    assert(MI.getOperand(1).isReg() && "Unexpected operand in dot");
-
-    Register TypeReg = InstrPtr->getOperand(1).getReg();
-    SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
-    if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
-      assert(TypeDef->getOperand(1).getImm() == 32);
-      Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitPackedKHR);
-    } else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) {
-      SPIRVType *ScalarTypeDef =
-          MRI.getVRegDef(TypeDef->getOperand(1).getReg());
-      assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt);
-      auto Capability = ScalarTypeDef->getOperand(1).getImm() == 8
-                            ? SPIRV::Capability::DotProductInput4x8BitKHR
-                            : SPIRV::Capability::DotProductInputAllKHR;
-      Reqs.addCapability(Capability);
-    }
+  case SPIRV::OpUDot:
+    AddDotProductRequirements(MI, Reqs, ST);
     break;
-  }
   default:
     break;
   }

>From a291cab33f94b1d2dab18dd8893bfd742b799a69 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 5 Nov 2024 02:33:34 +0000
Subject: [PATCH 12/13] review comments

  - use getRegClass instead of IDRegClass
  - remove unneeded check of isvulkan/opencl
  - fix testcase to use fixed version and -verify-machineinstrs
  - fix Result bool to be from and to or
---
 .../Target/SPIRV/SPIRVInstructionSelector.cpp  | 18 +++++++++---------
 llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp       |  3 ---
 .../SPIRV/hlsl-intrinsics/dot4add_i8packed.ll  |  8 ++++----
 3 files changed, 13 insertions(+), 16 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 752f35c02b9229..0c7f6a76cc575c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1652,7 +1652,7 @@ bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
   // Multiply the vectors, then sum the results
   Register Vec0 = I.getOperand(2).getReg();
   Register Vec1 = I.getOperand(3).getReg();
-  Register TmpVec = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+  Register TmpVec = MRI->createVirtualRegister(GR.getRegClass(ResType));
   SPIRVType *VecType = GR.getSPIRVTypeForVReg(Vec0);
 
   bool Result = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIMulV))
@@ -1666,8 +1666,8 @@ bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
          GR.getScalarOrVectorComponentCount(VecType) > 1 &&
          "dot product requires a vector of at least 2 components");
 
-  Register Res = MRI->createVirtualRegister(&SPIRV::IDRegClass);
-  Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
+  Register Res = MRI->createVirtualRegister(GR.getRegClass(ResType));
+  Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
                 .addDef(Res)
                 .addUse(GR.getSPIRVTypeID(ResType))
                 .addUse(TmpVec)
@@ -1675,9 +1675,9 @@ bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
                 .constrainAllUses(TII, TRI, RBI);
 
   for (unsigned i = 1; i < GR.getScalarOrVectorComponentCount(VecType); i++) {
-    Register Elt = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+    Register Elt = MRI->createVirtualRegister(GR.getRegClass(ResType));
 
-    Result |=
+    Result &=
         BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
             .addDef(Elt)
             .addUse(GR.getSPIRVTypeID(ResType))
@@ -1686,10 +1686,10 @@ bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
             .constrainAllUses(TII, TRI, RBI);
 
     Register Sum = i < GR.getScalarOrVectorComponentCount(VecType) - 1
-                       ? MRI->createVirtualRegister(&SPIRV::IDRegClass)
+                       ? MRI->createVirtualRegister(GR.getRegClass(ResType))
                        : ResVReg;
 
-    Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS))
+    Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS))
                   .addDef(Sum)
                   .addUse(GR.getSPIRVTypeID(ResType))
                   .addUse(Res)
@@ -1712,7 +1712,7 @@ bool SPIRVInstructionSelector::selectDot4AddPacked(Register ResVReg,
   MachineBasicBlock &BB = *I.getParent();
 
   auto DotOp = Signed ? SPIRV::OpSDot : SPIRV::OpUDot;
-  Register Dot = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+  Register Dot = MRI->createVirtualRegister(GR.getRegClass(ResType));
   bool Result = BuildMI(BB, I, I.getDebugLoc(), TII.get(DotOp))
                     .addDef(Dot)
                     .addUse(GR.getSPIRVTypeID(ResType))
@@ -1720,7 +1720,7 @@ bool SPIRVInstructionSelector::selectDot4AddPacked(Register ResVReg,
                     .addUse(I.getOperand(3).getReg())
                     .constrainAllUses(TII, TRI, RBI);
 
-  Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS))
+  Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS))
                 .addDef(ResVReg)
                 .addUse(GR.getSPIRVTypeID(ResType))
                 .addUse(Dot)
diff --git a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
index fc3f5527da3867..fc35a3e06c43f7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
@@ -129,9 +129,6 @@ bool SPIRVSubtarget::canDirectlyComparePointers() const {
 
 void SPIRVSubtarget::initAvailableExtensions() {
   AvailableExtensions.clear();
-  if (!(isOpenCLEnv() || isVulkanEnv()))
-    return;
-
   AvailableExtensions.insert(Extensions.begin(), Extensions.end());
 }
 
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll
index 2a03d430e2c02f..39ed8d061efefd 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll
@@ -1,9 +1,9 @@
-; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-EXP
+; RUN: llc -O0 -mtriple=spirv1.5-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-EXP
 ; RUN: llc -O0 -mtriple=spirv1.6-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
 ; RUN: llc -O0 -mtriple=spirv-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-EXT
-; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
-; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.6-unknown-unknown %s -o - -filetype=obj | spirv-val %}
-; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - -filetype=obj | spirv-val %}
+; RUN: %if spirv-tools %{ llc -verify-machineinstrs -O0 -mtriple=spirv1.5-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+; RUN: %if spirv-tools %{ llc -verify-machineinstrs -O0 -mtriple=spirv1.6-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+; RUN: %if spirv-tools %{ llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - -filetype=obj | spirv-val %}
 
 ; CHECK-DOT: OpCapability DotProductKHR
 ; CHECK-DOT: OpCapability DotProductInput4x8BitPackedKHR

>From 92539cd40d4c71292601f52450cd324b8e94a495 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 5 Nov 2024 02:42:07 +0000
Subject: [PATCH 13/13] change name to use not KHR version

---
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp     | 15 +++++++--------
 llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td    |  8 ++++----
 .../SPIRV/hlsl-intrinsics/dot4add_i8packed.ll     |  4 ++--
 3 files changed, 13 insertions(+), 14 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 84dccb7b8efc46..f2ff65acea7d92 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -631,10 +631,9 @@ void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {
                     Capability::Int16});
 
   if (ST.isAtLeastSPIRVVer(VersionTuple(1, 6)))
-    addAvailableCaps({Capability::DotProductKHR,
-                      Capability::DotProductInputAllKHR,
-                      Capability::DotProductInput4x8BitKHR,
-                      Capability::DotProductInput4x8BitPackedKHR});
+    addAvailableCaps({Capability::DotProduct, Capability::DotProductInputAll,
+                      Capability::DotProductInput4x8Bit,
+                      Capability::DotProductInput4x8BitPacked});
 
   // Add capabilities enabled by extensions.
   for (auto Extension : ST.getAllAvailableExtensions()) {
@@ -859,7 +858,7 @@ static void AddDotProductRequirements(const MachineInstr &MI,
                                       const SPIRVSubtarget &ST) {
   if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product))
     Reqs.addExtension(SPIRV::Extension::SPV_KHR_integer_dot_product);
-  Reqs.addCapability(SPIRV::Capability::DotProductKHR);
+  Reqs.addCapability(SPIRV::Capability::DotProduct);
 
   const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
   const MachineInstr *InstrPtr = &MI;
@@ -869,13 +868,13 @@ static void AddDotProductRequirements(const MachineInstr &MI,
   SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
   if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
     assert(TypeDef->getOperand(1).getImm() == 32);
-    Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitPackedKHR);
+    Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitPacked);
   } else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) {
     SPIRVType *ScalarTypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
     assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt);
     auto Capability = ScalarTypeDef->getOperand(1).getImm() == 8
-                          ? SPIRV::Capability::DotProductInput4x8BitKHR
-                          : SPIRV::Capability::DotProductInputAllKHR;
+                          ? SPIRV::Capability::DotProductInput4x8Bit
+                          : SPIRV::Capability::DotProductInputAll;
     Reqs.addCapability(Capability);
   }
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 888525f2ea85a4..1f00dbeb6418ef 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -465,10 +465,10 @@ defm ExpectAssumeKHR : CapabilityOperand<5629, 0, 0, [SPV_KHR_expect_assume], []
 defm FunctionPointersINTEL : CapabilityOperand<5603, 0, 0, [SPV_INTEL_function_pointers], []>;
 defm IndirectReferencesINTEL : CapabilityOperand<5604, 0, 0, [SPV_INTEL_function_pointers], []>;
 defm AsmINTEL : CapabilityOperand<5606, 0, 0, [SPV_INTEL_inline_assembly], []>;
-defm DotProductInputAllKHR : CapabilityOperand<6016, 0, 0, [SPV_KHR_integer_dot_product], []>;
-defm DotProductInput4x8BitKHR : CapabilityOperand<6017, 0, 0, [SPV_KHR_integer_dot_product], [Int8]>;
-defm DotProductInput4x8BitPackedKHR : CapabilityOperand<6018, 0, 0, [SPV_KHR_integer_dot_product], []>;
-defm DotProductKHR : CapabilityOperand<6019, 0, 0, [SPV_KHR_integer_dot_product], []>;
+defm DotProductInputAll : CapabilityOperand<6016, 0, 0, [SPV_KHR_integer_dot_product], []>;
+defm DotProductInput4x8Bit : CapabilityOperand<6017, 0, 0, [SPV_KHR_integer_dot_product], [Int8]>;
+defm DotProductInput4x8BitPacked : CapabilityOperand<6018, 0, 0, [SPV_KHR_integer_dot_product], []>;
+defm DotProduct : CapabilityOperand<6019, 0, 0, [SPV_KHR_integer_dot_product], []>;
 defm GroupNonUniformRotateKHR : CapabilityOperand<6026, 0, 0, [SPV_KHR_subgroup_rotate], [GroupNonUniform]>;
 defm AtomicFloat32AddEXT : CapabilityOperand<6033, 0, 0, [SPV_EXT_shader_atomic_float_add], []>;
 defm AtomicFloat64AddEXT : CapabilityOperand<6034, 0, 0, [SPV_EXT_shader_atomic_float_add], []>;
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll
index 39ed8d061efefd..2ac557b43916b9 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll
@@ -5,8 +5,8 @@
 ; RUN: %if spirv-tools %{ llc -verify-machineinstrs -O0 -mtriple=spirv1.6-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 ; RUN: %if spirv-tools %{ llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - -filetype=obj | spirv-val %}
 
-; CHECK-DOT: OpCapability DotProductKHR
-; CHECK-DOT: OpCapability DotProductInput4x8BitPackedKHR
+; CHECK-DOT: OpCapability DotProduct
+; CHECK-DOT: OpCapability DotProductInput4x8BitPacked
 ; CHECK-EXT: OpExtension "SPV_KHR_integer_dot_product"
 
 ; CHECK: %[[#int_32:]] = OpTypeInt 32 0



More information about the cfe-commits mailing list