[flang-commits] [flang] [flang] Expand SUM(DIM=CONSTANT) into an hlfir.elemental. (PR #118556)

via flang-commits flang-commits at lists.llvm.org
Wed Dec 4 02:42:00 PST 2024


================
@@ -90,13 +91,198 @@ class TransposeAsElementalConversion
   }
 };
 
+// Expand the SUM(DIM=CONSTANT) operation into .
+class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
+public:
+  using mlir::OpRewritePattern<hlfir::SumOp>::OpRewritePattern;
+
+  llvm::LogicalResult
+  matchAndRewrite(hlfir::SumOp sum,
+                  mlir::PatternRewriter &rewriter) const override {
+    mlir::Location loc = sum.getLoc();
+    fir::FirOpBuilder builder{rewriter, sum.getOperation()};
+    hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType());
+    assert(expr && "expected an expression type for the result of hlfir.sum");
+    mlir::Type elementType = expr.getElementType();
+    hlfir::Entity array = hlfir::Entity{sum.getArray()};
+    mlir::Value mask = sum.getMask();
+    mlir::Value dim = sum.getDim();
+    int64_t dimVal = fir::getIntIfConstant(dim).value_or(0);
+    assert(dimVal > 0 && "DIM must be present and a positive constant");
+    mlir::Value resultShape, dimExtent;
+    std::tie(resultShape, dimExtent) =
+        genResultShape(loc, builder, array, dimVal);
+
+    auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+                         mlir::ValueRange inputIndices) -> hlfir::Entity {
+      // Loop over all indices in the DIM dimension, and reduce all values.
+      // We do not need to create the reduction loop always: if we can
+      // slice the input array given the inputIndices, then we can
+      // just apply a new SUM operation (total reduction) to the slice.
+      // For the time being, generate the explicit loop because the slicing
+      // requires generating an elemental operation for the input array
+      // (and the mask, if present).
+      // TODO: produce the slices and new SUM after adding a pattern
+      // for expanding total reduction SUM case.
+      mlir::Type indexType = builder.getIndexType();
+      auto one = builder.createIntegerConstant(loc, indexType, 1);
+      auto ub = builder.createConvert(loc, indexType, dimExtent);
+
+      // Initial value for the reduction.
+      mlir::Value initValue = genInitValue(loc, builder, elementType);
+
+      // The reduction loop may be unordered if FastMathFlags::reassoc
+      // transformations are allowed. The integer reduction is always
+      // unordered.
+      bool isUnordered = mlir::isa<mlir::IntegerType>(elementType) ||
+                         static_cast<bool>(sum.getFastmath() &
+                                           mlir::arith::FastMathFlags::reassoc);
+
+      // If the mask is present and is a scalar, then we'd better load its value
+      // outside of the reduction loop making the loop unswitching easier.
+      // Maybe it is worth hoisting it from the elemental operation as well.
+      if (mask) {
+        hlfir::Entity maskValue{mask};
+        if (maskValue.isScalar())
+          mask = hlfir::loadTrivialScalar(loc, builder, maskValue);
+      }
+
+      // NOTE: the outer elemental operation may be lowered into
+      // omp.workshare.loop_wrapper/omp.loop_nest later, so the reduction
+      // loop may appear disjoint from the workshare loop nest.
+      // Moreover, the inner loop is not strictly nested (due to the reduction
+      // starting value initialization), and the above omp dialect operations
+      // cannot produce results.
+      // It is unclear what we should do about it yet.
+      auto doLoop = builder.create<fir::DoLoopOp>(
+          loc, one, ub, one, isUnordered, /*finalCountValue=*/false,
+          mlir::ValueRange{initValue});
+
+      // Address the input array using the reduction loop's IV
+      // for the DIM dimension.
+      mlir::Value iv = doLoop.getInductionVar();
+      llvm::SmallVector<mlir::Value> indices{inputIndices};
+      indices.insert(indices.begin() + dimVal - 1, iv);
+
+      mlir::OpBuilder::InsertionGuard guard(builder);
+      builder.setInsertionPointToStart(doLoop.getBody());
+      mlir::Value reductionValue = doLoop.getRegionIterArgs()[0];
+      fir::IfOp ifOp;
+      if (mask) {
+        // Make the reduction value update conditional on the value
+        // of the mask.
+        hlfir::Entity maskValue{mask};
+        if (!maskValue.isScalar()) {
+          // If the mask is an array, use the elemental and the loop indices
+          // to address the proper mask element.
+          maskValue = hlfir::getElementAt(loc, builder, maskValue, indices);
+          maskValue = hlfir::loadTrivialScalar(loc, builder, maskValue);
+        }
+        mlir::Value isUnmasked =
+            builder.create<fir::ConvertOp>(loc, builder.getI1Type(), maskValue);
+        ifOp = builder.create<fir::IfOp>(loc, elementType, isUnmasked,
+                                         /*withElseRegion=*/true);
+        // In the 'else' block return the current reduction value.
+        builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+        builder.create<fir::ResultOp>(loc, reductionValue);
+
+        // In the 'then' block do the actual addition.
+        builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+      }
+
+      hlfir::Entity element = hlfir::getElementAt(loc, builder, array, indices);
+      hlfir::Entity elementValue =
+          hlfir::loadTrivialScalar(loc, builder, element);
+      // NOTE: we can use "Kahan summation" same way as the runtime
+      // (e.g. when fast-math is not allowed), but let's start with
+      // the simple version.
+      reductionValue = genScalarAdd(loc, builder, reductionValue, elementValue);
+      builder.create<fir::ResultOp>(loc, reductionValue);
+
+      if (ifOp) {
+        builder.setInsertionPointAfter(ifOp);
+        builder.create<fir::ResultOp>(loc, ifOp.getResult(0));
+      }
+
+      return hlfir::Entity{doLoop.getResult(0)};
+    };
+    hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
+        loc, builder, elementType, resultShape, {}, genKernel,
+        /*isUnordered=*/true, /*polymorphicMold=*/nullptr,
+        sum.getResult().getType());
+
+    // it wouldn't be safe to replace block arguments with a different
+    // hlfir.expr type. Types can differ due to differing amounts of shape
+    // information
+    assert(elementalOp.getResult().getType() == sum.getResult().getType());
+
+    rewriter.replaceOp(sum, elementalOp);
+    return mlir::success();
+  }
+
+private:
+  // Return fir.shape specifying the shape of the result
+  // of a SUM reduction with DIM=dimVal. The second return value
+  // is the extent of the DIM dimension.
+  static std::tuple<mlir::Value, mlir::Value>
+  genResultShape(mlir::Location loc, fir::FirOpBuilder &builder,
+                 hlfir::Entity array, int64_t dimVal) {
+    mlir::Value inShape = hlfir::genShape(loc, builder, array);
+    llvm::SmallVector<mlir::Value> inExtents =
+        hlfir::getExplicitExtentsFromShape(inShape, builder);
+    if (inShape.getUses().empty())
+      inShape.getDefiningOp()->erase();
+
+    mlir::Value dimExtent = inExtents[dimVal - 1];
+    inExtents.erase(inExtents.begin() + dimVal - 1);
+    return {builder.create<fir::ShapeOp>(loc, inExtents), dimExtent};
+  }
+
+  // Generate the initial value for a SUM reduction with the given
+  // data type.
+  static mlir::Value genInitValue(mlir::Location loc,
+                                  fir::FirOpBuilder &builder,
+                                  mlir::Type elementType) {
+    if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
+      const llvm::fltSemantics &sem = ty.getFloatSemantics();
+      return builder.createRealConstant(loc, elementType,
+                                        llvm::APFloat::getZero(sem));
+    } else if (auto ty = mlir::dyn_cast<mlir::ComplexType>(elementType)) {
+      mlir::Value initValue = genInitValue(loc, builder, ty.getElementType());
+      return fir::factory::Complex{builder, loc}.createComplex(ty, initValue,
+                                                               initValue);
+    } else if (mlir::isa<mlir::IntegerType>(elementType)) {
+      return builder.createIntegerConstant(loc, elementType, 0);
+    }
+
+    llvm_unreachable("unsupported SUM reduction type");
+  }
+
+  // Generate scalar addition of the two values (of the same data type).
+  static mlir::Value genScalarAdd(mlir::Location loc,
+                                  fir::FirOpBuilder &builder,
+                                  mlir::Value value1, mlir::Value value2) {
+    mlir::Type ty = value1.getType();
+    assert(ty == value2.getType() && "reduction values' types do not match");
+    if (mlir::isa<mlir::FloatType>(ty))
+      return builder.create<mlir::arith::AddFOp>(loc, value1, value2);
----------------
jeanPerier wrote:

Are the fastmath flags from hlfir.sum propagated here and in AddCOp?

https://github.com/llvm/llvm-project/pull/118556


More information about the flang-commits mailing list