[flang-commits] [flang] [flang] Expand SUM(DIM=CONSTANT) into an hlfir.elemental. (PR #118556)
Slava Zakharin via flang-commits
flang-commits at lists.llvm.org
Wed Dec 4 14:48:08 PST 2024
================
@@ -90,13 +91,198 @@ class TransposeAsElementalConversion
}
};
+// Expand the SUM(DIM=CONSTANT) operation into .
+class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
+public:
+ using mlir::OpRewritePattern<hlfir::SumOp>::OpRewritePattern;
+
+ llvm::LogicalResult
+ matchAndRewrite(hlfir::SumOp sum,
+ mlir::PatternRewriter &rewriter) const override {
+ mlir::Location loc = sum.getLoc();
+ fir::FirOpBuilder builder{rewriter, sum.getOperation()};
+ hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType());
+ assert(expr && "expected an expression type for the result of hlfir.sum");
+ mlir::Type elementType = expr.getElementType();
+ hlfir::Entity array = hlfir::Entity{sum.getArray()};
+ mlir::Value mask = sum.getMask();
+ mlir::Value dim = sum.getDim();
+ int64_t dimVal = fir::getIntIfConstant(dim).value_or(0);
+ assert(dimVal > 0 && "DIM must be present and a positive constant");
+ mlir::Value resultShape, dimExtent;
+ std::tie(resultShape, dimExtent) =
+ genResultShape(loc, builder, array, dimVal);
+
+ auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+ mlir::ValueRange inputIndices) -> hlfir::Entity {
+ // Loop over all indices in the DIM dimension, and reduce all values.
+ // We do not need to create the reduction loop always: if we can
+ // slice the input array given the inputIndices, then we can
+ // just apply a new SUM operation (total reduction) to the slice.
+ // For the time being, generate the explicit loop because the slicing
+ // requires generating an elemental operation for the input array
+ // (and the mask, if present).
+ // TODO: produce the slices and new SUM after adding a pattern
+ // for expanding total reduction SUM case.
+ mlir::Type indexType = builder.getIndexType();
+ auto one = builder.createIntegerConstant(loc, indexType, 1);
+ auto ub = builder.createConvert(loc, indexType, dimExtent);
+
+ // Initial value for the reduction.
+ mlir::Value initValue = genInitValue(loc, builder, elementType);
+
+ // The reduction loop may be unordered if FastMathFlags::reassoc
+ // transformations are allowed. The integer reduction is always
+ // unordered.
+ bool isUnordered = mlir::isa<mlir::IntegerType>(elementType) ||
+ static_cast<bool>(sum.getFastmath() &
+ mlir::arith::FastMathFlags::reassoc);
----------------
vzakhari wrote:
That is not a language requirement. It is related more to the way the current clang/flang compilers behave if you write a SUM reduction by hand. For example, I believe, LLVM vectorizer will not vectorize a reduction loop unless `reassoc` is attached to the floating point `add` instruction.
Yes, the parallel models may override this based on the assumption that the user requested unordered reduction.
https://github.com/llvm/llvm-project/pull/118556
More information about the flang-commits
mailing list