[clang] [llvm] [HLSL] implement `mad` intrinsic (PR #83826)

Farzon Lotfi via cfe-commits cfe-commits at lists.llvm.org
Mon Mar 4 13:17:26 PST 2024


https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/83826

>From 6aa6d2cd8b999a0173130f7675f561595ed65d8d Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzon at farzon.com>
Date: Sat, 2 Mar 2024 16:40:39 -0500
Subject: [PATCH 1/4] [HLSL] implement mad intrinsic This change implements
 #83736 The dot product lowering needs a tertiary multipy addoperation. DXIL
 has three mad opcodes for fmad(46), imad(48), and umad(49). Dot product in
 DXIL only uses imad\umad, but for completeness fmad was also included. Two
 new intrinsics needed to be created to complete this change. the fmad case
 already existed.

hlsl_intrinsics.h - exposed mad api call.
Builtins.td - exposed a mad builtin.
Sema.h - make tertiary calls check for float types optional.
CGBuiltin.cpp - pick the intrinsic for singed\unsigned & float also reuse int_fmuladd.
SemaChecking.cpp - type checks for __builtin_hlsl_mad.
IntrinsicsDirectX.td create the two new intrinsics for imad\umad/
DXIL.td - create the llvm intrinsic to  DXIL opcode mapping.
---
 clang/include/clang/Basic/Builtins.td        |   6 +
 clang/include/clang/Sema/Sema.h              |   2 +-
 clang/lib/CodeGen/CGBuiltin.cpp              |  19 ++
 clang/lib/Headers/hlsl/hlsl_intrinsics.h     | 105 ++++++++++
 clang/lib/Sema/SemaChecking.cpp              |  22 ++-
 clang/test/CodeGenHLSL/builtins/mad.hlsl     | 197 +++++++++++++++++++
 clang/test/SemaHLSL/BuiltIns/mad-errors.hlsl |  86 ++++++++
 llvm/include/llvm/IR/IntrinsicsDirectX.td    |   4 +
 llvm/lib/Target/DirectX/DXIL.td              |   6 +
 llvm/test/CodeGen/DirectX/fmad.ll            |  67 +++++++
 llvm/test/CodeGen/DirectX/imad.ll            |  65 ++++++
 llvm/test/CodeGen/DirectX/umad.ll            |  65 ++++++
 12 files changed, 637 insertions(+), 7 deletions(-)
 create mode 100644 clang/test/CodeGenHLSL/builtins/mad.hlsl
 create mode 100644 clang/test/SemaHLSL/BuiltIns/mad-errors.hlsl
 create mode 100644 llvm/test/CodeGen/DirectX/fmad.ll
 create mode 100644 llvm/test/CodeGen/DirectX/imad.ll
 create mode 100644 llvm/test/CodeGen/DirectX/umad.ll

diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 2c83dca248fb7d..c4f466208793ea 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4548,6 +4548,12 @@ def HLSLLerp : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLMad : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_mad"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
   let Spellings = ["__xray_customevent"];
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 29baf542f6cbf5..06db8d45435f42 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -14146,7 +14146,7 @@ class Sema final {
   bool SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res);
   bool SemaBuiltinVectorToScalarMath(CallExpr *TheCall);
   bool SemaBuiltinElementwiseMath(CallExpr *TheCall);
-  bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall);
+  bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall, bool enforceFloatingPointCheck = true);
   bool PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall);
   bool PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall);
 
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 9ee51ca7142c77..191f6d5a9fab5e 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18044,6 +18044,25 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         /*ReturnType*/ Op0->getType(), Intrinsic::dx_frac,
         ArrayRef<Value *>{Op0}, nullptr, "dx.frac");
   }
+  case Builtin::BI__builtin_hlsl_mad: {
+    Value *M = EmitScalarExpr(E->getArg(0));
+    Value *A = EmitScalarExpr(E->getArg(1));
+    Value *B = EmitScalarExpr(E->getArg(2));
+    if (E->getArg(0)->getType()->hasFloatingRepresentation()) {
+      return Builder.CreateIntrinsic(
+          /*ReturnType*/ M->getType(), Intrinsic::fmuladd,
+          ArrayRef<Value *>{M, A, B}, nullptr, "dx.fmad");
+    }
+    if (E->getArg(0)->getType()->hasSignedIntegerRepresentation()) {
+      return Builder.CreateIntrinsic(
+          /*ReturnType*/ M->getType(), Intrinsic::dx_imad,
+          ArrayRef<Value *>{M, A, B}, nullptr, "dx.imad");
+    }
+    assert(E->getArg(0)->getType()->hasUnsignedIntegerRepresentation());
+    return Builder.CreateIntrinsic(
+        /*ReturnType*/ M->getType(), Intrinsic::dx_umad,
+        ArrayRef<Value *>{M, A, B}, nullptr, "dx.umad");
+  }
   }
   return nullptr;
 }
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 5180530363889f..b5bef78fae72fe 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -511,6 +511,111 @@ double3 log2(double3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_log2)
 double4 log2(double4);
 
