[clang] [llvm] [HLSL][SPIRV][DXIL] Implement `dot4add_u8packed` intrinsic (PR #115068)

Finn Plummer via cfe-commits cfe-commits at lists.llvm.org
Wed Nov 6 11:03:35 PST 2024


https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/115068

>From 414b07fd2276946936dc137fb633b04cd8c12fc4 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 5 Nov 2024 21:15:17 +0000
Subject: [PATCH 1/3] [HLSL][SPIRV][DXIL] Implement `dot4add_u8packed`
 intrinsic

- create a clang built-in in Builtins.td
- link dot4add_u8packed in hlsl_intrinsics.h
- add lowering to spirv backend through expansion of operation as OpUDot
is missing up to SPIRV 1.6 in SPIRVInstructionSelector.cpp
- add lowering to spirv backend using OpUDot in applicable SPIRV version
or if SPV_KHR_integer_dot_product is enabled
- add dot4add_u8packed intrinsic to IntrinsicsDirectX.td and mapping to
DXIL.td op Dot4AddU8Packed

- add tests for HLSL intrinsic lowering to dx/spv intrinsic in
dot4add_u8packed.hlsl
- add tests for sema checks in dot4add_u8packed-errors.hlsl
- add test of spir-v lowering in SPIRV/dot4add_u8packed.ll
- add test to dxil lowering in DirectX/dot4add_u8packed.ll
---
 clang/include/clang/Basic/Builtins.td         |  6 ++
 clang/lib/CodeGen/CGBuiltin.cpp               | 10 +++
 clang/lib/CodeGen/CGHLSLRuntime.h             |  1 +
 clang/lib/Headers/hlsl/hlsl_intrinsics.h      |  8 ++-
 .../builtins/dot4add_u8packed.hlsl            | 18 +++++
 .../BuiltIns/dot4add_u8packed-errors.hlsl     | 28 ++++++++
 llvm/include/llvm/IR/IntrinsicsDirectX.td     |  3 +-
 llvm/include/llvm/IR/IntrinsicsSPIRV.td       |  1 +
 llvm/lib/Target/DirectX/DXIL.td               | 10 +++
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 17 +++--
 llvm/test/CodeGen/DirectX/dot4add_u8packed.ll | 10 +++
 .../SPIRV/hlsl-intrinsics/dot4add_u8packed.ll | 65 +++++++++++++++++++
 12 files changed, 169 insertions(+), 8 deletions(-)
 create mode 100644 clang/test/CodeGenHLSL/builtins/dot4add_u8packed.hlsl
 create mode 100644 clang/test/SemaHLSL/BuiltIns/dot4add_u8packed-errors.hlsl
 create mode 100644 llvm/test/CodeGen/DirectX/dot4add_u8packed.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_u8packed.ll

diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index f5d45fa7b90a40..eba1105536a354 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4798,6 +4798,12 @@ def HLSLDot4AddI8Packed : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "int(unsigned int, unsigned int, int)";
 }
 
+def HLSLDot4AddU8Packed : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_dot4add_u8packed"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "unsigned int(unsigned int, unsigned int, unsigned int)";
+}
+
 def HLSLFrac : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_elementwise_frac"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index fb731413d8c9ad..43d56452a80071 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18866,6 +18866,16 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         /*ReturnType=*/C->getType(), ID, ArrayRef<Value *>{A, B, C}, nullptr,
         "hlsl.dot4add.i8packed");
   }
+  case Builtin::BI__builtin_hlsl_dot4add_u8packed: {
+    Value *A = EmitScalarExpr(E->getArg(0));
+    Value *B = EmitScalarExpr(E->getArg(1));
+    Value *C = EmitScalarExpr(E->getArg(2));
+
+    Intrinsic::ID ID = CGM.getHLSLRuntime().getDot4AddU8PackedIntrinsic();
+    return Builder.CreateIntrinsic(
+        /*ReturnType=*/C->getType(), ID, ArrayRef<Value *>{A, B, C}, nullptr,
+        "hlsl.dot4add.u8packed");
+  }
   case Builtin::BI__builtin_hlsl_lerp: {
     Value *X = EmitScalarExpr(E->getArg(0));
     Value *Y = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 5e09ab694a79b8..f2a16676da2847 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -90,6 +90,7 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddI8Packed, dot4add_i8packed)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
 
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 6cd93c5e9a592d..785f47b33d82a5 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -942,7 +942,13 @@ uint64_t dot(uint64_t4, uint64_t4);
 
 _HLSL_AVAILABILITY(shadermodel, 6.4)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot4add_i8packed)
