[flang-commits] [flang] [flang] Expand SUM(DIM=CONSTANT) into an hlfir.elemental. (PR #118556)

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Wed Dec 4 14:49:25 PST 2024


https://github.com/vzakhari updated https://github.com/llvm/llvm-project/pull/118556

>From 196d6de06e1c9b18db5b5bb06fa9ad05d068ac54 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Tue, 3 Dec 2024 14:21:11 -0800
Subject: [PATCH 1/3] [flang] Expand SUM(DIM=CONSTANT) into an hlfir.elemental.

An array SUM with the specified constant DIM argument
may be expanded into hlfir.elemental with a reduction loop
inside it processing all elements of the specified dimension.
The expansion allows further optimization of the cases like
`A=SUM(B+1,DIM=1)` in the optimized bufferization pass
(given that it can prove there are no read/write conflicts).
---
 .../Transforms/SimplifyHLFIRIntrinsics.cpp    | 204 ++++++++++
 .../HLFIR/simplify-hlfir-intrinsics-sum.fir   | 361 ++++++++++++++++++
 2 files changed, 565 insertions(+)
 create mode 100644 flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir

diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index 60b06437e6a987..35dc881e880df2 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -10,6 +10,7 @@
 // into the calling function.
 //===----------------------------------------------------------------------===//
 
+#include "flang/Optimizer/Builder/Complex.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Dialect/FIRDialect.h"
@@ -90,6 +91,190 @@ class TransposeAsElementalConversion
   }
 };
 
+// Expand the SUM(DIM=CONSTANT) operation into .
+class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
+public:
+  using mlir::OpRewritePattern<hlfir::SumOp>::OpRewritePattern;
+
+  llvm::LogicalResult
+  matchAndRewrite(hlfir::SumOp sum,
+                  mlir::PatternRewriter &rewriter) const override {
+    mlir::Location loc = sum.getLoc();
+    fir::FirOpBuilder builder{rewriter, sum.getOperation()};
+    hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType());
+    assert(expr && "expected an expression type for the result of hlfir.sum");
+    mlir::Type elementType = expr.getElementType();
+    hlfir::Entity array = hlfir::Entity{sum.getArray()};
+    mlir::Value mask = sum.getMask();
+    mlir::Value dim = sum.getDim();
+    int64_t dimVal = fir::getIntIfConstant(dim).value_or(0);
+    assert(dimVal > 0 && "DIM must be present and a positive constant");
+    mlir::Value resultShape, dimExtent;
+    std::tie(resultShape, dimExtent) =
+        genResultShape(loc, builder, array, dimVal);
+
+    auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+                         mlir::ValueRange inputIndices) -> hlfir::Entity {
+      // Loop over all indices in the DIM dimension, and reduce all values.
+      // We do not need to create the reduction loop always: if we can
+      // slice the input array given the inputIndices, then we can
+      // just apply a new SUM operation (total reduction) to the slice.
+      // For the time being, generate the explicit loop because the slicing
+      // requires generating an elemental operation for the input array
+      // (and the mask, if present).
+      // TODO: produce the slices and new SUM after adding a pattern
+      // for expanding total reduction SUM case.
+      mlir::Type indexType = builder.getIndexType();
+      auto one = builder.createIntegerConstant(loc, indexType, 1);
+      auto ub = builder.createConvert(loc, indexType, dimExtent);
+
+      // Initial value for the reduction.
+      mlir::Value initValue = genInitValue(loc, builder, elementType);
+
+      // The reduction loop may be unordered if FastMathFlags::reassoc
+      // transformations are allowed. The integer reduction is always
+      // unordered.
+      bool isUnordered = mlir::isa<mlir::IntegerType>(elementType) ||
+                         static_cast<bool>(sum.getFastmath() &
+                                           mlir::arith::FastMathFlags::reassoc);
+
+      // If the mask is present and is a scalar, then we'd better load its value
+      // outside of the reduction loop making the loop unswitching easier.
+      // Maybe it is worth hoisting it from the elemental operation as well.
+      if (mask) {
+        hlfir::Entity maskValue{mask};
+        if (maskValue.isScalar())
+          mask = hlfir::loadTrivialScalar(loc, builder, maskValue);
+      }
+
+      // NOTE: the outer elemental operation may be lowered into
+      // omp.workshare.loop_wrapper/omp.loop_nest later, so the reduction
+      // loop may appear disjoint from the workshare loop nest.
+      // Moreover, the inner loop is not strictly nested (due to the reduction
+      // starting value initialization), and the above omp dialect operations
+      // cannot produce results.
+      // It is unclear what we should do about it yet.
+      auto doLoop = builder.create<fir::DoLoopOp>(
+          loc, one, ub, one, isUnordered, /*finalCountValue=*/false,
+          mlir::ValueRange{initValue});
+
+      // Address the input array using the reduction loop's IV
+      // for the DIM dimension.
+      mlir::Value iv = doLoop.getInductionVar();
+      llvm::SmallVector<mlir::Value> indices{inputIndices};
+      indices.insert(indices.begin() + dimVal - 1, iv);
+
+      mlir::OpBuilder::InsertionGuard guard(builder);
+      builder.setInsertionPointToStart(doLoop.getBody());
+      mlir::Value reductionValue = doLoop.getRegionIterArgs()[0];
+      fir::IfOp ifOp;
+      if (mask) {
+        // Make the reduction value update conditional on the value
+        // of the mask.
+        hlfir::Entity maskValue{mask};
+        if (!maskValue.isScalar()) {
+          // If the mask is an array, use the elemental and the loop indices
+          // to address the proper mask element.
+          maskValue = hlfir::getElementAt(loc, builder, maskValue, indices);
+          maskValue = hlfir::loadTrivialScalar(loc, builder, maskValue);
+        }
+        mlir::Value isUnmasked =
+            builder.create<fir::ConvertOp>(loc, builder.getI1Type(), maskValue);
+        ifOp = builder.create<fir::IfOp>(loc, elementType, isUnmasked,
+                                         /*withElseRegion=*/true);
+        // In the 'else' block return the current reduction value.
+        builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+        builder.create<fir::ResultOp>(loc, reductionValue);
+
+        // In the 'then' block do the actual addition.
+        builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+      }
+
+      hlfir::Entity element = hlfir::getElementAt(loc, builder, array, indices);
+      hlfir::Entity elementValue =
+          hlfir::loadTrivialScalar(loc, builder, element);
+      // NOTE: we can use "Kahan summation" same way as the runtime
+      // (e.g. when fast-math is not allowed), but let's start with
+      // the simple version.
+      reductionValue = genScalarAdd(loc, builder, reductionValue, elementValue);
+      builder.create<fir::ResultOp>(loc, reductionValue);
+
+      if (ifOp) {
+        builder.setInsertionPointAfter(ifOp);
+        builder.create<fir::ResultOp>(loc, ifOp.getResult(0));
+      }
+
+      return hlfir::Entity{doLoop.getResult(0)};
+    };
+    hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
+        loc, builder, elementType, resultShape, {}, genKernel,
+        /*isUnordered=*/true, /*polymorphicMold=*/nullptr,
+        sum.getResult().getType());
+
+    // it wouldn't be safe to replace block arguments with a different
+    // hlfir.expr type. Types can differ due to differing amounts of shape
+    // information
+    assert(elementalOp.getResult().getType() == sum.getResult().getType());
+
+    rewriter.replaceOp(sum, elementalOp);
+    return mlir::success();
+  }
+
+private:
+  // Return fir.shape specifying the shape of the result
+  // of a SUM reduction with DIM=dimVal. The second return value
+  // is the extent of the DIM dimension.
+  static std::tuple<mlir::Value, mlir::Value>
+  genResultShape(mlir::Location loc, fir::FirOpBuilder &builder,
+                 hlfir::Entity array, int64_t dimVal) {
+    mlir::Value inShape = hlfir::genShape(loc, builder, array);
+    llvm::SmallVector<mlir::Value> inExtents =
+        hlfir::getExplicitExtentsFromShape(inShape, builder);
+    if (inShape.getUses().empty())
+      inShape.getDefiningOp()->erase();
+
+    mlir::Value dimExtent = inExtents[dimVal - 1];
+    inExtents.erase(inExtents.begin() + dimVal - 1);
+    return {builder.create<fir::ShapeOp>(loc, inExtents), dimExtent};
+  }
+
+  // Generate the initial value for a SUM reduction with the given
+  // data type.
+  static mlir::Value genInitValue(mlir::Location loc,
+                                  fir::FirOpBuilder &builder,
+                                  mlir::Type elementType) {
+    if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
+      const llvm::fltSemantics &sem = ty.getFloatSemantics();
+      return builder.createRealConstant(loc, elementType,
+                                        llvm::APFloat::getZero(sem));
+    } else if (auto ty = mlir::dyn_cast<mlir::ComplexType>(elementType)) {
+      mlir::Value initValue = genInitValue(loc, builder, ty.getElementType());
+      return fir::factory::Complex{builder, loc}.createComplex(ty, initValue,
+                                                               initValue);
+    } else if (mlir::isa<mlir::IntegerType>(elementType)) {
+      return builder.createIntegerConstant(loc, elementType, 0);
+    }
+
+    llvm_unreachable("unsupported SUM reduction type");
+  }
+
+  // Generate scalar addition of the two values (of the same data type).
+  static mlir::Value genScalarAdd(mlir::Location loc,
+                                  fir::FirOpBuilder &builder,
+                                  mlir::Value value1, mlir::Value value2) {
+    mlir::Type ty = value1.getType();
+    assert(ty == value2.getType() && "reduction values' types do not match");
+    if (mlir::isa<mlir::FloatType>(ty))
+      return builder.create<mlir::arith::AddFOp>(loc, value1, value2);
+    else if (mlir::isa<mlir::ComplexType>(ty))
+      return builder.create<fir::AddcOp>(loc, value1, value2);
+    else if (mlir::isa<mlir::IntegerType>(ty))
+      return builder.create<mlir::arith::AddIOp>(loc, value1, value2);
+
+    llvm_unreachable("unsupported SUM reduction type");
+  }
+};
+
 class SimplifyHLFIRIntrinsics
     : public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
 public:
