[flang-commits] [flang] 8bd76ac - [flang] Support multidimensional reductions in SimplifyIntrinsicsPass.

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Mon Sep 19 12:17:48 PDT 2022


Author: Slava Zakharin
Date: 2022-09-19T12:16:23-07:00
New Revision: 8bd76ac151534d2b9534ed919c0a7f4511002d84

URL: https://github.com/llvm/llvm-project/commit/8bd76ac151534d2b9534ed919c0a7f4511002d84
DIFF: https://github.com/llvm/llvm-project/commit/8bd76ac151534d2b9534ed919c0a7f4511002d84.diff

LOG: [flang] Support multidimensional reductions in SimplifyIntrinsicsPass.

Create simplified functions for each rank with "x<rank>" suffix
that implement multidimensional reductions. To enable this I had to fix
an issue with taking incorrect box shape in cases of sliced embox/rebox.

Differential Revision: https://reviews.llvm.org/D133820

Added: 
    

Modified: 
    flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
    flang/test/Transforms/simplifyintrinsics.fir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
index d23736ef8a68e..5682fa2816714 100644
--- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
+++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
@@ -61,7 +61,7 @@ class SimplifyIntrinsicsPass
   using FunctionBodyGeneratorTy =
       llvm::function_ref<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
   using GenReductionBodyTy = llvm::function_ref<void(
-      fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp)>;
+      fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp, unsigned rank)>;
 
 public:
   /// Generate a new function implementing a simplified version
@@ -110,10 +110,11 @@ using InitValGeneratorTy = llvm::function_ref<mlir::Value(
 ///    the reduction value
 /// \p genBody is called to fill in the actual reduciton operation
 ///    for example add for SUM, MAX for MAXVAL, etc.
+/// \p rank is the rank of the input argument.
 static void genReductionLoop(fir::FirOpBuilder &builder,
                              mlir::func::FuncOp &funcOp,
                              InitValGeneratorTy initVal,
-                             BodyOpGeneratorTy genBody) {
+                             BodyOpGeneratorTy genBody, unsigned rank) {
   auto loc = mlir::UnknownLoc::get(builder.getContext());
   mlir::Type elementType = funcOp.getResultTypes()[0];
   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
@@ -125,59 +126,98 @@ static void genReductionLoop(fir::FirOpBuilder &builder,
 
   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
 
-  fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
+  fir::SequenceType::Shape flatShape(rank,
+                                     fir::SequenceType::getUnknownExtent());
   mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
   mlir::Type boxArrTy = fir::BoxType::get(arrTy);
   mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
-  auto dims =
-      builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, zeroIdx);
-  mlir::Value len = dims.getResult(1);
+  mlir::Value init = initVal(builder, loc, elementType);
+
+  llvm::SmallVector<mlir::Value, 15> bounds;
+
+  assert(rank > 0 && "rank cannot be zero");
   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
-  mlir::Value step = one;
 
-  // We use C indexing here, so len-1 as loopcount
-  mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
-  mlir::Value init = initVal(builder, loc, elementType);
-  auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
-                                            /*unordered=*/false,
-                                            /*finalCountValue=*/false, init);
-  mlir::Value reductionVal = loop.getRegionIterArgs()[0];
+  // Compute all the upper bounds before the loop nest.
+  // It is not strictly necessary for performance, since the loop nest
+  // does not have any store operations and any LICM optimization
+  // should be able to optimize the redundancy.
+  for (unsigned i = 0; i < rank; ++i) {
+    mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
+    auto dims =
+        builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
+    mlir::Value len = dims.getResult(1);
+    // We use C indexing here, so len-1 as loopcount
+    mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
+    bounds.push_back(loopCount);
+  }
 
-  // Begin loop code
-  mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
-  builder.setInsertionPointToStart(loop.getBody());
+  // Create a loop nest consisting of DoLoopOp operations.
+  // Collect the loops' induction variables into indices array,
+  // which will be used in the innermost loop to load the input
+  // array's element.
+  // The loops are generated such that the innermost loop processes
+  // the 0 dimension.
+  llvm::SmallVector<mlir::Value, 15> indices;
+  for (unsigned i = rank; 0 < i; --i) {
+    mlir::Value step = one;
+    mlir::Value loopCount = bounds[i - 1];
+    auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
+                                              /*unordered=*/false,
+                                              /*finalCountValue=*/false, init);
+    init = loop.getRegionIterArgs()[0];
+    indices.push_back(loop.getInductionVar());
+    // Set insertion point to the loop body so that the next loop
+    // is inserted inside the current one.
+    builder.setInsertionPointToStart(loop.getBody());
+  }
+
+  // Reverse the indices such that they are ordered as:
+  //   <dim-0-idx, dim-1-idx, ...>
+  std::reverse(indices.begin(), indices.end());
 