-int dot4add_i8packed(unsigned int, unsigned int, int);
+int dot4add_i8packed(uint, uint, int);
+
+/// \fn uint dot4add_i8packed(uint A, uint B, uint C)
+
+_HLSL_AVAILABILITY(shadermodel, 6.4)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot4add_u8packed)
+uint dot4add_u8packed(uint, uint, uint);
 
 //===----------------------------------------------------------------------===//
 // exp builtins
diff --git a/clang/test/CodeGenHLSL/builtins/dot4add_u8packed.hlsl b/clang/test/CodeGenHLSL/builtins/dot4add_u8packed.hlsl
new file mode 100644
index 00000000000000..7e5fd8ee888f20
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/dot4add_u8packed.hlsl
@@ -0,0 +1,18 @@
+
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN:   dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s -DTARGET=dx
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN:   spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s -DTARGET=spv
+
+// Test basic lowering to runtime function call.
+
+// CHECK-LABEL: test
+uint test(uint a, uint b, uint c) {
+  // CHECK:  %[[RET:.*]] = call [[TY:i32]] @llvm.[[TARGET]].dot4add.u8packed([[TY]] %[[#]], [[TY]] %[[#]], [[TY]] %[[#]])
+  // CHECK:  ret [[TY]] %[[RET]]
+  return dot4add_u8packed(a, b, c);
+}
+
+// CHECK: declare [[TY]] @llvm.[[TARGET]].dot4add.u8packed([[TY]], [[TY]], [[TY]])
diff --git a/clang/test/SemaHLSL/BuiltIns/dot4add_u8packed-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/dot4add_u8packed-errors.hlsl
new file mode 100644
index 00000000000000..f1fa41902b9687
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/dot4add_u8packed-errors.hlsl
@@ -0,0 +1,28 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+int test_too_few_arg0() {
+  return __builtin_hlsl_dot4add_u8packed();
+  // expected-error at -1 {{too few arguments to function call, expected 3, have 0}}
+}
+
+int test_too_few_arg1(int p0) {
+  return __builtin_hlsl_dot4add_u8packed(p0);
+  // expected-error at -1 {{too few arguments to function call, expected 3, have 1}}
+}
+
+int test_too_few_arg2(uint p0) {
+  return __builtin_hlsl_dot4add_u8packed(p0, p0);
+  // expected-error at -1 {{too few arguments to function call, expected 3, have 2}}
+}
+
+int test_too_many_arg(uint p0) {
+  return __builtin_hlsl_dot4add_u8packed(p0, p0, p0, p0);
+  // expected-error at -1 {{too many arguments to function call, expected 3, have 4}}
+}
+
+struct S { float f; };
+
+int test_expr_struct_type_check(S p0, uint p1) {
+  return __builtin_hlsl_dot4add_u8packed(p1, p1, p0);
+  // expected-error at -1 {{no viable conversion from 'S' to 'unsigned int'}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 8cd5ff9006c1b7..90f394b02d90b3 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -69,7 +69,8 @@ def int_dx_udot :
     DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
     [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
     [IntrNoMem, Commutative] >;
-  def int_dx_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
+def int_dx_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
+def int_dx_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
 
 def int_dx_frac  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_degrees : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 72fee94908db72..85665c1c5965a7 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -84,6 +84,7 @@ let TargetPrefix = "spv" in {
     [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
     [IntrNoMem, Commutative] >;
   def int_spv_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
+  def int_spv_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
   def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
   def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
   def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index aa372b4e31f4b9..603b7f328de4b1 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -798,6 +798,16 @@ def Dot4AddI8Packed : DXILOp<163, dot4AddPacked> {
   let stages = [Stages<DXIL1_0, [all_stages]>];
 }
 
+def Dot4AddU8Packed : DXILOp<164, dot4AddPacked> {
+  let Doc = "unsigned dot product of 4 x i8 vectors packed into i32, with "
+            "accumulate to i32";
+  let LLVMIntrinsic = int_dx_dot4add_u8packed;
+  let arguments = [Int32Ty, Int32Ty, Int32Ty];
+  let result = Int32Ty;
+  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let stages = [Stages<DXIL1_0, [all_stages]>];
+}
+
 def AnnotateHandle : DXILOp<216, annotateHandle> {
   let Doc = "annotate handle with resource properties";
   let arguments = [HandleTy, ResPropsTy];
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 7aa5f4f2b1a8f1..ca256bbc86ddcd 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1743,7 +1743,7 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
   assert(I.getOperand(4).isReg());
   MachineBasicBlock &BB = *I.getParent();
 
-  bool Result = false;
+  bool Result = true;
 
   // Acc = C
   Register Acc = I.getOperand(4).getReg();
@@ -1755,7 +1755,7 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
   for (unsigned i = 0; i < 4; i++) {
     // A[i]
     Register AElt = MRI->createVirtualRegister(&SPIRV::IDRegClass);
-    Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
+    Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
                   .addDef(AElt)
                   .addUse(GR.getSPIRVTypeID(ResType))
                   .addUse(I.getOperand(2).getReg())
@@ -1765,7 +1765,7 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
 
     // B[i]
     Register BElt = MRI->createVirtualRegister(&SPIRV::IDRegClass);
-    Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
+    Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
                   .addDef(BElt)
                   .addUse(GR.getSPIRVTypeID(ResType))
                   .addUse(I.getOperand(3).getReg())
@@ -1775,7 +1775,7 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
 
     // A[i] * B[i]
     Register Mul = MRI->createVirtualRegister(&SPIRV::IDRegClass);
-    Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIMulS))
+    Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIMulS))
                   .addDef(Mul)
                   .addUse(GR.getSPIRVTypeID(ResType))
                   .addUse(AElt)
@@ -1784,7 +1784,7 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
 
     // Discard 24 highest-bits so that stored i32 register is i8 equivalent
     Register MaskMul = MRI->createVirtualRegister(&SPIRV::IDRegClass);
-    Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
+    Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
                   .addDef(MaskMul)
                   .addUse(GR.getSPIRVTypeID(ResType))
                   .addUse(Mul)
@@ -1795,7 +1795,7 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
     // Acc = Acc + A[i] * B[i]
     Register Sum =
         i < 3 ? MRI->createVirtualRegister(&SPIRV::IDRegClass) : ResVReg;
-    Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS))
+    Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS))
                   .addDef(Sum)
                   .addUse(GR.getSPIRVTypeID(ResType))
                   .addUse(Acc)
