[llvm-branch-commits] [flang] [flang] Lower REDUCE intrinsic with DIM argument (PR #94771)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Jun 7 10:06:40 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

This is a follow up patch to #<!-- -->94652 and handles the lowering of the reduce intrinsic with DIM argument and non scalar result. 

---
Full diff: https://github.com/llvm/llvm-project/pull/94771.diff


4 Files Affected:

- (modified) flang/include/flang/Optimizer/Builder/Runtime/Reduction.h (+7) 
- (modified) flang/lib/Optimizer/Builder/IntrinsicCall.cpp (+11-1) 
- (modified) flang/lib/Optimizer/Builder/Runtime/Reduction.cpp (+184-2) 
- (modified) flang/test/Lower/Intrinsics/reduce.f90 (+221) 


``````````diff
diff --git a/flang/include/flang/Optimizer/Builder/Runtime/Reduction.h b/flang/include/flang/Optimizer/Builder/Runtime/Reduction.h
index 27652208b524e..fedf453a6dc8d 100644
--- a/flang/include/flang/Optimizer/Builder/Runtime/Reduction.h
+++ b/flang/include/flang/Optimizer/Builder/Runtime/Reduction.h
@@ -240,6 +240,13 @@ mlir::Value genReduce(fir::FirOpBuilder &builder, mlir::Location loc,
                       mlir::Value maskBox, mlir::Value identity,
                       mlir::Value ordered);
 
+/// Generate call to `Reduce` intrinsic runtime routine. This is the version
+/// that takes arrays of any rank with a dim argument specified.
+void genReduceDim(fir::FirOpBuilder &builder, mlir::Location loc,
+                  mlir::Value arrayBox, mlir::Value operation, mlir::Value dim,
+                  mlir::Value maskBox, mlir::Value identity,
+                  mlir::Value ordered, mlir::Value resultBox);
+
 } // namespace fir::runtime
 
 #endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_REDUCTION_H
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 0e29849a57688..e250a476b5802 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -5790,7 +5790,17 @@ IntrinsicLibrary::genReduce(mlir::Type resultType,
     return fir::runtime::genReduce(builder, loc, array, operation, mask,
                                    identity, ordered);
   }
-  TODO(loc, "reduce with array result");
+  // Handle cases that have an array result.
+  // Create mutable fir.box to be passed to the runtime for the result.
+  mlir::Type resultArrayType = builder.getVarLenSeqTy(resultType, rank - 1);
+  fir::MutableBoxValue resultMutableBox =
+      fir::factory::createTempMutableBox(builder, loc, resultArrayType);
+  mlir::Value resultIrBox =
+      fir::factory::getMutableIRBox(builder, loc, resultMutableBox);
+  mlir::Value dim = fir::getBase(args[2]);
+  fir::runtime::genReduceDim(builder, loc, array, operation, dim, mask,
+                             identity, ordered, resultIrBox);
+  return readAndAddCleanUp(resultMutableBox, resultType, "REDUCE");
 }
 
 // REPEAT
diff --git a/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp b/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
index a7cd53328d69a..e83af63916dcd 100644
--- a/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
+++ b/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
@@ -486,6 +486,28 @@ struct ForcedReduceReal16 {
   }
 };
 
+/// Placeholder for DIM real*16 version of Reduce Intrinsic
+struct ForcedReduceReal16Dim {
+  static constexpr const char *name =
+      ExpandAndQuoteKey(RTNAME(ReduceReal16Dim));
+  static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
+    return [](mlir::MLIRContext *ctx) {
+      auto ty = mlir::FloatType::getF128(ctx);
+      auto boxTy =
+          fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
+      auto opTy = mlir::FunctionType::get(ctx, {ty, ty}, ty);
+      auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
+      auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
+      auto refTy = fir::ReferenceType::get(ty);
+      auto refBoxTy = fir::ReferenceType::get(boxTy);
+      auto i1Ty = mlir::IntegerType::get(ctx, 1);
+      return mlir::FunctionType::get(
+          ctx, {refBoxTy, boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty},
+          {});
+    };
+  }
+};
+
 /// Placeholder for integer*16 version of Reduce Intrinsic
 struct ForcedReduceInteger16 {
   static constexpr const char *name =
@@ -506,6 +528,28 @@ struct ForcedReduceInteger16 {
   }
 };
 