+  // We are in the innermost loop: generate the reduction body.
   mlir::Type eleRefTy = builder.getRefType(elementType);
-  mlir::Value index = loop.getInductionVar();
   mlir::Value addr =
-      builder.create<fir::CoordinateOp>(loc, eleRefTy, array, index);
+      builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices);
   mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
 
-  reductionVal = genBody(builder, loc, elementType, elem, reductionVal);
-
-  builder.create<fir::ResultOp>(loc, reductionVal);
-  // End of loop.
-  builder.restoreInsertionPoint(loopEndPt);
+  mlir::Value reductionVal = genBody(builder, loc, elementType, elem, init);
+
+  // Unwind the loop nest and insert ResultOp on each level
+  // to return the updated value of the reduction to the enclosing
+  // loops.
+  for (unsigned i = 0; i < rank; ++i) {
+    auto result = builder.create<fir::ResultOp>(loc, reductionVal);
+    // Proceed to the outer loop.
+    auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
+    reductionVal = loop.getResult(0);
+    // Set insertion point after the loop operation that we have
+    // just processed.
+    builder.setInsertionPointAfter(loop.getOperation());
+  }
 
-  mlir::Value resultVal = loop.getResult(0);
-  builder.create<mlir::func::ReturnOp>(loc, resultVal);
+  // End of loop nest. The insertion point is after the outermost loop.
+  // Return the reduction value from the function.
+  builder.create<mlir::func::ReturnOp>(loc, reductionVal);
 }
 
 /// Generate function body of the simplified version of RTNAME(Sum)
 /// with signature provided by \p funcOp. The caller is responsible
 /// for saving/restoring the original insertion point of \p builder.
 /// \p funcOp is expected to be empty on entry to this function.
+/// \p rank specifies the rank of the input argument.
 static void genRuntimeSumBody(fir::FirOpBuilder &builder,
-                              mlir::func::FuncOp &funcOp) {
-  // function RTNAME(Sum)<T>_simplified(arr)
+                              mlir::func::FuncOp &funcOp, unsigned rank) {
+  // function RTNAME(Sum)<T>x<rank>_simplified(arr)
   //   T, dimension(:) :: arr
   //   T sum = 0
   //   integer iter
   //   do iter = 0, extent(arr)
   //     sum = sum + arr[iter]
   //   end do
-  //   RTNAME(Sum)<T>_simplified = sum
-  // end function RTNAME(Sum)<T>_simplified
+  //   RTNAME(Sum)<T>x<rank>_simplified = sum
+  // end function RTNAME(Sum)<T>x<rank>_simplified
   auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
                  mlir::Type elementType) {
     if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
@@ -200,11 +240,11 @@ static void genRuntimeSumBody(fir::FirOpBuilder &builder,
     return {};
   };
 
-  genReductionLoop(builder, funcOp, zero, genBodyOp);
+  genReductionLoop(builder, funcOp, zero, genBodyOp, rank);
 }
 
 static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder,
-                                 mlir::func::FuncOp &funcOp) {
+                                 mlir::func::FuncOp &funcOp, unsigned rank) {
   auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
                  mlir::Type elementType) {
     if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
@@ -228,7 +268,7 @@ static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder,
     llvm_unreachable("unsupported type");
     return {};
   };
-  genReductionLoop(builder, funcOp, init, genBodyOp);
+  genReductionLoop(builder, funcOp, init, genBodyOp, rank);
 }
 
 /// Generate function type for the simplified version of RTNAME(DotProduct)
@@ -410,21 +450,31 @@ static bool isZero(mlir::Value val) {
   return false;
 }
 
