[flang-commits] [flang] [flang] Lower REDUCE intrinsic with no DIM argument and rank 1 (PR #94652)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Thu Jun 6 13:55:05 PDT 2024


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/94652

>From 5718e0483b5fe5204905cdbfff8a0175d85fc869 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Tue, 4 Jun 2024 07:55:25 -0700
Subject: [PATCH 1/3] [flang] Lower REDUCE intrinsic for scalar result

---
 .../Optimizer/Builder/Runtime/RTBuilder.h     | 182 ++++++++-
 .../Optimizer/Builder/Runtime/Reduction.h     |   8 +
 flang/lib/Optimizer/Builder/IntrinsicCall.cpp |  63 ++-
 .../Optimizer/Builder/Runtime/Reduction.cpp   | 178 ++++++++
 flang/test/Lower/Intrinsics/Todo/reduce.f90   |  13 -
 flang/test/Lower/Intrinsics/reduce.f90        | 379 ++++++++++++++++++
 6 files changed, 803 insertions(+), 20 deletions(-)
 delete mode 100644 flang/test/Lower/Intrinsics/Todo/reduce.f90
 create mode 100644 flang/test/Lower/Intrinsics/reduce.f90

diff --git a/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h b/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h
index 575746374fcc4..1367e6147f9f9 100644
--- a/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h
@@ -22,6 +22,7 @@
 #include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Optimizer/Dialect/FIRDialect.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Runtime/reduce.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/MLIRContext.h"
 #include "llvm/ADT/SmallVector.h"
@@ -52,6 +53,34 @@ namespace fir::runtime {
 using TypeBuilderFunc = mlir::Type (*)(mlir::MLIRContext *);
 using FuncTypeBuilderFunc = mlir::FunctionType (*)(mlir::MLIRContext *);
 
+#define REDUCTION_OPERATION_MODEL(T)                                           \
+  template <>                                                                  \
+  constexpr TypeBuilderFunc                                                    \
+  getModel<Fortran::runtime::ReductionOperation<T>>() {                        \
+    return [](mlir::MLIRContext *context) -> mlir::Type {                      \
+      TypeBuilderFunc f{getModel<T>()};                                        \
+      auto refTy = fir::ReferenceType::get(f(context));                        \
+      return mlir::FunctionType::get(context, {refTy, refTy}, refTy);          \
+    };                                                                         \
+  }
+
+#define REDUCTION_CHAR_OPERATION_MODEL(T)                                      \
+  template <>                                                                  \
+  constexpr TypeBuilderFunc                                                    \
+  getModel<Fortran::runtime::ReductionCharOperation<T>>() {                    \
+    return [](mlir::MLIRContext *context) -> mlir::Type {                      \
+      TypeBuilderFunc f{getModel<T>()};                                        \
+      auto voidTy = fir::LLVMPointerType::get(                                 \
+          context, mlir::IntegerType::get(context, 8));                        \
+      auto size_tTy =                                                          \
+          mlir::IntegerType::get(context, 8 * sizeof(std::size_t));            \
+      auto refTy = fir::ReferenceType::get(f(context));                        \
+      return mlir::FunctionType::get(                                          \
+          context, {refTy, size_tTy, refTy, refTy, size_tTy, size_tTy},        \
+          voidTy);                                                             \
+    };                                                                         \
+  }
+
 //===----------------------------------------------------------------------===//
 // Type builder models
 //===----------------------------------------------------------------------===//
@@ -75,7 +104,6 @@ constexpr TypeBuilderFunc getModel<unsigned int>() {
     return mlir::IntegerType::get(context, 8 * sizeof(unsigned int));
   };
 }
-
 template <>
 constexpr TypeBuilderFunc getModel<short int>() {
   return [](mlir::MLIRContext *context) -> mlir::Type {
@@ -83,6 +111,17 @@ constexpr TypeBuilderFunc getModel<short int>() {
   };
 }
 template <>
+constexpr TypeBuilderFunc getModel<short int *>() {
+  return [](mlir::MLIRContext *context) -> mlir::Type {
+    TypeBuilderFunc f{getModel<short int>()};
+    return fir::ReferenceType::get(f(context));
+  };
+}
+template <>
+constexpr TypeBuilderFunc getModel<const short int *>() {
+  return getModel<short int *>();
+}
+template <>
 constexpr TypeBuilderFunc getModel<int>() {
   return [](mlir::MLIRContext *context) -> mlir::Type {
     return mlir::IntegerType::get(context, 8 * sizeof(int));
@@ -96,6 +135,17 @@ constexpr TypeBuilderFunc getModel<int &>() {
   };
 }
 template <>
+constexpr TypeBuilderFunc getModel<int *>() {
+  return getModel<int &>();
+}
+template <>
+constexpr TypeBuilderFunc getModel<const int *>() {
+  return [](mlir::MLIRContext *context) -> mlir::Type {
+    TypeBuilderFunc f{getModel<int>()};
+    return fir::ReferenceType::get(f(context));
+  };
+}
+template <>
 constexpr TypeBuilderFunc getModel<char *>() {
   return [](mlir::MLIRContext *context) -> mlir::Type {
     return fir::ReferenceType::get(mlir::IntegerType::get(context, 8));
@@ -130,6 +180,43 @@ constexpr TypeBuilderFunc getModel<signed char>() {
   };
 }
 template <>
+constexpr TypeBuilderFunc getModel<signed char *>() {
+  return [](mlir::MLIRContext *context) -> mlir::Type {
+    TypeBuilderFunc f{getModel<signed char>()};
+    return fir::ReferenceType::get(f(context));
+  };
+}
+template <>
+constexpr TypeBuilderFunc getModel<const signed char *>() {
+  return getModel<signed char *>();
+}
+template <>
+constexpr TypeBuilderFunc getModel<char16_t>() {
+  return [](mlir::MLIRContext *context) -> mlir::Type {
+    return mlir::IntegerType::get(context, 8 * sizeof(char16_t));
+  };
+}
+template <>
+constexpr TypeBuilderFunc getModel<char16_t *>() {
+  return [](mlir::MLIRContext *context) -> mlir::Type {
+    TypeBuilderFunc f{getModel<char16_t>()};
+    return fir::ReferenceType::get(f(context));
+  };
+}
+template <>
+constexpr TypeBuilderFunc getModel<char32_t>() {
+  return [](mlir::MLIRContext *context) -> mlir::Type {
+    return mlir::IntegerType::get(context, 8 * sizeof(char32_t));
+  };
+}
+template <>
+constexpr TypeBuilderFunc getModel<char32_t *>() {
+  return [](mlir::MLIRContext *context) -> mlir::Type {
+    TypeBuilderFunc f{getModel<char32_t>()};
+    return fir::ReferenceType::get(f(context));
+  };
+}
+template <>
 constexpr TypeBuilderFunc getModel<unsigned char>() {
   return [](mlir::MLIRContext *context) -> mlir::Type {
     return mlir::IntegerType::get(context, 8 * sizeof(unsigned char));
@@ -175,6 +262,10 @@ constexpr TypeBuilderFunc getModel<long *>() {
   return getModel<long &>();
 }
 template <>
+constexpr TypeBuilderFunc getModel<const long *>() {
+  return getModel<long *>();
+}
+template <>
 constexpr TypeBuilderFunc getModel<long long>() {
   return [](mlir::MLIRContext *context) -> mlir::Type {
     return mlir::IntegerType::get(context, 8 * sizeof(long long));
@@ -198,6 +289,7 @@ template <>
 constexpr TypeBuilderFunc getModel<long long *>() {
   return getModel<long long &>();
 }
+
 template <>
 constexpr TypeBuilderFunc getModel<unsigned long>() {
   return [](mlir::MLIRContext *context) -> mlir::Type {
@@ -228,6 +320,27 @@ constexpr TypeBuilderFunc getModel<double *>() {
   return getModel<double &>();
 }
 template <>
+constexpr TypeBuilderFunc getModel<const double *>() {
+  return getModel<double *>();
+}
+template <>
+constexpr TypeBuilderFunc getModel<long double>() {
+  return [](mlir::MLIRContext *context) -> mlir::Type {
+    return mlir::FloatType::getF80(context);
+  };
+}
+template <>
+constexpr TypeBuilderFunc getModel<long double *>() {
+  return [](mlir::MLIRContext *context) -> mlir::Type {
+    TypeBuilderFunc f{getModel<long double>()};
+    return fir::ReferenceType::get(f(context));
+  };
+}
+template <>
+constexpr TypeBuilderFunc getModel<const long double *>() {
+  return getModel<long double *>();
+}
+template <>
 constexpr TypeBuilderFunc getModel<float>() {
   return [](mlir::MLIRContext *context) -> mlir::Type {
     return mlir::FloatType::getF32(context);
@@ -245,6 +358,10 @@ constexpr TypeBuilderFunc getModel<float *>() {
   return getModel<float &>();
 }
 template <>
+constexpr TypeBuilderFunc getModel<const float *>() {
+  return getModel<float *>();
+}
+template <>
 constexpr TypeBuilderFunc getModel<bool>() {
   return [](mlir::MLIRContext *context) -> mlir::Type {
     return mlir::IntegerType::get(context, 1);
@@ -258,20 +375,48 @@ constexpr TypeBuilderFunc getModel<bool &>() {
   };
 }
 template <>
+constexpr TypeBuilderFunc getModel<std::complex<float>>() {
+  return [](mlir::MLIRContext *context) -> mlir::Type {
+    return mlir::ComplexType::get(mlir::FloatType::getF32(context));
+  };
+}
+template <>
 constexpr TypeBuilderFunc getModel<std::complex<float> &>() {
   return [](mlir::MLIRContext *context) -> mlir::Type {
-    auto ty = mlir::ComplexType::get(mlir::FloatType::getF32(context));
-    return fir::ReferenceType::get(ty);
+    TypeBuilderFunc f{getModel<std::complex<float>>()};
+    return fir::ReferenceType::get(f(context));
+  };
+}
+template <>
+constexpr TypeBuilderFunc getModel<std::complex<float> *>() {
+  return getModel<std::complex<float> &>();
+}
+template <>
+constexpr TypeBuilderFunc getModel<const std::complex<float> *>() {
+  return getModel<std::complex<float> *>();
+}
+template <>
+constexpr TypeBuilderFunc getModel<std::complex<double>>() {
+  return [](mlir::MLIRContext *context) -> mlir::Type {
+    return mlir::ComplexType::get(mlir::FloatType::getF64(context));
   };
 }
 template <>
 constexpr TypeBuilderFunc getModel<std::complex<double> &>() {
   return [](mlir::MLIRContext *context) -> mlir::Type {
-    auto ty = mlir::ComplexType::get(mlir::FloatType::getF64(context));
-    return fir::ReferenceType::get(ty);
+    TypeBuilderFunc f{getModel<std::complex<double>>()};
+    return fir::ReferenceType::get(f(context));
   };
 }
 template <>
+constexpr TypeBuilderFunc getModel<std::complex<double> *>() {
+  return getModel<std::complex<double> &>();
+}
+template <>
+constexpr TypeBuilderFunc getModel<const std::complex<double> *>() {
+  return getModel<std::complex<double> *>();
+}
+template <>
 constexpr TypeBuilderFunc getModel<c_float_complex_t>() {
   return [](mlir::MLIRContext *context) -> mlir::Type {
     return fir::ComplexType::get(context, sizeof(float));
@@ -332,6 +477,33 @@ constexpr TypeBuilderFunc getModel<void>() {
   };
 }
 
+REDUCTION_OPERATION_MODEL(std::int8_t)
+REDUCTION_OPERATION_MODEL(std::int16_t)
+REDUCTION_OPERATION_MODEL(std::int32_t)
+REDUCTION_OPERATION_MODEL(std::int64_t)
+REDUCTION_OPERATION_MODEL(Fortran::common::int128_t)
+
+REDUCTION_OPERATION_MODEL(float)
+REDUCTION_OPERATION_MODEL(double)
+REDUCTION_OPERATION_MODEL(long double)
+
+REDUCTION_OPERATION_MODEL(std::complex<float>)
+REDUCTION_OPERATION_MODEL(std::complex<double>)
+
+REDUCTION_CHAR_OPERATION_MODEL(char)
+REDUCTION_CHAR_OPERATION_MODEL(char16_t)
+REDUCTION_CHAR_OPERATION_MODEL(char32_t)
+
+template <>
+constexpr TypeBuilderFunc
+getModel<Fortran::runtime::ReductionDerivedTypeOperation>() {
+  return [](mlir::MLIRContext *context) -> mlir::Type {
+    auto voidTy =
+        fir::LLVMPointerType::get(context, mlir::IntegerType::get(context, 8));
+    return mlir::FunctionType::get(context, {voidTy, voidTy, voidTy}, voidTy);
+  };
+}
+
 template <typename...>
 struct RuntimeTableKey;
 template <typename RT, typename... ATs>
diff --git a/flang/include/flang/Optimizer/Builder/Runtime/Reduction.h b/flang/include/flang/Optimizer/Builder/Runtime/Reduction.h
index 667ea9081a893..a4adaa72fa41a 100644
--- a/flang/include/flang/Optimizer/Builder/Runtime/Reduction.h
+++ b/flang/include/flang/Optimizer/Builder/Runtime/Reduction.h
@@ -224,6 +224,14 @@ void genIParityDim(fir::FirOpBuilder &builder, mlir::Location loc,
                    mlir::Value resultBox, mlir::Value arrayBox, mlir::Value dim,
                    mlir::Value maskBox);
 
+/// Generate call to `Reduce` intrinsic runtime routine. This is the version
+/// that does not take a dim argument.
+mlir::Value genReduce(fir::FirOpBuilder &builder, mlir::Location loc,
+                      mlir::Value arrayBox, mlir::Value operation,
+                      mlir::Value dim, mlir::Value maskBox,
+                      mlir::Value identity, mlir::Value ordered,
+                      mlir::Value resultBox);
+
 } // namespace fir::runtime
 
 #endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_REDUCTION_H
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 1cd3976d0afbe..b1d0be6a3ec4c 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -522,7 +522,7 @@ static constexpr IntrinsicHandler handlers[]{
        {"operation", asAddr},
        {"dim", asValue},
        {"mask", asBox, handleDynamicOptional},
-       {"identity", asValue},
+       {"identity", asAddr},
        {"ordered", asValue}}},
      /*isElemental=*/false},
     {"repeat",
@@ -5705,7 +5705,66 @@ void IntrinsicLibrary::genRandomSeed(llvm::ArrayRef<fir::ExtendedValue> args) {
 fir::ExtendedValue
 IntrinsicLibrary::genReduce(mlir::Type resultType,
                             llvm::ArrayRef<fir::ExtendedValue> args) {
-  TODO(loc, "intrinsic: reduce");
+  assert(args.size() == 6);
+
+  fir::BoxValue arrayTmp = builder.createBox(loc, args[0]);
+  mlir::Value array = fir::getBase(arrayTmp);
+  mlir::Value operation = fir::getBase(args[1]);
+  int rank = arrayTmp.rank();
+  assert(rank >= 1);
+
+  mlir::Type ty = array.getType();
+  mlir::Type arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
+  mlir::Type eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
+
+  // Handle optional mask argument
+  auto dim = isStaticallyAbsent(args[3])
+                 ? builder.createIntegerConstant(loc, builder.getI32Type(), 1)
+                 : fir::getBase(args[2]);
+
+  auto mask = isStaticallyAbsent(args[3])
+                  ? builder.create<fir::AbsentOp>(
+                        loc, fir::BoxType::get(builder.getI1Type()))
+                  : builder.createBox(loc, args[3]);
+
+  mlir::Value identity =
+      isStaticallyAbsent(args[4])
+          ? builder.create<fir::AbsentOp>(loc, fir::ReferenceType::get(eleTy))
+          : fir::getBase(args[4]);
+
+  mlir::Value ordered = isStaticallyAbsent(args[5])
+                            ? builder.createBool(loc, true)
+                            : fir::getBase(args[5]);
+
+  // We call the type specific versions because the result is scalar
+  // in the case below.
+  if (rank == 1) {
+    if (fir::isa_complex(eleTy) || fir::isa_derived(eleTy)) {
+      mlir::Value result = builder.createTemporary(loc, eleTy);
+      fir::runtime::genReduce(builder, loc, array, operation, dim, mask,
+                              identity, ordered, result);
+      if (fir::isa_derived(eleTy))
+        return result;
+      return builder.create<fir::LoadOp>(loc, result);
+    }
+    if (fir::isa_char(eleTy)) {
+      // Create mutable fir.box to be passed to the runtime for the result.
+      fir::MutableBoxValue resultMutableBox =
+          fir::factory::createTempMutableBox(builder, loc, eleTy);
+      mlir::Value resultIrBox =
+          fir::factory::getMutableIRBox(builder, loc, resultMutableBox);
+      fir::runtime::genReduce(builder, loc, array, operation, dim, mask,
+                              identity, ordered, resultIrBox);
+      // Handle cleanup of allocatable result descriptor and return
+      return readAndAddCleanUp(resultMutableBox, resultType, "REDUCE");
+    }
+    auto resultBox = builder.create<fir::AbsentOp>(
+        loc, fir::BoxType::get(builder.getI1Type()));
+    return fir::runtime::genReduce(builder, loc, array, operation, dim, mask,
+                                   identity, ordered, resultBox);
+  }
+
+  TODO(loc, "intrinsic: reduce with non scalar result");
 }
 
 // REPEAT
diff --git a/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp b/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
index d4076067bf103..0c1af6159c939 100644
--- a/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
+++ b/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
@@ -12,6 +12,7 @@
 #include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
 #include "flang/Optimizer/Support/Utils.h"
+#include "flang/Runtime/reduce.h"
 #include "flang/Runtime/reduction.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 
@@ -466,6 +467,85 @@ struct ForcedIParity16 {
   }
 };
 
+/// Placeholder for real*16 version of Reduce Intrinsic
+struct ForcedReduceReal16 {
+  static constexpr const char *name = ExpandAndQuoteKey(RTNAME(ReduceReal16));
+  static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
+    return [](mlir::MLIRContext *ctx) {
+      auto ty = mlir::FloatType::getF128(ctx);
+      auto boxTy =
+          fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
+      auto opTy = mlir::FunctionType::get(ctx, {ty, ty}, ty);
+      auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
+      auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
+      auto refTy = fir::ReferenceType::get(ty);
+      auto i1Ty = mlir::IntegerType::get(ctx, 1);
+      return mlir::FunctionType::get(
+          ctx, {boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty}, {ty});
+    };
+  }
+};
+
+/// Placeholder for integer*16 version of Reduce Intrinsic
+struct ForcedReduceInteger16 {
+  static constexpr const char *name =
+      ExpandAndQuoteKey(RTNAME(ReduceInteger16));
+  static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
+    return [](mlir::MLIRContext *ctx) {
+      auto ty = mlir::IntegerType::get(ctx, 128);
+      auto boxTy =
+          fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
+      auto opTy = mlir::FunctionType::get(ctx, {ty, ty}, ty);
+      auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
+      auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
+      auto refTy = fir::ReferenceType::get(ty);
+      auto i1Ty = mlir::IntegerType::get(ctx, 1);
+      return mlir::FunctionType::get(
+          ctx, {boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty}, {ty});
+    };
+  }
+};
+
+/// Placeholder for complex(10) version of Reduce Intrinsic
+struct ForcedReduceComplex10 {
+  static constexpr const char *name =
+      ExpandAndQuoteKey(RTNAME(CppReduceComplex10));
+  static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
+    return [](mlir::MLIRContext *ctx) {
+      auto ty = mlir::ComplexType::get(mlir::FloatType::getF80(ctx));
+      auto boxTy =
+          fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
+      auto opTy = mlir::FunctionType::get(ctx, {ty, ty}, ty);
+      auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
+      auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
+      auto refTy = fir::ReferenceType::get(ty);
+      auto i1Ty = mlir::IntegerType::get(ctx, 1);
+      return mlir::FunctionType::get(
+          ctx, {ty, boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty}, {});
+    };
+  }
+};
+
+/// Placeholder for complex(16) version of Reduce Intrinsic
+struct ForcedReduceComplex16 {
+  static constexpr const char *name =
+      ExpandAndQuoteKey(RTNAME(CppReduceComplex16));
+  static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
+    return [](mlir::MLIRContext *ctx) {
+      auto ty = mlir::ComplexType::get(mlir::FloatType::getF128(ctx));
+      auto boxTy =
+          fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
+      auto opTy = mlir::FunctionType::get(ctx, {ty, ty}, ty);
+      auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
+      auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
+      auto refTy = fir::ReferenceType::get(ty);
+      auto i1Ty = mlir::IntegerType::get(ctx, 1);
+      return mlir::FunctionType::get(
+          ctx, {ty, boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty}, {});
+    };
+  }
+};
+
 /// Generate call to specialized runtime function that takes a mask and
 /// dim argument. The All, Any, and Count intrinsics use this pattern.
 template <typename FN>
@@ -1237,3 +1317,101 @@ void fir::runtime::genIParityDim(fir::FirOpBuilder &builder, mlir::Location loc,
 /// Generate call to `IParity` intrinsic runtime routine. This is the version
 /// that does not take a dim argument.
 GEN_IALL_IANY_IPARITY(IParity)
+
+/// Generate call to `Reduce` intrinsic runtime routine. This is the version
+/// that does have a scalar result.
+mlir::Value fir::runtime::genReduce(fir::FirOpBuilder &builder,
+                                    mlir::Location loc, mlir::Value arrayBox,
+                                    mlir::Value operation, mlir::Value dim,
+                                    mlir::Value maskBox, mlir::Value identity,
+                                    mlir::Value ordered,
+                                    mlir::Value resultBox) {
+  mlir::func::FuncOp func;
+  auto ty = arrayBox.getType();
+  auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
+  auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
+
+  mlir::MLIRContext *ctx = builder.getContext();
+  fir::factory::CharacterExprHelper charHelper{builder, loc};
+
+  if (eleTy.isF16())
+    func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceReal2)>(loc, builder);
+  else if (eleTy.isBF16())
+    func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceReal3)>(loc, builder);
+  else if (eleTy.isF32())
+    func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceReal4)>(loc, builder);
+  else if (eleTy.isF64())
+    func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceReal8)>(loc, builder);
+  else if (eleTy.isF80())
+    func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceReal10)>(loc, builder);
+  else if (eleTy.isF128())
+    func = fir::runtime::getRuntimeFunc<ForcedReduceReal16>(loc, builder);
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
+    func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceInteger1)>(loc, builder);
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
+    func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceInteger2)>(loc, builder);
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
+    func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceInteger4)>(loc, builder);
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
+    func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceInteger8)>(loc, builder);
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
+    func = fir::runtime::getRuntimeFunc<ForcedReduceInteger16>(loc, builder);
+  else if (eleTy == fir::ComplexType::get(ctx, 2))
+    func =
+        fir::runtime::getRuntimeFunc<mkRTKey(CppReduceComplex2)>(loc, builder);
+  else if (eleTy == fir::ComplexType::get(ctx, 3))
+    func =
+        fir::runtime::getRuntimeFunc<mkRTKey(CppReduceComplex3)>(loc, builder);
+  else if (eleTy == fir::ComplexType::get(ctx, 4))
+    func =
+        fir::runtime::getRuntimeFunc<mkRTKey(CppReduceComplex4)>(loc, builder);
+  else if (eleTy == fir::ComplexType::get(ctx, 8))
+    func =
+        fir::runtime::getRuntimeFunc<mkRTKey(CppReduceComplex8)>(loc, builder);
+  else if (eleTy == fir::ComplexType::get(ctx, 10))
+    func = fir::runtime::getRuntimeFunc<ForcedReduceComplex10>(loc, builder);
+  else if (eleTy == fir::ComplexType::get(ctx, 16))
+    func = fir::runtime::getRuntimeFunc<ForcedSumComplex16>(loc, builder);
+  else if (eleTy == fir::LogicalType::get(ctx, 1))
+    func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceLogical1)>(loc, builder);
+  else if (eleTy == fir::LogicalType::get(ctx, 2))
+    func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceLogical2)>(loc, builder);
+  else if (eleTy == fir::LogicalType::get(ctx, 4))
+    func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceLogical4)>(loc, builder);
+  else if (eleTy == fir::LogicalType::get(ctx, 8))
+    func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceLogical8)>(loc, builder);
+  else if (fir::isa_char(eleTy) && charHelper.getCharacterKind(eleTy) == 1)
+    func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceChar1)>(loc, builder);
+  else if (fir::isa_char(eleTy) && charHelper.getCharacterKind(eleTy) == 2)
+    func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceChar2)>(loc, builder);
+  else if (fir::isa_char(eleTy) && charHelper.getCharacterKind(eleTy) == 4)
+    func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceChar4)>(loc, builder);
+  else if (fir::isa_derived(eleTy))
+    func =
+        fir::runtime::getRuntimeFunc<mkRTKey(ReduceDerivedType)>(loc, builder);
+  else
+    fir::intrinsicTypeTODO(builder, eleTy, loc, "REDUCE");
+
+  auto fTy = func.getFunctionType();
+  auto sourceFile = fir::factory::locationToFilename(builder, loc);
+  if (fir::isa_complex(eleTy) || fir::isa_char(eleTy) ||
+      fir::isa_derived(eleTy)) {
+    auto sourceLine =
+        fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
+    auto opAddr =
+        builder.create<fir::BoxAddrOp>(loc, fTy.getInput(2), operation);
+    auto args = fir::runtime::createArguments(
+        builder, loc, fTy, resultBox, arrayBox, opAddr, sourceFile, sourceLine,
+        dim, maskBox, identity, ordered);
+    builder.create<fir::CallOp>(loc, func, args);
+    return resultBox;
+  }
+
+  auto sourceLine =
+      fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
+  auto opAddr = builder.create<fir::BoxAddrOp>(loc, fTy.getInput(1), operation);
+  auto args = fir::runtime::createArguments(builder, loc, fTy, arrayBox, opAddr,
+                                            sourceFile, sourceLine, dim,
+                                            maskBox, identity, ordered);
+  return builder.create<fir::CallOp>(loc, func, args).getResult(0);
+}
diff --git a/flang/test/Lower/Intrinsics/Todo/reduce.f90 b/flang/test/Lower/Intrinsics/Todo/reduce.f90
deleted file mode 100644
index 7aa6f4a9f3ad3..0000000000000
--- a/flang/test/Lower/Intrinsics/Todo/reduce.f90
+++ /dev/null
@@ -1,13 +0,0 @@
-! RUN: %not_todo_cmd bbc -emit-fir %s -o -  2>&1 | FileCheck %s
-
-interface
-  pure function chfunc(a,b)
-    character(*),intent(in) :: a,b
-    character(3) :: chfunc
-  end function
-  end interface
-  character(3) x(5)
-  print*, reduce(x,chfunc)
-end program
-
-! CHECK: not yet implemented: intrinsic: reduce
diff --git a/flang/test/Lower/Intrinsics/reduce.f90 b/flang/test/Lower/Intrinsics/reduce.f90
new file mode 100644
index 0000000000000..3c548816b77c9
--- /dev/null
+++ b/flang/test/Lower/Intrinsics/reduce.f90
@@ -0,0 +1,379 @@
+! RUN: bbc -emit-hlfir %s -o - | FileCheck %s
+
+module reduce_mod
+
+type :: t1
+  integer :: a
+end type
+
+contains
+
+pure function red_int1(a,b)
+  integer(1), intent(in) :: a, b
+  integer(1) :: red_int1
+  red_int1 = a + b
+end function
+
+subroutine integer1(a, id)
+  integer(1), intent(in) :: a(:)
+  integer(1) :: res, id
+
+  res = reduce(a, red_int1)
+
+  res = reduce(a, red_int1, 1, identity=id)
+  
+  res = reduce(a, red_int1, 1, identity=id, ordered = .false.)
+
+  res = reduce(a, red_int1, 1, [.true., .true., .false.])
+end subroutine
+
+! CHECK-LABEL: func.func @_QMreduce_modPinteger1(
+! CHECK-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<?xi8>> {fir.bindc_name = "a"}, %[[ARG1:.*]]: !fir.ref<i8> {fir.bindc_name = "id"})
+! CHECK: %[[A:.*]]:2 = hlfir.declare %[[ARG0]] dummy_scope %{{.*}} {fortran_attrs = #fir.var_attrs<intent_in>, uniq_name = "_QMreduce_modFinteger1Ea"} : (!fir.box<!fir.array<?xi8>>, !fir.dscope) -> (!fir.box<!fir.array<?xi8>>, !fir.box<!fir.array<?xi8>>)
+! CHECK: %[[ID:.*]]:2 = hlfir.declare %[[ARG1]] dummy_scope %{{.*}} {uniq_name = "_QMreduce_modFinteger1Eid"} : (!fir.ref<i8>, !fir.dscope) -> (!fir.ref<i8>, !fir.ref<i8>)
+! CHECK: %[[ALLOC_RES:.*]] = fir.alloca i8 {bindc_name = "res", uniq_name = "_QMreduce_modFinteger1Eres"}
+! CHECK: %[[RES:.*]]:2 = hlfir.declare %[[ALLOC_RES]] {uniq_name = "_QMreduce_modFinteger1Eres"} : (!fir.ref<i8>) -> (!fir.ref<i8>, !fir.ref<i8>)
+! CHECK: %[[ADDR_OP:.*]] = fir.address_of(@_QMreduce_modPred_int1) : (!fir.ref<i8>, !fir.ref<i8>) -> i8
+! CHECK: %[[BOX_PROC:.*]] = fir.emboxproc %[[ADDR_OP]] : ((!fir.ref<i8>, !fir.ref<i8>) -> i8) -> !fir.boxproc<() -> ()>
+! CHECK: %[[MASK:.*]] = fir.absent !fir.box<i1>
+! CHECK: %[[IDENTITY:.*]] = fir.absent !fir.ref<i8>
+! CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[BOX_PROC]] : (!fir.boxproc<() -> ()>) -> ((!fir.ref<i8>, !fir.ref<i8>) -> !fir.ref<i8>)
+! CHECK: %[[A_NONE:.*]] = fir.convert %[[A]]#1 : (!fir.box<!fir.array<?xi8>>) -> !fir.box<none>
+! CHECK: %[[MASK_NONE:.*]] = fir.convert %[[MASK]] : (!fir.box<i1>) -> !fir.box<none>
+! CHECK: %[[REDUCE_RES:.*]] = fir.call @_FortranAReduceInteger1(%[[A_NONE]], %[[BOX_ADDR]], %{{.*}}, %{{.*}}, %c1{{.*}}, %[[MASK_NONE]], %[[IDENTITY]], %true) fastmath<contract> : (!fir.box<none>, (!fir.ref<i8>, !fir.ref<i8>) -> !fir.ref<i8>, !fir.ref<i8>, i32, i32, !fir.box<none>, !fir.ref<i8>, i1) -> i8
+! CHECK: hlfir.assign %[[REDUCE_RES]] to %[[RES]]#0 : i8, !fir.ref<i8>
+! CHECK: %[[ADDR_OP:.*]] = fir.address_of(@_QMreduce_modPred_int1) : (!fir.ref<i8>, !fir.ref<i8>) -> i8
+! CHECK: %[[BOX_PROC:.*]] = fir.emboxproc %[[ADDR_OP]] : ((!fir.ref<i8>, !fir.ref<i8>) -> i8) -> !fir.boxproc<() -> ()>
+! CHECK: %[[MASK:.*]] = fir.absent !fir.box<i1>
+! CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[BOX_PROC]] : (!fir.boxproc<() -> ()>) -> ((!fir.ref<i8>, !fir.ref<i8>) -> !fir.ref<i8>)
+! CHECK: %[[A_NONE:.*]] = fir.convert %[[A]]#1 : (!fir.box<!fir.array<?xi8>>) -> !fir.box<none>
+! CHECK: %[[MASK_NONE:.*]] = fir.convert %[[MASK]] : (!fir.box<i1>) -> !fir.box<none>
+! CHECK: %{{.*}} = fir.call @_FortranAReduceInteger1(%[[A_NONE]], %[[BOX_ADDR]], %{{.*}}, %{{.*}}, %c1{{.*}}, %[[MASK_NONE]], %[[ID]]#1, %true{{.*}}) fastmath<contract> : (!fir.box<none>, (!fir.ref<i8>, !fir.ref<i8>) -> !fir.ref<i8>, !fir.ref<i8>, i32, i32, !fir.box<none>, !fir.ref<i8>, i1) -> i8
+! CHECK: fir.call @_FortranAReduceInteger1(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}#1, %false)
+! CHECK: %[[MASK:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {fortran_attrs = #fir.var_attrs<parameter>, uniq_name = "_QQro.3xl4.0"} : (!fir.ref<!fir.array<3x!fir.logical<4>>>, !fir.shape<1>) -> (!fir.ref<!fir.array<3x!fir.logical<4>>>, !fir.ref<!fir.array<3x!fir.logical<4>>>)
+! CHECK: %[[SHAPE_C3:.*]] = fir.shape %c3{{.*}} : (index) -> !fir.shape<1>
+! CHECK: %[[BOXED_MASK:.*]] = fir.embox %[[MASK]]#1(%[[SHAPE_C3]]) : (!fir.ref<!fir.array<3x!fir.logical<4>>>, !fir.shape<1>) -> !fir.box<!fir.array<3x!fir.logical<4>>>
+! CHECK: %[[CONV_MASK:.*]] = fir.convert %[[BOXED_MASK]] : (!fir.box<!fir.array<3x!fir.logical<4>>>) -> !fir.box<none>
+! CHECK: fir.call @_FortranAReduceInteger1(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[CONV_MASK]], %{{.*}}, %true{{.*}})
+
+pure function red_int2(a,b)
+  integer(2), intent(in) :: a, b
+  integer(2) :: red_int2
+  red_int2 = a + b
+end function
+
+subroutine integer2(a)
+  integer(2), intent(in) :: a(:)
+  integer(2) :: res
+  res = reduce(a, red_int2, 1)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceInteger2
+
+pure function red_int4(a,b)
+  integer(4), intent(in) :: a, b
+  integer(4) :: red_int4
+  red_int4 = a + b
+end function
+
+subroutine integer4(a)
+  integer(4), intent(in) :: a(:)
+  integer(4) :: res
+  res = reduce(a, red_int4, 1)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceInteger4
+
+pure function red_int8(a,b)
+  integer(8), intent(in) :: a, b
+  integer(8) :: red_int8
+  red_int8 = a + b
+end function
+
+subroutine integer8(a)
+  integer(8), intent(in) :: a(:)
+  integer(8) :: res
+  res = reduce(a, red_int8, 1)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceInteger8
+
+pure function red_int16(a,b)
+  integer(16), intent(in) :: a, b
+  integer(16) :: red_int16
+  red_int16 = a + b
+end function
+
+subroutine integer16(a)
+  integer(16), intent(in) :: a(:)
+  integer(16) :: res
+  res = reduce(a, red_int16, 1)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceInteger16
+
+pure function red_real2(a,b)
+  real(2), intent(in) :: a, b
+  real(2) :: red_real2
+  red_real2 = a + b
+end function
+
+subroutine real2(a)
+  real(2), intent(in) :: a(:)
+  real(2) :: res
+  res = reduce(a, red_real2, 1)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceReal2
+
+pure function red_real3(a,b)
+  real(3), intent(in) :: a, b
+  real(3) :: red_real3
+  red_real3 = a + b
+end function
+
+subroutine real3(a)
+  real(3), intent(in) :: a(:)
+  real(3) :: res
+  res = reduce(a, red_real3, 1)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceReal3
+
+pure function red_real4(a,b)
+  real(4), intent(in) :: a, b
+  real(4) :: red_real4
+  red_real4 = a + b
+end function
+
+subroutine real4(a)
+  real(4), intent(in) :: a(:)
+  real(4) :: res
+  res = reduce(a, red_real4, 1)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceReal4
+
+pure function red_real8(a,b)
+  real(8), intent(in) :: a, b
+  real(8) :: red_real8
+  red_real8 = a + b
+end function
+
+subroutine real8(a)
+  real(8), intent(in) :: a(:)
+  real(8) :: res
+  res = reduce(a, red_real8, 1)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceReal8
+
+pure function red_real10(a,b)
+  real(10), intent(in) :: a, b
+  real(10) :: red_real10
+  red_real10 = a + b
+end function
+
+subroutine real10(a)
+  real(10), intent(in) :: a(:)
+  real(10) :: res
+  res = reduce(a, red_real10, 1)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceReal10
+
+pure function red_real16(a,b)
+  real(16), intent(in) :: a, b
+  real(16) :: red_real16
+  red_real16 = a + b
+end function
+
+subroutine real16(a)
+  real(16), intent(in) :: a(:)
+  real(16) :: res
+  res = reduce(a, red_real16, 1)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceReal16
+
+pure function red_complex2(a,b)
+  complex(2), intent(in) :: a, b
+  complex(2) :: red_complex2
+  red_complex2 = a + b
+end function
+
+subroutine complex2(a)
+  complex(2), intent(in) :: a(:)
+  complex(2) :: res
+  res = reduce(a, red_complex2, 1)
+end subroutine
+
+! CHECK: fir.call @_FortranACppReduceComplex2
+
+pure function red_complex3(a,b)
+  complex(3), intent(in) :: a, b
+  complex(3) :: red_complex3
+  red_complex3 = a + b
+end function
+
+subroutine complex3(a)
+  complex(3), intent(in) :: a(:)
+  complex(3) :: res
+  res = reduce(a, red_complex3, 1)
+end subroutine
+
+! CHECK: fir.call @_FortranACppReduceComplex3
+
+pure function red_complex4(a,b)
+  complex(4), intent(in) :: a, b
+  complex(4) :: red_complex4
+  red_complex4 = a + b
+end function
+
+subroutine complex4(a)
+  complex(4), intent(in) :: a(:)
+  complex(4) :: res
+  res = reduce(a, red_complex4, 1)
+end subroutine
+
+! CHECK: fir.call @_FortranACppReduceComplex4
+
+pure function red_complex8(a,b)
+  complex(8), intent(in) :: a, b
+  complex(8) :: red_complex8
+  red_complex8 = a + b
+end function
+
+subroutine complex8(a)
+  complex(8), intent(in) :: a(:)
+  complex(8) :: res
+  res = reduce(a, red_complex8, 1)
+end subroutine
+
+! CHECK: fir.call @_FortranACppReduceComplex8
+
+pure function red_complex10(a,b)
+  complex(10), intent(in) :: a, b
+  complex(10) :: red_complex10
+  red_complex10 = a + b
+end function
+
+subroutine complex10(a)
+  complex(10), intent(in) :: a(:)
+  complex(10) :: res
+!  res = reduce(a, red_complex10, 1)
+end subroutine
+
+pure function red_log1(a,b)
+  logical(1), intent(in) :: a, b
+  logical(1) :: red_log1
+  red_log1 = a .and. b
+end function
+
+subroutine log1(a)
+  logical(1), intent(in) :: a(:)
+  logical(1) :: res
+  res = reduce(a, red_log1, 1)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceLogical1
+
+pure function red_log2(a,b)
+  logical(2), intent(in) :: a, b
+  logical(2) :: red_log2
+  red_log2 = a .and. b
+end function
+
+subroutine log2(a)
+  logical(2), intent(in) :: a(:)
+  logical(2) :: res
+  res = reduce(a, red_log2, 1)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceLogical2
+
+pure function red_log4(a,b)
+  logical(4), intent(in) :: a, b
+  logical(4) :: red_log4
+  red_log4 = a .and. b
+end function
+
+subroutine log4(a)
+  logical(4), intent(in) :: a(:)
+  logical(4) :: res
+  res = reduce(a, red_log4, 1)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceLogical4
+
+pure function red_log8(a,b)
+  logical(8), intent(in) :: a, b
+  logical(8) :: red_log8
+  red_log8 = a .and. b
+end function
+
+subroutine log8(a)
+  logical(8), intent(in) :: a(:)
+  logical(8) :: res
+  res = reduce(a, red_log8, 1)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceLogical8
+
+pure function red_char1(a,b)
+  character(1), intent(in) :: a, b
+  character(1) :: red_char1
+  red_char1 = a // b
+end function
+
+subroutine char1(a)
+  character(1), intent(in) :: a(:)
+  character(1) :: res
+  res = reduce(a, red_char1, 1)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceChar1
+
+pure function red_char2(a,b)
+  character(kind=2), intent(in) :: a, b
+  character(kind=2) :: red_char2
+  red_char2 = a // b
+end function
+
+subroutine char2(a)
+  character(kind=2), intent(in) :: a(:)
+  character(kind=2) :: res
+  res = reduce(a, red_char2, 1)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceChar2
+
+pure function red_char4(a,b)
+  character(kind=4), intent(in) :: a, b
+  character(kind=4) :: red_char4
+  red_char4 = a // b
+end function
+
+subroutine char4(a)
+  character(kind=4), intent(in) :: a(:)
+  character(kind=4) :: res
+  res = reduce(a, red_char4, 1)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceChar4
+
+pure function red_type(a,b)
+  type(t1), intent(in) :: a, b
+  type(t1) :: red_type
+  red_type%a = a%a + b%a
+end function
+
+subroutine testtype(a)
+  type(t1), intent(in) :: a(:)
+  type(t1) :: res
+  res = reduce(a, red_type, 1)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceDerivedType
+
+end module

>From c6d422e0aff1085a48b7ec67474504336179041f Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 6 Jun 2024 13:50:36 -0700
Subject: [PATCH 2/3] Update for asbent dim

---
 .../flang/Optimizer/Builder/IntrinsicCall.h   |  2 +
 .../Optimizer/Builder/Runtime/Reduction.h     |  5 +-
 flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 19 +++----
 .../Optimizer/Builder/Runtime/Reduction.cpp   |  6 +--
 flang/test/Lower/Intrinsics/reduce.f90        | 52 +++++++++----------
 5 files changed, 41 insertions(+), 43 deletions(-)

diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index c47e41eab18b2..a8e1e131e8e4b 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -326,6 +326,8 @@ struct IntrinsicLibrary {
   void genRandomNumber(llvm::ArrayRef<fir::ExtendedValue>);
   void genRandomSeed(llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genReduce(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
+  fir::ExtendedValue genReduceDim(mlir::Type,
+                                  llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genRepeat(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genReshape(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
   mlir::Value genRRSpacing(mlir::Type resultType,
diff --git a/flang/include/flang/Optimizer/Builder/Runtime/Reduction.h b/flang/include/flang/Optimizer/Builder/Runtime/Reduction.h
index a4adaa72fa41a..b586fc3f3e608 100644
--- a/flang/include/flang/Optimizer/Builder/Runtime/Reduction.h
+++ b/flang/include/flang/Optimizer/Builder/Runtime/Reduction.h
@@ -228,9 +228,8 @@ void genIParityDim(fir::FirOpBuilder &builder, mlir::Location loc,
 /// that does not take a dim argument.
 mlir::Value genReduce(fir::FirOpBuilder &builder, mlir::Location loc,
                       mlir::Value arrayBox, mlir::Value operation,
-                      mlir::Value dim, mlir::Value maskBox,
-                      mlir::Value identity, mlir::Value ordered,
-                      mlir::Value resultBox);
+                      mlir::Value maskBox, mlir::Value identity,
+                      mlir::Value ordered, mlir::Value resultBox);
 
 } // namespace fir::runtime
 
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index b1d0be6a3ec4c..93bbc2465ed07 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -5718,9 +5718,7 @@ IntrinsicLibrary::genReduce(mlir::Type resultType,
   mlir::Type eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
 
   // Handle optional mask argument
-  auto dim = isStaticallyAbsent(args[3])
-                 ? builder.createIntegerConstant(loc, builder.getI32Type(), 1)
-                 : fir::getBase(args[2]);
+  bool absentDim = isStaticallyAbsent(args[1]);
 
   auto mask = isStaticallyAbsent(args[3])
                   ? builder.create<fir::AbsentOp>(
@@ -5738,11 +5736,11 @@ IntrinsicLibrary::genReduce(mlir::Type resultType,
 
   // We call the type specific versions because the result is scalar
   // in the case below.
-  if (rank == 1) {
+  if (absentDim || rank == 1) {
     if (fir::isa_complex(eleTy) || fir::isa_derived(eleTy)) {
       mlir::Value result = builder.createTemporary(loc, eleTy);
-      fir::runtime::genReduce(builder, loc, array, operation, dim, mask,
-                              identity, ordered, result);
+      fir::runtime::genReduce(builder, loc, array, operation, mask, identity,
+                              ordered, result);
       if (fir::isa_derived(eleTy))
         return result;
       return builder.create<fir::LoadOp>(loc, result);
@@ -5753,18 +5751,17 @@ IntrinsicLibrary::genReduce(mlir::Type resultType,
           fir::factory::createTempMutableBox(builder, loc, eleTy);
       mlir::Value resultIrBox =
           fir::factory::getMutableIRBox(builder, loc, resultMutableBox);
-      fir::runtime::genReduce(builder, loc, array, operation, dim, mask,
-                              identity, ordered, resultIrBox);
+      fir::runtime::genReduce(builder, loc, array, operation, mask, identity,
+                              ordered, resultIrBox);
       // Handle cleanup of allocatable result descriptor and return
       return readAndAddCleanUp(resultMutableBox, resultType, "REDUCE");
     }
     auto resultBox = builder.create<fir::AbsentOp>(
         loc, fir::BoxType::get(builder.getI1Type()));
-    return fir::runtime::genReduce(builder, loc, array, operation, dim, mask,
+    return fir::runtime::genReduce(builder, loc, array, operation, mask,
                                    identity, ordered, resultBox);
   }
-
-  TODO(loc, "intrinsic: reduce with non scalar result");
+  TODO(loc, "reduce with array result");
 }
 
 // REPEAT
diff --git a/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp b/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
index 0c1af6159c939..85cbaef453b4c 100644
--- a/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
+++ b/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
@@ -1322,14 +1322,14 @@ GEN_IALL_IANY_IPARITY(IParity)
 /// that does have a scalar result.
 mlir::Value fir::runtime::genReduce(fir::FirOpBuilder &builder,
                                     mlir::Location loc, mlir::Value arrayBox,
-                                    mlir::Value operation, mlir::Value dim,
-                                    mlir::Value maskBox, mlir::Value identity,
-                                    mlir::Value ordered,
+                                    mlir::Value operation, mlir::Value maskBox,
+                                    mlir::Value identity, mlir::Value ordered,
                                     mlir::Value resultBox) {
   mlir::func::FuncOp func;
   auto ty = arrayBox.getType();
   auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
   auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
+  auto dim = builder.createIntegerConstant(loc, builder.getI32Type(), 1);
 
   mlir::MLIRContext *ctx = builder.getContext();
   fir::factory::CharacterExprHelper charHelper{builder, loc};
diff --git a/flang/test/Lower/Intrinsics/reduce.f90 b/flang/test/Lower/Intrinsics/reduce.f90
index 3c548816b77c9..b2db3088e6de9 100644
--- a/flang/test/Lower/Intrinsics/reduce.f90
+++ b/flang/test/Lower/Intrinsics/reduce.f90
@@ -20,11 +20,11 @@ subroutine integer1(a, id)
 
   res = reduce(a, red_int1)
 
-  res = reduce(a, red_int1, 1, identity=id)
+  res = reduce(a, red_int1, identity=id)
   
-  res = reduce(a, red_int1, 1, identity=id, ordered = .false.)
+  res = reduce(a, red_int1, identity=id, ordered = .false.)
 
-  res = reduce(a, red_int1, 1, [.true., .true., .false.])
+  res = reduce(a, red_int1, [.true., .true., .false.])
 end subroutine
 
 ! CHECK-LABEL: func.func @_QMreduce_modPinteger1(
@@ -65,7 +65,7 @@ pure function red_int2(a,b)
 subroutine integer2(a)
   integer(2), intent(in) :: a(:)
   integer(2) :: res
-  res = reduce(a, red_int2, 1)
+  res = reduce(a, red_int2)
 end subroutine
 
 ! CHECK: fir.call @_FortranAReduceInteger2
@@ -79,7 +79,7 @@ pure function red_int4(a,b)
 subroutine integer4(a)
   integer(4), intent(in) :: a(:)
   integer(4) :: res
-  res = reduce(a, red_int4, 1)
+  res = reduce(a, red_int4)
 end subroutine
 
 ! CHECK: fir.call @_FortranAReduceInteger4
@@ -93,7 +93,7 @@ pure function red_int8(a,b)
 subroutine integer8(a)
   integer(8), intent(in) :: a(:)
   integer(8) :: res
-  res = reduce(a, red_int8, 1)
+  res = reduce(a, red_int8)
 end subroutine
 
 ! CHECK: fir.call @_FortranAReduceInteger8
@@ -107,7 +107,7 @@ pure function red_int16(a,b)
 subroutine integer16(a)
   integer(16), intent(in) :: a(:)
   integer(16) :: res
-  res = reduce(a, red_int16, 1)
+  res = reduce(a, red_int16)
 end subroutine
 
 ! CHECK: fir.call @_FortranAReduceInteger16
@@ -121,7 +121,7 @@ pure function red_real2(a,b)
 subroutine real2(a)
   real(2), intent(in) :: a(:)
   real(2) :: res
-  res = reduce(a, red_real2, 1)
+  res = reduce(a, red_real2)
 end subroutine
 
 ! CHECK: fir.call @_FortranAReduceReal2
@@ -135,7 +135,7 @@ pure function red_real3(a,b)
 subroutine real3(a)
   real(3), intent(in) :: a(:)
   real(3) :: res
-  res = reduce(a, red_real3, 1)
+  res = reduce(a, red_real3)
 end subroutine
 
 ! CHECK: fir.call @_FortranAReduceReal3
@@ -149,7 +149,7 @@ pure function red_real4(a,b)
 subroutine real4(a)
   real(4), intent(in) :: a(:)
   real(4) :: res
-  res = reduce(a, red_real4, 1)
+  res = reduce(a, red_real4)
 end subroutine
 
 ! CHECK: fir.call @_FortranAReduceReal4
@@ -163,7 +163,7 @@ pure function red_real8(a,b)
 subroutine real8(a)
   real(8), intent(in) :: a(:)
   real(8) :: res
-  res = reduce(a, red_real8, 1)
+  res = reduce(a, red_real8)
 end subroutine
 
 ! CHECK: fir.call @_FortranAReduceReal8
@@ -177,7 +177,7 @@ pure function red_real10(a,b)
 subroutine real10(a)
   real(10), intent(in) :: a(:)
   real(10) :: res
-  res = reduce(a, red_real10, 1)
+  res = reduce(a, red_real10)
 end subroutine
 
 ! CHECK: fir.call @_FortranAReduceReal10
@@ -191,7 +191,7 @@ pure function red_real16(a,b)
 subroutine real16(a)
   real(16), intent(in) :: a(:)
   real(16) :: res
-  res = reduce(a, red_real16, 1)
+  res = reduce(a, red_real16)
 end subroutine
 
 ! CHECK: fir.call @_FortranAReduceReal16
@@ -205,7 +205,7 @@ pure function red_complex2(a,b)
 subroutine complex2(a)
   complex(2), intent(in) :: a(:)
   complex(2) :: res
-  res = reduce(a, red_complex2, 1)
+  res = reduce(a, red_complex2)
 end subroutine
 
 ! CHECK: fir.call @_FortranACppReduceComplex2
@@ -219,7 +219,7 @@ pure function red_complex3(a,b)
 subroutine complex3(a)
   complex(3), intent(in) :: a(:)
   complex(3) :: res
-  res = reduce(a, red_complex3, 1)
+  res = reduce(a, red_complex3)
 end subroutine
 
 ! CHECK: fir.call @_FortranACppReduceComplex3
@@ -233,7 +233,7 @@ pure function red_complex4(a,b)
 subroutine complex4(a)
   complex(4), intent(in) :: a(:)
   complex(4) :: res
-  res = reduce(a, red_complex4, 1)
+  res = reduce(a, red_complex4)
 end subroutine
 
 ! CHECK: fir.call @_FortranACppReduceComplex4
@@ -247,7 +247,7 @@ pure function red_complex8(a,b)
 subroutine complex8(a)
   complex(8), intent(in) :: a(:)
   complex(8) :: res
-  res = reduce(a, red_complex8, 1)
+  res = reduce(a, red_complex8)
 end subroutine
 
 ! CHECK: fir.call @_FortranACppReduceComplex8
@@ -261,7 +261,7 @@ pure function red_complex10(a,b)
 subroutine complex10(a)
   complex(10), intent(in) :: a(:)
   complex(10) :: res
-!  res = reduce(a, red_complex10, 1)
+!  res = reduce(a, red_complex10)
 end subroutine
 
 pure function red_log1(a,b)
@@ -273,7 +273,7 @@ pure function red_log1(a,b)
 subroutine log1(a)
   logical(1), intent(in) :: a(:)
   logical(1) :: res
-  res = reduce(a, red_log1, 1)
+  res = reduce(a, red_log1)
 end subroutine
 
 ! CHECK: fir.call @_FortranAReduceLogical1
@@ -287,7 +287,7 @@ pure function red_log2(a,b)
 subroutine log2(a)
   logical(2), intent(in) :: a(:)
   logical(2) :: res
-  res = reduce(a, red_log2, 1)
+  res = reduce(a, red_log2)
 end subroutine
 
 ! CHECK: fir.call @_FortranAReduceLogical2
@@ -301,7 +301,7 @@ pure function red_log4(a,b)
 subroutine log4(a)
   logical(4), intent(in) :: a(:)
   logical(4) :: res
-  res = reduce(a, red_log4, 1)
+  res = reduce(a, red_log4)
 end subroutine
 
 ! CHECK: fir.call @_FortranAReduceLogical4
@@ -315,7 +315,7 @@ pure function red_log8(a,b)
 subroutine log8(a)
   logical(8), intent(in) :: a(:)
   logical(8) :: res
-  res = reduce(a, red_log8, 1)
+  res = reduce(a, red_log8)
 end subroutine
 
 ! CHECK: fir.call @_FortranAReduceLogical8
@@ -329,7 +329,7 @@ pure function red_char1(a,b)
 subroutine char1(a)
   character(1), intent(in) :: a(:)
   character(1) :: res
-  res = reduce(a, red_char1, 1)
+  res = reduce(a, red_char1)
 end subroutine
 
 ! CHECK: fir.call @_FortranAReduceChar1
@@ -343,7 +343,7 @@ pure function red_char2(a,b)
 subroutine char2(a)
   character(kind=2), intent(in) :: a(:)
   character(kind=2) :: res
-  res = reduce(a, red_char2, 1)
+  res = reduce(a, red_char2)
 end subroutine
 
 ! CHECK: fir.call @_FortranAReduceChar2
@@ -357,7 +357,7 @@ pure function red_char4(a,b)
 subroutine char4(a)
   character(kind=4), intent(in) :: a(:)
   character(kind=4) :: res
-  res = reduce(a, red_char4, 1)
+  res = reduce(a, red_char4)
 end subroutine
 
 ! CHECK: fir.call @_FortranAReduceChar4
@@ -371,7 +371,7 @@ pure function red_type(a,b)
 subroutine testtype(a)
   type(t1), intent(in) :: a(:)
   type(t1) :: res
-  res = reduce(a, red_type, 1)
+  res = reduce(a, red_type)
 end subroutine
 
 ! CHECK: fir.call @_FortranAReduceDerivedType

>From 369388aacab517687e5f330309a68e12faf296ac Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 6 Jun 2024 13:54:54 -0700
Subject: [PATCH 3/3] Update comment

---
 flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 93bbc2465ed07..c424d08319ab8 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -5717,7 +5717,7 @@ IntrinsicLibrary::genReduce(mlir::Type resultType,
   mlir::Type arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
   mlir::Type eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
 
-  // Handle optional mask argument
+  // Handle optional arguments
   bool absentDim = isStaticallyAbsent(args[1]);
 
   auto mask = isStaticallyAbsent(args[3])



More information about the flang-commits mailing list