[clang] [llvm] [HLSL] implement `mad` intrinsic (PR #83826)

via cfe-commits cfe-commits at lists.llvm.org
Mon Mar 4 03:21:52 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-x86
@llvm/pr-subscribers-backend-directx
@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-clang

Author: Farzon Lotfi (farzonl)

<details>
<summary>Changes</summary>

This change implements #<!-- -->83736
The dot product lowering needs a tertiary multipy add operation. DXIL has three mad opcodes for `fmad`(46), `imad`(48), and `umad`(49). Dot product in DXIL only uses `imad`\ `umad`, but for completeness and because the hlsl `mad` intrinsic requires it `fmad` was also included. Two new intrinsics were needed to be created to complete this change. the `fmad` case already supported by llvm via `fmuladd` intrinsic.

- `hlsl_intrinsics.h` - exposed mad api call.
- `Builtins.td` - exposed a `mad` builtin.
- `Sema.h` - make `tertiary` calls check for float types optional. 
- `CGBuiltin.cpp` - pick the intrinsic for singed\unsigned & float also reuse `int_fmuladd`. 
- `SemaChecking.cpp` - type checks for `__builtin_hlsl_mad`. 
- `IntrinsicsDirectX.td` create the two new intrinsics for `imad`\`umad`/ 
- `DXIL.td` - create the llvm intrinsic to  `DXIL` opcode mapping.

---

Patch is 33.59 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/83826.diff


12 Files Affected:

- (modified) clang/include/clang/Basic/Builtins.td (+6) 
- (modified) clang/include/clang/Sema/Sema.h (+1-1) 
- (modified) clang/lib/CodeGen/CGBuiltin.cpp (+19) 
- (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+105) 
- (modified) clang/lib/Sema/SemaChecking.cpp (+16-6) 
- (added) clang/test/CodeGenHLSL/builtins/mad.hlsl (+197) 
- (added) clang/test/SemaHLSL/BuiltIns/mad-errors.hlsl (+86) 
- (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+4) 
- (modified) llvm/lib/Target/DirectX/DXIL.td (+6) 
- (added) llvm/test/CodeGen/DirectX/fmad.ll (+67) 
- (added) llvm/test/CodeGen/DirectX/imad.ll (+65) 
- (added) llvm/test/CodeGen/DirectX/umad.ll (+65) 


``````````diff
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 2c83dca248fb7d..c4f466208793ea 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4548,6 +4548,12 @@ def HLSLLerp : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLMad : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_mad"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
   let Spellings = ["__xray_customevent"];
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index ef4b93fac95ce5..88be7dd90d21bc 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -14132,7 +14132,7 @@ class Sema final {
   bool SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res);
   bool SemaBuiltinVectorToScalarMath(CallExpr *TheCall);
   bool SemaBuiltinElementwiseMath(CallExpr *TheCall);
-  bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall);
+  bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall, bool enforceFloatingPointCheck = true);
   bool PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall);
   bool PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall);
 
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index e90014261217bc..5cb1dd07aab999 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18057,6 +18057,25 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         /*ReturnType*/ Op0->getType(), Intrinsic::dx_frac,
         ArrayRef<Value *>{Op0}, nullptr, "dx.frac");
   }
+  case Builtin::BI__builtin_hlsl_mad: {
+    Value *M = EmitScalarExpr(E->getArg(0));
+    Value *A = EmitScalarExpr(E->getArg(1));
+    Value *B = EmitScalarExpr(E->getArg(2));
+    if (E->getArg(0)->getType()->hasFloatingRepresentation()) {
+      return Builder.CreateIntrinsic(
+          /*ReturnType*/ M->getType(), Intrinsic::fmuladd,
+          ArrayRef<Value *>{M, A, B}, nullptr, "dx.fmad");
+    }
+    if (E->getArg(0)->getType()->hasSignedIntegerRepresentation()) {
+      return Builder.CreateIntrinsic(
+          /*ReturnType*/ M->getType(), Intrinsic::dx_imad,
+          ArrayRef<Value *>{M, A, B}, nullptr, "dx.imad");
+    }
+    assert(E->getArg(0)->getType()->hasUnsignedIntegerRepresentation());
+    return Builder.CreateIntrinsic(
+        /*ReturnType*/ M->getType(), Intrinsic::dx_umad,
+        ArrayRef<Value *>{M, A, B}, nullptr, "dx.umad");
+  }
   }
   return nullptr;
 }
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 5180530363889f..b5bef78fae72fe 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -511,6 +511,111 @@ double3 log2(double3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_log2)
 double4 log2(double4);
 
