[Mlir-commits] [mlir] [mlir][Transforms] LISH: Improve bypass analysis for loop-like ops (PR #70623)

Matthias Springer llvmlistbot at llvm.org
Tue Oct 31 19:07:15 PDT 2023


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/70623

>From 5feaa44d6f7ec66de821ee3d11cf7e8f3724e5eb Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Wed, 1 Nov 2023 11:06:12 +0900
Subject: [PATCH] [mlir][WIP] Bypass analysis for loops

---
 .../Utils/LoopInvariantCodeMotionUtils.cpp    | 72 ++++++++++++++-----
 .../loop-invariant-subset-hoisting.mlir       | 35 +++++++++
 2 files changed, 91 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index 01318cf7328b543..53bdb7aafe41a0c 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -120,8 +120,10 @@ namespace {
 class MatchingSubsets {
 public:
   /// Insert a subset op.
-  void insert(SubsetOpInterface op) {
+  void insert(SubsetOpInterface op, bool collectHoistableOps = true) {
     allSubsetOps.push_back(op);
+    if (!collectHoistableOps)
+      return;
     if (auto extractionOp =
             dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
       insertExtractionOp(extractionOp);
@@ -148,6 +150,15 @@ class MatchingSubsets {
         });
   }
 
+  /// Populate subset ops starting from the given region iter_arg. Return
+  /// "failure" if non-subset ops are found along the path to the loop yielding
+  /// op or if there is no single path to the tied yielded operand. If
+  /// `collectHoistableOps` is set to "false", subset ops are gathered
+  /// throughout the traversal, but not enumerated by `getHoistableSubsetOps`.
+  LogicalResult populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
+                                           BlockArgument iterArg,
+                                           bool collectHoistableOps = true);
+
 private:
   /// Helper function for equivalence of tensor values. Since only insertion
   /// subset ops (that are also destination style ops) are followed when
@@ -225,18 +236,12 @@ static OpOperand *getSingleTerminatorUse(Value value) {
   return nullptr;
 }
 
-/// Hoist all subset ops that operate on the idx-th region iter_arg of the given
-/// loop-like op and index into loop-invariant subset locations. Return the
-/// newly created loop op (that has extra iter_args) or the original loop op if
-/// nothing was hoisted.
-static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
-                                                BlockArgument iterArg) {
-  IRRewriter rewriter(loopLike.getContext());
+LogicalResult
+MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
+                                            BlockArgument iterArg,
+                                            bool collectHoistableOps) {
   assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
-  auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
-  int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
   Value value = iterArg;
-  MatchingSubsets subsets;
 
   // Traverse use-def chain. Subset ops can be hoisted only if all ops along the
   // use-def chain starting from the region iter_arg are subset extraction or
@@ -249,21 +254,39 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
     Value nextValue = {};
 
     for (OpOperand &use : value.getUses()) {
+      if (auto nestedLoop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
+        // Subset ops in nested loops are collected to check if there are only
+        // disjoint subset ops, but such subset ops are not subject to hoisting.
+        // To hoist subset ops from nested loops, the hoisting transformation
+        // should be run on the nested loop.
+        auto nestedIterArg = nestedLoop.getTiedLoopRegionIterArg(&use);
+        if (!nestedIterArg)
+          return failure();
+        // Note: `populateSubsetOpsAtIterArg` fails if there is no single SSA
+        // use-def chain starting at `nestedIterArg` and terminating in the
+        // tied, yielding operand.
+        if (failed(populateSubsetOpsAtIterArg(nestedLoop, nestedIterArg,
+                                              /*collectHoistableOps=*/false)))
+          return failure();
+        nextValue = nestedLoop.getTiedLoopResult(&use);
+        continue;
+      }
+
       auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner());
       if (!subsetOp)
-        return loopLike;
-      subsets.insert(subsetOp);
+        return failure();
+      insert(subsetOp);
 
       if (auto insertionOp =
               dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) {
         // The value must be used as a destination. (In case of a source, the
         // entire tensor would be read, which would prevent any hoisting.)
         if (&use != &insertionOp.getDestinationOperand())
-          return loopLike;
+          return failure();
         // There must be a single use-def chain from the region iter_arg to the
         // terminator. I.e., only one insertion op. Branches are not supported.
         if (nextValue)
-          return loopLike;
+          return failure();
         nextValue = insertionOp.getUpdatedDestination();
       }
     }
@@ -271,7 +294,7 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
     // Nothing can be hoisted if the chain does not continue with loop yielding
     // op or a subset insertion op.
     if (!nextValue)
-      return loopLike;
+      return failure();
     value = nextValue;
   }
 
@@ -279,6 +302,23 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
   // loop and the yielded value is the `idx`-th operand. (I.e., there is no
   // swapping yield.)
   if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand)
