[clang] [clang] Enable constexpr handling for __builtin_elementwise_fma (PR #152919)
Chaitanya Koparkar via cfe-commits
cfe-commits at lists.llvm.org
Wed Aug 20 05:41:41 PDT 2025
https://github.com/ckoparkar updated https://github.com/llvm/llvm-project/pull/152919
>From c314cb7b2aa3afa9b5e9079d6cfa0526ff7bafc9 Mon Sep 17 00:00:00 2001
From: Chaitanya Koparkar <ckoparkar at gmail.com>
Date: Sat, 9 Aug 2025 07:18:20 -0400
Subject: [PATCH] [clang] Enable constexpr handling for
__builtin_elementwise_fma
---
clang/docs/LanguageExtensions.rst | 8 +--
clang/include/clang/Basic/Builtins.td | 2 +-
clang/lib/AST/ByteCode/InterpBuiltin.cpp | 58 ++++++++++++++++++++
clang/lib/AST/ExprConstant.cpp | 39 +++++++++++++
clang/test/CodeGen/rounding-math.cpp | 52 ++++++++++++++++++
clang/test/Sema/constant-builtins-vector.cpp | 21 +++++++
6 files changed, 175 insertions(+), 5 deletions(-)
diff --git a/clang/docs/LanguageExtensions.rst b/clang/docs/LanguageExtensions.rst
index 97413588fea15..ac80eb0809cd7 100644
--- a/clang/docs/LanguageExtensions.rst
+++ b/clang/docs/LanguageExtensions.rst
@@ -757,11 +757,11 @@ elementwise to the input.
Unless specified otherwise operation(±0) = ±0 and operation(±infinity) = ±infinity
-The integer elementwise intrinsics, including ``__builtin_elementwise_popcount``,
+The elementwise intrinsics ``__builtin_elementwise_popcount``,
``__builtin_elementwise_bitreverse``, ``__builtin_elementwise_add_sat``,
``__builtin_elementwise_sub_sat``, ``__builtin_elementwise_max``,
-``__builtin_elementwise_min``, and ``__builtin_elementwise_abs``
-can be called in a ``constexpr`` context.
+``__builtin_elementwise_min``, ``__builtin_elementwise_abs``, and
+``__builtin_elementwise_fma`` can be called in a ``constexpr`` context.
No implicit promotion of integer types takes place. The mixing of integer types
of different sizes and signs is forbidden in binary and ternary builtins.
@@ -4370,7 +4370,7 @@ fall into one of the specified floating-point classes.
if (__builtin_isfpclass(x, 448)) {
// `x` is positive finite value
- ...
+ ...
}
**Description**:
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 604c9cddfe051..103a85f28927d 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -1498,7 +1498,7 @@ def ElementwiseCopysign : Builtin {
def ElementwiseFma : Builtin {
let Spellings = ["__builtin_elementwise_fma"];
- let Attributes = [NoThrow, Const, CustomTypeChecking];
+ let Attributes = [NoThrow, Const, CustomTypeChecking, Constexpr];
let Prototype = "void(...)";
}
diff --git a/clang/lib/AST/ByteCode/InterpBuiltin.cpp b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
index 729fef92770eb..a179217482ae7 100644
--- a/clang/lib/AST/ByteCode/InterpBuiltin.cpp
+++ b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
@@ -2627,6 +2627,62 @@ static bool interp__builtin_ia32_pmul(InterpState &S, CodePtr OpPC,
return true;
}
+static bool interp__builtin_elementwise_fma(InterpState &S, CodePtr OpPC,
+ const CallExpr *Call) {
+ assert(Call->getNumArgs() == 3);
+
+ FPOptions FPO = Call->getFPFeaturesInEffect(S.Ctx.getLangOpts());
+ llvm::RoundingMode RM = getRoundingMode(FPO);
+ const QualType Arg1Type = Call->getArg(0)->getType();
+ const QualType Arg2Type = Call->getArg(1)->getType();
+ const QualType Arg3Type = Call->getArg(2)->getType();
+
+ // Non-vector floating point types.
+ if (!Arg1Type->isVectorType()) {
+ assert(!Arg2Type->isVectorType());
+ assert(!Arg3Type->isVectorType());
+
+ const Floating &Z = S.Stk.pop<Floating>();
+ const Floating &Y = S.Stk.pop<Floating>();
+ const Floating &X = S.Stk.pop<Floating>();
+ APFloat F = X.getAPFloat();
+ F.fusedMultiplyAdd(Y.getAPFloat(), Z.getAPFloat(), RM);
+ Floating Result = S.allocFloat(X.getSemantics());
+ Result.copy(F);
+ S.Stk.push<Floating>(Result);
+ return true;
+ }
+
+ // Vector type.
+ assert(Arg1Type->isVectorType() && Arg2Type->isVectorType() &&
+ Arg3Type->isVectorType());
+
+ const VectorType *VecT = Arg1Type->castAs<VectorType>();
+ const QualType ElemT = VecT->getElementType();
+ unsigned NumElems = VecT->getNumElements();
+
+ assert(ElemT == Arg2Type->castAs<VectorType>()->getElementType() &&
+ ElemT == Arg3Type->castAs<VectorType>()->getElementType());
+ assert(NumElems == Arg2Type->castAs<VectorType>()->getNumElements() &&
+ NumElems == Arg3Type->castAs<VectorType>()->getNumElements());
+ assert(ElemT->isRealFloatingType());
+
+ const Pointer &VZ = S.Stk.pop<Pointer>();
+ const Pointer &VY = S.Stk.pop<Pointer>();
+ const Pointer &VX = S.Stk.pop<Pointer>();
+ const Pointer &Dst = S.Stk.peek<Pointer>();
+ for (unsigned I = 0; I != NumElems; ++I) {
+ using T = PrimConv<PT_Float>::T;
+ APFloat X = VX.elem<T>(I).getAPFloat();
+ APFloat Y = VY.elem<T>(I).getAPFloat();
+ APFloat Z = VZ.elem<T>(I).getAPFloat();
+ (void)X.fusedMultiplyAdd(Y, Z, RM);
+ Dst.elem<Floating>(I) = Floating(X);
+ }
+ Dst.initializeAllElements();
+ return true;
+}
+
bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
uint32_t BuiltinID) {
if (!S.getASTContext().BuiltinInfo.isConstantEvaluated(BuiltinID))
@@ -3053,6 +3109,8 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
case clang::X86::BI__builtin_ia32_pmuludq128:
case clang::X86::BI__builtin_ia32_pmuludq256:
return interp__builtin_ia32_pmul(S, OpPC, Call, BuiltinID);
+ case Builtin::BI__builtin_elementwise_fma:
+ return interp__builtin_elementwise_fma(S, OpPC, Call);
default:
S.FFDiag(S.Current->getLocation(OpPC),
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 40c56501b0c14..54f3b93347439 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -11827,6 +11827,30 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
return Success(APValue(ResultElements.data(), ResultElements.size()), E);
}
+
+ case Builtin::BI__builtin_elementwise_fma: {
+ APValue SourceX, SourceY, SourceZ;
+ if (!EvaluateAsRValue(Info, E->getArg(0), SourceX) ||
+ !EvaluateAsRValue(Info, E->getArg(1), SourceY) ||
+ !EvaluateAsRValue(Info, E->getArg(2), SourceZ))
+ return false;
+
+ unsigned SourceLen = SourceX.getVectorLength();
+ SmallVector<APValue> ResultElements;
+ ResultElements.reserve(SourceLen);
+ llvm::RoundingMode RM = getActiveRoundingMode(getEvalInfo(), E);
+
+ for (unsigned EltNum = 0; EltNum < SourceLen; ++EltNum) {
+ const APFloat &X = SourceX.getVectorElt(EltNum).getFloat();
+ const APFloat &Y = SourceY.getVectorElt(EltNum).getFloat();
+ const APFloat &Z = SourceZ.getVectorElt(EltNum).getFloat();
+ APFloat Result(X);
+ (void)Result.fusedMultiplyAdd(Y, Z, RM);
+ ResultElements.push_back(APValue(Result));
+ }
+
+ return Success(APValue(ResultElements.data(), ResultElements.size()), E);
+ }
}
}
@@ -16077,6 +16101,21 @@ bool FloatExprEvaluator::VisitCallExpr(const CallExpr *E) {
Result = minimumnum(Result, RHS);
return true;
}
+
+ case Builtin::BI__builtin_elementwise_fma: {
+ if(!E->getArg(0)->isPRValue() || !E->getArg(1)->isPRValue() ||
+ !E->getArg(2)->isPRValue()) {
+ return false;
+ }
+ APFloat SourceY(0.), SourceZ(0.);
+ if (!EvaluateFloat(E->getArg(0), Result, Info) ||
+ !EvaluateFloat(E->getArg(1), SourceY, Info) ||
+ !EvaluateFloat(E->getArg(2), SourceZ, Info))
+ return false;
+ llvm::RoundingMode RM = getActiveRoundingMode(getEvalInfo(), E);
+ (void)Result.fusedMultiplyAdd(SourceY, SourceZ, RM);
+ return true;
+ }
}
}
diff --git a/clang/test/CodeGen/rounding-math.cpp b/clang/test/CodeGen/rounding-math.cpp
index 264031dc9daa9..5c44fd31242c6 100644
--- a/clang/test/CodeGen/rounding-math.cpp
+++ b/clang/test/CodeGen/rounding-math.cpp
@@ -11,3 +11,55 @@ float V3 = func_01(1.0F, 2.0F);
// CHECK: @V1 = {{.*}}global float 1.000000e+00, align 4
// CHECK: @V2 = {{.*}}global float 1.000000e+00, align 4
// CHECK: @V3 = {{.*}}global float 3.000000e+00, align 4
+
+void test_builtin_elementwise_fma_round_upward() {
+ #pragma STDC FENV_ACCESS ON
+ #pragma STDC FENV_ROUND FE_UPWARD
+
+ // CHECK: store float 0x4018000100000000, ptr %f1
+ // CHECK: store float 0x4018000100000000, ptr %f2
+ constexpr float f1 = __builtin_elementwise_fma(2.0F, 3.000001F, 0.000001F);
+ constexpr float f2 = 2.0F * 3.000001F + 0.000001F;
+ static_assert(f1 == f2);
+ static_assert(f1 == 6.00000381F);
+ // CHECK: store double 0x40180000C9539B89, ptr %d1
+ // CHECK: store double 0x40180000C9539B89, ptr %d2
+ constexpr double d1 = __builtin_elementwise_fma(2.0, 3.000001, 0.000001);
+ constexpr double d2 = 2.0 * 3.000001 + 0.000001;
+ static_assert(d1 == d2);
+ static_assert(d1 == 6.0000030000000004);
+}
+
+void test_builtin_elementwise_fma_round_downward() {
+ #pragma STDC FENV_ACCESS ON
+ #pragma STDC FENV_ROUND FE_DOWNWARD
+
+ // CHECK: store float 0x40180000C0000000, ptr %f3
+ // CHECK: store float 0x40180000C0000000, ptr %f4
+ constexpr float f3 = __builtin_elementwise_fma(2.0F, 3.000001F, 0.000001F);
+ constexpr float f4 = 2.0F * 3.000001F + 0.000001F;
+ static_assert(f3 == f4);
+ // CHECK: store double 0x40180000C9539B87, ptr %d3
+ // CHECK: store double 0x40180000C9539B87, ptr %d4
+ constexpr double d3 = __builtin_elementwise_fma(2.0, 3.000001, 0.000001);
+ constexpr double d4 = 2.0 * 3.000001 + 0.000001;
+ static_assert(d3 == d4);
+}
+
+void test_builtin_elementwise_fma_round_nearest() {
+ #pragma STDC FENV_ACCESS ON
+ #pragma STDC FENV_ROUND FE_TONEAREST
+
+ // CHECK: store float 0x40180000C0000000, ptr %f5
+ // CHECK: store float 0x40180000C0000000, ptr %f6
+ constexpr float f5 = __builtin_elementwise_fma(2.0F, 3.000001F, 0.000001F);
+ constexpr float f6 = 2.0F * 3.000001F + 0.000001F;
+ static_assert(f5 == f6);
+ static_assert(f5 == 6.00000286F);
+ // CHECK: store double 0x40180000C9539B89, ptr %d5
+ // CHECK: store double 0x40180000C9539B89, ptr %d6
+ constexpr double d5 = __builtin_elementwise_fma(2.0, 3.000001, 0.000001);
+ constexpr double d6 = 2.0 * 3.000001 + 0.000001;
+ static_assert(d5 == d6);
+ static_assert(d5 == 6.0000030000000004);
+}
diff --git a/clang/test/Sema/constant-builtins-vector.cpp b/clang/test/Sema/constant-builtins-vector.cpp
index 2b7d76e36ce96..ba07fea336b81 100644
--- a/clang/test/Sema/constant-builtins-vector.cpp
+++ b/clang/test/Sema/constant-builtins-vector.cpp
@@ -894,3 +894,24 @@ CHECK_FOUR_FLOAT_VEC(__builtin_elementwise_abs((vector4float){-1.123, 2.123, -3.
CHECK_FOUR_FLOAT_VEC(__builtin_elementwise_abs((vector4double){-1.123, 2.123, -3.123, 4.123}), ((vector4double){1.123, 2.123, 3.123, 4.123}))
static_assert(__builtin_elementwise_abs((float)-1.123) - (float)1.123 < 1e-6); // making sure one element works
#undef CHECK_FOUR_FLOAT_VEC
+
+// Non-vector floating point types.
+static_assert(__builtin_elementwise_fma(2.0, 3.0, 4.0) == 10.0);
+static_assert(__builtin_elementwise_fma(200.0, 300.0, 400.0) == 60400.0);
+// Vector type.
+constexpr vector4float fmaFloat1 =
+ __builtin_elementwise_fma((vector4float){1.0, 2.0, 3.0, 4.0},
+ (vector4float){2.0, 3.0, 4.0, 5.0},
+ (vector4float){3.0, 4.0, 5.0, 6.0});
+static_assert(fmaFloat1[0] == 5.0);
+static_assert(fmaFloat1[1] == 10.0);
+static_assert(fmaFloat1[2] == 17.0);
+static_assert(fmaFloat1[3] == 26.0);
+constexpr vector4double fmaDouble1 =
+ __builtin_elementwise_fma((vector4double){1.0, 2.0, 3.0, 4.0},
+ (vector4double){2.0, 3.0, 4.0, 5.0},
+ (vector4double){3.0, 4.0, 5.0, 6.0});
+static_assert(fmaDouble1[0] == 5.0);
+static_assert(fmaDouble1[1] == 10.0);
+static_assert(fmaDouble1[2] == 17.0);
+static_assert(fmaDouble1[3] == 26.0);
More information about the cfe-commits
mailing list