[flang-commits] [flang] [flang] Improve designate/elemental indices match in opt-bufferization. (PR #121371)

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Mon Dec 30 21:09:14 PST 2024


https://github.com/vzakhari created https://github.com/llvm/llvm-project/pull/121371

This pattern appears in `tonto`: `rys1%w = rys1%w * ...`, where
component `w` is a pointer. Due to the computations transforming
the elemental's one-based indices to the array indices,
the indices match check did not pass in opt-bufferization.
This patch recognizes this indices adjusting pattern,
and returns the one-based indices for the designator.


>From 56d72e8289dc110a54ae86ef4dee0d9915ef039f Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Mon, 30 Dec 2024 21:03:09 -0800
Subject: [PATCH] [flang] Improve designate/elemental indices match in
 opt-bufferization.

This pattern appears in `tonto`: `rys1%w = rys1%w * ...`, where
component `w` is a pointer. Due to the computations transforming
the elemental's one-based indices to the array indices,
the indices match check did not pass in opt-bufferization.
This patch recognizes this indices adjusting pattern,
and returns the one-based indices for the designator.
---
 .../Transforms/OptimizedBufferization.cpp     | 76 ++++++++++++++++++-
 .../opt-bufferization-same-ptr-elemental.fir  | 69 +++++++++++++++++
 2 files changed, 144 insertions(+), 1 deletion(-)
 create mode 100644 flang/test/HLFIR/opt-bufferization-same-ptr-elemental.fir

diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index bf3cf861e46f4a..bfaabed0136785 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -87,6 +87,13 @@ class ElementalAssignBufferization
   /// determines if the transformation can be applied to this elemental
   static std::optional<MatchInfo> findMatch(hlfir::ElementalOp elemental);
 
+  /// Returns the array indices for the given hlfir.designate.
+  /// It recognizes the computations used to transform the one-based indices
+  /// into the array's lb-based indices, and returns the one-based indices
+  /// in these cases.
+  static llvm::SmallVector<mlir::Value>
+  getDesignatorIndices(hlfir::DesignateOp designate);
+
 public:
   using mlir::OpRewritePattern<hlfir::ElementalOp>::OpRewritePattern;
 
@@ -430,6 +437,73 @@ bool ArraySectionAnalyzer::isLess(mlir::Value v1, mlir::Value v2) {
   return false;
 }
 
+llvm::SmallVector<mlir::Value>
+ElementalAssignBufferization::getDesignatorIndices(
+    hlfir::DesignateOp designate) {
+  mlir::Value memref = designate.getMemref();
+
+  // If the object is a box, then the indices may be adjusted
+  // according to the box's lower bound(s). Scan through
+  // the computations to try to find the one-based indices.
+  if (mlir::isa<fir::BaseBoxType>(memref.getType())) {
+    // Look for the following pattern:
+    //   %13 = fir.load %12 : !fir.ref<!fir.box<...>
+    //   %14:3 = fir.box_dims %13, %c0 : (!fir.box<...>, index) -> ...
+    //   %17 = arith.subi %14#0, %c1 : index
+    //   %18 = arith.addi %arg2, %17 : index
+    //   %19 = hlfir.designate %13 (%18)  : (!fir.box<...>, index) -> ...
+    //
+    // %arg2 is a one-based index.
+
+    auto isNormalizedLb = [memref](mlir::Value v, unsigned dim) {
+      // Return true, if v and dim are such that:
+      //   %14:3 = fir.box_dims %13, %dim : (!fir.box<...>, index) -> ...
+      //   %17 = arith.subi %14#0, %c1 : index
+      //   %19 = hlfir.designate %13 (...)  : (!fir.box<...>, index) -> ...
+      if (auto subOp =
+              mlir::dyn_cast_or_null<mlir::arith::SubIOp>(v.getDefiningOp())) {
+        auto cst = fir::getIntIfConstant(subOp.getRhs());
+        if (!cst || *cst != 1)
+          return false;
+        if (auto dimsOp = mlir::dyn_cast_or_null<fir::BoxDimsOp>(
+                subOp.getLhs().getDefiningOp())) {
+          if (memref != dimsOp.getVal() ||
+              dimsOp.getResult(0) != subOp.getLhs())
+            return false;
+          auto dimsOpDim = fir::getIntIfConstant(dimsOp.getDim());
+          return dimsOpDim && dimsOpDim == dim;
+        }
+      }
+      return false;
+    };
+
+    llvm::SmallVector<mlir::Value> newIndices;
+    for (auto index : llvm::enumerate(designate.getIndices())) {
+      if (auto addOp = mlir::dyn_cast_or_null<mlir::arith::AddIOp>(
+              index.value().getDefiningOp())) {
+        for (unsigned opNum = 0; opNum < 2; ++opNum)
+          if (isNormalizedLb(addOp->getOperand(opNum), index.index())) {
+            newIndices.push_back(addOp->getOperand((opNum + 1) % 2));
+            break;
+          }
+
+        // If new one-based index was not added, exit early.
+        if (newIndices.size() <= index.index())
+          break;
+      }
+    }
+
+    // If any of the indices is not adjusted to the array's lb,
+    // then return the original designator indices.
+    if (newIndices.size() != designate.getIndices().size())
+      return designate.getIndices();
+
+    return newIndices;
+  }
+
+  return designate.getIndices();
+}
+
 std::optional<ElementalAssignBufferization::MatchInfo>
 ElementalAssignBufferization::findMatch(hlfir::ElementalOp elemental) {
   mlir::Operation::user_range users = elemental->getUsers();
@@ -557,7 +631,7 @@ ElementalAssignBufferization::findMatch(hlfir::ElementalOp elemental) {
                                   << " at " << elemental.getLoc() << "\n");
           return std::nullopt;
         }
-        auto indices = designate.getIndices();
+        auto indices = getDesignatorIndices(designate);
         auto elementalIndices = elemental.getIndices();
         if (indices.size() == elementalIndices.size() &&
             std::equal(indices.begin(), indices.end(), elementalIndices.begin(),
diff --git a/flang/test/HLFIR/opt-bufferization-same-ptr-elemental.fir b/flang/test/HLFIR/opt-bufferization-same-ptr-elemental.fir
new file mode 100644
index 00000000000000..ae91930d44eb12
--- /dev/null
+++ b/flang/test/HLFIR/opt-bufferization-same-ptr-elemental.fir
@@ -0,0 +1,69 @@
+// RUN: fir-opt --opt-bufferization %s | FileCheck %s
+
+// Verify that the hlfir.assign of hlfir.elemental is optimized
+// into element-per-element assignment:
+// subroutine test1(p)
+//   real, pointer :: p(:)
+//   p = p + 1.0
+// end subroutine test1
+
+func.func @_QPtest1(%arg0: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>> {fir.bindc_name = "p"}) {
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 1.000000e+00 : f32
+  %0 = fir.dummy_scope : !fir.dscope
+  %1:2 = hlfir.declare %arg0 dummy_scope %0 {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFtest1Ep"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
+  %2 = fir.load %1#0 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+  %3:3 = fir.box_dims %2, %c0 : (!fir.box<!fir.ptr<!fir.array<?xf32>>>, index) -> (index, index, index)
+  %4 = fir.shape %3#1 : (index) -> !fir.shape<1>
+  %5 = hlfir.elemental %4 unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+  ^bb0(%arg1: index):
+    %6 = arith.subi %3#0, %c1 : index
+    %7 = arith.addi %arg1, %6 : index
+    %8 = hlfir.designate %2 (%7)  : (!fir.box<!fir.ptr<!fir.array<?xf32>>>, index) -> !fir.ref<f32>
+    %9 = fir.load %8 : !fir.ref<f32>
+    %10 = arith.addf %9, %cst fastmath<contract> : f32
+    hlfir.yield_element %10 : f32
+  }
+  hlfir.assign %5 to %2 : !hlfir.expr<?xf32>, !fir.box<!fir.ptr<!fir.array<?xf32>>>
+  hlfir.destroy %5 : !hlfir.expr<?xf32>
+  return
+}
+// CHECK-LABEL:   func.func @_QPtest1(
+// CHECK-NOT:         hlfir.assign
+// CHECK:             hlfir.assign %{{.*}} to %{{.*}} : f32, !fir.ref<f32>
+// CHECK-NOT:         hlfir.assign
+
+// subroutine test2(p)
+//   real, pointer :: p(:,:)
+//   p = p + 1.0
+// end subroutine test2
+func.func @_QPtest2(%arg0: !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>> {fir.bindc_name = "p"}) {
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 1.000000e+00 : f32
+  %0 = fir.dummy_scope : !fir.dscope
+  %1:2 = hlfir.declare %arg0 dummy_scope %0 {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFtest2Ep"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>)
+  %2 = fir.load %1#0 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>
+  %3:3 = fir.box_dims %2, %c0 : (!fir.box<!fir.ptr<!fir.array<?x?xf32>>>, index) -> (index, index, index)
+  %4:3 = fir.box_dims %2, %c1 : (!fir.box<!fir.ptr<!fir.array<?x?xf32>>>, index) -> (index, index, index)
+  %5 = fir.shape %3#1, %4#1 : (index, index) -> !fir.shape<2>
+  %6 = hlfir.elemental %5 unordered : (!fir.shape<2>) -> !hlfir.expr<?x?xf32> {
+  ^bb0(%arg1: index, %arg2: index):
+    %7 = arith.subi %3#0, %c1 : index
+    %8 = arith.addi %arg1, %7 : index
+    %9 = arith.subi %4#0, %c1 : index
+    %10 = arith.addi %arg2, %9 : index
+    %11 = hlfir.designate %2 (%8, %10)  : (!fir.box<!fir.ptr<!fir.array<?x?xf32>>>, index, index) -> !fir.ref<f32>
+    %12 = fir.load %11 : !fir.ref<f32>
+    %13 = arith.addf %12, %cst fastmath<contract> : f32
+    hlfir.yield_element %13 : f32
+  }
+  hlfir.assign %6 to %2 : !hlfir.expr<?x?xf32>, !fir.box<!fir.ptr<!fir.array<?x?xf32>>>
+  hlfir.destroy %6 : !hlfir.expr<?x?xf32>
+  return
+}
+// CHECK-LABEL:   func.func @_QPtest2(
+// CHECK-NOT:         hlfir.assign
+// CHECK:             hlfir.assign %{{.*}} to %{{.*}} : f32, !fir.ref<f32>
+// CHECK-NOT:         hlfir.assign



More information about the flang-commits mailing list