+//===----------------------------------------------------------------------===//
+// mad builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T mad(T M, T A, T B)
+/// \brief The result of \a M * \a A + \a B.
+/// \param M The multiplication value.
+/// \param A The first addition value.
+/// \param B The second addition value.
+
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+half mad(half, half, half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+half2 mad(half2, half2, half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+half3 mad(half3, half3, half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+half4 mad(half4, half4, half4);
+
+#ifdef __HLSL_ENABLE_16_BIT
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int16_t mad(int16_t, int16_t, int16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int16_t2 mad(int16_t2, int16_t2, int16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int16_t3 mad(int16_t3, int16_t3, int16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int16_t4 mad(int16_t4, int16_t4, int16_t4);
+
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint16_t mad(uint16_t, uint16_t, uint16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint16_t2 mad(uint16_t2, uint16_t2, uint16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint16_t3 mad(uint16_t3, uint16_t3, uint16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint16_t4 mad(uint16_t4, uint16_t4, uint16_t4);
+#endif
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int mad(int, int, int);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int2 mad(int2, int2, int2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int3 mad(int3, int3, int3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int4 mad(int4, int4, int4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint mad(uint, uint, uint);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint2 mad(uint2, uint2, uint2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint3 mad(uint3, uint3, uint3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint4 mad(uint4, uint4, uint4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int64_t mad(int64_t, int64_t, int64_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int64_t2 mad(int64_t2, int64_t2, int64_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int64_t3 mad(int64_t3, int64_t3, int64_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int64_t4 mad(int64_t4, int64_t4, int64_t4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint64_t mad(uint64_t, uint64_t, uint64_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint64_t2 mad(uint64_t2, uint64_t2, uint64_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint64_t3 mad(uint64_t3, uint64_t3, uint64_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint64_t4 mad(uint64_t4, uint64_t4, uint64_t4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+float mad(float, float, float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+float2 mad(float2, float2, float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+float3 mad(float3, float3, float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+float4 mad(float4, float4, float4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+double mad(double, double, double);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+double2 mad(double2, double2, double2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+double3 mad(double3, double3, double3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+double4 mad(double4, double4, double4);
+
 //===----------------------------------------------------------------------===//
 // max builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 0d4d57db01c93a..795d59ed5fd369 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -5298,6 +5298,14 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_mad: {
+    if (checkArgCount(*this, TheCall, 3))
+      return true;
+    if (CheckVectorElementCallArgs(this, TheCall))
+      return true;
+    if (SemaBuiltinElementwiseTernaryMath(TheCall, false))
+      return true;
+  }
   }
   return false;
 }
@@ -19800,7 +19808,7 @@ bool Sema::SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res) {
   return false;
 }
 
-bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall) {
+bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall, bool enforceFloatingPointCheck) {
   if (checkArgCount(*this, TheCall, 3))
     return true;
 
@@ -19812,11 +19820,13 @@ bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall) {
     Args[I] = Converted.get();
   }
 
-  int ArgOrdinal = 1;
-  for (Expr *Arg : Args) {
-    if (checkFPMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(),
-                                      ArgOrdinal++))
-      return true;
+  if(enforceFloatingPointCheck) {
+    int ArgOrdinal = 1;
+    for (Expr *Arg : Args) {
+      if (checkFPMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(),
+                                        ArgOrdinal++))
+        return true;
+    }
   }
 
   for (int I = 1; I < 3; ++I) {
diff --git a/clang/test/CodeGenHLSL/builtins/mad.hlsl b/clang/test/CodeGenHLSL/builtins/mad.hlsl
new file mode 100644
index 00000000000000..4dd8f01785afa0
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/mad.hlsl
@@ -0,0 +1,197 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN:   -emit-llvm -disable-llvm-passes -o - | FileCheck %s \ 
+// RUN:   --check-prefixes=CHECK,NATIVE_HALF
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
+
+#ifdef __HLSL_ENABLE_16_BIT
+// NATIVE_HALF: %dx.umad = call i16 @llvm.dx.umad.i16(i16 %0, i16 %1, i16 %2)
+// NATIVE_HALF: ret i16 %dx.umad
+uint16_t test_mad_uint16_t(uint16_t p0, uint16_t p1, uint16_t p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.umad = call <2 x i16>  @llvm.dx.umad.v2i16(<2 x i16> %0, <2 x i16> %1, <2 x i16> %2)
+// NATIVE_HALF: ret <2 x i16> %dx.umad
+uint16_t2 test_mad_uint16_t2(uint16_t2 p0, uint16_t2 p1, uint16_t2 p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.umad = call <3 x i16>  @llvm.dx.umad.v3i16(<3 x i16> %0, <3 x i16> %1, <3 x i16> %2)
+// NATIVE_HALF: ret <3 x i16> %dx.umad
+uint16_t3 test_mad_uint16_t3(uint16_t3 p0, uint16_t3 p1, uint16_t3 p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.umad = call <4 x i16>  @llvm.dx.umad.v4i16(<4 x i16> %0, <4 x i16> %1, <4 x i16> %2)
+// NATIVE_HALF: ret <4 x i16> %dx.umad
+uint16_t4 test_mad_uint16_t4(uint16_t4 p0, uint16_t4 p1, uint16_t4 p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.imad = call i16 @llvm.dx.imad.i16(i16 %0, i16 %1, i16 %2)
+// NATIVE_HALF: ret i16 %dx.imad
+int16_t test_mad_int16_t(int16_t p0, int16_t p1, int16_t p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.imad = call <2 x i16>  @llvm.dx.imad.v2i16(<2 x i16> %0, <2 x i16> %1, <2 x i16> %2)
+// NATIVE_HALF: ret <2 x i16> %dx.imad
+int16_t2 test_mad_int16_t2(int16_t2 p0, int16_t2 p1, int16_t2 p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.imad = call <3 x i16>  @llvm.dx.imad.v3i16(<3 x i16> %0, <3 x i16> %1, <3 x i16> %2)
+// NATIVE_HALF: ret <3 x i16> %dx.imad
+int16_t3 test_mad_int16_t3(int16_t3 p0, int16_t3 p1, int16_t3 p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.imad = call <4 x i16>  @llvm.dx.imad.v4i16(<4 x i16> %0, <4 x i16> %1, <4 x i16> %2)
+// NATIVE_HALF: ret <4 x i16> %dx.imad
+int16_t4 test_mad_int16_t4(int16_t4 p0, int16_t4 p1, int16_t4 p2) { return mad(p0, p1, p2); }
+#endif // __HLSL_ENABLE_16_BIT
+
+// NATIVE_HALF: %dx.fmad = call half @llvm.fmuladd.f16(half %0, half %1, half %2)
+// NATIVE_HALF: ret half %dx.fmad
+// NO_HALF: %dx.fmad = call float @llvm.fmuladd.f32(float %0, float %1, float %2)
+// NO_HALF: ret float %dx.fmad
+half test_mad_half(half p0, half p1, half p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.fmad = call <2 x half>  @llvm.fmuladd.v2f16(<2 x half> %0, <2 x half> %1, <2 x half> %2)
+// NATIVE_HALF: ret <2 x half> %dx.fmad
+// NO_HALF: %dx.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
+// NO_HALF: ret <2 x float> %dx.fmad
+half2 test_mad_half2(half2 p0, half2 p1, half2 p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.fmad = call <3 x half>  @llvm.fmuladd.v3f16(<3 x half> %0, <3 x half> %1, <3 x half> %2)
+// NATIVE_HALF: ret <3 x half> %dx.fmad
+// NO_HALF: %dx.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
+// NO_HALF: ret <3 x float> %dx.fmad
+half3 test_mad_half3(half3 p0, half3 p1, half3 p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.fmad = call <4 x half>  @llvm.fmuladd.v4f16(<4 x half> %0, <4 x half> %1, <4 x half> %2)
+// NATIVE_HALF: ret <4 x half> %dx.fmad
+// NO_HALF: %dx.fmad = call <4 x float>  @llvm.fmuladd.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
+// NO_HALF: ret <4 x float> %dx.fmad
+half4 test_mad_half4(half4 p0, half4 p1, half4 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call float @llvm.fmuladd.f32(float %0, float %1, float %2)
+// CHECK: ret float %dx.fmad
+float test_mad_float(float p0, float p1, float p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
+// CHECK: ret <2 x float> %dx.fmad
+float2 test_mad_float2(float2 p0, float2 p1, float2 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
+// CHECK: ret <3 x float> %dx.fmad
+float3 test_mad_float3(float3 p0, float3 p1, float3 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call <4 x float>  @llvm.fmuladd.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
+// CHECK: ret <4 x float> %dx.fmad
+float4 test_mad_float4(float4 p0, float4 p1, float4 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call double @llvm.fmuladd.f64(double %0, double %1, double %2)
+// CHECK: ret double %dx.fmad
+double test_mad_double(double p0, double p1, double p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call <2 x double>  @llvm.fmuladd.v2f64(<2 x double> %0, <2 x double> %1, <2 x double> %2)
+// CHECK: ret <2 x double> %dx.fmad
+double2 test_mad_double2(double2 p0, double2 p1, double2 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call <3 x double>  @llvm.fmuladd.v3f64(<3 x double> %0, <3 x double> %1, <3 x double> %2)
+// CHECK: ret <3 x double> %dx.fmad
+double3 test_mad_double3(double3 p0, double3 p1, double3 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call <4 x double>  @llvm.fmuladd.v4f64(<4 x double> %0, <4 x double> %1, <4 x double> %2)
+// CHECK: ret <4 x double> %dx.fmad
+double4 test_mad_double4(double4 p0, double4 p1, double4 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.imad = call i32 @llvm.dx.imad.i32(i32 %0, i32 %1, i32 %2)
+// CHECK: ret i32 %dx.imad
+int test_mad_int(int p0, int p1, int p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.imad = call <2 x i32>  @llvm.dx.imad.v2i32(<2 x i32> %0, <2 x i32> %1, <2 x i32> %2)
+// CHECK: ret <2 x i32> %dx.imad
+int2 test_mad_int2(int2 p0, int2 p1, int2 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.imad = call <3 x i32>  @llvm.dx.imad.v3i32(<3 x i32> %0, <3 x i32> %1, <3 x i32> %2)
+// CHECK: ret <3 x i32> %dx.imad
+int3 test_mad_int3(int3 p0, int3 p1, int3 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.imad = call <4 x i32>  @llvm.dx.imad.v4i32(<4 x i32> %0, <4 x i32> %1, <4 x i32> %2)
+// CHECK: ret <4 x i32> %dx.imad
+int4 test_mad_int4(int4 p0, int4 p1, int4 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.imad = call i64 @llvm.dx.imad.i64(i64 %0, i64 %1, i64 %2)
+// CHECK: ret i64 %dx.imad
+int64_t test_mad_int64_t(int64_t p0, int64_t p1, int64_t p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.imad = call <2 x i64>  @llvm.dx.imad.v2i64(<2 x i64> %0, <2 x i64> %1, <2 x i64> %2)
+// CHECK: ret <2 x i64> %dx.imad
+int64_t2 test_mad_int64_t2(int64_t2 p0, int64_t2 p1, int64_t2 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.imad = call <3 x i64>  @llvm.dx.imad.v3i64(<3 x i64> %0, <3 x i64> %1, <3 x i64> %2)
+// CHECK: ret <3 x i64> %dx.imad
+int64_t3 test_mad_int64_t3(int64_t3 p0, int64_t3 p1, int64_t3 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.imad = call <4 x i64>  @llvm.dx.imad.v4i64(<4 x i64> %0, <4 x i64> %1, <4 x i64> %2)
+// CHECK: ret <4 x i64> %dx.imad
+int64_t4 test_mad_int64_t4(int64_t4 p0, int64_t4 p1, int64_t4 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.umad = call i32 @llvm.dx.umad.i32(i32 %0, i32 %1, i32 %2)
+// CHECK: ret i32 %dx.umad
+uint test_mad_uint(uint p0, uint p1, uint p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.umad = call <2 x i32>  @llvm.dx.umad.v2i32(<2 x i32> %0, <2 x i32> %1, <2 x i32> %2)
+// CHECK: ret <2 x i32> %dx.umad
+uint2 test_mad_uint2(uint2 p0, uint2 p1, uint2 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.umad = call <3 x i32>  @llvm.dx.umad.v3i32(<3 x i32> %0, <3 x i32> %1, <3 x i32> %2)
+// CHECK: ret <3 x i32> %dx.umad
+uint3 test_mad_uint3(uint3 p0, uint3 p1, uint3 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.umad = call <4 x i32>  @llvm.dx.umad.v4i32(<4 x i32> %0, <4 x i32> %1, <4 x i32> %2)
+// CHECK: ret <4 x i32> %dx.umad
+uint4 test_mad_uint4(uint4 p0, uint4 p1, uint4 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.umad = call i64 @llvm.dx.umad.i64(i64 %0, i64 %1, i64 %2)
+// CHECK: ret i64 %dx.umad
+uint64_t test_mad_uint64_t(uint64_t p0, uint64_t p1, uint64_t p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.umad = call <2 x i64>  @llvm.dx.umad.v2i64(<2 x i64> %0, <2 x i64> %1, <2 x i64> %2)
+// CHECK: ret <2 x i64> %dx.umad
+uint64_t2 test_mad_uint64_t2(uint64_t2 p0, uint64_t2 p1, uint64_t2 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.umad = call <3 x i64>  @llvm.dx.umad.v3i64(<3 x i64> %0, <3 x i64> %1, <3 x i64> %2)
+// CHECK: ret <3 x i64> %dx.umad
+uint64_t3 test_mad_uint64_t3(uint64_t3 p0, uint64_t3 p1, uint64_t3 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.umad = call <4 x i64>  @llvm.dx.umad.v4i64(<4 x i64> %0, <4 x i64> %1, <4 x i64> %2)
+// CHECK: ret <4 x i64> %dx.umad
+uint64_t4 test_mad_uint64_t4(uint64_t4 p0, uint64_t4 p1, uint64_t4 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %splat.splat, <2 x float> %1, <2 x float> %2)
+// CHECK: ret <2 x float> %dx.fmad
+float2 test_mad_float2_splat(float p0, float2 p1, float2 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %splat.splat, <3 x float> %1, <3 x float> %2)
+// CHECK: ret <3 x float> %dx.fmad
+float3 test_mad_float3_splat(float p0, float3 p1, float3 p2) { return mad(p0, p1, p2); }
+
+// CHECK:  %dx.fmad = call <4 x float>  @llvm.fmuladd.v4f32(<4 x float> %splat.splat, <4 x float> %1, <4 x float> %2)
+// CHECK:  ret <4 x float> %dx.fmad
+float4 test_mad_float4_splat(float p0, float4 p1, float4 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %conv = sitofp i32 %2 to float
+// CHECK: %splat.splatinsert = insertelement <2 x float> poison, float %conv, i64 0
+// CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
+// CHECK: %dx.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %splat.splat)
+// CHECK: ret <2 x float> %dx.fmad
+float2 test_mad_float2_int_splat(float2 p0, float2 p1, int p2) {
+  return mad(p0, p1, p2);
+}
+
+// CHECK: %conv = sitofp i32 %2 to float
+// CHECK: %splat.splatinsert = insertelement <3 x float> poison, float %conv, i64 0
+// CHECK: %splat.splat = shufflevector <3 x float> %splat.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer
+// CHECK:  %dx.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %splat.splat)
+// CHECK: ret <3 x float> %dx.fmad
+float3 test_mad_float3_int_splat(float3 p0, float3 p1, int p2) {
+  return mad(p0, p1, p2);
+}
+
+// CHECK: %dx.umad = call i1 @llvm.dx.umad.i1(i1 %tobool, i1 %tobool1, i1 %tobool2)
+// CHECK: ret i1 %dx.umad
+bool test_builtin_mad_bool_type_promotion(bool p0) {
+  return __builtin_hlsl_mad(p0, p0, p0);
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/mad-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/mad-errors.hlsl
new file mode 100644
index 00000000000000..b60ff1d3aa43e0
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/mad-errors.hlsl
@@ -0,0 +1,86 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected
+
+float2 test_no_second_arg(float2 p0) {
+  return __builtin_hlsl_mad(p0);
+  // expected-error at -1 {{too few arguments to function call, expected 3, have 1}}
+}
+
+float2 test_no_third_arg(float2 p0) {
+  return __builtin_hlsl_mad(p0, p0);
+  // expected-error at -1 {{too few arguments to function call, expected 3, have 2}}
+}
+
+float2 test_too_many_arg(float2 p0) {
+  return __builtin_hlsl_mad(p0, p0, p0, p0);
+  // expected-error at -1 {{too many arguments to function call, expected 3, have 4}}
+}
+
+float2 test_mad_no_second_arg(float2 p0) {
+  return mad(p0);
+  // expected-error at -1 {{no matching function for call to 'mad'}}
+}
+
+float2 test_mad_vector_size_mismatch(float3 p0, float2 p1) {
+  return mad(p0, p0, p1);
+  // expected-warning at -1 {{implicit conversion truncates vector: 'float3' (aka 'vector<float, 3>') to 'float __attribute__((ext_vector_type(2)))' (vector of 2 'float' values)}}
+}
+
+float2 test_mad_builtin_vector_size_mismatch(float3 p0, float2 p1) {
+  return __builtin_hlsl_mad(p0, p1, p1);
+  // expected-error at -1 {{all arguments to '__builtin_hlsl_mad' must have the same type}}
+}
+
+float test_mad_scalar_mismatch(float p0, half p1) {
+  return mad(p1, p0, p1);
+  // expected-error at -1 {{call to 'mad' is ambiguous}}
+}
+
+float2 test_mad_element_type_mismatch(half2 p0, float2 p1) {
+  return mad(p1, p0, p1);
+  // expected-error at -1 {{call to 'mad' is ambiguous}}
+}
+
+float2 test_builtin_mad_float2_splat(float p0, float2 p1) {
+  return __builtin_hlsl_mad(p0, p1, p1);
+  // expected-error at -1 {{all arguments to '__builtin_hlsl_mad' must be vectors}}
+}
+
+float3 test_builtin_mad_float3_splat(float p0, float3 p1) {
+  return __builtin_hlsl_mad(p0, p1, p1);
+  // expected-error at -1 {{all arguments to '__builtin_hlsl_mad' must be vectors}}
+}
+
+float4 test_builtin_mad_float4_splat(float p0, float4 p1) {
+  return __builtin_hlsl_mad(p0, p1, p1);
+  // expected-error at -1 {{all arguments to '__builtin_hlsl_mad' must be vectors}}
+}
+
+float2 test_mad_float2_int_splat(float2 p0, int p1) {
+  return __builtin_hlsl_mad(p0, p1, p1);
+  // expected-error at -1 {{all arguments to '__builtin_hlsl_mad' must be vectors}}
+}
+
+float3 test_mad_float3_int_splat(float3 p0, int p1) {
+  return __builtin_hlsl_mad(p0, p1, p1);
+  // expected-error at -1 {{all arguments to '__builtin_hlsl_mad' must be vectors}}
+}
+
+float2 test_builtin_mad_int_vect_to_float_vec_promotion(int2 p0, float p1) {
+  return __builtin_hlsl_mad(p0, p1, p1);
+  // expected-error at -1 {{all arguments to '__builtin_hlsl_mad' must be vectors}}
+}
+
+float builtin_bool_to_float_type_promotion(float p0, bool p1) {
+  return __builtin_hlsl_mad(p0, p0, p1);
+  // expected-error at -1 {{arguments are of different types ('double' vs 'bool')}}
+}
+
+float builtin_bool_to_float_type_promotion2(bool p0, float p1) {
+  return __builtin_hlsl_mad(p1, p0, p1);
+  // expected-error at -1 {{arguments are of different types ('double' vs 'bool')}}
+}
+
+float builtin_mad_int_to_float_promotion(float p0, int p1) {
+  return __builtin_hlsl_mad(p0, p0, p1);
+  // expected-error at -1 {{arguments are of different types ('double' vs 'int')}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index b44d1c6d3d2f06..08faa17d234d09 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -31,4 +31,8 @@ def int_dx_lerp :
     Intrinsic<[LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
     [llvm_anyvector_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>,LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
     [IntrNoMem, IntrWillReturn] >;
+
+
+def int_dx_imad  : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
+def int_dx_umad  : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
 }
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 33b08ed93e3d0a..8f73c08a658099 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -226,6 +226,12 @@ def Round : DXILOpMapping<26, unary, int_round,
                          "within a floating-point type.">;
 def UMax : DXILOpMapping<39, binary, int_umax,
                          "Unsigned integer maximum. UMax(a,b) = a > b ? a : b">;
+def FMad : DXILOpMapping<46, tertiary, int_fmuladd,
+                         "Floating point arithmetic multiply/add operation. fmad(m,a,b) = m * a + b.">;
+def IMad : DXILOpMapping<48, tertiary, int_dx_imad,
+                         "Signed integer arithmetic multiply/add operation. imad(m,a,b) = m * a + b.">;
+def UMad : DXILOpMapping<49, tertiary, int_dx_umad,
+                         "Unsigned integer arithmetic multiply/add operation. umad(m,a,b) = m * a + b.">;
 def ThreadId : DXILOpMapping<93, threadId, int_dx_thread_id,
                              "Reads the thread ID">;
 def GroupId  : DXILOpMapping<94, groupId, int_dx_group_id,
diff --git a/llvm/test/CodeGen/DirectX/fmad.ll b/llvm/test/CodeGen/DirectX/fmad.ll
new file mode 100644
index 00000000000000..ab0a5cfc8ec5a3
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/fmad.ll
@@ -0,0 +1,67 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for round are generated for float and half.
+; CHECK:call half @dx.op.tertiary.f16(i32 46, half %{{.*}}, half %{{.*}}, half %{{.*}})
+; CHECK:call float @dx.op.tertiary.f32(i32 46, float %{{.*}}, float %{{.*}}, float %{{.*}})
+; CHECK:call double @dx.op.tertiary.f64(i32 46, double %{{.*}}, double %{{.*}}, double %{{.*}})
+
+
+target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
+target triple = "dxil-pc-shadermodel6.7-library"
+
+; Function Attrs: noinline nounwind optnone
+define noundef half @"?test_mad_half@@YA$f16@$f16 at 00@Z"(half noundef %p0, half noundef %p1, half noundef %p2) #0 {
+entry:
+  %p2.addr = alloca half, align 2
+  %p1.addr = alloca half, align 2
+  %p0.addr = alloca half, align 2
+  store half %p2, ptr %p2.addr, align 2
+  store half %p1, ptr %p1.addr, align 2
+  store half %p0, ptr %p0.addr, align 2
+  %0 = load half, ptr %p0.addr, align 2
+  %1 = load half, ptr %p1.addr, align 2
+  %2 = load half, ptr %p2.addr, align 2
+  %dx.fmad = call half @llvm.fmuladd.f16(half %0, half %1, half %2)
+  ret half %dx.fmad
+}
+
+; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
+declare half @llvm.fmuladd.f16(half, half, half) #2
+
+; Function Attrs: noinline nounwind optnone
+define noundef float @"?test_mad_float@@YAMMMM at Z"(float noundef %p0, float noundef %p1, float noundef %p2) #0 {
+entry:
+  %p2.addr = alloca float, align 4
+  %p1.addr = alloca float, align 4
+  %p0.addr = alloca float, align 4
+  store float %p2, ptr %p2.addr, align 4
+  store float %p1, ptr %p1.addr, align 4
+  store float %p0, ptr %p0.addr, align 4
+  %0 = load float, ptr %p0.addr, align 4
+  %1 = load float, ptr %p1.addr, align 4
+  %2 = load float, ptr %p2.addr, align 4
+  %dx.fmad = call float @llvm.fmuladd.f32(float %0, float %1, float %2)
+  ret float %dx.fmad
+}
+
+; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
+declare float @llvm.fmuladd.f32(float, float, float) #2
+
+; Function Attrs: noinline nounwind optnone
+define noundef double @"?test_mad_double@@YANNNN at Z"(double noundef %p0, double noundef %p1, double noundef %p2) #0 {
+entry:
+  %p2.addr = alloca double, align 8
+  %p1.addr = alloca double, align 8
+  %p0.addr = alloca double, align 8
+  store double %p2, ptr %p2.addr, align 8
+  store double %p1, ptr %p1.addr, align 8
+  store double %p0, ptr %p0.addr, align 8
+  %0 = load double, ptr %p0.addr, align 8
+  %1 = load double, ptr %p1.addr, align 8
+  %2 = load double, ptr %p2.addr, align 8
+  %dx.fmad = call double @llvm.fmuladd.f64(double %0, double %1, double %2)
+  ret double %dx.fmad
+}
+
+; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
+declare double @llvm.fmuladd.f64(double, double, double) #2
diff --git a/llvm/test/CodeGen/DirectX/imad.ll b/llvm/test/CodeGen/DirectX/imad.ll
new file mode 100644
index 00000000000000..f7874fa5843405
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/imad.ll
@@ -0,0 +1,65 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for round are generated for float and half.
+; CHECK:call i16 @dx.op.tertiary.i16(i32 48, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
+; CHECK:call i32 @dx.op.tertiary.i32(i32 48, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+; CHECK:call i64 @dx.op.tertiary.i64(i32 48, i64 %{{.*}}, i64 %{{.*}}, i64 %{{.*}})
+
+target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
+target triple = "dxil-pc-shadermodel6.7-library"
+; Function Attrs: noinline nounwind optnone
+define noundef i16 @"?test_mad_int16_t@@YAFFFF at Z"(i16 noundef %p0, i16 noundef %p1, i16 noundef %p2) #0 {
+entry:
+  %p2.addr = alloca i16, align 2
+  %p1.addr = alloca i16, align 2
+  %p0.addr = alloca i16, align 2
+  store i16 %p2, ptr %p2.addr, align 2
+  store i16 %p1, ptr %p1.addr, align 2
+  store i16 %p0, ptr %p0.addr, align 2
+  %0 = load i16, ptr %p0.addr, align 2
+  %1 = load i16, ptr %p1.addr, align 2
+  %2 = load i16, ptr %p2.addr, align 2
+  %dx.imad = call i16 @llvm.dx.imad.i16(i16 %0, i16 %1, i16 %2)
+  ret i16 %dx.imad
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare i16 @llvm.dx.imad.i16(i16, i16, i16) #1
+
+; Function Attrs: noinline nounwind optnone
+define noundef i32 @"?test_mad_uint@@YAIIII at Z"(i32 noundef %p0, i32 noundef %p1, i32 noundef %p2) #0 {
+entry:
+  %p2.addr = alloca i32, align 4
+  %p1.addr = alloca i32, align 4
+  %p0.addr = alloca i32, align 4
+  store i32 %p2, ptr %p2.addr, align 4
+  store i32 %p1, ptr %p1.addr, align 4
+  store i32 %p0, ptr %p0.addr, align 4
+  %0 = load i32, ptr %p0.addr, align 4
+  %1 = load i32, ptr %p1.addr, align 4
+  %2 = load i32, ptr %p2.addr, align 4
+  %dx.imad = call i32 @llvm.dx.imad.i32(i32 %0, i32 %1, i32 %2)
+  ret i32 %dx.imad
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare i32 @llvm.dx.imad.i32(i32, i32, i32) #1
+
+; Function Attrs: noinline nounwind optnone
+define noundef i64 @"?test_mad_uint64_t@@YAKKKK at Z"(i64 noundef %p0, i64 noundef %p1, i64 noundef %p2) #0 {
+entry:
+  %p2.addr = alloca i64, align 8
+  %p1.addr = alloca i64, align 8
+  %p0.addr = alloca i64, align 8
+  store i64 %p2, ptr %p2.addr, align 8
+  store i64 %p1, ptr %p1.addr, align 8
+  store i64 %p0, ptr %p0.addr, align 8
+  %0 = load i64, ptr %p0.addr, align 8
+  %1 = load i64, ptr %p1.addr, align 8
+  %2 = load i64, ptr %p2.addr, align 8
+  %dx.imad = call i64 @llvm.dx.imad.i64(i64 %0, i64 %1, i64 %2)
+  ret i64 %dx.imad
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare i64 @llvm.dx.imad.i64(i64, i64, i64) #1
diff --git a/llvm/test/CodeGen/DirectX/umad.ll b/llvm/test/CodeGen/DirectX/umad.ll
new file mode 100644
index 00000000000000..986ac2c57c08d0
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/umad.ll
@@ -0,0 +1,65 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for round are generated for float and half.
+; CHECK:call i16 @dx.op.tertiary.i16(i32 49, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
+; CHECK:call i32 @dx.op.tertiary.i32(i32 49, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+; CHECK:call i64 @dx.op.tertiary.i64(i32 49, i64 %{{.*}}, i64 %{{.*}}, i64 %{{.*}})
+
+target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
+target triple = "dxil-pc-shadermodel6.7-library"
+; Function Attrs: noinline nounwind optnone
+define noundef i16 @"?test_mad_uint16_t@@YAGGGG at Z"(i16 noundef %p0, i16 noundef %p1, i16 noundef %p2) #0 {
+entry:
+  %p2.addr = alloca i16, align 2
+  %p1.addr = alloca i16, align 2
+  %p0.addr = alloca i16, align 2
+  store i16 %p2, ptr %p2.addr, align 2
+  store i16 %p1, ptr %p1.addr, align 2
+  store i16 %p0, ptr %p0.addr, align 2
+  %0 = load i16, ptr %p0.addr, align 2
+  %1 = load i16, ptr %p1.addr, align 2
+  %2 = load i16, ptr %p2.addr, align 2
+  %dx.umad = call i16 @llvm.dx.umad.i16(i16 %0, i16 %1, i16 %2)
+  ret i16 %dx.umad
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare i16 @llvm.dx.umad.i16(i16, i16, i16) #1
+
+; Function Attrs: noinline nounwind optnone
+define noundef i32 @"?test_mad_uint@@YAIIII at Z"(i32 noundef %p0, i32 noundef %p1, i32 noundef %p2) #0 {
+entry:
+  %p2.addr = alloca i32, align 4
+  %p1.addr = alloca i32, align 4
+  %p0.addr = alloca i32, align 4
+  store i32 %p2, ptr %p2.addr, align 4
+  store i32 %p1, ptr %p1.addr, align 4
+  store i32 %p0, ptr %p0.addr, align 4
+  %0 = load i32, ptr %p0.addr, align 4
+  %1 = load i32, ptr %p1.addr, align 4
+  %2 = load i32, ptr %p2.addr, align 4
+  %dx.umad = call i32 @llvm.dx.umad.i32(i32 %0, i32 %1, i32 %2)
+  ret i32 %dx.umad
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare i32 @llvm.dx.umad.i32(i32, i32, i32) #1
+
+; Function Attrs: noinline nounwind optnone
+define noundef i64 @"?test_mad_uint64_t@@YAKKKK at Z"(i64 noundef %p0, i64 noundef %p1, i64 noundef %p2) #0 {
+entry:
+  %p2.addr = alloca i64, align 8
+  %p1.addr = alloca i64, align 8
+  %p0.addr = alloca i64, align 8
+  store i64 %p2, ptr %p2.addr, align 8
+  store i64 %p1, ptr %p1.addr, align 8
+  store i64 %p0, ptr %p0.addr, align 8
+  %0 = load i64, ptr %p0.addr, align 8
+  %1 = load i64, ptr %p1.addr, align 8
+  %2 = load i64, ptr %p2.addr, align 8
+  %dx.umad = call i64 @llvm.dx.umad.i64(i64 %0, i64 %1, i64 %2)
+  ret i64 %dx.umad
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare i64 @llvm.dx.umad.i64(i64, i64, i64) #1

>From 55e2e87424bf57f155ff79f0bd375a9d9052beb9 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzon at farzon.org>
Date: Mon, 4 Mar 2024 06:26:10 -0500
Subject: [PATCH 2/4] fix formatting

---
 clang/include/clang/Sema/Sema.h | 3 ++-
 clang/lib/Sema/SemaChecking.cpp | 9 +++++----
 2 files changed, 7 insertions(+), 5 deletions(-)

diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 06db8d45435f42..2ac8b8bd4347d5 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -14146,7 +14146,8 @@ class Sema final {
   bool SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res);
   bool SemaBuiltinVectorToScalarMath(CallExpr *TheCall);
   bool SemaBuiltinElementwiseMath(CallExpr *TheCall);
-  bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall, bool enforceFloatingPointCheck = true);
+  bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall,
+                                         bool enforceFloatingPointCheck = true);
   bool PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall);
   bool PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall);
 
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 795d59ed5fd369..f5f6617fb27a2d 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -19808,7 +19808,8 @@ bool Sema::SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res) {
   return false;
 }
 
-bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall, bool enforceFloatingPointCheck) {
+bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall,
+                                             bool enforceFloatingPointCheck) {
   if (checkArgCount(*this, TheCall, 3))
     return true;
 
@@ -19820,11 +19821,11 @@ bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall, bool enforceFloa
     Args[I] = Converted.get();
   }
 
-  if(enforceFloatingPointCheck) {
+  if (enforceFloatingPointCheck) {
     int ArgOrdinal = 1;
     for (Expr *Arg : Args) {
-      if (checkFPMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(),
-                                        ArgOrdinal++))
+      if (checkFPMathBuiltinElementType(*this, Arg->getBeginLoc(),
+                                        Arg->getType(), ArgOrdinal++))
         return true;
     }
   }

>From 4360a273d57c9639f70304ca6e4a10b11f4d67f9 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzon at farzon.org>
Date: Mon, 4 Mar 2024 12:24:20 -0500
Subject: [PATCH 3/4] address pr comments

---
 llvm/include/llvm/IR/IntrinsicsDirectX.td | 1 -
 llvm/test/CodeGen/DirectX/fmad.ll         | 6 +++---
 llvm/test/CodeGen/DirectX/imad.ll         | 6 +++---
 llvm/test/CodeGen/DirectX/umad.ll         | 6 +++---
 4 files changed, 9 insertions(+), 10 deletions(-)

diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 08faa17d234d09..3096442335ce4e 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -32,7 +32,6 @@ def int_dx_lerp :
     [llvm_anyvector_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>,LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
     [IntrNoMem, IntrWillReturn] >;
 
-
 def int_dx_imad  : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
 def int_dx_umad  : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
 }
diff --git a/llvm/test/CodeGen/DirectX/fmad.ll b/llvm/test/CodeGen/DirectX/fmad.ll
index ab0a5cfc8ec5a3..693e237e70dc02 100644
--- a/llvm/test/CodeGen/DirectX/fmad.ll
+++ b/llvm/test/CodeGen/DirectX/fmad.ll
@@ -10,7 +10,7 @@ target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32
 target triple = "dxil-pc-shadermodel6.7-library"
 
 ; Function Attrs: noinline nounwind optnone
-define noundef half @"?test_mad_half@@YA$f16@$f16 at 00@Z"(half noundef %p0, half noundef %p1, half noundef %p2) #0 {
+define noundef half @fmad_half(half noundef %p0, half noundef %p1, half noundef %p2) #0 {
 entry:
   %p2.addr = alloca half, align 2
   %p1.addr = alloca half, align 2
@@ -29,7 +29,7 @@ entry:
 declare half @llvm.fmuladd.f16(half, half, half) #2
 
 ; Function Attrs: noinline nounwind optnone
-define noundef float @"?test_mad_float@@YAMMMM at Z"(float noundef %p0, float noundef %p1, float noundef %p2) #0 {
+define noundef float @fmad_float(float noundef %p0, float noundef %p1, float noundef %p2) #0 {
 entry:
   %p2.addr = alloca float, align 4
   %p1.addr = alloca float, align 4
@@ -48,7 +48,7 @@ entry:
 declare float @llvm.fmuladd.f32(float, float, float) #2
 
 ; Function Attrs: noinline nounwind optnone
-define noundef double @"?test_mad_double@@YANNNN at Z"(double noundef %p0, double noundef %p1, double noundef %p2) #0 {
+define noundef double @fmad_double(double noundef %p0, double noundef %p1, double noundef %p2) #0 {
 entry:
   %p2.addr = alloca double, align 8
   %p1.addr = alloca double, align 8
diff --git a/llvm/test/CodeGen/DirectX/imad.ll b/llvm/test/CodeGen/DirectX/imad.ll
index f7874fa5843405..5b818f86bc7f25 100644
--- a/llvm/test/CodeGen/DirectX/imad.ll
+++ b/llvm/test/CodeGen/DirectX/imad.ll
@@ -8,7 +8,7 @@
 target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
 target triple = "dxil-pc-shadermodel6.7-library"
 ; Function Attrs: noinline nounwind optnone
-define noundef i16 @"?test_mad_int16_t@@YAFFFF at Z"(i16 noundef %p0, i16 noundef %p1, i16 noundef %p2) #0 {
+define noundef i16 @imad_short(i16 noundef %p0, i16 noundef %p1, i16 noundef %p2) #0 {
 entry:
   %p2.addr = alloca i16, align 2
   %p1.addr = alloca i16, align 2
@@ -27,7 +27,7 @@ entry:
 declare i16 @llvm.dx.imad.i16(i16, i16, i16) #1
 
 ; Function Attrs: noinline nounwind optnone
-define noundef i32 @"?test_mad_uint@@YAIIII at Z"(i32 noundef %p0, i32 noundef %p1, i32 noundef %p2) #0 {
+define noundef i32 @imad_int(i32 noundef %p0, i32 noundef %p1, i32 noundef %p2) #0 {
 entry:
   %p2.addr = alloca i32, align 4
   %p1.addr = alloca i32, align 4
@@ -46,7 +46,7 @@ entry:
 declare i32 @llvm.dx.imad.i32(i32, i32, i32) #1
 
 ; Function Attrs: noinline nounwind optnone
-define noundef i64 @"?test_mad_uint64_t@@YAKKKK at Z"(i64 noundef %p0, i64 noundef %p1, i64 noundef %p2) #0 {
+define noundef i64 @imad_int64(i64 noundef %p0, i64 noundef %p1, i64 noundef %p2) #0 {
 entry:
   %p2.addr = alloca i64, align 8
   %p1.addr = alloca i64, align 8
diff --git a/llvm/test/CodeGen/DirectX/umad.ll b/llvm/test/CodeGen/DirectX/umad.ll
index 986ac2c57c08d0..583fdddfe03f34 100644
--- a/llvm/test/CodeGen/DirectX/umad.ll
+++ b/llvm/test/CodeGen/DirectX/umad.ll
@@ -8,7 +8,7 @@
 target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
 target triple = "dxil-pc-shadermodel6.7-library"
 ; Function Attrs: noinline nounwind optnone
-define noundef i16 @"?test_mad_uint16_t@@YAGGGG at Z"(i16 noundef %p0, i16 noundef %p1, i16 noundef %p2) #0 {
+define noundef i16 @umad_ushort(i16 noundef %p0, i16 noundef %p1, i16 noundef %p2) #0 {
 entry:
   %p2.addr = alloca i16, align 2
   %p1.addr = alloca i16, align 2
@@ -27,7 +27,7 @@ entry:
 declare i16 @llvm.dx.umad.i16(i16, i16, i16) #1
 
 ; Function Attrs: noinline nounwind optnone
-define noundef i32 @"?test_mad_uint@@YAIIII at Z"(i32 noundef %p0, i32 noundef %p1, i32 noundef %p2) #0 {
+define noundef i32 @umad_uint(i32 noundef %p0, i32 noundef %p1, i32 noundef %p2) #0 {
 entry:
   %p2.addr = alloca i32, align 4
   %p1.addr = alloca i32, align 4
@@ -46,7 +46,7 @@ entry:
 declare i32 @llvm.dx.umad.i32(i32, i32, i32) #1
 
 ; Function Attrs: noinline nounwind optnone
-define noundef i64 @"?test_mad_uint64_t@@YAKKKK at Z"(i64 noundef %p0, i64 noundef %p1, i64 noundef %p2) #0 {
+define noundef i64 @umad_uint64(i64 noundef %p0, i64 noundef %p1, i64 noundef %p2) #0 {
 entry:
   %p2.addr = alloca i64, align 8
   %p1.addr = alloca i64, align 8

>From ca0a146ef2f84963ff36ab97502a5c58b4a305bd Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzon at farzon.org>
Date: Mon, 4 Mar 2024 13:42:30 -0500
Subject: [PATCH 4/4] address pr comments

---
 clang/include/clang/Sema/Sema.h              |  2 +-
 clang/lib/Sema/SemaChecking.cpp              | 13 ++++++++++---
 clang/test/CodeGenHLSL/builtins/mad.hlsl     |  6 ------
 clang/test/SemaHLSL/BuiltIns/mad-errors.hlsl |  4 ++--
 4 files changed, 13 insertions(+), 12 deletions(-)

diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 2ac8b8bd4347d5..25c4c58ad4ae43 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -14147,7 +14147,7 @@ class Sema final {
   bool SemaBuiltinVectorToScalarMath(CallExpr *TheCall);
   bool SemaBuiltinElementwiseMath(CallExpr *TheCall);
   bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall,
-                                         bool enforceFloatingPointCheck = true);
+                                         bool CheckForFloatArgs = true);
   bool PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall);
   bool PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall);
 
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index f5f6617fb27a2d..f0707fcb0b9126 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -5303,7 +5303,7 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     if (CheckVectorElementCallArgs(this, TheCall))
       return true;
-    if (SemaBuiltinElementwiseTernaryMath(TheCall, false))
+    if (SemaBuiltinElementwiseTernaryMath(TheCall, /*CheckForFloatArgs*/ false))
       return true;
   }
   }
@@ -19809,7 +19809,7 @@ bool Sema::SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res) {
 }
 
 bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall,
-                                             bool enforceFloatingPointCheck) {
+                                             bool CheckForFloatArgs) {
   if (checkArgCount(*this, TheCall, 3))
     return true;
 
@@ -19821,13 +19821,20 @@ bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall,
     Args[I] = Converted.get();
   }
 
-  if (enforceFloatingPointCheck) {
+  if (CheckForFloatArgs) {
     int ArgOrdinal = 1;
     for (Expr *Arg : Args) {
       if (checkFPMathBuiltinElementType(*this, Arg->getBeginLoc(),
                                         Arg->getType(), ArgOrdinal++))
         return true;
     }
+  } else {
+    int ArgOrdinal = 1;
+    for (Expr *Arg : Args) {
+      if (checkMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(),
+                                      ArgOrdinal++))
+        return true;
+    }
   }
 
   for (int I = 1; I < 3; ++I) {
diff --git a/clang/test/CodeGenHLSL/builtins/mad.hlsl b/clang/test/CodeGenHLSL/builtins/mad.hlsl
index 4dd8f01785afa0..749eac6d64736d 100644
--- a/clang/test/CodeGenHLSL/builtins/mad.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/mad.hlsl
@@ -189,9 +189,3 @@ float2 test_mad_float2_int_splat(float2 p0, float2 p1, int p2) {
 float3 test_mad_float3_int_splat(float3 p0, float3 p1, int p2) {
   return mad(p0, p1, p2);
 }
-
-// CHECK: %dx.umad = call i1 @llvm.dx.umad.i1(i1 %tobool, i1 %tobool1, i1 %tobool2)
-// CHECK: ret i1 %dx.umad
-bool test_builtin_mad_bool_type_promotion(bool p0) {
-  return __builtin_hlsl_mad(p0, p0, p0);
-}
diff --git a/clang/test/SemaHLSL/BuiltIns/mad-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/mad-errors.hlsl
index b60ff1d3aa43e0..0b6843591455bd 100644
--- a/clang/test/SemaHLSL/BuiltIns/mad-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/mad-errors.hlsl
@@ -72,12 +72,12 @@ float2 test_builtin_mad_int_vect_to_float_vec_promotion(int2 p0, float p1) {
 
 float builtin_bool_to_float_type_promotion(float p0, bool p1) {
   return __builtin_hlsl_mad(p0, p0, p1);
-  // expected-error at -1 {{arguments are of different types ('double' vs 'bool')}}
+  // expected-error at -1 {{3rd argument must be a vector, integer or floating point type (was 'bool')}}
 }
 
 float builtin_bool_to_float_type_promotion2(bool p0, float p1) {
   return __builtin_hlsl_mad(p1, p0, p1);
-  // expected-error at -1 {{arguments are of different types ('double' vs 'bool')}}
+  // expected-error at -1 {{2nd argument must be a vector, integer or floating point type (was 'bool')}}
 }
 
 float builtin_mad_int_to_float_promotion(float p0, int p1) {



More information about the cfe-commits mailing list