[flang-commits] [flang] [flang] Simplify hlfir.sum total reductions. (PR #119482)

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Wed Dec 11 17:23:38 PST 2024


================
@@ -105,34 +105,47 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
                   mlir::PatternRewriter &rewriter) const override {
     mlir::Location loc = sum.getLoc();
     fir::FirOpBuilder builder{rewriter, sum.getOperation()};
-    hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType());
-    assert(expr && "expected an expression type for the result of hlfir.sum");
-    mlir::Type elementType = expr.getElementType();
+    mlir::Type elementType = hlfir::getFortranElementType(sum.getType());
     hlfir::Entity array = hlfir::Entity{sum.getArray()};
     mlir::Value mask = sum.getMask();
     mlir::Value dim = sum.getDim();
-    int64_t dimVal = fir::getIntIfConstant(dim).value_or(0);
+    bool isTotalReduction = hlfir::Entity{sum}.getRank() == 0;
+    int64_t dimVal =
+        isTotalReduction ? 0 : fir::getIntIfConstant(dim).value_or(0);
     mlir::Value resultShape, dimExtent;
-    std::tie(resultShape, dimExtent) =
-        genResultShape(loc, builder, array, dimVal);
+    llvm::SmallVector<mlir::Value> arrayExtents;
+    if (isTotalReduction)
+      arrayExtents = genArrayExtents(loc, builder, array);
+    else
+      std::tie(resultShape, dimExtent) =
+          genResultShapeForPartialReduction(loc, builder, array, dimVal);
+
+    // If the mask is present and is a scalar, then we'd better load its value
+    // outside of the reduction loop making the loop unswitching easier.
+    mlir::Value isPresentPred, maskValue;
+    if (mask) {
+      if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
+        // MASK represented by a box might be dynamically optional,
+        // so we have to check for its presence before accessing it.
+        isPresentPred =
+            builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), mask);
+      }
+
+      if (hlfir::Entity{mask}.isScalar())
+        maskValue = genMaskValue(loc, builder, mask, isPresentPred, {});
+    }
 
     auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
                          mlir::ValueRange inputIndices) -> hlfir::Entity {
       // Loop over all indices in the DIM dimension, and reduce all values.
-      // We do not need to create the reduction loop always: if we can
-      // slice the input array given the inputIndices, then we can
-      // just apply a new SUM operation (total reduction) to the slice.
-      // For the time being, generate the explicit loop because the slicing
-      // requires generating an elemental operation for the input array
-      // (and the mask, if present).
-      // TODO: produce the slices and new SUM after adding a pattern
-      // for expanding total reduction SUM case.
-      mlir::Type indexType = builder.getIndexType();
-      auto one = builder.createIntegerConstant(loc, indexType, 1);
-      auto ub = builder.createConvert(loc, indexType, dimExtent);
+      // If DIM is not present, do total reduction.
 
+      // Create temporary scalar for keeping the running reduction value.
+      mlir::Value reductionTemp =
+          builder.createTemporaryAlloc(loc, elementType, ".sum.reduction");
----------------
vzakhari wrote:

@jeanPerier, what do you think about calling this outside of `genKernel`?  It looks like it results in stacksave/stackrestore in the stack reclaim pass (after the elemental is transformed into loops), which is not ideal. I think it should be safe to hoist this call provided that the initializing store is kept inside the elemental.

https://github.com/llvm/llvm-project/pull/119482


More information about the flang-commits mailing list