[flang-commits] [flang] ce7700e - [flang][runtime] Address PRODUCT numeric discrepancy, folding vs runtime (#90125)
via flang-commits
flang-commits at lists.llvm.org
Thu May 2 05:43:58 PDT 2024
Author: Peter Klausler
Date: 2024-05-02T05:43:54-07:00
New Revision: ce7700e29dfbc85348942d74d0ca2ba9ac8d8cf5
URL: https://github.com/llvm/llvm-project/commit/ce7700e29dfbc85348942d74d0ca2ba9ac8d8cf5
DIFF: https://github.com/llvm/llvm-project/commit/ce7700e29dfbc85348942d74d0ca2ba9ac8d8cf5.diff
LOG: [flang][runtime] Address PRODUCT numeric discrepancy, folding vs runtime (#90125)
Ensure that the runtime implementations of floating-point reductions use
intermediate results of the same precision as the operands, so that
results match those from constant folding. (SUM reduction uses Kahan
summation in both cases.)
Added:
Modified:
flang/runtime/product.cpp
flang/runtime/reduction-templates.h
flang/runtime/sum.cpp
Removed:
################################################################################
diff --git a/flang/runtime/product.cpp b/flang/runtime/product.cpp
index 4c3b8c33a12e0f..7fc0fcd3b107de 100644
--- a/flang/runtime/product.cpp
+++ b/flang/runtime/product.cpp
@@ -107,7 +107,7 @@ CppTypeFor<TypeCategory::Integer, 16> RTDEF(ProductInteger16)(
CppTypeFor<TypeCategory::Real, 4> RTDEF(ProductReal4)(const Descriptor &x,
const char *source, int line, int dim, const Descriptor *mask) {
return GetTotalReduction<TypeCategory::Real, 4>(x, source, line, dim, mask,
- NonComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 8>>{x},
+ NonComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 4>>{x},
"PRODUCT");
}
CppTypeFor<TypeCategory::Real, 8> RTDEF(ProductReal8)(const Descriptor &x,
@@ -137,7 +137,7 @@ void RTDEF(CppProductComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result,
const Descriptor &x, const char *source, int line, int dim,
const Descriptor *mask) {
result = GetTotalReduction<TypeCategory::Complex, 4>(x, source, line, dim,
- mask, ComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 8>>{x},
+ mask, ComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 4>>{x},
"PRODUCT");
}
void RTDEF(CppProductComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result,
@@ -169,8 +169,8 @@ void RTDEF(CppProductComplex16)(CppTypeFor<TypeCategory::Complex, 16> &result,
void RTDEF(ProductDim)(Descriptor &result, const Descriptor &x, int dim,
const char *source, int line, const Descriptor *mask) {
TypedPartialNumericReduction<NonComplexProductAccumulator,
- NonComplexProductAccumulator, ComplexProductAccumulator>(
- result, x, dim, source, line, mask, "PRODUCT");
+ NonComplexProductAccumulator, ComplexProductAccumulator,
+ /*MIN_REAL_KIND=*/4>(result, x, dim, source, line, mask, "PRODUCT");
}
RT_EXT_API_GROUP_END
diff --git a/flang/runtime/reduction-templates.h b/flang/runtime/reduction-templates.h
index f8e6f6095509e6..d102e5642547d6 100644
--- a/flang/runtime/reduction-templates.h
+++ b/flang/runtime/reduction-templates.h
@@ -240,11 +240,10 @@ inline RT_API_ATTRS void PartialIntegerReduction(Descriptor &result,
kind, terminator, result, x, dim, mask, terminator, intrinsic);
}
-template <TypeCategory CAT, template <typename> class ACCUM>
+template <TypeCategory CAT, template <typename> class ACCUM, int MIN_KIND>
struct PartialFloatingReductionHelper {
template <int KIND> struct Functor {
- static constexpr int Intermediate{
- std::max(KIND, 8)}; // use at least "double" for intermediate results
+ static constexpr int Intermediate{std::max(KIND, MIN_KIND)};
RT_API_ATTRS void operator()(Descriptor &result, const Descriptor &x,
int dim, const Descriptor *mask, Terminator &terminator,
const char *intrinsic) const {
@@ -260,7 +259,7 @@ struct PartialFloatingReductionHelper {
template <template <typename> class INTEGER_ACCUM,
template <typename> class REAL_ACCUM,
- template <typename> class COMPLEX_ACCUM>
+ template <typename> class COMPLEX_ACCUM, int MIN_REAL_KIND>
inline RT_API_ATTRS void TypedPartialNumericReduction(Descriptor &result,
const Descriptor &x, int dim, const char *source, int line,
const Descriptor *mask, const char *intrinsic) {
@@ -274,13 +273,13 @@ inline RT_API_ATTRS void TypedPartialNumericReduction(Descriptor &result,
break;
case TypeCategory::Real:
ApplyFloatingPointKind<PartialFloatingReductionHelper<TypeCategory::Real,
- REAL_ACCUM>::template Functor,
+ REAL_ACCUM, MIN_REAL_KIND>::template Functor,
void>(catKind->second, terminator, result, x, dim, mask, terminator,
intrinsic);
break;
case TypeCategory::Complex:
ApplyFloatingPointKind<PartialFloatingReductionHelper<TypeCategory::Complex,
- COMPLEX_ACCUM>::template Functor,
+ COMPLEX_ACCUM, MIN_REAL_KIND>::template Functor,
void>(catKind->second, terminator, result, x, dim, mask, terminator,
intrinsic);
break;
diff --git a/flang/runtime/sum.cpp b/flang/runtime/sum.cpp
index d2495e3e956fe6..63d8c9029a0ef5 100644
--- a/flang/runtime/sum.cpp
+++ b/flang/runtime/sum.cpp
@@ -134,7 +134,7 @@ CppTypeFor<TypeCategory::Integer, 16> RTDEF(SumInteger16)(const Descriptor &x,
CppTypeFor<TypeCategory::Real, 4> RTDEF(SumReal4)(const Descriptor &x,
const char *source, int line, int dim, const Descriptor *mask) {
return GetTotalReduction<TypeCategory::Real, 4>(
- x, source, line, dim, mask, RealSumAccumulator<double>{x}, "SUM");
+ x, source, line, dim, mask, RealSumAccumulator<float>{x}, "SUM");
}
CppTypeFor<TypeCategory::Real, 8> RTDEF(SumReal8)(const Descriptor &x,
const char *source, int line, int dim, const Descriptor *mask) {
@@ -160,7 +160,7 @@ void RTDEF(CppSumComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result,
const Descriptor &x, const char *source, int line, int dim,
const Descriptor *mask) {
result = GetTotalReduction<TypeCategory::Complex, 4>(
- x, source, line, dim, mask, ComplexSumAccumulator<double>{x}, "SUM");
+ x, source, line, dim, mask, ComplexSumAccumulator<float>{x}, "SUM");
}
void RTDEF(CppSumComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result,
const Descriptor &x, const char *source, int line, int dim,
@@ -188,7 +188,8 @@ void RTDEF(CppSumComplex16)(CppTypeFor<TypeCategory::Complex, 16> &result,
void RTDEF(SumDim)(Descriptor &result, const Descriptor &x, int dim,
const char *source, int line, const Descriptor *mask) {
TypedPartialNumericReduction<IntegerSumAccumulator, RealSumAccumulator,
- ComplexSumAccumulator>(result, x, dim, source, line, mask, "SUM");
+ ComplexSumAccumulator, /*MIN_REAL_KIND=*/4>(
+ result, x, dim, source, line, mask, "SUM");
}
RT_EXT_API_GROUP_END
More information about the flang-commits
mailing list