[flang-commits] [flang] Revert "[flang] Generalized simplification of HLFIR reduction ops." (PR #136218)
via flang-commits
flang-commits at lists.llvm.org
Thu Apr 17 15:48:17 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Slava Zakharin (vzakhari)
<details>
<summary>Changes</summary>
Reverts llvm/llvm-project#<!-- -->136071
---
Patch is 281.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136218.diff
18 Files Affected:
- (modified) flang/include/flang/Optimizer/Builder/HLFIRTools.h (-5)
- (modified) flang/lib/Optimizer/Builder/HLFIRTools.cpp (-27)
- (modified) flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp (+465)
- (modified) flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp (+197-842)
- (added) flang/test/HLFIR/all-elemental.fir (+91)
- (added) flang/test/HLFIR/any-elemental.fir (+190)
- (added) flang/test/HLFIR/count-elemental.fir (+314)
- (added) flang/test/HLFIR/maxloc-elemental.fir (+133)
- (added) flang/test/HLFIR/maxval-elemental.fir (+117)
- (added) flang/test/HLFIR/minloc-elemental.fir (+397)
- (added) flang/test/HLFIR/minval-elemental.fir (+95)
- (removed) flang/test/HLFIR/simplify-hlfir-intrinsics-all.fir (-123)
- (removed) flang/test/HLFIR/simplify-hlfir-intrinsics-any.fir (-123)
- (removed) flang/test/HLFIR/simplify-hlfir-intrinsics-count.fir (-127)
- (removed) flang/test/HLFIR/simplify-hlfir-intrinsics-maxloc.fir (-343)
- (removed) flang/test/HLFIR/simplify-hlfir-intrinsics-maxval.fir (-177)
- (removed) flang/test/HLFIR/simplify-hlfir-intrinsics-minloc.fir (-343)
- (removed) flang/test/HLFIR/simplify-hlfir-intrinsics-minval.fir (-177)
``````````diff
diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
index cd259b9dc6071..ac80873dc374f 100644
--- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h
+++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
@@ -301,11 +301,6 @@ mlir::Value genExtent(mlir::Location loc, fir::FirOpBuilder &builder,
mlir::Value genLBound(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity entity, unsigned dim);
-/// Compute the lower bounds of \p entity, which is an array of known rank.
-llvm::SmallVector<mlir::Value> genLBounds(mlir::Location loc,
- fir::FirOpBuilder &builder,
- hlfir::Entity entity);
-
/// Generate a vector of extents with index type from a fir.shape
/// of fir.shape_shift value.
llvm::SmallVector<mlir::Value> getIndexExtents(mlir::Location loc,
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 2a5e136c57c62..558ebcb876ddb 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -659,33 +659,6 @@ mlir::Value hlfir::genLBound(mlir::Location loc, fir::FirOpBuilder &builder,
return dimInfo.getLowerBound();
}
-llvm::SmallVector<mlir::Value> hlfir::genLBounds(mlir::Location loc,
- fir::FirOpBuilder &builder,
- hlfir::Entity entity) {
- assert(!entity.isAssumedRank() &&
- "cannot compute all lower bounds for assumed rank");
- assert(!entity.isScalar() && "expected an array entity");
- int rank = entity.getRank();
- mlir::Type idxTy = builder.getIndexType();
- if (!entity.mayHaveNonDefaultLowerBounds())
- return {static_cast<std::size_t>(rank),
- builder.createIntegerConstant(loc, idxTy, 1)};
-
- if (auto shape = tryRetrievingShapeOrShift(entity)) {
- auto lbounds = getExplicitLboundsFromShape(shape);
- if (!lbounds.empty())
- return lbounds;
- }
-
- if (entity.isMutableBox())
- entity = hlfir::derefPointersAndAllocatables(loc, builder, entity);
-
- llvm::SmallVector<mlir::Value, Fortran::common::maxRank> lbounds;
- fir::factory::genDimInfoFromBox(builder, loc, entity, &lbounds,
- /*extents=*/nullptr, /*strides=*/nullptr);
- return lbounds;
-}
-
void hlfir::genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder,
Entity entity,
llvm::SmallVectorImpl<mlir::Value> &result) {
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index 79aabd2981e1a..c489450384a35 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -772,6 +772,458 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
return mlir::success();
}
+using GenBodyFn =
+ std::function<mlir::Value(fir::FirOpBuilder &, mlir::Location, mlir::Value,
+ const llvm::SmallVectorImpl<mlir::Value> &)>;
+static mlir::Value generateReductionLoop(fir::FirOpBuilder &builder,
+ mlir::Location loc, mlir::Value init,
+ mlir::Value shape, GenBodyFn genBody) {
+ auto extents = hlfir::getIndexExtents(loc, builder, shape);
+ mlir::Value reduction = init;
+ mlir::IndexType idxTy = builder.getIndexType();
+ mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1);
+
+ // Create a reduction loop nest. We use one-based indices so that they can be
+ // passed to the elemental, and reverse the order so that they can be
+ // generated in column-major order for better performance.
+ llvm::SmallVector<mlir::Value> indices(extents.size(), mlir::Value{});
+ for (unsigned i = 0; i < extents.size(); ++i) {
+ auto loop = builder.create<fir::DoLoopOp>(
+ loc, oneIdx, extents[extents.size() - i - 1], oneIdx, false,
+ /*finalCountValue=*/false, reduction);
+ reduction = loop.getRegionIterArgs()[0];
+ indices[extents.size() - i - 1] = loop.getInductionVar();
+ // Set insertion point to the loop body so that the next loop
+ // is inserted inside the current one.
+ builder.setInsertionPointToStart(loop.getBody());
+ }
+
+ // Generate the body
+ reduction = genBody(builder, loc, reduction, indices);
+
+ // Unwind the loop nest.
+ for (unsigned i = 0; i < extents.size(); ++i) {
+ auto result = builder.create<fir::ResultOp>(loc, reduction);
+ auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
+ reduction = loop.getResult(0);
+ // Set insertion point after the loop operation that we have
+ // just processed.
+ builder.setInsertionPointAfter(loop.getOperation());
+ }
+
+ return reduction;
+}
+
+auto makeMinMaxInitValGenerator(bool isMax) {
+ return [isMax](fir::FirOpBuilder builder, mlir::Location loc,
+ mlir::Type elementType) -> mlir::Value {
+ if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
+ const llvm::fltSemantics &sem = ty.getFloatSemantics();
+ llvm::APFloat limit = llvm::APFloat::getInf(sem, /*Negative=*/isMax);
+ return builder.createRealConstant(loc, elementType, limit);
+ }
+ unsigned bits = elementType.getIntOrFloatBitWidth();
+ int64_t limitInt =
+ isMax ? llvm::APInt::getSignedMinValue(bits).getSExtValue()
+ : llvm::APInt::getSignedMaxValue(bits).getSExtValue();
+ return builder.createIntegerConstant(loc, elementType, limitInt);
+ };
+}
+
+mlir::Value generateMinMaxComparison(fir::FirOpBuilder builder,
+ mlir::Location loc, mlir::Value elem,
+ mlir::Value reduction, bool isMax) {
+ if (mlir::isa<mlir::FloatType>(reduction.getType())) {
+ // For FP reductions we want the first smallest value to be used, that
+ // is not NaN. A OGL/OLT condition will usually work for this unless all
+ // the values are Nan or Inf. This follows the same logic as
+ // NumericCompare for Minloc/Maxlox in extrema.cpp.
+ mlir::Value cmp = builder.create<mlir::arith::CmpFOp>(
+ loc,
+ isMax ? mlir::arith::CmpFPredicate::OGT
+ : mlir::arith::CmpFPredicate::OLT,
+ elem, reduction);
+ mlir::Value cmpNan = builder.create<mlir::arith::CmpFOp>(
+ loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction);
+ mlir::Value cmpNan2 = builder.create<mlir::arith::CmpFOp>(
+ loc, mlir::arith::CmpFPredicate::OEQ, elem, elem);
+ cmpNan = builder.create<mlir::arith::AndIOp>(loc, cmpNan, cmpNan2);
+ return builder.create<mlir::arith::OrIOp>(loc, cmp, cmpNan);
+ } else if (mlir::isa<mlir::IntegerType>(reduction.getType())) {
+ return builder.create<mlir::arith::CmpIOp>(
+ loc,
+ isMax ? mlir::arith::CmpIPredicate::sgt
+ : mlir::arith::CmpIPredicate::slt,
+ elem, reduction);
+ }
+ llvm_unreachable("unsupported type");
+}
+
+/// Given a reduction operation with an elemental/designate source, attempt to
+/// generate a do-loop to perform the operation inline.
+/// %e = hlfir.elemental %shape unordered
+/// %r = hlfir.count %e
+/// =>
+/// %r = for.do_loop %arg = 1 to bound(%shape) step 1 iter_args(%arg2 = init)
+/// %i = <inline elemental>
+/// %c = <reduce count> %i
+/// fir.result %c
+template <typename Op>
+class ReductionConversion : public mlir::OpRewritePattern<Op> {
+public:
+ using mlir::OpRewritePattern<Op>::OpRewritePattern;
+
+ llvm::LogicalResult
+ matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
+ mlir::Location loc = op.getLoc();
+ // Select source and validate its arguments.
+ mlir::Value source;
+ bool valid = false;
+ if constexpr (std::is_same_v<Op, hlfir::AnyOp> ||
+ std::is_same_v<Op, hlfir::AllOp> ||
+ std::is_same_v<Op, hlfir::CountOp>) {
+ source = op.getMask();
+ valid = !op.getDim();
+ } else if constexpr (std::is_same_v<Op, hlfir::MaxvalOp> ||
+ std::is_same_v<Op, hlfir::MinvalOp>) {
+ source = op.getArray();
+ valid = !op.getDim() && !op.getMask();
+ } else if constexpr (std::is_same_v<Op, hlfir::MaxlocOp> ||
+ std::is_same_v<Op, hlfir::MinlocOp>) {
+ source = op.getArray();
+ valid = !op.getDim() && !op.getMask() && !op.getBack();
+ }
+ if (!valid)
+ return rewriter.notifyMatchFailure(
+ op, "Currently does not accept optional arguments");
+
+ hlfir::ElementalOp elemental;
+ hlfir::DesignateOp designate;
+ mlir::Value shape;
+ if ((elemental = source.template getDefiningOp<hlfir::ElementalOp>())) {
+ shape = elemental.getOperand(0);
+ } else if ((designate =
+ source.template getDefiningOp<hlfir::DesignateOp>())) {
+ shape = designate.getShape();
+ } else {
+ return rewriter.notifyMatchFailure(op, "Did not find valid argument");
+ }
+
+ auto inlineSource =
+ [elemental,
+ &designate](fir::FirOpBuilder builder, mlir::Location loc,
+ const llvm::SmallVectorImpl<mlir::Value> &oneBasedIndices)
+ -> mlir::Value {
+ if (elemental) {
+ // Inline the elemental and get the value from it.
+ auto yield =
+ inlineElementalOp(loc, builder, elemental, oneBasedIndices);
+ auto tmp = yield.getElementValue();
+ yield->erase();
+ return tmp;
+ }
+ if (designate) {
+ // Create a designator over the array designator, then load the
+ // reference.
+ mlir::Value elementAddr = hlfir::getElementAt(
+ loc, builder, hlfir::Entity{designate.getResult()},
+ oneBasedIndices);
+ return builder.create<fir::LoadOp>(loc, elementAddr);
+ }
+ llvm_unreachable("unsupported type");
+ };
+
+ fir::FirOpBuilder builder{rewriter, op.getOperation()};
+
+ mlir::Value init;
+ GenBodyFn genBodyFn;
+ if constexpr (std::is_same_v<Op, hlfir::AnyOp>) {
+ init = builder.createIntegerConstant(loc, builder.getI1Type(), 0);
+ genBodyFn = [inlineSource](
+ fir::FirOpBuilder builder, mlir::Location loc,
+ mlir::Value reduction,
+ const llvm::SmallVectorImpl<mlir::Value> &oneBasedIndices)
+ -> mlir::Value {
+ // Conditionally set the reduction variable.
+ mlir::Value cond = builder.create<fir::ConvertOp>(
+ loc, builder.getI1Type(),
+ inlineSource(builder, loc, oneBasedIndices));
+ return builder.create<mlir::arith::OrIOp>(loc, reduction, cond);
+ };
+ } else if constexpr (std::is_same_v<Op, hlfir::AllOp>) {
+ init = builder.createIntegerConstant(loc, builder.getI1Type(), 1);
+ genBodyFn = [inlineSource](
+ fir::FirOpBuilder builder, mlir::Location loc,
+ mlir::Value reduction,
+ const llvm::SmallVectorImpl<mlir::Value> &oneBasedIndices)
+ -> mlir::Value {
+ // Conditionally set the reduction variable.
+ mlir::Value cond = builder.create<fir::ConvertOp>(
+ loc, builder.getI1Type(),
+ inlineSource(builder, loc, oneBasedIndices));
+ return builder.create<mlir::arith::AndIOp>(loc, reduction, cond);
+ };
+ } else if constexpr (std::is_same_v<Op, hlfir::CountOp>) {
+ init = builder.createIntegerConstant(loc, op.getType(), 0);
+ genBodyFn = [inlineSource](
+ fir::FirOpBuilder builder, mlir::Location loc,
+ mlir::Value reduction,
+ const llvm::SmallVectorImpl<mlir::Value> &oneBasedIndices)
+ -> mlir::Value {
+ // Conditionally add one to the current value
+ mlir::Value cond = builder.create<fir::ConvertOp>(
+ loc, builder.getI1Type(),
+ inlineSource(builder, loc, oneBasedIndices));
+ mlir::Value one =
+ builder.createIntegerConstant(loc, reduction.getType(), 1);
+ mlir::Value add1 =
+ builder.create<mlir::arith::AddIOp>(loc, reduction, one);
+ return builder.create<mlir::arith::SelectOp>(loc, cond, add1,
+ reduction);
+ };
+ } else if constexpr (std::is_same_v<Op, hlfir::MaxlocOp> ||
+ std::is_same_v<Op, hlfir::MinlocOp>) {
+ // TODO: implement minloc/maxloc conversion.
+ return rewriter.notifyMatchFailure(
+ op, "Currently minloc/maxloc is not handled");
+ } else if constexpr (std::is_same_v<Op, hlfir::MaxvalOp> ||
+ std::is_same_v<Op, hlfir::MinvalOp>) {
+ mlir::Type ty = op.getType();
+ if (!(mlir::isa<mlir::FloatType>(ty) ||
+ mlir::isa<mlir::IntegerType>(ty))) {
+ return rewriter.notifyMatchFailure(
+ op, "Type is not supported for Maxval or Minval yet");
+ }
+
+ bool isMax = std::is_same_v<Op, hlfir::MaxvalOp>;
+ init = makeMinMaxInitValGenerator(isMax)(builder, loc, ty);
+ genBodyFn = [inlineSource, isMax](
+ fir::FirOpBuilder builder, mlir::Location loc,
+ mlir::Value reduction,
+ const llvm::SmallVectorImpl<mlir::Value> &oneBasedIndices)
+ -> mlir::Value {
+ mlir::Value val = inlineSource(builder, loc, oneBasedIndices);
+ mlir::Value cmp =
+ generateMinMaxComparison(builder, loc, val, reduction, isMax);
+ return builder.create<mlir::arith::SelectOp>(loc, cmp, val, reduction);
+ };
+ } else {
+ llvm_unreachable("unsupported type");
+ }
+
+ mlir::Value res =
+ generateReductionLoop(builder, loc, init, shape, genBodyFn);
+ if (res.getType() != op.getType())
+ res = builder.create<fir::ConvertOp>(loc, op.getType(), res);
+
+ // Check if the op was the only user of the source (apart from a destroy),
+ // and remove it if so.
+ mlir::Operation *sourceOp = source.getDefiningOp();
+ mlir::Operation::user_range srcUsers = sourceOp->getUsers();
+ hlfir::DestroyOp srcDestroy;
+ if (std::distance(srcUsers.begin(), srcUsers.end()) == 2) {
+ srcDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*srcUsers.begin());
+ if (!srcDestroy)
+ srcDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++srcUsers.begin());
+ }
+
+ rewriter.replaceOp(op, res);
+ if (srcDestroy) {
+ rewriter.eraseOp(srcDestroy);
+ rewriter.eraseOp(sourceOp);
+ }
+ return mlir::success();
+ }
+};
+
+// Look for minloc(mask=elemental) and generate the minloc loop with
+// inlined elemental.
+// %e = hlfir.elemental %shape ({ ... })
+// %m = hlfir.minloc %array mask %e
+template <typename Op>
+class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
+public:
+ using mlir::OpRewritePattern<Op>::OpRewritePattern;
+
+ llvm::LogicalResult
+ matchAndRewrite(Op mloc, mlir::PatternRewriter &rewriter) const override {
+ if (!mloc.getMask() || mloc.getDim() || mloc.getBack())
+ return rewriter.notifyMatchFailure(mloc,
+ "Did not find valid minloc/maxloc");
+
+ bool isMax = std::is_same_v<Op, hlfir::MaxlocOp>;
+
+ auto elemental =
+ mloc.getMask().template getDefiningOp<hlfir::ElementalOp>();
+ if (!elemental || hlfir::elementalOpMustProduceTemp(elemental))
+ return rewriter.notifyMatchFailure(mloc, "Did not find elemental");
+
+ mlir::Value array = mloc.getArray();
+
+ unsigned rank = mlir::cast<hlfir::ExprType>(mloc.getType()).getShape()[0];
+ mlir::Type arrayType = array.getType();
+ if (!mlir::isa<fir::BoxType>(arrayType))
+ return rewriter.notifyMatchFailure(
+ mloc, "Currently requires a boxed type input");
+ mlir::Type elementType = hlfir::getFortranElementType(arrayType);
+ if (!fir::isa_trivial(elementType))
+ return rewriter.notifyMatchFailure(
+ mloc, "Character arrays are currently not handled");
+
+ mlir::Location loc = mloc.getLoc();
+ fir::FirOpBuilder builder{rewriter, mloc.getOperation()};
+ mlir::Value resultArr = builder.createTemporary(
+ loc, fir::SequenceType::get(
+ rank, hlfir::getFortranElementType(mloc.getType())));
+
+ auto init = makeMinMaxInitValGenerator(isMax);
+
+ auto genBodyOp =
+ [&rank, &resultArr, &elemental, isMax](
+ fir::FirOpBuilder builder, mlir::Location loc,
+ mlir::Type elementType, mlir::Value array, mlir::Value flagRef,
+ mlir::Value reduction,
+ const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value {
+ // We are in the innermost loop: generate the elemental inline
+ mlir::Value oneIdx =
+ builder.createIntegerConstant(loc, builder.getIndexType(), 1);
+ llvm::SmallVector<mlir::Value> oneBasedIndices;
+ llvm::transform(
+ indices, std::back_inserter(oneBasedIndices), [&](mlir::Value V) {
+ return builder.create<mlir::arith::AddIOp>(loc, V, oneIdx);
+ });
+ hlfir::YieldElementOp yield =
+ hlfir::inlineElementalOp(loc, builder, elemental, oneBasedIndices);
+ mlir::Value maskElem = yield.getElementValue();
+ yield->erase();
+
+ mlir::Type ifCompatType = builder.getI1Type();
+ mlir::Value ifCompatElem =
+ builder.create<fir::ConvertOp>(loc, ifCompatType, maskElem);
+
+ llvm::SmallVector<mlir::Type> resultsTy = {elementType, elementType};
+ fir::IfOp maskIfOp =
+ builder.create<fir::IfOp>(loc, elementType, ifCompatElem,
+ /*withElseRegion=*/true);
+ builder.setInsertionPointToStart(&maskIfOp.getThenRegion().front());
+
+ // Set flag that mask was true at some point
+ mlir::Value flagSet = builder.createIntegerConstant(
+ loc, mlir::cast<fir::ReferenceType>(flagRef.getType()).getEleTy(), 1);
+ mlir::Value isFirst = builder.create<fir::LoadOp>(loc, flagRef);
+ mlir::Value addr = hlfir::getElementAt(loc, builder, hlfir::Entity{array},
+ oneBasedIndices);
+ mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
+
+ // Compare with the max reduction value
+ mlir::Value cmp =
+ generateMinMaxComparison(builder, loc, elem, reduction, isMax);
+
+ // The condition used for the loop is isFirst || <the condition above>.
+ isFirst = builder.create<fir::ConvertOp>(loc, cmp.getType(), isFirst);
+ isFirst = builder.create<mlir::arith::XOrIOp>(
+ loc, isFirst, builder.createIntegerConstant(loc, cmp.getType(), 1));
+ cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, isFirst);
+
+ // Set the new coordinate to the result
+ fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cmp,
+ /*withElseRegion*/ true);
+
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ builder.create<fir::StoreOp>(loc, flagSet, flagRef);
+ mlir::Type resultElemTy =
+ hlfir::getFortranElementType(resultArr.getType());
+ mlir::Type returnRefTy = builder.getRefType(resultElemTy);
+ mlir::IndexType idxTy = builder.getIndexType();
+
+ for (unsigned int i = 0; i < rank; ++i) {
+ mlir::Value index = builder.createIntegerConstant(loc, idxTy, i + 1);
+ mlir::Value resultElemAddr = builder.create<hlfir::DesignateOp>(
+ loc, returnRefTy, resultArr, index);
+ mlir::Value fortranIndex = builder.create<fir::ConvertOp>(
+ loc, resultElemTy, oneBasedIndices[i]);
+ builder.create<fir::StoreOp>(loc, fortranIndex, resultElemAddr);
+ }
+ builder.create<fir::ResultOp>(loc, elem);
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ builder.create<fir::ResultOp>(loc, reduction);
+ builder.setInsertionPointAfter(ifOp);
+
+ // Close the mask if
+ builder.create<fir::ResultOp>(loc, ifOp.getResult(0));
+ builder.setInsertionPointToStart(&maskIfOp.getElseRegion().front());
...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/136218
More information about the flang-commits
mailing list