-static mlir::Value findShape(mlir::Value val) {
+static mlir::Value findBoxDef(mlir::Value val) {
   if (auto op = expectConvertOp(val)) {
     assert(op->getOperands().size() != 0);
     if (auto box = mlir::dyn_cast_or_null<fir::EmboxOp>(
             op->getOperand(0).getDefiningOp()))
-      return box.getShape();
+      return box.getResult();
+    if (auto box = mlir::dyn_cast_or_null<fir::ReboxOp>(
+            op->getOperand(0).getDefiningOp()))
+      return box.getResult();
   }
   return {};
 }
 
 static unsigned getDimCount(mlir::Value val) {
-  if (mlir::Value shapeVal = findShape(val)) {
-    mlir::Type resType = shapeVal.getDefiningOp()->getResultTypes()[0];
-    return fir::getRankOfShapeType(resType);
-  }
+  // In order to find the dimensions count, we look for EmboxOp/ReboxOp
+  // and take the count from its *result* type. Note that in case
+  // of sliced emboxing the operand and the result of EmboxOp/ReboxOp
+  // have 
diff erent types.
+  // Actually, we can take the box type from the operand of
+  // the first ConvertOp that has non-opaque box type that we meet
+  // going through the ConvertOp chain.
+  if (mlir::Value emboxVal = findBoxDef(val))
+    if (auto boxTy = emboxVal.getType().dyn_cast<fir::BoxType>())
+      if (auto seqTy = boxTy.getEleTy().dyn_cast<fir::SequenceType>())
+        return seqTy.getDimension();
   return 0;
 }
 
@@ -455,7 +505,6 @@ void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call,
                                                const fir::KindMapping &kindMap,
                                                GenReductionBodyTy genBodyFunc) {
   mlir::SymbolRefAttr callee = call.getCalleeAttr();
-  mlir::StringRef funcName = callee.getLeafReference().getValue();
   mlir::Operation::operand_range args = call.getArgs();
   // args[1] and args[2] are source filename and line number, ignored.
   const mlir::Value &dim = args[3];
@@ -464,7 +513,7 @@ void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call,
   // detail in the runtime library.
   bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
   unsigned rank = getDimCount(args[0]);
-  if (dimAndMaskAbsent && rank == 1) {
+  if (dimAndMaskAbsent && rank > 0) {
     mlir::Location loc = call.getLoc();
     fir::FirOpBuilder builder(call, kindMap);
 
@@ -483,8 +532,17 @@ void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call,
     auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) {
       return genNoneBoxType(builder, resultType);
     };
+    auto bodyGenerator = [&rank, &genBodyFunc](fir::FirOpBuilder &builder,
+                                               mlir::func::FuncOp &funcOp) {
+      genBodyFunc(builder, funcOp, rank);
+    };
+    // Mangle the function name with the rank value as "x<rank>".
+    std::string funcName =
+        (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
+         mlir::Twine{rank})
+            .str();
     mlir::func::FuncOp newFunc =
-        getOrCreateFunction(builder, funcName, typeGenerator, genBodyFunc);
+        getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
     auto newCall =
         builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]});
     call->replaceAllUsesWith(newCall.getResults());

diff  --git a/flang/test/Transforms/simplifyintrinsics.fir b/flang/test/Transforms/simplifyintrinsics.fir
index b5d24c5785243..e3ac9c930d299 100644
--- a/flang/test/Transforms/simplifyintrinsics.fir
+++ b/flang/test/Transforms/simplifyintrinsics.fir
@@ -34,20 +34,21 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
 // CHECK:           %[[A_BOX_I32:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xi32>>
 // CHECK:           %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_I32]] : (!fir.box<!fir.array<10xi32>>) -> !fir.box<none>
 // CHECK-NOT:       fir.call @_FortranASumInteger4({{.*}})
-// CHECK:           %[[RES:.*]] = fir.call @_FortranASumInteger4_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> i32
+// CHECK:           %[[RES:.*]] = fir.call @_FortranASumInteger4x1_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> i32
 // CHECK-NOT:       fir.call @_FortranASumInteger4({{.*}})
 // CHECK:           return %{{.*}} : i32
 // CHECK:         }
 // CHECK:         func.func private @_FortranASumInteger4(!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i32 attributes {fir.runtime}
 
-// CHECK-LABEL:   func.func private @_FortranASumInteger4_simplified(
+// CHECK-LABEL:   func.func private @_FortranASumInteger4x1_simplified(
 // CHECK-SAME:                                                       %[[ARR:.*]]: !fir.box<none>) -> i32 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
 // CHECK:           %[[CINDEX_0:.*]] = arith.constant 0 : index
 // CHECK:           %[[ARR_BOX_I32:.*]] = fir.convert %[[ARR]] : (!fir.box<none>) -> !fir.box<!fir.array<?xi32>>
-// CHECK:           %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[CINDEX_0]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
+// CHECK:           %[[CI32_0:.*]] = arith.constant 0 : i32
 // CHECK:           %[[CINDEX_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[DIMIDX_0:.*]] = arith.constant 0 : index
+// CHECK:           %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[DIMIDX_0]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
 // CHECK:           %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index
-// CHECK:           %[[CI32_0:.*]] = arith.constant 0 : i32
 // CHECK:           %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] iter_args(%[[SUM:.*]] = %[[CI32_0]]) -> (i32) {
 // CHECK:             %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_I32]], %[[ITER]] : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
 // CHECK:             %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref<i32>
@@ -59,7 +60,7 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
 
 // -----
 
