[flang-commits] [flang] [flang] Fixed LoopVersioning for array slices. (PR #65703)

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Thu Sep 7 18:02:15 PDT 2023


https://github.com/vzakhari created https://github.com/llvm/llvm-project/pull/65703:

The first test case added in the LIT test demonstrates the problem.
Even though we did not consider the inner loop as a candidate for
the transformation due to the array_coor with a slice, we decided to
version the outer loop for the same function argument.
During the cloning of the outer loop we dropped the slicing completely
producing invalid code.

I restructured the code so that we record all arg uses that cannot be
transformed (regardless of the reason), and then fixup the usage
information across the loop nests. I also noticed that we may generate
redundant contiguity checks for the inner loops, so I fixed it
since it was easy with the new way of keeping the usage data.


>From 1edd969cb5439a657fb69356c8af461b42009dab Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Thu, 7 Sep 2023 17:04:13 -0700
Subject: [PATCH] [flang] Fixed LoopVersioning for array slices.

The first test case added in the LIT test demonstrates the problem.
Even though we did not consider the inner loop as a candidate for
the transformation due to the array_coor with a slice, we decided to
version the outer loop for the same function argument.
During the cloning of the outer loop we dropped the slicing completely
producing invalid code.

I restructured the code so that we record all arg uses that cannot be
transformed (regardless of the reason), and then fixup the usage
information across the loop nests. I also noticed that we may generate
redundant contiguity checks for the inner loops, so I fixed it
since it was easy with the new way of keeping the usage data.
---
 .../Optimizer/Transforms/LoopVersioning.cpp   | 217 ++++++++++++++----
 flang/test/Transforms/loop-versioning.fir     | 172 +++++++++++++-
 2 files changed, 346 insertions(+), 43 deletions(-)

diff --git a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp
index b524b11f5966443..4d3ea51ae1a5f71 100644
--- a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp
+++ b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp
@@ -77,6 +77,72 @@ class LoopVersioningPass
   void runOnOperation() override;
 };
 
+/// @struct ArgInfo
+/// A structure to hold an argument, the size of the argument and dimension
+/// information.
+struct ArgInfo {
+  mlir::Value arg;
+  size_t size;
+  unsigned rank;
+  fir::BoxDimsOp dims[CFI_MAX_RANK];
+};
+
+/// @struct ArgsUsageInLoop
+/// A structure providing information about the function arguments
+/// usage by the instructions immediately nested in a loop.
+struct ArgsUsageInLoop {
+  /// Mapping between the memref operand of an array indexing
+  /// operation (e.g. fir.coordinate_of) and the argument information.
+  llvm::DenseMap<mlir::Value, ArgInfo> usageInfo;
+  /// Some array indexing operations inside a loop cannot be transformed.
+  /// This vector holds the memref operands of such operations.
+  /// The vector is used to make sure that we do not try to transform
+  /// any outer loop, since this will imply the operation rewrite
+  /// in this loop.
+  llvm::SetVector<mlir::Value> cannotTransform;
+
+  // Debug dump of the structure members assuming that
+  // the information has been collected for the given loop.
+  void dump(fir::DoLoopOp loop) const {
+    // clang-format off
+    LLVM_DEBUG(
+        mlir::OpPrintingFlags printFlags;
+        printFlags.skipRegions();
+        llvm::dbgs() << "Arguments usage info for loop:\n";
+        loop.print(llvm::dbgs(), printFlags);
+        llvm::dbgs() << "\nUsed args:\n";
+        for (auto &use : usageInfo) {
+          mlir::Value v = use.first;
+          v.print(llvm::dbgs(), printFlags);
+          llvm::dbgs() << "\n";
+        }
+        llvm::dbgs() << "\nCannot transform args:\n";
+        for (mlir::Value arg : cannotTransform) {
+          arg.print(llvm::dbgs(), printFlags);
+          llvm::dbgs() << "\n";
+        }
+        llvm::dbgs() << "====\n"
+    );
+    // clang-format on
+  }
+
+  // Erase usageInfo and cannotTransform entries for a set
+  // of given arguments.
+  void eraseUsage(const llvm::SetVector<mlir::Value> &args) {
+    for (auto &arg : args)
+      usageInfo.erase(arg);
+    cannotTransform.set_subtract(args);
+  }
+
+  // Erase usageInfo and cannotTransform entries for a set
+  // of given arguments provided in the form of usageInfo map.
+  void eraseUsage(const llvm::DenseMap<mlir::Value, ArgInfo> &args) {
+    for (auto &arg : args) {
+      usageInfo.erase(arg.first);
+      cannotTransform.remove(arg.first);
+    }
+  }
+};
 } // namespace
 
 /// @c replaceOuterUses - replace uses outside of @c op with result of @c