@@ -97,6 +282,7 @@ class SimplifyHLFIRIntrinsics
     mlir::MLIRContext *context = &getContext();
     mlir::RewritePatternSet patterns(context);
     patterns.insert<TransposeAsElementalConversion>(context);
+    patterns.insert<SumAsElementalConversion>(context);
     mlir::ConversionTarget target(*context);
     // don't transform transpose of polymorphic arrays (not currently supported
     // by hlfir.elemental)
@@ -105,6 +291,24 @@ class SimplifyHLFIRIntrinsics
           return mlir::cast<hlfir::ExprType>(transpose.getType())
               .isPolymorphic();
         });
+    // Handle only SUM(DIM=CONSTANT) case for now.
+    // It may be beneficial to expand the non-DIM case as well.
+    // E.g. when the input array is an elemental array expression,
+    // expanding the SUM into a total reduction loop nest
+    // would avoid creating a temporary for the elemental array expression.
+    target.addDynamicallyLegalOp<hlfir::SumOp>([](hlfir::SumOp sum) {
+      if (mlir::Value dim = sum.getDim()) {
+        if (fir::getIntIfConstant(dim)) {
+          if (!fir::isa_trivial(sum.getType())) {
+            // Ignore the case SUM(a, DIM=X), where 'a' is a 1D array.
+            // It is only legal when X is 1, and it should probably be
+            // canonicalized into SUM(a).
+            return false;
+          }
+        }
+      }
+      return true;
+    });
     target.markUnknownOpDynamicallyLegal(
         [](mlir::Operation *) { return true; });
     if (mlir::failed(mlir::applyFullConversion(getOperation(), target,
diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
new file mode 100644
index 00000000000000..05a4dfde6344e2
--- /dev/null
+++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
@@ -0,0 +1,361 @@
+// RUN: fir-opt --simplify-hlfir-intrinsics %s | FileCheck %s
+
+// box with known extents
+func.func @sum_box_known_extents(%arg0: !fir.box<!fir.array<2x3xi32>>) {
+  %cst = arith.constant 2 : i32
+  %res = hlfir.sum %arg0 dim %cst : (!fir.box<!fir.array<2x3xi32>>, i32) -> !hlfir.expr<2xi32>
+  return
+}
+// CHECK-LABEL:   func.func @sum_box_known_extents(
+// CHECK-SAME:                                     %[[VAL_0:.*]]: !fir.box<!fir.array<2x3xi32>>) {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 2 : i32
+// CHECK:           %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_4:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_5:.*]] = hlfir.elemental %[[VAL_4]] unordered : (!fir.shape<1>) -> !hlfir.expr<2xi32> {
+// CHECK:           ^bb0(%[[VAL_6:.*]]: index):
+// CHECK:             %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_8:.*]] = arith.constant 0 : i32
+// CHECK:             %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_7]] to %[[VAL_3]] step %[[VAL_7]] unordered iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (i32) {
+// CHECK:               %[[VAL_12:.*]] = arith.constant 0 : index
+// CHECK:               %[[VAL_13:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_12]] : (!fir.box<!fir.array<2x3xi32>>, index) -> (index, index, index)
+// CHECK:               %[[VAL_14:.*]] = arith.constant 1 : index
+// CHECK:               %[[VAL_15:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_14]] : (!fir.box<!fir.array<2x3xi32>>, index) -> (index, index, index)
+// CHECK:               %[[VAL_16:.*]] = arith.constant 1 : index
+// CHECK:               %[[VAL_17:.*]] = arith.subi %[[VAL_13]]#0, %[[VAL_16]] : index
+// CHECK:               %[[VAL_18:.*]] = arith.addi %[[VAL_6]], %[[VAL_17]] : index
+// CHECK:               %[[VAL_19:.*]] = arith.subi %[[VAL_15]]#0, %[[VAL_16]] : index
+// CHECK:               %[[VAL_20:.*]] = arith.addi %[[VAL_10]], %[[VAL_19]] : index
+// CHECK:               %[[VAL_21:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_18]], %[[VAL_20]])  : (!fir.box<!fir.array<2x3xi32>>, index, index) -> !fir.ref<i32>
+// CHECK:               %[[VAL_22:.*]] = fir.load %[[VAL_21]] : !fir.ref<i32>
+// CHECK:               %[[VAL_23:.*]] = arith.addi %[[VAL_11]], %[[VAL_22]] : i32
+// CHECK:               fir.result %[[VAL_23]] : i32
+// CHECK:             }
+// CHECK:             hlfir.yield_element %[[VAL_9]] : i32
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// expr with known extents
+func.func @sum_expr_known_extents(%arg0: !hlfir.expr<2x3xi32>) {
+  %cst = arith.constant 1 : i32
+  %res = hlfir.sum %arg0 dim %cst : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>
+  return
+}
+// CHECK-LABEL:   func.func @sum_expr_known_extents(
+// CHECK-SAME:                                      %[[VAL_0:.*]]: !hlfir.expr<2x3xi32>) {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : i32
+// CHECK:           %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_4:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_5:.*]] = hlfir.elemental %[[VAL_4]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xi32> {
+// CHECK:           ^bb0(%[[VAL_6:.*]]: index):
+// CHECK:             %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_8:.*]] = arith.constant 0 : i32
+// CHECK:             %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_7]] to %[[VAL_2]] step %[[VAL_7]] unordered iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (i32) {
+// CHECK:               %[[VAL_12:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_10]], %[[VAL_6]] : (!hlfir.expr<2x3xi32>, index, index) -> i32
+// CHECK:               %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : i32
+// CHECK:               fir.result %[[VAL_13]] : i32
+// CHECK:             }
+// CHECK:             hlfir.yield_element %[[VAL_9]] : i32
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// box with unknown extent
+func.func @sum_box_unknown_extent1(%arg0: !fir.box<!fir.array<?x3xcomplex<f64>>>) {
+  %cst = arith.constant 1 : i32
+  %res = hlfir.sum %arg0 dim %cst : (!fir.box<!fir.array<?x3xcomplex<f64>>>, i32) -> !hlfir.expr<3xcomplex<f64>>
+  return
+}
+// CHECK-LABEL:   func.func @sum_box_unknown_extent1(
+// CHECK-SAME:                                       %[[VAL_0:.*]]: !fir.box<!fir.array<?x3xcomplex<f64>>>) {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : i32
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_2]] : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index) -> (index, index, index)
+// CHECK:           %[[VAL_4:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xcomplex<f64>> {
+// CHECK:           ^bb0(%[[VAL_7:.*]]: index):
+// CHECK:             %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK:             %[[VAL_10:.*]] = fir.undefined complex<f64>
+// CHECK:             %[[VAL_11:.*]] = fir.insert_value %[[VAL_10]], %[[VAL_9]], [0 : index] : (complex<f64>, f64) -> complex<f64>
+// CHECK:             %[[VAL_12:.*]] = fir.insert_value %[[VAL_11]], %[[VAL_9]], [1 : index] : (complex<f64>, f64) -> complex<f64>
+// CHECK:             %[[VAL_13:.*]] = fir.do_loop %[[VAL_14:.*]] = %[[VAL_8]] to %[[VAL_3]]#1 step %[[VAL_8]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (complex<f64>) {
+// CHECK:               %[[VAL_16:.*]] = arith.constant 0 : index
+// CHECK:               %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_16]] : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index) -> (index, index, index)
+// CHECK:               %[[VAL_18:.*]] = arith.constant 1 : index
+// CHECK:               %[[VAL_19:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_18]] : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index) -> (index, index, index)
+// CHECK:               %[[VAL_20:.*]] = arith.constant 1 : index
+// CHECK:               %[[VAL_21:.*]] = arith.subi %[[VAL_17]]#0, %[[VAL_20]] : index
+// CHECK:               %[[VAL_22:.*]] = arith.addi %[[VAL_14]], %[[VAL_21]] : index
+// CHECK:               %[[VAL_23:.*]] = arith.subi %[[VAL_19]]#0, %[[VAL_20]] : index
+// CHECK:               %[[VAL_24:.*]] = arith.addi %[[VAL_7]], %[[VAL_23]] : index
+// CHECK:               %[[VAL_25:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_22]], %[[VAL_24]])  : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index, index) -> !fir.ref<complex<f64>>
+// CHECK:               %[[VAL_26:.*]] = fir.load %[[VAL_25]] : !fir.ref<complex<f64>>
+// CHECK:               %[[VAL_27:.*]] = fir.addc %[[VAL_15]], %[[VAL_26]] : complex<f64>
+// CHECK:               fir.result %[[VAL_27]] : complex<f64>
+// CHECK:             }
+// CHECK:             hlfir.yield_element %[[VAL_13]] : complex<f64>
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+func.func @sum_box_unknown_extent2(%arg0: !fir.box<!fir.array<?x3xcomplex<f64>>>) {
+  %cst = arith.constant 2 : i32
+  %res = hlfir.sum %arg0 dim %cst : (!fir.box<!fir.array<?x3xcomplex<f64>>>, i32) -> !hlfir.expr<?xcomplex<f64>>
+  return
+}
+// CHECK-LABEL:   func.func @sum_box_unknown_extent2(
+// CHECK-SAME:                                       %[[VAL_0:.*]]: !fir.box<!fir.array<?x3xcomplex<f64>>>) {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 2 : i32
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_2]] : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index) -> (index, index, index)
+// CHECK:           %[[VAL_4:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_5:.*]] = fir.shape %[[VAL_3]]#1 : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xcomplex<f64>> {
+// CHECK:           ^bb0(%[[VAL_7:.*]]: index):
+// CHECK:             %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK:             %[[VAL_10:.*]] = fir.undefined complex<f64>
+// CHECK:             %[[VAL_11:.*]] = fir.insert_value %[[VAL_10]], %[[VAL_9]], [0 : index] : (complex<f64>, f64) -> complex<f64>
+// CHECK:             %[[VAL_12:.*]] = fir.insert_value %[[VAL_11]], %[[VAL_9]], [1 : index] : (complex<f64>, f64) -> complex<f64>
+// CHECK:             %[[VAL_13:.*]] = fir.do_loop %[[VAL_14:.*]] = %[[VAL_8]] to %[[VAL_4]] step %[[VAL_8]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (complex<f64>) {
+// CHECK:               %[[VAL_16:.*]] = arith.constant 0 : index
+// CHECK:               %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_16]] : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index) -> (index, index, index)
+// CHECK:               %[[VAL_18:.*]] = arith.constant 1 : index
+// CHECK:               %[[VAL_19:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_18]] : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index) -> (index, index, index)
+// CHECK:               %[[VAL_20:.*]] = arith.constant 1 : index
+// CHECK:               %[[VAL_21:.*]] = arith.subi %[[VAL_17]]#0, %[[VAL_20]] : index
+// CHECK:               %[[VAL_22:.*]] = arith.addi %[[VAL_7]], %[[VAL_21]] : index
+// CHECK:               %[[VAL_23:.*]] = arith.subi %[[VAL_19]]#0, %[[VAL_20]] : index
+// CHECK:               %[[VAL_24:.*]] = arith.addi %[[VAL_14]], %[[VAL_23]] : index
+// CHECK:               %[[VAL_25:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_22]], %[[VAL_24]])  : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index, index) -> !fir.ref<complex<f64>>
+// CHECK:               %[[VAL_26:.*]] = fir.load %[[VAL_25]] : !fir.ref<complex<f64>>
+// CHECK:               %[[VAL_27:.*]] = fir.addc %[[VAL_15]], %[[VAL_26]] : complex<f64>
+// CHECK:               fir.result %[[VAL_27]] : complex<f64>
+// CHECK:             }
+// CHECK:             hlfir.yield_element %[[VAL_13]] : complex<f64>
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// expr with unknown extent
+func.func @sum_expr_unkwnonw_extent1(%arg0: !hlfir.expr<?x3xf32>) {
+  %cst = arith.constant 1 : i32
+  %res = hlfir.sum %arg0 dim %cst : (!hlfir.expr<?x3xf32>, i32) -> !hlfir.expr<3xf32>
+  return
+}
+// CHECK-LABEL:   func.func @sum_expr_unkwnonw_extent1(
+// CHECK-SAME:                                         %[[VAL_0:.*]]: !hlfir.expr<?x3xf32>) {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : i32
+// CHECK:           %[[VAL_2:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x3xf32>) -> !fir.shape<2>
+// CHECK:           %[[VAL_3:.*]] = hlfir.get_extent %[[VAL_2]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xf32> {
+// CHECK:           ^bb0(%[[VAL_7:.*]]: index):
+// CHECK:             %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:             %[[VAL_10:.*]] = fir.do_loop %[[VAL_11:.*]] = %[[VAL_8]] to %[[VAL_3]] step %[[VAL_8]] iter_args(%[[VAL_12:.*]] = %[[VAL_9]]) -> (f32) {
+// CHECK:               %[[VAL_13:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_11]], %[[VAL_7]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
+// CHECK:               %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32
+// CHECK:               fir.result %[[VAL_14]] : f32
+// CHECK:             }
+// CHECK:             hlfir.yield_element %[[VAL_10]] : f32
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+func.func @sum_expr_unkwnonw_extent2(%arg0: !hlfir.expr<?x3xf32>) {
+  %cst = arith.constant 2 : i32
+  %res = hlfir.sum %arg0 dim %cst : (!hlfir.expr<?x3xf32>, i32) -> !hlfir.expr<?xf32>
+  return
+}
+// CHECK-LABEL:   func.func @sum_expr_unkwnonw_extent2(
+// CHECK-SAME:                                         %[[VAL_0:.*]]: !hlfir.expr<?x3xf32>) {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 2 : i32
+// CHECK:           %[[VAL_2:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x3xf32>) -> !fir.shape<2>
+// CHECK:           %[[VAL_3:.*]] = hlfir.get_extent %[[VAL_2]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_5:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+// CHECK:           ^bb0(%[[VAL_7:.*]]: index):
+// CHECK:             %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:             %[[VAL_10:.*]] = fir.do_loop %[[VAL_11:.*]] = %[[VAL_8]] to %[[VAL_4]] step %[[VAL_8]] iter_args(%[[VAL_12:.*]] = %[[VAL_9]]) -> (f32) {
+// CHECK:               %[[VAL_13:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_7]], %[[VAL_11]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
+// CHECK:               %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32
+// CHECK:               fir.result %[[VAL_14]] : f32
+// CHECK:             }
+// CHECK:             hlfir.yield_element %[[VAL_10]] : f32
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// scalar mask
+func.func @sum_scalar_mask(%arg0: !hlfir.expr<?x3xf32>, %mask: !fir.ref<!fir.logical<1>>) {
+  %cst = arith.constant 1 : i32
+  %res = hlfir.sum %arg0 dim %cst mask %mask : (!hlfir.expr<?x3xf32>, i32, !fir.ref<!fir.logical<1>>) -> !hlfir.expr<3xf32>
+  return
+}
+// CHECK-LABEL:   func.func @sum_scalar_mask(
+// CHECK-SAME:                               %[[VAL_0:.*]]: !hlfir.expr<?x3xf32>,
+// CHECK-SAME:                               %[[VAL_1:.*]]: !fir.ref<!fir.logical<1>>) {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : i32
+// CHECK:           %[[VAL_3:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x3xf32>) -> !fir.shape<2>
+// CHECK:           %[[VAL_4:.*]] = hlfir.get_extent %[[VAL_3]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_6:.*]] = fir.shape %[[VAL_5]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_7:.*]] = hlfir.elemental %[[VAL_6]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xf32> {
+// CHECK:           ^bb0(%[[VAL_8:.*]]: index):
+// CHECK:             %[[VAL_9:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:             %[[VAL_11:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.logical<1>>
+// CHECK:             %[[VAL_12:.*]] = fir.do_loop %[[VAL_13:.*]] = %[[VAL_9]] to %[[VAL_4]] step %[[VAL_9]] iter_args(%[[VAL_14:.*]] = %[[VAL_10]]) -> (f32) {
+// CHECK:               %[[VAL_15:.*]] = fir.convert %[[VAL_11]] : (!fir.logical<1>) -> i1
+// CHECK:               %[[VAL_16:.*]] = fir.if %[[VAL_15]] -> (f32) {
+// CHECK:                 %[[VAL_17:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_13]], %[[VAL_8]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
+// CHECK:                 %[[VAL_18:.*]] = arith.addf %[[VAL_14]], %[[VAL_17]] : f32
+// CHECK:                 fir.result %[[VAL_18]] : f32
+// CHECK:               } else {
+// CHECK:                 fir.result %[[VAL_14]] : f32
+// CHECK:               }
+// CHECK:               fir.result %[[VAL_16]] : f32
+// CHECK:             }
+// CHECK:             hlfir.yield_element %[[VAL_12]] : f32
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// array mask
+func.func @sum_array_mask(%arg0: !hlfir.expr<?x3xf32>, %mask: !fir.box<!fir.array<?x3x!fir.logical<1>>>) {
+  %cst = arith.constant 2 : i32
+  %res = hlfir.sum %arg0 dim %cst mask %mask : (!hlfir.expr<?x3xf32>, i32, !fir.box<!fir.array<?x3x!fir.logical<1>>>) -> !hlfir.expr<?xf32>
+  return
+}
+// CHECK-LABEL:   func.func @sum_array_mask(
+// CHECK-SAME:                              %[[VAL_0:.*]]: !hlfir.expr<?x3xf32>,
+// CHECK-SAME:                              %[[VAL_1:.*]]: !fir.box<!fir.array<?x3x!fir.logical<1>>>) {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 2 : i32
+// CHECK:           %[[VAL_3:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x3xf32>) -> !fir.shape<2>
+// CHECK:           %[[VAL_4:.*]] = hlfir.get_extent %[[VAL_3]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_6:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_7:.*]] = hlfir.elemental %[[VAL_6]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+// CHECK:           ^bb0(%[[VAL_8:.*]]: index):
+// CHECK:             %[[VAL_9:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:             %[[VAL_11:.*]] = fir.do_loop %[[VAL_12:.*]] = %[[VAL_9]] to %[[VAL_5]] step %[[VAL_9]] iter_args(%[[VAL_13:.*]] = %[[VAL_10]]) -> (f32) {
+// CHECK:               %[[VAL_14:.*]] = arith.constant 0 : index
+// CHECK:               %[[VAL_15:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_14]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index) -> (index, index, index)
+// CHECK:               %[[VAL_16:.*]] = arith.constant 1 : index
+// CHECK:               %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_16]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index) -> (index, index, index)
+// CHECK:               %[[VAL_18:.*]] = arith.constant 1 : index
+// CHECK:               %[[VAL_19:.*]] = arith.subi %[[VAL_15]]#0, %[[VAL_18]] : index
+// CHECK:               %[[VAL_20:.*]] = arith.addi %[[VAL_8]], %[[VAL_19]] : index
+// CHECK:               %[[VAL_21:.*]] = arith.subi %[[VAL_17]]#0, %[[VAL_18]] : index
+// CHECK:               %[[VAL_22:.*]] = arith.addi %[[VAL_12]], %[[VAL_21]] : index
+// CHECK:               %[[VAL_23:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_20]], %[[VAL_22]])  : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index, index) -> !fir.ref<!fir.logical<1>>
+// CHECK:               %[[VAL_24:.*]] = fir.load %[[VAL_23]] : !fir.ref<!fir.logical<1>>
+// CHECK:               %[[VAL_25:.*]] = fir.convert %[[VAL_24]] : (!fir.logical<1>) -> i1
+// CHECK:               %[[VAL_26:.*]] = fir.if %[[VAL_25]] -> (f32) {
+// CHECK:                 %[[VAL_27:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_8]], %[[VAL_12]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
+// CHECK:                 %[[VAL_28:.*]] = arith.addf %[[VAL_13]], %[[VAL_27]] : f32
+// CHECK:                 fir.result %[[VAL_28]] : f32
+// CHECK:               } else {
+// CHECK:                 fir.result %[[VAL_13]] : f32
+// CHECK:               }
+// CHECK:               fir.result %[[VAL_26]] : f32
+// CHECK:             }
+// CHECK:             hlfir.yield_element %[[VAL_11]] : f32
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// array expr mask
+func.func @sum_array_expr_mask(%arg0: !hlfir.expr<?x3xf32>, %mask: !hlfir.expr<?x3x!fir.logical<1>>) {
+  %cst = arith.constant 2 : i32
+  %res = hlfir.sum %arg0 dim %cst mask %mask : (!hlfir.expr<?x3xf32>, i32, !hlfir.expr<?x3x!fir.logical<1>>) -> !hlfir.expr<?xf32>
+  return
+}
+// CHECK-LABEL:   func.func @sum_array_expr_mask(
+// CHECK-SAME:                                   %[[VAL_0:.*]]: !hlfir.expr<?x3xf32>,
+// CHECK-SAME:                                   %[[VAL_1:.*]]: !hlfir.expr<?x3x!fir.logical<1>>) {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 2 : i32
+// CHECK:           %[[VAL_3:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x3xf32>) -> !fir.shape<2>
+// CHECK:           %[[VAL_4:.*]] = hlfir.get_extent %[[VAL_3]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_6:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_7:.*]] = hlfir.elemental %[[VAL_6]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+// CHECK:           ^bb0(%[[VAL_8:.*]]: index):
+// CHECK:             %[[VAL_9:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:             %[[VAL_11:.*]] = fir.do_loop %[[VAL_12:.*]] = %[[VAL_9]] to %[[VAL_5]] step %[[VAL_9]] iter_args(%[[VAL_13:.*]] = %[[VAL_10]]) -> (f32) {
+// CHECK:               %[[VAL_14:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_8]], %[[VAL_12]] : (!hlfir.expr<?x3x!fir.logical<1>>, index, index) -> !fir.logical<1>
+// CHECK:               %[[VAL_15:.*]] = fir.convert %[[VAL_14]] : (!fir.logical<1>) -> i1
+// CHECK:               %[[VAL_16:.*]] = fir.if %[[VAL_15]] -> (f32) {
+// CHECK:                 %[[VAL_17:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_8]], %[[VAL_12]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
+// CHECK:                 %[[VAL_18:.*]] = arith.addf %[[VAL_13]], %[[VAL_17]] : f32
+// CHECK:                 fir.result %[[VAL_18]] : f32
+// CHECK:               } else {
+// CHECK:                 fir.result %[[VAL_13]] : f32
+// CHECK:               }
+// CHECK:               fir.result %[[VAL_16]] : f32
+// CHECK:             }
+// CHECK:             hlfir.yield_element %[[VAL_11]] : f32
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// unordered floating point reduction
+func.func @sum_unordered_reduction(%arg0: !hlfir.expr<2x3xf32>) {
+  %cst = arith.constant 1 : i32
+  %res = hlfir.sum %arg0 dim %cst {fastmath = #arith.fastmath<reassoc>} : (!hlfir.expr<2x3xf32>, i32) -> !hlfir.expr<3xf32>
+  return
+}
+// CHECK-LABEL:   func.func @sum_unordered_reduction(
+// CHECK-SAME:                                       %[[VAL_0:.*]]: !hlfir.expr<2x3xf32>) {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : i32
+// CHECK:           %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_4:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_5:.*]] = hlfir.elemental %[[VAL_4]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xf32> {
+// CHECK:           ^bb0(%[[VAL_6:.*]]: index):
+// CHECK:             %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:             %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_7]] to %[[VAL_2]] step %[[VAL_7]] unordered iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (f32) {
+// CHECK:               %[[VAL_12:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_10]], %[[VAL_6]] : (!hlfir.expr<2x3xf32>, index, index) -> f32
+// CHECK:               %[[VAL_13:.*]] = arith.addf %[[VAL_11]], %[[VAL_12]] fastmath<reassoc> : f32
+// CHECK:               fir.result %[[VAL_13]] : f32
+// CHECK:             }
+// CHECK:             hlfir.yield_element %[[VAL_9]] : f32
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// negative: total reduction
+func.func @sum_total_reduction(%arg0: !fir.box<!fir.array<3xi32>>) {
+  %cst = arith.constant 1 : i32
+  %res = hlfir.sum %arg0 dim %cst : (!fir.box<!fir.array<3xi32>>, i32) -> i32
+  return
+}
+// CHECK-LABEL:   func.func @sum_total_reduction(
+// CHECK-SAME:                                   %[[VAL_0:.*]]: !fir.box<!fir.array<3xi32>>) {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : i32
+// CHECK:           %[[VAL_2:.*]] = hlfir.sum %[[VAL_0]] dim %[[VAL_1]] : (!fir.box<!fir.array<3xi32>>, i32) -> i32
+// CHECK:           return
+// CHECK:         }
+
+// negative: non-const dim
+func.func @sum_non_const_dim(%arg0: !fir.box<!fir.array<3xi32>>, %dim: i32) {
+  %res = hlfir.sum %arg0 dim %dim : (!fir.box<!fir.array<3xi32>>, i32) -> i32
+  return
+}
+// CHECK-LABEL:   func.func @sum_non_const_dim(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: !fir.box<!fir.array<3xi32>>,
+// CHECK-SAME:                                 %[[VAL_1:.*]]: i32) {
+// CHECK:           %[[VAL_2:.*]] = hlfir.sum %[[VAL_0]] dim %[[VAL_1]] : (!fir.box<!fir.array<3xi32>>, i32) -> i32
+// CHECK:           return
+// CHECK:         }

>From 0542e16c0b45f74be5f2ca6d52d9420876d5bc60 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Wed, 4 Dec 2024 14:26:35 -0800
Subject: [PATCH 2/3] Handle dynamically absent mask argument properly.

---
 .../Transforms/SimplifyHLFIRIntrinsics.cpp    | 64 +++++++++++--
 .../HLFIR/simplify-hlfir-intrinsics-sum.fir   | 92 +++++++++++++++----
 2 files changed, 129 insertions(+), 27 deletions(-)

diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index 35dc881e880df2..0c34c8221aeda6 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -141,10 +141,17 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
       // If the mask is present and is a scalar, then we'd better load its value
       // outside of the reduction loop making the loop unswitching easier.
       // Maybe it is worth hoisting it from the elemental operation as well.
+      mlir::Value isPresentPred, maskValue;
       if (mask) {
-        hlfir::Entity maskValue{mask};
-        if (maskValue.isScalar())
-          mask = hlfir::loadTrivialScalar(loc, builder, maskValue);
+        if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
+          // MASK represented by a box might be dynamically optional,
+          // so we have to check for its presence before accessing it.
+          isPresentPred =
+              builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), mask);
+        }
+
+        if (hlfir::Entity{mask}.isScalar())
+          maskValue = genMaskValue(loc, builder, mask, isPresentPred, {});
       }
 
       // NOTE: the outer elemental operation may be lowered into
