[flang-commits] [flang] [flang][runtime] Address PRODUCT numeric discrepancy, folding vs runtime (PR #90125)
via flang-commits
flang-commits at lists.llvm.org
Thu Apr 25 14:17:52 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-runtime
Author: Peter Klausler (klausler)
<details>
<summary>Changes</summary>
Ensure that the runtime implementations of floating-point reductions use intermediate results of the same precision as the operands, so that results match those from constant folding. (SUM reduction uses Kahan summation in both cases.)
---
Full diff: https://github.com/llvm/llvm-project/pull/90125.diff
3 Files Affected:
- (modified) flang/runtime/product.cpp (+4-4)
- (modified) flang/runtime/reduction-templates.h (+5-6)
- (modified) flang/runtime/sum.cpp (+2-1)
``````````diff
diff --git a/flang/runtime/product.cpp b/flang/runtime/product.cpp
index 4c3b8c33a12e0f..7fc0fcd3b107de 100644
--- a/flang/runtime/product.cpp
+++ b/flang/runtime/product.cpp
@@ -107,7 +107,7 @@ CppTypeFor<TypeCategory::Integer, 16> RTDEF(ProductInteger16)(
CppTypeFor<TypeCategory::Real, 4> RTDEF(ProductReal4)(const Descriptor &x,
const char *source, int line, int dim, const Descriptor *mask) {
return GetTotalReduction<TypeCategory::Real, 4>(x, source, line, dim, mask,
- NonComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 8>>{x},
+ NonComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 4>>{x},
"PRODUCT");
}
CppTypeFor<TypeCategory::Real, 8> RTDEF(ProductReal8)(const Descriptor &x,
@@ -137,7 +137,7 @@ void RTDEF(CppProductComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result,
const Descriptor &x, const char *source, int line, int dim,
const Descriptor *mask) {
result = GetTotalReduction<TypeCategory::Complex, 4>(x, source, line, dim,
- mask, ComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 8>>{x},
+ mask, ComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 4>>{x},
"PRODUCT");
}
void RTDEF(CppProductComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result,
@@ -169,8 +169,8 @@ void RTDEF(CppProductComplex16)(CppTypeFor<TypeCategory::Complex, 16> &result,
void RTDEF(ProductDim)(Descriptor &result, const Descriptor &x, int dim,
const char *source, int line, const Descriptor *mask) {
TypedPartialNumericReduction<NonComplexProductAccumulator,
- NonComplexProductAccumulator, ComplexProductAccumulator>(
- result, x, dim, source, line, mask, "PRODUCT");
+ NonComplexProductAccumulator, ComplexProductAccumulator,
+ /*MIN_REAL_KIND=*/4>(result, x, dim, source, line, mask, "PRODUCT");
}
RT_EXT_API_GROUP_END
diff --git a/flang/runtime/reduction-templates.h b/flang/runtime/reduction-templates.h
index f8e6f6095509e6..d102e5642547d6 100644
--- a/flang/runtime/reduction-templates.h
+++ b/flang/runtime/reduction-templates.h
@@ -240,11 +240,10 @@ inline RT_API_ATTRS void PartialIntegerReduction(Descriptor &result,
kind, terminator, result, x, dim, mask, terminator, intrinsic);
}
-template <TypeCategory CAT, template <typename> class ACCUM>
+template <TypeCategory CAT, template <typename> class ACCUM, int MIN_KIND>
struct PartialFloatingReductionHelper {
template <int KIND> struct Functor {
- static constexpr int Intermediate{
- std::max(KIND, 8)}; // use at least "double" for intermediate results
+ static constexpr int Intermediate{std::max(KIND, MIN_KIND)};
RT_API_ATTRS void operator()(Descriptor &result, const Descriptor &x,
int dim, const Descriptor *mask, Terminator &terminator,
const char *intrinsic) const {
@@ -260,7 +259,7 @@ struct PartialFloatingReductionHelper {
template <template <typename> class INTEGER_ACCUM,
template <typename> class REAL_ACCUM,
- template <typename> class COMPLEX_ACCUM>
+ template <typename> class COMPLEX_ACCUM, int MIN_REAL_KIND>
inline RT_API_ATTRS void TypedPartialNumericReduction(Descriptor &result,
const Descriptor &x, int dim, const char *source, int line,
const Descriptor *mask, const char *intrinsic) {
@@ -274,13 +273,13 @@ inline RT_API_ATTRS void TypedPartialNumericReduction(Descriptor &result,
break;
case TypeCategory::Real:
ApplyFloatingPointKind<PartialFloatingReductionHelper<TypeCategory::Real,
- REAL_ACCUM>::template Functor,
+ REAL_ACCUM, MIN_REAL_KIND>::template Functor,
void>(catKind->second, terminator, result, x, dim, mask, terminator,
intrinsic);
break;
case TypeCategory::Complex:
ApplyFloatingPointKind<PartialFloatingReductionHelper<TypeCategory::Complex,
- COMPLEX_ACCUM>::template Functor,
+ COMPLEX_ACCUM, MIN_REAL_KIND>::template Functor,
void>(catKind->second, terminator, result, x, dim, mask, terminator,
intrinsic);
break;
diff --git a/flang/runtime/sum.cpp b/flang/runtime/sum.cpp
index d2495e3e956fe6..1d64554c2785dc 100644
--- a/flang/runtime/sum.cpp
+++ b/flang/runtime/sum.cpp
@@ -188,7 +188,8 @@ void RTDEF(CppSumComplex16)(CppTypeFor<TypeCategory::Complex, 16> &result,
void RTDEF(SumDim)(Descriptor &result, const Descriptor &x, int dim,
const char *source, int line, const Descriptor *mask) {
TypedPartialNumericReduction<IntegerSumAccumulator, RealSumAccumulator,
- ComplexSumAccumulator>(result, x, dim, source, line, mask, "SUM");
+ ComplexSumAccumulator, /*MIN_REAL_KIND=*/4>(
+ result, x, dim, source, line, mask, "SUM");
}
RT_EXT_API_GROUP_END
``````````
</details>
https://github.com/llvm/llvm-project/pull/90125
More information about the flang-commits
mailing list