[flang-commits] [flang] [flang] Simplify hlfir.sum total reductions. (PR #119482)

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Thu Dec 12 21:45:42 PST 2024


https://github.com/vzakhari updated https://github.com/llvm/llvm-project/pull/119482

>From c4b4a59999cb49ce0992f4ec07ebe8cee4f3fd8c Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Tue, 10 Dec 2024 17:47:21 -0800
Subject: [PATCH 1/3] [flang] Simplify hlfir.sum total reductions.

I am trying to switch to keeping the reduction value in a temporary
scalar location so that I can use hlfir::genLoopNest easily.
This also allows using omp.loop_nest with worksharing for OpenMP.
---
 .../Transforms/SimplifyHLFIRIntrinsics.cpp    | 181 ++++++-----
 .../HLFIR/simplify-hlfir-intrinsics-sum.fir   | 289 ++++++++++--------
 2 files changed, 261 insertions(+), 209 deletions(-)

diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index b61f9767ccc2b8..2bb1a786f6c12c 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -17,6 +17,7 @@
 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/HLFIR/Passes.h"
+#include "flang/Optimizer/OpenMP/Passes.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/BuiltinDialect.h"
@@ -105,34 +106,47 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
                   mlir::PatternRewriter &rewriter) const override {
     mlir::Location loc = sum.getLoc();
     fir::FirOpBuilder builder{rewriter, sum.getOperation()};
-    hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType());
-    assert(expr && "expected an expression type for the result of hlfir.sum");
-    mlir::Type elementType = expr.getElementType();
+    mlir::Type elementType = hlfir::getFortranElementType(sum.getType());
     hlfir::Entity array = hlfir::Entity{sum.getArray()};
     mlir::Value mask = sum.getMask();
     mlir::Value dim = sum.getDim();
-    int64_t dimVal = fir::getIntIfConstant(dim).value_or(0);
+    bool isTotalReduction = hlfir::Entity{sum}.getRank() == 0;
+    int64_t dimVal =
+        isTotalReduction ? 0 : fir::getIntIfConstant(dim).value_or(0);
     mlir::Value resultShape, dimExtent;
-    std::tie(resultShape, dimExtent) =
-        genResultShape(loc, builder, array, dimVal);
+    llvm::SmallVector<mlir::Value> arrayExtents;
+    if (isTotalReduction)
+      arrayExtents = genArrayExtents(loc, builder, array);
+    else
+      std::tie(resultShape, dimExtent) =
+          genResultShapeForPartialReduction(loc, builder, array, dimVal);
+
+    // If the mask is present and is a scalar, then we'd better load its value
+    // outside of the reduction loop making the loop unswitching easier.
+    mlir::Value isPresentPred, maskValue;
+    if (mask) {
+      if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
+        // MASK represented by a box might be dynamically optional,
+        // so we have to check for its presence before accessing it.
+        isPresentPred =
+            builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), mask);
+      }
+
+      if (hlfir::Entity{mask}.isScalar())
+        maskValue = genMaskValue(loc, builder, mask, isPresentPred, {});
+    }
 
     auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
                          mlir::ValueRange inputIndices) -> hlfir::Entity {
       // Loop over all indices in the DIM dimension, and reduce all values.
-      // We do not need to create the reduction loop always: if we can
-      // slice the input array given the inputIndices, then we can
-      // just apply a new SUM operation (total reduction) to the slice.
-      // For the time being, generate the explicit loop because the slicing
-      // requires generating an elemental operation for the input array
-      // (and the mask, if present).
-      // TODO: produce the slices and new SUM after adding a pattern
-      // for expanding total reduction SUM case.
-      mlir::Type indexType = builder.getIndexType();
-      auto one = builder.createIntegerConstant(loc, indexType, 1);
-      auto ub = builder.createConvert(loc, indexType, dimExtent);
+      // If DIM is not present, do total reduction.
 
+      // Create temporary scalar for keeping the running reduction value.
+      mlir::Value reductionTemp =
+          builder.createTemporaryAlloc(loc, elementType, ".sum.reduction");
       // Initial value for the reduction.
       mlir::Value initValue = genInitValue(loc, builder, elementType);
+      builder.create<fir::StoreOp>(loc, initValue, reductionTemp);
 
       // The reduction loop may be unordered if FastMathFlags::reassoc
       // transformations are allowed. The integer reduction is always
@@ -141,42 +155,32 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
                          static_cast<bool>(sum.getFastmath() &
                                            mlir::arith::FastMathFlags::reassoc);
 
-      // If the mask is present and is a scalar, then we'd better load its value
-      // outside of the reduction loop making the loop unswitching easier.
-      // Maybe it is worth hoisting it from the elemental operation as well.
-      mlir::Value isPresentPred, maskValue;
-      if (mask) {
-        if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
-          // MASK represented by a box might be dynamically optional,
-          // so we have to check for its presence before accessing it.
-          isPresentPred =
-              builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), mask);
-        }
-
-        if (hlfir::Entity{mask}.isScalar())
-          maskValue = genMaskValue(loc, builder, mask, isPresentPred, {});
-      }
+      llvm::SmallVector<mlir::Value> extents;
+      if (isTotalReduction)
+        extents = arrayExtents;
+      else
+        extents.push_back(
+            builder.createConvert(loc, builder.getIndexType(), dimExtent));
 
       // NOTE: the outer elemental operation may be lowered into
       // omp.workshare.loop_wrapper/omp.loop_nest later, so the reduction
       // loop may appear disjoint from the workshare loop nest.
-      // Moreover, the inner loop is not strictly nested (due to the reduction
-      // starting value initialization), and the above omp dialect operations
-      // cannot produce results.
-      // It is unclear what we should do about it yet.
-      auto doLoop = builder.create<fir::DoLoopOp>(
-          loc, one, ub, one, isUnordered, /*finalCountValue=*/false,
-          mlir::ValueRange{initValue});
-
-      // Address the input array using the reduction loop's IV
-      // for the DIM dimension.
-      mlir::Value iv = doLoop.getInductionVar();
-      llvm::SmallVector<mlir::Value> indices{inputIndices};
-      indices.insert(indices.begin() + dimVal - 1, iv);
-
-      mlir::OpBuilder::InsertionGuard guard(builder);
-      builder.setInsertionPointToStart(doLoop.getBody());
-      mlir::Value reductionValue = doLoop.getRegionIterArgs()[0];
+      bool emitWorkshareLoop =
+          isTotalReduction ? flangomp::shouldUseWorkshareLowering(sum) : false;
+
+      hlfir::LoopNest loopNest = hlfir::genLoopNest(
+          loc, builder, extents, isUnordered, emitWorkshareLoop);
+
+      llvm::SmallVector<mlir::Value> indices;
+      if (isTotalReduction) {
+        indices = loopNest.oneBasedIndices;
+      } else {
+        indices = inputIndices;
+        indices.insert(indices.begin() + dimVal - 1,
+                       loopNest.oneBasedIndices[0]);
+      }
+
+      builder.setInsertionPointToStart(loopNest.body);
       fir::IfOp ifOp;
       if (mask) {
         // Make the reduction value update conditional on the value
@@ -188,16 +192,15 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
         }
         mlir::Value isUnmasked =
             builder.create<fir::ConvertOp>(loc, builder.getI1Type(), maskValue);
-        ifOp = builder.create<fir::IfOp>(loc, elementType, isUnmasked,
-                                         /*withElseRegion=*/true);
-        // In the 'else' block return the current reduction value.
-        builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
-        builder.create<fir::ResultOp>(loc, reductionValue);
+        ifOp = builder.create<fir::IfOp>(loc, isUnmasked,
+                                         /*withElseRegion=*/false);
 
         // In the 'then' block do the actual addition.
         builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
       }
 
+      mlir::Value reductionValue =
+          builder.create<fir::LoadOp>(loc, reductionTemp);
       hlfir::Entity element = hlfir::getElementAt(loc, builder, array, indices);
       hlfir::Entity elementValue =
           hlfir::loadTrivialScalar(loc, builder, element);
@@ -205,15 +208,18 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
       // (e.g. when fast-math is not allowed), but let's start with
       // the simple version.
       reductionValue = genScalarAdd(loc, builder, reductionValue, elementValue);
-      builder.create<fir::ResultOp>(loc, reductionValue);
-
-      if (ifOp) {
-        builder.setInsertionPointAfter(ifOp);
-        builder.create<fir::ResultOp>(loc, ifOp.getResult(0));
-      }
+      builder.create<fir::StoreOp>(loc, reductionValue, reductionTemp);
 
-      return hlfir::Entity{doLoop.getResult(0)};
+      builder.setInsertionPointAfter(loopNest.outerOp);
+      return hlfir::Entity{builder.create<fir::LoadOp>(loc, reductionTemp)};
     };
+
+    if (isTotalReduction) {
+      hlfir::Entity result = genKernel(loc, builder, mlir::ValueRange{});
+      rewriter.replaceOp(sum, result);
+      return mlir::success();
+    }
+
     hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
         loc, builder, elementType, resultShape, {}, genKernel,
         /*isUnordered=*/true, /*polymorphicMold=*/nullptr,
@@ -229,20 +235,29 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
   }
 
 private:
+  static llvm::SmallVector<mlir::Value>
+  genArrayExtents(mlir::Location loc, fir::FirOpBuilder &builder,
+                  hlfir::Entity array) {
+    mlir::Value inShape = hlfir::genShape(loc, builder, array);
+    llvm::SmallVector<mlir::Value> inExtents =
+        hlfir::getExplicitExtentsFromShape(inShape, builder);
+    if (inShape.getUses().empty())
+      inShape.getDefiningOp()->erase();
+    return inExtents;
+  }
+
   // Return fir.shape specifying the shape of the result
   // of a SUM reduction with DIM=dimVal. The second return value
   // is the extent of the DIM dimension.
   static std::tuple<mlir::Value, mlir::Value>
