[flang-commits] [flang] [flang] Improve designate/elemental indices match in opt-bufferization. (PR #121371)
via flang-commits
flang-commits at lists.llvm.org
Mon Dec 30 21:09:50 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Slava Zakharin (vzakhari)
<details>
<summary>Changes</summary>
This pattern appears in `tonto`: `rys1%w = rys1%w * ...`, where
component `w` is a pointer. Due to the computations transforming
the elemental's one-based indices to the array indices,
the indices match check did not pass in opt-bufferization.
This patch recognizes this indices adjusting pattern,
and returns the one-based indices for the designator.
---
Full diff: https://github.com/llvm/llvm-project/pull/121371.diff
2 Files Affected:
- (modified) flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp (+75-1)
- (added) flang/test/HLFIR/opt-bufferization-same-ptr-elemental.fir (+69)
``````````diff
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index bf3cf861e46f4a..bfaabed0136785 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -87,6 +87,13 @@ class ElementalAssignBufferization
/// determines if the transformation can be applied to this elemental
static std::optional<MatchInfo> findMatch(hlfir::ElementalOp elemental);
+ /// Returns the array indices for the given hlfir.designate.
+ /// It recognizes the computations used to transform the one-based indices
+ /// into the array's lb-based indices, and returns the one-based indices
+ /// in these cases.
+ static llvm::SmallVector<mlir::Value>
+ getDesignatorIndices(hlfir::DesignateOp designate);
+
public:
using mlir::OpRewritePattern<hlfir::ElementalOp>::OpRewritePattern;
@@ -430,6 +437,73 @@ bool ArraySectionAnalyzer::isLess(mlir::Value v1, mlir::Value v2) {
return false;
}
+llvm::SmallVector<mlir::Value>
+ElementalAssignBufferization::getDesignatorIndices(
+ hlfir::DesignateOp designate) {
+ mlir::Value memref = designate.getMemref();
+
+ // If the object is a box, then the indices may be adjusted
+ // according to the box's lower bound(s). Scan through
+ // the computations to try to find the one-based indices.
+ if (mlir::isa<fir::BaseBoxType>(memref.getType())) {
+ // Look for the following pattern:
+ // %13 = fir.load %12 : !fir.ref<!fir.box<...>
+ // %14:3 = fir.box_dims %13, %c0 : (!fir.box<...>, index) -> ...
+ // %17 = arith.subi %14#0, %c1 : index
+ // %18 = arith.addi %arg2, %17 : index
+ // %19 = hlfir.designate %13 (%18) : (!fir.box<...>, index) -> ...
+ //
+ // %arg2 is a one-based index.
+
+ auto isNormalizedLb = [memref](mlir::Value v, unsigned dim) {
+ // Return true, if v and dim are such that:
+ // %14:3 = fir.box_dims %13, %dim : (!fir.box<...>, index) -> ...
+ // %17 = arith.subi %14#0, %c1 : index
+ // %19 = hlfir.designate %13 (...) : (!fir.box<...>, index) -> ...
+ if (auto subOp =
+ mlir::dyn_cast_or_null<mlir::arith::SubIOp>(v.getDefiningOp())) {
+ auto cst = fir::getIntIfConstant(subOp.getRhs());
+ if (!cst || *cst != 1)
+ return false;
+ if (auto dimsOp = mlir::dyn_cast_or_null<fir::BoxDimsOp>(
+ subOp.getLhs().getDefiningOp())) {
+ if (memref != dimsOp.getVal() ||
+ dimsOp.getResult(0) != subOp.getLhs())
+ return false;
+ auto dimsOpDim = fir::getIntIfConstant(dimsOp.getDim());
+ return dimsOpDim && dimsOpDim == dim;
+ }
+ }
+ return false;
+ };
+
+ llvm::SmallVector<mlir::Value> newIndices;
+ for (auto index : llvm::enumerate(designate.getIndices())) {
+ if (auto addOp = mlir::dyn_cast_or_null<mlir::arith::AddIOp>(
+ index.value().getDefiningOp())) {
+ for (unsigned opNum = 0; opNum < 2; ++opNum)
+ if (isNormalizedLb(addOp->getOperand(opNum), index.index())) {
+ newIndices.push_back(addOp->getOperand((opNum + 1) % 2));
+ break;
+ }
+
+ // If new one-based index was not added, exit early.
+ if (newIndices.size() <= index.index())
+ break;
+ }
+ }
+
+ // If any of the indices is not adjusted to the array's lb,
+ // then return the original designator indices.
+ if (newIndices.size() != designate.getIndices().size())
+ return designate.getIndices();
+
+ return newIndices;
+ }
+
+ return designate.getIndices();
+}
+
std::optional<ElementalAssignBufferization::MatchInfo>
ElementalAssignBufferization::findMatch(hlfir::ElementalOp elemental) {
mlir::Operation::user_range users = elemental->getUsers();
@@ -557,7 +631,7 @@ ElementalAssignBufferization::findMatch(hlfir::ElementalOp elemental) {
<< " at " << elemental.getLoc() << "\n");
return std::nullopt;
}
- auto indices = designate.getIndices();
+ auto indices = getDesignatorIndices(designate);
auto elementalIndices = elemental.getIndices();
if (indices.size() == elementalIndices.size() &&
std::equal(indices.begin(), indices.end(), elementalIndices.begin(),
diff --git a/flang/test/HLFIR/opt-bufferization-same-ptr-elemental.fir b/flang/test/HLFIR/opt-bufferization-same-ptr-elemental.fir
new file mode 100644
index 00000000000000..ae91930d44eb12
--- /dev/null
+++ b/flang/test/HLFIR/opt-bufferization-same-ptr-elemental.fir
@@ -0,0 +1,69 @@
+// RUN: fir-opt --opt-bufferization %s | FileCheck %s
+
+// Verify that the hlfir.assign of hlfir.elemental is optimized
+// into element-per-element assignment:
+// subroutine test1(p)
+// real, pointer :: p(:)
+// p = p + 1.0
+// end subroutine test1
+
+func.func @_QPtest1(%arg0: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>> {fir.bindc_name = "p"}) {
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 1.000000e+00 : f32
+ %0 = fir.dummy_scope : !fir.dscope
+ %1:2 = hlfir.declare %arg0 dummy_scope %0 {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFtest1Ep"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
+ %2 = fir.load %1#0 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+ %3:3 = fir.box_dims %2, %c0 : (!fir.box<!fir.ptr<!fir.array<?xf32>>>, index) -> (index, index, index)
+ %4 = fir.shape %3#1 : (index) -> !fir.shape<1>
+ %5 = hlfir.elemental %4 unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+ ^bb0(%arg1: index):
+ %6 = arith.subi %3#0, %c1 : index
+ %7 = arith.addi %arg1, %6 : index
+ %8 = hlfir.designate %2 (%7) : (!fir.box<!fir.ptr<!fir.array<?xf32>>>, index) -> !fir.ref<f32>
+ %9 = fir.load %8 : !fir.ref<f32>
+ %10 = arith.addf %9, %cst fastmath<contract> : f32
+ hlfir.yield_element %10 : f32
+ }
+ hlfir.assign %5 to %2 : !hlfir.expr<?xf32>, !fir.box<!fir.ptr<!fir.array<?xf32>>>
+ hlfir.destroy %5 : !hlfir.expr<?xf32>
+ return
+}
+// CHECK-LABEL: func.func @_QPtest1(
+// CHECK-NOT: hlfir.assign
+// CHECK: hlfir.assign %{{.*}} to %{{.*}} : f32, !fir.ref<f32>
+// CHECK-NOT: hlfir.assign
+
+// subroutine test2(p)
+// real, pointer :: p(:,:)
+// p = p + 1.0
+// end subroutine test2
+func.func @_QPtest2(%arg0: !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>> {fir.bindc_name = "p"}) {
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 1.000000e+00 : f32
+ %0 = fir.dummy_scope : !fir.dscope
+ %1:2 = hlfir.declare %arg0 dummy_scope %0 {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFtest2Ep"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>)
+ %2 = fir.load %1#0 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>
+ %3:3 = fir.box_dims %2, %c0 : (!fir.box<!fir.ptr<!fir.array<?x?xf32>>>, index) -> (index, index, index)
+ %4:3 = fir.box_dims %2, %c1 : (!fir.box<!fir.ptr<!fir.array<?x?xf32>>>, index) -> (index, index, index)
+ %5 = fir.shape %3#1, %4#1 : (index, index) -> !fir.shape<2>
+ %6 = hlfir.elemental %5 unordered : (!fir.shape<2>) -> !hlfir.expr<?x?xf32> {
+ ^bb0(%arg1: index, %arg2: index):
+ %7 = arith.subi %3#0, %c1 : index
+ %8 = arith.addi %arg1, %7 : index
+ %9 = arith.subi %4#0, %c1 : index
+ %10 = arith.addi %arg2, %9 : index
+ %11 = hlfir.designate %2 (%8, %10) : (!fir.box<!fir.ptr<!fir.array<?x?xf32>>>, index, index) -> !fir.ref<f32>
+ %12 = fir.load %11 : !fir.ref<f32>
+ %13 = arith.addf %12, %cst fastmath<contract> : f32
+ hlfir.yield_element %13 : f32
+ }
+ hlfir.assign %6 to %2 : !hlfir.expr<?x?xf32>, !fir.box<!fir.ptr<!fir.array<?x?xf32>>>
+ hlfir.destroy %6 : !hlfir.expr<?x?xf32>
+ return
+}
+// CHECK-LABEL: func.func @_QPtest2(
+// CHECK-NOT: hlfir.assign
+// CHECK: hlfir.assign %{{.*}} to %{{.*}} : f32, !fir.ref<f32>
+// CHECK-NOT: hlfir.assign
``````````
</details>
https://github.com/llvm/llvm-project/pull/121371
More information about the flang-commits
mailing list