@@ -179,16 +245,6 @@ void LoopVersioningPass::runOnOperation() {
   LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
   mlir::func::FuncOp func = getOperation();
 
-  /// @c ArgInfo
-  /// A structure to hold an argument, the size of the argument and dimension
-  /// information.
-  struct ArgInfo {
-    mlir::Value arg;
-    size_t size;
-    unsigned rank;
-    fir::BoxDimsOp dims[CFI_MAX_RANK];
-  };
-
   // First look for arguments with assumed shape = unknown extent in the lowest
   // dimension.
   LLVM_DEBUG(llvm::dbgs() << "Func-name:" << func.getSymName() << "\n");
@@ -224,58 +280,137 @@ void LoopVersioningPass::runOnOperation() {
     }
   }
 
-  if (argsOfInterest.empty())
+  if (argsOfInterest.empty()) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "No suitable arguments.\n=== End " DEBUG_TYPE " ===\n");
     return;
+  }
 
-  struct OpsWithArgs {
-    mlir::Operation *op;
-    mlir::SmallVector<ArgInfo, 4> argsAndDims;
-  };
-  // Now see if those arguments are used inside any loop.
-  mlir::SmallVector<OpsWithArgs, 4> loopsOfInterest;
+  // A list of all loops in the function in post-order.
+  mlir::SmallVector<fir::DoLoopOp> originalLoops;
+  // Information about the arguments usage by the instructions
+  // immediately nested in a loop.
+  llvm::DenseMap<fir::DoLoopOp, ArgsUsageInLoop> argsInLoops;
 