@@ -171,12 +178,10 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
       if (mask) {
         // Make the reduction value update conditional on the value
         // of the mask.
-        hlfir::Entity maskValue{mask};
-        if (!maskValue.isScalar()) {
+        if (!maskValue) {
           // If the mask is an array, use the elemental and the loop indices
           // to address the proper mask element.
-          maskValue = hlfir::getElementAt(loc, builder, maskValue, indices);
-          maskValue = hlfir::loadTrivialScalar(loc, builder, maskValue);
+          maskValue = genMaskValue(loc, builder, mask, isPresentPred, indices);
         }
         mlir::Value isUnmasked =
             builder.create<fir::ConvertOp>(loc, builder.getI1Type(), maskValue);
@@ -273,6 +278,51 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
 
     llvm_unreachable("unsupported SUM reduction type");
   }
+
+  static mlir::Value genMaskValue(mlir::Location loc,
+                                  fir::FirOpBuilder &builder, mlir::Value mask,
+                                  mlir::Value isPresentPred,
+                                  mlir::ValueRange indices) {
+    mlir::OpBuilder::InsertionGuard guard(builder);
+    fir::IfOp ifOp;
+    mlir::Type maskType =
+        hlfir::getFortranElementType(fir::unwrapPassByRefType(mask.getType()));
+    if (isPresentPred) {
+      ifOp = builder.create<fir::IfOp>(loc, maskType, isPresentPred,
+                                       /*withElseRegion=*/true);
+
+      // Use 'true', if the mask is not present.
+      builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+      mlir::Value trueValue = builder.createBool(loc, true);
+      trueValue = builder.createConvert(loc, maskType, trueValue);
+      builder.create<fir::ResultOp>(loc, trueValue);
+
+      // Load the mask value, if the mask is present.
+      builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+    }
+
+    hlfir::Entity maskVar{mask};
+    if (maskVar.isScalar()) {
+      if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
+        // MASK may be a boxed scalar.
+        mlir::Value addr = hlfir::genVariableRawAddress(loc, builder, maskVar);
+        mask = builder.create<fir::LoadOp>(loc, hlfir::Entity{addr});
+      } else {
+        mask = hlfir::loadTrivialScalar(loc, builder, maskVar);
+      }
+    } else {
+      // Load from the mask array.
+      assert(!indices.empty() && "no indices for addressing the mask array");
+      maskVar = hlfir::getElementAt(loc, builder, maskVar, indices);
+      mask = hlfir::loadTrivialScalar(loc, builder, maskVar);
+    }
+
+    if (!isPresentPred)
+      return mask;
+
+    builder.create<fir::ResultOp>(loc, mask);
+    return ifOp.getResult(0);
+  }
 };
 
 class SimplifyHLFIRIntrinsics
diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
index 05a4dfde6344e2..48c4144f70393f 100644
--- a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
+++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
@@ -229,6 +229,50 @@ func.func @sum_scalar_mask(%arg0: !hlfir.expr<?x3xf32>, %mask: !fir.ref<!fir.log
 // CHECK:           return
 // CHECK:         }
 
+// scalar boxed mask
+func.func @sum_scalar_boxed_mask(%arg0: !hlfir.expr<?x3xf32>, %mask: !fir.box<!fir.logical<1>>) {
+  %cst = arith.constant 1 : i32
+  %res = hlfir.sum %arg0 dim %cst mask %mask : (!hlfir.expr<?x3xf32>, i32, !fir.box<!fir.logical<1>>) -> !hlfir.expr<3xf32>
+  return
+}
+// CHECK-LABEL:   func.func @sum_scalar_boxed_mask(
+// CHECK-SAME:                                     %[[VAL_0:.*]]: !hlfir.expr<?x3xf32>,
+// CHECK-SAME:                                     %[[VAL_1:.*]]: !fir.box<!fir.logical<1>>) {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : i32
+// CHECK:           %[[VAL_3:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x3xf32>) -> !fir.shape<2>
+// CHECK:           %[[VAL_4:.*]] = hlfir.get_extent %[[VAL_3]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_6:.*]] = fir.shape %[[VAL_5]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_7:.*]] = hlfir.elemental %[[VAL_6]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xf32> {
+// CHECK:           ^bb0(%[[VAL_8:.*]]: index):
+// CHECK:             %[[VAL_9:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:             %[[VAL_11:.*]] = fir.is_present %[[VAL_1]] : (!fir.box<!fir.logical<1>>) -> i1
+// CHECK:             %[[VAL_12:.*]] = fir.if %[[VAL_11]] -> (!fir.logical<1>) {
+// CHECK:               %[[VAL_13:.*]] = fir.box_addr %[[VAL_1]] : (!fir.box<!fir.logical<1>>) -> !fir.ref<!fir.logical<1>>
+// CHECK:               %[[VAL_14:.*]] = fir.load %[[VAL_13]] : !fir.ref<!fir.logical<1>>
+// CHECK:               fir.result %[[VAL_14]] : !fir.logical<1>
+// CHECK:             } else {
+// CHECK:               %[[VAL_15:.*]] = arith.constant true
+// CHECK:               %[[VAL_16:.*]] = fir.convert %[[VAL_15]] : (i1) -> !fir.logical<1>
+// CHECK:               fir.result %[[VAL_16]] : !fir.logical<1>
+// CHECK:             }
+// CHECK:             %[[VAL_17:.*]] = fir.do_loop %[[VAL_18:.*]] = %[[VAL_9]] to %[[VAL_4]] step %[[VAL_9]] iter_args(%[[VAL_19:.*]] = %[[VAL_10]]) -> (f32) {
+// CHECK:               %[[VAL_20:.*]] = fir.convert %[[VAL_12]] : (!fir.logical<1>) -> i1
+// CHECK:               %[[VAL_21:.*]] = fir.if %[[VAL_20]] -> (f32) {
+// CHECK:                 %[[VAL_22:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_18]], %[[VAL_8]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
+// CHECK:                 %[[VAL_23:.*]] = arith.addf %[[VAL_19]], %[[VAL_22]] : f32
+// CHECK:                 fir.result %[[VAL_23]] : f32
+// CHECK:               } else {
+// CHECK:                 fir.result %[[VAL_19]] : f32
+// CHECK:               }
+// CHECK:               fir.result %[[VAL_21]] : f32
+// CHECK:             }
+// CHECK:             hlfir.yield_element %[[VAL_17]] : f32
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
 // array mask
 func.func @sum_array_mask(%arg0: !hlfir.expr<?x3xf32>, %mask: !fir.box<!fir.array<?x3x!fir.logical<1>>>) {
   %cst = arith.constant 2 : i32
@@ -247,29 +291,37 @@ func.func @sum_array_mask(%arg0: !hlfir.expr<?x3xf32>, %mask: !fir.box<!fir.arra
 // CHECK:           ^bb0(%[[VAL_8:.*]]: index):
 // CHECK:             %[[VAL_9:.*]] = arith.constant 1 : index
 // CHECK:             %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:             %[[VAL_11:.*]] = fir.do_loop %[[VAL_12:.*]] = %[[VAL_9]] to %[[VAL_5]] step %[[VAL_9]] iter_args(%[[VAL_13:.*]] = %[[VAL_10]]) -> (f32) {
-// CHECK:               %[[VAL_14:.*]] = arith.constant 0 : index
-// CHECK:               %[[VAL_15:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_14]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index) -> (index, index, index)
-// CHECK:               %[[VAL_16:.*]] = arith.constant 1 : index
-// CHECK:               %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_16]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index) -> (index, index, index)
-// CHECK:               %[[VAL_18:.*]] = arith.constant 1 : index
-// CHECK:               %[[VAL_19:.*]] = arith.subi %[[VAL_15]]#0, %[[VAL_18]] : index
-// CHECK:               %[[VAL_20:.*]] = arith.addi %[[VAL_8]], %[[VAL_19]] : index
-// CHECK:               %[[VAL_21:.*]] = arith.subi %[[VAL_17]]#0, %[[VAL_18]] : index
-// CHECK:               %[[VAL_22:.*]] = arith.addi %[[VAL_12]], %[[VAL_21]] : index
-// CHECK:               %[[VAL_23:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_20]], %[[VAL_22]])  : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index, index) -> !fir.ref<!fir.logical<1>>
-// CHECK:               %[[VAL_24:.*]] = fir.load %[[VAL_23]] : !fir.ref<!fir.logical<1>>
-// CHECK:               %[[VAL_25:.*]] = fir.convert %[[VAL_24]] : (!fir.logical<1>) -> i1
-// CHECK:               %[[VAL_26:.*]] = fir.if %[[VAL_25]] -> (f32) {
-// CHECK:                 %[[VAL_27:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_8]], %[[VAL_12]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
-// CHECK:                 %[[VAL_28:.*]] = arith.addf %[[VAL_13]], %[[VAL_27]] : f32
-// CHECK:                 fir.result %[[VAL_28]] : f32
+// CHECK:             %[[VAL_11:.*]] = fir.is_present %[[VAL_1]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>) -> i1
+// CHECK:             %[[VAL_12:.*]] = fir.do_loop %[[VAL_13:.*]] = %[[VAL_9]] to %[[VAL_5]] step %[[VAL_9]] iter_args(%[[VAL_14:.*]] = %[[VAL_10]]) -> (f32) {
+// CHECK:               %[[VAL_15:.*]] = fir.if %[[VAL_11]] -> (!fir.logical<1>) {
+// CHECK:                 %[[VAL_16:.*]] = arith.constant 0 : index
+// CHECK:                 %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_16]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index) -> (index, index, index)
+// CHECK:                 %[[VAL_18:.*]] = arith.constant 1 : index
+// CHECK:                 %[[VAL_19:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_18]] : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index) -> (index, index, index)
+// CHECK:                 %[[VAL_20:.*]] = arith.constant 1 : index
+// CHECK:                 %[[VAL_21:.*]] = arith.subi %[[VAL_17]]#0, %[[VAL_20]] : index
+// CHECK:                 %[[VAL_22:.*]] = arith.addi %[[VAL_8]], %[[VAL_21]] : index
+// CHECK:                 %[[VAL_23:.*]] = arith.subi %[[VAL_19]]#0, %[[VAL_20]] : index
+// CHECK:                 %[[VAL_24:.*]] = arith.addi %[[VAL_13]], %[[VAL_23]] : index
+// CHECK:                 %[[VAL_25:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_22]], %[[VAL_24]])  : (!fir.box<!fir.array<?x3x!fir.logical<1>>>, index, index) -> !fir.ref<!fir.logical<1>>
+// CHECK:                 %[[VAL_26:.*]] = fir.load %[[VAL_25]] : !fir.ref<!fir.logical<1>>
+// CHECK:                 fir.result %[[VAL_26]] : !fir.logical<1>
 // CHECK:               } else {
-// CHECK:                 fir.result %[[VAL_13]] : f32
+// CHECK:                 %[[VAL_27:.*]] = arith.constant true
+// CHECK:                 %[[VAL_28:.*]] = fir.convert %[[VAL_27]] : (i1) -> !fir.logical<1>
+// CHECK:                 fir.result %[[VAL_28]] : !fir.logical<1>
 // CHECK:               }
-// CHECK:               fir.result %[[VAL_26]] : f32
+// CHECK:               %[[VAL_29:.*]] = fir.convert %[[VAL_15]] : (!fir.logical<1>) -> i1
+// CHECK:               %[[VAL_30:.*]] = fir.if %[[VAL_29]] -> (f32) {
+// CHECK:                 %[[VAL_31:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_8]], %[[VAL_13]] : (!hlfir.expr<?x3xf32>, index, index) -> f32
+// CHECK:                 %[[VAL_32:.*]] = arith.addf %[[VAL_14]], %[[VAL_31]] : f32
+// CHECK:                 fir.result %[[VAL_32]] : f32
+// CHECK:               } else {
+// CHECK:                 fir.result %[[VAL_14]] : f32
+// CHECK:               }
+// CHECK:               fir.result %[[VAL_30]] : f32
 // CHECK:             }
-// CHECK:             hlfir.yield_element %[[VAL_11]] : f32
+// CHECK:             hlfir.yield_element %[[VAL_12]] : f32
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }

>From 091e5922a4a6b9cb84dc323be2b753d69caf881e Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Wed, 4 Dec 2024 14:48:59 -0800
Subject: [PATCH 3/3] Fixed typo in the test.

---
 flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
index 48c4144f70393f..703b6673154f3f 100644
--- a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
+++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
@@ -142,12 +142,12 @@ func.func @sum_box_unknown_extent2(%arg0: !fir.box<!fir.array<?x3xcomplex<f64>>>
 // CHECK:         }
 
 // expr with unknown extent
-func.func @sum_expr_unkwnonw_extent1(%arg0: !hlfir.expr<?x3xf32>) {
+func.func @sum_expr_unknown_extent1(%arg0: !hlfir.expr<?x3xf32>) {
   %cst = arith.constant 1 : i32
   %res = hlfir.sum %arg0 dim %cst : (!hlfir.expr<?x3xf32>, i32) -> !hlfir.expr<3xf32>
   return
 }
-// CHECK-LABEL:   func.func @sum_expr_unkwnonw_extent1(
+// CHECK-LABEL:   func.func @sum_expr_unknown_extent1(
 // CHECK-SAME:                                         %[[VAL_0:.*]]: !hlfir.expr<?x3xf32>) {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 1 : i32
 // CHECK:           %[[VAL_2:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x3xf32>) -> !fir.shape<2>
@@ -168,12 +168,12 @@ func.func @sum_expr_unkwnonw_extent1(%arg0: !hlfir.expr<?x3xf32>) {
 // CHECK:           return
 // CHECK:         }
 
-func.func @sum_expr_unkwnonw_extent2(%arg0: !hlfir.expr<?x3xf32>) {
+func.func @sum_expr_unknown_extent2(%arg0: !hlfir.expr<?x3xf32>) {
   %cst = arith.constant 2 : i32
   %res = hlfir.sum %arg0 dim %cst : (!hlfir.expr<?x3xf32>, i32) -> !hlfir.expr<?xf32>
   return
 }
-// CHECK-LABEL:   func.func @sum_expr_unkwnonw_extent2(
+// CHECK-LABEL:   func.func @sum_expr_unknown_extent2(
 // CHECK-SAME:                                         %[[VAL_0:.*]]: !hlfir.expr<?x3xf32>) {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 2 : i32
 // CHECK:           %[[VAL_2:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x3xf32>) -> !fir.shape<2>



More information about the flang-commits mailing list