[clang] [llvm] [HLSL] implement `mad` intrinsic (PR #83826)
via cfe-commits
cfe-commits at lists.llvm.org
Mon Mar 4 03:21:52 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-x86
@llvm/pr-subscribers-backend-directx
@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-clang
Author: Farzon Lotfi (farzonl)
<details>
<summary>Changes</summary>
This change implements #<!-- -->83736
The dot product lowering needs a tertiary multipy add operation. DXIL has three mad opcodes for `fmad`(46), `imad`(48), and `umad`(49). Dot product in DXIL only uses `imad`\ `umad`, but for completeness and because the hlsl `mad` intrinsic requires it `fmad` was also included. Two new intrinsics were needed to be created to complete this change. the `fmad` case already supported by llvm via `fmuladd` intrinsic.
- `hlsl_intrinsics.h` - exposed mad api call.
- `Builtins.td` - exposed a `mad` builtin.
- `Sema.h` - make `tertiary` calls check for float types optional.
- `CGBuiltin.cpp` - pick the intrinsic for singed\unsigned & float also reuse `int_fmuladd`.
- `SemaChecking.cpp` - type checks for `__builtin_hlsl_mad`.
- `IntrinsicsDirectX.td` create the two new intrinsics for `imad`\`umad`/
- `DXIL.td` - create the llvm intrinsic to `DXIL` opcode mapping.
---
Patch is 33.59 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/83826.diff
12 Files Affected:
- (modified) clang/include/clang/Basic/Builtins.td (+6)
- (modified) clang/include/clang/Sema/Sema.h (+1-1)
- (modified) clang/lib/CodeGen/CGBuiltin.cpp (+19)
- (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+105)
- (modified) clang/lib/Sema/SemaChecking.cpp (+16-6)
- (added) clang/test/CodeGenHLSL/builtins/mad.hlsl (+197)
- (added) clang/test/SemaHLSL/BuiltIns/mad-errors.hlsl (+86)
- (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+4)
- (modified) llvm/lib/Target/DirectX/DXIL.td (+6)
- (added) llvm/test/CodeGen/DirectX/fmad.ll (+67)
- (added) llvm/test/CodeGen/DirectX/imad.ll (+65)
- (added) llvm/test/CodeGen/DirectX/umad.ll (+65)
``````````diff
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 2c83dca248fb7d..c4f466208793ea 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4548,6 +4548,12 @@ def HLSLLerp : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}
+def HLSLMad : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_mad"];
+ let Attributes = [NoThrow, Const];
+ let Prototype = "void(...)";
+}
+
// Builtins for XRay.
def XRayCustomEvent : Builtin {
let Spellings = ["__xray_customevent"];
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index ef4b93fac95ce5..88be7dd90d21bc 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -14132,7 +14132,7 @@ class Sema final {
bool SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res);
bool SemaBuiltinVectorToScalarMath(CallExpr *TheCall);
bool SemaBuiltinElementwiseMath(CallExpr *TheCall);
- bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall);
+ bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall, bool enforceFloatingPointCheck = true);
bool PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall);
bool PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall);
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index e90014261217bc..5cb1dd07aab999 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18057,6 +18057,25 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
/*ReturnType*/ Op0->getType(), Intrinsic::dx_frac,
ArrayRef<Value *>{Op0}, nullptr, "dx.frac");
}
+ case Builtin::BI__builtin_hlsl_mad: {
+ Value *M = EmitScalarExpr(E->getArg(0));
+ Value *A = EmitScalarExpr(E->getArg(1));
+ Value *B = EmitScalarExpr(E->getArg(2));
+ if (E->getArg(0)->getType()->hasFloatingRepresentation()) {
+ return Builder.CreateIntrinsic(
+ /*ReturnType*/ M->getType(), Intrinsic::fmuladd,
+ ArrayRef<Value *>{M, A, B}, nullptr, "dx.fmad");
+ }
+ if (E->getArg(0)->getType()->hasSignedIntegerRepresentation()) {
+ return Builder.CreateIntrinsic(
+ /*ReturnType*/ M->getType(), Intrinsic::dx_imad,
+ ArrayRef<Value *>{M, A, B}, nullptr, "dx.imad");
+ }
+ assert(E->getArg(0)->getType()->hasUnsignedIntegerRepresentation());
+ return Builder.CreateIntrinsic(
+ /*ReturnType*/ M->getType(), Intrinsic::dx_umad,
+ ArrayRef<Value *>{M, A, B}, nullptr, "dx.umad");
+ }
}
return nullptr;
}
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 5180530363889f..b5bef78fae72fe 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -511,6 +511,111 @@ double3 log2(double3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_log2)
double4 log2(double4);
+//===----------------------------------------------------------------------===//
+// mad builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T mad(T M, T A, T B)
+/// \brief The result of \a M * \a A + \a B.
+/// \param M The multiplication value.
+/// \param A The first addition value.
+/// \param B The second addition value.
+
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+half mad(half, half, half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+half2 mad(half2, half2, half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+half3 mad(half3, half3, half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+half4 mad(half4, half4, half4);
+
+#ifdef __HLSL_ENABLE_16_BIT
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int16_t mad(int16_t, int16_t, int16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int16_t2 mad(int16_t2, int16_t2, int16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int16_t3 mad(int16_t3, int16_t3, int16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int16_t4 mad(int16_t4, int16_t4, int16_t4);
+
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint16_t mad(uint16_t, uint16_t, uint16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint16_t2 mad(uint16_t2, uint16_t2, uint16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint16_t3 mad(uint16_t3, uint16_t3, uint16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint16_t4 mad(uint16_t4, uint16_t4, uint16_t4);
+#endif
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int mad(int, int, int);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int2 mad(int2, int2, int2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int3 mad(int3, int3, int3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int4 mad(int4, int4, int4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint mad(uint, uint, uint);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint2 mad(uint2, uint2, uint2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint3 mad(uint3, uint3, uint3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint4 mad(uint4, uint4, uint4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int64_t mad(int64_t, int64_t, int64_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int64_t2 mad(int64_t2, int64_t2, int64_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int64_t3 mad(int64_t3, int64_t3, int64_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+int64_t4 mad(int64_t4, int64_t4, int64_t4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint64_t mad(uint64_t, uint64_t, uint64_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint64_t2 mad(uint64_t2, uint64_t2, uint64_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint64_t3 mad(uint64_t3, uint64_t3, uint64_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+uint64_t4 mad(uint64_t4, uint64_t4, uint64_t4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+float mad(float, float, float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+float2 mad(float2, float2, float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+float3 mad(float3, float3, float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+float4 mad(float4, float4, float4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+double mad(double, double, double);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+double2 mad(double2, double2, double2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+double3 mad(double3, double3, double3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mad)
+double4 mad(double4, double4, double4);
+
//===----------------------------------------------------------------------===//
// max builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 0d4d57db01c93a..795d59ed5fd369 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -5298,6 +5298,14 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
+ case Builtin::BI__builtin_hlsl_mad: {
+ if (checkArgCount(*this, TheCall, 3))
+ return true;
+ if (CheckVectorElementCallArgs(this, TheCall))
+ return true;
+ if (SemaBuiltinElementwiseTernaryMath(TheCall, false))
+ return true;
+ }
}
return false;
}
@@ -19800,7 +19808,7 @@ bool Sema::SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res) {
return false;
}
-bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall) {
+bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall, bool enforceFloatingPointCheck) {
if (checkArgCount(*this, TheCall, 3))
return true;
@@ -19812,11 +19820,13 @@ bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall) {
Args[I] = Converted.get();
}
- int ArgOrdinal = 1;
- for (Expr *Arg : Args) {
- if (checkFPMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(),
- ArgOrdinal++))
- return true;
+ if(enforceFloatingPointCheck) {
+ int ArgOrdinal = 1;
+ for (Expr *Arg : Args) {
+ if (checkFPMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(),
+ ArgOrdinal++))
+ return true;
+ }
}
for (int I = 1; I < 3; ++I) {
diff --git a/clang/test/CodeGenHLSL/builtins/mad.hlsl b/clang/test/CodeGenHLSL/builtins/mad.hlsl
new file mode 100644
index 00000000000000..4dd8f01785afa0
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/mad.hlsl
@@ -0,0 +1,197 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
+// RUN: --check-prefixes=CHECK,NATIVE_HALF
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
+
+#ifdef __HLSL_ENABLE_16_BIT
+// NATIVE_HALF: %dx.umad = call i16 @llvm.dx.umad.i16(i16 %0, i16 %1, i16 %2)
+// NATIVE_HALF: ret i16 %dx.umad
+uint16_t test_mad_uint16_t(uint16_t p0, uint16_t p1, uint16_t p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.umad = call <2 x i16> @llvm.dx.umad.v2i16(<2 x i16> %0, <2 x i16> %1, <2 x i16> %2)
+// NATIVE_HALF: ret <2 x i16> %dx.umad
+uint16_t2 test_mad_uint16_t2(uint16_t2 p0, uint16_t2 p1, uint16_t2 p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.umad = call <3 x i16> @llvm.dx.umad.v3i16(<3 x i16> %0, <3 x i16> %1, <3 x i16> %2)
+// NATIVE_HALF: ret <3 x i16> %dx.umad
+uint16_t3 test_mad_uint16_t3(uint16_t3 p0, uint16_t3 p1, uint16_t3 p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.umad = call <4 x i16> @llvm.dx.umad.v4i16(<4 x i16> %0, <4 x i16> %1, <4 x i16> %2)
+// NATIVE_HALF: ret <4 x i16> %dx.umad
+uint16_t4 test_mad_uint16_t4(uint16_t4 p0, uint16_t4 p1, uint16_t4 p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.imad = call i16 @llvm.dx.imad.i16(i16 %0, i16 %1, i16 %2)
+// NATIVE_HALF: ret i16 %dx.imad
+int16_t test_mad_int16_t(int16_t p0, int16_t p1, int16_t p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.imad = call <2 x i16> @llvm.dx.imad.v2i16(<2 x i16> %0, <2 x i16> %1, <2 x i16> %2)
+// NATIVE_HALF: ret <2 x i16> %dx.imad
+int16_t2 test_mad_int16_t2(int16_t2 p0, int16_t2 p1, int16_t2 p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.imad = call <3 x i16> @llvm.dx.imad.v3i16(<3 x i16> %0, <3 x i16> %1, <3 x i16> %2)
+// NATIVE_HALF: ret <3 x i16> %dx.imad
+int16_t3 test_mad_int16_t3(int16_t3 p0, int16_t3 p1, int16_t3 p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.imad = call <4 x i16> @llvm.dx.imad.v4i16(<4 x i16> %0, <4 x i16> %1, <4 x i16> %2)
+// NATIVE_HALF: ret <4 x i16> %dx.imad
+int16_t4 test_mad_int16_t4(int16_t4 p0, int16_t4 p1, int16_t4 p2) { return mad(p0, p1, p2); }
+#endif // __HLSL_ENABLE_16_BIT
+
+// NATIVE_HALF: %dx.fmad = call half @llvm.fmuladd.f16(half %0, half %1, half %2)
+// NATIVE_HALF: ret half %dx.fmad
+// NO_HALF: %dx.fmad = call float @llvm.fmuladd.f32(float %0, float %1, float %2)
+// NO_HALF: ret float %dx.fmad
+half test_mad_half(half p0, half p1, half p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.fmad = call <2 x half> @llvm.fmuladd.v2f16(<2 x half> %0, <2 x half> %1, <2 x half> %2)
+// NATIVE_HALF: ret <2 x half> %dx.fmad
+// NO_HALF: %dx.fmad = call <2 x float> @llvm.fmuladd.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
+// NO_HALF: ret <2 x float> %dx.fmad
+half2 test_mad_half2(half2 p0, half2 p1, half2 p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.fmad = call <3 x half> @llvm.fmuladd.v3f16(<3 x half> %0, <3 x half> %1, <3 x half> %2)
+// NATIVE_HALF: ret <3 x half> %dx.fmad
+// NO_HALF: %dx.fmad = call <3 x float> @llvm.fmuladd.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
+// NO_HALF: ret <3 x float> %dx.fmad
+half3 test_mad_half3(half3 p0, half3 p1, half3 p2) { return mad(p0, p1, p2); }
+
+// NATIVE_HALF: %dx.fmad = call <4 x half> @llvm.fmuladd.v4f16(<4 x half> %0, <4 x half> %1, <4 x half> %2)
+// NATIVE_HALF: ret <4 x half> %dx.fmad
+// NO_HALF: %dx.fmad = call <4 x float> @llvm.fmuladd.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
+// NO_HALF: ret <4 x float> %dx.fmad
+half4 test_mad_half4(half4 p0, half4 p1, half4 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call float @llvm.fmuladd.f32(float %0, float %1, float %2)
+// CHECK: ret float %dx.fmad
+float test_mad_float(float p0, float p1, float p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call <2 x float> @llvm.fmuladd.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
+// CHECK: ret <2 x float> %dx.fmad
+float2 test_mad_float2(float2 p0, float2 p1, float2 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call <3 x float> @llvm.fmuladd.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
+// CHECK: ret <3 x float> %dx.fmad
+float3 test_mad_float3(float3 p0, float3 p1, float3 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call <4 x float> @llvm.fmuladd.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
+// CHECK: ret <4 x float> %dx.fmad
+float4 test_mad_float4(float4 p0, float4 p1, float4 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call double @llvm.fmuladd.f64(double %0, double %1, double %2)
+// CHECK: ret double %dx.fmad
+double test_mad_double(double p0, double p1, double p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> %0, <2 x double> %1, <2 x double> %2)
+// CHECK: ret <2 x double> %dx.fmad
+double2 test_mad_double2(double2 p0, double2 p1, double2 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call <3 x double> @llvm.fmuladd.v3f64(<3 x double> %0, <3 x double> %1, <3 x double> %2)
+// CHECK: ret <3 x double> %dx.fmad
+double3 test_mad_double3(double3 p0, double3 p1, double3 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call <4 x double> @llvm.fmuladd.v4f64(<4 x double> %0, <4 x double> %1, <4 x double> %2)
+// CHECK: ret <4 x double> %dx.fmad
+double4 test_mad_double4(double4 p0, double4 p1, double4 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.imad = call i32 @llvm.dx.imad.i32(i32 %0, i32 %1, i32 %2)
+// CHECK: ret i32 %dx.imad
+int test_mad_int(int p0, int p1, int p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.imad = call <2 x i32> @llvm.dx.imad.v2i32(<2 x i32> %0, <2 x i32> %1, <2 x i32> %2)
+// CHECK: ret <2 x i32> %dx.imad
+int2 test_mad_int2(int2 p0, int2 p1, int2 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.imad = call <3 x i32> @llvm.dx.imad.v3i32(<3 x i32> %0, <3 x i32> %1, <3 x i32> %2)
+// CHECK: ret <3 x i32> %dx.imad
+int3 test_mad_int3(int3 p0, int3 p1, int3 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.imad = call <4 x i32> @llvm.dx.imad.v4i32(<4 x i32> %0, <4 x i32> %1, <4 x i32> %2)
+// CHECK: ret <4 x i32> %dx.imad
+int4 test_mad_int4(int4 p0, int4 p1, int4 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.imad = call i64 @llvm.dx.imad.i64(i64 %0, i64 %1, i64 %2)
+// CHECK: ret i64 %dx.imad
+int64_t test_mad_int64_t(int64_t p0, int64_t p1, int64_t p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.imad = call <2 x i64> @llvm.dx.imad.v2i64(<2 x i64> %0, <2 x i64> %1, <2 x i64> %2)
+// CHECK: ret <2 x i64> %dx.imad
+int64_t2 test_mad_int64_t2(int64_t2 p0, int64_t2 p1, int64_t2 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.imad = call <3 x i64> @llvm.dx.imad.v3i64(<3 x i64> %0, <3 x i64> %1, <3 x i64> %2)
+// CHECK: ret <3 x i64> %dx.imad
+int64_t3 test_mad_int64_t3(int64_t3 p0, int64_t3 p1, int64_t3 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.imad = call <4 x i64> @llvm.dx.imad.v4i64(<4 x i64> %0, <4 x i64> %1, <4 x i64> %2)
+// CHECK: ret <4 x i64> %dx.imad
+int64_t4 test_mad_int64_t4(int64_t4 p0, int64_t4 p1, int64_t4 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.umad = call i32 @llvm.dx.umad.i32(i32 %0, i32 %1, i32 %2)
+// CHECK: ret i32 %dx.umad
+uint test_mad_uint(uint p0, uint p1, uint p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.umad = call <2 x i32> @llvm.dx.umad.v2i32(<2 x i32> %0, <2 x i32> %1, <2 x i32> %2)
+// CHECK: ret <2 x i32> %dx.umad
+uint2 test_mad_uint2(uint2 p0, uint2 p1, uint2 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.umad = call <3 x i32> @llvm.dx.umad.v3i32(<3 x i32> %0, <3 x i32> %1, <3 x i32> %2)
+// CHECK: ret <3 x i32> %dx.umad
+uint3 test_mad_uint3(uint3 p0, uint3 p1, uint3 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.umad = call <4 x i32> @llvm.dx.umad.v4i32(<4 x i32> %0, <4 x i32> %1, <4 x i32> %2)
+// CHECK: ret <4 x i32> %dx.umad
+uint4 test_mad_uint4(uint4 p0, uint4 p1, uint4 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.umad = call i64 @llvm.dx.umad.i64(i64 %0, i64 %1, i64 %2)
+// CHECK: ret i64 %dx.umad
+uint64_t test_mad_uint64_t(uint64_t p0, uint64_t p1, uint64_t p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.umad = call <2 x i64> @llvm.dx.umad.v2i64(<2 x i64> %0, <2 x i64> %1, <2 x i64> %2)
+// CHECK: ret <2 x i64> %dx.umad
+uint64_t2 test_mad_uint64_t2(uint64_t2 p0, uint64_t2 p1, uint64_t2 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.umad = call <3 x i64> @llvm.dx.umad.v3i64(<3 x i64> %0, <3 x i64> %1, <3 x i64> %2)
+// CHECK: ret <3 x i64> %dx.umad
+uint64_t3 test_mad_uint64_t3(uint64_t3 p0, uint64_t3 p1, uint64_t3 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.umad = call <4 x i64> @llvm.dx.umad.v4i64(<4 x i64> %0, <4 x i64> %1, <4 x i64> %2)
+// CHECK: ret <4 x i64> %dx.umad
+uint64_t4 test_mad_uint64_t4(uint64_t4 p0, uint64_t4 p1, uint64_t4 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call <2 x float> @llvm.fmuladd.v2f32(<2 x float> %splat.splat, <2 x float> %1, <2 x float> %2)
+// CHECK: ret <2 x float> %dx.fmad
+float2 test_mad_float2_splat(float p0, float2 p1, float2 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call <3 x float> @llvm.fmuladd.v3f32(<3 x float> %splat.splat, <3 x float> %1, <3 x float> %2)
+// CHECK: ret <3 x float> %dx.fmad
+float3 test_mad_float3_splat(float p0, float3 p1, float3 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %dx.fmad = call <4 x float> @llvm.fmuladd.v4f32(<4 x float> %splat.splat, <4 x float> %1, <4 x float> %2)
+// CHECK: ret <4 x float> %dx.fmad
+float4 test_mad_float4_splat(float p0, float4 p1, float4 p2) { return mad(p0, p1, p2); }
+
+// CHECK: %conv = sitofp i32 %2 to float
+// CHECK: %splat.splatinsert = insertelement <2 x float> poison, float %conv, i64 0
+// CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
+// CHECK: %dx.fmad = call <2 x float> @llvm.fmuladd.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %splat.splat)
+// CHECK: ret <2 x float> %dx.fmad
+float2 test_mad_float2_int_splat(float2 p0, float2 p1, int p2) {
+ return mad(p0, p1, p2);
+}
+
+// CHECK: %conv = sitofp i32 %2 to float
+// CHECK: %splat.splatinsert = insertelement <3 x float> poison, float %conv, i64 0
+// CHECK: %splat.splat = shufflevector <3 x float> %splat.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer
+// CHECK: %dx.fmad = call <3 x float> @llvm.fmuladd.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %splat.splat)
+// CHECK: ret <3 x float> %dx.fmad
+float3 test_mad_float3_int_splat(float3 p0, float3 p1, int p2) {
+ return mad(p0, p1, p2);
+}
+
+// CHECK: %dx.umad = call i1 @llvm.dx.umad.i1(i1 %tobool, i1 %tobool1, i1 %tobool2)
+// CHECK: ret i1 %dx.umad
+bool test_builtin_mad_bool_type_promotion(bool p0) {
+ return __builtin_hlsl_mad(p0, p0, p0);
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/mad-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/mad-errors.hlsl
new file mode 100644
index 00000000000000..b60ff1d3aa43e0
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/mad-errors.hlsl
@@ -0,0 +1,86 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected
+
+float2 test_no_second_arg(float2 p0) {
+ ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/83826
More information about the cfe-commits
mailing list