[clang] [llvm] [DXIL] implement dot intrinsic lowering for integers (PR #85662)

Farzon Lotfi via cfe-commits cfe-commits at lists.llvm.org
Tue Mar 19 08:10:17 PDT 2024


https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/85662

>From e7738ae379375ed40558b2e93cc67a5a726aadbc Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Fri, 15 Mar 2024 18:19:52 -0400
Subject: [PATCH 1/2] [DXIL] implement dot intrinsic lowering for integers this
 implements part 1 of 2 for #83626 - `CGBuiltin.cpp` - modified to have
 seperate cases for signed and   unsigned integers. - `SemaChecking.cpp` -
 modified to prevent the generation of a double   dot product intrinsic if the
 builtin were to be called directly. - `IntrinsicsDirectX.td` creation of the
 signed and unsigned dot   intrinsics needed for instruction expansion. -
 `DXILIntrinsicExpansion.cpp` - handle instruction expansion cases for  
 integer dot product.

---
 clang/lib/CodeGen/CGBuiltin.cpp               |  14 ++-
 clang/lib/Sema/SemaChecking.cpp               |  15 +++
 clang/test/CodeGenHLSL/builtins/dot.hlsl      |  36 +++----
 clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl  |   9 ++
 llvm/include/llvm/IR/IntrinsicsDirectX.td     |  11 +-
 .../Target/DirectX/DXILIntrinsicExpansion.cpp |  36 +++++++
 llvm/test/CodeGen/DirectX/idot.ll             | 100 ++++++++++++++++++
 7 files changed, 201 insertions(+), 20 deletions(-)
 create mode 100644 llvm/test/CodeGen/DirectX/idot.ll

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index e965df810add54..e89691ab7921c3 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18036,6 +18036,17 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
   return Arg;
 }
 
+Intrinsic::ID getDotProductIntrinsic(QualType QT) {
+  if (QT->hasSignedIntegerRepresentation())
+    return Intrinsic::dx_sdot;
+  if (QT->hasUnsignedIntegerRepresentation())
+    return Intrinsic::dx_udot;
+
+  assert(QT->hasFloatingRepresentation());
+  return Intrinsic::dx_dot;
+  ;
+}
+
 Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
                                             const CallExpr *E) {
   if (!getLangOpts().HLSL)
@@ -18096,7 +18107,8 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
            "Dot product requires vectors to be of the same size.");
 
     return Builder.CreateIntrinsic(
-        /*ReturnType=*/T0->getScalarType(), Intrinsic::dx_dot,
+        /*ReturnType=*/T0->getScalarType(),
+        getDotProductIntrinsic(E->getArg(0)->getType()),
         ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
   } break;
   case Builtin::BI__builtin_hlsl_lerp: {
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index a0b256ab5579ee..384b929d37bc82 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -5484,6 +5484,19 @@ bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
                                   checkFloatorHalf);
 }
 
+bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) {
+  auto checkDoubleVector = [](clang::QualType PassedType) -> bool {
+    if (PassedType->isVectorType() && PassedType->hasFloatingRepresentation()) {
+      clang::QualType BaseType =
+          PassedType->getAs<clang::VectorType>()->getElementType();
+      return !BaseType->isHalfType() && !BaseType->isFloat32Type();
+    }
+    return false;
+  };
+  return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,
+                                  checkDoubleVector);
+}
+
 void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
                                 QualType ReturnType) {
   auto *VecTyA = TheCall->getArg(0)->getType()->getAs<VectorType>();
@@ -5520,6 +5533,8 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     if (SemaBuiltinVectorToScalarMath(TheCall))
       return true;
+    if (CheckNoDoubleVectors(this, TheCall))
+      return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_elementwise_rcp: {
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
index c064d118caf3e7..0f993193c00cce 100644
--- a/clang/test/CodeGenHLSL/builtins/dot.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -11,15 +11,15 @@
 // NATIVE_HALF: ret i16 %dx.dot
 int16_t test_dot_short(int16_t p0, int16_t p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v2i16(<2 x i16> %0, <2 x i16> %1)
 // NATIVE_HALF: ret i16 %dx.dot
 int16_t test_dot_short2(int16_t2 p0, int16_t2 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v3i16(<3 x i16> %0, <3 x i16> %1)
 // NATIVE_HALF: ret i16 %dx.dot
 int16_t test_dot_short3(int16_t3 p0, int16_t3 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v4i16(<4 x i16> %0, <4 x i16> %1)
 // NATIVE_HALF: ret i16 %dx.dot
 int16_t test_dot_short4(int16_t4 p0, int16_t4 p1) { return dot(p0, p1); }
 
@@ -27,15 +27,15 @@ int16_t test_dot_short4(int16_t4 p0, int16_t4 p1) { return dot(p0, p1); }
 // NATIVE_HALF: ret i16 %dx.dot
 uint16_t test_dot_ushort(uint16_t p0, uint16_t p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v2i16(<2 x i16> %0, <2 x i16> %1)
 // NATIVE_HALF: ret i16 %dx.dot
 uint16_t test_dot_ushort2(uint16_t2 p0, uint16_t2 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v3i16(<3 x i16> %0, <3 x i16> %1)
 // NATIVE_HALF: ret i16 %dx.dot
 uint16_t test_dot_ushort3(uint16_t3 p0, uint16_t3 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v4i16(<4 x i16> %0, <4 x i16> %1)
 // NATIVE_HALF: ret i16 %dx.dot
 uint16_t test_dot_ushort4(uint16_t4 p0, uint16_t4 p1) { return dot(p0, p1); }
 #endif
@@ -44,15 +44,15 @@ uint16_t test_dot_ushort4(uint16_t4 p0, uint16_t4 p1) { return dot(p0, p1); }
 // CHECK: ret i32 %dx.dot
 int test_dot_int(int p0, int p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i32 @llvm.dx.dot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v2i32(<2 x i32> %0, <2 x i32> %1)
 // CHECK: ret i32 %dx.dot
 int test_dot_int2(int2 p0, int2 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i32 @llvm.dx.dot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v3i32(<3 x i32> %0, <3 x i32> %1)
 // CHECK: ret i32 %dx.dot
 int test_dot_int3(int3 p0, int3 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i32 @llvm.dx.dot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v4i32(<4 x i32> %0, <4 x i32> %1)
 // CHECK: ret i32 %dx.dot
 int test_dot_int4(int4 p0, int4 p1) { return dot(p0, p1); }
 
@@ -60,15 +60,15 @@ int test_dot_int4(int4 p0, int4 p1) { return dot(p0, p1); }
 // CHECK: ret i32 %dx.dot
 uint test_dot_uint(uint p0, uint p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i32 @llvm.dx.dot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: %dx.dot = call i32 @llvm.dx.udot.v2i32(<2 x i32> %0, <2 x i32> %1)
 // CHECK: ret i32 %dx.dot
 uint test_dot_uint2(uint2 p0, uint2 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i32 @llvm.dx.dot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: %dx.dot = call i32 @llvm.dx.udot.v3i32(<3 x i32> %0, <3 x i32> %1)
 // CHECK: ret i32 %dx.dot
 uint test_dot_uint3(uint3 p0, uint3 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i32 @llvm.dx.dot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: %dx.dot = call i32 @llvm.dx.udot.v4i32(<4 x i32> %0, <4 x i32> %1)
 // CHECK: ret i32 %dx.dot
 uint test_dot_uint4(uint4 p0, uint4 p1) { return dot(p0, p1); }
 
@@ -76,15 +76,15 @@ uint test_dot_uint4(uint4 p0, uint4 p1) { return dot(p0, p1); }
 // CHECK: ret i64 %dx.dot
 int64_t test_dot_long(int64_t p0, int64_t p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v2i64(<2 x i64> %0, <2 x i64> %1)
 // CHECK: ret i64 %dx.dot
 int64_t test_dot_long2(int64_t2 p0, int64_t2 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i64 @llvm.dx.dot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v3i64(<3 x i64> %0, <3 x i64> %1)
 // CHECK: ret i64 %dx.dot
 int64_t test_dot_long3(int64_t3 p0, int64_t3 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i64 @llvm.dx.dot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v4i64(<4 x i64> %0, <4 x i64> %1)
 // CHECK: ret i64 %dx.dot
 int64_t test_dot_long4(int64_t4 p0, int64_t4 p1) { return dot(p0, p1); }
 
@@ -92,15 +92,15 @@ int64_t test_dot_long4(int64_t4 p0, int64_t4 p1) { return dot(p0, p1); }
 // CHECK: ret i64 %dx.dot
 uint64_t test_dot_ulong(uint64_t p0, uint64_t p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: %dx.dot = call i64 @llvm.dx.udot.v2i64(<2 x i64> %0, <2 x i64> %1)
 // CHECK: ret i64 %dx.dot
 uint64_t test_dot_ulong2(uint64_t2 p0, uint64_t2 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i64 @llvm.dx.dot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: %dx.dot = call i64 @llvm.dx.udot.v3i64(<3 x i64> %0, <3 x i64> %1)
 // CHECK: ret i64 %dx.dot
 uint64_t test_dot_ulong3(uint64_t3 p0, uint64_t3 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i64 @llvm.dx.dot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: %dx.dot = call i64 @llvm.dx.udot.v4i64(<4 x i64> %0, <4 x i64> %1)
 // CHECK: ret i64 %dx.dot
 uint64_t test_dot_ulong4(uint64_t4 p0, uint64_t4 p1) { return dot(p0, p1); }
 
diff --git a/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl
index 59eb9482b9ef92..ba7ffc20484ae0 100644
--- a/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl
@@ -108,3 +108,12 @@ int test_builtin_dot_bool_type_promotion(bool p0, bool p1) {
   return __builtin_hlsl_dot(p0, p1);
   // expected-error at -1 {{1st argument must be a vector, integer or floating point type (was 'bool')}}
 }
+
+double test_dot_double(double2 p0, double2 p1) {
+  return dot(p0, p1);
+  // expected-error at -1 {{call to 'dot' is ambiguous}}
+}
+double test_dot_double_builtin(double2 p0, double2 p1) {
+  return __builtin_hlsl_dot(p0, p1);
+  // expected-error at -1 {{passing 'double2' (aka 'vector<double, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 5c72f06f96ed12..1164b241ba7b0d 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -23,9 +23,18 @@ def int_dx_create_handle : ClangBuiltin<"__builtin_hlsl_create_handle">,
 def int_dx_any  : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
 def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
 def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>; 
+
 def int_dx_dot : 
     Intrinsic<[LLVMVectorElementType<0>], 
-    [llvm_anyvector_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+    [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+    [IntrNoMem, IntrWillReturn, Commutative] >;
+def int_dx_sdot : 
+    Intrinsic<[LLVMVectorElementType<0>], 
+    [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+    [IntrNoMem, IntrWillReturn, Commutative] >;
+def int_dx_udot : 
+    Intrinsic<[LLVMVectorElementType<0>], 
+    [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
     [IntrNoMem, IntrWillReturn, Commutative] >;
 
 def int_dx_frac  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index bc38c10a1fceb0..0db42bc0a0fb64 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -39,11 +39,44 @@ static bool isIntrinsicExpansion(Function &F) {
   case Intrinsic::dx_uclamp:
   case Intrinsic::dx_lerp:
   case Intrinsic::dx_rcp:
+  case Intrinsic::dx_sdot:
+  case Intrinsic::dx_udot:
     return true;
   }
   return false;
 }
 
+static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
+  assert(DotIntrinsic == Intrinsic::dx_sdot ||
+         DotIntrinsic == Intrinsic::dx_udot);
+  Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
+                                   ? Intrinsic::dx_imad
+                                   : Intrinsic::dx_umad;
+  Value *A = Orig->getOperand(0);
+  Value *B = Orig->getOperand(1);
+  Type *ATy = A->getType();
+  Type *BTy = B->getType();
+  assert(ATy->isVectorTy() && BTy->isVectorTy());
+
+  IRBuilder<> Builder(Orig->getParent());
+  Builder.SetInsertPoint(Orig);
+
+  auto *AVec = dyn_cast<FixedVectorType>(A->getType());
+  Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0);
+  Value *Elt1 = Builder.CreateExtractElement(B, (uint64_t)0);
+  Value *Result = Builder.CreateMul(Elt0, Elt1);
+  for (unsigned I = 1; I < AVec->getNumElements(); I++) {
+    Elt0 = Builder.CreateExtractElement(A, I);
+    Elt1 = Builder.CreateExtractElement(B, I);
+    Result = Builder.CreateIntrinsic(Result->getType(), MadIntrinsic,
+                                     ArrayRef<Value *>{Elt0, Elt1, Result},
+                                     nullptr, "dx.mad");
+  }
+  Orig->replaceAllUsesWith(Result);
+  Orig->eraseFromParent();
+  return true;
+}
+
 static bool expandExpIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   IRBuilder<> Builder(Orig->getParent());
@@ -191,6 +224,9 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
     return expandLerpIntrinsic(Orig);
   case Intrinsic::dx_rcp:
     return expandRcpIntrinsic(Orig);
+  case Intrinsic::dx_sdot:
+  case Intrinsic::dx_udot:
+    return expandIntegerDot(Orig, F.getIntrinsicID());
   }
   return false;
 }
diff --git a/llvm/test/CodeGen/DirectX/idot.ll b/llvm/test/CodeGen/DirectX/idot.ll
new file mode 100644
index 00000000000000..286bef3f655116
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/idot.ll
@@ -0,0 +1,100 @@
+; RUN: opt -S  -dxil-intrinsic-expansion  < %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK
+; RUN: opt -S  -dxil-op-lower  < %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
+
+; Make sure dxil operation function calls for dot are generated for int/uint vectors.
+
+; CHECK-LABEL: dot_int16_t2
+define noundef i16 @dot_int16_t2(<2 x i16> noundef %a, <2 x i16> noundef %b) {
+entry:
+; CHECK: extractelement <2 x i16> %a, i64 0
+; CHECK: extractelement <2 x i16> %b, i64 0
+; CHECK: mul i16 %{{.*}}, %{{.*}}
+; CHECK: extractelement <2 x i16> %a, i64 1
+; CHECK: extractelement <2 x i16> %b, i64 1
+; EXPCHECK: call i16 @llvm.dx.imad.i16(i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
+; DOPCHECK: call i16 @dx.op.tertiary.i16(i32 48, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
+  %dx.dot = call i16 @llvm.dx.sdot.v3i16(<2 x i16> %a, <2 x i16> %b)
+  ret i16 %dx.dot
+}
+
+; CHECK-LABEL: sdot_int4
+define noundef i32 @sdot_int4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
+entry:
+; CHECK: extractelement <4 x i32> %a, i64 0
+; CHECK: extractelement <4 x i32> %b, i64 0
+; CHECK: mul i32 %{{.*}}, %{{.*}}
+; CHECK: extractelement <4 x i32> %a, i64 1
+; CHECK: extractelement <4 x i32> %b, i64 1
+; EXPCHECK: call i32 @llvm.dx.imad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 48, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+; CHECK: extractelement <4 x i32> %a, i64 2
+; CHECK: extractelement <4 x i32> %b, i64 2
+; EXPCHECK: call i32 @llvm.dx.imad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 48, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+; CHECK: extractelement <4 x i32> %a, i64 3
+; CHECK: extractelement <4 x i32> %b, i64 3
+; EXPCHECK: call i32 @llvm.dx.imad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 48, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %dx.dot = call i32 @llvm.dx.sdot.v4i32(<4 x i32> %a, <4 x i32> %b)
+  ret i32 %dx.dot
+}
+
+; CHECK-LABEL: dot_uint16_t3
+define noundef i16 @dot_uint16_t3(<3 x i16> noundef %a, <3 x i16> noundef %b) {
+entry:
+; CHECK: extractelement <3 x i16> %a, i64 0
+; CHECK: extractelement <3 x i16> %b, i64 0
+; CHECK: mul i16 %{{.*}}, %{{.*}}
+; CHECK: extractelement <3 x i16> %a, i64 1
+; CHECK: extractelement <3 x i16> %b, i64 1
+; EXPCHECK: call i16 @llvm.dx.umad.i16(i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
+; DOPCHECK: call i16 @dx.op.tertiary.i16(i32 49, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
+; CHECK: extractelement <3 x i16> %a, i64 2
+; CHECK: extractelement <3 x i16> %b, i64 2
+; EXPCHECK: call i16 @llvm.dx.umad.i16(i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
+; DOPCHECK: call i16 @dx.op.tertiary.i16(i32 49, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
+  %dx.dot = call i16 @llvm.dx.udot.v3i16(<3 x i16> %a, <3 x i16> %b)
+  ret i16 %dx.dot
+}
+
+; CHECK-LABEL: dot_uint4
+define noundef i32 @dot_uint4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
+entry:
+; CHECK: extractelement <4 x i32> %a, i64 0
+; CHECK: extractelement <4 x i32> %b, i64 0
+; CHECK: mul i32 %{{.*}}, %{{.*}}
+; CHECK: extractelement <4 x i32> %a, i64 1
+; CHECK: extractelement <4 x i32> %b, i64 1
+; EXPCHECK: call i32 @llvm.dx.umad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 49, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+; CHECK: extractelement <4 x i32> %a, i64 2
+; CHECK: extractelement <4 x i32> %b, i64 2
+; EXPCHECK: call i32 @llvm.dx.umad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 49, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+; CHECK: extractelement <4 x i32> %a, i64 3
+; CHECK: extractelement <4 x i32> %b, i64 3
+; EXPCHECK: call i32 @llvm.dx.umad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 49, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %dx.dot = call i32 @llvm.dx.udot.v4i32(<4 x i32> %a, <4 x i32> %b)
+  ret i32 %dx.dot
+}
+
+; CHECK-LABEL: dot_uint64_t4
+define noundef i64 @dot_uint64_t4(<2 x i64> noundef %a, <2 x i64> noundef %b) {
+entry:
+; CHECK: extractelement <2 x i64> %a, i64 0
+; CHECK: extractelement <2 x i64> %b, i64 0
+; CHECK: mul i64 %{{.*}}, %{{.*}}
+; CHECK: extractelement <2 x i64> %a, i64 1
+; CHECK: extractelement <2 x i64> %b, i64 1
+; EXPCHECK: call i64 @llvm.dx.umad.i64(i64 %{{.*}}, i64 %{{.*}}, i64 %{{.*}})
+; DOPCHECK: call i64 @dx.op.tertiary.i64(i32 49, i64 %{{.*}}, i64 %{{.*}}, i64 %{{.*}})
+  %dx.dot = call i64 @llvm.dx.udot.v2i64(<2 x i64> %a, <2 x i64> %b)
+  ret i64 %dx.dot
+}
+
+declare i16 @llvm.dx.sdot.v2i16(<2 x i16>, <2 x i32>)
+declare i32 @llvm.dx.sdot.v4i32(<4 x i32>, <4 x i32>)
+declare i16 @llvm.dx.udot.v3i32(<3 x i16>, <3 x i32>)
+declare i32 @llvm.dx.udot.v4i32(<4 x i32>, <4 x i32>)
+declare i64 @llvm.dx.udot.v2i64(<2 x i64>, <2 x i64>)

>From 6919d4591b123598d2201f70af3e1598bbcc553d Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Mon, 18 Mar 2024 17:25:31 -0400
Subject: [PATCH 2/2] - make pr suggestion - fix typo

---
 clang/lib/Sema/SemaChecking.cpp   | 5 ++---
 llvm/test/CodeGen/DirectX/idot.ll | 4 ++--
 2 files changed, 4 insertions(+), 5 deletions(-)

diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 384b929d37bc82..f9112a29027acd 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -5486,9 +5486,8 @@ bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
 
 bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) {
   auto checkDoubleVector = [](clang::QualType PassedType) -> bool {
-    if (PassedType->isVectorType() && PassedType->hasFloatingRepresentation()) {
-      clang::QualType BaseType =
-          PassedType->getAs<clang::VectorType>()->getElementType();
+    if (const auto *VecTy = dyn_cast<VectorType>(PassedType)) {
+      clang::QualType BaseType = VecTy->getElementType();
       return !BaseType->isHalfType() && !BaseType->isFloat32Type();
     }
     return false;
diff --git a/llvm/test/CodeGen/DirectX/idot.ll b/llvm/test/CodeGen/DirectX/idot.ll
index 286bef3f655116..9f89a8d6d340d5 100644
--- a/llvm/test/CodeGen/DirectX/idot.ll
+++ b/llvm/test/CodeGen/DirectX/idot.ll
@@ -93,8 +93,8 @@ entry:
   ret i64 %dx.dot
 }
 
-declare i16 @llvm.dx.sdot.v2i16(<2 x i16>, <2 x i32>)
+declare i16 @llvm.dx.sdot.v2i16(<2 x i16>, <2 x i16>)
 declare i32 @llvm.dx.sdot.v4i32(<4 x i32>, <4 x i32>)
-declare i16 @llvm.dx.udot.v3i32(<3 x i16>, <3 x i32>)
+declare i16 @llvm.dx.udot.v3i32(<3 x i16>, <3 x i16>)
 declare i32 @llvm.dx.udot.v4i32(<4 x i32>, <4 x i32>)
 declare i64 @llvm.dx.udot.v2i64(<2 x i64>, <2 x i64>)



More information about the cfe-commits mailing list