-// Call to SUM with 2D I32 arrays is not replaced.
+// Call to SUM with 2D I32 arrays is replaced.
 module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.target_triple = "native"} {
   func.func @sum_2d_array_int(%arg0: !fir.ref<!fir.array<10x10xi32>> {fir.bindc_name = "a"}) -> i32 {
     %c10 = arith.constant 10 : index
@@ -88,9 +89,39 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
 }
 
 // CHECK-LABEL:   func.func @sum_2d_array_int({{.*}} !fir.ref<!fir.array<10x10xi32>> {fir.bindc_name = "a"}) -> i32 {
-// CHECK-NOT:       fir.call @_FortranASumInteger4_simplified({{.*}})
-// CHECK:           fir.call @_FortranASumInteger4({{.*}}) : (!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i32
-// CHECK-NOT:       fir.call @_FortranASumInteger4_simplified({{.*}})
+// CHECK:           %[[SHAPE:.*]] = fir.shape %{{.*}} : (index, index) -> !fir.shape<2>
+// CHECK:           %[[A_BOX_I32:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref<!fir.array<10x10xi32>>, !fir.shape<2>) -> !fir.box<!fir.array<10x10xi32>>
+// CHECK:           %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_I32]] : (!fir.box<!fir.array<10x10xi32>>) -> !fir.box<none>
+// CHECK-NOT:       fir.call @_FortranASumInteger4({{.*}})
+// CHECK:           %[[RES:.*]] = fir.call @_FortranASumInteger4x2_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> i32
+// CHECK-NOT:       fir.call @_FortranASumInteger4({{.*}})
+// CHECK:           return %{{.*}} : i32
+// CHECK:         }
+// CHECK:         func.func private @_FortranASumInteger4(!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i32 attributes {fir.runtime}
+
+// CHECK-LABEL:   func.func private @_FortranASumInteger4x2_simplified(
+// CHECK-SAME:                                                       %[[ARR:.*]]: !fir.box<none>) -> i32 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
+// CHECK:           %[[CINDEX_0:.*]] = arith.constant 0 : index
+// CHECK:           %[[ARR_BOX_I32:.*]] = fir.convert %[[ARR]] : (!fir.box<none>) -> !fir.box<!fir.array<?x?xi32>>
+// CHECK:           %[[CI32_0:.*]] = arith.constant 0 : i32
+// CHECK:           %[[CINDEX_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[DIMIDX_0:.*]] = arith.constant 0 : index
+// CHECK:           %[[DIMS_0:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[DIMIDX_0]] : (!fir.box<!fir.array<?x?xi32>>, index) -> (index, index, index)
+// CHECK:           %[[EXTENT_0:.*]] = arith.subi %[[DIMS_0]]#1, %[[CINDEX_1]] : index
+// CHECK:           %[[DIMIDX_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[DIMS_1:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[DIMIDX_1]] : (!fir.box<!fir.array<?x?xi32>>, index) -> (index, index, index)
+// CHECK:           %[[EXTENT_1:.*]] = arith.subi %[[DIMS_1]]#1, %[[CINDEX_1]] : index
+// CHECK:           %[[RES_1:.*]] = fir.do_loop %[[ITER_1:.*]] = %[[CINDEX_0]] to %[[EXTENT_1]] step %[[CINDEX_1]] iter_args(%[[SUM_1:.*]] = %[[CI32_0]]) -> (i32) {
+// CHECK:             %[[RES_0:.*]] = fir.do_loop %[[ITER_0:.*]] = %[[CINDEX_0]] to %[[EXTENT_0]] step %[[CINDEX_1]] iter_args(%[[SUM_0:.*]] = %[[SUM_1]]) -> (i32) {
+// CHECK:               %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_I32]], %[[ITER_0]], %[[ITER_1]] : (!fir.box<!fir.array<?x?xi32>>, index, index) -> !fir.ref<i32>
+// CHECK:               %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref<i32>
+// CHECK:               %[[NEW_SUM:.*]] = arith.addi %[[ITEM_VAL]], %[[SUM_0]] : i32
+// CHECK:               fir.result %[[NEW_SUM]] : i32
+// CHECK:             }
+// CHECK:             fir.result %[[RES_0]]
+// CHECK:           }
+// CHECK:           return %[[RES_1]] : i32
+// CHECK:         }
 
 // -----
 
@@ -129,19 +160,20 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
 // CHECK:           %[[A_BOX_F64:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref<!fir.array<10xf64>>, !fir.shape<1>) -> !fir.box<!fir.array<10xf64>>
 // CHECK:           %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_F64]] : (!fir.box<!fir.array<10xf64>>) -> !fir.box<none>
 // CHECK-NOT:       fir.call @_FortranASumReal8({{.*}})
-// CHECK:           %[[RES:.*]] = fir.call @_FortranASumReal8_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> f64
+// CHECK:           %[[RES:.*]] = fir.call @_FortranASumReal8x1_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> f64
 // CHECK-NOT:       fir.call @_FortranASumReal8({{.*}})
 // CHECK:           return %{{.*}} : f64
 // CHECK:         }
 
-// CHECK-LABEL:   func.func private @_FortranASumReal8_simplified(
+// CHECK-LABEL:   func.func private @_FortranASumReal8x1_simplified(
 // CHECK-SAME:                                                    %[[ARR:.*]]: !fir.box<none>) -> f64 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
 // CHECK:           %[[CINDEX_0:.*]] = arith.constant 0 : index
 // CHECK:           %[[ARR_BOX_F64:.*]] = fir.convert %[[ARR]] : (!fir.box<none>) -> !fir.box<!fir.array<?xf64>>
-// CHECK:           %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F64]], %[[CINDEX_0]] : (!fir.box<!fir.array<?xf64>>, index) -> (index, index, index)
+// CHECK:           %[[ZERO:.*]] = arith.constant 0.000000e+00 : f64
 // CHECK:           %[[CINDEX_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[DIMIDX_0:.*]] = arith.constant 0 : index
+// CHECK:           %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F64]], %[[DIMIDX_0]] : (!fir.box<!fir.array<?xf64>>, index) -> (index, index, index)
 // CHECK:           %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index
-// CHECK:           %[[ZERO:.*]] = arith.constant 0.000000e+00 : f64
 // CHECK:           %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] iter_args(%[[SUM]] = %[[ZERO]]) -> (f64) {
 // CHECK:             %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_F64]], %[[ITER]] : (!fir.box<!fir.array<?xf64>>, index) -> !fir.ref<f64>
 // CHECK:             %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref<f64>
@@ -188,19 +220,20 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
 // CHECK:           %[[A_BOX_F32:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xf32>>
 // CHECK:           %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_F32]] : (!fir.box<!fir.array<10xf32>>) -> !fir.box<none>
 // CHECK-NOT:       fir.call @_FortranASumReal4({{.*}})
-// CHECK:           %[[RES:.*]] = fir.call @_FortranASumReal4_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> f32
+// CHECK:           %[[RES:.*]] = fir.call @_FortranASumReal4x1_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> f32
 // CHECK-NOT:       fir.call @_FortranASumReal4({{.*}})
 // CHECK:           return %{{.*}} : f32
 // CHECK:         }
 
-// CHECK-LABEL:   func.func private @_FortranASumReal4_simplified(
+// CHECK-LABEL:   func.func private @_FortranASumReal4x1_simplified(
 // CHECK-SAME:                                                    %[[ARR:.*]]: !fir.box<none>) -> f32 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
 // CHECK:           %[[CINDEX_0:.*]] = arith.constant 0 : index
 // CHECK:           %[[ARR_BOX_F32:.*]] = fir.convert %[[ARR]] : (!fir.box<none>) -> !fir.box<!fir.array<?xf32>>
-// CHECK:           %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F32]], %[[CINDEX_0]] : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
+// CHECK:           %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[CINDEX_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[DIMIDX_0:.*]] = arith.constant 0 : index
+// CHECK:           %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F32]], %[[DIMIDX_0]] : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
 // CHECK:           %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index
-// CHECK:           %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] iter_args(%[[SUM]] = %[[ZERO]]) -> (f32) {
 // CHECK:             %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_F32]], %[[ITER]] : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
 // CHECK:             %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref<f32>
@@ -243,9 +276,9 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
 }
 
 // CHECK-LABEL:   func.func @sum_1d_complex(%{{.*}}: !fir.ref<!fir.array<10x!fir.complex<4>>> {fir.bindc_name = "a"}) -> !fir.complex<4> {