+    return failure();
+
+  return success();
+}
+
+/// Hoist all subset ops that operate on the idx-th region iter_arg of the given
+/// loop-like op and index into loop-invariant subset locations. Return the
+/// newly created loop op (that has extra iter_args) or the original loop op if
+/// nothing was hoisted.
+static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
+                                                BlockArgument iterArg) {
+  assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
+  auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
+  int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
+  IRRewriter rewriter(loopLike.getContext());
+  MatchingSubsets subsets;
+  if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg)))
     return loopLike;
 
   // Hoist all matching extraction-insertion pairs one-by-one.
diff --git a/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
index 5cded4c99182c14..b9161f4e20d1927 100644
--- a/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
+++ b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
@@ -235,3 +235,38 @@ func.func @non_loop_invariant_subset_op(%arg: tensor<?xf32>) -> tensor<?xf32> {
 
   return %0 : tensor<?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @nested_hoisting(
+//  CHECK-SAME:     %[[arg:.*]]: tensor<?xf32>
+func.func @nested_hoisting(%arg: tensor<?xf32>) -> tensor<?xf32> {
+  %lb = "test.foo"() : () -> (index)
+  %ub = "test.foo"() : () -> (index)
+  %step = "test.foo"() : () -> (index)
+
+  // CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]][0] [5] [1]
+  // CHECK: %[[extract2:.*]] = tensor.extract_slice %[[arg]][5] [5] [1]
+  // CHECK: %[[for:.*]]:3 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]], %[[hoisted2:.*]] = %[[extract2]])
+  %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+    %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+    // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]])
+    %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+    %3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
+    // CHECK: %[[for2:.*]]:2 = {{.*}} iter_args(%[[t2:.*]] = %[[t]], %[[hoisted2_nested:.*]] = %[[hoisted2]])
+    %4 = scf.for %iv2 = %lb to %ub step %step iter_args(%t2 = %3) -> (tensor<?xf32>) {
+      %5 = tensor.extract_slice %t2[5][5][1] : tensor<?xf32> to tensor<5xf32>
+      // CHECK: %[[foo2:.*]] = "test.foo"(%[[hoisted2_nested]])
+      %6 = "test.foo"(%5) : (tensor<5xf32>) -> (tensor<5xf32>)
+      %7 = tensor.insert_slice %6 into %t2[5][5][1] : tensor<5xf32> into tensor<?xf32>
+      // CHECK: scf.yield %[[t2]], %[[foo2]]
+      scf.yield %7 : tensor<?xf32>
+    }
+    // CHECK: scf.yield %[[for2]]#0, %[[foo]], %[[for2]]#1
+    scf.yield %4 : tensor<?xf32>
+  }
+  // CHECK: %[[insert:.*]] = tensor.insert_slice %[[for]]#2 into %[[for]]#0[5] [5] [1]
+  // CHECK: %[[insert2:.*]] = tensor.insert_slice %[[for]]#1 into %[[insert]][0] [5] [1]
+  // CHECK: return %[[insert2]]
+  return %0 : tensor<?xf32>
+}



More information about the Mlir-commits mailing list