[flang-commits] [flang] Reland [flang] Generalized simplification of HLFIR reduction ops. (#136071) (PR #136246)
via flang-commits
flang-commits at lists.llvm.org
Thu Apr 17 20:02:04 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Slava Zakharin (vzakhari)
<details>
<summary>Changes</summary>
This change generalizes SumAsElemental inlining in
SimplifyHLFIRIntrinsics pass so that it can be applied
to ALL, ANY, COUNT, MAXLOC, MAXVAL, MINLOC, MINVAL, SUM.
This change makes the special handling of the reduction
operations in OptimizedBufferization redundant: once HLFIR
operations are inlined, the hlfir.elemental inlining should
do the rest of the job.
---
Patch is 276.21 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136246.diff
16 Files Affected:
- (modified) flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp (-465)
- (modified) flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp (+864-200)
- (removed) flang/test/HLFIR/all-elemental.fir (-91)
- (removed) flang/test/HLFIR/any-elemental.fir (-190)
- (removed) flang/test/HLFIR/count-elemental.fir (-314)
- (removed) flang/test/HLFIR/maxloc-elemental.fir (-133)
- (removed) flang/test/HLFIR/maxval-elemental.fir (-117)
- (removed) flang/test/HLFIR/minloc-elemental.fir (-397)
- (removed) flang/test/HLFIR/minval-elemental.fir (-95)
- (added) flang/test/HLFIR/simplify-hlfir-intrinsics-all.fir (+123)
- (added) flang/test/HLFIR/simplify-hlfir-intrinsics-any.fir (+123)
- (added) flang/test/HLFIR/simplify-hlfir-intrinsics-count.fir (+127)
- (added) flang/test/HLFIR/simplify-hlfir-intrinsics-maxloc.fir (+312)
- (added) flang/test/HLFIR/simplify-hlfir-intrinsics-maxval.fir (+186)
- (added) flang/test/HLFIR/simplify-hlfir-intrinsics-minloc.fir (+312)
- (added) flang/test/HLFIR/simplify-hlfir-intrinsics-minval.fir (+186)
``````````diff
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index c489450384a35..79aabd2981e1a 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -772,458 +772,6 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
return mlir::success();
}
-using GenBodyFn =
- std::function<mlir::Value(fir::FirOpBuilder &, mlir::Location, mlir::Value,
- const llvm::SmallVectorImpl<mlir::Value> &)>;
-static mlir::Value generateReductionLoop(fir::FirOpBuilder &builder,
- mlir::Location loc, mlir::Value init,
- mlir::Value shape, GenBodyFn genBody) {
- auto extents = hlfir::getIndexExtents(loc, builder, shape);
- mlir::Value reduction = init;
- mlir::IndexType idxTy = builder.getIndexType();
- mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1);
-
- // Create a reduction loop nest. We use one-based indices so that they can be
- // passed to the elemental, and reverse the order so that they can be
- // generated in column-major order for better performance.
- llvm::SmallVector<mlir::Value> indices(extents.size(), mlir::Value{});
- for (unsigned i = 0; i < extents.size(); ++i) {
- auto loop = builder.create<fir::DoLoopOp>(
- loc, oneIdx, extents[extents.size() - i - 1], oneIdx, false,
- /*finalCountValue=*/false, reduction);
- reduction = loop.getRegionIterArgs()[0];
- indices[extents.size() - i - 1] = loop.getInductionVar();
- // Set insertion point to the loop body so that the next loop
- // is inserted inside the current one.
- builder.setInsertionPointToStart(loop.getBody());
- }
-
- // Generate the body
- reduction = genBody(builder, loc, reduction, indices);
-
- // Unwind the loop nest.
- for (unsigned i = 0; i < extents.size(); ++i) {
- auto result = builder.create<fir::ResultOp>(loc, reduction);
- auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
- reduction = loop.getResult(0);
- // Set insertion point after the loop operation that we have
- // just processed.
- builder.setInsertionPointAfter(loop.getOperation());
- }
-
- return reduction;
-}
-
-auto makeMinMaxInitValGenerator(bool isMax) {
- return [isMax](fir::FirOpBuilder builder, mlir::Location loc,
- mlir::Type elementType) -> mlir::Value {
- if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
- const llvm::fltSemantics &sem = ty.getFloatSemantics();
- llvm::APFloat limit = llvm::APFloat::getInf(sem, /*Negative=*/isMax);
- return builder.createRealConstant(loc, elementType, limit);
- }
- unsigned bits = elementType.getIntOrFloatBitWidth();
- int64_t limitInt =
- isMax ? llvm::APInt::getSignedMinValue(bits).getSExtValue()
- : llvm::APInt::getSignedMaxValue(bits).getSExtValue();
- return builder.createIntegerConstant(loc, elementType, limitInt);
- };
-}
-
-mlir::Value generateMinMaxComparison(fir::FirOpBuilder builder,
- mlir::Location loc, mlir::Value elem,
- mlir::Value reduction, bool isMax) {
- if (mlir::isa<mlir::FloatType>(reduction.getType())) {
- // For FP reductions we want the first smallest value to be used, that
- // is not NaN. A OGL/OLT condition will usually work for this unless all
- // the values are Nan or Inf. This follows the same logic as
- // NumericCompare for Minloc/Maxlox in extrema.cpp.
- mlir::Value cmp = builder.create<mlir::arith::CmpFOp>(
- loc,
- isMax ? mlir::arith::CmpFPredicate::OGT
- : mlir::arith::CmpFPredicate::OLT,
- elem, reduction);
- mlir::Value cmpNan = builder.create<mlir::arith::CmpFOp>(
- loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction);
- mlir::Value cmpNan2 = builder.create<mlir::arith::CmpFOp>(
- loc, mlir::arith::CmpFPredicate::OEQ, elem, elem);
- cmpNan = builder.create<mlir::arith::AndIOp>(loc, cmpNan, cmpNan2);
- return builder.create<mlir::arith::OrIOp>(loc, cmp, cmpNan);
- } else if (mlir::isa<mlir::IntegerType>(reduction.getType())) {
- return builder.create<mlir::arith::CmpIOp>(
- loc,
- isMax ? mlir::arith::CmpIPredicate::sgt
- : mlir::arith::CmpIPredicate::slt,
- elem, reduction);
- }
- llvm_unreachable("unsupported type");
-}
-
-/// Given a reduction operation with an elemental/designate source, attempt to
-/// generate a do-loop to perform the operation inline.
-/// %e = hlfir.elemental %shape unordered
-/// %r = hlfir.count %e
-/// =>
-/// %r = for.do_loop %arg = 1 to bound(%shape) step 1 iter_args(%arg2 = init)
-/// %i = <inline elemental>
-/// %c = <reduce count> %i
-/// fir.result %c
-template <typename Op>
-class ReductionConversion : public mlir::OpRewritePattern<Op> {
-public:
- using mlir::OpRewritePattern<Op>::OpRewritePattern;
-
- llvm::LogicalResult
- matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
- mlir::Location loc = op.getLoc();
- // Select source and validate its arguments.
- mlir::Value source;
- bool valid = false;
- if constexpr (std::is_same_v<Op, hlfir::AnyOp> ||
- std::is_same_v<Op, hlfir::AllOp> ||
- std::is_same_v<Op, hlfir::CountOp>) {
- source = op.getMask();
- valid = !op.getDim();
- } else if constexpr (std::is_same_v<Op, hlfir::MaxvalOp> ||
- std::is_same_v<Op, hlfir::MinvalOp>) {
- source = op.getArray();
- valid = !op.getDim() && !op.getMask();
- } else if constexpr (std::is_same_v<Op, hlfir::MaxlocOp> ||
- std::is_same_v<Op, hlfir::MinlocOp>) {
- source = op.getArray();
- valid = !op.getDim() && !op.getMask() && !op.getBack();
- }
- if (!valid)
- return rewriter.notifyMatchFailure(
- op, "Currently does not accept optional arguments");
-
- hlfir::ElementalOp elemental;
- hlfir::DesignateOp designate;
- mlir::Value shape;
- if ((elemental = source.template getDefiningOp<hlfir::ElementalOp>())) {
- shape = elemental.getOperand(0);
- } else if ((designate =
- source.template getDefiningOp<hlfir::DesignateOp>())) {
- shape = designate.getShape();
- } else {
- return rewriter.notifyMatchFailure(op, "Did not find valid argument");
- }
-
- auto inlineSource =
- [elemental,
- &designate](fir::FirOpBuilder builder, mlir::Location loc,
- const llvm::SmallVectorImpl<mlir::Value> &oneBasedIndices)
- -> mlir::Value {
- if (elemental) {
- // Inline the elemental and get the value from it.
- auto yield =
- inlineElementalOp(loc, builder, elemental, oneBasedIndices);
- auto tmp = yield.getElementValue();
- yield->erase();
- return tmp;
- }
- if (designate) {
- // Create a designator over the array designator, then load the
- // reference.
- mlir::Value elementAddr = hlfir::getElementAt(
- loc, builder, hlfir::Entity{designate.getResult()},
- oneBasedIndices);
- return builder.create<fir::LoadOp>(loc, elementAddr);
- }
- llvm_unreachable("unsupported type");
- };
-
- fir::FirOpBuilder builder{rewriter, op.getOperation()};
-
- mlir::Value init;
- GenBodyFn genBodyFn;
- if constexpr (std::is_same_v<Op, hlfir::AnyOp>) {
- init = builder.createIntegerConstant(loc, builder.getI1Type(), 0);
- genBodyFn = [inlineSource](
- fir::FirOpBuilder builder, mlir::Location loc,
- mlir::Value reduction,
- const llvm::SmallVectorImpl<mlir::Value> &oneBasedIndices)
- -> mlir::Value {
- // Conditionally set the reduction variable.
- mlir::Value cond = builder.create<fir::ConvertOp>(
- loc, builder.getI1Type(),
- inlineSource(builder, loc, oneBasedIndices));
- return builder.create<mlir::arith::OrIOp>(loc, reduction, cond);
- };
- } else if constexpr (std::is_same_v<Op, hlfir::AllOp>) {
- init = builder.createIntegerConstant(loc, builder.getI1Type(), 1);
- genBodyFn = [inlineSource](
- fir::FirOpBuilder builder, mlir::Location loc,
- mlir::Value reduction,
- const llvm::SmallVectorImpl<mlir::Value> &oneBasedIndices)
- -> mlir::Value {
- // Conditionally set the reduction variable.
- mlir::Value cond = builder.create<fir::ConvertOp>(
- loc, builder.getI1Type(),
- inlineSource(builder, loc, oneBasedIndices));
- return builder.create<mlir::arith::AndIOp>(loc, reduction, cond);
- };
- } else if constexpr (std::is_same_v<Op, hlfir::CountOp>) {
- init = builder.createIntegerConstant(loc, op.getType(), 0);
- genBodyFn = [inlineSource](
- fir::FirOpBuilder builder, mlir::Location loc,
- mlir::Value reduction,
- const llvm::SmallVectorImpl<mlir::Value> &oneBasedIndices)
- -> mlir::Value {
- // Conditionally add one to the current value
- mlir::Value cond = builder.create<fir::ConvertOp>(
- loc, builder.getI1Type(),
- inlineSource(builder, loc, oneBasedIndices));
- mlir::Value one =
- builder.createIntegerConstant(loc, reduction.getType(), 1);
- mlir::Value add1 =
- builder.create<mlir::arith::AddIOp>(loc, reduction, one);
- return builder.create<mlir::arith::SelectOp>(loc, cond, add1,
- reduction);
- };
- } else if constexpr (std::is_same_v<Op, hlfir::MaxlocOp> ||
- std::is_same_v<Op, hlfir::MinlocOp>) {
- // TODO: implement minloc/maxloc conversion.
- return rewriter.notifyMatchFailure(
- op, "Currently minloc/maxloc is not handled");
- } else if constexpr (std::is_same_v<Op, hlfir::MaxvalOp> ||
- std::is_same_v<Op, hlfir::MinvalOp>) {
- mlir::Type ty = op.getType();
- if (!(mlir::isa<mlir::FloatType>(ty) ||
- mlir::isa<mlir::IntegerType>(ty))) {
- return rewriter.notifyMatchFailure(
- op, "Type is not supported for Maxval or Minval yet");
- }
-
- bool isMax = std::is_same_v<Op, hlfir::MaxvalOp>;
- init = makeMinMaxInitValGenerator(isMax)(builder, loc, ty);
- genBodyFn = [inlineSource, isMax](
- fir::FirOpBuilder builder, mlir::Location loc,
- mlir::Value reduction,
- const llvm::SmallVectorImpl<mlir::Value> &oneBasedIndices)
- -> mlir::Value {
- mlir::Value val = inlineSource(builder, loc, oneBasedIndices);
- mlir::Value cmp =
- generateMinMaxComparison(builder, loc, val, reduction, isMax);
- return builder.create<mlir::arith::SelectOp>(loc, cmp, val, reduction);
- };
- } else {
- llvm_unreachable("unsupported type");
- }
-
- mlir::Value res =
- generateReductionLoop(builder, loc, init, shape, genBodyFn);
- if (res.getType() != op.getType())
- res = builder.create<fir::ConvertOp>(loc, op.getType(), res);
-
- // Check if the op was the only user of the source (apart from a destroy),
- // and remove it if so.
- mlir::Operation *sourceOp = source.getDefiningOp();
- mlir::Operation::user_range srcUsers = sourceOp->getUsers();
- hlfir::DestroyOp srcDestroy;
- if (std::distance(srcUsers.begin(), srcUsers.end()) == 2) {
- srcDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*srcUsers.begin());
- if (!srcDestroy)
- srcDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++srcUsers.begin());
- }
-
- rewriter.replaceOp(op, res);
- if (srcDestroy) {
- rewriter.eraseOp(srcDestroy);
- rewriter.eraseOp(sourceOp);
- }
- return mlir::success();
- }
-};
-
-// Look for minloc(mask=elemental) and generate the minloc loop with
-// inlined elemental.
-// %e = hlfir.elemental %shape ({ ... })
-// %m = hlfir.minloc %array mask %e
-template <typename Op>
-class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
-public:
- using mlir::OpRewritePattern<Op>::OpRewritePattern;
-
- llvm::LogicalResult
- matchAndRewrite(Op mloc, mlir::PatternRewriter &rewriter) const override {
- if (!mloc.getMask() || mloc.getDim() || mloc.getBack())
- return rewriter.notifyMatchFailure(mloc,
- "Did not find valid minloc/maxloc");
-
- bool isMax = std::is_same_v<Op, hlfir::MaxlocOp>;
-
- auto elemental =
- mloc.getMask().template getDefiningOp<hlfir::ElementalOp>();
- if (!elemental || hlfir::elementalOpMustProduceTemp(elemental))
- return rewriter.notifyMatchFailure(mloc, "Did not find elemental");
-
- mlir::Value array = mloc.getArray();
-
- unsigned rank = mlir::cast<hlfir::ExprType>(mloc.getType()).getShape()[0];
- mlir::Type arrayType = array.getType();
- if (!mlir::isa<fir::BoxType>(arrayType))
- return rewriter.notifyMatchFailure(
- mloc, "Currently requires a boxed type input");
- mlir::Type elementType = hlfir::getFortranElementType(arrayType);
- if (!fir::isa_trivial(elementType))
- return rewriter.notifyMatchFailure(
- mloc, "Character arrays are currently not handled");
-
- mlir::Location loc = mloc.getLoc();
- fir::FirOpBuilder builder{rewriter, mloc.getOperation()};
- mlir::Value resultArr = builder.createTemporary(
- loc, fir::SequenceType::get(
- rank, hlfir::getFortranElementType(mloc.getType())));
-
- auto init = makeMinMaxInitValGenerator(isMax);
-
- auto genBodyOp =
- [&rank, &resultArr, &elemental, isMax](
- fir::FirOpBuilder builder, mlir::Location loc,
- mlir::Type elementType, mlir::Value array, mlir::Value flagRef,
- mlir::Value reduction,
- const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value {
- // We are in the innermost loop: generate the elemental inline
- mlir::Value oneIdx =
- builder.createIntegerConstant(loc, builder.getIndexType(), 1);
- llvm::SmallVector<mlir::Value> oneBasedIndices;
- llvm::transform(
- indices, std::back_inserter(oneBasedIndices), [&](mlir::Value V) {
- return builder.create<mlir::arith::AddIOp>(loc, V, oneIdx);
- });
- hlfir::YieldElementOp yield =
- hlfir::inlineElementalOp(loc, builder, elemental, oneBasedIndices);
- mlir::Value maskElem = yield.getElementValue();
- yield->erase();
-
- mlir::Type ifCompatType = builder.getI1Type();
- mlir::Value ifCompatElem =
- builder.create<fir::ConvertOp>(loc, ifCompatType, maskElem);
-
- llvm::SmallVector<mlir::Type> resultsTy = {elementType, elementType};
- fir::IfOp maskIfOp =
- builder.create<fir::IfOp>(loc, elementType, ifCompatElem,
- /*withElseRegion=*/true);
- builder.setInsertionPointToStart(&maskIfOp.getThenRegion().front());
-
- // Set flag that mask was true at some point
- mlir::Value flagSet = builder.createIntegerConstant(
- loc, mlir::cast<fir::ReferenceType>(flagRef.getType()).getEleTy(), 1);
- mlir::Value isFirst = builder.create<fir::LoadOp>(loc, flagRef);
- mlir::Value addr = hlfir::getElementAt(loc, builder, hlfir::Entity{array},
- oneBasedIndices);
- mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
-
- // Compare with the max reduction value
- mlir::Value cmp =
- generateMinMaxComparison(builder, loc, elem, reduction, isMax);
-
- // The condition used for the loop is isFirst || <the condition above>.
- isFirst = builder.create<fir::ConvertOp>(loc, cmp.getType(), isFirst);
- isFirst = builder.create<mlir::arith::XOrIOp>(
- loc, isFirst, builder.createIntegerConstant(loc, cmp.getType(), 1));
- cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, isFirst);
-
- // Set the new coordinate to the result
- fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cmp,
- /*withElseRegion*/ true);
-
- builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- builder.create<fir::StoreOp>(loc, flagSet, flagRef);
- mlir::Type resultElemTy =
- hlfir::getFortranElementType(resultArr.getType());
- mlir::Type returnRefTy = builder.getRefType(resultElemTy);
- mlir::IndexType idxTy = builder.getIndexType();
-
- for (unsigned int i = 0; i < rank; ++i) {
- mlir::Value index = builder.createIntegerConstant(loc, idxTy, i + 1);
- mlir::Value resultElemAddr = builder.create<hlfir::DesignateOp>(
- loc, returnRefTy, resultArr, index);
- mlir::Value fortranIndex = builder.create<fir::ConvertOp>(
- loc, resultElemTy, oneBasedIndices[i]);
- builder.create<fir::StoreOp>(loc, fortranIndex, resultElemAddr);
- }
- builder.create<fir::ResultOp>(loc, elem);
- builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- builder.create<fir::ResultOp>(loc, reduction);
- builder.setInsertionPointAfter(ifOp);
-
- // Close the mask if
- builder.create<fir::ResultOp>(loc, ifOp.getResult(0));
- builder.setInsertionPointToStart(&maskIfOp.getElseRegion().front());
- builder.create<fir::ResultOp>(loc, reduction);
- builder.setInsertionPointAfter(maskIfOp);
-
- return maskIfOp.getResult(0);
- };
- auto getAddrFn = [](fir::FirOpBuilder builder, mlir::Location loc,
- const mlir::Type &resultElemType, mlir::Value resultArr,
- mlir::Value index) {
- mlir::Type resultRefTy = builder.getRefType(resultElemType);
- mlir::Value oneIdx =
- builder.createIntegerConstant(loc, builder.getIndexType(), 1);
- index = builder.create<mlir::arith::AddIOp>(loc, index, oneIdx);
- return builder.create<hlfir::DesignateOp>(loc, resultRefTy, resultArr,
- index);
- };
-
- // Initialize the result
- mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType());
- mlir::Type resultRefTy = builder.getRefType(resultElemTy);
- mlir::Value returnValue =
- builder.createIntegerConstant(loc, resultElemTy, 0);
- for (unsigned int i = 0; i < rank; ++i) {
- mlir::Value index =
- builder.createIntegerConstant(loc, builder.getIndexType(), i + 1);
- mlir::Value resultElemAddr = builder.create<hlfir::DesignateOp>(
- loc, resultRefTy, resultArr, index);
- builder.create<fir::StoreOp>(loc, returnValue, resultElemAddr);
- }
-
- fir::genMinMaxlocReductionLoop(builder, array, init, genBodyOp, getAddrFn,
- rank, elementType, loc, builder.getI1Type(),
- resultArr, false);
-
- mlir::Value asExpr = builder.create<hlfir::AsExprOp>(
- loc, resultArr, builder.createBool(loc, false));
-
- // Check all the users - the destroy is no longer required, and any assign
- // can use resultArr directly so that InlineHLFIRAssign pass
- // can optimize the results. Other operations are replaced with an AsExpr
- // for the temporary resultArr.
- llvm::SmallVector<hlfir::DestroyOp> destroys;
- llvm::SmallVector<hlfir::AssignOp> assigns;
- for (auto user : mloc->getUsers()) {
- if (auto destroy = mlir::dyn_cast<hlfir::DestroyOp>(user))
- destroys.push_back(destroy);
- else if (auto assign = mlir::dyn_cast<hlfir::AssignOp>(user))
- assigns.push_back(assign);
- }
-
- // Check if the minloc/maxloc was the only user of the elemental (apart from
- // a destroy), and remove it if so.
- mlir::Operation::user_range elemUsers = elemental->getUsers();
- hlfir::DestroyOp elemDestroy;
- if (std::distance(elemUsers.begin(), elemUsers.end()) == 2) {
- elemDestroy = mlir::dyn_cast<hlfir::Dest...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/136246
More information about the flang-commits
mailing list