[flang-commits] [flang] [Flang] Maxloc elemental intrinsic lowering. (PR #79469)
via flang-commits
flang-commits at lists.llvm.org
Thu Jan 25 08:55:39 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: David Green (davemgreen)
<details>
<summary>Changes</summary>
This is an extension to #<!-- -->74828 to handle maxloc too, to keep the minloc and maxloc symmetric.
---
Patch is 29.93 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/79469.diff
4 Files Affected:
- (modified) flang/include/flang/Optimizer/Support/Utils.h (+128-7)
- (modified) flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp (+38-26)
- (modified) flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp (-128)
- (added) flang/test/HLFIR/maxloc-elemental.fir (+140)
``````````diff
diff --git a/flang/include/flang/Optimizer/Support/Utils.h b/flang/include/flang/Optimizer/Support/Utils.h
index e31121260acdae7..b50f297a7d31410 100644
--- a/flang/include/flang/Optimizer/Support/Utils.h
+++ b/flang/include/flang/Optimizer/Support/Utils.h
@@ -18,6 +18,7 @@
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "flang/Optimizer/Support/FatalError.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -144,13 +145,133 @@ using AddrGeneratorTy = llvm::function_ref<mlir::Value(
mlir::Value)>;
// Produces a loop nest for a Minloc intrinsic.
-void genMinMaxlocReductionLoop(fir::FirOpBuilder &builder, mlir::Value array,
- InitValGeneratorTy initVal,
- MinlocBodyOpGeneratorTy genBody,
- fir::AddrGeneratorTy getAddrFn, unsigned rank,
- mlir::Type elementType, mlir::Location loc,
- mlir::Type maskElemType, mlir::Value resultArr,
- bool maskMayBeLogicalScalar);
+inline void genMinMaxlocReductionLoop(
+ fir::FirOpBuilder &builder, mlir::Value array,
+ fir::InitValGeneratorTy initVal, fir::MinlocBodyOpGeneratorTy genBody,
+ fir::AddrGeneratorTy getAddrFn, unsigned rank, mlir::Type elementType,
+ mlir::Location loc, mlir::Type maskElemType, mlir::Value resultArr,
+ bool maskMayBeLogicalScalar) {
+ mlir::IndexType idxTy = builder.getIndexType();
+
+ mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
+
+ fir::SequenceType::Shape flatShape(rank,
+ fir::SequenceType::getUnknownExtent());
+ mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
+ mlir::Type boxArrTy = fir::BoxType::get(arrTy);
+ array = builder.create<fir::ConvertOp>(loc, boxArrTy, array);
+
+ mlir::Type resultElemType = hlfir::getFortranElementType(resultArr.getType());
+ mlir::Value flagSet = builder.createIntegerConstant(loc, resultElemType, 1);
+ mlir::Value zero = builder.createIntegerConstant(loc, resultElemType, 0);
+ mlir::Value flagRef = builder.createTemporary(loc, resultElemType);
+ builder.create<fir::StoreOp>(loc, zero, flagRef);
+
+ mlir::Value init = initVal(builder, loc, elementType);
+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds;
+
+ assert(rank > 0 && "rank cannot be zero");
+ mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
+
+ // Compute all the upper bounds before the loop nest.
+ // It is not strictly necessary for performance, since the loop nest
+ // does not have any store operations and any LICM optimization
+ // should be able to optimize the redundancy.
+ for (unsigned i = 0; i < rank; ++i) {
+ mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
+ auto dims =
+ builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
+ mlir::Value len = dims.getResult(1);
+ // We use C indexing here, so len-1 as loopcount
+ mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
+ bounds.push_back(loopCount);
+ }
+ // Create a loop nest consisting of OP operations.
+ // Collect the loops' induction variables into indices array,
+ // which will be used in the innermost loop to load the input
+ // array's element.
+ // The loops are generated such that the innermost loop processes
+ // the 0 dimension.
+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
+ for (unsigned i = rank; 0 < i; --i) {
+ mlir::Value step = one;
+ mlir::Value loopCount = bounds[i - 1];
+ auto loop =
+ builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step, false,
+ /*finalCountValue=*/false, init);
+ init = loop.getRegionIterArgs()[0];
+ indices.push_back(loop.getInductionVar());
+ // Set insertion point to the loop body so that the next loop
+ // is inserted inside the current one.
+ builder.setInsertionPointToStart(loop.getBody());
+ }
+
+ // Reverse the indices such that they are ordered as:
+ // <dim-0-idx, dim-1-idx, ...>
+ std::reverse(indices.begin(), indices.end());
+ mlir::Value reductionVal =
+ genBody(builder, loc, elementType, array, flagRef, init, indices);
+
+ // Unwind the loop nest and insert ResultOp on each level
+ // to return the updated value of the reduction to the enclosing
+ // loops.
+ for (unsigned i = 0; i < rank; ++i) {
+ auto result = builder.create<fir::ResultOp>(loc, reductionVal);
+ // Proceed to the outer loop.
+ auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
+ reductionVal = loop.getResult(0);
+ // Set insertion point after the loop operation that we have
+ // just processed.
+ builder.setInsertionPointAfter(loop.getOperation());
+ }
+ // End of loop nest. The insertion point is after the outermost loop.
+ if (maskMayBeLogicalScalar) {
+ if (fir::IfOp ifOp =
+ mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp())) {
+ builder.create<fir::ResultOp>(loc, reductionVal);
+ builder.setInsertionPointAfter(ifOp);
+ // Redefine flagSet to escape scope of ifOp
+ flagSet = builder.createIntegerConstant(loc, resultElemType, 1);
+ reductionVal = ifOp.getResult(0);
+ }
+ }
+
+ // Check for case where array was full of max values.
+ // flag will be 0 if mask was never true, 1 if mask was true as some point,
+ // this is needed to avoid catching cases where we didn't access any elements
+ // e.g. mask=.FALSE.
+ mlir::Value flagValue =
+ builder.create<fir::LoadOp>(loc, resultElemType, flagRef);
+ mlir::Value flagCmp = builder.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::eq, flagValue, flagSet);
+ fir::IfOp ifMaskTrueOp =
+ builder.create<fir::IfOp>(loc, flagCmp, /*withElseRegion=*/false);
+ builder.setInsertionPointToStart(&ifMaskTrueOp.getThenRegion().front());
+
+ mlir::Value testInit = initVal(builder, loc, elementType);
+ fir::IfOp ifMinSetOp;
+ if (elementType.isa<mlir::FloatType>()) {
+ mlir::Value cmp = builder.create<mlir::arith::CmpFOp>(
+ loc, mlir::arith::CmpFPredicate::OEQ, testInit, reductionVal);
+ ifMinSetOp = builder.create<fir::IfOp>(loc, cmp,
+ /*withElseRegion*/ false);
+ } else {
+ mlir::Value cmp = builder.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::eq, testInit, reductionVal);
+ ifMinSetOp = builder.create<fir::IfOp>(loc, cmp,
+ /*withElseRegion*/ false);
+ }
+ builder.setInsertionPointToStart(&ifMinSetOp.getThenRegion().front());
+
+ // Load output array with 1s instead of 0s
+ for (unsigned int i = 0; i < rank; ++i) {
+ mlir::Value index = builder.createIntegerConstant(loc, idxTy, i);
+ mlir::Value resultElemAddr =
+ getAddrFn(builder, loc, resultElemType, resultArr, index);
+ builder.create<fir::StoreOp>(loc, flagSet, resultElemAddr);
+ }
+ builder.setInsertionPointAfter(ifMaskTrueOp);
+}
} // namespace fir
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index 3f4ec4f3bccc80f..0915f25d49375de 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -812,38 +812,41 @@ class ReductionElementalConversion : public mlir::OpRewritePattern<Op> {
// inlined elemental.
// %e = hlfir.elemental %shape ({ ... })
// %m = hlfir.minloc %array mask %e
-class MinMaxlocElementalConversion
- : public mlir::OpRewritePattern<hlfir::MinlocOp> {
+template <typename Op>
+class MinMaxlocElementalConversion : public mlir::OpRewritePattern<Op> {
public:
- using mlir::OpRewritePattern<hlfir::MinlocOp>::OpRewritePattern;
+ using mlir::OpRewritePattern<Op>::OpRewritePattern;
mlir::LogicalResult
- matchAndRewrite(hlfir::MinlocOp minloc,
- mlir::PatternRewriter &rewriter) const override {
- if (!minloc.getMask() || minloc.getDim() || minloc.getBack())
- return rewriter.notifyMatchFailure(minloc, "Did not find valid minloc");
+ matchAndRewrite(Op mloc, mlir::PatternRewriter &rewriter) const override {
+ if (!mloc.getMask() || mloc.getDim() || mloc.getBack())
+ return rewriter.notifyMatchFailure(mloc,
+ "Did not find valid minloc/maxloc");
- auto elemental = minloc.getMask().getDefiningOp<hlfir::ElementalOp>();
+ constexpr bool isMax = std::is_same_v<Op, hlfir::MaxlocOp>;
+
+ auto elemental =
+ mloc.getMask().template getDefiningOp<hlfir::ElementalOp>();
if (!elemental || hlfir::elementalOpMustProduceTemp(elemental))
- return rewriter.notifyMatchFailure(minloc, "Did not find elemental");
+ return rewriter.notifyMatchFailure(mloc, "Did not find elemental");
- mlir::Value array = minloc.getArray();
+ mlir::Value array = mloc.getArray();
- unsigned rank = mlir::cast<hlfir::ExprType>(minloc.getType()).getShape()[0];
+ unsigned rank = mlir::cast<hlfir::ExprType>(mloc.getType()).getShape()[0];
mlir::Type arrayType = array.getType();
if (!arrayType.isa<fir::BoxType>())
return rewriter.notifyMatchFailure(
- minloc, "Currently requires a boxed type input");
+ mloc, "Currently requires a boxed type input");
mlir::Type elementType = hlfir::getFortranElementType(arrayType);
if (!fir::isa_trivial(elementType))
return rewriter.notifyMatchFailure(
- minloc, "Character arrays are currently not handled");
+ mloc, "Character arrays are currently not handled");
- mlir::Location loc = minloc.getLoc();
- fir::FirOpBuilder builder{rewriter, minloc.getOperation()};
+ mlir::Location loc = mloc.getLoc();
+ fir::FirOpBuilder builder{rewriter, mloc.getOperation()};
mlir::Value resultArr = builder.createTemporary(
loc, fir::SequenceType::get(
- rank, hlfir::getFortranElementType(minloc.getType())));
+ rank, hlfir::getFortranElementType(mloc.getType())));
auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
mlir::Type elementType) {
@@ -851,11 +854,13 @@ class MinMaxlocElementalConversion
const llvm::fltSemantics &sem = ty.getFloatSemantics();
return builder.createRealConstant(
loc, elementType,
- llvm::APFloat::getLargest(sem, /*Negative=*/false));
+ llvm::APFloat::getLargest(sem, /*Negative=*/!isMax));
}
unsigned bits = elementType.getIntOrFloatBitWidth();
- int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
- return builder.createIntegerConstant(loc, elementType, maxInt);
+ int64_t limitInt =
+ isMax ? llvm::APInt::getSignedMinValue(bits).getSExtValue()
+ : llvm::APInt::getSignedMaxValue(bits).getSExtValue();
+ return builder.createIntegerConstant(loc, elementType, limitInt);
};
auto genBodyOp =
@@ -899,10 +904,16 @@ class MinMaxlocElementalConversion
mlir::Value cmp;
if (elementType.isa<mlir::FloatType>()) {
cmp = builder.create<mlir::arith::CmpFOp>(
- loc, mlir::arith::CmpFPredicate::OLT, elem, reduction);
+ loc,
+ isMax ? mlir::arith::CmpFPredicate::OGT
+ : mlir::arith::CmpFPredicate::OLT,
+ elem, reduction);
} else if (elementType.isa<mlir::IntegerType>()) {
cmp = builder.create<mlir::arith::CmpIOp>(
- loc, mlir::arith::CmpIPredicate::slt, elem, reduction);
+ loc,
+ isMax ? mlir::arith::CmpIPredicate::sgt
+ : mlir::arith::CmpIPredicate::slt,
+ elem, reduction);
} else {
llvm_unreachable("unsupported type");
}
@@ -975,15 +986,15 @@ class MinMaxlocElementalConversion
// AsExpr for the temporary resultArr.
llvm::SmallVector<hlfir::DestroyOp> destroys;
llvm::SmallVector<hlfir::AssignOp> assigns;
- for (auto user : minloc->getUsers()) {
+ for (auto user : mloc->getUsers()) {
if (auto destroy = mlir::dyn_cast<hlfir::DestroyOp>(user))
destroys.push_back(destroy);
else if (auto assign = mlir::dyn_cast<hlfir::AssignOp>(user))
assigns.push_back(assign);
}
- // Check if the minloc was the only user of the elemental (apart from a
- // destroy), and remove it if so.
+ // Check if the minloc/maxloc was the only user of the elemental (apart from
+ // a destroy), and remove it if so.
mlir::Operation::user_range elemUsers = elemental->getUsers();
hlfir::DestroyOp elemDestroy;
if (std::distance(elemUsers.begin(), elemUsers.end()) == 2) {
@@ -996,7 +1007,7 @@ class MinMaxlocElementalConversion
rewriter.eraseOp(d);
for (auto a : assigns)
a.setOperand(0, resultArr);
- rewriter.replaceOp(minloc, asExpr);
+ rewriter.replaceOp(mloc, asExpr);
if (elemDestroy) {
rewriter.eraseOp(elemDestroy);
rewriter.eraseOp(elemental);
@@ -1030,7 +1041,8 @@ class OptimizedBufferizationPass
patterns.insert<ReductionElementalConversion<hlfir::CountOp>>(context);
patterns.insert<ReductionElementalConversion<hlfir::AnyOp>>(context);
patterns.insert<ReductionElementalConversion<hlfir::AllOp>>(context);
- patterns.insert<MinMaxlocElementalConversion>(context);
+ patterns.insert<MinMaxlocElementalConversion<hlfir::MinlocOp>>(context);
+ patterns.insert<MinMaxlocElementalConversion<hlfir::MaxlocOp>>(context);
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
func, std::move(patterns), config))) {
diff --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
index 2301e7146f14101..b415463075d68f9 100644
--- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
+++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
@@ -353,134 +353,6 @@ genReductionLoop(fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp,
builder.create<mlir::func::ReturnOp>(loc, results[resultIndex]);
}
-void fir::genMinMaxlocReductionLoop(
- fir::FirOpBuilder &builder, mlir::Value array,
- fir::InitValGeneratorTy initVal, fir::MinlocBodyOpGeneratorTy genBody,
- fir::AddrGeneratorTy getAddrFn, unsigned rank, mlir::Type elementType,
- mlir::Location loc, mlir::Type maskElemType, mlir::Value resultArr,
- bool maskMayBeLogicalScalar) {
- mlir::IndexType idxTy = builder.getIndexType();
-
- mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
-
- fir::SequenceType::Shape flatShape(rank,
- fir::SequenceType::getUnknownExtent());
- mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
- mlir::Type boxArrTy = fir::BoxType::get(arrTy);
- array = builder.create<fir::ConvertOp>(loc, boxArrTy, array);
-
- mlir::Type resultElemType = hlfir::getFortranElementType(resultArr.getType());
- mlir::Value flagSet = builder.createIntegerConstant(loc, resultElemType, 1);
- mlir::Value zero = builder.createIntegerConstant(loc, resultElemType, 0);
- mlir::Value flagRef = builder.createTemporary(loc, resultElemType);
- builder.create<fir::StoreOp>(loc, zero, flagRef);
-
- mlir::Value init = initVal(builder, loc, elementType);
- llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds;
-
- assert(rank > 0 && "rank cannot be zero");
- mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
-
- // Compute all the upper bounds before the loop nest.
- // It is not strictly necessary for performance, since the loop nest
- // does not have any store operations and any LICM optimization
- // should be able to optimize the redundancy.
- for (unsigned i = 0; i < rank; ++i) {
- mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
- auto dims =
- builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
- mlir::Value len = dims.getResult(1);
- // We use C indexing here, so len-1 as loopcount
- mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
- bounds.push_back(loopCount);
- }
- // Create a loop nest consisting of OP operations.
- // Collect the loops' induction variables into indices array,
- // which will be used in the innermost loop to load the input
- // array's element.
- // The loops are generated such that the innermost loop processes
- // the 0 dimension.
- llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
- for (unsigned i = rank; 0 < i; --i) {
- mlir::Value step = one;
- mlir::Value loopCount = bounds[i - 1];
- auto loop =
- builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step, false,
- /*finalCountValue=*/false, init);
- init = loop.getRegionIterArgs()[0];
- indices.push_back(loop.getInductionVar());
- // Set insertion point to the loop body so that the next loop
- // is inserted inside the current one.
- builder.setInsertionPointToStart(loop.getBody());
- }
-
- // Reverse the indices such that they are ordered as:
- // <dim-0-idx, dim-1-idx, ...>
- std::reverse(indices.begin(), indices.end());
- mlir::Value reductionVal =
- genBody(builder, loc, elementType, array, flagRef, init, indices);
-
- // Unwind the loop nest and insert ResultOp on each level
- // to return the updated value of the reduction to the enclosing
- // loops.
- for (unsigned i = 0; i < rank; ++i) {
- auto result = builder.create<fir::ResultOp>(loc, reductionVal);
- // Proceed to the outer loop.
- auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
- reductionVal = loop.getResult(0);
- // Set insertion point after the loop operation that we have
- // just processed.
- builder.setInsertionPointAfter(loop.getOperation());
- }
- // End of loop nest. The insertion point is after the outermost loop.
- if (maskMayBeLogicalScalar) {
- if (fir::IfOp ifOp =
- mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp())) {
- builder.create<fir::ResultOp>(loc, reductionVal);
- builder.setInsertionPointAfter(ifOp);
- // Redefine flagSet to escape scope of ifOp
- flagSet = builder.createIntegerConstant(loc, resultElemType, 1);
- reductionVal = ifOp.getResult(0);
- }
- }
-
- // Check for case where array was full of max values.
- // flag will be 0 if mask was never true, 1 if mask was true as some point,
- // this is needed to avoid catching cases where we didn't access any elements
- // e.g. mask=.FALSE.
- mlir::Value flagValue =
- builder.create<fir::LoadOp>(loc, resultElemType, flagRef);
- mlir::Value flagCmp = builder.create<mlir::arith::CmpIOp>(
- loc, mlir::arith::CmpIPredicate::eq, flagValue, flagSet);
- fir::IfOp ifMaskTrueOp =
- builder.create<fir::IfOp>(loc, flagCmp, /*withElseRegion=*/false);
- builder.setInsertionPointToStart(&ifMaskTrueOp.getThenRegion().front());
-
- mlir::Value testInit = initVal(builder, loc, elementType);
- fir::IfOp ifMinSetOp;
- if (elementType.isa<mlir::FloatType>()) {
- mlir::Value cmp = builder.create<mlir::arith::CmpFOp>(
- loc, mlir::arith::CmpFPredicate::OEQ, testInit, reductionVal);
- ifMinSetOp = builder.create<fir::IfOp>(loc, cmp,
- /*withElseRegion*/ false);
- } else {
- mlir::Value cmp = builder.create<mlir::arith::CmpIOp>(
- loc, mlir::arith::CmpIPredicate::eq, testInit, reductionVal);
- ifMinSetOp = builder.create<fir::IfOp>(loc, cmp,
- /*withElseRegion*/ false);
- }
- builder.setInsertionPointToStart(&ifMinSetOp.getThenRegion().front());
-
- // Load output array with 1s instead of 0s
- for (unsigned int i = 0; i < rank; ++i) {
- mlir::Value index = builder.createIntegerConstant(loc, idxTy, i);
- mlir::Value resultElemAddr =
- getAddrFn(builder, loc, resultElemType, resultArr, index);
- builder.create<fir::StoreOp>(loc, flagSet, resultElemAddr);
- }
- builder.setInsertionPointAfter(ifMaskTrueOp);
-}
-
s...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/79469
More information about the flang-commits
mailing list