@@ -2646,6 +2646,11 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
         STI.isAtLeastSPIRVVer(VersionTuple(1, 6)))
       return selectDot4AddPacked<true>(ResVReg, ResType, I);
     return selectDot4AddPackedExpansion<true>(ResVReg, ResType, I);
+  case Intrinsic::spv_dot4add_u8packed:
+    if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product) ||
+        STI.isAtLeastSPIRVVer(VersionTuple(1, 6)))
+      return selectDot4AddPacked<false>(ResVReg, ResType, I);
+    return selectDot4AddPackedExpansion<false>(ResVReg, ResType, I);
   case Intrinsic::spv_all:
     return selectAll(ResVReg, ResType, I);
   case Intrinsic::spv_any:
diff --git a/llvm/test/CodeGen/DirectX/dot4add_u8packed.ll b/llvm/test/CodeGen/DirectX/dot4add_u8packed.ll
new file mode 100644
index 00000000000000..3836b4a4bc16c9
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/dot4add_u8packed.ll
@@ -0,0 +1,10 @@
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
+
+define void @main(i32 %a, i32 %b, i32 %c) {
+entry:
+; CHECK: call i32 @dx.op.dot4AddPacked(i32 164, i32 %a, i32 %b, i32 %c)
+  %0 = call i32 @llvm.dx.dot4add.u8packed(i32 %a, i32 %b, i32 %c)
+  ret void
+}
+
+declare i32 @llvm.dx.dot4add.u8packed(i32, i32, i32)
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_u8packed.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_u8packed.ll
new file mode 100644
index 00000000000000..8acdb7d2ea3241
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_u8packed.ll
@@ -0,0 +1,65 @@
+; RUN: llc -O0 -mtriple=spirv1.5-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-EXP
+; RUN: llc -O0 -mtriple=spirv1.6-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
+; RUN: llc -O0 -mtriple=spirv-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-EXT
+; RUN: %if spirv-tools %{ llc -verify-machineinstrs -O0 -mtriple=spirv1.5-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+; RUN: %if spirv-tools %{ llc -verify-machineinstrs -O0 -mtriple=spirv1.6-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+; RUN: %if spirv-tools %{ llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DOT: OpCapability DotProduct
+; CHECK-DOT: OpCapability DotProductInput4x8BitPacked
+; CHECK-EXT: OpExtension "SPV_KHR_integer_dot_product"
+
+; CHECK: %[[#int_32:]] = OpTypeInt 32 0
+; CHECK-EXP-DAG: %[[#int_8:]] = OpTypeInt 8 0
+; CHECK-EXP-DAG: %[[#zero:]] = OpConstantNull %[[#int_8]]
+; CHECK-EXP-DAG: %[[#eight:]] = OpConstant %[[#int_8]] 8
+; CHECK-EXP-DAG: %[[#sixteen:]] = OpConstant %[[#int_8]] 16
+; CHECK-EXP-DAG: %[[#twentyfour:]] = OpConstant %[[#int_8]] 24
+
+; CHECK-LABEL: Begin function test_dot
+define noundef i32 @test_dot(i32 noundef %a, i32 noundef %b, i32 noundef %c) {
+entry:
+; CHECK: %[[#A:]] = OpFunctionParameter %[[#int_32]]
+; CHECK: %[[#B:]] = OpFunctionParameter %[[#int_32]]
+; CHECK: %[[#C:]] = OpFunctionParameter %[[#int_32]]
+
+; Test that we use the dot product op when capabilities allow
+
+; CHECK-DOT: %[[#DOT:]] = OpUDot %[[#int_32]] %[[#A]] %[[#B]]
+; CHECK-DOT: %[[#RES:]] = OpIAdd %[[#int_32]] %[[#DOT]] %[[#C]]
+
+; Test expansion is used when spirv dot product capabilities aren't available:
+
+; First element of the packed vector
+; CHECK-EXP: %[[#A0:]] = OpBitFieldUExtract %[[#int_32]] %[[#A]] %[[#zero]] %[[#eight]]
+; CHECK-EXP: %[[#B0:]] = OpBitFieldUExtract %[[#int_32]] %[[#B]] %[[#zero]] %[[#eight]]
+; CHECK-EXP: %[[#MUL0:]] = OpIMul %[[#int_32]] %[[#A0]] %[[#B0]]
+; CHECK-EXP: %[[#MASK0:]] = OpBitFieldUExtract %[[#int_32]] %[[#MUL0]] %[[#zero]] %[[#eight]]
+; CHECK-EXP: %[[#ACC0:]] = OpIAdd %[[#int_32]] %[[#C]] %[[#MASK0]]
+
+; Second element of the packed vector
+; CHECK-EXP: %[[#A1:]] = OpBitFieldUExtract %[[#int_32]] %[[#A]] %[[#eight]] %[[#eight]]
+; CHECK-EXP: %[[#B1:]] = OpBitFieldUExtract %[[#int_32]] %[[#B]] %[[#eight]] %[[#eight]]
+; CHECK-EXP: %[[#MUL1:]] = OpIMul %[[#int_32]] %[[#A1]] %[[#B1]]
+; CHECK-EXP: %[[#MASK1:]] = OpBitFieldUExtract %[[#int_32]] %[[#MUL1]] %[[#zero]] %[[#eight]]
+; CHECK-EXP: %[[#ACC1:]] = OpIAdd %[[#int_32]] %[[#ACC0]] %[[#MASK1]]
+
+; Third element of the packed vector
+; CHECK-EXP: %[[#A2:]] = OpBitFieldUExtract %[[#int_32]] %[[#A]] %[[#sixteen]] %[[#eight]]
+; CHECK-EXP: %[[#B2:]] = OpBitFieldUExtract %[[#int_32]] %[[#B]] %[[#sixteen]] %[[#eight]]
+; CHECK-EXP: %[[#MUL2:]] = OpIMul %[[#int_32]] %[[#A2]] %[[#B2]]
+; CHECK-EXP: %[[#MASK2:]] = OpBitFieldUExtract %[[#int_32]] %[[#MUL2]] %[[#zero]] %[[#eight]]
+; CHECK-EXP: %[[#ACC2:]] = OpIAdd %[[#int_32]] %[[#ACC1]] %[[#MASK2]]
+
+; Fourth element of the packed vector
+; CHECK-EXP: %[[#A3:]] = OpBitFieldUExtract %[[#int_32]] %[[#A]] %[[#twentyfour]] %[[#eight]]
+; CHECK-EXP: %[[#B3:]] = OpBitFieldUExtract %[[#int_32]] %[[#B]] %[[#twentyfour]] %[[#eight]]
+; CHECK-EXP: %[[#MUL3:]] = OpIMul %[[#int_32]] %[[#A3]] %[[#B3]]
+; CHECK-EXP: %[[#MASK3:]] = OpBitFieldUExtract %[[#int_32]] %[[#MUL3]] %[[#zero]] %[[#eight]]
+
+; CHECK-EXP: %[[#RES:]] = OpIAdd %[[#int_32]] %[[#ACC2]] %[[#MASK3]]
+; CHECK: OpReturnValue %[[#RES]]
+  %spv.dot = call i32 @llvm.spv.dot4add.u8packed(i32 %a, i32 %b, i32 %c)
+
+  ret i32 %spv.dot
+}

>From 8d097c8a40667e509e7952dd66b47690e3c8a89a Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 5 Nov 2024 21:18:39 +0000
Subject: [PATCH 2/3] fix typo

---
 clang/lib/Headers/hlsl/hlsl_intrinsics.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 785f47b33d82a5..5a228300957fba 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -944,7 +944,7 @@ _HLSL_AVAILABILITY(shadermodel, 6.4)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot4add_i8packed)
 int dot4add_i8packed(uint, uint, int);
 
-/// \fn uint dot4add_i8packed(uint A, uint B, uint C)
+/// \fn uint dot4add_u8packed(uint A, uint B, uint C)
 
 _HLSL_AVAILABILITY(shadermodel, 6.4)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot4add_u8packed)

>From 462fae7436ffc672874d470c1c8a8f13a44931a2 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Wed, 6 Nov 2024 19:02:01 +0000
Subject: [PATCH 3/3] move the Result bug fix to seperate pr

---
 llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index ca256bbc86ddcd..fa4b00d48c0ebe 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1743,7 +1743,7 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
   assert(I.getOperand(4).isReg());
   MachineBasicBlock &BB = *I.getParent();
 
-  bool Result = true;
+  bool Result = false;
 
   // Acc = C
   Register Acc = I.getOperand(4).getReg();
@@ -1755,7 +1755,7 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
   for (unsigned i = 0; i < 4; i++) {
     // A[i]
     Register AElt = MRI->createVirtualRegister(&SPIRV::IDRegClass);
-    Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
+    Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
                   .addDef(AElt)
                   .addUse(GR.getSPIRVTypeID(ResType))
                   .addUse(I.getOperand(2).getReg())
@@ -1765,7 +1765,7 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
 
     // B[i]
     Register BElt = MRI->createVirtualRegister(&SPIRV::IDRegClass);
-    Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
+    Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
                   .addDef(BElt)
                   .addUse(GR.getSPIRVTypeID(ResType))
                   .addUse(I.getOperand(3).getReg())
@@ -1775,7 +1775,7 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
 
     // A[i] * B[i]
     Register Mul = MRI->createVirtualRegister(&SPIRV::IDRegClass);
-    Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIMulS))
+    Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIMulS))
                   .addDef(Mul)
                   .addUse(GR.getSPIRVTypeID(ResType))
                   .addUse(AElt)
@@ -1784,7 +1784,7 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
 
     // Discard 24 highest-bits so that stored i32 register is i8 equivalent
     Register MaskMul = MRI->createVirtualRegister(&SPIRV::IDRegClass);
-    Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
+    Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
                   .addDef(MaskMul)
                   .addUse(GR.getSPIRVTypeID(ResType))
                   .addUse(Mul)
@@ -1795,7 +1795,7 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
     // Acc = Acc + A[i] * B[i]
     Register Sum =
         i < 3 ? MRI->createVirtualRegister(&SPIRV::IDRegClass) : ResVReg;
-    Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS))
+    Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS))
                   .addDef(Sum)
                   .addUse(GR.getSPIRVTypeID(ResType))
                   .addUse(Acc)



More information about the cfe-commits mailing list