+/// Placeholder for DIM integer*16 version of Reduce Intrinsic
+struct ForcedReduceInteger16Dim {
+  static constexpr const char *name =
+      ExpandAndQuoteKey(RTNAME(ReduceInteger16Dim));
+  static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
+    return [](mlir::MLIRContext *ctx) {
+      auto ty = mlir::IntegerType::get(ctx, 128);
+      auto boxTy =
+          fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
+      auto opTy = mlir::FunctionType::get(ctx, {ty, ty}, ty);
+      auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
+      auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
+      auto refTy = fir::ReferenceType::get(ty);
+      auto refBoxTy = fir::ReferenceType::get(boxTy);
+      auto i1Ty = mlir::IntegerType::get(ctx, 1);
+      return mlir::FunctionType::get(
+          ctx, {refBoxTy, boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty},
+          {});
+    };
+  }
+};
+
 /// Placeholder for complex(10) version of Reduce Intrinsic
 struct ForcedReduceComplex10 {
   static constexpr const char *name =
@@ -527,10 +571,32 @@ struct ForcedReduceComplex10 {
   }
 };
 
+/// Placeholder for Dim complex(10) version of Reduce Intrinsic
+struct ForcedReduceComplex10Dim {
+  static constexpr const char *name =
+      ExpandAndQuoteKey(RTNAME(CppReduceComplex10Dim));
+  static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
+    return [](mlir::MLIRContext *ctx) {
+      auto ty = mlir::ComplexType::get(mlir::FloatType::getF80(ctx));
+      auto boxTy =
+          fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
+      auto opTy = mlir::FunctionType::get(ctx, {ty, ty}, ty);
+      auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
+      auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
+      auto refTy = fir::ReferenceType::get(ty);
+      auto refBoxTy = fir::ReferenceType::get(boxTy);
+      auto i1Ty = mlir::IntegerType::get(ctx, 1);
+      return mlir::FunctionType::get(
+          ctx, {refBoxTy, boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty},
+          {});
+    };
+  }
+};
+
 /// Placeholder for complex(16) version of Reduce Intrinsic
 struct ForcedReduceComplex16 {
   static constexpr const char *name =
-      ExpandAndQuoteKey(RTNAME(CppReduceComplex16));
+      ExpandAndQuoteKey(RTNAME(CppReduceComplex16Dim));
   static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
     return [](mlir::MLIRContext *ctx) {
       auto ty = mlir::ComplexType::get(mlir::FloatType::getF128(ctx));
@@ -540,9 +606,31 @@ struct ForcedReduceComplex16 {
       auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
       auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
       auto refTy = fir::ReferenceType::get(ty);
+      auto refBoxTy = fir::ReferenceType::get(boxTy);
       auto i1Ty = mlir::IntegerType::get(ctx, 1);
       return mlir::FunctionType::get(
-          ctx, {refTy, boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty},
+          ctx, {refBoxTy, boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty},
+          {});
+    };
+  }
+};
+
+/// Placeholder for Dim complex(16) version of Reduce Intrinsic
+struct ForcedReduceComplex16Dim {
+  static constexpr const char *name =
+      ExpandAndQuoteKey(RTNAME(CppReduceComplex16Dim));
+  static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
+    return [](mlir::MLIRContext *ctx) {
+      auto ty = mlir::ComplexType::get(mlir::FloatType::getF128(ctx));
+      auto boxTy =
+          fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
+      auto opTy = mlir::FunctionType::get(ctx, {ty, ty}, ty);
+      auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
+      auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
+      auto refTy = fir::ReferenceType::get(ty);
+      auto i1Ty = mlir::IntegerType::get(ctx, 1);
+      return mlir::FunctionType::get(
+          ctx, {boxTy, boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty},
           {});
     };
   }