-// CHECK-NOT:       fir.call @_FortranACppSumComplex4_simplified({{.*}})
+// CHECK-NOT:       fir.call @_FortranACppSumComplex4x1_simplified({{.*}})
 // CHECK:           fir.call @_FortranACppSumComplex4({{.*}}) : (!fir.ref<complex<f32>>, !fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> none
-// CHECK-NOT:       fir.call @_FortranACppSumComplex4_simplified({{.*}})
+// CHECK-NOT:       fir.call @_FortranACppSumComplex4x1_simplified({{.*}})
 
 // -----
 
@@ -298,20 +331,20 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
 
 // CHECK-LABEL:   func.func @sum_1d_calla(%{{.*}}) -> i32 {
 // CHECK-NOT:       fir.call @_FortranASumInteger4({{.*}})
-// CHECK:           fir.call @_FortranASumInteger4_simplified(%{{.*}})
+// CHECK:           fir.call @_FortranASumInteger4x1_simplified(%{{.*}})
 // CHECK-NOT:       fir.call @_FortranASumInteger4({{.*}})
 // CHECK:         }
 
 // CHECK-LABEL:   func.func @sum_1d_callb(%{{.*}}) -> i32 {
 // CHECK-NOT:       fir.call @_FortranASumInteger4({{.*}})
