[clang] [clang] constexpr built-in fma function. (PR #113020)

via cfe-commits cfe-commits at lists.llvm.org
Fri Oct 18 21:18:11 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang

Author: None (c8ef)

<details>
<summary>Changes</summary>

According to [P0533R9](https://wg21.link/P0533R9), the C++ standard library functions `fma` are now `constexpr`:

```c++
  constexpr floating-point-type fma(floating-point-type x, floating-point-type y,
                                    floating-point-type z);
  constexpr float               fmaf(float x, float y, float z);
  constexpr long double         fmal(long double x, long double y, long double z);
```

To implement this feature in libc++, we must make the built-in `fma` function `constexpr`. This patch adds the implementation of a `constexpr` fma function for the current constant evaluator and the new bytecode interpreter.

---
Full diff: https://github.com/llvm/llvm-project/pull/113020.diff


6 Files Affected:

- (modified) clang/docs/ReleaseNotes.rst (+1) 
- (modified) clang/include/clang/Basic/Builtins.td (+1) 
- (modified) clang/lib/AST/ByteCode/InterpBuiltin.cpp (+37) 
- (modified) clang/lib/AST/ExprConstant.cpp (+16) 
- (modified) clang/test/AST/ByteCode/builtin-functions.cpp (+9) 
- (modified) clang/test/Sema/constant-builtins-2.c (+7) 


``````````diff
diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst
index b7a6ace8bb895d..605d55a9e51f37 100644
--- a/clang/docs/ReleaseNotes.rst
+++ b/clang/docs/ReleaseNotes.rst
@@ -273,6 +273,7 @@ Non-comprehensive list of changes in this release
 - Plugins can now define custom attributes that apply to statements
   as well as declarations.
 - ``__builtin_abs`` function can now be used in constant expressions.
+- ``__builtin_fma`` function can now be used in constant expressions.
 
 New Compiler Flags
 ------------------
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 90475a361bb8f8..55f470a9f715b9 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -3723,6 +3723,7 @@ def Fma : FPMathTemplate, LibBuiltin<"math.h"> {
   let Attributes = [NoThrow, ConstIgnoringErrnoAndExceptions];
   let Prototype = "T(T, T, T)";
   let AddBuiltinPrefixedAlias = 1;
+  let OnlyBuiltinPrefixedAliasIsConstexpr = 1;
 }
 
 def Fmax : FPMathTemplate, LibBuiltin<"math.h"> {
diff --git a/clang/lib/AST/ByteCode/InterpBuiltin.cpp b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
index d4a8e6c2035ee5..145f4627dd73da 100644
--- a/clang/lib/AST/ByteCode/InterpBuiltin.cpp
+++ b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
@@ -142,6 +142,19 @@ static bool retPrimValue(InterpState &S, CodePtr OpPC, APValue &Result,
 #undef RET_CASE
 }
 
+/// Get rounding mode to use in evaluation of the specified expression.
+///
+/// If rounding mode is unknown at compile time, still try to evaluate the
+/// expression. If the result is exact, it does not depend on rounding mode.
+/// So return "tonearest" mode instead of "dynamic".
+static llvm::RoundingMode getActiveRoundingMode(InterpState &S, const Expr *E) {
+  llvm::RoundingMode RM =
+      E->getFPFeaturesInEffect(S.getLangOpts()).getRoundingMode();
+  if (RM == llvm::RoundingMode::Dynamic)
+    RM = llvm::RoundingMode::NearestTiesToEven;
+  return RM;
+}
+
 static bool interp__builtin_is_constant_evaluated(InterpState &S, CodePtr OpPC,
                                                   const InterpFrame *Frame,
                                                   const CallExpr *Call) {
@@ -549,6 +562,22 @@ static bool interp__builtin_fpclassify(InterpState &S, CodePtr OpPC,
   return true;
 }
 
+static bool interp__builtin_fma(InterpState &S, CodePtr OpPC,
+                                const InterpFrame *Frame, const Function *Func,
+                                const CallExpr *Call) {
+  const Floating &X = getParam<Floating>(Frame, 0);
+  const Floating &Y = getParam<Floating>(Frame, 1);
+  const Floating &Z = getParam<Floating>(Frame, 2);
+  Floating Result;
+
+  llvm::RoundingMode RM = getActiveRoundingMode(S, Call);
+  Floating::mul(X, Y, RM, &Result);
+  Floating::add(Result, Z, RM, &Result);
+
+  S.Stk.push<Floating>(Result);
+  return true;
+}
+
 // The C standard says "fabs raises no floating-point exceptions,
 // even if x is a signaling NaN. The returned value is independent of
 // the current rounding direction mode."  Therefore constant folding can
@@ -1814,6 +1843,14 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const Function *F,
       return false;
     break;
 
+  case Builtin::BI__builtin_fma:
+  case Builtin::BI__builtin_fmaf:
+  case Builtin::BI__builtin_fmal:
+  case Builtin::BI__builtin_fmaf128:
+    if (!interp__builtin_fma(S, OpPC, Frame, F, Call))
+      return false;
+    break;
+
   case Builtin::BI__builtin_fabs:
   case Builtin::BI__builtin_fabsf:
   case Builtin::BI__builtin_fabsl:
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 8e36cad2d2c6e7..685ce8a63f6c9e 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -15314,6 +15314,22 @@ bool FloatExprEvaluator::VisitCallExpr(const CallExpr *E) {
       Result.changeSign();
     return true;
 
+  case Builtin::BI__builtin_fma:
+  case Builtin::BI__builtin_fmaf:
+  case Builtin::BI__builtin_fmal:
+  case Builtin::BI__builtin_fmaf128: {
+    APFloat Y(0.), Z(0.);
+    if (!EvaluateFloat(E->getArg(0), Result, Info) ||
+        !EvaluateFloat(E->getArg(1), Y, Info) ||
+        !EvaluateFloat(E->getArg(2), Z, Info))
+      return false;
+
+    llvm::RoundingMode RM = getActiveRoundingMode(Info, E);
+    Result.multiply(Y, RM);
+    Result.add(Z, RM);
+    return true;
+  }
+
   case Builtin::BI__arithmetic_fence:
     return EvaluateFloat(E->getArg(0), Result, Info);
 
diff --git a/clang/test/AST/ByteCode/builtin-functions.cpp b/clang/test/AST/ByteCode/builtin-functions.cpp
index b5d334178f8213..0dba62e252b3d7 100644
--- a/clang/test/AST/ByteCode/builtin-functions.cpp
+++ b/clang/test/AST/ByteCode/builtin-functions.cpp
@@ -265,6 +265,15 @@ namespace fpclassify {
   char classify_subnorm [__builtin_fpclassify(-1, -1, -1, +1, -1, 1.0e-38f)];
 }
 
+namespace fma {
+  static_assert(__builtin_fma(1.0, 1.0, 1.0) == 2.0);
+  static_assert(__builtin_fma(1.0, -1.0, 1.0) == 0.0);
+  static_assert(__builtin_fmaf(1.0f, 1.0f, 1.0f) == 2.0f);
+  static_assert(__builtin_fmaf(1.0f, -1.0f, 1.0f) == 0.0f);
+  static_assert(__builtin_fmal(1.0L, 1.0L, 1.0L) == 2.0L);
+  static_assert(__builtin_fmal(1.0L, -1.0L, 1.0L) == 0.0L);
+} // namespace fma
+
 namespace abs {
   static_assert(__builtin_abs(14) == 14, "");
   static_assert(__builtin_labs(14L) == 14L, "");
diff --git a/clang/test/Sema/constant-builtins-2.c b/clang/test/Sema/constant-builtins-2.c
index e465a3c5f0ad86..2a2dbc2caee1d1 100644
--- a/clang/test/Sema/constant-builtins-2.c
+++ b/clang/test/Sema/constant-builtins-2.c
@@ -54,6 +54,13 @@ long double  g18 = __builtin_copysignl(1.0L, -1.0L);
 __float128   g18_2 = __builtin_copysignf128(1.0q, -1.0q);
 #endif
 
+double g19 = __builtin_fma(1.0, 1.0, 1.0);
+float g20 = __builtin_fmaf(1.0f, 1.0f, 1.0f);
+long double g21 = __builtin_fmal(1.0L, 1.0L, 1.0L);
+#if defined(__FLOAT128__) || defined(__SIZEOF_FLOAT128__)
+__float128 g21_2 = __builtin_fma(1.0q, 1.0q, 1.0q);
+#endif
+
 char classify_nan     [__builtin_fpclassify(+1, -1, -1, -1, -1, __builtin_nan(""))];
 char classify_snan    [__builtin_fpclassify(+1, -1, -1, -1, -1, __builtin_nans(""))];
 char classify_inf     [__builtin_fpclassify(-1, +1, -1, -1, -1, __builtin_inf())];

``````````

</details>


https://github.com/llvm/llvm-project/pull/113020


More information about the cfe-commits mailing list