[flang-commits] [flang] [flang] Expand SUM(DIM=CONSTANT) into an hlfir.elemental. (PR #118556)

via flang-commits flang-commits at lists.llvm.org
Wed Dec 4 02:41:58 PST 2024


================
@@ -90,13 +91,198 @@ class TransposeAsElementalConversion
   }
 };
 
+// Expand the SUM(DIM=CONSTANT) operation into .
+class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
+public:
+  using mlir::OpRewritePattern<hlfir::SumOp>::OpRewritePattern;
+
+  llvm::LogicalResult
+  matchAndRewrite(hlfir::SumOp sum,
+                  mlir::PatternRewriter &rewriter) const override {
+    mlir::Location loc = sum.getLoc();
+    fir::FirOpBuilder builder{rewriter, sum.getOperation()};
+    hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType());
+    assert(expr && "expected an expression type for the result of hlfir.sum");
+    mlir::Type elementType = expr.getElementType();
+    hlfir::Entity array = hlfir::Entity{sum.getArray()};
+    mlir::Value mask = sum.getMask();
+    mlir::Value dim = sum.getDim();
+    int64_t dimVal = fir::getIntIfConstant(dim).value_or(0);
+    assert(dimVal > 0 && "DIM must be present and a positive constant");
+    mlir::Value resultShape, dimExtent;
+    std::tie(resultShape, dimExtent) =
+        genResultShape(loc, builder, array, dimVal);
+
+    auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+                         mlir::ValueRange inputIndices) -> hlfir::Entity {
+      // Loop over all indices in the DIM dimension, and reduce all values.
+      // We do not need to create the reduction loop always: if we can
+      // slice the input array given the inputIndices, then we can
+      // just apply a new SUM operation (total reduction) to the slice.
+      // For the time being, generate the explicit loop because the slicing
+      // requires generating an elemental operation for the input array
+      // (and the mask, if present).
+      // TODO: produce the slices and new SUM after adding a pattern
+      // for expanding total reduction SUM case.
+      mlir::Type indexType = builder.getIndexType();
+      auto one = builder.createIntegerConstant(loc, indexType, 1);
+      auto ub = builder.createConvert(loc, indexType, dimExtent);
+
+      // Initial value for the reduction.
+      mlir::Value initValue = genInitValue(loc, builder, elementType);
+
+      // The reduction loop may be unordered if FastMathFlags::reassoc
+      // transformations are allowed. The integer reduction is always
+      // unordered.
+      bool isUnordered = mlir::isa<mlir::IntegerType>(elementType) ||
+                         static_cast<bool>(sum.getFastmath() &
+                                           mlir::arith::FastMathFlags::reassoc);
----------------
jeanPerier wrote:

Is this a language requirement, or is it to be safe with regards to what existing compilers are doing?

I think we may want to lift this when parallels models are enabled (No need to bother now since as you noted below, this currently does not happen). 

https://github.com/llvm/llvm-project/pull/118556


More information about the flang-commits mailing list