@@ -1442,3 +1530,97 @@ mlir::Value fir::runtime::genReduce(fir::FirOpBuilder &builder,
                                             maskBox, identity, ordered);
   return builder.create<fir::CallOp>(loc, func, args).getResult(0);
 }
+
+void fir::runtime::genReduceDim(fir::FirOpBuilder &builder, mlir::Location loc,
+                                mlir::Value arrayBox, mlir::Value operation,
+                                mlir::Value dim, mlir::Value maskBox,
+                                mlir::Value identity, mlir::Value ordered,
+                                mlir::Value resultBox) {
+  mlir::func::FuncOp func;
+  auto ty = arrayBox.getType();
+  auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
+  auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
+
+  mlir::MLIRContext *ctx = builder.getContext();
+  fir::factory::CharacterExprHelper charHelper{builder, loc};
+
+  if (eleTy.isF16())
+    func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceReal2Dim)>(loc, builder);
+  else if (eleTy.isBF16())
+    func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceReal3Dim)>(loc, builder);
+  else if (eleTy.isF32())
+    func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceReal4Dim)>(loc, builder);
+  else if (eleTy.isF64())
+    func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceReal8Dim)>(loc, builder);
+  else if (eleTy.isF80())
+    func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceReal10Dim)>(loc, builder);
+  else if (eleTy.isF128())
+    func = fir::runtime::getRuntimeFunc<ForcedReduceReal16Dim>(loc, builder);
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
+    func =
+        fir::runtime::getRuntimeFunc<mkRTKey(ReduceInteger1Dim)>(loc, builder);
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
+    func =
+        fir::runtime::getRuntimeFunc<mkRTKey(ReduceInteger2Dim)>(loc, builder);
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
+    func =
+        fir::runtime::getRuntimeFunc<mkRTKey(ReduceInteger4Dim)>(loc, builder);
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
+    func =
+        fir::runtime::getRuntimeFunc<mkRTKey(ReduceInteger8Dim)>(loc, builder);
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
+    func = fir::runtime::getRuntimeFunc<ForcedReduceInteger16Dim>(loc, builder);
+  else if (eleTy == fir::ComplexType::get(ctx, 2))
+    func = fir::runtime::getRuntimeFunc<mkRTKey(CppReduceComplex2Dim)>(loc,
+                                                                       builder);
+  else if (eleTy == fir::ComplexType::get(ctx, 3))
+    func = fir::runtime::getRuntimeFunc<mkRTKey(CppReduceComplex3Dim)>(loc,
+                                                                       builder);
+  else if (eleTy == fir::ComplexType::get(ctx, 4))
+    func = fir::runtime::getRuntimeFunc<mkRTKey(CppReduceComplex4Dim)>(loc,
+                                                                       builder);
+  else if (eleTy == fir::ComplexType::get(ctx, 8))
+    func = fir::runtime::getRuntimeFunc<mkRTKey(CppReduceComplex8Dim)>(loc,
+                                                                       builder);
+  else if (eleTy == fir::ComplexType::get(ctx, 10))
+    func = fir::runtime::getRuntimeFunc<ForcedReduceComplex10Dim>(loc, builder);
+  else if (eleTy == fir::ComplexType::get(ctx, 16))
+    func = fir::runtime::getRuntimeFunc<ForcedReduceComplex16Dim>(loc, builder);
+  else if (eleTy == fir::LogicalType::get(ctx, 1))
+    func =
+        fir::runtime::getRuntimeFunc<mkRTKey(ReduceLogical1Dim)>(loc, builder);
+  else if (eleTy == fir::LogicalType::get(ctx, 2))
+    func =
+        fir::runtime::getRuntimeFunc<mkRTKey(ReduceLogical2Dim)>(loc, builder);
+  else if (eleTy == fir::LogicalType::get(ctx, 4))
+    func =
+        fir::runtime::getRuntimeFunc<mkRTKey(ReduceLogical4Dim)>(loc, builder);
+  else if (eleTy == fir::LogicalType::get(ctx, 8))
+    func =
+        fir::runtime::getRuntimeFunc<mkRTKey(ReduceLogical8Dim)>(loc, builder);
+  else if (fir::isa_char(eleTy) && charHelper.getCharacterKind(eleTy) == 1)
+    func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceCharacter1Dim)>(loc,
+                                                                      builder);
+  else if (fir::isa_char(eleTy) && charHelper.getCharacterKind(eleTy) == 2)
+    func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceCharacter2Dim)>(loc,
+                                                                      builder);
+  else if (fir::isa_char(eleTy) && charHelper.getCharacterKind(eleTy) == 4)
+    func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceCharacter4Dim)>(loc,
+                                                                      builder);
+  else if (fir::isa_derived(eleTy))
+    func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceDerivedTypeDim)>(loc,
+                                                                       builder);
+  else
+    fir::intrinsicTypeTODO(builder, eleTy, loc, "REDUCE");
+
+  auto fTy = func.getFunctionType();
+  auto sourceFile = fir::factory::locationToFilename(builder, loc);
+
+  auto sourceLine =
+      fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
+  auto opAddr = builder.create<fir::BoxAddrOp>(loc, fTy.getInput(2), operation);
+  auto args = fir::runtime::createArguments(
+      builder, loc, fTy, resultBox, arrayBox, opAddr, sourceFile, sourceLine,
+      dim, maskBox, identity, ordered);
+  builder.create<fir::CallOp>(loc, func, args);
+}
diff --git a/flang/test/Lower/Intrinsics/reduce.f90 b/flang/test/Lower/Intrinsics/reduce.f90
index 36900abaa79f8..842e626d7cc39 100644
--- a/flang/test/Lower/Intrinsics/reduce.f90
+++ b/flang/test/Lower/Intrinsics/reduce.f90
@@ -392,4 +392,225 @@ subroutine testtype(a)
 
 ! CHECK: fir.call @_FortranAReduceDerivedType
 
