[flang-commits] [flang] [flang][runtime] Address PRODUCT numeric discrepancy, folding vs runtime (PR #90125)

via flang-commits flang-commits at lists.llvm.org
Thu Apr 25 14:17:52 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-runtime

Author: Peter Klausler (klausler)

<details>
<summary>Changes</summary>

Ensure that the runtime implementations of floating-point reductions use intermediate results of the same precision as the operands, so that results match those from constant folding.  (SUM reduction uses Kahan summation in both cases.)

---
Full diff: https://github.com/llvm/llvm-project/pull/90125.diff


3 Files Affected:

- (modified) flang/runtime/product.cpp (+4-4) 
- (modified) flang/runtime/reduction-templates.h (+5-6) 
- (modified) flang/runtime/sum.cpp (+2-1) 


``````````diff
diff --git a/flang/runtime/product.cpp b/flang/runtime/product.cpp
index 4c3b8c33a12e0f..7fc0fcd3b107de 100644
--- a/flang/runtime/product.cpp
+++ b/flang/runtime/product.cpp
@@ -107,7 +107,7 @@ CppTypeFor<TypeCategory::Integer, 16> RTDEF(ProductInteger16)(
 CppTypeFor<TypeCategory::Real, 4> RTDEF(ProductReal4)(const Descriptor &x,
     const char *source, int line, int dim, const Descriptor *mask) {
   return GetTotalReduction<TypeCategory::Real, 4>(x, source, line, dim, mask,
-      NonComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 8>>{x},
+      NonComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 4>>{x},
       "PRODUCT");
 }
 CppTypeFor<TypeCategory::Real, 8> RTDEF(ProductReal8)(const Descriptor &x,
@@ -137,7 +137,7 @@ void RTDEF(CppProductComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result,
     const Descriptor &x, const char *source, int line, int dim,
     const Descriptor *mask) {
   result = GetTotalReduction<TypeCategory::Complex, 4>(x, source, line, dim,
-      mask, ComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 8>>{x},
+      mask, ComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 4>>{x},
       "PRODUCT");
 }
 void RTDEF(CppProductComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result,
@@ -169,8 +169,8 @@ void RTDEF(CppProductComplex16)(CppTypeFor<TypeCategory::Complex, 16> &result,
 void RTDEF(ProductDim)(Descriptor &result, const Descriptor &x, int dim,
     const char *source, int line, const Descriptor *mask) {
   TypedPartialNumericReduction<NonComplexProductAccumulator,
-      NonComplexProductAccumulator, ComplexProductAccumulator>(
-      result, x, dim, source, line, mask, "PRODUCT");
+      NonComplexProductAccumulator, ComplexProductAccumulator,
+      /*MIN_REAL_KIND=*/4>(result, x, dim, source, line, mask, "PRODUCT");
 }
 
 RT_EXT_API_GROUP_END
diff --git a/flang/runtime/reduction-templates.h b/flang/runtime/reduction-templates.h
index f8e6f6095509e6..d102e5642547d6 100644
--- a/flang/runtime/reduction-templates.h
+++ b/flang/runtime/reduction-templates.h
@@ -240,11 +240,10 @@ inline RT_API_ATTRS void PartialIntegerReduction(Descriptor &result,
       kind, terminator, result, x, dim, mask, terminator, intrinsic);
 }
 
-template <TypeCategory CAT, template <typename> class ACCUM>
+template <TypeCategory CAT, template <typename> class ACCUM, int MIN_KIND>
 struct PartialFloatingReductionHelper {
   template <int KIND> struct Functor {
-    static constexpr int Intermediate{
-        std::max(KIND, 8)}; // use at least "double" for intermediate results
+    static constexpr int Intermediate{std::max(KIND, MIN_KIND)};
     RT_API_ATTRS void operator()(Descriptor &result, const Descriptor &x,
         int dim, const Descriptor *mask, Terminator &terminator,
         const char *intrinsic) const {
@@ -260,7 +259,7 @@ struct PartialFloatingReductionHelper {
 
 template <template <typename> class INTEGER_ACCUM,
     template <typename> class REAL_ACCUM,
-    template <typename> class COMPLEX_ACCUM>
+    template <typename> class COMPLEX_ACCUM, int MIN_REAL_KIND>
 inline RT_API_ATTRS void TypedPartialNumericReduction(Descriptor &result,
     const Descriptor &x, int dim, const char *source, int line,
     const Descriptor *mask, const char *intrinsic) {
@@ -274,13 +273,13 @@ inline RT_API_ATTRS void TypedPartialNumericReduction(Descriptor &result,
     break;
   case TypeCategory::Real:
     ApplyFloatingPointKind<PartialFloatingReductionHelper<TypeCategory::Real,
-                               REAL_ACCUM>::template Functor,
+                               REAL_ACCUM, MIN_REAL_KIND>::template Functor,
         void>(catKind->second, terminator, result, x, dim, mask, terminator,
         intrinsic);
     break;
   case TypeCategory::Complex:
     ApplyFloatingPointKind<PartialFloatingReductionHelper<TypeCategory::Complex,
-                               COMPLEX_ACCUM>::template Functor,
+                               COMPLEX_ACCUM, MIN_REAL_KIND>::template Functor,
         void>(catKind->second, terminator, result, x, dim, mask, terminator,
         intrinsic);
     break;
diff --git a/flang/runtime/sum.cpp b/flang/runtime/sum.cpp
index d2495e3e956fe6..1d64554c2785dc 100644
--- a/flang/runtime/sum.cpp
+++ b/flang/runtime/sum.cpp
@@ -188,7 +188,8 @@ void RTDEF(CppSumComplex16)(CppTypeFor<TypeCategory::Complex, 16> &result,
 void RTDEF(SumDim)(Descriptor &result, const Descriptor &x, int dim,
     const char *source, int line, const Descriptor *mask) {
   TypedPartialNumericReduction<IntegerSumAccumulator, RealSumAccumulator,
-      ComplexSumAccumulator>(result, x, dim, source, line, mask, "SUM");
+      ComplexSumAccumulator, /*MIN_REAL_KIND=*/4>(
+      result, x, dim, source, line, mask, "SUM");
 }
 
 RT_EXT_API_GROUP_END

``````````

</details>


https://github.com/llvm/llvm-project/pull/90125


More information about the flang-commits mailing list