-  genResultShape(mlir::Location loc, fir::FirOpBuilder &builder,
-                 hlfir::Entity array, int64_t dimVal) {
-    mlir::Value inShape = hlfir::genShape(loc, builder, array);
+  genResultShapeForPartialReduction(mlir::Location loc,
+                                    fir::FirOpBuilder &builder,
+                                    hlfir::Entity array, int64_t dimVal) {
     llvm::SmallVector<mlir::Value> inExtents =
-        hlfir::getExplicitExtentsFromShape(inShape, builder);
+        genArrayExtents(loc, builder, array);
     assert(dimVal > 0 && dimVal <= static_cast<int64_t>(inExtents.size()) &&
            "DIM must be present and a positive constant not exceeding "
            "the array's rank");
-    if (inShape.getUses().empty())
-      inShape.getDefiningOp()->erase();
 
     mlir::Value dimExtent = inExtents[dimVal - 1];
     inExtents.erase(inExtents.begin() + dimVal - 1);
@@ -355,22 +370,22 @@ class SimplifyHLFIRIntrinsics
     target.addDynamicallyLegalOp<hlfir::SumOp>([](hlfir::SumOp sum) {
       if (!simplifySum)
         return true;
-      if (mlir::Value dim = sum.getDim()) {
-        if (auto dimVal = fir::getIntIfConstant(dim)) {
-          if (!fir::isa_trivial(sum.getType())) {
-            // Ignore the case SUM(a, DIM=X), where 'a' is a 1D array.
-            // It is only legal when X is 1, and it should probably be
-            // canonicalized into SUM(a).
-            fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
-                hlfir::getFortranElementOrSequenceType(
-                    sum.getArray().getType()));
-            if (*dimVal > 0 && *dimVal <= arrayTy.getDimension()) {
-              // Ignore SUMs with illegal DIM values.
-              // They may appear in dead code,
-              // and they do not have to be converted.
-              return false;
-            }
-          }
+
+      // Always inline total reductions.
+      if (hlfir::Entity{sum}.getRank() == 0)
+        return false;
+      mlir::Value dim = sum.getDim();
+      if (!dim)
+        return false;
+
+      if (auto dimVal = fir::getIntIfConstant(dim)) {
+        fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
+            hlfir::getFortranElementOrSequenceType(sum.getArray().getType()));
+        if (*dimVal > 0 && *dimVal <= arrayTy.getDimension()) {
+          // Ignore SUMs with illegal DIM values.
+          // They may appear in dead code,
+          // and they do not have to be converted.
+          return false;
         }
       }
       return true;
diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
index 54a592a66670f1..572b9f0da1e4ab 100644
--- a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
+++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
@@ -14,9 +14,12 @@ func.func @sum_box_known_extents(%arg0: !fir.box<!fir.array<2x3xi32>>) {
 // CHECK:           %[[VAL_4:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
 // CHECK:           %[[VAL_5:.*]] = hlfir.elemental %[[VAL_4]] unordered : (!fir.shape<1>) -> !hlfir.expr<2xi32> {
 // CHECK:           ^bb0(%[[VAL_6:.*]]: index):
-// CHECK:             %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_7:.*]] = fir.alloca i32 {bindc_name = ".sum.reduction"}
 // CHECK:             %[[VAL_8:.*]] = arith.constant 0 : i32
-// CHECK:             %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_7]] to %[[VAL_3]] step %[[VAL_7]] unordered iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (i32) {
+// CHECK:             fir.store %[[VAL_8]] to %[[VAL_7]] : !fir.ref<i32>
+// CHECK:             %[[VAL_9:.*]] = arith.constant 1 : index
+// CHECK:             fir.do_loop %[[VAL_10:.*]] = %[[VAL_9]] to %[[VAL_3]] step %[[VAL_9]] unordered {
+// CHECK:               %[[VAL_11:.*]] = fir.load %[[VAL_7]] : !fir.ref<i32>
 // CHECK:               %[[VAL_12:.*]] = arith.constant 0 : index
 // CHECK:               %[[VAL_13:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_12]] : (!fir.box<!fir.array<2x3xi32>>, index) -> (index, index, index)
 // CHECK:               %[[VAL_14:.*]] = arith.constant 1 : index
@@ -29,9 +32,10 @@ func.func @sum_box_known_extents(%arg0: !fir.box<!fir.array<2x3xi32>>) {
 // CHECK:               %[[VAL_21:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_18]], %[[VAL_20]])  : (!fir.box<!fir.array<2x3xi32>>, index, index) -> !fir.ref<i32>
 // CHECK:               %[[VAL_22:.*]] = fir.load %[[VAL_21]] : !fir.ref<i32>
 // CHECK:               %[[VAL_23:.*]] = arith.addi %[[VAL_11]], %[[VAL_22]] : i32
-// CHECK:               fir.result %[[VAL_23]] : i32
+// CHECK:               fir.store %[[VAL_23]] to %[[VAL_7]] : !fir.ref<i32>
 // CHECK:             }
-// CHECK:             hlfir.yield_element %[[VAL_9]] : i32
+// CHECK:             %[[VAL_24:.*]] = fir.load %[[VAL_7]] : !fir.ref<i32>
+// CHECK:             hlfir.yield_element %[[VAL_24]] : i32
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -50,14 +54,18 @@ func.func @sum_expr_known_extents(%arg0: !hlfir.expr<2x3xi32>) {
 // CHECK:           %[[VAL_4:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
 // CHECK:           %[[VAL_5:.*]] = hlfir.elemental %[[VAL_4]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xi32> {
 // CHECK:           ^bb0(%[[VAL_6:.*]]: index):
-// CHECK:             %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_7:.*]] = fir.alloca i32 {bindc_name = ".sum.reduction"}
 // CHECK:             %[[VAL_8:.*]] = arith.constant 0 : i32
-// CHECK:             %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_7]] to %[[VAL_2]] step %[[VAL_7]] unordered iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (i32) {
+// CHECK:             fir.store %[[VAL_8]] to %[[VAL_7]] : !fir.ref<i32>
+// CHECK:             %[[VAL_9:.*]] = arith.constant 1 : index
+// CHECK:             fir.do_loop %[[VAL_10:.*]] = %[[VAL_9]] to %[[VAL_2]] step %[[VAL_9]] unordered {
+// CHECK:               %[[VAL_11:.*]] = fir.load %[[VAL_7]] : !fir.ref<i32>
 // CHECK:               %[[VAL_12:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_10]], %[[VAL_6]] : (!hlfir.expr<2x3xi32>, index, index) -> i32
 // CHECK:               %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : i32
-// CHECK:               fir.result %[[VAL_13]] : i32
+// CHECK:               fir.store %[[VAL_13]] to %[[VAL_7]] : !fir.ref<i32>
 // CHECK:             }
-// CHECK:             hlfir.yield_element %[[VAL_9]] : i32
+// CHECK:             %[[VAL_14:.*]] = fir.load %[[VAL_7]] : !fir.ref<i32>
+// CHECK:             hlfir.yield_element %[[VAL_14]] : i32
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -77,12 +85,15 @@ func.func @sum_box_unknown_extent1(%arg0: !fir.box<!fir.array<?x3xcomplex<f64>>>
 // CHECK:           %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
 // CHECK:           %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xcomplex<f64>> {
 // CHECK:           ^bb0(%[[VAL_7:.*]]: index):
-// CHECK:             %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_8:.*]] = fir.alloca complex<f64> {bindc_name = ".sum.reduction"}
 // CHECK:             %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f64
 // CHECK:             %[[VAL_10:.*]] = fir.undefined complex<f64>
 // CHECK:             %[[VAL_11:.*]] = fir.insert_value %[[VAL_10]], %[[VAL_9]], [0 : index] : (complex<f64>, f64) -> complex<f64>
 // CHECK:             %[[VAL_12:.*]] = fir.insert_value %[[VAL_11]], %[[VAL_9]], [1 : index] : (complex<f64>, f64) -> complex<f64>
-// CHECK:             %[[VAL_13:.*]] = fir.do_loop %[[VAL_14:.*]] = %[[VAL_8]] to %[[VAL_3]]#1 step %[[VAL_8]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (complex<f64>) {
+// CHECK:             fir.store %[[VAL_12]] to %[[VAL_8]] : !fir.ref<complex<f64>>
+// CHECK:             %[[VAL_13:.*]] = arith.constant 1 : index
+// CHECK:             fir.do_loop %[[VAL_14:.*]] = %[[VAL_13]] to %[[VAL_3]]#1 step %[[VAL_13]] {
+// CHECK:               %[[VAL_15:.*]] = fir.load %[[VAL_8]] : !fir.ref<complex<f64>>
 // CHECK:               %[[VAL_16:.*]] = arith.constant 0 : index
 // CHECK:               %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_16]] : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index) -> (index, index, index)
 // CHECK:               %[[VAL_18:.*]] = arith.constant 1 : index
@@ -95,9 +106,10 @@ func.func @sum_box_unknown_extent1(%arg0: !fir.box<!fir.array<?x3xcomplex<f64>>>
 // CHECK:               %[[VAL_25:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_22]], %[[VAL_24]])  : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index, index) -> !fir.ref<complex<f64>>
 // CHECK:               %[[VAL_26:.*]] = fir.load %[[VAL_25]] : !fir.ref<complex<f64>>
 // CHECK:               %[[VAL_27:.*]] = fir.addc %[[VAL_15]], %[[VAL_26]] : complex<f64>
-// CHECK:               fir.result %[[VAL_27]] : complex<f64>
+// CHECK:               fir.store %[[VAL_27]] to %[[VAL_8]] : !fir.ref<complex<f64>>
 // CHECK:             }
-// CHECK:             hlfir.yield_element %[[VAL_13]] : complex<f64>
+// CHECK:             %[[VAL_28:.*]] = fir.load %[[VAL_8]] : !fir.ref<complex<f64>>
+// CHECK:             hlfir.yield_element %[[VAL_28]] : complex<f64>
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -116,12 +128,15 @@ func.func @sum_box_unknown_extent2(%arg0: !fir.box<!fir.array<?x3xcomplex<f64>>>
 // CHECK:           %[[VAL_5:.*]] = fir.shape %[[VAL_3]]#1 : (index) -> !fir.shape<1>
 // CHECK:           %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xcomplex<f64>> {
 // CHECK:           ^bb0(%[[VAL_7:.*]]: index):
-// CHECK:             %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_8:.*]] = fir.alloca complex<f64> {bindc_name = ".sum.reduction"}
 // CHECK:             %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f64
 // CHECK:             %[[VAL_10:.*]] = fir.undefined complex<f64>
 // CHECK:             %[[VAL_11:.*]] = fir.insert_value %[[VAL_10]], %[[VAL_9]], [0 : index] : (complex<f64>, f64) -> complex<f64>
 // CHECK:             %[[VAL_12:.*]] = fir.insert_value %[[VAL_11]], %[[VAL_9]], [1 : index] : (complex<f64>, f64) -> complex<f64>
-// CHECK:             %[[VAL_13:.*]] = fir.do_loop %[[VAL_14:.*]] = %[[VAL_8]] to %[[VAL_4]] step %[[VAL_8]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (complex<f64>) {
+// CHECK:             fir.store %[[VAL_12]] to %[[VAL_8]] : !fir.ref<complex<f64>>
+// CHECK:             %[[VAL_13:.*]] = arith.constant 1 : index
+// CHECK:             fir.do_loop %[[VAL_14:.*]] = %[[VAL_13]] to %[[VAL_4]] step %[[VAL_13]] {
+// CHECK:               %[[VAL_15:.*]] = fir.load %[[VAL_8]] : !fir.ref<complex<f64>>
 // CHECK:               %[[VAL_16:.*]] = arith.constant 0 : index
 // CHECK:               %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_16]] : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index) -> (index, index, index)
 // CHECK:               %[[VAL_18:.*]] = arith.constant 1 : index
@@ -134,9 +149,10 @@ func.func @sum_box_unknown_extent2(%arg0: !fir.box<!fir.array<?x3xcomplex<f64>>>
 // CHECK:               %[[VAL_25:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_22]], %[[VAL_24]])  : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index, index) -> !fir.ref<complex<f64>>
 // CHECK:               %[[VAL_26:.*]] = fir.load %[[VAL_25]] : !fir.ref<complex<f64>>
 // CHECK:               %[[VAL_27:.*]] = fir.addc %[[VAL_15]], %[[VAL_26]] : complex<f64>
-// CHECK:               fir.result %[[VAL_27]] : complex<f64>
+// CHECK:               fir.store %[[VAL_27]] to %[[VAL_8]] : !fir.ref<complex<f64>>
 // CHECK:             }
-// CHECK:             hlfir.yield_element %[[VAL_13]] : complex<f64>
+// CHECK:             %[[VAL_28:.*]] = fir.load %[[VAL_8]] : !fir.ref<complex<f64>>
+// CHECK:             hlfir.yield_element %[[VAL_28]] : complex<f64>
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -148,7 +164,7 @@ func.func @sum_expr_unknown_extent1(%arg0: !hlfir.expr<?x3xf32>) {
   return
 }
 // CHECK-LABEL:   func.func @sum_expr_unknown_extent1(
-// CHECK-SAME:                                         %[[VAL_0:.*]]: !hlfir.expr<?x3xf32>) {
+// CHECK-SAME:                                        %[[VAL_0:.*]]: !hlfir.expr<?x3xf32>) {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 1 : i32
 // CHECK:           %[[VAL_2:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x3xf32>) -> !fir.shape<2>
 // CHECK:           %[[VAL_3:.*]] = hlfir.get_extent %[[VAL_2]] {dim = 0 : index} : (!fir.shape<2>) -> index
@@ -156,14 +172,18 @@ func.func @sum_expr_unknown_extent1(%arg0: !hlfir.expr<?x3xf32>) {
 // CHECK:           %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
 // CHECK:           %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xf32> {
 // CHECK:           ^bb0(%[[VAL_7:.*]]: index):
-// CHECK:             %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_8:.*]] = fir.alloca f32 {bindc_name = ".sum.reduction"}
 // CHECK:             %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:             %[[VAL_10:.*]] = fir.do_loop %[[VAL_11:.*]] = %[[VAL_8]] to %[[VAL_3]] step %[[VAL_8]] iter_args(%[[VAL_12:.*]] = %[[VAL_9]]) -> (f32) {
+// CHECK:             fir.store %[[VAL_9]] to %[[VAL_8]] : !fir.ref<f32>
+// CHECK:             %[[VAL_10:.*]] = arith.constant 1 : index
+// CHECK:             fir.do_loop %[[VAL_11:.*]] = %[[VAL_10]] to %[[VAL_3]] step %[[VAL_10]] {
+// CHECK:               %[[VAL_12:.*]] = fir.load %[[VAL_8]] : !fir.ref<f32>
 // CHECK:               %[[VAL_13:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_11]], %[[VAL_7]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
 // CHECK:               %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32
-// CHECK:               fir.result %[[VAL_14]] : f32
+// CHECK:               fir.store %[[VAL_14]] to %[[VAL_8]] : !fir.ref<f32>
 // CHECK:             }
-// CHECK:             hlfir.yield_element %[[VAL_10]] : f32
+// CHECK:             %[[VAL_15:.*]] = fir.load %[[VAL_8]] : !fir.ref<f32>
+// CHECK:             hlfir.yield_element %[[VAL_15]] : f32
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -174,7 +194,7 @@ func.func @sum_expr_unknown_extent2(%arg0: !hlfir.expr<?x3xf32>) {
   return
 }
 // CHECK-LABEL:   func.func @sum_expr_unknown_extent2(
-// CHECK-SAME:                                         %[[VAL_0:.*]]: !hlfir.expr<?x3xf32>) {
+// CHECK-SAME:                                        %[[VAL_0:.*]]: !hlfir.expr<?x3xf32>) {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 2 : i32
 // CHECK:           %[[VAL_2:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x3xf32>) -> !fir.shape<2>
 // CHECK:           %[[VAL_3:.*]] = hlfir.get_extent %[[VAL_2]] {dim = 0 : index} : (!fir.shape<2>) -> index
@@ -182,14 +202,18 @@ func.func @sum_expr_unknown_extent2(%arg0: !hlfir.expr<?x3xf32>) {
 // CHECK:           %[[VAL_5:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
 // CHECK:           %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
 // CHECK:           ^bb0(%[[VAL_7:.*]]: index):
-// CHECK:             %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_8:.*]] = fir.alloca f32 {bindc_name = ".sum.reduction"}
 // CHECK:             %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:             %[[VAL_10:.*]] = fir.do_loop %[[VAL_11:.*]] = %[[VAL_8]] to %[[VAL_4]] step %[[VAL_8]] iter_args(%[[VAL_12:.*]] = %[[VAL_9]]) -> (f32) {
+// CHECK:             fir.store %[[VAL_9]] to %[[VAL_8]] : !fir.ref<f32>
+// CHECK:             %[[VAL_10:.*]] = arith.constant 1 : index
+// CHECK:             fir.do_loop %[[VAL_11:.*]] = %[[VAL_10]] to %[[VAL_4]] step %[[VAL_10]] {
+// CHECK:               %[[VAL_12:.*]] = fir.load %[[VAL_8]] : !fir.ref<f32>
 // CHECK:               %[[VAL_13:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_7]], %[[VAL_11]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
 // CHECK:               %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32
-// CHECK:               fir.result %[[VAL_14]] : f32
+// CHECK:               fir.store %[[VAL_14]] to %[[VAL_8]] : !fir.ref<f32>
 // CHECK:             }
-// CHECK:             hlfir.yield_element %[[VAL_10]] : f32
+// CHECK:             %[[VAL_15:.*]] = fir.load %[[VAL_8]] : !fir.ref<f32>
+// CHECK:             hlfir.yield_element %[[VAL_15]] : f32
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -208,23 +232,24 @@ func.func @sum_scalar_mask(%arg0: !hlfir.expr<?x3xf32>, %mask: !fir.ref<!fir.log
 // CHECK:           %[[VAL_4:.*]] = hlfir.get_extent %[[VAL_3]] {dim = 0 : index} : (!fir.shape<2>) -> index
 // CHECK:           %[[VAL_5:.*]] = arith.constant 3 : index
 // CHECK:           %[[VAL_6:.*]] = fir.shape %[[VAL_5]] : (index) -> !fir.shape<1>
-// CHECK:           %[[VAL_7:.*]] = hlfir.elemental %[[VAL_6]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xf32> {
-// CHECK:           ^bb0(%[[VAL_8:.*]]: index):
-// CHECK:             %[[VAL_9:.*]] = arith.constant 1 : index
-// CHECK:             %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:             %[[VAL_11:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.logical<1>>
-// CHECK:             %[[VAL_12:.*]] = fir.do_loop %[[VAL_13:.*]] = %[[VAL_9]] to %[[VAL_4]] step %[[VAL_9]] iter_args(%[[VAL_14:.*]] = %[[VAL_10]]) -> (f32) {
-// CHECK:               %[[VAL_15:.*]] = fir.convert %[[VAL_11]] : (!fir.logical<1>) -> i1
-// CHECK:               %[[VAL_16:.*]] = fir.if %[[VAL_15]] -> (f32) {
-// CHECK:                 %[[VAL_17:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_13]], %[[VAL_8]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
-// CHECK:                 %[[VAL_18:.*]] = arith.addf %[[VAL_14]], %[[VAL_17]] : f32
-// CHECK:                 fir.result %[[VAL_18]] : f32
-// CHECK:               } else {
-// CHECK:                 fir.result %[[VAL_14]] : f32
+// CHECK:           %[[VAL_7:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.logical<1>>
+// CHECK:           %[[VAL_8:.*]] = hlfir.elemental %[[VAL_6]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xf32> {
+// CHECK:           ^bb0(%[[VAL_9:.*]]: index):
+// CHECK:             %[[VAL_10:.*]] = fir.alloca f32 {bindc_name = ".sum.reduction"}
+// CHECK:             %[[VAL_11:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:             fir.store %[[VAL_11]] to %[[VAL_10]] : !fir.ref<f32>
+// CHECK:             %[[VAL_12:.*]] = arith.constant 1 : index
+// CHECK:             fir.do_loop %[[VAL_13:.*]] = %[[VAL_12]] to %[[VAL_4]] step %[[VAL_12]] {
+// CHECK:               %[[VAL_14:.*]] = fir.convert %[[VAL_7]] : (!fir.logical<1>) -> i1
+// CHECK:               fir.if %[[VAL_14]] {
+// CHECK:                 %[[VAL_15:.*]] = fir.load %[[VAL_10]] : !fir.ref<f32>
+// CHECK:                 %[[VAL_16:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_13]], %[[VAL_9]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
+// CHECK:                 %[[VAL_17:.*]] = arith.addf %[[VAL_15]], %[[VAL_16]] : f32
+// CHECK:                 fir.store %[[VAL_17]] to %[[VAL_10]] : !fir.ref<f32>
 // CHECK:               }
-// CHECK:               fir.result %[[VAL_16]] : f32
 // CHECK:             }
-// CHECK:             hlfir.yield_element %[[VAL_12]] : f32
+// CHECK:             %[[VAL_18:.*]] = fir.load %[[VAL_10]] : !fir.ref<f32>
+// CHECK:             hlfir.yield_element %[[VAL_18]] : f32
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -243,32 +268,33 @@ func.func @sum_scalar_boxed_mask(%arg0: !hlfir.expr<?x3xf32>, %mask: !fir.box<!f
 // CHECK:           %[[VAL_4:.*]] = hlfir.get_extent %[[VAL_3]] {dim = 0 : index} : (!fir.shape<2>) -> index
 // CHECK:           %[[VAL_5:.*]] = arith.constant 3 : index
 // CHECK:           %[[VAL_6:.*]] = fir.shape %[[VAL_5]] : (index) -> !fir.shape<1>
-// CHECK:           %[[VAL_7:.*]] = hlfir.elemental %[[VAL_6]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xf32> {
-// CHECK:           ^bb0(%[[VAL_8:.*]]: index):
-// CHECK:             %[[VAL_9:.*]] = arith.constant 1 : index
-// CHECK:             %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:             %[[VAL_11:.*]] = fir.is_present %[[VAL_1]] : (!fir.box<!fir.logical<1>>) -> i1
-// CHECK:             %[[VAL_12:.*]] = fir.if %[[VAL_11]] -> (!fir.logical<1>) {
-// CHECK:               %[[VAL_13:.*]] = fir.box_addr %[[VAL_1]] : (!fir.box<!fir.logical<1>>) -> !fir.ref<!fir.logical<1>>
-// CHECK:               %[[VAL_14:.*]] = fir.load %[[VAL_13]] : !fir.ref<!fir.logical<1>>
-// CHECK:               fir.result %[[VAL_14]] : !fir.logical<1>
-// CHECK:             } else {
-// CHECK:               %[[VAL_15:.*]] = arith.constant true
-// CHECK:               %[[VAL_16:.*]] = fir.convert %[[VAL_15]] : (i1) -> !fir.logical<1>
-// CHECK:               fir.result %[[VAL_16]] : !fir.logical<1>
-// CHECK:             }
-// CHECK:             %[[VAL_17:.*]] = fir.do_loop %[[VAL_18:.*]] = %[[VAL_9]] to %[[VAL_4]] step %[[VAL_9]] iter_args(%[[VAL_19:.*]] = %[[VAL_10]]) -> (f32) {
-// CHECK:               %[[VAL_20:.*]] = fir.convert %[[VAL_12]] : (!fir.logical<1>) -> i1
-// CHECK:               %[[VAL_21:.*]] = fir.if %[[VAL_20]] -> (f32) {
-// CHECK:                 %[[VAL_22:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_18]], %[[VAL_8]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
-// CHECK:                 %[[VAL_23:.*]] = arith.addf %[[VAL_19]], %[[VAL_22]] : f32
-// CHECK:                 fir.result %[[VAL_23]] : f32
-// CHECK:               } else {
-// CHECK:                 fir.result %[[VAL_19]] : f32
+// CHECK:           %[[VAL_7:.*]] = fir.is_present %[[VAL_1]] : (!fir.box<!fir.logical<1>>) -> i1
+// CHECK:           %[[VAL_8:.*]] = fir.if %[[VAL_7]] -> (!fir.logical<1>) {
+// CHECK:             %[[VAL_9:.*]] = fir.box_addr %[[VAL_1]] : (!fir.box<!fir.logical<1>>) -> !fir.ref<!fir.logical<1>>
+// CHECK:             %[[VAL_10:.*]] = fir.load %[[VAL_9]] : !fir.ref<!fir.logical<1>>
+// CHECK:             fir.result %[[VAL_10]] : !fir.logical<1>
+// CHECK:           } else {
+// CHECK:             %[[VAL_11:.*]] = arith.constant true
+// CHECK:             %[[VAL_12:.*]] = fir.convert %[[VAL_11]] : (i1) -> !fir.logical<1>
+// CHECK:             fir.result %[[VAL_12]] : !fir.logical<1>
+// CHECK:           }
+// CHECK:           %[[VAL_13:.*]] = hlfir.elemental %[[VAL_6]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xf32> {
+// CHECK:           ^bb0(%[[VAL_14:.*]]: index):
+// CHECK:             %[[VAL_15:.*]] = fir.alloca f32 {bindc_name = ".sum.reduction"}
+// CHECK:             %[[VAL_16:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:             fir.store %[[VAL_16]] to %[[VAL_15]] : !fir.ref<f32>
+// CHECK:             %[[VAL_17:.*]] = arith.constant 1 : index
+// CHECK:             fir.do_loop %[[VAL_18:.*]] = %[[VAL_17]] to %[[VAL_4]] step %[[VAL_17]] {
+// CHECK:               %[[VAL_19:.*]] = fir.convert %[[VAL_8]] : (!fir.logical<1>) -> i1
+// CHECK:               fir.if %[[VAL_19]] {
+// CHECK:                 %[[VAL_20:.*]] = fir.load %[[VAL_15]] : !fir.ref<f32>
+// CHECK:                 %[[VAL_21:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_18]], %[[VAL_14]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
+// CHECK:                 %[[VAL_22:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32
+// CHECK:                 fir.store %[[VAL_22]] to %[[VAL_15]] : !fir.ref<f32>
 // CHECK:               }
-// CHECK:               fir.result %[[VAL_21]] : f32
 // CHECK:             }
-// CHECK:             hlfir.yield_element %[[VAL_17]] : f32
+// CHECK:             %[[VAL_23:.*]] = fir.load %[[VAL_15]] : !fir.ref<f32>
+// CHECK:             hlfir.yield_element %[[VAL_23]] : f32
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -287,41 +313,42 @@ func.func @sum_array_mask(%arg0: !hlfir.expr<?x3xf32>, %mask: !fir.box<!fir.arra
 // CHECK:           %[[VAL_4:.*]] = hlfir.get_extent %[[VAL_3]] {dim = 0 : index} : (!fir.shape<2>) -> index
 // CHECK:           %[[VAL_5:.*]] = arith.constant 3 : index
 // CHECK:           %[[VAL_6:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
-// CHECK:           %[[VAL_7:.*]] = hlfir.elemental %[[VAL_6]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
-// CHECK:           ^bb0(%[[VAL_8:.*]]: index):
-// CHECK:             %[[VAL_9:.*]] = arith.constant 1 : index
-// CHECK:             %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:             %[[VAL_11:.*]] = fir.is_present %[[VAL_1]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>) -> i1
-// CHECK:             %[[VAL_12:.*]] = fir.do_loop %[[VAL_13:.*]] = %[[VAL_9]] to %[[VAL_5]] step %[[VAL_9]] iter_args(%[[VAL_14:.*]] = %[[VAL_10]]) -> (f32) {
-// CHECK:               %[[VAL_15:.*]] = fir.if %[[VAL_11]] -> (!fir.logical<1>) {
-// CHECK:                 %[[VAL_16:.*]] = arith.constant 0 : index
-// CHECK:                 %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_16]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index) -> (index, index, index)
-// CHECK:                 %[[VAL_18:.*]] = arith.constant 1 : index
-// CHECK:                 %[[VAL_19:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_18]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index) -> (index, index, index)
-// CHECK:                 %[[VAL_20:.*]] = arith.constant 1 : index
-// CHECK:                 %[[VAL_21:.*]] = arith.subi %[[VAL_17]]#0, %[[VAL_20]] : index
-// CHECK:                 %[[VAL_22:.*]] = arith.addi %[[VAL_8]], %[[VAL_21]] : index
-// CHECK:                 %[[VAL_23:.*]] = arith.subi %[[VAL_19]]#0, %[[VAL_20]] : index
-// CHECK:                 %[[VAL_24:.*]] = arith.addi %[[VAL_13]], %[[VAL_23]] : index
-// CHECK:                 %[[VAL_25:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_22]], %[[VAL_24]])  : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index, index) -> !fir.ref<!fir.logical<1>>
-// CHECK:                 %[[VAL_26:.*]] = fir.load %[[VAL_25]] : !fir.ref<!fir.logical<1>>
-// CHECK:                 fir.result %[[VAL_26]] : !fir.logical<1>
+// CHECK:           %[[VAL_7:.*]] = fir.is_present %[[VAL_1]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>) -> i1
+// CHECK:           %[[VAL_8:.*]] = hlfir.elemental %[[VAL_6]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+// CHECK:           ^bb0(%[[VAL_9:.*]]: index):
+// CHECK:             %[[VAL_10:.*]] = fir.alloca f32 {bindc_name = ".sum.reduction"}
+// CHECK:             %[[VAL_11:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:             fir.store %[[VAL_11]] to %[[VAL_10]] : !fir.ref<f32>
+// CHECK:             %[[VAL_12:.*]] = arith.constant 1 : index
+// CHECK:             fir.do_loop %[[VAL_13:.*]] = %[[VAL_12]] to %[[VAL_5]] step %[[VAL_12]] {
+// CHECK:               %[[VAL_14:.*]] = fir.if %[[VAL_7]] -> (!fir.logical<1>) {
+// CHECK:                 %[[VAL_15:.*]] = arith.constant 0 : index
+// CHECK:                 %[[VAL_16:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_15]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index) -> (index, index, index)
+// CHECK:                 %[[VAL_17:.*]] = arith.constant 1 : index
+// CHECK:                 %[[VAL_18:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_17]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index) -> (index, index, index)
+// CHECK:                 %[[VAL_19:.*]] = arith.constant 1 : index
+// CHECK:                 %[[VAL_20:.*]] = arith.subi %[[VAL_16]]#0, %[[VAL_19]] : index
+// CHECK:                 %[[VAL_21:.*]] = arith.addi %[[VAL_9]], %[[VAL_20]] : index
+// CHECK:                 %[[VAL_22:.*]] = arith.subi %[[VAL_18]]#0, %[[VAL_19]] : index
+// CHECK:                 %[[VAL_23:.*]] = arith.addi %[[VAL_13]], %[[VAL_22]] : index
+// CHECK:                 %[[VAL_24:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_21]], %[[VAL_23]])  : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index, index) -> !fir.ref<!fir.logical<1>>
+// CHECK:                 %[[VAL_25:.*]] = fir.load %[[VAL_24]] : !fir.ref<!fir.logical<1>>
+// CHECK:                 fir.result %[[VAL_25]] : !fir.logical<1>
 // CHECK:               } else {
-// CHECK:                 %[[VAL_27:.*]] = arith.constant true
-// CHECK:                 %[[VAL_28:.*]] = fir.convert %[[VAL_27]] : (i1) -> !fir.logical<1>
-// CHECK:                 fir.result %[[VAL_28]] : !fir.logical<1>
+// CHECK:                 %[[VAL_26:.*]] = arith.constant true
+// CHECK:                 %[[VAL_27:.*]] = fir.convert %[[VAL_26]] : (i1) -> !fir.logical<1>
+// CHECK:                 fir.result %[[VAL_27]] : !fir.logical<1>
 // CHECK:               }
-// CHECK:               %[[VAL_29:.*]] = fir.convert %[[VAL_15]] : (!fir.logical<1>) -> i1
-// CHECK:               %[[VAL_30:.*]] = fir.if %[[VAL_29]] -> (f32) {
-// CHECK:                 %[[VAL_31:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_8]], %[[VAL_13]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
-// CHECK:                 %[[VAL_32:.*]] = arith.addf %[[VAL_14]], %[[VAL_31]] : f32
-// CHECK:                 fir.result %[[VAL_32]] : f32
-// CHECK:               } else {
-// CHECK:                 fir.result %[[VAL_14]] : f32
+// CHECK:               %[[VAL_28:.*]] = fir.convert %[[VAL_14]] : (!fir.logical<1>) -> i1
+// CHECK:               fir.if %[[VAL_28]] {
+// CHECK:                 %[[VAL_29:.*]] = fir.load %[[VAL_10]] : !fir.ref<f32>
+// CHECK:                 %[[VAL_30:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_9]], %[[VAL_13]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
+// CHECK:                 %[[VAL_31:.*]] = arith.addf %[[VAL_29]], %[[VAL_30]] : f32
+// CHECK:                 fir.store %[[VAL_31]] to %[[VAL_10]] : !fir.ref<f32>
 // CHECK:               }
-// CHECK:               fir.result %[[VAL_30]] : f32
 // CHECK:             }
-// CHECK:             hlfir.yield_element %[[VAL_12]] : f32
+// CHECK:             %[[VAL_32:.*]] = fir.load %[[VAL_10]] : !fir.ref<f32>
+// CHECK:             hlfir.yield_element %[[VAL_32]] : f32
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -342,21 +369,22 @@ func.func @sum_array_expr_mask(%arg0: !hlfir.expr<?x3xf32>, %mask: !hlfir.expr<?
 // CHECK:           %[[VAL_6:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
 // CHECK:           %[[VAL_7:.*]] = hlfir.elemental %[[VAL_6]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
 // CHECK:           ^bb0(%[[VAL_8:.*]]: index):
-// CHECK:             %[[VAL_9:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_9:.*]] = fir.alloca f32 {bindc_name = ".sum.reduction"}
 // CHECK:             %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:             %[[VAL_11:.*]] = fir.do_loop %[[VAL_12:.*]] = %[[VAL_9]] to %[[VAL_5]] step %[[VAL_9]] iter_args(%[[VAL_13:.*]] = %[[VAL_10]]) -> (f32) {
-// CHECK:               %[[VAL_14:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_8]], %[[VAL_12]] : (!hlfir.expr<?x3x!fir.logical<1>>, index, index) -> !fir.logical<1>
-// CHECK:               %[[VAL_15:.*]] = fir.convert %[[VAL_14]] : (!fir.logical<1>) -> i1
-// CHECK:               %[[VAL_16:.*]] = fir.if %[[VAL_15]] -> (f32) {
-// CHECK:                 %[[VAL_17:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_8]], %[[VAL_12]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
-// CHECK:                 %[[VAL_18:.*]] = arith.addf %[[VAL_13]], %[[VAL_17]] : f32
-// CHECK:                 fir.result %[[VAL_18]] : f32
-// CHECK:               } else {
-// CHECK:                 fir.result %[[VAL_13]] : f32
+// CHECK:             fir.store %[[VAL_10]] to %[[VAL_9]] : !fir.ref<f32>
+// CHECK:             %[[VAL_11:.*]] = arith.constant 1 : index
+// CHECK:             fir.do_loop %[[VAL_12:.*]] = %[[VAL_11]] to %[[VAL_5]] step %[[VAL_11]] {
+// CHECK:               %[[VAL_13:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_8]], %[[VAL_12]] : (!hlfir.expr<?x3x!fir.logical<1>>, index, index) -> !fir.logical<1>
+// CHECK:               %[[VAL_14:.*]] = fir.convert %[[VAL_13]] : (!fir.logical<1>) -> i1
+// CHECK:               fir.if %[[VAL_14]] {
+// CHECK:                 %[[VAL_15:.*]] = fir.load %[[VAL_9]] : !fir.ref<f32>
+// CHECK:                 %[[VAL_16:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_8]], %[[VAL_12]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
+// CHECK:                 %[[VAL_17:.*]] = arith.addf %[[VAL_15]], %[[VAL_16]] : f32
+// CHECK:                 fir.store %[[VAL_17]] to %[[VAL_9]] : !fir.ref<f32>
 // CHECK:               }
-// CHECK:               fir.result %[[VAL_16]] : f32
 // CHECK:             }
-// CHECK:             hlfir.yield_element %[[VAL_11]] : f32
+// CHECK:             %[[VAL_18:.*]] = fir.load %[[VAL_9]] : !fir.ref<f32>
+// CHECK:             hlfir.yield_element %[[VAL_18]] : f32
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -375,19 +403,23 @@ func.func @sum_unordered_reduction(%arg0: !hlfir.expr<2x3xf32>) {
 // CHECK:           %[[VAL_4:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
 // CHECK:           %[[VAL_5:.*]] = hlfir.elemental %[[VAL_4]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xf32> {
 // CHECK:           ^bb0(%[[VAL_6:.*]]: index):
-// CHECK:             %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_7:.*]] = fir.alloca f32 {bindc_name = ".sum.reduction"}
 // CHECK:             %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:             %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_7]] to %[[VAL_2]] step %[[VAL_7]] unordered iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (f32) {
+// CHECK:             fir.store %[[VAL_8]] to %[[VAL_7]] : !fir.ref<f32>
+// CHECK:             %[[VAL_9:.*]] = arith.constant 1 : index
+// CHECK:             fir.do_loop %[[VAL_10:.*]] = %[[VAL_9]] to %[[VAL_2]] step %[[VAL_9]] unordered {
+// CHECK:               %[[VAL_11:.*]] = fir.load %[[VAL_7]] : !fir.ref<f32>
 // CHECK:               %[[VAL_12:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_10]], %[[VAL_6]] : (!hlfir.expr<2x3xf32>, index, index) -> f32
 // CHECK:               %[[VAL_13:.*]] = arith.addf %[[VAL_11]], %[[VAL_12]] fastmath<reassoc> : f32
-// CHECK:               fir.result %[[VAL_13]] : f32
+// CHECK:               fir.store %[[VAL_13]] to %[[VAL_7]] : !fir.ref<f32>
 // CHECK:             }
-// CHECK:             hlfir.yield_element %[[VAL_9]] : f32
+// CHECK:             %[[VAL_14:.*]] = fir.load %[[VAL_7]] : !fir.ref<f32>
+// CHECK:             hlfir.yield_element %[[VAL_14]] : f32
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
 
-// negative: total reduction
+// total reduction
 func.func @sum_total_reduction(%arg0: !fir.box<!fir.array<3xi32>>) {
   %cst = arith.constant 1 : i32
   %res = hlfir.sum %arg0 dim %cst : (!fir.box<!fir.array<3xi32>>, i32) -> i32
@@ -396,19 +428,24 @@ func.func @sum_total_reduction(%arg0: !fir.box<!fir.array<3xi32>>) {
 // CHECK-LABEL:   func.func @sum_total_reduction(
 // CHECK-SAME:                                   %[[VAL_0:.*]]: !fir.box<!fir.array<3xi32>>) {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 1 : i32
-// CHECK:           %[[VAL_2:.*]] = hlfir.sum %[[VAL_0]] dim %[[VAL_1]] : (!fir.box<!fir.array<3xi32>>, i32) -> i32
-// CHECK:           return
-// CHECK:         }
-
-// negative: non-const dim
-func.func @sum_non_const_dim(%arg0: !fir.box<!fir.array<3xi32>>, %dim: i32) {
-  %res = hlfir.sum %arg0 dim %dim : (!fir.box<!fir.array<3xi32>>, i32) -> i32
-  return
-}
-// CHECK-LABEL:   func.func @sum_non_const_dim(
-// CHECK-SAME:                                 %[[VAL_0:.*]]: !fir.box<!fir.array<3xi32>>,
-// CHECK-SAME:                                 %[[VAL_1:.*]]: i32) {
-// CHECK:           %[[VAL_2:.*]] = hlfir.sum %[[VAL_0]] dim %[[VAL_1]] : (!fir.box<!fir.array<3xi32>>, i32) -> i32
+// CHECK:           %[[VAL_2:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_3:.*]] = fir.alloca i32 {bindc_name = ".sum.reduction"}
+// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : i32
+// CHECK:           fir.store %[[VAL_4]] to %[[VAL_3]] : !fir.ref<i32>
+// CHECK:           %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK:           fir.do_loop %[[VAL_6:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_5]] unordered {
+// CHECK:             %[[VAL_7:.*]] = fir.load %[[VAL_3]] : !fir.ref<i32>
+// CHECK:             %[[VAL_8:.*]] = arith.constant 0 : index
+// CHECK:             %[[VAL_9:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_8]] : (!fir.box<!fir.array<3xi32>>, index) -> (index, index, index)
+// CHECK:             %[[VAL_10:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_11:.*]] = arith.subi %[[VAL_9]]#0, %[[VAL_10]] : index
+// CHECK:             %[[VAL_12:.*]] = arith.addi %[[VAL_6]], %[[VAL_11]] : index
+// CHECK:             %[[VAL_13:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_12]])  : (!fir.box<!fir.array<3xi32>>, index) -> !fir.ref<i32>
+// CHECK:             %[[VAL_14:.*]] = fir.load %[[VAL_13]] : !fir.ref<i32>
+// CHECK:             %[[VAL_15:.*]] = arith.addi %[[VAL_7]], %[[VAL_14]] : i32
+// CHECK:             fir.store %[[VAL_15]] to %[[VAL_3]] : !fir.ref<i32>
+// CHECK:           }
+// CHECK:           %[[VAL_16:.*]] = fir.load %[[VAL_3]] : !fir.ref<i32>
 // CHECK:           return
 // CHECK:         }
 

>From 3d37ec81b7abdeb4c0d9ab103e0310901ab65398 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Wed, 11 Dec 2024 08:38:18 -0800
Subject: [PATCH 2/3] Addressed the omp "issue".

---
 .../HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp         | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index 2bb1a786f6c12c..eec51a403cdd90 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -17,7 +17,6 @@
 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/HLFIR/Passes.h"
-#include "flang/Optimizer/OpenMP/Passes.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/BuiltinDialect.h"
@@ -165,11 +164,11 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
       // NOTE: the outer elemental operation may be lowered into
       // omp.workshare.loop_wrapper/omp.loop_nest later, so the reduction
       // loop may appear disjoint from the workshare loop nest.
-      bool emitWorkshareLoop =
-          isTotalReduction ? flangomp::shouldUseWorkshareLowering(sum) : false;
-
+      //
+      // TODO: a workshare loop nest can be used for the total reductions,
+      // but a proper reduction clause is required to make it work.
       hlfir::LoopNest loopNest = hlfir::genLoopNest(
-          loc, builder, extents, isUnordered, emitWorkshareLoop);
+          loc, builder, extents, isUnordered, /*emitWorkshareLoop=*/false);
 
       llvm::SmallVector<mlir::Value> indices;
       if (isTotalReduction) {

>From 76811eb6c8a0c2526fb480f40a88c338d0ff39c6 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Thu, 12 Dec 2024 20:33:32 -0800
Subject: [PATCH 3/3] Reverted to SSA reductions.

---
 .../flang/Optimizer/Builder/HLFIRTools.h      |  35 +++
 flang/lib/Optimizer/Builder/HLFIRTools.cpp    |  50 +++
 .../Transforms/SimplifyHLFIRIntrinsics.cpp    | 110 ++++---
 .../HLFIR/simplify-hlfir-intrinsics-sum.fir   | 297 +++++++++---------
 4 files changed, 291 insertions(+), 201 deletions(-)

diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
index efbd9e4f50d432..c8aad644bc784a 100644
--- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h
+++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
@@ -366,6 +366,10 @@ struct LoopNest {
 /// Generate a fir.do_loop nest looping from 1 to extents[i].
 /// \p isUnordered specifies whether the loops in the loop nest
 /// are unordered.
+///
+/// NOTE: genLoopNestWithReductions() should be used in favor
+/// of this method, though, it cannot generate OpenMP workshare
+/// loop constructs currently.
 LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
                      mlir::ValueRange extents, bool isUnordered = false,
                      bool emitWorkshareLoop = false);
@@ -376,6 +380,37 @@ inline LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
                      isUnordered, emitWorkshareLoop);
 }
 
+/// The type of a callback that generates the body of a reduction
+/// loop nest. It takes a location and a builder, as usual.
+/// In addition, the first set of values are the values of the loops'
+/// induction variables. The second set of values are the values
+/// of the reductions on entry to the innermost loop.
+/// The callback must return the updated values of the reductions.
+using ReductionLoopBodyGenerator = std::function<llvm::SmallVector<mlir::Value>(
+    mlir::Location, fir::FirOpBuilder &, mlir::ValueRange, mlir::ValueRange)>;
+
+/// Generate a loop nest loopong from 1 to \p extents[i] and reducing
+/// a set of values.
+/// \p isUnordered specifies whether the loops in the loop nest
+/// are unordered.
+/// \p reductionInits are the initial values of the reductions
+/// on entry to the outermost loop.
+/// \p genBody callback is repsonsible for generating the code
+/// that updates the reduction values in the innermost loop.
+///
+/// NOTE: the implementation of this function may decide
+/// to perform the reductions on SSA or in memory.
+/// In the latter case, this function is responsible for
+/// allocating/loading/storing the reduction variables,
+/// and making sure they have proper data sharing attributes
+/// in case any parallel constructs are present around the point
+/// of the loop nest insertion, or if the function decides
+/// to use any worksharing loop constructs for the loop nest.
+llvm::SmallVector<mlir::Value> genLoopNestWithReductions(
+    mlir::Location loc, fir::FirOpBuilder &builder, mlir::ValueRange extents,
+    mlir::ValueRange reductionInits, const ReductionLoopBodyGenerator &genBody,
+    bool isUnordered = false);
+
 /// Inline the body of an hlfir.elemental at the current insertion point
 /// given a list of one based indices. This generates the computation
 /// of one element of the elemental expression. Return the YieldElementOp
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 1bd950f2445ee4..94238bc24e453d 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -910,6 +910,56 @@ hlfir::LoopNest hlfir::genLoopNest(mlir::Location loc,
   return loopNest;
 }
 
+llvm::SmallVector<mlir::Value> hlfir::genLoopNestWithReductions(
+    mlir::Location loc, fir::FirOpBuilder &builder, mlir::ValueRange extents,
+    mlir::ValueRange reductionInits, const ReductionLoopBodyGenerator &genBody,
+    bool isUnordered) {
+  assert(!extents.empty() && "must have at least one extent");
+  // Build loop nest from column to row.
+  auto one = builder.create<mlir::arith::ConstantIndexOp>(loc, 1);
+  mlir::Type indexType = builder.getIndexType();
+  unsigned dim = extents.size() - 1;
+  fir::DoLoopOp outerLoop = nullptr;
+  fir::DoLoopOp parentLoop = nullptr;
+  llvm::SmallVector<mlir::Value> oneBasedIndices;
+  oneBasedIndices.resize(dim + 1);
+  for (auto extent : llvm::reverse(extents)) {
+    auto ub = builder.createConvert(loc, indexType, extent);
+
+    // The outermost loop takes reductionInits as the initial
+    // values of its iter-args.
+    // A child loop takes its iter-args from the region iter-args
+    // of its parent loop.
+    fir::DoLoopOp doLoop;
+    if (!parentLoop) {
+      doLoop = builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered,
+                                             /*finalCountValue=*/false,
+                                             reductionInits);
+    } else {
+      doLoop = builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered,
+                                             /*finalCountValue=*/false,
+                                             parentLoop.getRegionIterArgs());
+      // Return the results of the child loop from its parent loop.
+      builder.create<fir::ResultOp>(loc, doLoop.getResults());
+    }
+
+    builder.setInsertionPointToStart(doLoop.getBody());
+    // Reverse the indices so they are in column-major order.
+    oneBasedIndices[dim--] = doLoop.getInductionVar();
+    if (!outerLoop)
+      outerLoop = doLoop;
+    parentLoop = doLoop;
+  }
+
+  llvm::SmallVector<mlir::Value> reductionValues;
+  reductionValues =
+      genBody(loc, builder, oneBasedIndices, parentLoop.getRegionIterArgs());
+  builder.setInsertionPointToEnd(parentLoop.getBody());
+  builder.create<fir::ResultOp>(loc, reductionValues);
+  builder.setInsertionPointAfter(outerLoop);
+  return outerLoop->getResults();
+}
+
 static fir::ExtendedValue translateVariableToExtendedValue(
     mlir::Location loc, fir::FirOpBuilder &builder, hlfir::Entity variable,
     bool forceHlfirBase = false, bool contiguousHint = false) {
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index eec51a403cdd90..f58fde9eb5f36d 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -140,12 +140,8 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
       // Loop over all indices in the DIM dimension, and reduce all values.
       // If DIM is not present, do total reduction.
 
-      // Create temporary scalar for keeping the running reduction value.
-      mlir::Value reductionTemp =
-          builder.createTemporaryAlloc(loc, elementType, ".sum.reduction");
       // Initial value for the reduction.
-      mlir::Value initValue = genInitValue(loc, builder, elementType);
-      builder.create<fir::StoreOp>(loc, initValue, reductionTemp);
+      mlir::Value reductionInitValue = genInitValue(loc, builder, elementType);
 
       // The reduction loop may be unordered if FastMathFlags::reassoc
       // transformations are allowed. The integer reduction is always
@@ -161,56 +157,68 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
         extents.push_back(
             builder.createConvert(loc, builder.getIndexType(), dimExtent));
 
-      // NOTE: the outer elemental operation may be lowered into
-      // omp.workshare.loop_wrapper/omp.loop_nest later, so the reduction
-      // loop may appear disjoint from the workshare loop nest.
-      //
-      // TODO: a workshare loop nest can be used for the total reductions,
-      // but a proper reduction clause is required to make it work.
-      hlfir::LoopNest loopNest = hlfir::genLoopNest(
-          loc, builder, extents, isUnordered, /*emitWorkshareLoop=*/false);
-
-      llvm::SmallVector<mlir::Value> indices;
-      if (isTotalReduction) {
-        indices = loopNest.oneBasedIndices;
-      } else {
-        indices = inputIndices;
-        indices.insert(indices.begin() + dimVal - 1,
-                       loopNest.oneBasedIndices[0]);
-      }
+      auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+                         mlir::ValueRange oneBasedIndices,
+                         mlir::ValueRange reductionArgs)
+          -> llvm::SmallVector<mlir::Value, 1> {
+        // Generate the reduction loop-nest body.
+        // The initial reduction value in the innermost loop
+        // is passed via reductionArgs[0].
+        llvm::SmallVector<mlir::Value> indices;
+        if (isTotalReduction) {
+          indices = oneBasedIndices;
+        } else {
+          indices = inputIndices;
+          indices.insert(indices.begin() + dimVal - 1, oneBasedIndices[0]);
+        }
 
-      builder.setInsertionPointToStart(loopNest.body);
-      fir::IfOp ifOp;
-      if (mask) {
-        // Make the reduction value update conditional on the value
-        // of the mask.
-        if (!maskValue) {
-          // If the mask is an array, use the elemental and the loop indices
-          // to address the proper mask element.
-          maskValue = genMaskValue(loc, builder, mask, isPresentPred, indices);
+        mlir::Value reductionValue = reductionArgs[0];
+        fir::IfOp ifOp;
+        if (mask) {
+          // Make the reduction value update conditional on the value
+          // of the mask.
+          if (!maskValue) {
+            // If the mask is an array, use the elemental and the loop indices
+            // to address the proper mask element.
+            maskValue =
+                genMaskValue(loc, builder, mask, isPresentPred, indices);
+          }
+          mlir::Value isUnmasked = builder.create<fir::ConvertOp>(
+              loc, builder.getI1Type(), maskValue);
+          ifOp = builder.create<fir::IfOp>(loc, elementType, isUnmasked,
+                                           /*withElseRegion=*/true);
+          // In the 'else' block return the current reduction value.
+          builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+          builder.create<fir::ResultOp>(loc, reductionValue);
+
+          // In the 'then' block do the actual addition.
+          builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
         }
-        mlir::Value isUnmasked =
-            builder.create<fir::ConvertOp>(loc, builder.getI1Type(), maskValue);
-        ifOp = builder.create<fir::IfOp>(loc, isUnmasked,
-                                         /*withElseRegion=*/false);
 
-        // In the 'then' block do the actual addition.
-        builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-      }
+        hlfir::Entity element =
+            hlfir::getElementAt(loc, builder, array, indices);
+        hlfir::Entity elementValue =
+            hlfir::loadTrivialScalar(loc, builder, element);
+        // NOTE: we can use "Kahan summation" same way as the runtime
+        // (e.g. when fast-math is not allowed), but let's start with
+        // the simple version.
+        reductionValue =
+            genScalarAdd(loc, builder, reductionValue, elementValue);
+
+        if (ifOp) {
+          builder.create<fir::ResultOp>(loc, reductionValue);
+          builder.setInsertionPointAfter(ifOp);
+          reductionValue = ifOp.getResult(0);
+        }
+
+        return {reductionValue};
+      };
 
-      mlir::Value reductionValue =
-          builder.create<fir::LoadOp>(loc, reductionTemp);
-      hlfir::Entity element = hlfir::getElementAt(loc, builder, array, indices);
-      hlfir::Entity elementValue =
-          hlfir::loadTrivialScalar(loc, builder, element);
-      // NOTE: we can use "Kahan summation" same way as the runtime
-      // (e.g. when fast-math is not allowed), but let's start with
-      // the simple version.
-      reductionValue = genScalarAdd(loc, builder, reductionValue, elementValue);
-      builder.create<fir::StoreOp>(loc, reductionValue, reductionTemp);
-
-      builder.setInsertionPointAfter(loopNest.outerOp);
-      return hlfir::Entity{builder.create<fir::LoadOp>(loc, reductionTemp)};
+      llvm::SmallVector<mlir::Value, 1> reductionFinalValues =
+          hlfir::genLoopNestWithReductions(loc, builder, extents,
+                                           {reductionInitValue}, genBody,
+                                           isUnordered);
+      return hlfir::Entity{reductionFinalValues[0]};
     };
 
     if (isTotalReduction) {
diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
index 572b9f0da1e4ab..bb406a6ad359df 100644
--- a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
+++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
@@ -14,12 +14,9 @@ func.func @sum_box_known_extents(%arg0: !fir.box<!fir.array<2x3xi32>>) {
 // CHECK:           %[[VAL_4:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
 // CHECK:           %[[VAL_5:.*]] = hlfir.elemental %[[VAL_4]] unordered : (!fir.shape<1>) -> !hlfir.expr<2xi32> {
 // CHECK:           ^bb0(%[[VAL_6:.*]]: index):
-// CHECK:             %[[VAL_7:.*]] = fir.alloca i32 {bindc_name = ".sum.reduction"}
-// CHECK:             %[[VAL_8:.*]] = arith.constant 0 : i32
-// CHECK:             fir.store %[[VAL_8]] to %[[VAL_7]] : !fir.ref<i32>
-// CHECK:             %[[VAL_9:.*]] = arith.constant 1 : index
-// CHECK:             fir.do_loop %[[VAL_10:.*]] = %[[VAL_9]] to %[[VAL_3]] step %[[VAL_9]] unordered {
-// CHECK:               %[[VAL_11:.*]] = fir.load %[[VAL_7]] : !fir.ref<i32>
+// CHECK:             %[[VAL_7:.*]] = arith.constant 0 : i32
+// CHECK:             %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_3]] step %[[VAL_8]] unordered iter_args(%[[VAL_11:.*]] = %[[VAL_7]]) -> (i32) {
 // CHECK:               %[[VAL_12:.*]] = arith.constant 0 : index
 // CHECK:               %[[VAL_13:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_12]] : (!fir.box<!fir.array<2x3xi32>>, index) -> (index, index, index)
 // CHECK:               %[[VAL_14:.*]] = arith.constant 1 : index
@@ -32,10 +29,9 @@ func.func @sum_box_known_extents(%arg0: !fir.box<!fir.array<2x3xi32>>) {
 // CHECK:               %[[VAL_21:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_18]], %[[VAL_20]])  : (!fir.box<!fir.array<2x3xi32>>, index, index) -> !fir.ref<i32>
 // CHECK:               %[[VAL_22:.*]] = fir.load %[[VAL_21]] : !fir.ref<i32>
 // CHECK:               %[[VAL_23:.*]] = arith.addi %[[VAL_11]], %[[VAL_22]] : i32
-// CHECK:               fir.store %[[VAL_23]] to %[[VAL_7]] : !fir.ref<i32>
+// CHECK:               fir.result %[[VAL_23]] : i32
 // CHECK:             }
-// CHECK:             %[[VAL_24:.*]] = fir.load %[[VAL_7]] : !fir.ref<i32>
-// CHECK:             hlfir.yield_element %[[VAL_24]] : i32
+// CHECK:             hlfir.yield_element %[[VAL_9]] : i32
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -54,18 +50,14 @@ func.func @sum_expr_known_extents(%arg0: !hlfir.expr<2x3xi32>) {
 // CHECK:           %[[VAL_4:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
 // CHECK:           %[[VAL_5:.*]] = hlfir.elemental %[[VAL_4]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xi32> {
 // CHECK:           ^bb0(%[[VAL_6:.*]]: index):
-// CHECK:             %[[VAL_7:.*]] = fir.alloca i32 {bindc_name = ".sum.reduction"}
-// CHECK:             %[[VAL_8:.*]] = arith.constant 0 : i32
-// CHECK:             fir.store %[[VAL_8]] to %[[VAL_7]] : !fir.ref<i32>
-// CHECK:             %[[VAL_9:.*]] = arith.constant 1 : index
-// CHECK:             fir.do_loop %[[VAL_10:.*]] = %[[VAL_9]] to %[[VAL_2]] step %[[VAL_9]] unordered {
-// CHECK:               %[[VAL_11:.*]] = fir.load %[[VAL_7]] : !fir.ref<i32>
+// CHECK:             %[[VAL_7:.*]] = arith.constant 0 : i32
+// CHECK:             %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_2]] step %[[VAL_8]] unordered iter_args(%[[VAL_11:.*]] = %[[VAL_7]]) -> (i32) {
 // CHECK:               %[[VAL_12:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_10]], %[[VAL_6]] : (!hlfir.expr<2x3xi32>, index, index) -> i32
 // CHECK:               %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : i32
-// CHECK:               fir.store %[[VAL_13]] to %[[VAL_7]] : !fir.ref<i32>
+// CHECK:               fir.result %[[VAL_13]] : i32
 // CHECK:             }
-// CHECK:             %[[VAL_14:.*]] = fir.load %[[VAL_7]] : !fir.ref<i32>
-// CHECK:             hlfir.yield_element %[[VAL_14]] : i32
+// CHECK:             hlfir.yield_element %[[VAL_9]] : i32
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -85,15 +77,12 @@ func.func @sum_box_unknown_extent1(%arg0: !fir.box<!fir.array<?x3xcomplex<f64>>>
 // CHECK:           %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
 // CHECK:           %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xcomplex<f64>> {
 // CHECK:           ^bb0(%[[VAL_7:.*]]: index):
-// CHECK:             %[[VAL_8:.*]] = fir.alloca complex<f64> {bindc_name = ".sum.reduction"}
-// CHECK:             %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK:             %[[VAL_10:.*]] = fir.undefined complex<f64>
-// CHECK:             %[[VAL_11:.*]] = fir.insert_value %[[VAL_10]], %[[VAL_9]], [0 : index] : (complex<f64>, f64) -> complex<f64>
-// CHECK:             %[[VAL_12:.*]] = fir.insert_value %[[VAL_11]], %[[VAL_9]], [1 : index] : (complex<f64>, f64) -> complex<f64>
-// CHECK:             fir.store %[[VAL_12]] to %[[VAL_8]] : !fir.ref<complex<f64>>
-// CHECK:             %[[VAL_13:.*]] = arith.constant 1 : index
-// CHECK:             fir.do_loop %[[VAL_14:.*]] = %[[VAL_13]] to %[[VAL_3]]#1 step %[[VAL_13]] {
-// CHECK:               %[[VAL_15:.*]] = fir.load %[[VAL_8]] : !fir.ref<complex<f64>>
+// CHECK:             %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK:             %[[VAL_9:.*]] = fir.undefined complex<f64>
+// CHECK:             %[[VAL_10:.*]] = fir.insert_value %[[VAL_9]], %[[VAL_8]], [0 : index] : (complex<f64>, f64) -> complex<f64>
+// CHECK:             %[[VAL_11:.*]] = fir.insert_value %[[VAL_10]], %[[VAL_8]], [1 : index] : (complex<f64>, f64) -> complex<f64>
+// CHECK:             %[[VAL_12:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_13:.*]] = fir.do_loop %[[VAL_14:.*]] = %[[VAL_12]] to %[[VAL_3]]#1 step %[[VAL_12]] iter_args(%[[VAL_15:.*]] = %[[VAL_11]]) -> (complex<f64>) {
 // CHECK:               %[[VAL_16:.*]] = arith.constant 0 : index
 // CHECK:               %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_16]] : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index) -> (index, index, index)
 // CHECK:               %[[VAL_18:.*]] = arith.constant 1 : index
@@ -106,10 +95,9 @@ func.func @sum_box_unknown_extent1(%arg0: !fir.box<!fir.array<?x3xcomplex<f64>>>
 // CHECK:               %[[VAL_25:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_22]], %[[VAL_24]])  : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index, index) -> !fir.ref<complex<f64>>
 // CHECK:               %[[VAL_26:.*]] = fir.load %[[VAL_25]] : !fir.ref<complex<f64>>
 // CHECK:               %[[VAL_27:.*]] = fir.addc %[[VAL_15]], %[[VAL_26]] : complex<f64>
-// CHECK:               fir.store %[[VAL_27]] to %[[VAL_8]] : !fir.ref<complex<f64>>
+// CHECK:               fir.result %[[VAL_27]] : complex<f64>
 // CHECK:             }
-// CHECK:             %[[VAL_28:.*]] = fir.load %[[VAL_8]] : !fir.ref<complex<f64>>
-// CHECK:             hlfir.yield_element %[[VAL_28]] : complex<f64>
+// CHECK:             hlfir.yield_element %[[VAL_13]] : complex<f64>
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -128,15 +116,12 @@ func.func @sum_box_unknown_extent2(%arg0: !fir.box<!fir.array<?x3xcomplex<f64>>>
 // CHECK:           %[[VAL_5:.*]] = fir.shape %[[VAL_3]]#1 : (index) -> !fir.shape<1>
 // CHECK:           %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xcomplex<f64>> {
 // CHECK:           ^bb0(%[[VAL_7:.*]]: index):
-// CHECK:             %[[VAL_8:.*]] = fir.alloca complex<f64> {bindc_name = ".sum.reduction"}
-// CHECK:             %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK:             %[[VAL_10:.*]] = fir.undefined complex<f64>
-// CHECK:             %[[VAL_11:.*]] = fir.insert_value %[[VAL_10]], %[[VAL_9]], [0 : index] : (complex<f64>, f64) -> complex<f64>
-// CHECK:             %[[VAL_12:.*]] = fir.insert_value %[[VAL_11]], %[[VAL_9]], [1 : index] : (complex<f64>, f64) -> complex<f64>
-// CHECK:             fir.store %[[VAL_12]] to %[[VAL_8]] : !fir.ref<complex<f64>>
-// CHECK:             %[[VAL_13:.*]] = arith.constant 1 : index
-// CHECK:             fir.do_loop %[[VAL_14:.*]] = %[[VAL_13]] to %[[VAL_4]] step %[[VAL_13]] {
-// CHECK:               %[[VAL_15:.*]] = fir.load %[[VAL_8]] : !fir.ref<complex<f64>>
+// CHECK:             %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK:             %[[VAL_9:.*]] = fir.undefined complex<f64>
+// CHECK:             %[[VAL_10:.*]] = fir.insert_value %[[VAL_9]], %[[VAL_8]], [0 : index] : (complex<f64>, f64) -> complex<f64>
+// CHECK:             %[[VAL_11:.*]] = fir.insert_value %[[VAL_10]], %[[VAL_8]], [1 : index] : (complex<f64>, f64) -> complex<f64>
+// CHECK:             %[[VAL_12:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_13:.*]] = fir.do_loop %[[VAL_14:.*]] = %[[VAL_12]] to %[[VAL_4]] step %[[VAL_12]] iter_args(%[[VAL_15:.*]] = %[[VAL_11]]) -> (complex<f64>) {
 // CHECK:               %[[VAL_16:.*]] = arith.constant 0 : index
 // CHECK:               %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_16]] : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index) -> (index, index, index)
 // CHECK:               %[[VAL_18:.*]] = arith.constant 1 : index
@@ -149,10 +134,9 @@ func.func @sum_box_unknown_extent2(%arg0: !fir.box<!fir.array<?x3xcomplex<f64>>>
 // CHECK:               %[[VAL_25:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_22]], %[[VAL_24]])  : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index, index) -> !fir.ref<complex<f64>>
 // CHECK:               %[[VAL_26:.*]] = fir.load %[[VAL_25]] : !fir.ref<complex<f64>>
 // CHECK:               %[[VAL_27:.*]] = fir.addc %[[VAL_15]], %[[VAL_26]] : complex<f64>
-// CHECK:               fir.store %[[VAL_27]] to %[[VAL_8]] : !fir.ref<complex<f64>>
+// CHECK:               fir.result %[[VAL_27]] : complex<f64>
 // CHECK:             }
-// CHECK:             %[[VAL_28:.*]] = fir.load %[[VAL_8]] : !fir.ref<complex<f64>>
-// CHECK:             hlfir.yield_element %[[VAL_28]] : complex<f64>
+// CHECK:             hlfir.yield_element %[[VAL_13]] : complex<f64>
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -172,18 +156,14 @@ func.func @sum_expr_unknown_extent1(%arg0: !hlfir.expr<?x3xf32>) {
 // CHECK:           %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
 // CHECK:           %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xf32> {
 // CHECK:           ^bb0(%[[VAL_7:.*]]: index):
-// CHECK:             %[[VAL_8:.*]] = fir.alloca f32 {bindc_name = ".sum.reduction"}
-// CHECK:             %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:             fir.store %[[VAL_9]] to %[[VAL_8]] : !fir.ref<f32>
-// CHECK:             %[[VAL_10:.*]] = arith.constant 1 : index
-// CHECK:             fir.do_loop %[[VAL_11:.*]] = %[[VAL_10]] to %[[VAL_3]] step %[[VAL_10]] {
-// CHECK:               %[[VAL_12:.*]] = fir.load %[[VAL_8]] : !fir.ref<f32>
+// CHECK:             %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:             %[[VAL_9:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_10:.*]] = fir.do_loop %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_3]] step %[[VAL_9]] iter_args(%[[VAL_12:.*]] = %[[VAL_8]]) -> (f32) {
 // CHECK:               %[[VAL_13:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_11]], %[[VAL_7]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
 // CHECK:               %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32
-// CHECK:               fir.store %[[VAL_14]] to %[[VAL_8]] : !fir.ref<f32>
+// CHECK:               fir.result %[[VAL_14]] : f32
 // CHECK:             }
-// CHECK:             %[[VAL_15:.*]] = fir.load %[[VAL_8]] : !fir.ref<f32>
-// CHECK:             hlfir.yield_element %[[VAL_15]] : f32
+// CHECK:             hlfir.yield_element %[[VAL_10]] : f32
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -202,18 +182,14 @@ func.func @sum_expr_unknown_extent2(%arg0: !hlfir.expr<?x3xf32>) {
 // CHECK:           %[[VAL_5:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
 // CHECK:           %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
 // CHECK:           ^bb0(%[[VAL_7:.*]]: index):
-// CHECK:             %[[VAL_8:.*]] = fir.alloca f32 {bindc_name = ".sum.reduction"}
-// CHECK:             %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:             fir.store %[[VAL_9]] to %[[VAL_8]] : !fir.ref<f32>
-// CHECK:             %[[VAL_10:.*]] = arith.constant 1 : index
-// CHECK:             fir.do_loop %[[VAL_11:.*]] = %[[VAL_10]] to %[[VAL_4]] step %[[VAL_10]] {
-// CHECK:               %[[VAL_12:.*]] = fir.load %[[VAL_8]] : !fir.ref<f32>
+// CHECK:             %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:             %[[VAL_9:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_10:.*]] = fir.do_loop %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_4]] step %[[VAL_9]] iter_args(%[[VAL_12:.*]] = %[[VAL_8]]) -> (f32) {
 // CHECK:               %[[VAL_13:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_7]], %[[VAL_11]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
 // CHECK:               %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32
-// CHECK:               fir.store %[[VAL_14]] to %[[VAL_8]] : !fir.ref<f32>
+// CHECK:               fir.result %[[VAL_14]] : f32
 // CHECK:             }
-// CHECK:             %[[VAL_15:.*]] = fir.load %[[VAL_8]] : !fir.ref<f32>
-// CHECK:             hlfir.yield_element %[[VAL_15]] : f32
+// CHECK:             hlfir.yield_element %[[VAL_10]] : f32
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -235,21 +211,20 @@ func.func @sum_scalar_mask(%arg0: !hlfir.expr<?x3xf32>, %mask: !fir.ref<!fir.log
 // CHECK:           %[[VAL_7:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.logical<1>>
 // CHECK:           %[[VAL_8:.*]] = hlfir.elemental %[[VAL_6]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xf32> {
 // CHECK:           ^bb0(%[[VAL_9:.*]]: index):
-// CHECK:             %[[VAL_10:.*]] = fir.alloca f32 {bindc_name = ".sum.reduction"}
-// CHECK:             %[[VAL_11:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:             fir.store %[[VAL_11]] to %[[VAL_10]] : !fir.ref<f32>
-// CHECK:             %[[VAL_12:.*]] = arith.constant 1 : index
-// CHECK:             fir.do_loop %[[VAL_13:.*]] = %[[VAL_12]] to %[[VAL_4]] step %[[VAL_12]] {
-// CHECK:               %[[VAL_14:.*]] = fir.convert %[[VAL_7]] : (!fir.logical<1>) -> i1
-// CHECK:               fir.if %[[VAL_14]] {
-// CHECK:                 %[[VAL_15:.*]] = fir.load %[[VAL_10]] : !fir.ref<f32>
-// CHECK:                 %[[VAL_16:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_13]], %[[VAL_9]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
-// CHECK:                 %[[VAL_17:.*]] = arith.addf %[[VAL_15]], %[[VAL_16]] : f32
-// CHECK:                 fir.store %[[VAL_17]] to %[[VAL_10]] : !fir.ref<f32>
+// CHECK:             %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:             %[[VAL_11:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_12:.*]] = fir.do_loop %[[VAL_13:.*]] = %[[VAL_11]] to %[[VAL_4]] step %[[VAL_11]] iter_args(%[[VAL_14:.*]] = %[[VAL_10]]) -> (f32) {
+// CHECK:               %[[VAL_15:.*]] = fir.convert %[[VAL_7]] : (!fir.logical<1>) -> i1
+// CHECK:               %[[VAL_16:.*]] = fir.if %[[VAL_15]] -> (f32) {
+// CHECK:                 %[[VAL_17:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_13]], %[[VAL_9]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
+// CHECK:                 %[[VAL_18:.*]] = arith.addf %[[VAL_14]], %[[VAL_17]] : f32
+// CHECK:                 fir.result %[[VAL_18]] : f32
+// CHECK:               } else {
+// CHECK:                 fir.result %[[VAL_14]] : f32
 // CHECK:               }
+// CHECK:               fir.result %[[VAL_16]] : f32
 // CHECK:             }
-// CHECK:             %[[VAL_18:.*]] = fir.load %[[VAL_10]] : !fir.ref<f32>
-// CHECK:             hlfir.yield_element %[[VAL_18]] : f32
+// CHECK:             hlfir.yield_element %[[VAL_12]] : f32
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -280,21 +255,20 @@ func.func @sum_scalar_boxed_mask(%arg0: !hlfir.expr<?x3xf32>, %mask: !fir.box<!f
 // CHECK:           }
 // CHECK:           %[[VAL_13:.*]] = hlfir.elemental %[[VAL_6]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xf32> {
 // CHECK:           ^bb0(%[[VAL_14:.*]]: index):
-// CHECK:             %[[VAL_15:.*]] = fir.alloca f32 {bindc_name = ".sum.reduction"}
-// CHECK:             %[[VAL_16:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:             fir.store %[[VAL_16]] to %[[VAL_15]] : !fir.ref<f32>
-// CHECK:             %[[VAL_17:.*]] = arith.constant 1 : index
-// CHECK:             fir.do_loop %[[VAL_18:.*]] = %[[VAL_17]] to %[[VAL_4]] step %[[VAL_17]] {
-// CHECK:               %[[VAL_19:.*]] = fir.convert %[[VAL_8]] : (!fir.logical<1>) -> i1
-// CHECK:               fir.if %[[VAL_19]] {
-// CHECK:                 %[[VAL_20:.*]] = fir.load %[[VAL_15]] : !fir.ref<f32>
-// CHECK:                 %[[VAL_21:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_18]], %[[VAL_14]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
-// CHECK:                 %[[VAL_22:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32
-// CHECK:                 fir.store %[[VAL_22]] to %[[VAL_15]] : !fir.ref<f32>
+// CHECK:             %[[VAL_15:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:             %[[VAL_16:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_17:.*]] = fir.do_loop %[[VAL_18:.*]] = %[[VAL_16]] to %[[VAL_4]] step %[[VAL_16]] iter_args(%[[VAL_19:.*]] = %[[VAL_15]]) -> (f32) {
+// CHECK:               %[[VAL_20:.*]] = fir.convert %[[VAL_8]] : (!fir.logical<1>) -> i1
+// CHECK:               %[[VAL_21:.*]] = fir.if %[[VAL_20]] -> (f32) {
+// CHECK:                 %[[VAL_22:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_18]], %[[VAL_14]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
+// CHECK:                 %[[VAL_23:.*]] = arith.addf %[[VAL_19]], %[[VAL_22]] : f32
+// CHECK:                 fir.result %[[VAL_23]] : f32
+// CHECK:               } else {
+// CHECK:                 fir.result %[[VAL_19]] : f32
 // CHECK:               }
+// CHECK:               fir.result %[[VAL_21]] : f32
 // CHECK:             }
-// CHECK:             %[[VAL_23:.*]] = fir.load %[[VAL_15]] : !fir.ref<f32>
-// CHECK:             hlfir.yield_element %[[VAL_23]] : f32
+// CHECK:             hlfir.yield_element %[[VAL_17]] : f32
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -316,39 +290,38 @@ func.func @sum_array_mask(%arg0: !hlfir.expr<?x3xf32>, %mask: !fir.box<!fir.arra
 // CHECK:           %[[VAL_7:.*]] = fir.is_present %[[VAL_1]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>) -> i1
 // CHECK:           %[[VAL_8:.*]] = hlfir.elemental %[[VAL_6]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
 // CHECK:           ^bb0(%[[VAL_9:.*]]: index):
-// CHECK:             %[[VAL_10:.*]] = fir.alloca f32 {bindc_name = ".sum.reduction"}
-// CHECK:             %[[VAL_11:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:             fir.store %[[VAL_11]] to %[[VAL_10]] : !fir.ref<f32>
-// CHECK:             %[[VAL_12:.*]] = arith.constant 1 : index
-// CHECK:             fir.do_loop %[[VAL_13:.*]] = %[[VAL_12]] to %[[VAL_5]] step %[[VAL_12]] {
-// CHECK:               %[[VAL_14:.*]] = fir.if %[[VAL_7]] -> (!fir.logical<1>) {
-// CHECK:                 %[[VAL_15:.*]] = arith.constant 0 : index
-// CHECK:                 %[[VAL_16:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_15]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index) -> (index, index, index)
-// CHECK:                 %[[VAL_17:.*]] = arith.constant 1 : index
-// CHECK:                 %[[VAL_18:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_17]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index) -> (index, index, index)
-// CHECK:                 %[[VAL_19:.*]] = arith.constant 1 : index
-// CHECK:                 %[[VAL_20:.*]] = arith.subi %[[VAL_16]]#0, %[[VAL_19]] : index
-// CHECK:                 %[[VAL_21:.*]] = arith.addi %[[VAL_9]], %[[VAL_20]] : index
-// CHECK:                 %[[VAL_22:.*]] = arith.subi %[[VAL_18]]#0, %[[VAL_19]] : index
-// CHECK:                 %[[VAL_23:.*]] = arith.addi %[[VAL_13]], %[[VAL_22]] : index
-// CHECK:                 %[[VAL_24:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_21]], %[[VAL_23]])  : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index, index) -> !fir.ref<!fir.logical<1>>
-// CHECK:                 %[[VAL_25:.*]] = fir.load %[[VAL_24]] : !fir.ref<!fir.logical<1>>
-// CHECK:                 fir.result %[[VAL_25]] : !fir.logical<1>
+// CHECK:             %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:             %[[VAL_11:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_12:.*]] = fir.do_loop %[[VAL_13:.*]] = %[[VAL_11]] to %[[VAL_5]] step %[[VAL_11]] iter_args(%[[VAL_14:.*]] = %[[VAL_10]]) -> (f32) {
+// CHECK:               %[[VAL_15:.*]] = fir.if %[[VAL_7]] -> (!fir.logical<1>) {
+// CHECK:                 %[[VAL_16:.*]] = arith.constant 0 : index
+// CHECK:                 %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_16]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index) -> (index, index, index)
+// CHECK:                 %[[VAL_18:.*]] = arith.constant 1 : index
+// CHECK:                 %[[VAL_19:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_18]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index) -> (index, index, index)
+// CHECK:                 %[[VAL_20:.*]] = arith.constant 1 : index
+// CHECK:                 %[[VAL_21:.*]] = arith.subi %[[VAL_17]]#0, %[[VAL_20]] : index
+// CHECK:                 %[[VAL_22:.*]] = arith.addi %[[VAL_9]], %[[VAL_21]] : index
+// CHECK:                 %[[VAL_23:.*]] = arith.subi %[[VAL_19]]#0, %[[VAL_20]] : index
+// CHECK:                 %[[VAL_24:.*]] = arith.addi %[[VAL_13]], %[[VAL_23]] : index
+// CHECK:                 %[[VAL_25:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_22]], %[[VAL_24]])  : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index, index) -> !fir.ref<!fir.logical<1>>
+// CHECK:                 %[[VAL_26:.*]] = fir.load %[[VAL_25]] : !fir.ref<!fir.logical<1>>
+// CHECK:                 fir.result %[[VAL_26]] : !fir.logical<1>
 // CHECK:               } else {
-// CHECK:                 %[[VAL_26:.*]] = arith.constant true
-// CHECK:                 %[[VAL_27:.*]] = fir.convert %[[VAL_26]] : (i1) -> !fir.logical<1>
-// CHECK:                 fir.result %[[VAL_27]] : !fir.logical<1>
+// CHECK:                 %[[VAL_27:.*]] = arith.constant true
+// CHECK:                 %[[VAL_28:.*]] = fir.convert %[[VAL_27]] : (i1) -> !fir.logical<1>
+// CHECK:                 fir.result %[[VAL_28]] : !fir.logical<1>
 // CHECK:               }
-// CHECK:               %[[VAL_28:.*]] = fir.convert %[[VAL_14]] : (!fir.logical<1>) -> i1
-// CHECK:               fir.if %[[VAL_28]] {
-// CHECK:                 %[[VAL_29:.*]] = fir.load %[[VAL_10]] : !fir.ref<f32>
-// CHECK:                 %[[VAL_30:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_9]], %[[VAL_13]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
-// CHECK:                 %[[VAL_31:.*]] = arith.addf %[[VAL_29]], %[[VAL_30]] : f32
-// CHECK:                 fir.store %[[VAL_31]] to %[[VAL_10]] : !fir.ref<f32>
+// CHECK:               %[[VAL_29:.*]] = fir.convert %[[VAL_15]] : (!fir.logical<1>) -> i1
+// CHECK:               %[[VAL_30:.*]] = fir.if %[[VAL_29]] -> (f32) {
+// CHECK:                 %[[VAL_31:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_9]], %[[VAL_13]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
+// CHECK:                 %[[VAL_32:.*]] = arith.addf %[[VAL_14]], %[[VAL_31]] : f32
+// CHECK:                 fir.result %[[VAL_32]] : f32
+// CHECK:               } else {
+// CHECK:                 fir.result %[[VAL_14]] : f32
 // CHECK:               }
+// CHECK:               fir.result %[[VAL_30]] : f32
 // CHECK:             }
-// CHECK:             %[[VAL_32:.*]] = fir.load %[[VAL_10]] : !fir.ref<f32>
-// CHECK:             hlfir.yield_element %[[VAL_32]] : f32
+// CHECK:             hlfir.yield_element %[[VAL_12]] : f32
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -369,22 +342,21 @@ func.func @sum_array_expr_mask(%arg0: !hlfir.expr<?x3xf32>, %mask: !hlfir.expr<?
 // CHECK:           %[[VAL_6:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
 // CHECK:           %[[VAL_7:.*]] = hlfir.elemental %[[VAL_6]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
 // CHECK:           ^bb0(%[[VAL_8:.*]]: index):
-// CHECK:             %[[VAL_9:.*]] = fir.alloca f32 {bindc_name = ".sum.reduction"}
-// CHECK:             %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:             fir.store %[[VAL_10]] to %[[VAL_9]] : !fir.ref<f32>
-// CHECK:             %[[VAL_11:.*]] = arith.constant 1 : index
-// CHECK:             fir.do_loop %[[VAL_12:.*]] = %[[VAL_11]] to %[[VAL_5]] step %[[VAL_11]] {
-// CHECK:               %[[VAL_13:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_8]], %[[VAL_12]] : (!hlfir.expr<?x3x!fir.logical<1>>, index, index) -> !fir.logical<1>
-// CHECK:               %[[VAL_14:.*]] = fir.convert %[[VAL_13]] : (!fir.logical<1>) -> i1
-// CHECK:               fir.if %[[VAL_14]] {
-// CHECK:                 %[[VAL_15:.*]] = fir.load %[[VAL_9]] : !fir.ref<f32>
-// CHECK:                 %[[VAL_16:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_8]], %[[VAL_12]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
-// CHECK:                 %[[VAL_17:.*]] = arith.addf %[[VAL_15]], %[[VAL_16]] : f32
-// CHECK:                 fir.store %[[VAL_17]] to %[[VAL_9]] : !fir.ref<f32>
+// CHECK:             %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:             %[[VAL_10:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_11:.*]] = fir.do_loop %[[VAL_12:.*]] = %[[VAL_10]] to %[[VAL_5]] step %[[VAL_10]] iter_args(%[[VAL_13:.*]] = %[[VAL_9]]) -> (f32) {
+// CHECK:               %[[VAL_14:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_8]], %[[VAL_12]] : (!hlfir.expr<?x3x!fir.logical<1>>, index, index) -> !fir.logical<1>
+// CHECK:               %[[VAL_15:.*]] = fir.convert %[[VAL_14]] : (!fir.logical<1>) -> i1
+// CHECK:               %[[VAL_16:.*]] = fir.if %[[VAL_15]] -> (f32) {
+// CHECK:                 %[[VAL_17:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_8]], %[[VAL_12]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
+// CHECK:                 %[[VAL_18:.*]] = arith.addf %[[VAL_13]], %[[VAL_17]] : f32
+// CHECK:                 fir.result %[[VAL_18]] : f32
+// CHECK:               } else {
+// CHECK:                 fir.result %[[VAL_13]] : f32
 // CHECK:               }
+// CHECK:               fir.result %[[VAL_16]] : f32
 // CHECK:             }
-// CHECK:             %[[VAL_18:.*]] = fir.load %[[VAL_9]] : !fir.ref<f32>
-// CHECK:             hlfir.yield_element %[[VAL_18]] : f32
+// CHECK:             hlfir.yield_element %[[VAL_11]] : f32
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -403,38 +375,31 @@ func.func @sum_unordered_reduction(%arg0: !hlfir.expr<2x3xf32>) {
 // CHECK:           %[[VAL_4:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
 // CHECK:           %[[VAL_5:.*]] = hlfir.elemental %[[VAL_4]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xf32> {
 // CHECK:           ^bb0(%[[VAL_6:.*]]: index):
-// CHECK:             %[[VAL_7:.*]] = fir.alloca f32 {bindc_name = ".sum.reduction"}
-// CHECK:             %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:             fir.store %[[VAL_8]] to %[[VAL_7]] : !fir.ref<f32>
-// CHECK:             %[[VAL_9:.*]] = arith.constant 1 : index
-// CHECK:             fir.do_loop %[[VAL_10:.*]] = %[[VAL_9]] to %[[VAL_2]] step %[[VAL_9]] unordered {
-// CHECK:               %[[VAL_11:.*]] = fir.load %[[VAL_7]] : !fir.ref<f32>
+// CHECK:             %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:             %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_2]] step %[[VAL_8]] unordered iter_args(%[[VAL_11:.*]] = %[[VAL_7]]) -> (f32) {
 // CHECK:               %[[VAL_12:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_10]], %[[VAL_6]] : (!hlfir.expr<2x3xf32>, index, index) -> f32
 // CHECK:               %[[VAL_13:.*]] = arith.addf %[[VAL_11]], %[[VAL_12]] fastmath<reassoc> : f32
-// CHECK:               fir.store %[[VAL_13]] to %[[VAL_7]] : !fir.ref<f32>
+// CHECK:               fir.result %[[VAL_13]] : f32
 // CHECK:             }
-// CHECK:             %[[VAL_14:.*]] = fir.load %[[VAL_7]] : !fir.ref<f32>
-// CHECK:             hlfir.yield_element %[[VAL_14]] : f32
+// CHECK:             hlfir.yield_element %[[VAL_9]] : f32
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
 
-// total reduction
-func.func @sum_total_reduction(%arg0: !fir.box<!fir.array<3xi32>>) {
+// total 1d reduction
+func.func @sum_total_1d_reduction(%arg0: !fir.box<!fir.array<3xi32>>) {
   %cst = arith.constant 1 : i32
   %res = hlfir.sum %arg0 dim %cst : (!fir.box<!fir.array<3xi32>>, i32) -> i32
   return
 }
-// CHECK-LABEL:   func.func @sum_total_reduction(
-// CHECK-SAME:                                   %[[VAL_0:.*]]: !fir.box<!fir.array<3xi32>>) {
+// CHECK-LABEL:   func.func @sum_total_1d_reduction(
+// CHECK-SAME:                                      %[[VAL_0:.*]]: !fir.box<!fir.array<3xi32>>) {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 1 : i32
 // CHECK:           %[[VAL_2:.*]] = arith.constant 3 : index
-// CHECK:           %[[VAL_3:.*]] = fir.alloca i32 {bindc_name = ".sum.reduction"}
-// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : i32
-// CHECK:           fir.store %[[VAL_4]] to %[[VAL_3]] : !fir.ref<i32>
-// CHECK:           %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK:           fir.do_loop %[[VAL_6:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_5]] unordered {
-// CHECK:             %[[VAL_7:.*]] = fir.load %[[VAL_3]] : !fir.ref<i32>
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = fir.do_loop %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_2]] step %[[VAL_4]] unordered iter_args(%[[VAL_7:.*]] = %[[VAL_3]]) -> (i32) {
 // CHECK:             %[[VAL_8:.*]] = arith.constant 0 : index
 // CHECK:             %[[VAL_9:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_8]] : (!fir.box<!fir.array<3xi32>>, index) -> (index, index, index)
 // CHECK:             %[[VAL_10:.*]] = arith.constant 1 : index
@@ -443,9 +408,41 @@ func.func @sum_total_reduction(%arg0: !fir.box<!fir.array<3xi32>>) {
 // CHECK:             %[[VAL_13:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_12]])  : (!fir.box<!fir.array<3xi32>>, index) -> !fir.ref<i32>
 // CHECK:             %[[VAL_14:.*]] = fir.load %[[VAL_13]] : !fir.ref<i32>
 // CHECK:             %[[VAL_15:.*]] = arith.addi %[[VAL_7]], %[[VAL_14]] : i32
-// CHECK:             fir.store %[[VAL_15]] to %[[VAL_3]] : !fir.ref<i32>
+// CHECK:             fir.result %[[VAL_15]] : i32
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// total 2d reduction
+func.func @sum_total_2d_reduction(%arg0: !fir.box<!fir.array<?x3xi32>>) {
+  %res = hlfir.sum %arg0 : (!fir.box<!fir.array<?x3xi32>>) -> i32
+  return
+}
+// CHECK-LABEL:   func.func @sum_total_2d_reduction(
+// CHECK-SAME:                                      %[[VAL_0:.*]]: !fir.box<!fir.array<?x3xi32>>) {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_2:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_1]] : (!fir.box<!fir.array<?x3xi32>>, index) -> (index, index, index)
+// CHECK:           %[[VAL_3:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_6:.*]] = fir.do_loop %[[VAL_7:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_5]] unordered iter_args(%[[VAL_8:.*]] = %[[VAL_4]]) -> (i32) {
+// CHECK:             %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_2]]#1 step %[[VAL_5]] unordered iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (i32) {
+// CHECK:               %[[VAL_12:.*]] = arith.constant 0 : index
+// CHECK:               %[[VAL_13:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_12]] : (!fir.box<!fir.array<?x3xi32>>, index) -> (index, index, index)
+// CHECK:               %[[VAL_14:.*]] = arith.constant 1 : index
+// CHECK:               %[[VAL_15:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_14]] : (!fir.box<!fir.array<?x3xi32>>, index) -> (index, index, index)
+// CHECK:               %[[VAL_16:.*]] = arith.constant 1 : index
+// CHECK:               %[[VAL_17:.*]] = arith.subi %[[VAL_13]]#0, %[[VAL_16]] : index
+// CHECK:               %[[VAL_18:.*]] = arith.addi %[[VAL_10]], %[[VAL_17]] : index
+// CHECK:               %[[VAL_19:.*]] = arith.subi %[[VAL_15]]#0, %[[VAL_16]] : index
+// CHECK:               %[[VAL_20:.*]] = arith.addi %[[VAL_7]], %[[VAL_19]] : index
+// CHECK:               %[[VAL_21:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_18]], %[[VAL_20]])  : (!fir.box<!fir.array<?x3xi32>>, index, index) -> !fir.ref<i32>
+// CHECK:               %[[VAL_22:.*]] = fir.load %[[VAL_21]] : !fir.ref<i32>
+// CHECK:               %[[VAL_23:.*]] = arith.addi %[[VAL_11]], %[[VAL_22]] : i32
+// CHECK:               fir.result %[[VAL_23]] : i32
+// CHECK:             }
+// CHECK:             fir.result %[[VAL_9]] : i32
 // CHECK:           }
-// CHECK:           %[[VAL_16:.*]] = fir.load %[[VAL_3]] : !fir.ref<i32>
 // CHECK:           return
 // CHECK:         }
 



More information about the flang-commits mailing list