+//===----------------------------------------------------------------------===//
+// mad builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T mad(T M, T A, T B)
+/// \brief The result of \a M * \a A + \a B.
+/// \param M The multiplication value.
+/// \param A The first addition value.
+/// \param B The second addition value.
+
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+half mad(half, half, half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+half2 mad(half2, half2, half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+half3 mad(half3, half3, half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+half4 mad(half4, half4, half4);
+
+#ifdef __HLSL_ENABLE_16_BIT
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int16_t mad(int16_t, int16_t, int16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int16_t2 mad(int16_t2, int16_t2, int16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int16_t3 mad(int16_t3, int16_t3, int16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int16_t4 mad(int16_t4, int16_t4, int16_t4);
+
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint16_t mad(uint16_t, uint16_t, uint16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint16_t2 mad(uint16_t2, uint16_t2, uint16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint16_t3 mad(uint16_t3, uint16_t3, uint16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint16_t4 mad(uint16_t4, uint16_t4, uint16_t4);
+#endif
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int mad(int, int, int);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int2 mad(int2, int2, int2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int3 mad(int3, int3, int3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int4 mad(int4, int4, int4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint mad(uint, uint, uint);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint2 mad(uint2, uint2, uint2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint3 mad(uint3, uint3, uint3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint4 mad(uint4, uint4, uint4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int64_t mad(int64_t, int64_t, int64_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int64_t2 mad(int64_t2, int64_t2, int64_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int64_t3 mad(int64_t3, int64_t3, int64_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int64_t4 mad(int64_t4, int64_t4, int64_t4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint64_t mad(uint64_t, uint64_t, uint64_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint64_t2 mad(uint64_t2, uint64_t2, uint64_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint64_t3 mad(uint64_t3, uint64_t3, uint64_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint64_t4 mad(uint64_t4, uint64_t4, uint64_t4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+float mad(float, float, float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+float2 mad(float2, float2, float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+float3 mad(float3, float3, float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+float4 mad(float4, float4, float4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+double mad(double, double, double);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+double2 mad(double2, double2, double2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+double3 mad(double3, double3, double3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+double4 mad(double4, double4, double4);
+
 //===----------------------------------------------------------------------===//
 // max builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 0d4d57db01c93a..795d59ed5fd369 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -5298,6 +5298,14 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_mad: {
+    if (checkArgCount(*this, TheCall, 3))
+      return true;
+    if (CheckVectorElementCallArgs(this, TheCall))
+      return true;
+    if (SemaBuiltinElementwiseTernaryMath(TheCall, false))
+      return true;
+  }
   }
   return false;
 }
@@ -19800,7 +19808,7 @@ bool Sema::SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res) {
   return false;
 }
 
-bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall) {
+bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall, bool enforceFloatingPointCheck) {
   if (checkArgCount(*this, TheCall, 3))
     return true;
 
@@ -19812,11 +19820,13 @@ bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall) {
     Args[I] = Converted.get();
   }
 
-  int ArgOrdinal = 1;
-  for (Expr *Arg : Args) {
-    if (checkFPMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(),
-                                      ArgOrdinal++))
-      return true;
+  if(enforceFloatingPointCheck) {
+    int ArgOrdinal = 1;
+    for (Expr *Arg : Args) {
+      if (checkFPMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(),
+                                        ArgOrdinal++))
+        return true;
+    }
   }
 
   for (int I = 1; I < 3; ++I) {
diff --git a/clang/test/CodeGenHLSL/builtins/mad.hlsl b/clang/test/CodeGenHLSL/builtins/mad.hlsl
new file mode 100644
index 00000000000000..4dd8f01785afa0
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/mad.hlsl
@@ -0,0 +1,197 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN:   -emit-llvm -disable-llvm-passes -o - | FileCheck %s \ 
+// RUN:   --check-prefixes=CHECK,NATIVE_HALF
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
+
+#ifdef __HLSL_ENABLE_16_BIT
+// NATIVE_HALF: %dx.umad = call i16 @llvm.dx.umad.i16(i16 %0, i16 %1, i16 %2)
+// NATIVE_HALF: ret i16 %dx.umad
+uint16_t test_mad_uint16_t(uint16_t p0, uint16_t p1, uint16_t p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.umad = call <2 x i16>  @llvm.dx.umad.v2i16(<2 x i16> %0, <2 x i16> %1, <2 x i16> %2)
+// NATIVE_HALF: ret <2 x i16> %dx.umad
+uint16_t2 test_mad_uint16_t2(uint16_t2 p0, uint16_t2 p1, uint16_t2 p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.umad = call <3 x i16>  @llvm.dx.umad.v3i16(<3 x i16> %0, <3 x i16> %1, <3 x i16> %2)
+// NATIVE_HALF: ret <3 x i16> %dx.umad
+uint16_t3 test_mad_uint16_t3(uint16_t3 p0, uint16_t3 p1, uint16_t3 p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.umad = call <4 x i16>  @llvm.dx.umad.v4i16(<4 x i16> %0, <4 x i16> %1, <4 x i16> %2)
+// NATIVE_HALF: ret <4 x i16> %dx.umad
+uint16_t4 test_mad_uint16_t4(uint16_t4 p0, uint16_t4 p1, uint16_t4 p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.imad = call i16 @llvm.dx.imad.i16(i16 %0, i16 %1, i16 %2)
+// NATIVE_HALF: ret i16 %dx.imad
+int16_t test_mad_int16_t(int16_t p0, int16_t p1, int16_t p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.imad = call <2 x i16>  @llvm.dx.imad.v2i16(<2 x i16> %0, <2 x i16> %1, <2 x i16> %2)
+// NATIVE_HALF: ret <2 x i16> %dx.imad
+int16_t2 test_mad_int16_t2(int16_t2 p0, int16_t2 p1, int16_t2 p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.imad = call <3 x i16>  @llvm.dx.imad.v3i16(<3 x i16> %0, <3 x i16> %1, <3 x i16> %2)
+// NATIVE_HALF: ret <3 x i16> %dx.imad
+int16_t3 test_mad_int16_t3(int16_t3 p0, int16_t3 p1, int16_t3 p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.imad = call <4 x i16>  @llvm.dx.imad.v4i16(<4 x i16> %0, <4 x i16> %1, <4 x i16> %2)
+// NATIVE_HALF: ret <4 x i16> %dx.imad
+int16_t4 test_mad_int16_t4(int16_t4 p0, int16_t4 p1, int16_t4 p2) { return mad(p0, p1, p2); }
+#endif // __HLSL_ENABLE_16_BIT
+
+// NATIVE_HALF: %dx.fmad = call half @llvm.fmuladd.f16(half %0, half %1, half %2)
+// NATIVE_HALF: ret half %dx.fmad
+// NO_HALF: %dx.fmad = call float @llvm.fmuladd.f32(float %0, float %1, float %2)
+// NO_HALF: ret float %dx.fmad
+half test_mad_half(half p0, half p1, half p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.fmad = call <2 x half>  @llvm.fmuladd.v2f16(<2 x half> %0, <2 x half> %1, <2 x half> %2)
+// NATIVE_HALF: ret <2 x half> %dx.fmad
+// NO_HALF: %dx.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
+// NO_HALF: ret <2 x float> %dx.fmad
+half2 test_mad_half2(half2 p0, half2 p1, half2 p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.fmad = call <3 x half>  @llvm.fmuladd.v3f16(<3 x half> %0, <3 x half> %1, <3 x half> %2)
+// NATIVE_HALF: ret <3 x half> %dx.fmad
+// NO_HALF: %dx.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
+// NO_HALF: ret <3 x float> %dx.fmad
+half3 test_mad_half3(half3 p0, half3 p1, half3 p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.fmad = call <4 x half>  @llvm.fmuladd.v4f16(<4 x half> %0, <4 x half> %1, <4 x half> %2)
+// NATIVE_HALF: ret <4 x half> %dx.fmad
+// NO_HALF: %dx.fmad = call <4 x float>  @llvm.fmuladd.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
+// NO_HALF: ret <4 x float> %dx.fmad
+half4 test_mad_half4(half4 p0, half4 p1, half4 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call float @llvm.fmuladd.f32(float %0, float %1, float %2)
+// CHECK: ret float %dx.fmad
+float test_mad_float(float p0, float p1, float p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
+// CHECK: ret <2 x float> %dx.fmad
+float2 test_mad_float2(float2 p0, float2 p1, float2 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
+// CHECK: ret <3 x float> %dx.fmad
+float3 test_mad_float3(float3 p0, float3 p1, float3 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call <4 x float>  @llvm.fmuladd.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
+// CHECK: ret <4 x float> %dx.fmad
+float4 test_mad_float4(float4 p0, float4 p1, float4 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call double @llvm.fmuladd.f64(double %0, double %1, double %2)
+// CHECK: ret double %dx.fmad
+double test_mad_double(double p0, double p1, double p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call <2 x double>  @llvm.fmuladd.v2f64(<2 x double> %0, <2 x double> %1, <2 x double> %2)
+// CHECK: ret <2 x double> %dx.fmad
+double2 test_mad_double2(double2 p0, double2 p1, double2 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call <3 x double>  @llvm.fmuladd.v3f64(<3 x double> %0, <3 x double> %1, <3 x double> %2)
+// CHECK: ret <3 x double> %dx.fmad
+double3 test_mad_double3(double3 p0, double3 p1, double3 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call <4 x double>  @llvm.fmuladd.v4f64(<4 x double> %0, <4 x double> %1, <4 x double> %2)
+// CHECK: ret <4 x double> %dx.fmad
+double4 test_mad_double4(double4 p0, double4 p1, double4 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.imad = call i32 @llvm.dx.imad.i32(i32 %0, i32 %1, i32 %2)
+// CHECK: ret i32 %dx.imad
+int test_mad_int(int p0, int p1, int p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.imad = call <2 x i32>  @llvm.dx.imad.v2i32(<2 x i32> %0, <2 x i32> %1, <2 x i32> %2)
+// CHECK: ret <2 x i32> %dx.imad
+int2 test_mad_int2(int2 p0, int2 p1, int2 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.imad = call <3 x i32>  @llvm.dx.imad.v3i32(<3 x i32> %0, <3 x i32> %1, <3 x i32> %2)
+// CHECK: ret <3 x i32> %dx.imad
+int3 test_mad_int3(int3 p0, int3 p1, int3 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.imad = call <4 x i32>  @llvm.dx.imad.v4i32(<4 x i32> %0, <4 x i32> %1, <4 x i32> %2)
+// CHECK: ret <4 x i32> %dx.imad
+int4 test_mad_int4(int4 p0, int4 p1, int4 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.imad = call i64 @llvm.dx.imad.i64(i64 %0, i64 %1, i64 %2)
+// CHECK: ret i64 %dx.imad
+int64_t test_mad_int64_t(int64_t p0, int64_t p1, int64_t p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.imad = call <2 x i64>  @llvm.dx.imad.v2i64(<2 x i64> %0, <2 x i64> %1, <2 x i64> %2)
+// CHECK: ret <2 x i64> %dx.imad
+int64_t2 test_mad_int64_t2(int64_t2 p0, int64_t2 p1, int64_t2 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.imad = call <3 x i64>  @llvm.dx.imad.v3i64(<3 x i64> %0, <3 x i64> %1, <3 x i64> %2)
+// CHECK: ret <3 x i64> %dx.imad
+int64_t3 test_mad_int64_t3(int64_t3 p0, int64_t3 p1, int64_t3 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.imad = call <4 x i64>  @llvm.dx.imad.v4i64(<4 x i64> %0, <4 x i64> %1, <4 x i64> %2)
+// CHECK: ret <4 x i64> %dx.imad
+int64_t4 test_mad_int64_t4(int64_t4 p0, int64_t4 p1, int64_t4 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.umad = call i32 @llvm.dx.umad.i32(i32 %0, i32 %1, i32 %2)
+// CHECK: ret i32 %dx.umad
+uint test_mad_uint(uint p0, uint p1, uint p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.umad = call <2 x i32>  @llvm.dx.umad.v2i32(<2 x i32> %0, <2 x i32> %1, <2 x i32> %2)
+// CHECK: ret <2 x i32> %dx.umad
+uint2 test_mad_uint2(uint2 p0, uint2 p1, uint2 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.umad = call <3 x i32>  @llvm.dx.umad.v3i32(<3 x i32> %0, <3 x i32> %1, <3 x i32> %2)
+// CHECK: ret <3 x i32> %dx.umad
+uint3 test_mad_uint3(uint3 p0, uint3 p1, uint3 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.umad = call <4 x i32>  @llvm.dx.umad.v4i32(<4 x i32> %0, <4 x i32> %1, <4 x i32> %2)
+// CHECK: ret <4 x i32> %dx.umad
+uint4 test_mad_uint4(uint4 p0, uint4 p1, uint4 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.umad = call i64 @llvm.dx.umad.i64(i64 %0, i64 %1, i64 %2)
+// CHECK: ret i64 %dx.umad
+uint64_t test_mad_uint64_t(uint64_t p0, uint64_t p1, uint64_t p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.umad = call <2 x i64>  @llvm.dx.umad.v2i64(<2 x i64> %0, <2 x i64> %1, <2 x i64> %2)
+// CHECK: ret <2 x i64> %dx.umad
+uint64_t2 test_mad_uint64_t2(uint64_t2 p0, uint64_t2 p1, uint64_t2 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.umad = call <3 x i64>  @llvm.dx.umad.v3i64(<3 x i64> %0, <3 x i64> %1, <3 x i64> %2)
+// CHECK: ret <3 x i64> %dx.umad
+uint64_t3 test_mad_uint64_t3(uint64_t3 p0, uint64_t3 p1, uint64_t3 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.umad = call <4 x i64>  @llvm.dx.umad.v4i64(<4 x i64> %0, <4 x i64> %1, <4 x i64> %2)
+// CHECK: ret <4 x i64> %dx.umad
+uint64_t4 test_mad_uint64_t4(uint64_t4 p0, uint64_t4 p1, uint64_t4 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %splat.splat, <2 x float> %1, <2 x float> %2)
+// CHECK: ret <2 x float> %dx.fmad
+float2 test_mad_float2_splat(float p0, float2 p1, float2 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %splat.splat, <3 x float> %1, <3 x float> %2)
+// CHECK: ret <3 x float> %dx.fmad
+float3 test_mad_float3_splat(float p0, float3 p1, float3 p2) { return mad(p0, p1, p2); }
+
+// CHECK:  %dx.fmad = call <4 x float>  @llvm.fmuladd.v4f32(<4 x float> %splat.splat, <4 x float> %1, <4 x float> %2)
+// CHECK:  ret <4 x float> %dx.fmad
+float4 test_mad_float4_splat(float p0, float4 p1, float4 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %conv = sitofp i32 %2 to float
+// CHECK: %splat.splatinsert = insertelement <2 x float> poison, float %conv, i64 0
+// CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
+// CHECK: %dx.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %splat.splat)
+// CHECK: ret <2 x float> %dx.fmad
+float2 test_mad_float2_int_splat(float2 p0, float2 p1, int p2) {
+  return mad(p0, p1, p2);
+}
+
+// CHECK: %conv = sitofp i32 %2 to float
+// CHECK: %splat.splatinsert = insertelement <3 x float> poison, float %conv, i64 0
+// CHECK: %splat.splat = shufflevector <3 x float> %splat.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer
+// CHECK:  %dx.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %splat.splat)
+// CHECK: ret <3 x float> %dx.fmad
+float3 test_mad_float3_int_splat(float3 p0, float3 p1, int p2) {
+  return mad(p0, p1, p2);
+}
+
+// CHECK: %dx.umad = call i1 @llvm.dx.umad.i1(i1 %tobool, i1 %tobool1, i1 %tobool2)
+// CHECK: ret i1 %dx.umad
+bool test_builtin_mad_bool_type_promotion(bool p0) {
+  return __builtin_hlsl_mad(p0, p0, p0);
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/mad-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/mad-errors.hlsl
new file mode 100644
index 00000000000000..b60ff1d3aa43e0
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/mad-errors.hlsl
@@ -0,0 +1,86 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected
+
+float2 test_no_second_arg(float2 p0) {
+ ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/83826


More information about the cfe-commits mailing list