[flang-commits] [flang] [flang] Lower REDUCE intrinsic for scalar result (PR #94652)
via flang-commits
flang-commits at lists.llvm.org
Thu Jun 6 11:14:08 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
This patch lowers the `REDUCE` intrinsic call to the runtime equivalent for scalar results. Call with array result will follow.
---
Patch is 36.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/94652.diff
6 Files Affected:
- (modified) flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h (+177-5)
- (modified) flang/include/flang/Optimizer/Builder/Runtime/Reduction.h (+8)
- (modified) flang/lib/Optimizer/Builder/IntrinsicCall.cpp (+61-2)
- (modified) flang/lib/Optimizer/Builder/Runtime/Reduction.cpp (+178)
- (removed) flang/test/Lower/Intrinsics/Todo/reduce.f90 (-13)
- (added) flang/test/Lower/Intrinsics/reduce.f90 (+379)
``````````diff
diff --git a/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h b/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h
index 575746374fcc4..1367e6147f9f9 100644
--- a/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h
@@ -22,6 +22,7 @@
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Runtime/reduce.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "llvm/ADT/SmallVector.h"
@@ -52,6 +53,34 @@ namespace fir::runtime {
using TypeBuilderFunc = mlir::Type (*)(mlir::MLIRContext *);
using FuncTypeBuilderFunc = mlir::FunctionType (*)(mlir::MLIRContext *);
+#define REDUCTION_OPERATION_MODEL(T) \
+ template <> \
+ constexpr TypeBuilderFunc \
+ getModel<Fortran::runtime::ReductionOperation<T>>() { \
+ return [](mlir::MLIRContext *context) -> mlir::Type { \
+ TypeBuilderFunc f{getModel<T>()}; \
+ auto refTy = fir::ReferenceType::get(f(context)); \
+ return mlir::FunctionType::get(context, {refTy, refTy}, refTy); \
+ }; \
+ }
+
+#define REDUCTION_CHAR_OPERATION_MODEL(T) \
+ template <> \
+ constexpr TypeBuilderFunc \
+ getModel<Fortran::runtime::ReductionCharOperation<T>>() { \
+ return [](mlir::MLIRContext *context) -> mlir::Type { \
+ TypeBuilderFunc f{getModel<T>()}; \
+ auto voidTy = fir::LLVMPointerType::get( \
+ context, mlir::IntegerType::get(context, 8)); \
+ auto size_tTy = \
+ mlir::IntegerType::get(context, 8 * sizeof(std::size_t)); \
+ auto refTy = fir::ReferenceType::get(f(context)); \
+ return mlir::FunctionType::get( \
+ context, {refTy, size_tTy, refTy, refTy, size_tTy, size_tTy}, \
+ voidTy); \
+ }; \
+ }
+
//===----------------------------------------------------------------------===//
// Type builder models
//===----------------------------------------------------------------------===//
@@ -75,7 +104,6 @@ constexpr TypeBuilderFunc getModel<unsigned int>() {
return mlir::IntegerType::get(context, 8 * sizeof(unsigned int));
};
}
-
template <>
constexpr TypeBuilderFunc getModel<short int>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
@@ -83,6 +111,17 @@ constexpr TypeBuilderFunc getModel<short int>() {
};
}
template <>
+constexpr TypeBuilderFunc getModel<short int *>() {
+ return [](mlir::MLIRContext *context) -> mlir::Type {
+ TypeBuilderFunc f{getModel<short int>()};
+ return fir::ReferenceType::get(f(context));
+ };
+}
+template <>
+constexpr TypeBuilderFunc getModel<const short int *>() {
+ return getModel<short int *>();
+}
+template <>
constexpr TypeBuilderFunc getModel<int>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(context, 8 * sizeof(int));
@@ -96,6 +135,17 @@ constexpr TypeBuilderFunc getModel<int &>() {
};
}
template <>
+constexpr TypeBuilderFunc getModel<int *>() {
+ return getModel<int &>();
+}
+template <>
+constexpr TypeBuilderFunc getModel<const int *>() {
+ return [](mlir::MLIRContext *context) -> mlir::Type {
+ TypeBuilderFunc f{getModel<int>()};
+ return fir::ReferenceType::get(f(context));
+ };
+}
+template <>
constexpr TypeBuilderFunc getModel<char *>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
return fir::ReferenceType::get(mlir::IntegerType::get(context, 8));
@@ -130,6 +180,43 @@ constexpr TypeBuilderFunc getModel<signed char>() {
};
}
template <>
+constexpr TypeBuilderFunc getModel<signed char *>() {
+ return [](mlir::MLIRContext *context) -> mlir::Type {
+ TypeBuilderFunc f{getModel<signed char>()};
+ return fir::ReferenceType::get(f(context));
+ };
+}
+template <>
+constexpr TypeBuilderFunc getModel<const signed char *>() {
+ return getModel<signed char *>();
+}
+template <>
+constexpr TypeBuilderFunc getModel<char16_t>() {
+ return [](mlir::MLIRContext *context) -> mlir::Type {
+ return mlir::IntegerType::get(context, 8 * sizeof(char16_t));
+ };
+}
+template <>
+constexpr TypeBuilderFunc getModel<char16_t *>() {
+ return [](mlir::MLIRContext *context) -> mlir::Type {
+ TypeBuilderFunc f{getModel<char16_t>()};
+ return fir::ReferenceType::get(f(context));
+ };
+}
+template <>
+constexpr TypeBuilderFunc getModel<char32_t>() {
+ return [](mlir::MLIRContext *context) -> mlir::Type {
+ return mlir::IntegerType::get(context, 8 * sizeof(char32_t));
+ };
+}
+template <>
+constexpr TypeBuilderFunc getModel<char32_t *>() {
+ return [](mlir::MLIRContext *context) -> mlir::Type {
+ TypeBuilderFunc f{getModel<char32_t>()};
+ return fir::ReferenceType::get(f(context));
+ };
+}
+template <>
constexpr TypeBuilderFunc getModel<unsigned char>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(context, 8 * sizeof(unsigned char));
@@ -175,6 +262,10 @@ constexpr TypeBuilderFunc getModel<long *>() {
return getModel<long &>();
}
template <>
+constexpr TypeBuilderFunc getModel<const long *>() {
+ return getModel<long *>();
+}
+template <>
constexpr TypeBuilderFunc getModel<long long>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(context, 8 * sizeof(long long));
@@ -198,6 +289,7 @@ template <>
constexpr TypeBuilderFunc getModel<long long *>() {
return getModel<long long &>();
}
+
template <>
constexpr TypeBuilderFunc getModel<unsigned long>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
@@ -228,6 +320,27 @@ constexpr TypeBuilderFunc getModel<double *>() {
return getModel<double &>();
}
template <>
+constexpr TypeBuilderFunc getModel<const double *>() {
+ return getModel<double *>();
+}
+template <>
+constexpr TypeBuilderFunc getModel<long double>() {
+ return [](mlir::MLIRContext *context) -> mlir::Type {
+ return mlir::FloatType::getF80(context);
+ };
+}
+template <>
+constexpr TypeBuilderFunc getModel<long double *>() {
+ return [](mlir::MLIRContext *context) -> mlir::Type {
+ TypeBuilderFunc f{getModel<long double>()};
+ return fir::ReferenceType::get(f(context));
+ };
+}
+template <>
+constexpr TypeBuilderFunc getModel<const long double *>() {
+ return getModel<long double *>();
+}
+template <>
constexpr TypeBuilderFunc getModel<float>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::FloatType::getF32(context);
@@ -245,6 +358,10 @@ constexpr TypeBuilderFunc getModel<float *>() {
return getModel<float &>();
}
template <>
+constexpr TypeBuilderFunc getModel<const float *>() {
+ return getModel<float *>();
+}
+template <>
constexpr TypeBuilderFunc getModel<bool>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(context, 1);
@@ -258,20 +375,48 @@ constexpr TypeBuilderFunc getModel<bool &>() {
};
}
template <>
+constexpr TypeBuilderFunc getModel<std::complex<float>>() {
+ return [](mlir::MLIRContext *context) -> mlir::Type {
+ return mlir::ComplexType::get(mlir::FloatType::getF32(context));
+ };
+}
+template <>
constexpr TypeBuilderFunc getModel<std::complex<float> &>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
- auto ty = mlir::ComplexType::get(mlir::FloatType::getF32(context));
- return fir::ReferenceType::get(ty);
+ TypeBuilderFunc f{getModel<std::complex<float>>()};
+ return fir::ReferenceType::get(f(context));
+ };
+}
+template <>
+constexpr TypeBuilderFunc getModel<std::complex<float> *>() {
+ return getModel<std::complex<float> &>();
+}
+template <>
+constexpr TypeBuilderFunc getModel<const std::complex<float> *>() {
+ return getModel<std::complex<float> *>();
+}
+template <>
+constexpr TypeBuilderFunc getModel<std::complex<double>>() {
+ return [](mlir::MLIRContext *context) -> mlir::Type {
+ return mlir::ComplexType::get(mlir::FloatType::getF64(context));
};
}
template <>
constexpr TypeBuilderFunc getModel<std::complex<double> &>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
- auto ty = mlir::ComplexType::get(mlir::FloatType::getF64(context));
- return fir::ReferenceType::get(ty);
+ TypeBuilderFunc f{getModel<std::complex<double>>()};
+ return fir::ReferenceType::get(f(context));
};
}
template <>
+constexpr TypeBuilderFunc getModel<std::complex<double> *>() {
+ return getModel<std::complex<double> &>();
+}
+template <>
+constexpr TypeBuilderFunc getModel<const std::complex<double> *>() {
+ return getModel<std::complex<double> *>();
+}
+template <>
constexpr TypeBuilderFunc getModel<c_float_complex_t>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
return fir::ComplexType::get(context, sizeof(float));
@@ -332,6 +477,33 @@ constexpr TypeBuilderFunc getModel<void>() {
};
}
+REDUCTION_OPERATION_MODEL(std::int8_t)
+REDUCTION_OPERATION_MODEL(std::int16_t)
+REDUCTION_OPERATION_MODEL(std::int32_t)
+REDUCTION_OPERATION_MODEL(std::int64_t)
+REDUCTION_OPERATION_MODEL(Fortran::common::int128_t)
+
+REDUCTION_OPERATION_MODEL(float)
+REDUCTION_OPERATION_MODEL(double)
+REDUCTION_OPERATION_MODEL(long double)
+
+REDUCTION_OPERATION_MODEL(std::complex<float>)
+REDUCTION_OPERATION_MODEL(std::complex<double>)
+
+REDUCTION_CHAR_OPERATION_MODEL(char)
+REDUCTION_CHAR_OPERATION_MODEL(char16_t)
+REDUCTION_CHAR_OPERATION_MODEL(char32_t)
+
+template <>
+constexpr TypeBuilderFunc
+getModel<Fortran::runtime::ReductionDerivedTypeOperation>() {
+ return [](mlir::MLIRContext *context) -> mlir::Type {
+ auto voidTy =
+ fir::LLVMPointerType::get(context, mlir::IntegerType::get(context, 8));
+ return mlir::FunctionType::get(context, {voidTy, voidTy, voidTy}, voidTy);
+ };
+}
+
template <typename...>
struct RuntimeTableKey;
template <typename RT, typename... ATs>
diff --git a/flang/include/flang/Optimizer/Builder/Runtime/Reduction.h b/flang/include/flang/Optimizer/Builder/Runtime/Reduction.h
index 667ea9081a893..a4adaa72fa41a 100644
--- a/flang/include/flang/Optimizer/Builder/Runtime/Reduction.h
+++ b/flang/include/flang/Optimizer/Builder/Runtime/Reduction.h
@@ -224,6 +224,14 @@ void genIParityDim(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value resultBox, mlir::Value arrayBox, mlir::Value dim,
mlir::Value maskBox);
+/// Generate call to `Reduce` intrinsic runtime routine. This is the version
+/// that does not take a dim argument.
+mlir::Value genReduce(fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::Value arrayBox, mlir::Value operation,
+ mlir::Value dim, mlir::Value maskBox,
+ mlir::Value identity, mlir::Value ordered,
+ mlir::Value resultBox);
+
} // namespace fir::runtime
#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_REDUCTION_H
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 1cd3976d0afbe..b1d0be6a3ec4c 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -522,7 +522,7 @@ static constexpr IntrinsicHandler handlers[]{
{"operation", asAddr},
{"dim", asValue},
{"mask", asBox, handleDynamicOptional},
- {"identity", asValue},
+ {"identity", asAddr},
{"ordered", asValue}}},
/*isElemental=*/false},
{"repeat",
@@ -5705,7 +5705,66 @@ void IntrinsicLibrary::genRandomSeed(llvm::ArrayRef<fir::ExtendedValue> args) {
fir::ExtendedValue
IntrinsicLibrary::genReduce(mlir::Type resultType,
llvm::ArrayRef<fir::ExtendedValue> args) {
- TODO(loc, "intrinsic: reduce");
+ assert(args.size() == 6);
+
+ fir::BoxValue arrayTmp = builder.createBox(loc, args[0]);
+ mlir::Value array = fir::getBase(arrayTmp);
+ mlir::Value operation = fir::getBase(args[1]);
+ int rank = arrayTmp.rank();
+ assert(rank >= 1);
+
+ mlir::Type ty = array.getType();
+ mlir::Type arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
+ mlir::Type eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
+
+ // Handle optional mask argument
+ auto dim = isStaticallyAbsent(args[3])
+ ? builder.createIntegerConstant(loc, builder.getI32Type(), 1)
+ : fir::getBase(args[2]);
+
+ auto mask = isStaticallyAbsent(args[3])
+ ? builder.create<fir::AbsentOp>(
+ loc, fir::BoxType::get(builder.getI1Type()))
+ : builder.createBox(loc, args[3]);
+
+ mlir::Value identity =
+ isStaticallyAbsent(args[4])
+ ? builder.create<fir::AbsentOp>(loc, fir::ReferenceType::get(eleTy))
+ : fir::getBase(args[4]);
+
+ mlir::Value ordered = isStaticallyAbsent(args[5])
+ ? builder.createBool(loc, true)
+ : fir::getBase(args[5]);
+
+ // We call the type specific versions because the result is scalar
+ // in the case below.
+ if (rank == 1) {
+ if (fir::isa_complex(eleTy) || fir::isa_derived(eleTy)) {
+ mlir::Value result = builder.createTemporary(loc, eleTy);
+ fir::runtime::genReduce(builder, loc, array, operation, dim, mask,
+ identity, ordered, result);
+ if (fir::isa_derived(eleTy))
+ return result;
+ return builder.create<fir::LoadOp>(loc, result);
+ }
+ if (fir::isa_char(eleTy)) {
+ // Create mutable fir.box to be passed to the runtime for the result.
+ fir::MutableBoxValue resultMutableBox =
+ fir::factory::createTempMutableBox(builder, loc, eleTy);
+ mlir::Value resultIrBox =
+ fir::factory::getMutableIRBox(builder, loc, resultMutableBox);
+ fir::runtime::genReduce(builder, loc, array, operation, dim, mask,
+ identity, ordered, resultIrBox);
+ // Handle cleanup of allocatable result descriptor and return
+ return readAndAddCleanUp(resultMutableBox, resultType, "REDUCE");
+ }
+ auto resultBox = builder.create<fir::AbsentOp>(
+ loc, fir::BoxType::get(builder.getI1Type()));
+ return fir::runtime::genReduce(builder, loc, array, operation, dim, mask,
+ identity, ordered, resultBox);
+ }
+
+ TODO(loc, "intrinsic: reduce with non scalar result");
}
// REPEAT
diff --git a/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp b/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
index d4076067bf103..0c1af6159c939 100644
--- a/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
+++ b/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
@@ -12,6 +12,7 @@
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
#include "flang/Optimizer/Support/Utils.h"
+#include "flang/Runtime/reduce.h"
#include "flang/Runtime/reduction.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -466,6 +467,85 @@ struct ForcedIParity16 {
}
};
+/// Placeholder for real*16 version of Reduce Intrinsic
+struct ForcedReduceReal16 {
+ static constexpr const char *name = ExpandAndQuoteKey(RTNAME(ReduceReal16));
+ static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
+ return [](mlir::MLIRContext *ctx) {
+ auto ty = mlir::FloatType::getF128(ctx);
+ auto boxTy =
+ fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
+ auto opTy = mlir::FunctionType::get(ctx, {ty, ty}, ty);
+ auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
+ auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
+ auto refTy = fir::ReferenceType::get(ty);
+ auto i1Ty = mlir::IntegerType::get(ctx, 1);
+ return mlir::FunctionType::get(
+ ctx, {boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty}, {ty});
+ };
+ }
+};
+
+/// Placeholder for integer*16 version of Reduce Intrinsic
+struct ForcedReduceInteger16 {
+ static constexpr const char *name =
+ ExpandAndQuoteKey(RTNAME(ReduceInteger16));
+ static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
+ return [](mlir::MLIRContext *ctx) {
+ auto ty = mlir::IntegerType::get(ctx, 128);
+ auto boxTy =
+ fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
+ auto opTy = mlir::FunctionType::get(ctx, {ty, ty}, ty);
+ auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
+ auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
+ auto refTy = fir::ReferenceType::get(ty);
+ auto i1Ty = mlir::IntegerType::get(ctx, 1);
+ return mlir::FunctionType::get(
+ ctx, {boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty}, {ty});
+ };
+ }
+};
+
+/// Placeholder for complex(10) version of Reduce Intrinsic
+struct ForcedReduceComplex10 {
+ static constexpr const char *name =
+ ExpandAndQuoteKey(RTNAME(CppReduceComplex10));
+ static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
+ return [](mlir::MLIRContext *ctx) {
+ auto ty = mlir::ComplexType::get(mlir::FloatType::getF80(ctx));
+ auto boxTy =
+ fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
+ auto opTy = mlir::FunctionType::get(ctx, {ty, ty}, ty);
+ auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
+ auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
+ auto refTy = fir::ReferenceType::get(ty);
+ auto i1Ty = mlir::IntegerType::get(ctx, 1);
+ return mlir::FunctionType::get(
+ ctx, {ty, boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty}, {});
+ };
+ }
+};
+
+/// Placeholder for complex(16) version of Reduce Intrinsic
+struct ForcedReduceComplex16 {
+ static constexpr const char *name =
+ ExpandAndQuoteKey(RTNAME(CppReduceComplex16));
+ static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
+ return [](mlir::MLIRContext *ctx) {
+ auto ty = mlir::ComplexType::get(mlir::FloatType::getF128(ctx));
+ auto boxTy =
+ fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
+ auto opTy = mlir::FunctionType::get(ctx, {ty, ty}, ty);
+ auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
+ auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
+ auto refTy = fir::ReferenceType::get(ty);
+ auto i1Ty = mlir::IntegerType::get(ctx, 1);
+ return mlir::FunctionType::get(
+ ctx, {ty, boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty}, {});
+ };
+ }
+};
+
/// Generate call to specialized runtime function that takes a mask and
/// dim argument. The All, Any, and Count intrinsics use this pattern.
template <typename FN>
@@ -1237,3 +1317,101 @@ void fir::runtime::genIParityDim(fir::FirOpBuilder &builder, mlir::Location loc,
/// Generate call to `IParity` intrinsic runtime routine. This is the version
/// that does not take a dim argument.
GEN_IALL_IANY_IPARITY(IParity)
+
+/// Generate call to `Reduce` intrinsic runtime routine. This is the version
+/// that does have a scalar result.
+mlir::Value fir::runtime::genReduce(fir::FirOpBuilder &builder,
+ mlir::Location loc, mlir::Value arrayBox,
+ mlir::Value operation, mlir::Value dim,
+ mlir::Value maskBox, mlir::Value identity,
+ mlir::Value ordered,
+ mlir::Value resultBox) {
+ mlir::func::FuncOp func;
+ a...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/94652
More information about the flang-commits
mailing list