+  // Traverse the loops in post-order and see
+  // if those arguments are used inside any loop.
   func.walk([&](fir::DoLoopOp loop) {
     mlir::Block &body = *loop.getBody();
-    mlir::SmallVector<ArgInfo, 4> argsInLoop;
+    auto &argsInLoop = argsInLoops[loop];
+    originalLoops.push_back(loop);
     body.walk([&](mlir::Operation *op) {
-      // support either fir.array_coor or fir.coordinate_of
-      if (auto arrayCoor = mlir::dyn_cast<fir::ArrayCoorOp>(op)) {
-        // no support currently for sliced arrays
-        if (arrayCoor.getSlice())
-          return;
-      } else if (!mlir::isa<fir::CoordinateOp>(op)) {
+      // Support either fir.array_coor or fir.coordinate_of.
+      if (!mlir::isa<fir::ArrayCoorOp, fir::CoordinateOp>(op))
         return;
-      }
-
-      // The current operation could be inside another loop than
-      // the one we're currently processing. Skip it, we'll get
-      // to it later.
+      // Process only operations immediately nested in the current loop.
       if (op->getParentOfType<fir::DoLoopOp>() != loop)
         return;
       mlir::Value operand = op->getOperand(0);
       for (auto a : argsOfInterest) {
         if (a.arg == normaliseVal(operand)) {
-          // use the reboxed value, not the block arg when re-creating the loop:
+          // Use the reboxed value, not the block arg when re-creating the loop.
+          // TODO: should we check that the operand dominates the loop?
+          // If this might be a case, we should record such operands in
+          // argsInLoop.cannotTransform, so that they disable the transformation
+          // for the parent loops as well.
           a.arg = operand;
-          // Only add if it's not already in the list.
-          if (std::find_if(argsInLoop.begin(), argsInLoop.end(), [&](auto it) {
-                return it.arg == a.arg;
-              }) == argsInLoop.end()) {
 
-            argsInLoop.push_back(a);
+          // No support currently for sliced arrays.
+          // This means that we cannot transform properly
+          // instructions referencing a.arg in the whole loop
+          // nest this loop is located in.
+          if (auto arrayCoor = mlir::dyn_cast<fir::ArrayCoorOp>(op))
+            if (arrayCoor.getSlice())
+              argsInLoop.cannotTransform.insert(a.arg);
+
+          if (argsInLoop.cannotTransform.contains(a.arg)) {
+            // Remove any previously recorded usage, if any.
+            argsInLoop.usageInfo.erase(a.arg);
             break;
           }
+
+          // Record the a.arg usage, if not recorded yet.
+          argsInLoop.usageInfo.try_emplace(a.arg, a);
+          break;
         }
       }
     });
-
-    if (!argsInLoop.empty()) {
-      OpsWithArgs ops = {loop, argsInLoop};
-      loopsOfInterest.push_back(ops);
-    }
   });
-  if (loopsOfInterest.empty())
+
+  // Dump loops info after initial collection.
+  // clang-format off
+  LLVM_DEBUG(
+      llvm::dbgs() << "Initial usage info:\n";
+      for (fir::DoLoopOp loop : originalLoops) {
+        auto &argsInLoop = argsInLoops[loop];
+        argsInLoop.dump(loop);
+      }
+  );
+  // clang-format on
+
+  // Clear argument usage for parent loops if an inner loop
+  // contains a non-transformable usage.
+  for (fir::DoLoopOp loop : originalLoops) {
+    auto &argsInLoop = argsInLoops[loop];
+    if (argsInLoop.cannotTransform.empty())
+      continue;
+
+    fir::DoLoopOp parent = loop;
+    while ((parent = parent->getParentOfType<fir::DoLoopOp>()))
+      argsInLoops[parent].eraseUsage(argsInLoop.cannotTransform);
+  }
+
+  // If an argument access can be optimized in a loop and
+  // its descendant loop, then it does not make sense to
+  // generate the contiguity check for the descendant loop.
+  // The check will be produced as part of the ancestor
+  // loop's transformation. So we can clear the argument
+  // usage for all descendant loops.
+  for (fir::DoLoopOp loop : originalLoops) {
+    auto &argsInLoop = argsInLoops[loop];
+    if (argsInLoop.usageInfo.empty())
+      continue;
+
+    loop.getBody()->walk([&](fir::DoLoopOp dloop) {
+      argsInLoops[dloop].eraseUsage(argsInLoop.usageInfo);
+    });
+  }
+
+  // clang-format off
+  LLVM_DEBUG(
+      llvm::dbgs() << "Final usage info:\n";
+      for (fir::DoLoopOp loop : originalLoops) {
+        auto &argsInLoop = argsInLoops[loop];
+        argsInLoop.dump(loop);
+      }
+  );
+  // clang-format on
+
+  // Reduce the collected information to a list of loops
+  // with attached arguments usage information.
+  // The list must hold the loops in post order, so that
+  // the inner loops are transformed before the outer loops.
+  struct OpsWithArgs {
+    mlir::Operation *op;
+    mlir::SmallVector<ArgInfo, 4> argsAndDims;
+  };
+  mlir::SmallVector<OpsWithArgs, 4> loopsOfInterest;
+  for (fir::DoLoopOp loop : originalLoops) {
+    auto &argsInLoop = argsInLoops[loop];
+    if (argsInLoop.usageInfo.empty())
+      continue;
+    OpsWithArgs info;
+    info.op = loop;
+    for (auto &arg : argsInLoop.usageInfo)
+      info.argsAndDims.push_back(arg.second);
+    loopsOfInterest.emplace_back(std::move(info));
+  }
+
+  if (loopsOfInterest.empty()) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "No loops to transform.\n=== End " DEBUG_TYPE " ===\n");
     return;