-// CHECK:           fir.call @_FortranASumInteger4_simplified(%{{.*}})
+// CHECK:           fir.call @_FortranASumInteger4x1_simplified(%{{.*}})
 // CHECK-NOT:       fir.call @_FortranASumInteger4({{.*}})
 // CHECK:         }
 
-// CHECK-LABEL:   func.func private @_FortranASumInteger4_simplified({{.*}}) -> i32 {{.*}} {
+// CHECK-LABEL:   func.func private @_FortranASumInteger4x1_simplified({{.*}}) -> i32 {{.*}} {
 // CHECK:           return %{{.*}} : i32
 // CHECK:         }
-// CHECK-NOT:   func.func private @_FortranASumInteger4_simplified({{.*}})
+// CHECK-NOT:   func.func private @_FortranASumInteger4x1_simplified({{.*}})
 
 // -----
 
@@ -354,14 +387,14 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
 // CHECK:           %[[SLICE:.*]] = fir.slice %{{.*}}, %{{.*}}, %[[CINDEX_2]] : (index, index, index) -> !fir.slice<1>
 // CHECK:           %[[A_BOX_I32:.*]] = fir.embox %{{.*}}(%[[SHAPE]]) {{\[}}%[[SLICE]]] : (!fir.ref<!fir.array<20xi32>>, !fir.shape<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>>
 // CHECK:           %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_I32]] : (!fir.box<!fir.array<?xi32>>) -> !fir.box<none>
-// CHECK:           %{{.*}} = fir.call @_FortranASumInteger4_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> i32
+// CHECK:           %{{.*}} = fir.call @_FortranASumInteger4x1_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> i32
 // CHECK:           return %{{.*}} : i32
 // CHECK:         }
 
-// CHECK-LABEL:   func.func private @_FortranASumInteger4_simplified(%{{.*}}) -> i32 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
+// CHECK-LABEL:   func.func private @_FortranASumInteger4x1_simplified(%{{.*}}) -> i32 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
 // CHECK:           %[[ARR_BOX_I32:.*]] = fir.convert %{{.*}} : (!fir.box<none>) -> !fir.box<!fir.array<?xi32>>
-// CHECK:           %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %{{.*}} : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
 // CHECK:           %[[CINDEX_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %{{.*}} : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
 // CHECK:           %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index
 // CHECK:           %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %{{.*}} to %[[EXTENT]] step %[[CINDEX_1]] iter_args({{.*}}) -> (i32) {
 // CHECK:             %{{.*}} = fir.coordinate_of %[[ARR_BOX_I32]], %[[ITER]] : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
@@ -792,18 +825,19 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
 // CHECK:           %[[SHAPE:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1>
 // CHECK:           %[[A_BOX_I32:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xi32>>
 // CHECK:           %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_I32]] : (!fir.box<!fir.array<10xi32>>) -> !fir.box<none>
-// CHECK:           %[[RES:.*]] = fir.call @_FortranAMaxvalInteger4_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> i32
+// CHECK:           %[[RES:.*]] = fir.call @_FortranAMaxvalInteger4x1_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> i32
 // CHECK:           return %{{.*}} : i32
 // CHECK:         }
 
-// CHECK-LABEL:   func.func private @_FortranAMaxvalInteger4_simplified(
+// CHECK-LABEL:   func.func private @_FortranAMaxvalInteger4x1_simplified(
 // CHECK-SAME:                                                       %[[ARR:.*]]: !fir.box<none>) -> i32 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
 // CHECK:           %[[CINDEX_0:.*]] = arith.constant 0 : index
 // CHECK:           %[[ARR_BOX_I32:.*]] = fir.convert %[[ARR]] : (!fir.box<none>) -> !fir.box<!fir.array<?xi32>>
-// CHECK:           %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[CINDEX_0]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
+// CHECK:           %[[CI32_MININT:.*]] = arith.constant -2147483648 : i32
 // CHECK:           %[[CINDEX_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[DIMIDX_0:.*]] = arith.constant 0 : index
+// CHECK:           %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[DIMIDX_0]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
 // CHECK:           %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index
-// CHECK:           %[[CI32_MININT:.*]] = arith.constant -2147483648 : i32
 // CHECK:           %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] iter_args(%[[MAX:.*]] = %[[CI32_MININT]]) -> (i32) {
 // CHECK:             %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_I32]], %[[ITER]] : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
 // CHECK:             %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref<i32>
