[flang-commits] [flang] [flang] Simplify hlfir.sum total reductions. (PR #119482)
via flang-commits
flang-commits at lists.llvm.org
Thu Dec 12 02:16:15 PST 2024
================
@@ -105,34 +105,47 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
mlir::PatternRewriter &rewriter) const override {
mlir::Location loc = sum.getLoc();
fir::FirOpBuilder builder{rewriter, sum.getOperation()};
- hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType());
- assert(expr && "expected an expression type for the result of hlfir.sum");
- mlir::Type elementType = expr.getElementType();
+ mlir::Type elementType = hlfir::getFortranElementType(sum.getType());
hlfir::Entity array = hlfir::Entity{sum.getArray()};
mlir::Value mask = sum.getMask();
mlir::Value dim = sum.getDim();
- int64_t dimVal = fir::getIntIfConstant(dim).value_or(0);
+ bool isTotalReduction = hlfir::Entity{sum}.getRank() == 0;
+ int64_t dimVal =
+ isTotalReduction ? 0 : fir::getIntIfConstant(dim).value_or(0);
mlir::Value resultShape, dimExtent;
- std::tie(resultShape, dimExtent) =
- genResultShape(loc, builder, array, dimVal);
+ llvm::SmallVector<mlir::Value> arrayExtents;
+ if (isTotalReduction)
+ arrayExtents = genArrayExtents(loc, builder, array);
+ else
+ std::tie(resultShape, dimExtent) =
+ genResultShapeForPartialReduction(loc, builder, array, dimVal);
+
+ // If the mask is present and is a scalar, then we'd better load its value
+ // outside of the reduction loop making the loop unswitching easier.
+ mlir::Value isPresentPred, maskValue;
+ if (mask) {
+ if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
+ // MASK represented by a box might be dynamically optional,
+ // so we have to check for its presence before accessing it.
+ isPresentPred =
+ builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), mask);
+ }
+
+ if (hlfir::Entity{mask}.isScalar())
+ maskValue = genMaskValue(loc, builder, mask, isPresentPred, {});
+ }
auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
mlir::ValueRange inputIndices) -> hlfir::Entity {
// Loop over all indices in the DIM dimension, and reduce all values.
- // We do not need to create the reduction loop always: if we can
- // slice the input array given the inputIndices, then we can
- // just apply a new SUM operation (total reduction) to the slice.
- // For the time being, generate the explicit loop because the slicing
- // requires generating an elemental operation for the input array
- // (and the mask, if present).
- // TODO: produce the slices and new SUM after adding a pattern
- // for expanding total reduction SUM case.
- mlir::Type indexType = builder.getIndexType();
- auto one = builder.createIntegerConstant(loc, indexType, 1);
- auto ub = builder.createConvert(loc, indexType, dimExtent);
+ // If DIM is not present, do total reduction.
+ // Create temporary scalar for keeping the running reduction value.
+ mlir::Value reductionTemp =
+ builder.createTemporaryAlloc(loc, elementType, ".sum.reduction");
----------------
jeanPerier wrote:
Makes some sense to me, the only impact I see is that it may make harder parallelization of SUM(, DIM) which is otherwise trivial (each threads does reduces into an element of the result array).
Can you try what happens with a SUM(, DIM) inside a workshare construct? Since you enable the rewrite to an elemental, I think that the elemental to omp loop should kick in and hoisting the alloca may be bad there.
Maybe the stack reclaim pass should hoist constant size alloca outside of loops (with the assumption that parallelization of the loops happened at that point), at least for scalars. This may have impacts on the stack size of course, but for scalars that should be limited.
Since SUM(, DIM) was not parallelized before anyway, your solution would still be acceptable to me though.
https://github.com/llvm/llvm-project/pull/119482
More information about the flang-commits
mailing list