+subroutine integer1dim(a, id)
+  integer(1), intent(in) :: a(:,:)
+  integer(1), allocatable :: res(:)
+
+  res = reduce(a, red_int1, 2)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceInteger1Dim
+
+subroutine integer2dim(a, id)
+  integer(2), intent(in) :: a(:,:)
+  integer(2), allocatable :: res(:)
+
+  res = reduce(a, red_int2, 2)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceInteger2Dim
+
+subroutine integer4dim(a, id)
+  integer(4), intent(in) :: a(:,:)
+  integer(4), allocatable :: res(:)
+
+  res = reduce(a, red_int4, 2)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceInteger4Dim
+
+subroutine integer8dim(a, id)
+  integer(8), intent(in) :: a(:,:)
+  integer(8), allocatable :: res(:)
+
+  res = reduce(a, red_int8, 2)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceInteger8Dim
+
+subroutine integer16dim(a, id)
+  integer(16), intent(in) :: a(:,:)
+  integer(16), allocatable :: res(:)
+
+  res = reduce(a, red_int16, 2)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceInteger16Dim
+
+subroutine real2dim(a, id)
+  real(2), intent(in) :: a(:,:)
+  real(2), allocatable :: res(:)
+
+  res = reduce(a, red_real2, 2)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceReal2Dim
+
+subroutine real3dim(a, id)
+  real(3), intent(in) :: a(:,:)
+  real(3), allocatable :: res(:)
+
+  res = reduce(a, red_real3, 2)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceReal3Dim
+
+subroutine real4dim(a, id)
+  real(4), intent(in) :: a(:,:)
+  real(4), allocatable :: res(:)
+
+  res = reduce(a, red_real4, 2)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceReal4Dim
+
+subroutine real8dim(a, id)
+  real(8), intent(in) :: a(:,:)
+  real(8), allocatable :: res(:)
+
+  res = reduce(a, red_real8, 2)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceReal8Dim
+
+subroutine real10dim(a, id)
+  real(10), intent(in) :: a(:,:)
+  real(10), allocatable :: res(:)
+
+  res = reduce(a, red_real10, 2)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceReal10Dim
+
+subroutine real16dim(a, id)
+  real(16), intent(in) :: a(:,:)
+  real(16), allocatable :: res(:)
+
+  res = reduce(a, red_real16, 2)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceReal16Dim
+
+subroutine complex2dim(a, id)
+  complex(2), intent(in) :: a(:,:)
+  complex(2), allocatable :: res(:)
+
+  res = reduce(a, red_complex2, 2)
+end subroutine
+
+! CHECK: fir.call @_FortranACppReduceComplex2Dim
+
+subroutine complex3dim(a, id)
+  complex(3), intent(in) :: a(:,:)
+  complex(3), allocatable :: res(:)
+
+  res = reduce(a, red_complex3, 2)
+end subroutine
+
+! CHECK: fir.call @_FortranACppReduceComplex3Dim
+
+subroutine complex4dim(a, id)
+  complex(4), intent(in) :: a(:,:)
+  complex(4), allocatable :: res(:)
+
+  res = reduce(a, red_complex4, 2)
+end subroutine
+
+! CHECK: fir.call @_FortranACppReduceComplex4Dim
+
+subroutine complex8dim(a, id)
+  complex(8), intent(in) :: a(:,:)
+  complex(8), allocatable :: res(:)
+
+  res = reduce(a, red_complex8, 2)
+end subroutine
+
+! CHECK: fir.call @_FortranACppReduceComplex8Dim
+
+subroutine complex10dim(a, id)
+  complex(10), intent(in) :: a(:,:)
+  complex(10), allocatable :: res(:)
+
+  res = reduce(a, red_complex10, 2)
+end subroutine
+
+! CHECK: fir.call @_FortranACppReduceComplex10Dim
+
+subroutine complex16dim(a, id)
+  complex(16), intent(in) :: a(:,:)
+  complex(16), allocatable :: res(:)
+
+  res = reduce(a, red_complex16, 2)
+end subroutine
+
+! CHECK: fir.call @_FortranACppReduceComplex16Dim
+
+subroutine logical1dim(a, id)
+  logical(1), intent(in) :: a(:,:)
+  logical(1), allocatable :: res(:)
+
+  res = reduce(a, red_log1, 2)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceLogical1Dim
+
+subroutine logical2dim(a, id)
+  logical(2), intent(in) :: a(:,:)
+  logical(2), allocatable :: res(:)
+
+  res = reduce(a, red_log2, 2)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceLogical2Dim
+
+subroutine logical4dim(a, id)
+  logical(4), intent(in) :: a(:,:)
+  logical(4), allocatable :: res(:)
+
+  res = reduce(a, red_log4, 2)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceLogical4Dim
+
+subroutine logical8dim(a, id)
+  logical(8), intent(in) :: a(:,:)
+  logical(8), allocatable :: res(:)
+
+  res = reduce(a, red_log8, 2)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceLogical8Dim
+
+subroutine testtypeDim(a)
+  type(t1), intent(in) :: a(:,:)
+  type(t1), allocatable :: res(:)
+  res = reduce(a, red_type, 2)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceDerivedTypeDim
+
+subroutine char1dim(a)
+  character(1), intent(in) :: a(:, :)
+  character(1), allocatable :: res(:)
+  res = reduce(a, red_char1, 2)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceCharacter1Dim
+
+subroutine char2dim(a)
+  character(kind=2), intent(in) :: a(:, :)
+  character(kind=2), allocatable :: res(:)
+  res = reduce(a, red_char2, 2)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceCharacter2Dim
+
+subroutine char4dim(a)
+  character(kind=4), intent(in) :: a(:, :)
+  character(kind=4), allocatable :: res(:)
+  res = reduce(a, red_char4, 2)
+end subroutine
+
+! CHECK: fir.call @_FortranAReduceCharacter4Dim
+
 end module

``````````

</details>


https://github.com/llvm/llvm-project/pull/94771


More information about the llvm-branch-commits mailing list