[mlir] [llvm] [mlir][Transforms] Add loop-invariant subset hoisting (LISH) transformation (PR #70619)

Matthias Springer via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 31 18:51:22 PDT 2023


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/70619

>From 67b0157f76da2113672fa09acc6aa8f2cd4445b2 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Wed, 1 Nov 2023 10:48:25 +0900
Subject: [PATCH] [mlir] Loop-invariant subset hoisting

---
 .../Transforms/LoopInvariantCodeMotionUtils.h |  39 +++
 mlir/include/mlir/Transforms/Passes.h         |   3 +
 mlir/include/mlir/Transforms/Passes.td        |   5 +
 .../Transforms/LoopInvariantCodeMotion.cpp    |  20 ++
 mlir/lib/Transforms/Utils/CMakeLists.txt      |   1 +
 .../Utils/LoopInvariantCodeMotionUtils.cpp    | 253 +++++++++++++++++-
 .../loop-invariant-subset-hoisting.mlir       | 237 ++++++++++++++++
 .../llvm-project-overlay/mlir/BUILD.bazel     |   1 +
 8 files changed, 555 insertions(+), 4 deletions(-)
 create mode 100644 mlir/test/Transforms/loop-invariant-subset-hoisting.mlir

diff --git a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
index c7b816eb28faf5f..579054070f729b0 100644
--- a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
@@ -71,6 +71,45 @@ size_t moveLoopInvariantCode(
 /// methods provided by the interface.
 size_t moveLoopInvariantCode(LoopLikeOpInterface loopLike);
 
+/// Hoist loop-invariant tensor subsets (subset extraction and subset insertion
+/// ops) from loop-like ops. Extraction ops are moved before the loop. Insertion
+/// ops are moved after the loop. The loop body operates on newly added region
+/// iter_args (one per extraction-insertion pair).
+///
+/// A subset extraction op (`SubsetExtractionOpInterface`) extracts from a
+/// tensor value at a subset. The result of the op may have an arbitrary type,
+/// i.e., not necessarily a tensor type. Example: "tensor.extract_slice".
+///
+/// A subset insertion op  (`SubsetInsertionOpInterface`) inserts into a tensor
+/// value ("destination") at a subset. Example: "tensor.insert_slice".
+///
+/// Matching extraction-insertion subset ops can be hoisted from a loop if there
+/// are no other ops within the loop that operate on the same or on an
+/// overlapping subset. In particular, non-subset ops can prevent hoisting
+/// because the analysis does not know what subset they operate on.
+///
+/// Example:
+/// ```
+/// %r = scf.for ... iter_args(%t = %a) -> (tensor<?xf32>) {
+///   %0 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+///   %1 = "test.foo"(%0) : (tensor<5xf32>) -> (tensor<5xf32>)
+///   %2 = tensor.insert_slice %1 into %t[0][5][1]
+///       : tensor<5xf32> into tensor<?xf32>
+///   scf.yield %2 : tensor<?xf32>
+/// }
+/// ```
+/// Is rewritten to:
+/// ```
+/// %0 = tensor.extract_slice %a[0][5][1] : tensor<?xf32> to tensor<5xf32>
+/// %new_loop:2 = scf.for ... iter_args(%t = %a, %h = %0) -> (tensor<?xf32>) {
+///   %1 = "test.foo"(%h) : (tensor<5xf32>) -> (tensor<5xf32>)
+///   scf.yield %t, %2 : tensor<?xf32>, tensor<5xf32>
+/// }
+/// %r = tensor.insert_slice %new_loop#1 into %new_loop#0
+///     : tensor<5xf32> into tensor<?xf32>
+/// ```
+LoopLikeOpInterface hoistLoopInvariantSubsets(LoopLikeOpInterface loopLike);
+
 } // end namespace mlir
 
 #endif // MLIR_TRANSFORMS_LOOPINVARIANTCODEMOTIONUTILS_H
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 320932bb999561f..11f5b23e62c663b 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -78,6 +78,9 @@ std::unique_ptr<Pass> createGenerateRuntimeVerificationPass();
 /// instructions out of the loop.
 std::unique_ptr<Pass> createLoopInvariantCodeMotionPass();
 
+/// Creates a pass that hoists loop-invariant subset ops.
+std::unique_ptr<Pass> createLoopInvariantSubsetHoistingPass();
+
 /// Creates a pass to strip debug information from a function.
 std::unique_ptr<Pass> createStripDebugInfoPass();
 
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 26d2ff3c30ded57..2d2d54fb8fb5eaa 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -329,6 +329,11 @@ def LoopInvariantCodeMotion : Pass<"loop-invariant-code-motion"> {
   let constructor = "mlir::createLoopInvariantCodeMotionPass()";
 }
 
+def LoopInvariantSubsetHoisting : Pass<"loop-invariant-subset-hoisting"> {
+  let summary = "Hoist loop invariant subset ops outside of the loop";
+  let constructor = "mlir::createLoopInvariantSubsetHoistingPass()";
+}
+
 def Mem2Reg : Pass<"mem2reg"> {
   let summary = "Promotes memory slots into values.";
   let description = [{
diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
index 854fde09bac796e..e6d8af8f05832d3 100644
--- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
+++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
@@ -18,6 +18,7 @@
 
 namespace mlir {
 #define GEN_PASS_DEF_LOOPINVARIANTCODEMOTION
+#define GEN_PASS_DEF_LOOPINVARIANTSUBSETHOISTING
 #include "mlir/Transforms/Passes.h.inc"
 } // namespace mlir
 
@@ -29,6 +30,12 @@ struct LoopInvariantCodeMotion
     : public impl::LoopInvariantCodeMotionBase<LoopInvariantCodeMotion> {
   void runOnOperation() override;
 };
+
+struct LoopInvariantSubsetHoisting
+    : public impl::LoopInvariantSubsetHoistingBase<
+          LoopInvariantSubsetHoisting> {
+  void runOnOperation() override;
+};
 } // namespace
 
 void LoopInvariantCodeMotion::runOnOperation() {
@@ -39,6 +46,19 @@ void LoopInvariantCodeMotion::runOnOperation() {
       [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
 }
 
+void LoopInvariantSubsetHoisting::runOnOperation() {
+  // Walk through all loops in a function in innermost-loop-first order. This
+  // way, we first hoist from the inner loop, and place the ops in the outer
+  // loop, which in turn can be further hoisted from.
+  getOperation()->walk([&](LoopLikeOpInterface loopLike) {
+    (void)hoistLoopInvariantSubsets(loopLike);
+  });
+}
+
 std::unique_ptr<Pass> mlir::createLoopInvariantCodeMotionPass() {
   return std::make_unique<LoopInvariantCodeMotion>();
 }
+
+std::unique_ptr<Pass> mlir::createLoopInvariantSubsetHoistingPass() {
+  return std::make_unique<LoopInvariantSubsetHoisting>();
+}
diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
index efc7a5160b2399e..1c608e0634a67e2 100644
--- a/mlir/lib/Transforms/Utils/CMakeLists.txt
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -20,5 +20,6 @@ add_mlir_library(MLIRTransformUtils
   MLIRFunctionInterfaces
   MLIRLoopLikeInterface
   MLIRSideEffectInterfaces
+  MLIRSubsetOpInterface
   MLIRRewrite
   )
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index 080492da6ae4b97..f39f8ec3598f322 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -11,9 +11,12 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
+
 #include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/SubsetOpInterface.h"
 #include "llvm/Support/Debug.h"
 #include <queue>
 
@@ -26,7 +29,7 @@ using namespace mlir;
 ///   loop (by means of calling definedOutside).
 /// - the op has no side-effects.
 static bool canBeHoisted(Operation *op,
-                         function_ref<bool(Value)> definedOutside) {
+                         function_ref<bool(OpOperand &)> condition) {
   // Do not move terminators.
   if (op->hasTrait<OpTrait::IsTerminator>())
     return false;
@@ -35,11 +38,11 @@ static bool canBeHoisted(Operation *op,
   // defined outside of the loop or in a nested region, but not at the level of
   // the loop body.
   auto walkFn = [&](Operation *child) {
-    for (Value operand : child->getOperands()) {
+    for (OpOperand &operand : child->getOpOperands()) {
       // Ignore values defined in a nested region.
-      if (op->isAncestor(operand.getParentRegion()->getParentOp()))
+      if (op->isAncestor(operand.get().getParentRegion()->getParentOp()))
         continue;
-      if (!definedOutside(operand))
+      if (!condition(operand))
         return WalkResult::interrupt();
     }
     return WalkResult::advance();
@@ -47,6 +50,12 @@ static bool canBeHoisted(Operation *op,
   return !op->walk(walkFn).wasInterrupted();
 }
 
+static bool canBeHoisted(Operation *op,
+                         function_ref<bool(Value)> definedOutside) {
+  return canBeHoisted(
+      op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
+}
+
 size_t mlir::moveLoopInvariantCode(
     ArrayRef<Region *> regions,
     function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
@@ -105,3 +114,239 @@ size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
       },
       [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
 }
+
+namespace {
+/// Helper data structure that keeps track of equivalent/disjoint subset ops.
+class MatchingSubsets {
+public:
+  /// Insert a subset op.
+  void insert(SubsetOpInterface op) {
+    allSubsetOps.push_back(op);
+    if (auto extractionOp =
+            dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
+      insertExtractionOp(extractionOp);
+    if (auto insertionOp =
+            dyn_cast<SubsetInsertionOpInterface>(op.getOperation()))
+      insertInsertionOp(insertionOp);
+  }
+
+  /// Return a range of matching extraction-insertion subset ops. If there is no
+  /// matching extraction/insertion op, the respective value is empty. Ops are
+  /// skipped if there are other subset ops that are not guaranteed to operate
+  /// on disjoint subsets.
+  auto getHoistableSubsetOps() {
+    return llvm::make_filter_range(
+        llvm::zip(extractions, insertions), [&](auto pair) {
+          auto [extractionOp, insertionOp] = pair;
+          // Hoist only if the extracted and inserted values have the same type.
+          if (extractionOp && insertionOp &&
+              extractionOp->getResult(0).getType() !=
+                  insertionOp.getSourceOperand().get().getType())
+            return false;
+          // Hoist only if there are no conflicting subset ops.
+          return allDisjoint(extractionOp, insertionOp);
+        });
+  }
+
+private:
+  /// Helper function for equivalence of tensor values. Since only insertion
+  /// subset ops (that are also destination style ops) are followed when
+  /// traversing the SSA use-def chain, all tensor values are equivalent.
+  static bool isEquivalent(Value v1, Value v2) { return true; }
+
+  /// Return "true" if the subsets of the given extraction and insertion ops
+  /// are operating disjoint from the subsets that all other known subset ops
+  /// are operating on.
+  bool allDisjoint(SubsetExtractionOpInterface extractionOp,
+                   SubsetInsertionOpInterface insertionOp) const {
+    for (SubsetOpInterface other : allSubsetOps) {
+      if (other == extractionOp || other == insertionOp)
+        continue;
+      if (extractionOp &&
+          !other.operatesOnDisjointSubset(extractionOp, isEquivalent))
+        return false;
+      if (insertionOp &&
+          !other.operatesOnDisjointSubset(insertionOp, isEquivalent))
+        return false;
+    }
+    return true;
+  }
+
+  /// Insert a subset extraction op. If the subset is equivalent to an existing
+  /// subset insertion op, pair them up. (If there is already a paired up subset
+  /// extraction op, overwrite the subset extraction op.)
+  void insertExtractionOp(SubsetExtractionOpInterface extractionOp) {
+    for (auto it : llvm::enumerate(insertions)) {
+      if (!it.value())
+        continue;
+      auto other = cast<SubsetOpInterface>(it.value().getOperation());
+      if (other.operatesOnEquivalentSubset(extractionOp, isEquivalent)) {
+        extractions[it.index()] = extractionOp;
+        return;
+      }
+    }
+    // There is no known equivalent insertion op. Create a new entry.
+    extractions.push_back(extractionOp);
+    insertions.push_back({});
+  }
+
+  /// Insert a subset insertion op. If the subset is equivalent to an existing
+  /// subset extraction op, pair them up. (If there is already a paired up
+  /// subset insertion op, overwrite the subset insertion op.)
+  void insertInsertionOp(SubsetInsertionOpInterface insertionOp) {
+    for (auto it : llvm::enumerate(extractions)) {
+      if (!it.value())
+        continue;
+      auto other = cast<SubsetOpInterface>(it.value().getOperation());
+      if (other.operatesOnEquivalentSubset(insertionOp, isEquivalent)) {
+        insertions[it.index()] = insertionOp;
+        return;
+      }
+    }
+    // There is no known equivalent extraction op. Create a new entry.
+    extractions.push_back({});
+    insertions.push_back(insertionOp);
+  }
+
+  SmallVector<SubsetExtractionOpInterface> extractions;
+  SmallVector<SubsetInsertionOpInterface> insertions;
+  SmallVector<SubsetOpInterface> allSubsetOps;
+};
+} // namespace
+
+/// If the given value has a single use by an op that is a terminator, return
+/// that use. Otherwise, return nullptr.
+static OpOperand *getSingleTerminatorUse(Value value) {
+  if (!value.hasOneUse())
+    return nullptr;
+  OpOperand &use = *value.getUses().begin();
+  if (use.getOwner()->hasTrait<OpTrait::IsTerminator>())
+    return &use;
+  return nullptr;
+}
+
+/// Hoist all subset ops that operate on the idx-th region iter_arg of the given
+/// loop-like op and index into loop-invariant subset locations. Return the
+/// newly created loop op (that has extra iter_args) or the original loop op if
+/// nothing was hoisted.
+static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
+                                                BlockArgument iterArg) {
+  IRRewriter rewriter(loopLike.getContext());
+  assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
+  auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
+  int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
+  Value value = iterArg;
+  MatchingSubsets subsets;
+
+  // Traverse use-def chain. Subset ops can be hoisted only if all ops along the
+  // use-def chain starting from the region iter_arg are subset extraction or
+  // subset insertion ops. The chain must terminate at the corresponding yield
+  // operand (e.g., no swapping of iter_args).
+  OpOperand *yieldedOperand = nullptr;
+  // Iterate until the single use of the current SSA value is a terminator,
+  // which is expected to be the yielding operation of the loop.
+  while (!(yieldedOperand = getSingleTerminatorUse(value))) {
+    Value nextValue = {};
+
+    for (OpOperand &use : value.getUses()) {
+      auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner());
+      if (!subsetOp)
+        return loopLike;
+      subsets.insert(subsetOp);
+
+      if (auto insertionOp =
+              dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) {
+        // The value must be used as a destination. (In case of a source, the
+        // entire tensor would be read, which would prevent any hoisting.)
+        if (&use != &insertionOp.getDestinationOperand())
+          return loopLike;
+        // There must be a single use-def chain from the region iter_arg to the
+        // terminator. I.e., only one insertion op. Branches are not supported.
+        if (nextValue)
+          return loopLike;
+        nextValue = insertionOp.getUpdatedDestination();
+      }
+    }
+
+    // Nothing can be hoisted if the chain does not continue with loop yielding
+    // op or a subset insertion op.
+    if (!nextValue)
+      return loopLike;
+    value = nextValue;
+  }
+
+  // Hoist only if the SSA use-def chain ends in the yielding terminator of the
+  // loop and the yielded value is the `idx`-th operand. (I.e., there is no
+  // swapping yield.)
+  if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand)
+    return loopLike;
+
+  // Hoist all matching extraction-insertion pairs one-by-one.
+  for (auto it : subsets.getHoistableSubsetOps()) {
+    auto extractionOp = std::get<0>(it);
+    auto insertionOp = std::get<1>(it);
+
+    // Ops cannot be hoisted if they depend on loop-variant values.
+    if (extractionOp) {
+      if (!canBeHoisted(extractionOp, [&](OpOperand &operand) {
+            return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
+                   &operand == &extractionOp.getSourceOperand();
+          }))
+        extractionOp = {};
+    }
+    if (insertionOp) {
+      if (!canBeHoisted(insertionOp, [&](OpOperand &operand) {
+            return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
+                   &operand == &insertionOp.getSourceOperand() ||
+                   &operand == &insertionOp.getDestinationOperand();
+          }))
+        insertionOp = {};
+    }
+
+    // Only hoist extraction-insertion pairs for now. Standalone extractions/
+    // insertions that are loop-invariant could be hoisted, but there may be
+    // easier ways to canonicalize the IR.
+    if (extractionOp && insertionOp) {
+      // Create a new loop with an additional iter_arg.
+      NewYieldValuesFn newYieldValuesFn =
+          [&](OpBuilder &b, Location loc,
+              ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
+        return {insertionOp.getSourceOperand().get()};
+      };
+      FailureOr<LoopLikeOpInterface> newLoop =
+          loopLike.replaceWithAdditionalYields(
+              rewriter, extractionOp.getResult(),
+              /*replaceInitOperandUsesInLoop=*/true, newYieldValuesFn);
+      if (failed(newLoop))
+        return loopLike;
+      loopLike = *newLoop;
+
+      // Hoist the extraction/insertion ops.
+      iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
+      OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
+      OpResult newLoopResult = loopLike.getLoopResults()->back();
+      extractionOp->moveBefore(loopLike);
+      insertionOp->moveAfter(loopLike);
+      insertionOp.getUpdatedDestination().replaceAllUsesWith(
+          insertionOp.getDestinationOperand().get());
+      extractionOp.getSourceOperand().set(
+          loopLike.getTiedLoopInit(iterArg)->get());
+      loopResult.replaceAllUsesWith(insertionOp.getUpdatedDestination());
+      insertionOp.getSourceOperand().set(newLoopResult);
+      insertionOp.getDestinationOperand().set(loopResult);
+    }
+  }
+
+  return loopLike;
+}
+
+LoopLikeOpInterface
+mlir::hoistLoopInvariantSubsets(LoopLikeOpInterface loopLike) {
+  // Note: As subset ops are getting hoisted, the number of region iter_args
+  // increases. This can enable further hoisting opportunities on the new
+  // iter_args.
+  for (int64_t i = 0; i < loopLike.getRegionIterArgs().size(); ++i) {
+    loopLike = hoistSubsetAtIterArg(loopLike, loopLike.getRegionIterArgs()[i]);
+  }
+  return loopLike;
+}
diff --git a/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
new file mode 100644
index 000000000000000..5cded4c99182c14
--- /dev/null
+++ b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
@@ -0,0 +1,237 @@
+// RUN: mlir-opt %s  -split-input-file -loop-invariant-subset-hoisting | FileCheck %s
+
+// CHECK-LABEL: func @hoist_matching_extract_insert(
+//  CHECK-SAME:     %[[arg:.*]]: tensor<?xf32>
+func.func @hoist_matching_extract_insert(%arg: tensor<?xf32>) -> tensor<?xf32> {
+  %lb = "test.foo"() : () -> (index)
+  %ub = "test.foo"() : () -> (index)
+  %step = "test.foo"() : () -> (index)
+
+  // CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]]
+  // CHECK: %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]])
+  %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+    // CHECK: tensor.extract_slice %[[t]][9] [5] [1]
+    %standalone = tensor.extract_slice %t[9][5][1] : tensor<?xf32> to tensor<5xf32>
+    "test.foo"(%standalone) : (tensor<5xf32>) -> ()
+
+    %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+    // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]])
+    %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+    %3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
+    // CHECK: scf.yield %[[t]], %[[foo]]
+    scf.yield %3 : tensor<?xf32>
+  }
+  // CHECK: %[[insert:.*]] = tensor.insert_slice %[[for]]#1 into %[[for]]#0
+
+  // CHECK: return %[[insert]]
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+func.func @subset_of_subset(%arg: tensor<?xf32>) -> tensor<?xf32> {
+  %lb = "test.foo"() : () -> (index)
+  %ub = "test.foo"() : () -> (index)
+  %step = "test.foo"() : () -> (index)
+
+  // CHECK: %[[extract1:.*]] = tensor.extract_slice %[[arg]]
+  // CHECK: %[[extract2:.*]] = tensor.extract_slice %[[extract1]]
+  // CHECK: %[[for:.*]]:3 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted1:.*]] = %[[extract1]], %[[hoisted2:.*]] = %[[extract2]])
+  %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+    %extract1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+    %extract2 = tensor.extract_slice %extract1[1][2][1] : tensor<5xf32> to tensor<2xf32>
+
+    // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted2]])
+    %2 = "test.foo"(%extract2) : (tensor<2xf32>) -> (tensor<2xf32>)
+
+    %insert1 = tensor.insert_slice %2 into %extract1[1][2][1] : tensor<2xf32> into tensor<5xf32>
+    %insert2 = tensor.insert_slice %insert1 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
+
+    // CHECK: scf.yield %[[t]], %[[hoisted1]], %[[foo]]
+    scf.yield %insert2 : tensor<?xf32>
+  }
+  // CHECK: %[[insert2:.*]] = tensor.insert_slice %[[for]]#2 into %[[for]]#1[1] [2] [1]
+  // CHECK: %[[insert1:.*]] = tensor.insert_slice %[[insert2]] into %[[for]]#0[0] [5] [1]
+
+  // CHECK: return %[[insert1]]
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @hoist_matching_chain(
+//  CHECK-SAME:     %[[arg:.*]]: tensor<?xf32>
+func.func @hoist_matching_chain(%arg: tensor<?xf32>) -> tensor<?xf32> {
+  %lb = "test.foo"() : () -> (index)
+  %ub = "test.foo"() : () -> (index)
+  %step = "test.foo"() : () -> (index)
+  %sz = "test.foo"() : () -> (index)
+
+  // CHECK: %[[extract2:.*]] = tensor.extract_slice %[[arg]][%{{.*}}] [5] [1]
+  // CHECK: %[[extract1:.*]] = tensor.extract_slice %[[arg]][0] [%{{.*}}] [1]
+  // CHECK: %[[for:.*]]:3 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted2:.*]] = %[[extract2]], %[[hoisted1:.*]] = %[[extract1]])
+  %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+    %1 = tensor.extract_slice %t[0][%sz][1] : tensor<?xf32> to tensor<?xf32>
+    %2 = tensor.extract_slice %t[%sz][5][1] : tensor<?xf32> to tensor<5xf32>
+    // CHECK-DAG: %[[foo1:.*]] = "test.foo"(%[[hoisted1]])
+    // CHECK-DAG: %[[foo2:.*]] = "test.foo"(%[[hoisted2]])
+    %foo1 = "test.foo"(%1) : (tensor<?xf32>) -> (tensor<?xf32>)
+    %foo2 = "test.foo"(%2) : (tensor<5xf32>) -> (tensor<5xf32>)
+    %5 = tensor.insert_slice %foo2 into %t[%sz][5][1] : tensor<5xf32> into tensor<?xf32>
+    %6 = tensor.insert_slice %foo1 into %5[0][%sz][1] : tensor<?xf32> into tensor<?xf32>
+    // CHECK: scf.yield %[[t]], %[[foo2]], %[[foo1]]
+    scf.yield %6 : tensor<?xf32>
+  }
+  // CHECK: %[[insert2:.*]] = tensor.insert_slice %[[for]]#2 into %[[for]]#0[0] [%{{.*}}] [1]
+  // CHECK: %[[insert1:.*]] = tensor.insert_slice %[[for]]#1 into %[[insert2]][%{{.*}}] [5] [1]
+
+  // CHECK: return %[[insert1]]
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @do_not_hoist_overlapping_subsets(
+func.func @do_not_hoist_overlapping_subsets(%arg: tensor<?xf32>) -> tensor<?xf32> {
+  %lb = "test.foo"() : () -> (index)
+  %ub = "test.foo"() : () -> (index)
+  %step = "test.foo"() : () -> (index)
+  %sz1 = "test.foo"() : () -> (index)
+  %sz2 = "test.foo"() : () -> (index)
+
+  // CHECK: scf.for
+  %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+    // These two slices are potentially overlapping. Do not hoist.
+    // CHECK: tensor.extract_slice
+    // CHECK: tensor.extract_slice
+    %1 = tensor.extract_slice %t[0][%sz1][1] : tensor<?xf32> to tensor<?xf32>
+    %2 = tensor.extract_slice %t[10][%sz2][1] : tensor<?xf32> to tensor<?xf32>
+    // CHECK: "test.foo"
+    // CHECK: "test.foo"
+    %foo1 = "test.foo"(%1) : (tensor<?xf32>) -> (tensor<?xf32>)
+    %foo2 = "test.foo"(%2) : (tensor<?xf32>) -> (tensor<?xf32>)
+    // CHECK: tensor.insert_slice
+    // CHECK: tensor.insert_slice
+    %5 = tensor.insert_slice %foo2 into %t[0][%sz1][1] : tensor<?xf32> into tensor<?xf32>
+    %6 = tensor.insert_slice %foo1 into %5[10][%sz2][1] : tensor<?xf32> into tensor<?xf32>
+    // CHECK: scf.yield
+    scf.yield %6 : tensor<?xf32>
+  }
+
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @multiple_yields(
+//  CHECK-SAME:     %[[arg:.*]]: tensor<?xf32>
+func.func @multiple_yields(%arg: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
+  %lb = "test.foo"() : () -> (index)
+  %ub = "test.foo"() : () -> (index)
+  %step = "test.foo"() : () -> (index)
+
+  // CHECK: %[[extract1:.*]] = tensor.extract_slice
+  // CHECK: %[[extract2:.*]] = tensor.extract_slice
+  // CHECK: scf.for {{.*}} iter_args(%{{.*}} = %[[arg]], %{{.*}} = %[[arg]], %{{.*}} = %[[extract1]], %{{.*}} = %[[extract2]])
+  %0:2 = scf.for %iv = %lb to %ub step %step iter_args(%t1 = %arg, %t2 = %arg)
+      -> (tensor<?xf32>, tensor<?xf32>) {
+    %1 = tensor.extract_slice %t1[0][5][1] : tensor<?xf32> to tensor<5xf32>
+    %2 = tensor.extract_slice %t2[5][5][1] : tensor<?xf32> to tensor<5xf32>
+    // CHECK: "test.foo"
+    // CHECK: "test.foo"
+    %foo1 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+    %foo2 = "test.foo"(%2) : (tensor<5xf32>) -> (tensor<5xf32>)
+    %5 = tensor.insert_slice %foo2 into %t1[0][5][1] : tensor<5xf32> into tensor<?xf32>
+    %6 = tensor.insert_slice %foo1 into %t2[5][5][1] : tensor<5xf32> into tensor<?xf32>
+    // CHECK: scf.yield
+    scf.yield %5, %6 : tensor<?xf32>, tensor<?xf32>
+  }
+  // CHECK: tensor.insert_slice
+  // CHECK: tensor.insert_slice
+
+  return %0#0, %0#1 : tensor<?xf32>, tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @do_not_hoist_swapping_yields(
+func.func @do_not_hoist_swapping_yields(%arg: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
+  %lb = "test.foo"() : () -> (index)
+  %ub = "test.foo"() : () -> (index)
+  %step = "test.foo"() : () -> (index)
+
+  // CHECK: scf.for
+  %0:2 = scf.for %iv = %lb to %ub step %step iter_args(%t1 = %arg, %t2 = %arg)
+      -> (tensor<?xf32>, tensor<?xf32>) {
+    // CHECK: tensor.extract_slice
+    // CHECK: tensor.extract_slice
+    %1 = tensor.extract_slice %t1[0][5][1] : tensor<?xf32> to tensor<5xf32>
+    %2 = tensor.extract_slice %t2[5][5][1] : tensor<?xf32> to tensor<5xf32>
+    // CHECK: "test.foo"
+    // CHECK: "test.foo"
+    %foo1 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+    %foo2 = "test.foo"(%2) : (tensor<5xf32>) -> (tensor<5xf32>)
+    // CHECK: tensor.insert_slice
+    // CHECK: tensor.insert_slice
+    %5 = tensor.insert_slice %foo2 into %t1[0][5][1] : tensor<5xf32> into tensor<?xf32>
+    %6 = tensor.insert_slice %foo1 into %t2[5][5][1] : tensor<5xf32> into tensor<?xf32>
+    // Swapping yields: do not hoist.
+    // CHECK: scf.yield
+    scf.yield %6, %5 : tensor<?xf32>, tensor<?xf32>
+  }
+
+  return %0#0, %0#1 : tensor<?xf32>, tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @non_subset_op(
+func.func @non_subset_op(%arg: tensor<?xf32>) -> tensor<?xf32> {
+  %lb = "test.foo"() : () -> (index)
+  %ub = "test.foo"() : () -> (index)
+  %step = "test.foo"() : () -> (index)
+
+  // CHECK: scf.for
+  %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+    // If any value along the use-def chain from the region iter_arg to the
+    // terminator is used by a non-subset op, no subset op along that chain can
+    // be hoisted. That is because it is unknown which parts of the value are
+    // accessed by the non-subset op.
+    // CHECK: "test.non_subset_op"
+    "test.non_subset_op"(%t) : (tensor<?xf32>) -> ()
+    // CHECK: tensor.extract_slice
+    %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+    // CHECK: "test.foo"
+    %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+    // CHECK: tensor.insert_slice
+    %3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
+    // CHECK: scf.yield
+    scf.yield %3 : tensor<?xf32>
+  }
+
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @non_loop_invariant_subset_op(
+func.func @non_loop_invariant_subset_op(%arg: tensor<?xf32>) -> tensor<?xf32> {
+  %lb = "test.foo"() : () -> (index)
+  %ub = "test.foo"() : () -> (index)
+  %step = "test.foo"() : () -> (index)
+
+  // CHECK: scf.for
+  %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+    // Subset ops that are not loop-invariant cannot be hoisted.
+    // CHECK: tensor.extract_slice
+    %1 = tensor.extract_slice %t[%iv][5][1] : tensor<?xf32> to tensor<5xf32>
+    // CHECK: "test.foo"
+    %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+    // CHECK: tensor.insert_slice
+    %3 = tensor.insert_slice %2 into %t[%iv][5][1] : tensor<5xf32> into tensor<?xf32>
+    // CHECK: scf.yield
+    scf.yield %3 : tensor<?xf32>
+  }
+
+  return %0 : tensor<?xf32>
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 0448ecbef655fa5..0a2ae427169a99a 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -7035,6 +7035,7 @@ cc_library(
         ":MemorySlotInterfaces",
         ":Rewrite",
         ":SideEffectInterfaces",
+        ":SubsetOpInterface",
         ":Support",
         ":TransformsPassIncGen",
         ":config",



More information about the llvm-commits mailing list