+  }
 
   // If we get here, there are loops to process.
   fir::FirOpBuilder builder{module, std::move(kindMap)};
diff --git a/flang/test/Transforms/loop-versioning.fir b/flang/test/Transforms/loop-versioning.fir
index 566903d0897f237..f2768d7325f7407 100644
--- a/flang/test/Transforms/loop-versioning.fir
+++ b/flang/test/Transforms/loop-versioning.fir
@@ -118,8 +118,6 @@ func.func @sum1dfixed(%arg0: !fir.ref<!fir.array<?xf64>> {fir.bindc_name = "a"},
 
 // -----
 
-// RUN: fir-opt --loop-versioning %s | FileCheck %s
-
 // Check that "no result" from a versioned loop works correctly
 // This code was the basis for this, but `read` is replaced with a function called Func
 // subroutine test3(x, y)
@@ -1266,4 +1264,174 @@ func.func @test_optional_arg(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name
 // CHECK:           fir.store %[[VAL_166:.*]]#1 to %[[VAL_18]] : !fir.ref<i32>
 // CHECK:           return
 // CHECK:         }
+
+// ! Verify that neither of the loops is versioned
+// ! due to the array section in the inner loop:
+// subroutine test_slice(x)
+//   real :: x(:,:)
+//   do i=10,100
+//      x(i,7) = 1.0
+//      x(i,3:5) = 2.0
+//   end do
+// end subroutine test_slice
+func.func @_QPtest_slice(%arg0: !fir.box<!fir.array<?x?xf32>> {fir.bindc_name = "x"}) {
+  %c10 = arith.constant 10 : index
+  %c100 = arith.constant 100 : index
+  %c6_i64 = arith.constant 6 : i64
+  %c3 = arith.constant 3 : index
+  %c2 = arith.constant 2 : index
+  %c5 = arith.constant 5 : index
+  %cst = arith.constant 2.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c1_i64 = arith.constant 1 : i64
+  %cst_0 = arith.constant 1.000000e+00 : f32
+  %c1 = arith.constant 1 : index
+  %0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFtest_sliceEi"}
+  %1 = fir.convert %c10 : (index) -> i32
+  %2:2 = fir.do_loop %arg1 = %c10 to %c100 step %c1 iter_args(%arg2 = %1) -> (index, i32) {
+    fir.store %arg2 to %0 : !fir.ref<i32>
+    %3 = fir.load %0 : !fir.ref<i32>
+    %4 = fir.convert %3 : (i32) -> i64
+    %5 = arith.subi %4, %c1_i64 : i64
+    %6 = fir.coordinate_of %arg0, %5, %c6_i64 : (!fir.box<!fir.array<?x?xf32>>, i64, i64) -> !fir.ref<f32>
+    fir.store %cst_0 to %6 : !fir.ref<f32>
+    %7 = fir.load %0 : !fir.ref<i32>
+    %8 = fir.convert %7 : (i32) -> i64
+    %9 = fir.undefined index
+    %10 = fir.convert %7 : (i32) -> index
+    %11 = fir.slice %8, %9, %9, %c3, %c5, %c1 : (i64, index, index, index, index, index) -> !fir.slice<2>
+    %12 = fir.undefined !fir.array<?x?xf32>
+    %13 = fir.do_loop %arg3 = %c0 to %c2 step %c1 unordered iter_args(%arg4 = %12) -> (!fir.array<?x?xf32>) {
+      %18 = arith.addi %arg3, %c1 : index
+      %19 = fir.array_coor %arg0 [%11] %10, %18 : (!fir.box<!fir.array<?x?xf32>>, !fir.slice<2>, index, index) -> !fir.ref<f32>
+      fir.store %cst to %19 : !fir.ref<f32>
+      fir.result %12 : !fir.array<?x?xf32>
+    }
+    %14 = arith.addi %arg1, %c1 : index
+    %15 = fir.convert %c1 : (index) -> i32
+    %16 = fir.load %0 : !fir.ref<i32>
+    %17 = arith.addi %16, %15 : i32
+    fir.result %14, %17 : index, i32
+  }
+  fir.store %2#1 to %0 : !fir.ref<i32>
+  return
+}
+// CHECK-LABEL:   func.func @_QPtest_slice(
+// CHECK-NOT: fir.if
+
+// ! Verify versioning for argument 'x' but not for 'y':
+// subroutine test_independent_args(x, y)
+//   real :: x(:,:), y(:,:)
+//   do i=10,100
+//      x(i,7) = 1.0
+//      y(i,3:5) = 2.0
+//   end do
+// end subroutine test_independent_args
+func.func @_QPtest_independent_args(%arg0: !fir.box<!fir.array<?x?xf32>> {fir.bindc_name = "x"}, %arg1: !fir.box<!fir.array<?x?xf32>> {fir.bindc_name = "y"}) {
+  %c10 = arith.constant 10 : index
+  %c100 = arith.constant 100 : index
+  %c6_i64 = arith.constant 6 : i64
+  %c3 = arith.constant 3 : index
+  %c2 = arith.constant 2 : index
+  %c5 = arith.constant 5 : index
+  %cst = arith.constant 2.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c1_i64 = arith.constant 1 : i64
+  %cst_0 = arith.constant 1.000000e+00 : f32
+  %c1 = arith.constant 1 : index
+  %0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFtest_independent_argsEi"}
+  %1 = fir.convert %c10 : (index) -> i32
+  %2:2 = fir.do_loop %arg2 = %c10 to %c100 step %c1 iter_args(%arg3 = %1) -> (index, i32) {
+    fir.store %arg3 to %0 : !fir.ref<i32>
+    %3 = fir.load %0 : !fir.ref<i32>
+    %4 = fir.convert %3 : (i32) -> i64
+    %5 = arith.subi %4, %c1_i64 : i64
+    %6 = fir.coordinate_of %arg0, %5, %c6_i64 : (!fir.box<!fir.array<?x?xf32>>, i64, i64) -> !fir.ref<f32>
+    fir.store %cst_0 to %6 : !fir.ref<f32>
+    %7 = fir.load %0 : !fir.ref<i32>
+    %8 = fir.convert %7 : (i32) -> i64
+    %9 = fir.undefined index
+    %10 = fir.convert %7 : (i32) -> index
+    %11 = fir.slice %8, %9, %9, %c3, %c5, %c1 : (i64, index, index, index, index, index) -> !fir.slice<2>
+    %12 = fir.undefined !fir.array<?x?xf32>
+    %13 = fir.do_loop %arg4 = %c0 to %c2 step %c1 unordered iter_args(%arg5 = %12) -> (!fir.array<?x?xf32>) {
+      %18 = arith.addi %arg4, %c1 : index
+      %19 = fir.array_coor %arg1 [%11] %10, %18 : (!fir.box<!fir.array<?x?xf32>>, !fir.slice<2>, index, index) -> !fir.ref<f32>
+      fir.store %cst to %19 : !fir.ref<f32>
+      fir.result %12 : !fir.array<?x?xf32>
+    }
+    %14 = arith.addi %arg2, %c1 : index
+    %15 = fir.convert %c1 : (index) -> i32
+    %16 = fir.load %0 : !fir.ref<i32>
+    %17 = arith.addi %16, %15 : i32
+    fir.result %14, %17 : index, i32
+  }
+  fir.store %2#1 to %0 : !fir.ref<i32>
+  return
+}
+// CHECK-LABEL:   func.func @_QPtest_independent_args(
+// CHECK-SAME:        %[[VAL_0:.*]]: !fir.box<!fir.array<?x?xf32>> {fir.bindc_name = "x"},
+// CHECK-SAME:        %[[VAL_1:.*]]: !fir.box<!fir.array<?x?xf32>> {fir.bindc_name = "y"}) {
+// CHECK:           %[[VAL_16:.*]]:3 = fir.box_dims %[[VAL_0]], %{{.*}} : (!fir.box<!fir.array<?x?xf32>>, index) -> (index, index, index)
+// CHECK:           %[[VAL_19:.*]] = arith.constant 4 : index
+// CHECK:           %[[VAL_20:.*]] = arith.cmpi eq, %[[VAL_16]]#2, %[[VAL_19]] : index
+// CHECK:           %[[VAL_21:.*]]:2 = fir.if %[[VAL_20]] -> (index, i32) {
+// CHECK-NOT: fir.if
+
+
+// ! Verify that the whole loop nest is versioned
+// ! without additional contiguity check for the inner loop:
+// subroutine test_loop_nest(x)
+//   real :: x(:)
+//   do i=10,100
+//      x(i) = 1.0
+//      do j=10,100
+//         x(j) = 2.0
+//      end do
+//   end do
+// end subroutine test_loop_nest
+func.func @_QPtest_loop_nest(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "x"}) {
+  %c10 = arith.constant 10 : index
+  %c100 = arith.constant 100 : index
+  %cst = arith.constant 2.000000e+00 : f32
+  %c1_i64 = arith.constant 1 : i64
+  %cst_0 = arith.constant 1.000000e+00 : f32
+  %c1 = arith.constant 1 : index
+  %0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFtest_loop_nestEi"}
+  %1 = fir.alloca i32 {bindc_name = "j", uniq_name = "_QFtest_loop_nestEj"}
+  %2 = fir.convert %c10 : (index) -> i32
+  %3:2 = fir.do_loop %arg1 = %c10 to %c100 step %c1 iter_args(%arg2 = %2) -> (index, i32) {
+    fir.store %arg2 to %0 : !fir.ref<i32>
+    %4 = fir.load %0 : !fir.ref<i32>
+    %5 = fir.convert %4 : (i32) -> i64
+    %6 = arith.subi %5, %c1_i64 : i64
+    %7 = fir.coordinate_of %arg0, %6 : (!fir.box<!fir.array<?xf32>>, i64) -> !fir.ref<f32>
+    fir.store %cst_0 to %7 : !fir.ref<f32>
+    %8:2 = fir.do_loop %arg3 = %c10 to %c100 step %c1 iter_args(%arg4 = %2) -> (index, i32) {
+      fir.store %arg4 to %1 : !fir.ref<i32>
+      %13 = fir.load %1 : !fir.ref<i32>
+      %14 = fir.convert %13 : (i32) -> i64
+      %15 = arith.subi %14, %c1_i64 : i64
+      %16 = fir.coordinate_of %arg0, %15 : (!fir.box<!fir.array<?xf32>>, i64) -> !fir.ref<f32>
+      fir.store %cst to %16 : !fir.ref<f32>
+      %17 = arith.addi %arg3, %c1 : index
+      %18 = fir.convert %c1 : (index) -> i32
+      %19 = fir.load %1 : !fir.ref<i32>
+      %20 = arith.addi %19, %18 : i32
+      fir.result %17, %20 : index, i32
+    }
+    fir.store %8#1 to %1 : !fir.ref<i32>
+    %9 = arith.addi %arg1, %c1 : index
+    %10 = fir.convert %c1 : (index) -> i32
+    %11 = fir.load %0 : !fir.ref<i32>
+    %12 = arith.addi %11, %10 : i32
+    fir.result %9, %12 : index, i32
+  }
+  fir.store %3#1 to %0 : !fir.ref<i32>
+  return
+}
+// CHECK-LABEL:   func.func @_QPtest_loop_nest(
+// CHECK: fir.if
+// CHECK-NOT: fir.if
+
 } // End module



More information about the flang-commits mailing list