@@ -849,18 +883,19 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
 // CHECK:           %[[SHAPE:.*]] = fir.shape %[[CINDEX_10]] : (index) -> !fir.shape<1>
 // CHECK:           %[[A_BOX_F64:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref<!fir.array<10xf64>>, !fir.shape<1>) -> !fir.box<!fir.array<10xf64>>
 // CHECK:           %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_F64]] : (!fir.box<!fir.array<10xf64>>) -> !fir.box<none>
-// CHECK:           %[[RES:.*]] = fir.call @_FortranAMaxvalReal8_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> f64
+// CHECK:           %[[RES:.*]] = fir.call @_FortranAMaxvalReal8x1_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> f64
 // CHECK:           return %{{.*}} : f64
 // CHECK:         }
 
-// CHECK-LABEL:   func.func private @_FortranAMaxvalReal8_simplified(
+// CHECK-LABEL:   func.func private @_FortranAMaxvalReal8x1_simplified(
 // CHECK-SAME:                                                    %[[ARR:.*]]: !fir.box<none>) -> f64 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
 // CHECK:           %[[CINDEX_0:.*]] = arith.constant 0 : index
 // CHECK:           %[[ARR_BOX_F64:.*]] = fir.convert %[[ARR]] : (!fir.box<none>) -> !fir.box<!fir.array<?xf64>>
-// CHECK:           %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F64]], %[[CINDEX_0]] : (!fir.box<!fir.array<?xf64>>, index) -> (index, index, index)
+// CHECK:           %[[NEG_DBL_MAX:.*]] = arith.constant -1.7976931348623157E+308 : f64
 // CHECK:           %[[CINDEX_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[DIMIDX_0:.*]] = arith.constant 0 : index
+// CHECK:           %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F64]], %[[DIMIDX_0]] : (!fir.box<!fir.array<?xf64>>, index) -> (index, index, index)
 // CHECK:           %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index
-// CHECK:           %[[NEG_DBL_MAX:.*]] = arith.constant -1.7976931348623157E+308 : f64
 // CHECK:           %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] iter_args(%[[MAX]] = %[[NEG_DBL_MAX]]) -> (f64) {
 // CHECK:             %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_F64]], %[[ITER]] : (!fir.box<!fir.array<?xf64>>, index) -> !fir.ref<f64>
 // CHECK:             %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref<f64>
@@ -869,3 +904,97 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
 // CHECK:           }
 // CHECK:           return %[[RES]] : f64
 // CHECK:         }
