[flang-commits] [flang] [flang] Simplify hlfir.sum total reductions. (PR #119482)

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Thu Dec 12 02:25:35 PST 2024


================
@@ -105,34 +105,47 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
                   mlir::PatternRewriter &rewriter) const override {
     mlir::Location loc = sum.getLoc();
     fir::FirOpBuilder builder{rewriter, sum.getOperation()};
-    hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType());
-    assert(expr && "expected an expression type for the result of hlfir.sum");
-    mlir::Type elementType = expr.getElementType();
+    mlir::Type elementType = hlfir::getFortranElementType(sum.getType());
     hlfir::Entity array = hlfir::Entity{sum.getArray()};
     mlir::Value mask = sum.getMask();
     mlir::Value dim = sum.getDim();
-    int64_t dimVal = fir::getIntIfConstant(dim).value_or(0);
+    bool isTotalReduction = hlfir::Entity{sum}.getRank() == 0;
+    int64_t dimVal =
+        isTotalReduction ? 0 : fir::getIntIfConstant(dim).value_or(0);
     mlir::Value resultShape, dimExtent;
-    std::tie(resultShape, dimExtent) =
-        genResultShape(loc, builder, array, dimVal);
+    llvm::SmallVector<mlir::Value> arrayExtents;
+    if (isTotalReduction)
+      arrayExtents = genArrayExtents(loc, builder, array);
+    else
+      std::tie(resultShape, dimExtent) =
+          genResultShapeForPartialReduction(loc, builder, array, dimVal);
+
+    // If the mask is present and is a scalar, then we'd better load its value
+    // outside of the reduction loop making the loop unswitching easier.
+    mlir::Value isPresentPred, maskValue;
+    if (mask) {
+      if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
+        // MASK represented by a box might be dynamically optional,
+        // so we have to check for its presence before accessing it.
+        isPresentPred =
+            builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), mask);
+      }
+
+      if (hlfir::Entity{mask}.isScalar())
+        maskValue = genMaskValue(loc, builder, mask, isPresentPred, {});
+    }
 
     auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
                          mlir::ValueRange inputIndices) -> hlfir::Entity {
       // Loop over all indices in the DIM dimension, and reduce all values.
-      // We do not need to create the reduction loop always: if we can
-      // slice the input array given the inputIndices, then we can
-      // just apply a new SUM operation (total reduction) to the slice.
-      // For the time being, generate the explicit loop because the slicing
-      // requires generating an elemental operation for the input array
-      // (and the mask, if present).
-      // TODO: produce the slices and new SUM after adding a pattern
-      // for expanding total reduction SUM case.
-      mlir::Type indexType = builder.getIndexType();
-      auto one = builder.createIntegerConstant(loc, indexType, 1);
-      auto ub = builder.createConvert(loc, indexType, dimExtent);
+      // If DIM is not present, do total reduction.
 
+      // Create temporary scalar for keeping the running reduction value.
+      mlir::Value reductionTemp =
+          builder.createTemporaryAlloc(loc, elementType, ".sum.reduction");
----------------
tblah wrote:

`fir::FirOpBuilder::getAllocaBlock` understands OpenMP operations and should give safe insertion points to move allocas in the stack reclaim pass. OpenMP parallelisation will all have happened by then.

If OpenMP workshare support is blocking optimizations in earlier passes please let me know and I will see if I can rethink the design.

https://github.com/llvm/llvm-project/pull/119482


More information about the flang-commits mailing list