[clang] [clang] constexpr built-in fma function. (PR #113020)

via cfe-commits cfe-commits at lists.llvm.org
Sat Oct 19 01:34:20 PDT 2024


https://github.com/c8ef updated https://github.com/llvm/llvm-project/pull/113020

>From 93c625ad60fc834e72df667addc6eec83247fc8c Mon Sep 17 00:00:00 2001
From: c8ef <c8ef at outlook.com>
Date: Sat, 19 Oct 2024 03:45:17 +0000
Subject: [PATCH 1/2] constexpr fma

---
 clang/docs/ReleaseNotes.rst                   |  1 +
 clang/include/clang/Basic/Builtins.td         |  1 +
 clang/lib/AST/ByteCode/InterpBuiltin.cpp      | 37 +++++++++++++++++++
 clang/lib/AST/ExprConstant.cpp                | 16 ++++++++
 clang/test/AST/ByteCode/builtin-functions.cpp |  9 +++++
 clang/test/Sema/constant-builtins-2.c         |  7 ++++
 6 files changed, 71 insertions(+)

diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst
index b7a6ace8bb895d..605d55a9e51f37 100644
--- a/clang/docs/ReleaseNotes.rst
+++ b/clang/docs/ReleaseNotes.rst
@@ -273,6 +273,7 @@ Non-comprehensive list of changes in this release
 - Plugins can now define custom attributes that apply to statements
   as well as declarations.
 - ``__builtin_abs`` function can now be used in constant expressions.
+- ``__builtin_fma`` function can now be used in constant expressions.
 
 New Compiler Flags
 ------------------
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 90475a361bb8f8..55f470a9f715b9 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -3723,6 +3723,7 @@ def Fma : FPMathTemplate, LibBuiltin<"math.h"> {
   let Attributes = [NoThrow, ConstIgnoringErrnoAndExceptions];
   let Prototype = "T(T, T, T)";
   let AddBuiltinPrefixedAlias = 1;
+  let OnlyBuiltinPrefixedAliasIsConstexpr = 1;
 }
 
 def Fmax : FPMathTemplate, LibBuiltin<"math.h"> {
diff --git a/clang/lib/AST/ByteCode/InterpBuiltin.cpp b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
index d4a8e6c2035ee5..145f4627dd73da 100644
--- a/clang/lib/AST/ByteCode/InterpBuiltin.cpp
+++ b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
@@ -142,6 +142,19 @@ static bool retPrimValue(InterpState &S, CodePtr OpPC, APValue &Result,
 #undef RET_CASE
 }
 
+/// Get rounding mode to use in evaluation of the specified expression.
+///
+/// If rounding mode is unknown at compile time, still try to evaluate the
+/// expression. If the result is exact, it does not depend on rounding mode.
+/// So return "tonearest" mode instead of "dynamic".
+static llvm::RoundingMode getActiveRoundingMode(InterpState &S, const Expr *E) {
+  llvm::RoundingMode RM =
+      E->getFPFeaturesInEffect(S.getLangOpts()).getRoundingMode();
+  if (RM == llvm::RoundingMode::Dynamic)
+    RM = llvm::RoundingMode::NearestTiesToEven;
+  return RM;
+}
+
 static bool interp__builtin_is_constant_evaluated(InterpState &S, CodePtr OpPC,
                                                   const InterpFrame *Frame,
                                                   const CallExpr *Call) {
@@ -549,6 +562,22 @@ static bool interp__builtin_fpclassify(InterpState &S, CodePtr OpPC,
   return true;
 }
 
+static bool interp__builtin_fma(InterpState &S, CodePtr OpPC,
+                                const InterpFrame *Frame, const Function *Func,
+                                const CallExpr *Call) {
+  const Floating &X = getParam<Floating>(Frame, 0);
+  const Floating &Y = getParam<Floating>(Frame, 1);
+  const Floating &Z = getParam<Floating>(Frame, 2);
+  Floating Result;
+
+  llvm::RoundingMode RM = getActiveRoundingMode(S, Call);
+  Floating::mul(X, Y, RM, &Result);
+  Floating::add(Result, Z, RM, &Result);
+
+  S.Stk.push<Floating>(Result);
+  return true;
+}
+
 // The C standard says "fabs raises no floating-point exceptions,
 // even if x is a signaling NaN. The returned value is independent of
 // the current rounding direction mode."  Therefore constant folding can
@@ -1814,6 +1843,14 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const Function *F,
       return false;
     break;
 
+  case Builtin::BI__builtin_fma:
+  case Builtin::BI__builtin_fmaf:
+  case Builtin::BI__builtin_fmal:
+  case Builtin::BI__builtin_fmaf128:
+    if (!interp__builtin_fma(S, OpPC, Frame, F, Call))
+      return false;
+    break;
+
   case Builtin::BI__builtin_fabs:
   case Builtin::BI__builtin_fabsf:
   case Builtin::BI__builtin_fabsl:
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 8e36cad2d2c6e7..685ce8a63f6c9e 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -15314,6 +15314,22 @@ bool FloatExprEvaluator::VisitCallExpr(const CallExpr *E) {
       Result.changeSign();
     return true;
 
+  case Builtin::BI__builtin_fma:
+  case Builtin::BI__builtin_fmaf:
+  case Builtin::BI__builtin_fmal:
+  case Builtin::BI__builtin_fmaf128: {
+    APFloat Y(0.), Z(0.);
+    if (!EvaluateFloat(E->getArg(0), Result, Info) ||
+        !EvaluateFloat(E->getArg(1), Y, Info) ||
+        !EvaluateFloat(E->getArg(2), Z, Info))
+      return false;
+
+    llvm::RoundingMode RM = getActiveRoundingMode(Info, E);
+    Result.multiply(Y, RM);
+    Result.add(Z, RM);
+    return true;
+  }
+
   case Builtin::BI__arithmetic_fence:
     return EvaluateFloat(E->getArg(0), Result, Info);
 
diff --git a/clang/test/AST/ByteCode/builtin-functions.cpp b/clang/test/AST/ByteCode/builtin-functions.cpp
index b5d334178f8213..0dba62e252b3d7 100644
--- a/clang/test/AST/ByteCode/builtin-functions.cpp
+++ b/clang/test/AST/ByteCode/builtin-functions.cpp
@@ -265,6 +265,15 @@ namespace fpclassify {
   char classify_subnorm [__builtin_fpclassify(-1, -1, -1, +1, -1, 1.0e-38f)];
 }
 
+namespace fma {
+  static_assert(__builtin_fma(1.0, 1.0, 1.0) == 2.0);
+  static_assert(__builtin_fma(1.0, -1.0, 1.0) == 0.0);
+  static_assert(__builtin_fmaf(1.0f, 1.0f, 1.0f) == 2.0f);
+  static_assert(__builtin_fmaf(1.0f, -1.0f, 1.0f) == 0.0f);
+  static_assert(__builtin_fmal(1.0L, 1.0L, 1.0L) == 2.0L);
+  static_assert(__builtin_fmal(1.0L, -1.0L, 1.0L) == 0.0L);
+} // namespace fma
+
 namespace abs {
   static_assert(__builtin_abs(14) == 14, "");
   static_assert(__builtin_labs(14L) == 14L, "");
diff --git a/clang/test/Sema/constant-builtins-2.c b/clang/test/Sema/constant-builtins-2.c
index e465a3c5f0ad86..2a2dbc2caee1d1 100644
--- a/clang/test/Sema/constant-builtins-2.c
+++ b/clang/test/Sema/constant-builtins-2.c
@@ -54,6 +54,13 @@ long double  g18 = __builtin_copysignl(1.0L, -1.0L);
 __float128   g18_2 = __builtin_copysignf128(1.0q, -1.0q);
 #endif
 
+double g19 = __builtin_fma(1.0, 1.0, 1.0);
+float g20 = __builtin_fmaf(1.0f, 1.0f, 1.0f);
+long double g21 = __builtin_fmal(1.0L, 1.0L, 1.0L);
+#if defined(__FLOAT128__) || defined(__SIZEOF_FLOAT128__)
+__float128 g21_2 = __builtin_fma(1.0q, 1.0q, 1.0q);
+#endif
+
 char classify_nan     [__builtin_fpclassify(+1, -1, -1, -1, -1, __builtin_nan(""))];
 char classify_snan    [__builtin_fpclassify(+1, -1, -1, -1, -1, __builtin_nans(""))];
 char classify_inf     [__builtin_fpclassify(-1, +1, -1, -1, -1, __builtin_inf())];

>From 89a87ca6bb11bf152259365731afabaeed5d0f24 Mon Sep 17 00:00:00 2001
From: c8ef <c8ef at outlook.com>
Date: Sat, 19 Oct 2024 08:34:10 +0000
Subject: [PATCH 2/2] address review comments

---
 clang/include/clang/Basic/Builtins.td    | 2 +-
 clang/lib/AST/ByteCode/Floating.h        | 7 +++++++
 clang/lib/AST/ByteCode/InterpBuiltin.cpp | 4 ++--
 clang/lib/AST/ExprConstant.cpp           | 4 +---
 clang/test/Sema/constant-builtins-2.c    | 2 +-
 5 files changed, 12 insertions(+), 7 deletions(-)

diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 55f470a9f715b9..76d54c10f8edb9 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -199,7 +199,7 @@ def FloorF16F128 : Builtin, F16F128MathTemplate {
 
 def FmaF16F128 : Builtin, F16F128MathTemplate {
   let Spellings = ["__builtin_fma"];
-  let Attributes = [FunctionWithBuiltinPrefix, NoThrow, ConstIgnoringErrnoAndExceptions];
+  let Attributes = [FunctionWithBuiltinPrefix, NoThrow, ConstIgnoringErrnoAndExceptions, Constexpr];
   let Prototype = "T(T, T, T)";
 }
 
diff --git a/clang/lib/AST/ByteCode/Floating.h b/clang/lib/AST/ByteCode/Floating.h
index 114487821880fb..134490249c05b6 100644
--- a/clang/lib/AST/ByteCode/Floating.h
+++ b/clang/lib/AST/ByteCode/Floating.h
@@ -203,6 +203,13 @@ class Floating final {
     return R->F.divide(B.F, RM);
   }
 
+  static APFloat::opStatus fma(const Floating &A, const Floating &B,
+                               const Floating &C, llvm::RoundingMode RM,
+                               Floating *R) {
+    *R = Floating(A.F);
+    return R->F.fusedMultiplyAdd(B.F, C.F, RM);
+  }
+
   static bool neg(const Floating &A, Floating *R) {
     *R = -A;
     return false;
diff --git a/clang/lib/AST/ByteCode/InterpBuiltin.cpp b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
index 145f4627dd73da..d5fd67b3f0c8a8 100644
--- a/clang/lib/AST/ByteCode/InterpBuiltin.cpp
+++ b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
@@ -571,8 +571,8 @@ static bool interp__builtin_fma(InterpState &S, CodePtr OpPC,
   Floating Result;
 
   llvm::RoundingMode RM = getActiveRoundingMode(S, Call);
-  Floating::mul(X, Y, RM, &Result);
-  Floating::add(Result, Z, RM, &Result);
+  if (!Floating::fma(X, Y, Z, RM, &Result))
+    return false;
 
   S.Stk.push<Floating>(Result);
   return true;
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 685ce8a63f6c9e..01022f6ac40c63 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -15325,9 +15325,7 @@ bool FloatExprEvaluator::VisitCallExpr(const CallExpr *E) {
       return false;
 
     llvm::RoundingMode RM = getActiveRoundingMode(Info, E);
-    Result.multiply(Y, RM);
-    Result.add(Z, RM);
-    return true;
+    return Result.fusedMultiplyAdd(Y, Z, RM) == APFloat::opOK;
   }
 
   case Builtin::BI__arithmetic_fence:
diff --git a/clang/test/Sema/constant-builtins-2.c b/clang/test/Sema/constant-builtins-2.c
index 2a2dbc2caee1d1..2a47d189c23c7b 100644
--- a/clang/test/Sema/constant-builtins-2.c
+++ b/clang/test/Sema/constant-builtins-2.c
@@ -58,7 +58,7 @@ double g19 = __builtin_fma(1.0, 1.0, 1.0);
 float g20 = __builtin_fmaf(1.0f, 1.0f, 1.0f);
 long double g21 = __builtin_fmal(1.0L, 1.0L, 1.0L);
 #if defined(__FLOAT128__) || defined(__SIZEOF_FLOAT128__)
-__float128 g21_2 = __builtin_fma(1.0q, 1.0q, 1.0q);
+__float128 g21_2 = __builtin_fmaf128(1.0q, 1.0q, 1.0q);
 #endif
 
 char classify_nan     [__builtin_fpclassify(+1, -1, -1, -1, -1, __builtin_nan(""))];



More information about the cfe-commits mailing list