[flang-commits] [flang] [flang] Inline minval/maxval over elemental/designate (PR #103503)
via flang-commits
flang-commits at lists.llvm.org
Tue Aug 13 20:08:07 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: None (khaki3)
<details>
<summary>Changes</summary>
This PR intends to optimize away `hlfir.elemental` operations, which leave temporary buffers (`allocmem`) in FIR. We typically see elemental operations in the arguments of reduction intrinsics, so extending `OptimizedBufferization` shall be the first solution to get heap-free code.
Here we newly handle `minval`/`maxval` along with other reduction intrinsics. Those functions over elemental become do loops. Furthermore, we take the same action with `hlfir.designate` in order to inline more intrinsics, which otherwise call runtime routines.
---
Patch is 28.70 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/103503.diff
3 Files Affected:
- (modified) flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp (+170-94)
- (added) flang/test/HLFIR/maxval-elemental.fir (+95)
- (added) flang/test/HLFIR/minval-elemental.fir (+95)
``````````diff
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index c5b809514c54c6..273079585b7035 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -702,8 +702,53 @@ static mlir::Value generateReductionLoop(fir::FirOpBuilder &builder,
return reduction;
}
-/// Given a reduction operation with an elemental mask, attempt to generate a
-/// do-loop to perform the operation inline.
+auto makeMinMaxInitValGenerator(bool isMax) {
+ return [isMax](fir::FirOpBuilder builder, mlir::Location loc,
+ mlir::Type elementType) -> mlir::Value {
+ if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
+ const llvm::fltSemantics &sem = ty.getFloatSemantics();
+ llvm::APFloat limit = llvm::APFloat::getInf(sem, /*Negative=*/isMax);
+ return builder.createRealConstant(loc, elementType, limit);
+ }
+ unsigned bits = elementType.getIntOrFloatBitWidth();
+ int64_t limitInt =
+ isMax ? llvm::APInt::getSignedMinValue(bits).getSExtValue()
+ : llvm::APInt::getSignedMaxValue(bits).getSExtValue();
+ return builder.createIntegerConstant(loc, elementType, limitInt);
+ };
+}
+
+mlir::Value generateMinMaxComparison(fir::FirOpBuilder builder,
+ mlir::Location loc, mlir::Value elem,
+ mlir::Value reduction, bool isMax) {
+ if (mlir::isa<mlir::FloatType>(reduction.getType())) {
+ // For FP reductions we want the first smallest value to be used, that
+ // is not NaN. A OGL/OLT condition will usually work for this unless all
+ // the values are Nan or Inf. This follows the same logic as
+ // NumericCompare for Minloc/Maxlox in extrema.cpp.
+ mlir::Value cmp = builder.create<mlir::arith::CmpFOp>(
+ loc,
+ isMax ? mlir::arith::CmpFPredicate::OGT
+ : mlir::arith::CmpFPredicate::OLT,
+ elem, reduction);
+ mlir::Value cmpNan = builder.create<mlir::arith::CmpFOp>(
+ loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction);
+ mlir::Value cmpNan2 = builder.create<mlir::arith::CmpFOp>(
+ loc, mlir::arith::CmpFPredicate::OEQ, elem, elem);
+ cmpNan = builder.create<mlir::arith::AndIOp>(loc, cmpNan, cmpNan2);
+ return builder.create<mlir::arith::OrIOp>(loc, cmp, cmpNan);
+ } else if (mlir::isa<mlir::IntegerType>(reduction.getType())) {
+ return builder.create<mlir::arith::CmpIOp>(
+ loc,
+ isMax ? mlir::arith::CmpIPredicate::sgt
+ : mlir::arith::CmpIPredicate::slt,
+ elem, reduction);
+ }
+ llvm_unreachable("unsupported type");
+}
+
+/// Given a reduction operation with an elemental/designate source, attempt to
+/// generate a do-loop to perform the operation inline.
/// %e = hlfir.elemental %shape unordered
/// %r = hlfir.count %e
/// =>
@@ -712,17 +757,66 @@ static mlir::Value generateReductionLoop(fir::FirOpBuilder &builder,
/// %c = <reduce count> %i
/// fir.result %c
template <typename Op>
-class ReductionElementalConversion : public mlir::OpRewritePattern<Op> {
+class ReductionConversion : public mlir::OpRewritePattern<Op> {
public:
using mlir::OpRewritePattern<Op>::OpRewritePattern;
llvm::LogicalResult
matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
mlir::Location loc = op.getLoc();
- hlfir::ElementalOp elemental =
- op.getMask().template getDefiningOp<hlfir::ElementalOp>();
- if (!elemental || op.getDim())
- return rewriter.notifyMatchFailure(op, "Did not find valid elemental");
+ // Select source and validate its arguments.
+ mlir::Value source;
+ bool valid = false;
+ if constexpr (std::is_same_v<Op, hlfir::AnyOp> ||
+ std::is_same_v<Op, hlfir::AllOp> ||
+ std::is_same_v<Op, hlfir::CountOp>) {
+ source = op.getMask();
+ valid = !op.getDim();
+ } else if constexpr (std::is_same_v<Op, hlfir::MaxvalOp> ||
+ std::is_same_v<Op, hlfir::MinvalOp>) {
+ source = op.getArray();
+ valid = !op.getDim() && !op.getDim();
+ } else if constexpr (std::is_same_v<Op, hlfir::MaxlocOp> ||
+ std::is_same_v<Op, hlfir::MinlocOp>) {
+ source = op.getArray();
+ valid = !op.getDim() && !op.getDim() && !op.getBack();
+ }
+ if (!valid)
+ return rewriter.notifyMatchFailure(
+ op, "Currently does not accept optional arguments");
+
+ hlfir::ElementalOp elemental;
+ hlfir::DesignateOp designate;
+ mlir::Value shape;
+ if ((elemental = source.template getDefiningOp<hlfir::ElementalOp>())) {
+ shape = elemental.getOperand(0);
+ } else if ((designate =
+ source.template getDefiningOp<hlfir::DesignateOp>())) {
+ shape = designate.getShape();
+ } else {
+ return rewriter.notifyMatchFailure(op, "Did not find valid argument");
+ }
+
+ auto inlineSource =
+ [elemental, &designate](
+ fir::FirOpBuilder builder, mlir::Location loc,
+ const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value {
+ if (elemental) {
+ // Inline the elemental and get the value from it.
+ auto yield = inlineElementalOp(loc, builder, elemental, indices);
+ auto tmp = yield.getElementValue();
+ yield->erase();
+ return tmp;
+ }
+ if (designate) {
+ // Create a designator over designator, then load the reference.
+ auto resEntity = hlfir::Entity{designate.getResult()};
+ auto tmp = builder.create<hlfir::DesignateOp>(
+ loc, getVariableElementType(resEntity), designate, indices);
+ return builder.create<fir::LoadOp>(loc, tmp);
+ }
+ llvm_unreachable("unsupported type");
+ };
fir::KindMapping kindMap =
fir::getKindMapping(op->template getParentOfType<mlir::ModuleOp>());
@@ -732,47 +826,38 @@ class ReductionElementalConversion : public mlir::OpRewritePattern<Op> {
GenBodyFn genBodyFn;
if constexpr (std::is_same_v<Op, hlfir::AnyOp>) {
init = builder.createIntegerConstant(loc, builder.getI1Type(), 0);
- genBodyFn = [elemental](fir::FirOpBuilder builder, mlir::Location loc,
- mlir::Value reduction,
- const llvm::SmallVectorImpl<mlir::Value> &indices)
+ genBodyFn =
+ [inlineSource](fir::FirOpBuilder builder, mlir::Location loc,
+ mlir::Value reduction,
+ const llvm::SmallVectorImpl<mlir::Value> &indices)
-> mlir::Value {
- // Inline the elemental and get the condition from it.
- auto yield = inlineElementalOp(loc, builder, elemental, indices);
- mlir::Value cond = builder.create<fir::ConvertOp>(
- loc, builder.getI1Type(), yield.getElementValue());
- yield->erase();
-
// Conditionally set the reduction variable.
+ mlir::Value cond = builder.create<fir::ConvertOp>(
+ loc, builder.getI1Type(), inlineSource(builder, loc, indices));
return builder.create<mlir::arith::OrIOp>(loc, reduction, cond);
};
} else if constexpr (std::is_same_v<Op, hlfir::AllOp>) {
init = builder.createIntegerConstant(loc, builder.getI1Type(), 1);
- genBodyFn = [elemental](fir::FirOpBuilder builder, mlir::Location loc,
- mlir::Value reduction,
- const llvm::SmallVectorImpl<mlir::Value> &indices)
+ genBodyFn =
+ [inlineSource](fir::FirOpBuilder builder, mlir::Location loc,
+ mlir::Value reduction,
+ const llvm::SmallVectorImpl<mlir::Value> &indices)
-> mlir::Value {
- // Inline the elemental and get the condition from it.
- auto yield = inlineElementalOp(loc, builder, elemental, indices);
- mlir::Value cond = builder.create<fir::ConvertOp>(
- loc, builder.getI1Type(), yield.getElementValue());
- yield->erase();
-
// Conditionally set the reduction variable.
+ mlir::Value cond = builder.create<fir::ConvertOp>(
+ loc, builder.getI1Type(), inlineSource(builder, loc, indices));
return builder.create<mlir::arith::AndIOp>(loc, reduction, cond);
};
} else if constexpr (std::is_same_v<Op, hlfir::CountOp>) {
init = builder.createIntegerConstant(loc, op.getType(), 0);
- genBodyFn = [elemental](fir::FirOpBuilder builder, mlir::Location loc,
- mlir::Value reduction,
- const llvm::SmallVectorImpl<mlir::Value> &indices)
+ genBodyFn =
+ [inlineSource](fir::FirOpBuilder builder, mlir::Location loc,
+ mlir::Value reduction,
+ const llvm::SmallVectorImpl<mlir::Value> &indices)
-> mlir::Value {
- // Inline the elemental and get the condition from it.
- auto yield = inlineElementalOp(loc, builder, elemental, indices);
- mlir::Value cond = builder.create<fir::ConvertOp>(
- loc, builder.getI1Type(), yield.getElementValue());
- yield->erase();
-
// Conditionally add one to the current value
+ mlir::Value cond = builder.create<fir::ConvertOp>(
+ loc, builder.getI1Type(), inlineSource(builder, loc, indices));
mlir::Value one =
builder.createIntegerConstant(loc, reduction.getType(), 1);
mlir::Value add1 =
@@ -780,29 +865,49 @@ class ReductionElementalConversion : public mlir::OpRewritePattern<Op> {
return builder.create<mlir::arith::SelectOp>(loc, cond, add1,
reduction);
};
+ } else if constexpr (std::is_same_v<Op, hlfir::MaxlocOp> ||
+ std::is_same_v<Op, hlfir::MinlocOp>) {
+ // TODO: implement minloc/maxloc conversion.
+ return rewriter.notifyMatchFailure(
+ op, "Currently minloc/maxloc is not handled");
+ } else if constexpr (std::is_same_v<Op, hlfir::MaxvalOp> ||
+ std::is_same_v<Op, hlfir::MinvalOp>) {
+ bool isMax = std::is_same_v<Op, hlfir::MaxvalOp>;
+ init = makeMinMaxInitValGenerator(isMax)(builder, loc, op.getType());
+ genBodyFn = [inlineSource,
+ isMax](fir::FirOpBuilder builder, mlir::Location loc,
+ mlir::Value reduction,
+ const llvm::SmallVectorImpl<mlir::Value> &indices)
+ -> mlir::Value {
+ mlir::Value val = inlineSource(builder, loc, indices);
+ mlir::Value cmp =
+ generateMinMaxComparison(builder, loc, val, reduction, isMax);
+ return builder.create<mlir::arith::SelectOp>(loc, cmp, val, reduction);
+ };
} else {
- return mlir::failure();
+ llvm_unreachable("unsupported type");
}
- mlir::Value res = generateReductionLoop(builder, loc, init,
- elemental.getOperand(0), genBodyFn);
+ mlir::Value res =
+ generateReductionLoop(builder, loc, init, shape, genBodyFn);
if (res.getType() != op.getType())
res = builder.create<fir::ConvertOp>(loc, op.getType(), res);
- // Check if the op was the only user of the elemental (apart from a
- // destroy), and remove it if so.
- mlir::Operation::user_range elemUsers = elemental->getUsers();
- hlfir::DestroyOp elemDestroy;
- if (std::distance(elemUsers.begin(), elemUsers.end()) == 2) {
- elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*elemUsers.begin());
- if (!elemDestroy)
- elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++elemUsers.begin());
+ // Check if the op was the only user of the source (apart from a destroy),
+ // and remove it if so.
+ mlir::Operation *sourceOp = source.getDefiningOp();
+ mlir::Operation::user_range srcUsers = sourceOp->getUsers();
+ hlfir::DestroyOp srcDestroy;
+ if (std::distance(srcUsers.begin(), srcUsers.end()) == 2) {
+ srcDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*srcUsers.begin());
+ if (!srcDestroy)
+ srcDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++srcUsers.begin());
}
rewriter.replaceOp(op, res);
- if (elemDestroy) {
- rewriter.eraseOp(elemDestroy);
- rewriter.eraseOp(elemental);
+ if (srcDestroy) {
+ rewriter.eraseOp(srcDestroy);
+ rewriter.eraseOp(sourceOp);
}
return mlir::success();
}
@@ -813,7 +918,7 @@ class ReductionElementalConversion : public mlir::OpRewritePattern<Op> {
// %e = hlfir.elemental %shape ({ ... })
// %m = hlfir.minloc %array mask %e
template <typename Op>
-class MinMaxlocElementalConversion : public mlir::OpRewritePattern<Op> {
+class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
public:
using mlir::OpRewritePattern<Op>::OpRewritePattern;
@@ -848,19 +953,7 @@ class MinMaxlocElementalConversion : public mlir::OpRewritePattern<Op> {
loc, fir::SequenceType::get(
rank, hlfir::getFortranElementType(mloc.getType())));
- auto init = [isMax](fir::FirOpBuilder builder, mlir::Location loc,
- mlir::Type elementType) {
- if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
- const llvm::fltSemantics &sem = ty.getFloatSemantics();
- llvm::APFloat limit = llvm::APFloat::getInf(sem, /*Negative=*/isMax);
- return builder.createRealConstant(loc, elementType, limit);
- }
- unsigned bits = elementType.getIntOrFloatBitWidth();
- int64_t limitInt =
- isMax ? llvm::APInt::getSignedMinValue(bits).getSExtValue()
- : llvm::APInt::getSignedMaxValue(bits).getSExtValue();
- return builder.createIntegerConstant(loc, elementType, limitInt);
- };
+ auto init = makeMinMaxInitValGenerator(isMax);
auto genBodyOp =
[&rank, &resultArr, &elemental, isMax](
@@ -900,33 +993,8 @@ class MinMaxlocElementalConversion : public mlir::OpRewritePattern<Op> {
mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
// Compare with the max reduction value
- mlir::Value cmp;
- if (mlir::isa<mlir::FloatType>(elementType)) {
- // For FP reductions we want the first smallest value to be used, that
- // is not NaN. A OGL/OLT condition will usually work for this unless all
- // the values are Nan or Inf. This follows the same logic as
- // NumericCompare for Minloc/Maxlox in extrema.cpp.
- cmp = builder.create<mlir::arith::CmpFOp>(
- loc,
- isMax ? mlir::arith::CmpFPredicate::OGT
- : mlir::arith::CmpFPredicate::OLT,
- elem, reduction);
-
- mlir::Value cmpNan = builder.create<mlir::arith::CmpFOp>(
- loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction);
- mlir::Value cmpNan2 = builder.create<mlir::arith::CmpFOp>(
- loc, mlir::arith::CmpFPredicate::OEQ, elem, elem);
- cmpNan = builder.create<mlir::arith::AndIOp>(loc, cmpNan, cmpNan2);
- cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, cmpNan);
- } else if (mlir::isa<mlir::IntegerType>(elementType)) {
- cmp = builder.create<mlir::arith::CmpIOp>(
- loc,
- isMax ? mlir::arith::CmpIPredicate::sgt
- : mlir::arith::CmpIPredicate::slt,
- elem, reduction);
- } else {
- llvm_unreachable("unsupported type");
- }
+ mlir::Value cmp =
+ generateMinMaxComparison(builder, loc, elem, reduction, isMax);
// The condition used for the loop is isFirst || <the condition above>.
isFirst = builder.create<fir::ConvertOp>(loc, cmp.getType(), isFirst);
@@ -1055,11 +1123,19 @@ class OptimizedBufferizationPass
patterns.insert<ElementalAssignBufferization>(context);
patterns.insert<BroadcastAssignBufferization>(context);
patterns.insert<VariableAssignBufferization>(context);
- patterns.insert<ReductionElementalConversion<hlfir::CountOp>>(context);
- patterns.insert<ReductionElementalConversion<hlfir::AnyOp>>(context);
- patterns.insert<ReductionElementalConversion<hlfir::AllOp>>(context);
- patterns.insert<MinMaxlocElementalConversion<hlfir::MinlocOp>>(context);
- patterns.insert<MinMaxlocElementalConversion<hlfir::MaxlocOp>>(context);
+ patterns.insert<ReductionConversion<hlfir::CountOp>>(context);
+ patterns.insert<ReductionConversion<hlfir::AnyOp>>(context);
+ patterns.insert<ReductionConversion<hlfir::AllOp>>(context);
+ // TODO: implement basic minloc/maxloc conversion.
+ // patterns.insert<ReductionConversion<hlfir::MaxlocOp>>(context);
+ // patterns.insert<ReductionConversion<hlfir::MinlocOp>>(context);
+ patterns.insert<ReductionConversion<hlfir::MaxvalOp>>(context);
+ patterns.insert<ReductionConversion<hlfir::MinvalOp>>(context);
+ patterns.insert<ReductionMaskConversion<hlfir::MinlocOp>>(context);
+ patterns.insert<ReductionMaskConversion<hlfir::MaxlocOp>>(context);
+ // TODO: implement masked minval/maxval conversion.
+ // patterns.insert<ReductionMaskConversion<hlfir::MaxvalOp>>(context);
+ // patterns.insert<ReductionMaskConversion<hlfir::MinvalOp>>(context);
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
getOperation(), std::move(patterns), config))) {
diff --git a/flang/test/HLFIR/maxval-elemental.fir b/flang/test/HLFIR/maxval-elemental.fir
new file mode 100644
index 00000000000000..aa642253b08323
--- /dev/null
+++ b/flang/test/HLFIR/maxval-elemental.fir
@@ -0,0 +1,95 @@
+// Test maxval inlining for both elemental and designate
+// RUN: fir-opt %s -opt-bufferization | FileCheck %s
+
+// subroutine test(array)
+// integer :: array(:), x
+// x = maxval(abs(array))
+// end subroutine test
+
+func.func @_QPtest(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}) {
+ %c31_i32 = arith.constant 31 : i32
+ %c0 = arith.constant 0 : index
+ %0 = fir.dummy_scope : !fir.dscope
+ %1:2 = hlfir.declare %arg0 dummy_scope %0 {uniq_name = "_QFtestEarray"} : (!fir.box<!fir.array<?xi32>>, !fir.dscope) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
+ %2 = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFtestEx"}
+ %3:2 = hlfir.declare %2 {uniq_name = "_QFtestEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ %4:3 = fir.box_dims %1#0, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
+ %5 = fir.shape %4#1 : (index) -> !fir.shape<1>
+ %6 = hlfir.elemental %5 unordered : (!fir.shape<1>) -> !hlfir.expr<?xi32> {
+ ^bb0(%arg1: index):
+ %8 = hlfir.designate %1#0 (%arg1) : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
+ %9 = fir.load %8 : !fir.ref<i32>
+ %10 = arith.shrsi %9, %c31_i32 : i32
+ %11 = arith.xori %9, %10 : i32
+ %12 = arith.subi %11, %10 : i32
+ hlfir.yield_element %12 : i32
+ }
+ %7 = hlfir.maxval %6 {fastmath = #arith.fastmath<contract>} : (!hlfir.expr<?xi32>) -> i32
+ hlfir.assign %7 to %3#0 : i32, !fir.ref<i32>
+ hlfir.destroy %6 : !hlfir.expr<?xi32>
+ return
+}
+
+// CHECK-LABEL: func.func @_QPtest(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}) {
+// CHECK-NEXT: %c1 = arith.constant 1 : index
+// CHECK-NEXT: %c-2147483648_i32 = arith.constant -2147483648 : i32
+// CHECK-NEXT: %c31_i32 = arith.constant 31 : i32
+// CHECK-NEXT: %c0 = arith.constant 0 : index
+// CHECK-NEXT: %[[V0:.*]] = fir.dummy_scope : !fir.dscope
+// CHECK-NEXT: %[[V1:.*]]:2 = hlfir.declare %arg0 dummy_scope %[[V0]] {uniq_name = "_QFtestEarray"} : (!fir.box<!fir.array<?xi32>>, !fir.dscope) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
+// CHECK-NEXT: %[[V2:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFtestEx"}
+// CHECK-NEXT: %[[V3:.*]]:2 = hlfir.declare %[[V2]] {uniq_name = "_QFtestEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK-NEXT: %[[V4:.*]]:3 = fir.box_dims %[[V1]]#0, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
+// CHECK-NEXT: %[[V5:.*]] = fir.do_loop %arg1 = %c1 to %[[V4]]#1 step %c1 iter_args(%arg2 = %c-2147483648_i32) -> (i32) {
+// CHECK-NEXT: %[[V6:.*]] = hlfir.designate %[[V1]]#0 (%arg...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/103503
More information about the flang-commits
mailing list