+
+// -----
+
+// SUM reduction of sliced explicit-shape array is replaced with
+// 2D simplified implementation.
+func.func @sum_sliced_embox_i64(%arg0: !fir.ref<!fir.array<10x10x10xi64>> {fir.bindc_name = "a"}) -> f32 {
+  %c10 = arith.constant 10 : index
+  %c10_0 = arith.constant 10 : index
+  %c10_1 = arith.constant 10 : index
+  %0 = fir.alloca f32 {bindc_name = "sum_sliced_embox_i64", uniq_name = "_QFsum_sliced_embox_i64Esum_sliced_embox_i64"}
+  %1 = fir.alloca i64 {bindc_name = "sum_sliced_i64", uniq_name = "_QFsum_sliced_embox_i64Esum_sliced_i64"}
+  %c1 = arith.constant 1 : index
+  %c1_i64 = arith.constant 1 : i64
+  %2 = fir.convert %c1_i64 : (i64) -> index
+  %3 = arith.addi %c1, %c10 : index
+  %4 = arith.subi %3, %c1 : index
+  %c1_i64_2 = arith.constant 1 : i64
+  %5 = fir.convert %c1_i64_2 : (i64) -> index
+  %6 = arith.addi %c1, %c10_0 : index
+  %7 = arith.subi %6, %c1 : index
+  %c1_i64_3 = arith.constant 1 : i64
+  %8 = fir.undefined index
+  %9 = fir.shape %c10, %c10_0, %c10_1 : (index, index, index) -> !fir.shape<3>
+  %10 = fir.slice %c1, %4, %2, %c1, %7, %5, %c1_i64_3, %8, %8 : (index, index, index, index, index, index, i64, index, index) -> !fir.slice<3>
+  %11 = fir.embox %arg0(%9) [%10] : (!fir.ref<!fir.array<10x10x10xi64>>, !fir.shape<3>, !fir.slice<3>) -> !fir.box<!fir.array<?x?xi64>>
+  %12 = fir.absent !fir.box<i1>
+  %c0 = arith.constant 0 : index
+  %13 = fir.address_of(@_QQcl.2E2F746573742E66393000) : !fir.ref<!fir.char<1,11>>
+  %c3_i32 = arith.constant 3 : i32
+  %14 = fir.convert %11 : (!fir.box<!fir.array<?x?xi64>>) -> !fir.box<none>
+  %15 = fir.convert %13 : (!fir.ref<!fir.char<1,11>>) -> !fir.ref<i8>
+  %16 = fir.convert %c0 : (index) -> i32
+  %17 = fir.convert %12 : (!fir.box<i1>) -> !fir.box<none>
+  %18 = fir.call @_FortranASumInteger8(%14, %15, %c3_i32, %16, %17) : (!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i64
+  fir.store %18 to %1 : !fir.ref<i64>
+  %19 = fir.load %0 : !fir.ref<f32>
+  return %19 : f32
+}
+func.func private @_FortranASumInteger8(!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i64 attributes {fir.runtime}
+fir.global linkonce @_QQcl.2E2F746573742E66393000 constant : !fir.char<1,11> {
+  %0 = fir.string_lit "./test.f90\00"(11) : !fir.char<1,11>
+  fir.has_value %0 : !fir.char<1,11>
+}
+
+// CHECK-NOT: call{{.*}}_FortranASumInteger8(
+// CHECK: call @_FortranASumInteger8x2_simplified(
+// CHECK-NOT: call{{.*}}_FortranASumInteger8(
+
+// -----
+
+// SUM reduction of sliced assumed-shape array is replaced with
+// 2D simplified implementation.
+func.func @_QPsum_sliced_rebox_i64(%arg0: !fir.box<!fir.array<?x?x?xi64>> {fir.bindc_name = "a"}) -> f32 {
+  %0 = fir.alloca i64 {bindc_name = "sum_sliced_i64", uniq_name = "_QFsum_sliced_rebox_i64Esum_sliced_i64"}
+  %1 = fir.alloca f32 {bindc_name = "sum_sliced_rebox_i64", uniq_name = "_QFsum_sliced_rebox_i64Esum_sliced_rebox_i64"}
+  %c1 = arith.constant 1 : index
+  %c1_i64 = arith.constant 1 : i64
+  %2 = fir.convert %c1_i64 : (i64) -> index
+  %c0 = arith.constant 0 : index
+  %3:3 = fir.box_dims %arg0, %c0 : (!fir.box<!fir.array<?x?x?xi64>>, index) -> (index, index, index)
+  %4 = arith.addi %c1, %3#1 : index
+  %5 = arith.subi %4, %c1 : index
+  %c1_i64_0 = arith.constant 1 : i64
+  %6 = fir.convert %c1_i64_0 : (i64) -> index
+  %c1_1 = arith.constant 1 : index
+  %7:3 = fir.box_dims %arg0, %c1_1 : (!fir.box<!fir.array<?x?x?xi64>>, index) -> (index, index, index)
+  %8 = arith.addi %c1, %7#1 : index
+  %9 = arith.subi %8, %c1 : index
+  %c1_i64_2 = arith.constant 1 : i64
+  %10 = fir.undefined index
+  %11 = fir.slice %c1, %5, %2, %c1, %9, %6, %c1_i64_2, %10, %10 : (index, index, index, index, index, index, i64, index, index) -> !fir.slice<3>
+  %12 = fir.rebox %arg0 [%11] : (!fir.box<!fir.array<?x?x?xi64>>, !fir.slice<3>) -> !fir.box<!fir.array<?x?xi64>>
+  %13 = fir.absent !fir.box<i1>
+  %c0_3 = arith.constant 0 : index
+  %14 = fir.address_of(@_QQcl.2E2F746573742E66393000) : !fir.ref<!fir.char<1,11>>
+  %c8_i32 = arith.constant 8 : i32
+  %15 = fir.convert %12 : (!fir.box<!fir.array<?x?xi64>>) -> !fir.box<none>
+  %16 = fir.convert %14 : (!fir.ref<!fir.char<1,11>>) -> !fir.ref<i8>
+  %17 = fir.convert %c0_3 : (index) -> i32
+  %18 = fir.convert %13 : (!fir.box<i1>) -> !fir.box<none>
+  %19 = fir.call @_FortranASumInteger8(%15, %16, %c8_i32, %17, %18) : (!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i64
+  fir.store %19 to %0 : !fir.ref<i64>
+  %20 = fir.load %1 : !fir.ref<f32>
+  return %20 : f32
+}
+func.func private @_FortranASumInteger8(!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i64 attributes {fir.runtime}
+fir.global linkonce @_QQcl.2E2F746573742E66393000 constant : !fir.char<1,11> {
+  %0 = fir.string_lit "./test.f90\00"(11) : !fir.char<1,11>
+  fir.has_value %0 : !fir.char<1,11>
+}
+
+// CHECK-NOT: call{{.*}}_FortranASumInteger8(
+// CHECK: call @_FortranASumInteger8x2_simplified(
+// CHECK-NOT: call{{.*}}_FortranASumInteger8(


        


More information about the flang-commits mailing list