[flang-commits] [flang] [Flang] Generate inline reduction loops for elemental count intrinsics (PR #75774)

David Green via flang-commits flang-commits at lists.llvm.org
Tue Jan 9 01:27:48 PST 2024


https://github.com/davemgreen updated https://github.com/llvm/llvm-project/pull/75774

>From ed74ee78ed16be141c236cd18e6e96311555fb43 Mon Sep 17 00:00:00 2001
From: David Green <david.green at arm.com>
Date: Mon, 18 Dec 2023 09:22:50 +0000
Subject: [PATCH 1/2] [Flang] Generate inline reduction loops for elemental
 count intrinsics

This adds a ReductionElementalConversion transform to OptimizedBufferizationPass,
taking hlfir::count(hlfir::elemental) and generating the inline loop to perform
the count of true elements. This lets us generate a single loop instead of
ending up as two plus a temporary.

This is currently part of OptimizedBufferization, similar to #74828. I
attempted to move it to LowerHLFIRIntrinsics to make it part of the existing
lowering, but it hit problems with inlining elementals that contain operations
that are being legalized by the same pass.

Any and All should be able to share the same code with a different
function/initial value.
---
 .../Transforms/OptimizedBufferization.cpp     | 119 +++++++
 flang/test/HLFIR/count-elemental.fir          | 314 ++++++++++++++++++
 2 files changed, 433 insertions(+)
 create mode 100644 flang/test/HLFIR/count-elemental.fir

diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index 7abfa20493c736..7c839f1d20af52 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -659,6 +659,124 @@ mlir::LogicalResult VariableAssignBufferization::matchAndRewrite(
   return mlir::success();
 }
 
+using GenBodyFn =
+    std::function<mlir::Value(fir::FirOpBuilder &, mlir::Location, mlir::Value,
+                              const llvm::SmallVectorImpl<mlir::Value> &)>;
+static mlir::Value generateReductionLoop(fir::FirOpBuilder &builder,
+                                         mlir::Location loc, mlir::Value init,
+                                         mlir::Value shape, GenBodyFn genBody) {
+  auto extents = hlfir::getIndexExtents(loc, builder, shape);
+  mlir::Value reduction = init;
+  mlir::IndexType idxTy = builder.getIndexType();
+  mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1);
+
+  // Create a reduction loop nest. We use one-based indices so that they can be
+  // passed to the elemental.
+  llvm::SmallVector<mlir::Value> indices;
+  for (unsigned i = 0; i < extents.size(); ++i) {
+    auto loop =
+        builder.create<fir::DoLoopOp>(loc, oneIdx, extents[i], oneIdx, false,
+                                      /*finalCountValue=*/false, reduction);
+    reduction = loop.getRegionIterArgs()[0];
+    indices.push_back(loop.getInductionVar());
+    // Set insertion point to the loop body so that the next loop
+    // is inserted inside the current one.
+    builder.setInsertionPointToStart(loop.getBody());
+  }
+
+  // Generate the body
+  reduction = genBody(builder, loc, reduction, indices);
+
+  // Unwind the loop nest.
+  for (unsigned i = 0; i < extents.size(); ++i) {
+    auto result = builder.create<fir::ResultOp>(loc, reduction);
+    auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
+    reduction = loop.getResult(0);
+    // Set insertion point after the loop operation that we have
+    // just processed.
+    builder.setInsertionPointAfter(loop.getOperation());
+  }
+
+  return reduction;
+}
+
+/// Given a reduction operation with an elemental mask, attempt to generate a
+/// do-loop to perform the operation inline.
+///   %e = hlfir.elemental %shape unordered
+///   %r = hlfir.count %e
+/// =>
+///   %r = for.do_loop %arg = 1 to bound(%shape) step 1 iter_args(%arg2 = init)
+///     %i = <inline elemental>
+///     %c = <reduce count> %i
+///     fir.result %c
+template <typename Op>
+class ReductionElementalConversion : public mlir::OpRewritePattern<Op> {
+public:
+  using mlir::OpRewritePattern<Op>::OpRewritePattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
+    mlir::Location loc = op.getLoc();
+    hlfir::ElementalOp elemental =
+        op.getMask().template getDefiningOp<hlfir::ElementalOp>();
+    if (!elemental || op.getDim())
+      return rewriter.notifyMatchFailure(op, "Did not find valid elemental");
+
+    fir::KindMapping kindMap =
+        fir::getKindMapping(op->template getParentOfType<mlir::ModuleOp>());
+    fir::FirOpBuilder builder{op, kindMap};
+
+    mlir::Value init;
+    GenBodyFn genBodyFn;
+    if constexpr (std::is_same_v<Op, hlfir::CountOp>) {
+      init = builder.createIntegerConstant(loc, op.getType(), 0);
+      genBodyFn = [elemental](fir::FirOpBuilder builder, mlir::Location loc,
+                              mlir::Value reduction,
+                              const llvm::SmallVectorImpl<mlir::Value> &indices)
+          -> mlir::Value {
+        // Inline the elemental and get the condition from it.
+        auto yield = inlineElementalOp(loc, builder, elemental, indices);
+        mlir::Value cond = builder.create<fir::ConvertOp>(
+            loc, builder.getI1Type(), yield.getElementValue());
+        yield->erase();
+
+        // Conditionally add one to the current value
+        mlir::Value one =
+            builder.createIntegerConstant(loc, reduction.getType(), 1);
+        mlir::Value add1 =
+            builder.create<mlir::arith::AddIOp>(loc, reduction, one);
+        return builder.create<mlir::arith::SelectOp>(loc, cond, add1,
+                                                     reduction);
+      };
+    } else {
+      static_assert("Expected Op to be handled");
+      return mlir::failure();
+    }
+
+    mlir::Value res = generateReductionLoop(builder, loc, init,
+                                            elemental.getOperand(0), genBodyFn);
+    if (res.getType() != op.getType())
+      res = builder.create<fir::ConvertOp>(loc, op.getType(), res);
+
+    // Check if the op was the only user of the elemental (apart from a
+    // destroy), and remove it if so.
+    mlir::Operation::user_range elemUsers = elemental->getUsers();
+    hlfir::DestroyOp elemDestroy;
+    if (std::distance(elemUsers.begin(), elemUsers.end()) == 2) {
+      elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*elemUsers.begin());
+      if (!elemDestroy)
+        elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++elemUsers.begin());
+    }
+
+    rewriter.replaceOp(op, res);
+    if (elemDestroy) {
+      rewriter.eraseOp(elemDestroy);
+      rewriter.eraseOp(elemental);
+    }
+    return mlir::success();
+  }
+};
+
 class OptimizedBufferizationPass
     : public hlfir::impl::OptimizedBufferizationBase<
           OptimizedBufferizationPass> {
@@ -681,6 +799,7 @@ class OptimizedBufferizationPass
     patterns.insert<ElementalAssignBufferization>(context);
     patterns.insert<BroadcastAssignBufferization>(context);
     patterns.insert<VariableAssignBufferization>(context);
+    patterns.insert<ReductionElementalConversion<hlfir::CountOp>>(context);
 
     if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
             func, std::move(patterns), config))) {
diff --git a/flang/test/HLFIR/count-elemental.fir b/flang/test/HLFIR/count-elemental.fir
new file mode 100644
index 00000000000000..1641e0fae6fb55
--- /dev/null
+++ b/flang/test/HLFIR/count-elemental.fir
@@ -0,0 +1,314 @@
+// RUN: fir-opt %s -opt-bufferization | FileCheck %s
+
+func.func @_QFPtest(%arg0: !fir.ref<!fir.array<4x7xi32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "row"}, %arg2: !fir.ref<i32> {fir.bindc_name = "val"}) -> i32 {
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c7 = arith.constant 7 : index
+  %0 = fir.shape %c4, %c7 : (index, index) -> !fir.shape<2>
+  %1:2 = hlfir.declare %arg0(%0) {uniq_name = "_QFFtestEb"} : (!fir.ref<!fir.array<4x7xi32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<4x7xi32>>, !fir.ref<!fir.array<4x7xi32>>)
+  %2:2 = hlfir.declare %arg1 {uniq_name = "_QFFtestErow"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %3 = fir.alloca i32 {bindc_name = "test", uniq_name = "_QFFtestEtest"}
+  %4:2 = hlfir.declare %3 {uniq_name = "_QFFtestEtest"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %5:2 = hlfir.declare %arg2 {uniq_name = "_QFFtestEval"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %6 = fir.load %2#0 : !fir.ref<i32>
+  %7 = fir.convert %6 : (i32) -> i64
+  %8 = fir.shape %c7 : (index) -> !fir.shape<1>
+  %9 = hlfir.designate %1#0 (%7, %c1:%c7:%c1)  shape %8 : (!fir.ref<!fir.array<4x7xi32>>, i64, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<7xi32>>
+  %10 = fir.load %5#0 : !fir.ref<i32>
+  %11 = hlfir.elemental %8 unordered : (!fir.shape<1>) -> !hlfir.expr<7x!fir.logical<4>> {
+  ^bb0(%arg3: index):
+    %14 = hlfir.designate %9 (%arg3)  : (!fir.box<!fir.array<7xi32>>, index) -> !fir.ref<i32>
+    %15 = fir.load %14 : !fir.ref<i32>
+    %16 = arith.cmpi sge, %15, %10 : i32
+    %17 = fir.convert %16 : (i1) -> !fir.logical<4>
+    hlfir.yield_element %17 : !fir.logical<4>
+  }
+  %12 = hlfir.count %11 : (!hlfir.expr<7x!fir.logical<4>>) -> i32
+  hlfir.assign %12 to %4#0 : i32, !fir.ref<i32>
+  hlfir.destroy %11 : !hlfir.expr<7x!fir.logical<4>>
+  %13 = fir.load %4#1 : !fir.ref<i32>
+  return %13 : i32
+}
+// CHECK-LABEL:  func.func @_QFPtest(%arg0: !fir.ref<!fir.array<4x7xi32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "row"}, %arg2: !fir.ref<i32> {fir.bindc_name = "val"}) -> i32 {
+// CHECK-NEXT:     %c1_i32 = arith.constant 1 : i32
+// CHECK-NEXT:     %c0_i32 = arith.constant 0 : i32
+// CHECK-NEXT:     %c1 = arith.constant 1 : index
+// CHECK-NEXT:     %c4 = arith.constant 4 : index
+// CHECK-NEXT:     %c7 = arith.constant 7 : index
+// CHECK-NEXT:     %[[V0:.*]] = fir.shape %c4, %c7 : (index, index) -> !fir.shape<2>
+// CHECK-NEXT:     %[[V1:.*]]:2 = hlfir.declare %arg0(%[[V0]])
+// CHECK-NEXT:     %[[V2:.*]]:2 = hlfir.declare %arg1
+// CHECK-NEXT:     %[[V3:.*]] = fir.alloca i32
+// CHECK-NEXT:     %[[V4:.*]]:2 = hlfir.declare %[[V3]]
+// CHECK-NEXT:     %[[V5:.*]]:2 = hlfir.declare %arg2
+// CHECK-NEXT:     %[[V6:.*]] = fir.load %[[V2]]#0 : !fir.ref<i32>
+// CHECK-NEXT:     %[[V7:.*]] = fir.convert %[[V6]] : (i32) -> i64
+// CHECK-NEXT:     %[[V8:.*]] = fir.shape %c7 : (index) -> !fir.shape<1>
+// CHECK-NEXT:     %[[V9:.*]] = hlfir.designate %[[V1]]#0 (%[[V7]], %c1:%c7:%c1)  shape %[[V8]] : (!fir.ref<!fir.array<4x7xi32>>, i64, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<7xi32>>
+// CHECK-NEXT:     %[[V10:.*]] = fir.load %[[V5]]#0 : !fir.ref<i32>
+// CHECK-NEXT:     %[[V11:.*]] = fir.do_loop %arg3 = %c1 to %c7 step %c1 iter_args(%arg4 = %c0_i32) -> (i32) {
+// CHECK-NEXT:       %[[V13:.*]] = hlfir.designate %[[V9]] (%arg3)  : (!fir.box<!fir.array<7xi32>>, index) -> !fir.ref<i32>
+// CHECK-NEXT:       %[[V14:.*]] = fir.load %[[V13]] : !fir.ref<i32>
+// CHECK-NEXT:       %[[V15:.*]] = arith.cmpi sge, %[[V14]], %[[V10]] : i32
+// CHECK-NEXT:       %[[V16:.*]] = arith.addi %arg4, %c1_i32 : i32
+// CHECK-NEXT:       %[[V17:.*]] = arith.select %[[V15]], %[[V16]], %arg4 : i32
+// CHECK-NEXT:       fir.result %[[V17]] : i32
+// CHECK-NEXT:     }
+// CHECK-NEXT:     hlfir.assign %[[V11]] to %[[V4]]#0 : i32, !fir.ref<i32>
+// CHECK-NEXT:     %[[V12:.*]] = fir.load %[[V4]]#1 : !fir.ref<i32>
+// CHECK-NEXT:     return %[[V12]] : i32
+
+func.func @_QFPtest_kind2(%arg0: !fir.ref<!fir.array<4x7xi32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "row"}, %arg2: !fir.ref<i32> {fir.bindc_name = "val"}) -> i16 {
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c7 = arith.constant 7 : index
+  %0 = fir.shape %c4, %c7 : (index, index) -> !fir.shape<2>
+  %1:2 = hlfir.declare %arg0(%0) {uniq_name = "_QFFtestEb"} : (!fir.ref<!fir.array<4x7xi32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<4x7xi32>>, !fir.ref<!fir.array<4x7xi32>>)
+  %2:2 = hlfir.declare %arg1 {uniq_name = "_QFFtestErow"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %3 = fir.alloca i16 {bindc_name = "test", uniq_name = "_QFFtestEtest"}
+  %4:2 = hlfir.declare %3 {uniq_name = "_QFFtestEtest"} : (!fir.ref<i16>) -> (!fir.ref<i16>, !fir.ref<i16>)
+  %5:2 = hlfir.declare %arg2 {uniq_name = "_QFFtestEval"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %6 = fir.load %2#0 : !fir.ref<i32>
+  %7 = fir.convert %6 : (i32) -> i64
+  %8 = fir.shape %c7 : (index) -> !fir.shape<1>
+  %9 = hlfir.designate %1#0 (%7, %c1:%c7:%c1)  shape %8 : (!fir.ref<!fir.array<4x7xi32>>, i64, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<7xi32>>
+  %10 = fir.load %5#0 : !fir.ref<i32>
+  %11 = hlfir.elemental %8 unordered : (!fir.shape<1>) -> !hlfir.expr<7x!fir.logical<4>> {
+  ^bb0(%arg3: index):
+    %14 = hlfir.designate %9 (%arg3)  : (!fir.box<!fir.array<7xi32>>, index) -> !fir.ref<i32>
+    %15 = fir.load %14 : !fir.ref<i32>
+    %16 = arith.cmpi sge, %15, %10 : i32
+    %17 = fir.convert %16 : (i1) -> !fir.logical<4>
+    hlfir.yield_element %17 : !fir.logical<4>
+  }
+  %12 = hlfir.count %11 : (!hlfir.expr<7x!fir.logical<4>>) -> i16
+  hlfir.assign %12 to %4#0 : i16, !fir.ref<i16>
+  hlfir.destroy %11 : !hlfir.expr<7x!fir.logical<4>>
+  %13 = fir.load %4#1 : !fir.ref<i16>
+  return %13 : i16
+}
+// CHECK-LABEL:  func.func @_QFPtest_kind2(%arg0: !fir.ref<!fir.array<4x7xi32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "row"}, %arg2: !fir.ref<i32> {fir.bindc_name = "val"}) -> i16 {
+// CHECK-NEXT:     %c1_i16 = arith.constant 1 : i16
+// CHECK-NEXT:     %c0_i16 = arith.constant 0 : i16
+// CHECK-NEXT:     %c1 = arith.constant 1 : index
+// CHECK-NEXT:     %c4 = arith.constant 4 : index
+// CHECK-NEXT:     %c7 = arith.constant 7 : index
+// CHECK-NEXT:     %[[V0:.*]] = fir.shape %c4, %c7 : (index, index) -> !fir.shape<2>
+// CHECK-NEXT:     %[[V1:.*]]:2 = hlfir.declare %arg0(%[[V0]])
+// CHECK-NEXT:     %[[V2:.*]]:2 = hlfir.declare %arg1
+// CHECK-NEXT:     %[[V3:.*]] = fir.alloca i16
+// CHECK-NEXT:     %[[V4:.*]]:2 = hlfir.declare %[[V3]]
+// CHECK-NEXT:     %[[V5:.*]]:2 = hlfir.declare %arg2
+// CHECK-NEXT:     %[[V6:.*]] = fir.load %[[V2]]#0 : !fir.ref<i32>
+// CHECK-NEXT:     %[[V7:.*]] = fir.convert %[[V6]] : (i32) -> i64
+// CHECK-NEXT:     %[[V8:.*]] = fir.shape %c7 : (index) -> !fir.shape<1>
+// CHECK-NEXT:     %[[V9:.*]] = hlfir.designate %[[V1]]#0 (%[[V7]], %c1:%c7:%c1)  shape %[[V8]] : (!fir.ref<!fir.array<4x7xi32>>, i64, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<7xi32>>
+// CHECK-NEXT:     %[[V10:.*]] = fir.load %[[V5]]#0 : !fir.ref<i32>
+// CHECK-NEXT:     %[[V11:.*]] = fir.do_loop %arg3 = %c1 to %c7 step %c1 iter_args(%arg4 = %c0_i16) -> (i16) {
+// CHECK-NEXT:       %[[V13:.*]] = hlfir.designate %[[V9]] (%arg3)  : (!fir.box<!fir.array<7xi32>>, index) -> !fir.ref<i32>
+// CHECK-NEXT:       %[[V14:.*]] = fir.load %[[V13]] : !fir.ref<i32>
+// CHECK-NEXT:       %[[V15:.*]] = arith.cmpi sge, %[[V14]], %[[V10]] : i32
+// CHECK-NEXT:       %[[V16:.*]] = arith.addi %arg4, %c1_i16 : i16
+// CHECK-NEXT:       %[[V17:.*]] = arith.select %[[V15]], %[[V16]], %arg4 : i16
+// CHECK-NEXT:       fir.result %[[V17]] : i16
+// CHECK-NEXT:     }
+// CHECK-NEXT:     hlfir.assign %[[V11]] to %[[V4]]#0 : i16, !fir.ref<i16>
+// CHECK-NEXT:     %[[V12:.*]] = fir.load %[[V4]]#1 : !fir.ref<i16>
+// CHECK-NEXT:     return %[[V12]] : i16
+
+func.func @_QFPtest_dim(%arg0: !fir.ref<!fir.array<4x7xi32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "row"}, %arg2: !fir.ref<i32> {fir.bindc_name = "val"}) -> !fir.array<7xi32> {
+  %c1_i32 = arith.constant 1 : i32
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c7 = arith.constant 7 : index
+  %0 = fir.shape %c4, %c7 : (index, index) -> !fir.shape<2>
+  %1:2 = hlfir.declare %arg0(%0) {uniq_name = "_QFFtestEb"} : (!fir.ref<!fir.array<4x7xi32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<4x7xi32>>, !fir.ref<!fir.array<4x7xi32>>)
+  %2:2 = hlfir.declare %arg1 {uniq_name = "_QFFtestErow"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %3 = fir.alloca !fir.array<7xi32> {bindc_name = "test", uniq_name = "_QFFtestEtest"}
+  %4 = fir.shape %c7 : (index) -> !fir.shape<1>
+  %5:2 = hlfir.declare %3(%4) {uniq_name = "_QFFtestEtest"} : (!fir.ref<!fir.array<7xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<7xi32>>, !fir.ref<!fir.array<7xi32>>)
+  %6:2 = hlfir.declare %arg2 {uniq_name = "_QFFtestEval"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %7 = hlfir.designate %1#0 (%c1:%c4:%c1, %c1:%c7:%c1)  shape %0 : (!fir.ref<!fir.array<4x7xi32>>, index, index, index, index, index, index, !fir.shape<2>) -> !fir.ref<!fir.array<4x7xi32>>
+  %8 = fir.load %6#0 : !fir.ref<i32>
+  %9 = hlfir.elemental %0 unordered : (!fir.shape<2>) -> !hlfir.expr<4x7x!fir.logical<4>> {
+  ^bb0(%arg3: index, %arg4: index):
+    %12 = hlfir.designate %7 (%arg3, %arg4)  : (!fir.ref<!fir.array<4x7xi32>>, index, index) -> !fir.ref<i32>
+    %13 = fir.load %12 : !fir.ref<i32>
+    %14 = arith.cmpi sge, %13, %8 : i32
+    %15 = fir.convert %14 : (i1) -> !fir.logical<4>
+    hlfir.yield_element %15 : !fir.logical<4>
+  }
+  %10 = hlfir.count %9 dim %c1_i32 : (!hlfir.expr<4x7x!fir.logical<4>>, i32) -> !hlfir.expr<7xi32>
+  hlfir.assign %10 to %5#0 : !hlfir.expr<7xi32>, !fir.ref<!fir.array<7xi32>>
+  hlfir.destroy %10 : !hlfir.expr<7xi32>
+  hlfir.destroy %9 : !hlfir.expr<4x7x!fir.logical<4>>
+  %11 = fir.load %5#1 : !fir.ref<!fir.array<7xi32>>
+  return %11 : !fir.array<7xi32>
+}
+// CHECK-LABEL:  func.func @_QFPtest_dim(
+// CHECK: %{{.*}} = hlfir.count %{{.*}} dim %c1_i32
+
+
+func.func @_QFPtest_multi(%arg0: !fir.ref<!fir.array<4x7x2xi32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "row"}, %arg2: !fir.ref<i32> {fir.bindc_name = "val"}) -> i32 {
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c7 = arith.constant 7 : index
+  %c2 = arith.constant 2 : index
+  %0 = fir.shape %c4, %c7, %c2 : (index, index, index) -> !fir.shape<3>
+  %1:2 = hlfir.declare %arg0(%0) {uniq_name = "_QFFtestEb"} : (!fir.ref<!fir.array<4x7x2xi32>>, !fir.shape<3>) -> (!fir.ref<!fir.array<4x7x2xi32>>, !fir.ref<!fir.array<4x7x2xi32>>)
+  %2:2 = hlfir.declare %arg1 {uniq_name = "_QFFtestErow"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %3 = fir.alloca i32 {bindc_name = "test", uniq_name = "_QFFtestEtest"}
+  %4:2 = hlfir.declare %3 {uniq_name = "_QFFtestEtest"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %5:2 = hlfir.declare %arg2 {uniq_name = "_QFFtestEval"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %6 = hlfir.designate %1#0 (%c1:%c4:%c1, %c1:%c7:%c1, %c1:%c2:%c1)  shape %0 : (!fir.ref<!fir.array<4x7x2xi32>>, index, index, index, index, index, index, index, index, index, !fir.shape<3>) -> !fir.ref<!fir.array<4x7x2xi32>>
+  %7 = fir.load %5#0 : !fir.ref<i32>
+  %8 = hlfir.elemental %0 unordered : (!fir.shape<3>) -> !hlfir.expr<4x7x2x!fir.logical<4>> {
+  ^bb0(%arg3: index, %arg4: index, %arg5: index):
+    %11 = hlfir.designate %6 (%arg3, %arg4, %arg5)  : (!fir.ref<!fir.array<4x7x2xi32>>, index, index, index) -> !fir.ref<i32>
+    %12 = fir.load %11 : !fir.ref<i32>
+    %13 = arith.cmpi sge, %12, %7 : i32
+    %14 = fir.convert %13 : (i1) -> !fir.logical<4>
+    hlfir.yield_element %14 : !fir.logical<4>
+  }
+  %9 = hlfir.count %8 : (!hlfir.expr<4x7x2x!fir.logical<4>>) -> i32
+  hlfir.assign %9 to %4#0 : i32, !fir.ref<i32>
+  hlfir.destroy %8 : !hlfir.expr<4x7x2x!fir.logical<4>>
+  %10 = fir.load %4#1 : !fir.ref<i32>
+  return %10 : i32
+}
+// CHECK-LABEL:  func.func @_QFPtest_multi(%arg0: !fir.ref<!fir.array<4x7x2xi32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "row"}, %arg2: !fir.ref<i32> {fir.bindc_name = "val"}) -> i32 {
+// CHECK-NEXT:     %c1_i32 = arith.constant 1 : i32
+// CHECK-NEXT:     %c0_i32 = arith.constant 0 : i32
+// CHECK-NEXT:     %c1 = arith.constant 1 : index
+// CHECK-NEXT:     %c4 = arith.constant 4 : index
+// CHECK-NEXT:     %c7 = arith.constant 7 : index
+// CHECK-NEXT:     %c2 = arith.constant 2 : index
+// CHECK-NEXT:     %[[V0:.*]] = fir.shape %c4, %c7, %c2 : (index, index, index) -> !fir.shape<3>
+// CHECK-NEXT:     %[[V1:.*]]:2 = hlfir.declare %arg0(%[[V0]]) {uniq_name = "_QFFtestEb"} : (!fir.ref<!fir.array<4x7x2xi32>>, !fir.shape<3>) -> (!fir.ref<!fir.array<4x7x2xi32>>, !fir.ref<!fir.array<4x7x2xi32>>)
+// CHECK-NEXT:     %[[V2:.*]]:2 = hlfir.declare %arg1 {uniq_name = "_QFFtestErow"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK-NEXT:     %[[V3:.*]] = fir.alloca i32 {bindc_name = "test", uniq_name = "_QFFtestEtest"}
+// CHECK-NEXT:     %[[V4:.*]]:2 = hlfir.declare %[[V3]] {uniq_name = "_QFFtestEtest"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK-NEXT:     %[[V5:.*]]:2 = hlfir.declare %arg2 {uniq_name = "_QFFtestEval"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK-NEXT:     %[[V6:.*]] = hlfir.designate %[[V1]]#0 (%c1:%c4:%c1, %c1:%c7:%c1, %c1:%c2:%c1)  shape %[[V0]] : (!fir.ref<!fir.array<4x7x2xi32>>, index, index, index, index, index, index, index, index, index, !fir.shape<3>) -> !fir.ref<!fir.array<4x7x2xi32>>
+// CHECK-NEXT:     %[[V7:.*]] = fir.load %[[V5]]#0 : !fir.ref<i32>
+// CHECK-NEXT:     %[[V8:.*]] = fir.do_loop %arg3 = %c1 to %c4 step %c1 iter_args(%arg4 = %c0_i32) -> (i32) {
+// CHECK-NEXT:       %[[V10:.*]] = fir.do_loop %arg5 = %c1 to %c7 step %c1 iter_args(%arg6 = %arg4) -> (i32) {
+// CHECK-NEXT:         %[[V11:.*]] = fir.do_loop %arg7 = %c1 to %c2 step %c1 iter_args(%arg8 = %arg6) -> (i32) {
+// CHECK-NEXT:           %[[V12:.*]] = hlfir.designate %[[V6]] (%arg3, %arg5, %arg7)  : (!fir.ref<!fir.array<4x7x2xi32>>, index, index, index) -> !fir.ref<i32>
+// CHECK-NEXT:           %[[V13:.*]] = fir.load %[[V12]] : !fir.ref<i32>
+// CHECK-NEXT:           %[[V14:.*]] = arith.cmpi sge, %[[V13]], %[[V7]] : i32
+// CHECK-NEXT:           %[[V15:.*]] = arith.addi %arg8, %c1_i32 : i32
+// CHECK-NEXT:           %[[V16:.*]] = arith.select %[[V14]], %[[V15]], %arg8 : i32
+// CHECK-NEXT:           fir.result %[[V16]] : i32
+// CHECK-NEXT:         }
+// CHECK-NEXT:         fir.result %[[V11]] : i32
+// CHECK-NEXT:       }
+// CHECK-NEXT:       fir.result %[[V10]] : i32
+// CHECK-NEXT:     }
+// CHECK-NEXT:     hlfir.assign %[[V8]] to %[[V4]]#0 : i32, !fir.ref<i32>
+// CHECK-NEXT:     %[[V9:.*]] = fir.load %[[V4]]#1 : !fir.ref<i32>
+// CHECK-NEXT:     return %[[V9]] : i32
+
+
+
+
+
+func.func @_QFPtest_rec_sum(%arg0: !fir.ref<!fir.array<4x7xi32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "row"}, %arg2: !fir.ref<i32> {fir.bindc_name = "val"}) -> i32 {
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c7 = arith.constant 7 : index
+  %0 = fir.shape %c4, %c7 : (index, index) -> !fir.shape<2>
+  %1:2 = hlfir.declare %arg0(%0) {uniq_name = "_QFFtestEb"} : (!fir.ref<!fir.array<4x7xi32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<4x7xi32>>, !fir.ref<!fir.array<4x7xi32>>)
+  %2:2 = hlfir.declare %arg1 {uniq_name = "_QFFtestErow"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %3 = fir.alloca i32 {bindc_name = "test", uniq_name = "_QFFtestEtest"}
+  %4:2 = hlfir.declare %3 {uniq_name = "_QFFtestEtest"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %5:2 = hlfir.declare %arg2 {uniq_name = "_QFFtestEval"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %6 = fir.load %2#0 : !fir.ref<i32>
+  %7 = fir.convert %6 : (i32) -> i64
+  %8 = fir.shape %c7 : (index) -> !fir.shape<1>
+  %9 = hlfir.designate %1#0 (%7, %c1:%c7:%c1)  shape %8 : (!fir.ref<!fir.array<4x7xi32>>, i64, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<7xi32>>
+  %10 = fir.load %5#0 : !fir.ref<i32>
+  %11 = hlfir.elemental %8 unordered : (!fir.shape<1>) -> !hlfir.expr<7xi32> {
+  ^bb0(%arg3: index):
+    %15 = hlfir.designate %9 (%arg3)  : (!fir.box<!fir.array<7xi32>>, index) -> !fir.ref<i32>
+    %16 = fir.load %15 : !fir.ref<i32>
+    hlfir.yield_element %16 : i32
+  }
+  %12 = hlfir.elemental %8 unordered : (!fir.shape<1>) -> !hlfir.expr<7x!fir.logical<4>> {
+  ^bb0(%arg3: index):
+    %15 = hlfir.sum %11 : (!hlfir.expr<7xi32>) -> i32
+    %16 = arith.cmpi sge, %15, %10 : i32
+    %17 = fir.convert %16 : (i1) -> !fir.logical<4>
+    hlfir.yield_element %17 : !fir.logical<4>
+  }
+  %13 = hlfir.count %12 : (!hlfir.expr<7x!fir.logical<4>>) -> i32
+  hlfir.assign %13 to %4#0 : i32, !fir.ref<i32>
+  hlfir.destroy %12 : !hlfir.expr<7x!fir.logical<4>>
+  hlfir.destroy %11 : !hlfir.expr<7xi32>
+  %14 = fir.load %4#1 : !fir.ref<i32>
+  return %14 : i32
+}
+// CHECK-LABEL:  func.func @_QFPtest_rec_sum(%arg0: !fir.ref<!fir.array<4x7xi32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "row"}, %arg2: !fir.ref<i32> {fir.bindc_name = "val"}) -> i32 {
+// CHECK:    %[[V12:.*]] = fir.do_loop %arg3 = %c1 to %c7 step %c1 iter_args(%arg4 = %c0_i32) -> (i32) {
+// CHECK:      %[[V14:.*]] = hlfir.sum %[[V11]] : (!hlfir.expr<7xi32>) -> i32
+// CHECK:      %[[V15:.*]] = arith.cmpi sge, %[[V14]], %[[V10]] : i32
+// CHECK:      %[[V16:.*]] = arith.addi %arg4, %c1_i32 : i32
+// CHECK:      %[[V17:.*]] = arith.select %[[V15]], %[[V16]], %arg4 : i32
+// CHECK:      fir.result %[[V17]] : i32
+// CHECK:    }
+
+
+
+
+func.func @_QFPtest_rec_count(%arg0: !fir.ref<!fir.array<4x7xi32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "row"}, %arg2: !fir.ref<i32> {fir.bindc_name = "val"}) -> i32 {
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c7 = arith.constant 7 : index
+  %0 = fir.shape %c4, %c7 : (index, index) -> !fir.shape<2>
+  %1:2 = hlfir.declare %arg0(%0) {uniq_name = "_QFFtestEb"} : (!fir.ref<!fir.array<4x7xi32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<4x7xi32>>, !fir.ref<!fir.array<4x7xi32>>)
+  %2:2 = hlfir.declare %arg1 {uniq_name = "_QFFtestErow"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %3 = fir.alloca i32 {bindc_name = "test", uniq_name = "_QFFtestEtest"}
+  %4:2 = hlfir.declare %3 {uniq_name = "_QFFtestEtest"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %5:2 = hlfir.declare %arg2 {uniq_name = "_QFFtestEval"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %6 = fir.load %2#0 : !fir.ref<i32>
+  %7 = fir.convert %6 : (i32) -> i64
+  %8 = fir.shape %c7 : (index) -> !fir.shape<1>
+  %9 = hlfir.designate %1#0 (%7, %c1:%c7:%c1)  shape %8 : (!fir.ref<!fir.array<4x7xi32>>, i64, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<7xi32>>
+  %10 = fir.load %5#0 : !fir.ref<i32>
+  %11 = hlfir.elemental %8 unordered : (!fir.shape<1>) -> !hlfir.expr<7x!fir.logical<4>> {
+  ^bb0(%arg3: index):
+    %15 = hlfir.designate %9 (%arg3)  : (!fir.box<!fir.array<7xi32>>, index) -> !fir.ref<i32>
+    %16 = fir.load %15 : !fir.ref<i32>
+    %17 = arith.cmpi sge, %16, %10 : i32
+    %18 = fir.convert %17 : (i1) -> !fir.logical<4>
+    hlfir.yield_element %18 : !fir.logical<4>
+  }
+  %12 = hlfir.elemental %8 unordered : (!fir.shape<1>) -> !hlfir.expr<7x!fir.logical<4>> {
+  ^bb0(%arg3: index):
+    %15 = hlfir.count %11 : (!hlfir.expr<7x!fir.logical<4>>) -> i32
+    %16 = arith.cmpi sge, %15, %10 : i32
+    %17 = fir.convert %16 : (i1) -> !fir.logical<4>
+    hlfir.yield_element %17 : !fir.logical<4>
+  }
+  %13 = hlfir.count %12 : (!hlfir.expr<7x!fir.logical<4>>) -> i32
+  hlfir.assign %13 to %4#0 : i32, !fir.ref<i32>
+  hlfir.destroy %12 : !hlfir.expr<7x!fir.logical<4>>
+  hlfir.destroy %11 : !hlfir.expr<7x!fir.logical<4>>
+  %14 = fir.load %4#1 : !fir.ref<i32>
+  return %14 : i32
+}
+// CHECK-LABEL:  func.func @_QFPtest_rec_count(%arg0: !fir.ref<!fir.array<4x7xi32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "row"}, %arg2: !fir.ref<i32> {fir.bindc_name = "val"}) -> i32 {
+// CHECK:    %[[V11:.*]] = fir.do_loop %arg3 = %c1 to %c7 step %c1 iter_args(%arg4 = %c0_i32) -> (i32) {
+// CHECK:      %[[V13:.*]] = fir.do_loop %arg5 = %c1 to %c7 step %c1 iter_args(%arg6 = %c0_i32) -> (i32) {
+// CHECK:        %[[V17:.*]] = hlfir.designate %[[V9]] (%arg5)  : (!fir.box<!fir.array<7xi32>>, index) -> !fir.ref<i32>
+// CHECK:        %[[V18:.*]] = fir.load %[[V17]] : !fir.ref<i32>
+// CHECK:        %[[V19:.*]] = arith.cmpi sge, %[[V18]], %[[V10]] : i32
+// CHECK:        %[[V20:.*]] = arith.addi %arg6, %c1_i32 : i32
+// CHECK:        %[[V21:.*]] = arith.select %[[V19]], %[[V20]], %arg6 : i32
+// CHECK:        fir.result %[[V21]] : i32
+// CHECK:      }
+// CHECK:      %[[V14:.*]] = arith.cmpi sge, %[[V13]], %[[V10]] : i32
+// CHECK:      %[[V15:.*]] = arith.addi %arg4, %c1_i32 : i32
+// CHECK:      %[[V16:.*]] = arith.select %[[V14]], %[[V15]], %arg4 : i32
+// CHECK:      fir.result %[[V16]] : i32
+// CHECK:    }

>From a377173ba59866bffc32a1eb2f54694ed0ccf81f Mon Sep 17 00:00:00 2001
From: David Green <david.green at arm.com>
Date: Tue, 9 Jan 2024 09:27:35 +0000
Subject: [PATCH 2/2] Reverse loop generation order

---
 .../HLFIR/Transforms/OptimizedBufferization.cpp     | 13 +++++++------
 flang/test/HLFIR/count-elemental.fir                |  6 +++---
 2 files changed, 10 insertions(+), 9 deletions(-)

diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index 7c839f1d20af52..afdcda29e2fe19 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -671,14 +671,15 @@ static mlir::Value generateReductionLoop(fir::FirOpBuilder &builder,
   mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1);
 
   // Create a reduction loop nest. We use one-based indices so that they can be
-  // passed to the elemental.
-  llvm::SmallVector<mlir::Value> indices;
+  // passed to the elemental, and reverse the order so that they can be
+  // generated in column-major order for better performance.
+  llvm::SmallVector<mlir::Value> indices(extents.size(), mlir::Value{});
   for (unsigned i = 0; i < extents.size(); ++i) {
-    auto loop =
-        builder.create<fir::DoLoopOp>(loc, oneIdx, extents[i], oneIdx, false,
-                                      /*finalCountValue=*/false, reduction);
+    auto loop = builder.create<fir::DoLoopOp>(
+        loc, oneIdx, extents[extents.size() - i - 1], oneIdx, false,
+        /*finalCountValue=*/false, reduction);
     reduction = loop.getRegionIterArgs()[0];
-    indices.push_back(loop.getInductionVar());
+    indices[extents.size() - i - 1] = loop.getInductionVar();
     // Set insertion point to the loop body so that the next loop
     // is inserted inside the current one.
     builder.setInsertionPointToStart(loop.getBody());
diff --git a/flang/test/HLFIR/count-elemental.fir b/flang/test/HLFIR/count-elemental.fir
index 1641e0fae6fb55..0df5cc3c031ea5 100644
--- a/flang/test/HLFIR/count-elemental.fir
+++ b/flang/test/HLFIR/count-elemental.fir
@@ -191,10 +191,10 @@ func.func @_QFPtest_multi(%arg0: !fir.ref<!fir.array<4x7x2xi32>> {fir.bindc_name
 // CHECK-NEXT:     %[[V5:.*]]:2 = hlfir.declare %arg2 {uniq_name = "_QFFtestEval"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 // CHECK-NEXT:     %[[V6:.*]] = hlfir.designate %[[V1]]#0 (%c1:%c4:%c1, %c1:%c7:%c1, %c1:%c2:%c1)  shape %[[V0]] : (!fir.ref<!fir.array<4x7x2xi32>>, index, index, index, index, index, index, index, index, index, !fir.shape<3>) -> !fir.ref<!fir.array<4x7x2xi32>>
 // CHECK-NEXT:     %[[V7:.*]] = fir.load %[[V5]]#0 : !fir.ref<i32>
-// CHECK-NEXT:     %[[V8:.*]] = fir.do_loop %arg3 = %c1 to %c4 step %c1 iter_args(%arg4 = %c0_i32) -> (i32) {
+// CHECK-NEXT:     %[[V8:.*]] = fir.do_loop %arg3 = %c1 to %c2 step %c1 iter_args(%arg4 = %c0_i32) -> (i32) {
 // CHECK-NEXT:       %[[V10:.*]] = fir.do_loop %arg5 = %c1 to %c7 step %c1 iter_args(%arg6 = %arg4) -> (i32) {
-// CHECK-NEXT:         %[[V11:.*]] = fir.do_loop %arg7 = %c1 to %c2 step %c1 iter_args(%arg8 = %arg6) -> (i32) {
-// CHECK-NEXT:           %[[V12:.*]] = hlfir.designate %[[V6]] (%arg3, %arg5, %arg7)  : (!fir.ref<!fir.array<4x7x2xi32>>, index, index, index) -> !fir.ref<i32>
+// CHECK-NEXT:         %[[V11:.*]] = fir.do_loop %arg7 = %c1 to %c4 step %c1 iter_args(%arg8 = %arg6) -> (i32) {
+// CHECK-NEXT:           %[[V12:.*]] = hlfir.designate %[[V6]] (%arg7, %arg5, %arg3)  : (!fir.ref<!fir.array<4x7x2xi32>>, index, index, index) -> !fir.ref<i32>
 // CHECK-NEXT:           %[[V13:.*]] = fir.load %[[V12]] : !fir.ref<i32>
 // CHECK-NEXT:           %[[V14:.*]] = arith.cmpi sge, %[[V13]], %[[V7]] : i32
 // CHECK-NEXT:           %[[V15:.*]] = arith.addi %arg8, %c1_i32 : i32



More information about the flang-commits mailing list