[Mlir-commits] [mlir] [mlir][linalg][NFC] Remove linalg subset hoisting (PR #70636)

Matthias Springer llvmlistbot at llvm.org
Mon Oct 30 02:15:16 PDT 2023


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/70636

Remove `SubsetHoisting.cpp` and migrate all remaining uses to the newly added loop-invariant subset hoisting transform in `mlir/Transforms`.

Depends on #70535, #70617, #70619, #70623, #70628, #70629, #70630. Only review the top commit.

>From d5a79135800b91462c8e4563af0e7295f80743f5 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 30 Oct 2023 14:32:21 +0900
Subject: [PATCH 1/7] [mlir] `SubsetOpInterface` and
 `SubsetExtractionOpInterface`

---
 .../Dialect/Bufferization/IR/Bufferization.h  |   2 +-
 .../Bufferization/IR/BufferizationOps.td      |   3 +-
 .../SubsetInsertionOpInterfaceImpl.h          |   3 +-
 .../SubsetInsertionOpInterfaceImpl.h          |   3 +-
 mlir/include/mlir/InitAllDialects.h           |   4 +-
 mlir/include/mlir/Interfaces/CMakeLists.txt   |   2 +-
 .../Interfaces/SubsetInsertionOpInterface.h   |  27 --
 .../Interfaces/SubsetInsertionOpInterface.td  | 155 ----------
 .../mlir/Interfaces/SubsetOpInterface.h       |  45 +++
 .../mlir/Interfaces/SubsetOpInterface.td      | 267 ++++++++++++++++++
 .../mlir/Interfaces/ValueBoundsOpInterface.h  |  36 +--
 .../Bufferization/IR/BufferizationOps.cpp     |  12 +
 .../Dialect/Bufferization/IR/CMakeLists.txt   |   2 +-
 .../Bufferization/Transforms/CMakeLists.txt   |   2 +-
 .../Transforms/EmptyTensorElimination.cpp     |   2 +-
 .../Transforms/OneShotAnalysis.cpp            |   2 +-
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   2 +-
 .../SubsetInsertionOpInterfaceImpl.cpp        |  31 +-
 .../BufferizableOpInterfaceImpl.cpp           |   2 +-
 .../Dialect/Tensor/Transforms/CMakeLists.txt  |   2 +-
 .../SubsetInsertionOpInterfaceImpl.cpp        | 137 +++++++--
 mlir/lib/Interfaces/CMakeLists.txt            |   9 +-
 .../Interfaces/SubsetInsertionOpInterface.cpp |  23 --
 mlir/lib/Interfaces/SubsetOpInterface.cpp     |  58 ++++
 .../lib/Interfaces/ValueBoundsOpInterface.cpp |  76 +++++
 .../llvm-project-overlay/mlir/BUILD.bazel     |  33 +--
 .../mlir/python/BUILD.bazel                   |   2 +-
 27 files changed, 654 insertions(+), 288 deletions(-)
 delete mode 100644 mlir/include/mlir/Interfaces/SubsetInsertionOpInterface.h
 delete mode 100644 mlir/include/mlir/Interfaces/SubsetInsertionOpInterface.td
 create mode 100644 mlir/include/mlir/Interfaces/SubsetOpInterface.h
 create mode 100644 mlir/include/mlir/Interfaces/SubsetOpInterface.td
 delete mode 100644 mlir/lib/Interfaces/SubsetInsertionOpInterface.cpp
 create mode 100644 mlir/lib/Interfaces/SubsetOpInterface.cpp

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
index c035190f43e3950..e98b5728b38ef81 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
@@ -15,7 +15,7 @@
 #include "mlir/Interfaces/CopyOpInterface.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
-#include "mlir/Interfaces/SubsetInsertionOpInterface.h"
+#include "mlir/Interfaces/SubsetOpInterface.h"
 
 //===----------------------------------------------------------------------===//
 // Bufferization Dialect
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 72a4aa712f49c98..e6b6d052df96a8c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -15,7 +15,7 @@ include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
-include "mlir/Interfaces/SubsetInsertionOpInterface.td"
+include "mlir/Interfaces/SubsetOpInterface.td"
 include "mlir/Interfaces/CopyOpInterface.td"
 
 class Bufferization_Op<string mnemonic, list<Trait> traits = []>
@@ -220,6 +220,7 @@ def Bufferization_MaterializeInDestinationOp
          AllElementTypesMatch<["source", "dest"]>,
          BufferizableOpInterface, DestinationStyleOpInterface,
          DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+         DeclareOpInterfaceMethods<SubsetOpInterface>,
          DeclareOpInterfaceMethods<SubsetInsertionOpInterface,
             ["getSourceOperand", "getValuesNeededToBuildSubsetExtraction",
              "buildSubsetExtraction", "isEquivalentSubset"]>,
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h
index 023a46df2620109..94b0fb25b506650 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h
@@ -13,8 +13,7 @@ namespace mlir {
 class DialectRegistry;
 
 namespace linalg {
-void registerSubsetInsertionOpInterfaceExternalModels(
-    DialectRegistry &registry);
+void registerSubsetOpInterfaceExternalModels(DialectRegistry &registry);
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h
index e21b07d8a2705a0..019da189a8c991b 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h
@@ -13,8 +13,7 @@ namespace mlir {
 class DialectRegistry;
 
 namespace tensor {
-void registerSubsetInsertionOpInterfaceExternalModels(
-    DialectRegistry &registry);
+void registerSubsetOpInterfaceExternalModels(DialectRegistry &registry);
 } // namespace tensor
 } // namespace mlir
 
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 00f400aab5d50a0..7c2ffb7408d9afd 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -151,7 +151,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   cf::registerBufferDeallocationOpInterfaceExternalModels(registry);
   gpu::registerBufferDeallocationOpInterfaceExternalModels(registry);
   linalg::registerBufferizableOpInterfaceExternalModels(registry);
-  linalg::registerSubsetInsertionOpInterfaceExternalModels(registry);
+  linalg::registerSubsetOpInterfaceExternalModels(registry);
   linalg::registerTilingInterfaceExternalModels(registry);
   linalg::registerValueBoundsOpInterfaceExternalModels(registry);
   memref::registerAllocationOpInterfaceExternalModels(registry);
@@ -167,7 +167,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   tensor::registerBufferizableOpInterfaceExternalModels(registry);
   tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry);
   tensor::registerInferTypeOpInterfaceExternalModels(registry);
-  tensor::registerSubsetInsertionOpInterfaceExternalModels(registry);
+  tensor::registerSubsetOpInterfaceExternalModels(registry);
   tensor::registerTilingInterfaceExternalModels(registry);
   tensor::registerValueBoundsOpInterfaceExternalModels(registry);
   vector::registerBufferizableOpInterfaceExternalModels(registry);
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index 36a04ff0eaeaf4b..d81298bb4daf014 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -12,7 +12,7 @@ add_mlir_interface(ParallelCombiningOpInterface)
 add_mlir_interface(RuntimeVerifiableOpInterface)
 add_mlir_interface(ShapedOpInterfaces)
 add_mlir_interface(SideEffectInterfaces)
-add_mlir_interface(SubsetInsertionOpInterface)
+add_mlir_interface(SubsetOpInterface)
 add_mlir_interface(TilingInterface)
 add_mlir_interface(ValueBoundsOpInterface)
 add_mlir_interface(VectorInterfaces)
diff --git a/mlir/include/mlir/Interfaces/SubsetInsertionOpInterface.h b/mlir/include/mlir/Interfaces/SubsetInsertionOpInterface.h
deleted file mode 100644
index 3a6dfceadcce7c0..000000000000000
--- a/mlir/include/mlir/Interfaces/SubsetInsertionOpInterface.h
+++ /dev/null
@@ -1,27 +0,0 @@
-//===- SubsetInsertionOpInterface.h - Tensor Subsets ------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_INTERFACES_SUBSETINSERTIONOPINTERFACE_H_
-#define MLIR_INTERFACES_SUBSETINSERTIONOPINTERFACE_H_
-
-#include "mlir/IR/OpDefinition.h"
-
-namespace mlir {
-namespace detail {
-
-/// Return the destination/"init" operand of the op if it implements the
-/// `DestinationStyleOpInterface` and has exactly one "init" operand. Asserts
-/// otherwise.
-OpOperand &defaultGetDestinationOperand(Operation *op);
-
-} // namespace detail
-} // namespace mlir
-
-#include "mlir/Interfaces/SubsetInsertionOpInterface.h.inc"
-
-#endif // MLIR_INTERFACES_SUBSETINSERTIONOPINTERFACE_H_
diff --git a/mlir/include/mlir/Interfaces/SubsetInsertionOpInterface.td b/mlir/include/mlir/Interfaces/SubsetInsertionOpInterface.td
deleted file mode 100644
index ef94a8ae9a60efd..000000000000000
--- a/mlir/include/mlir/Interfaces/SubsetInsertionOpInterface.td
+++ /dev/null
@@ -1,155 +0,0 @@
-//===-- SubsetInsertionOpInterface.td - Tensor Subsets -----*- tablegen -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef SUBSET_INSERTION_OP_INTERFACE
-#define SUBSET_INSERTION_OP_INTERFACE
-
-include "mlir/IR/OpBase.td"
-
-def SubsetInsertionOpInterface : OpInterface<"SubsetInsertionOpInterface"> {
-  let description = [{
-    This interface can be implemented by ops that insert a source tensor into
-    a destination tensor.
-
-    The elements in the destination tensor that are overwritten by this
-    insertion are called the "subset". How the subset is defined is up to the
-    op. E.g., "tensor.insert_slice" defines the subset via a hyperrectangular
-    slice. A scatter operation could define the subset via a list of indices.
-
-    Ops that deal with tensor subsets come in two flavours:
-    - Insertion flavor: Ops that insert a source tensor into a destination
-      tensor at the specified subset. Such ops usually return a new destination
-      tensor and implement the `DestinationStyleOpInterface`. Insertion ops can
-      implement the `SubsetInsertionOpInterface`. Example: "tensor.insert_slice"
-    - Extraction flavor: Ops that define a tensor subset. They extract a
-      specified subset from a tensor. There is currently no op interface for
-      such ops. Example: "tensor.extract_slice"
-
-    This interface provides helper methods for efficient bufferization of
-    subset-based tensor IR. Tensor subsets can bufferize to buffer "views"/
-    "aliases" (in contrast to one or multiple less efficient buffer allocation).
-
-    This interface is queried by One-Shot Bufferize to detect cases where a
-    seeming read-after-write is not actually a conflict because the respective
-    ops are operating on equivalent subsets. More details can be found in the
-    documentation of One-Shot Analysis (see `areNonConflictingSubsets`).
-
-    Note: This interface currently assumes that a subset op inserts a single
-    tensor (source) into a destination tensor at a single subset.
-  }];
-  let cppNamespace = "::mlir";
-  let methods = [
-      InterfaceMethod<
-        /*desc=*/[{
-          Return the source tensor operand.
-        }],
-        /*retType=*/"::mlir::OpOperand &",
-        /*methodName=*/"getSourceOperand",
-        /*args=*/(ins)
-      >,
-      InterfaceMethod<
-        /*desc=*/[{
-          Return the destination tensor operand.
-        }],
-        /*retType=*/"::mlir::OpOperand &",
-        /*methodName=*/"getDestinationOperand",
-        /*args=*/(ins),
-        /*methodBody=*/"",
-        /*defaultImplementation=*/[{
-          return ::mlir::detail::defaultGetDestinationOperand(
-              $_op.getOperation());
-        }]
-      >,
-      InterfaceMethod<
-        /*desc=*/[{
-          Return "true" if this operation inserts into a subset that is
-          equivalent to the subset defined by `candidate`.
-
-          Two subsets are "equivalent" and "same" if they can bufferize to the
-          same buffer views/aliases. If they are "equivalent", the tensor IR
-          may be expressed in terms of different SSA values (but they could
-          bufferize to MemRef SSA values that can CSE without breaking
-          correctness). `equivalenceFn` should return "true" if the two given
-          values are equivalent.
-
-          Example:
-          ```
-          // The subset of the SubsetInsertionOpInterface op %1 is equivalent to
-          // the subset defined by %2 (but not "same"):
-          %0 = arith.select %c, %t, %t : tensor<?xf32>
-          %1 = tensor.insert_slice %x into %0[0][5][1]
-              : tensor<5xf32> into tensor<?xf32>
-          %2 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
-
-          // The subset of the SubsetInsertionOpInterface op %1 is equivalent to
-          // and "same" as the subset defined by %2.
-          %1 = tensor.insert_slice %x into %t[0][5][1]
-              : tensor<5xf32> into tensor<?xf32>
-          %2 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
-          ```
-        }],
-        /*retType=*/"bool",
-        /*methodName=*/"isEquivalentSubset",
-        /*args=*/(ins
-            "::mlir::Value":$candidate,
-            "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn)
-      >,
-      InterfaceMethod<
-        /*desc=*/[{
-          Return the subset of the destination tensor that this operation
-          inserts into.
-
-          Example:
-          ```
-          // SubsetOpInterface op:
-          %0 = tensor.insert_slice %t0 into %t1[%pos][5][1]
-              : tensor<5xf32> into tensor<?xf32>
-          // Subset (built by this function):
-          %1 = tensor.extract_slice %t1[%pos][5][1]
-              : tensor<?xf32> to tensor<5xf32>
-          ```
-
-          Note: Implementations do not necessarily have to build new IR. They
-          may return existing SSA values.
-        }],
-        /*retType=*/"::mlir::Value",
-        /*methodName=*/"buildSubsetExtraction",
-        /*args=*/(ins "::mlir::OpBuilder &":$builder, "Location":$loc)
-      >,
-      InterfaceMethod<
-        /*desc=*/[{
-          Return all SSA values that are needed (i.e., must be in scope) at the
-          insertion of the builder when calling `buildSubsetExtraction`. Users
-          of `buildSubsetExtraction` can use this helper method to find a
-          suitable insertion point.
-
-          Example: The SSA values needed to build the subset in the example of
-          `buildSubsetExtraction` are %t1 and %pos.
-        }],
-        /*retType=*/"::llvm::SmallVector<::mlir::Value>",
-        /*methodName=*/"getValuesNeededToBuildSubsetExtraction",
-        /*args=*/(ins)
-      >,
-  ];
-
-  let extraClassDeclaration = [{
-    /// Return "true" if this operation inserts into the same subset as defined
-    /// by `candidate`.
-    ///
-    /// Note: This function is useful outside of bufferization, where no tensor
-    /// equivalence information is available.
-    bool isSameSubset(OpResult candidate) {
-      auto subsetOp = cast<::mlir::SubsetInsertionOpInterface>(
-          getOperation());
-      return subsetOp.isEquivalentSubset(
-          candidate, [](Value v1, Value v2) { return v1 == v2; });
-    }
-  }];
-}
-
-#endif // SUBSET_INSERTION_OP_INTERFACE
diff --git a/mlir/include/mlir/Interfaces/SubsetOpInterface.h b/mlir/include/mlir/Interfaces/SubsetOpInterface.h
new file mode 100644
index 000000000000000..049cf2456a9c842
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/SubsetOpInterface.h
@@ -0,0 +1,45 @@
+//===- SubsetOpInterface.h - Tensor Subsets ---------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_SUBSETOPINTERFACE_H_
+#define MLIR_INTERFACES_SUBSETOPINTERFACE_H_
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+class SubsetOpInterface;
+class SubsetExtractionOpInterface;
+class SubsetInsertionOpInterface;
+
+namespace detail {
+
+/// Return the destination/"init" operand of the op if it implements the
+/// `DestinationStyleOpInterface` and has exactly one "init" operand. Asserts
+/// otherwise.
+OpOperand &defaultGetDestinationOperand(Operation *op);
+
+/// Return the updated destination result of the op if it implements the
+/// `DestinationStyleOpInterface`.
+OpResult defaultGetUpdatedDestination(Operation *op);
+
+/// Default implementation of `isEquivalentSubset`.
+bool defaultIsEquivalentSubset(Operation *op, Value candidate,
+                               function_ref<bool(Value, Value)> equivalenceFn);
+
+/// Verify `SubsetOpInterface`.
+LogicalResult verifySubsetOpInterface(SubsetOpInterface op);
+
+/// Verify `SubsetExtractionOpInterface`.
+LogicalResult verifySubsetExtractionOpInterface(SubsetExtractionOpInterface op);
+
+} // namespace detail
+} // namespace mlir
+
+#include "mlir/Interfaces/SubsetOpInterface.h.inc"
+
+#endif // MLIR_INTERFACES_SUBSETOPINTERFACE_H_
diff --git a/mlir/include/mlir/Interfaces/SubsetOpInterface.td b/mlir/include/mlir/Interfaces/SubsetOpInterface.td
new file mode 100644
index 000000000000000..07d62b8319c2961
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/SubsetOpInterface.td
@@ -0,0 +1,267 @@
+//===-- SubsetOpInterface.td - Tensor Subsets --------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SUBSET_OP_INTERFACE
+#define SUBSET_OP_INTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def SubsetOpInterface : OpInterface<"SubsetOpInterface"> {
+  let description = [{
+    This interface can be implemented by ops that operate on tensor subsets. A
+    "subset" is a part of a tensor. This interface describes the subset that
+    an implementing op operates on. Only the specified subset may be accessed by
+    the op.
+
+    Subset ops come in two flavours and ops that implement the
+    `SubsetOpInterface` must also implement one of the respective interfaces.
+    - Insertion flavor: Ops that insert a source value into a destination
+      tensor at the specified subset. Such ops return an updated destination
+      tensor and usually implement the `DestinationStyleOpInterface`. Insertion
+      ops must implement the `SubsetInsertionOpInterface`.
+    - Extraction flavor: Ops that extract at a subset. Extraction ops must
+      implement the `SubsetExtractionOpInterface`.
+
+    How the subset is specified is up to the implementing op. E.g.:
+    - `tensor.extract_slice/insert_slice` describe the subset as a
+      hyperrectangular slice.
+    - `tensor.gather/scatter` describe the subset as list of indices. (Not
+      implemented yet.)
+
+    Note: This interface does not expose any interface methods to get a
+    description of the accessed subset. That is because there is currently no
+    efficient way to describe arbitrary subsets. This interface merely provides
+    interface methods to check if two subsets are equivalent or disjoint.
+  }];
+
+  let cppNamespace = "::mlir";
+  let methods = [
+      InterfaceMethod<
+        /*desc=*/[{
+          Return "true" if this op and the given candidate subset op operate on
+          an equivalent subset. Return "false" is the two subsets are disjoint
+          or cannot be proven to be equivalent.
+        }],
+        /*retType=*/"bool",
+        /*methodName=*/"operatesOnEquivalentSubset",
+        /*args=*/(ins
+            "::mlir::SubsetOpInterface":$candidate,
+            "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn)
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return "true" if this op and the given candidate subset op operate on
+          disjoint subsets. Return "false" is the two subsets are equivalent,
+          overlapping or cannot be proven to be disjoint.
+        }],
+        /*retType=*/"bool",
+        /*methodName=*/"operatesOnDisjointSubset",
+        /*args=*/(ins
+            "::mlir::SubsetOpInterface":$candidate,
+            "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn)
+      >,
+  ];
+
+  let verify = [{
+    return ::mlir::detail::verifySubsetOpInterface(
+        ::mlir::cast<::mlir::SubsetOpInterface>($_op));
+  }];
+}
+
+def SubsetExtractionOpInterface
+    : OpInterface<"SubsetExtractionOpInterface", [SubsetOpInterface]> {
+  let description = [{
+    This interface can be implemented by ops that extract a value from
+    a source tensor at a specified subset. The elements in the source tensor
+    that are read by this extraction are called "subset".
+
+    Extraction ops must have a single result value.
+  }];
+
+  let cppNamespace = "::mlir";
+  let methods = [
+      InterfaceMethod<
+        /*desc=*/[{
+          Return the source tensor operand.
+        }],
+        /*retType=*/"::mlir::OpOperand &",
+        /*methodName=*/"getSourceOperand",
+        /*args=*/(ins)
+      >,
+  ];
+
+  let verify = [{
+    return ::mlir::detail::verifySubsetExtractionOpInterface(
+        ::mlir::cast<::mlir::SubsetExtractionOpInterface>($_op));
+  }];
+
+  let extraClassDeclaration = [{
+    /// Return the single result of this op.
+    ::mlir::Value getResult() {
+      return getOperation()->getResult(0);
+    }
+  }];
+}
+
+def SubsetInsertionOpInterface
+    : OpInterface<"SubsetInsertionOpInterface", [SubsetOpInterface]> {
+  let description = [{
+    This interface can be implemented by ops that insert a source value into
+    a destination tensor at a specified subset. The elements in the destination
+    tensor that are overwritten by this insertion are called "subset". The
+    updated destination tensor is returned.
+
+    This interface provides helper methods for efficient bufferization of
+    subset-based tensor IR. Tensor subsets can bufferize to buffer "views"/
+    "aliases" (in contrast to one or multiple less efficient buffer allocation).
+
+    This interface is queried by One-Shot Bufferize to detect cases where a
+    seeming read-after-write is not actually a conflict because the respective
+    ops are operating on equivalent subsets. More details can be found in the
+    documentation of One-Shot Analysis (see `areNonConflictingSubsets`).
+  }];
+
+  let cppNamespace = "::mlir";
+  let methods = [
+      InterfaceMethod<
+        /*desc=*/[{
+          Return the source operand. The source operand can have any type.
+        }],
+        /*retType=*/"::mlir::OpOperand &",
+        /*methodName=*/"getSourceOperand",
+        /*args=*/(ins)
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return the destination operand. The destination operand must be a
+          tensor.
+
+          This function does not have to be implemented for destination style
+          ops that exactly one "init".
+        }],
+        /*retType=*/"::mlir::OpOperand &",
+        /*methodName=*/"getDestinationOperand",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return ::mlir::detail::defaultGetDestinationOperand(
+              $_op.getOperation());
+        }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return the updated destination result.
+
+          This function does not have to be implemented for destination style
+          ops.
+        }],
+        /*retType=*/"::mlir::OpResult",
+        /*methodName=*/"getUpdatedDestination",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return ::mlir::detail::defaultGetUpdatedDestination(
+              $_op.getOperation());
+        }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return "true" if this operation inserts into a subset that is
+          equivalent to the subset defined by `candidate`.
+
+          Two subsets are "equivalent" and "same" if they can bufferize to the
+          same buffer views/aliases. If they are "equivalent", the tensor IR
+          may be expressed in terms of different SSA values (but they could
+          bufferize to MemRef SSA values that can CSE without breaking
+          correctness). `equivalenceFn` should return "true" if the two given
+          values are equivalent.
+
+          Example:
+          ```
+          // The subset of the SubsetInsertionOpInterface op %1 is equivalent to
+          // the subset defined by %2 (but not "same"):
+          %0 = arith.select %c, %t, %t : tensor<?xf32>
+          %1 = tensor.insert_slice %x into %0[0][5][1]
+              : tensor<5xf32> into tensor<?xf32>
+          %2 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+
+          // The subset of the SubsetInsertionOpInterface op %1 is equivalent to
+          // and "same" as the subset defined by %2.
+          %1 = tensor.insert_slice %x into %t[0][5][1]
+              : tensor<5xf32> into tensor<?xf32>
+          %2 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+          ```
+
+          Note: By default, this function calls
+          `SubsetOpInterface::operatesOnEquivalentSubset`.
+        }],
+        /*retType=*/"bool",
+        /*methodName=*/"isEquivalentSubset",
+        /*args=*/(ins
+            "::mlir::Value":$candidate,
+            "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return ::mlir::detail::defaultIsEquivalentSubset(
+              $_op.getOperation(), candidate, equivalenceFn);
+        }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return the subset of the destination tensor that this operation
+          inserts into.
+
+          Example:
+          ```
+          // SubsetOpInterface op:
+          %0 = tensor.insert_slice %t0 into %t1[%pos][5][1]
+              : tensor<5xf32> into tensor<?xf32>
+          // Subset (built by this function):
+          %1 = tensor.extract_slice %t1[%pos][5][1]
+              : tensor<?xf32> to tensor<5xf32>
+          ```
+
+          Note: Implementations do not necessarily have to build new IR. They
+          may return existing SSA values.
+        }],
+        /*retType=*/"::mlir::Value",
+        /*methodName=*/"buildSubsetExtraction",
+        /*args=*/(ins "::mlir::OpBuilder &":$builder, "Location":$loc)
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return all SSA values that are needed (i.e., must be in scope) at the
+          insertion of the builder when calling `buildSubsetExtraction`. Users
+          of `buildSubsetExtraction` can use this helper method to find a
+          suitable insertion point.
+
+          Example: The SSA values needed to build the subset in the example of
+          `buildSubsetExtraction` are %t1 and %pos.
+        }],
+        /*retType=*/"::llvm::SmallVector<::mlir::Value>",
+        /*methodName=*/"getValuesNeededToBuildSubsetExtraction",
+        /*args=*/(ins)
+      >,
+  ];
+
+  let extraClassDeclaration = [{
+    /// Return "true" if this operation inserts into the same subset as defined
+    /// by `candidate`.
+    ///
+    /// Note: This function is useful outside of bufferization, where no tensor
+    /// equivalence information is available.
+    bool isSameSubset(OpResult candidate) {
+      auto subsetOp = cast<::mlir::SubsetInsertionOpInterface>(
+          getOperation());
+      return subsetOp.isEquivalentSubset(
+          candidate, [](Value v1, Value v2) { return v1 == v2; });
+    }
+  }];
+}
+
+#endif // SUBSET_OP_INTERFACE
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 8f11c563e0cbd91..8e2986a2d1f05f6 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -19,6 +19,7 @@
 #include <queue>
 
 namespace mlir {
+class OffsetSizeAndStrideOpInterface;
 
 using ValueDimList = SmallVector<std::pair<Value, std::optional<int64_t>>>;
 
@@ -134,11 +135,11 @@ class ValueBoundsConstraintSet {
                           std::optional<int64_t> dim, ValueRange independencies,
                           bool closedUB = false);
 
-  /// Compute a constant bound for the given index-typed value or shape
-  /// dimension size.
+  /// Compute a constant bound for the given affine map, where dims and symbols
+  /// are bound to the given operands. The affine map must have exactly one
+  /// result.
   ///
-  /// `dim` must be `nullopt` if and only if `value` is index-typed. This
-  /// function traverses the backward slice of the given value in a
+  /// This function traverses the backward slice of the given operands in a
   /// worklist-driven manner until `stopCondition` evaluates to "true". The
   /// constraint set is populated according to `ValueBoundsOpInterface` for each
   /// visited value. (No constraints are added for values for which the stop
@@ -155,26 +156,12 @@ class ValueBoundsConstraintSet {
                        std::optional<int64_t> dim = std::nullopt,
                        StopConditionFn stopCondition = nullptr,
                        bool closedUB = false);
-
-  /// Compute a constant bound for the given affine map, where dims and symbols
-  /// are bound to the given operands. The affine map must have exactly one
-  /// result.
-  ///
-  /// This function traverses the backward slice of the given operands in a
-  /// worklist-driven manner until `stopCondition` evaluates to "true". The
-  /// constraint set is populated according to `ValueBoundsOpInterface` for each
-  /// visited value. (No constraints are added for values for which the stop
-  /// condition evaluates to "true".)
-  ///
-  /// The stop condition is optional: If none is specified, the backward slice
-  /// is traversed in a breadth-first manner until a constant bound could be
-  /// computed.
-  ///
-  /// By default, lower/equal bounds are closed and upper bounds are open. If
-  /// `closedUB` is set to "true", upper bounds are also closed.
   static FailureOr<int64_t> computeConstantBound(
       presburger::BoundType type, AffineMap map, ValueDimList mapOperands,
       StopConditionFn stopCondition = nullptr, bool closedUB = false);
+  static FailureOr<int64_t> computeConstantBound(
+      presburger::BoundType type, AffineMap map, ArrayRef<Value> mapOperands,
+      StopConditionFn stopCondition = nullptr, bool closedUB = false);
 
   /// Compute a constant delta between the given two values. Return "failure"
   /// if a constant delta could not be determined.
@@ -195,6 +182,13 @@ class ValueBoundsConstraintSet {
                                   std::optional<int64_t> dim1 = std::nullopt,
                                   std::optional<int64_t> dim2 = std::nullopt);
 
+  /// Return "true" if the given slices are guaranteed to be overlapping.
+  /// Return "false" if the given slices are guaranteed to be non-overlapping.
+  /// Return "failure" if unknown.
+  static FailureOr<bool>
+  areOverlappingSlices(OffsetSizeAndStrideOpInterface slice1,
+                       OffsetSizeAndStrideOpInterface slice2);
+
   /// Add a bound for the given index-typed value or shaped value. This function
   /// returns a builder that adds the bound.
   BoundBuilder bound(Value value) { return BoundBuilder(*this, value); }
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 52ff6ceeee85b03..8f19245efdba6c8 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -643,6 +643,18 @@ OpOperand &MaterializeInDestinationOp::getSourceOperand() {
   return getOperation()->getOpOperand(0) /*source*/;
 }
 
+bool MaterializeInDestinationOp::operatesOnEquivalentSubset(
+    SubsetOpInterface subsetOp,
+    function_ref<bool(Value, Value)> equivalenceFn) {
+  return false;
+}
+
+bool MaterializeInDestinationOp::operatesOnDisjointSubset(
+    SubsetOpInterface subsetOp,
+    function_ref<bool(Value, Value)> equivalenceFn) {
+  return false;
+}
+
 LogicalResult MaterializeInDestinationOp::verify() {
   if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
     return emitOpError("'dest' must be a tensor or a memref");
diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
index 385d8dc9364e379..9895db9d93ce0bb 100644
--- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
@@ -23,7 +23,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
   MLIRFunctionInterfaces
   MLIRIR
   MLIRSparseTensorDialect
-  MLIRSubsetInsertionOpInterface
+  MLIRSubsetOpInterface
   MLIRTensorDialect
   MLIRMemRefDialect
   )
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index a6876c7c824e0ce..8617c17e7a5e5e5 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -36,7 +36,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
   MLIRTensorDialect
   MLIRSCFDialect
   MLIRSideEffectInterfaces
-  MLIRSubsetInsertionOpInterface
+  MLIRSubsetOpInterface
   MLIRTransforms
   MLIRViewLikeInterface
   MLIRSupport
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 6622cfefa76a26f..4a418a05e6ff565 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -15,7 +15,7 @@
 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Dominance.h"
-#include "mlir/Interfaces/SubsetInsertionOpInterface.h"
+#include "mlir/Interfaces/SubsetOpInterface.h"
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 0bbfdba2b6e6ef9..c48402c3742a77b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -54,7 +54,7 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
-#include "mlir/Interfaces/SubsetInsertionOpInterface.h"
+#include "mlir/Interfaces/SubsetOpInterface.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/SetVector.h"
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index bad246c262979b7..bb90e5ee546d113 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -66,7 +66,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   MLIRSCFTransforms
   MLIRSCFUtils
   MLIRPass
-  MLIRSubsetInsertionOpInterface
+  MLIRSubsetOpInterface
   MLIRSparseTensorDialect
   MLIRTensorDialect
   MLIRTensorTilingInterfaceImpl
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp
index e0819082102ef66..e9fe93efceec356 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp
@@ -9,12 +9,38 @@
 #include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
 
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Interfaces/SubsetInsertionOpInterface.h"
+#include "mlir/Interfaces/SubsetOpInterface.h"
 
 using namespace mlir;
 using namespace mlir::linalg;
 
 namespace {
+struct LinalgCopyOpSubsetOpInterface
+    : public SubsetOpInterface::ExternalModel<LinalgCopyOpSubsetOpInterface,
+                                              linalg::CopyOp> {
+  bool operatesOnEquivalentSubset(
+      Operation *op, SubsetOpInterface candidate,
+      function_ref<bool(Value, Value)> equivalenceFn) const {
+    // linalg.copy operates on the entire source tensor. This interface is
+    // currently needed only for One-Shot Bufferize, which only uses
+    // `SubsetInsertionOpInterface`, so this interface method is currently not
+    // needed. In the absence of an analysis, "false" is a conservative way to
+    // implement this interface method.
+    return false;
+  }
+
+  bool operatesOnDisjointSubset(
+      Operation *op, SubsetOpInterface candidate,
+      function_ref<bool(Value, Value)> equivalenceFn) const {
+    // linalg.copy operates on the entire source tensor. This interface is
+    // currently needed only for One-Shot Bufferize, which only uses
+    // `SubsetInsertionOpInterface`, so this interface method is currently not
+    // needed. In the absence of an analysis, "false" is a conservative way to
+    // implement this interface method.
+    return false;
+  }
+};
+
 struct LinalgCopyOpInterface
     : public SubsetInsertionOpInterface::ExternalModel<LinalgCopyOpInterface,
                                                        linalg::CopyOp> {
@@ -48,9 +74,10 @@ struct LinalgCopyOpInterface
 };
 } // namespace
 
-void mlir::linalg::registerSubsetInsertionOpInterfaceExternalModels(
+void mlir::linalg::registerSubsetOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
+    linalg::CopyOp::attachInterface<LinalgCopyOpSubsetOpInterface>(*ctx);
     linalg::CopyOp::attachInterface<LinalgCopyOpInterface>(*ctx);
   });
 }
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index a95443db88b50b2..2cd57e7324b4dc5 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1066,5 +1066,5 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
 
   // Bufferization requires SubsetInsertionOpInterface models. Make sure that
   // they are registered.
-  tensor::registerSubsetInsertionOpInterfaceExternalModels(registry);
+  tensor::registerSubsetOpInterfaceExternalModels(registry);
 }
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index a0c172ac52e4be8..c5fd4e65bbf7028 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -30,7 +30,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
   MLIRMemRefDialect
   MLIRPass
   MLIRSCFDialect
-  MLIRSubsetInsertionOpInterface
+  MLIRSubsetOpInterface
   MLIRTensorDialect
   MLIRTensorUtils
   MLIRTilingInterface
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
index dbda9953684f41d..7a1bafd409eea60 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
@@ -9,17 +9,115 @@
 #include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
 
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Interfaces/SubsetInsertionOpInterface.h"
+#include "mlir/Interfaces/SubsetOpInterface.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
 
 using namespace mlir;
 using namespace mlir::tensor;
 
 namespace {
 
+/// Return the tensor that the given subset op operates on.
+Value getContainerOperand(SubsetOpInterface op) {
+  if (auto extractionOp =
+          dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
+    return extractionOp.getSourceOperand().get();
+  if (auto insertionOp =
+          dyn_cast<SubsetInsertionOpInterface>(op.getOperation()))
+    return insertionOp.getDestinationOperand().get();
+  llvm_unreachable("expected SubsetExtraction/InsertionOpInterface");
+}
+
+/// Return "true" if the two ops operate on an equivalent subset.
+/// `equivalenceFn` is used to determine equivalence of tensors. Return "false"
+/// if the two ops operate non-equivalent subsets, if equivalence cannot be
+/// determined or if `op1` is not a subset op.
+template <typename OpTy>
+bool operateOnEquivalentSubsets(
+    OpTy op1, SubsetOpInterface op2,
+    function_ref<bool(Value, Value)> equivalenceFn) {
+  auto offsetsSizesAndStrides2 =
+      dyn_cast<OffsetSizeAndStrideOpInterface>(op2.getOperation());
+  if (!offsetsSizesAndStrides2)
+    return false;
+  if (!sameOffsetsSizesAndStrides(op1, offsetsSizesAndStrides2,
+                                  isEqualConstantIntOrValue))
+    return false;
+  return equivalenceFn(
+      getContainerOperand(cast<SubsetOpInterface>(op1.getOperation())),
+      getContainerOperand(op2));
+}
+
+/// Return "true" if the two ops operate on a disjoint subsets.
+/// `equivalenceFn` is used to determine equivalence of tensors. Return "false"
+/// if the two ops operate non-disjoint subsets, if disjointness cannot be
+/// determined or if `op1` is not a subset op.
+template <typename OpTy>
+bool operateOnDisjointSubsets(OpTy op1, SubsetOpInterface op2,
+                              function_ref<bool(Value, Value)> equivalenceFn) {
+  auto offsetsSizesAndStrides2 =
+      dyn_cast<OffsetSizeAndStrideOpInterface>(op2.getOperation());
+  if (!offsetsSizesAndStrides2)
+    return false;
+  FailureOr<bool> overlappingSlices =
+      ValueBoundsConstraintSet::areOverlappingSlices(op1,
+                                                     offsetsSizesAndStrides2);
+  if (failed(overlappingSlices) || *overlappingSlices)
+    return false;
+  return equivalenceFn(
+      getContainerOperand(cast<SubsetOpInterface>(op1.getOperation())),
+      getContainerOperand(op2));
+}
+
+struct ExtractSliceOpSubsetOpInterface
+    : public SubsetOpInterface::ExternalModel<ExtractSliceOpSubsetOpInterface,
+                                              tensor::ExtractSliceOp> {
+  bool operatesOnEquivalentSubset(
+      Operation *op, SubsetOpInterface candidate,
+      function_ref<bool(Value, Value)> equivalenceFn) const {
+    auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
+    return operateOnEquivalentSubsets(extractSliceOp, candidate, equivalenceFn);
+  }
+
+  bool operatesOnDisjointSubset(
+      Operation *op, SubsetOpInterface candidate,
+      function_ref<bool(Value, Value)> equivalenceFn) const {
+    auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
+    return operateOnDisjointSubsets(extractSliceOp, candidate, equivalenceFn);
+  }
+};
+
+struct ExtractSliceOpSubsetExtractionOpInterface
+    : public SubsetExtractionOpInterface::ExternalModel<
+          ExtractSliceOpSubsetExtractionOpInterface, tensor::ExtractSliceOp> {
+  OpOperand &getSourceOperand(Operation *op) const {
+    return cast<tensor::ExtractSliceOp>(op).getSourceMutable();
+  }
+};
+
 template <typename OpTy>
-struct InsertSliceLikeOpInterface
+struct InsertSliceLikeOpSubsetOpInterface
+    : public SubsetOpInterface::ExternalModel<
+          InsertSliceLikeOpSubsetOpInterface<OpTy>, OpTy> {
+  bool operatesOnEquivalentSubset(
+      Operation *op, SubsetOpInterface candidate,
+      function_ref<bool(Value, Value)> equivalenceFn) const {
+    auto insertSliceOp = cast<OpTy>(op);
+    return operateOnEquivalentSubsets(insertSliceOp, candidate, equivalenceFn);
+  }
+
+  bool operatesOnDisjointSubset(
+      Operation *op, SubsetOpInterface candidate,
+      function_ref<bool(Value, Value)> equivalenceFn) const {
+    auto insertSliceOp = cast<OpTy>(op);
+    return operateOnDisjointSubsets(insertSliceOp, candidate, equivalenceFn);
+  }
+};
+
+template <typename OpTy>
+struct InsertSliceLikeOpSubsetInsertionOpInterface
     : public SubsetInsertionOpInterface::ExternalModel<
-          InsertSliceLikeOpInterface<OpTy>, OpTy> {
+          InsertSliceLikeOpSubsetInsertionOpInterface<OpTy>, OpTy> {
   OpOperand &getSourceOperand(Operation *op) const {
     return cast<OpTy>(op).getSourceMutable();
   }
@@ -28,23 +126,6 @@ struct InsertSliceLikeOpInterface
     return cast<OpTy>(op).getDestMutable();
   }
 
-  /// Return "true" if `insertSliceOp` inserts into a subset that is equivalent
-  /// to the subset defined by `candidate`. `equivalenceFn` is used to determine
-  /// equivalence of tensors.
-  bool
-  isEquivalentSubset(Operation *op, Value candidate,
-                     function_ref<bool(Value, Value)> equivalenceFn) const {
-    auto insertSliceOp = cast<OpTy>(op);
-    // Look for a matching tensor.extract_slice op.
-    auto extractSliceOp = candidate.getDefiningOp<tensor::ExtractSliceOp>();
-    if (!extractSliceOp)
-      return false;
-    if (!equivalenceFn(extractSliceOp.getSource(), insertSliceOp.getDest()))
-      return false;
-    return sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
-                                      isEqualConstantIntOrValue);
-  }
-
   Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
                               Location loc) const {
     auto insertSliceOp = cast<OpTy>(op);
@@ -73,12 +154,22 @@ struct InsertSliceLikeOpInterface
 
 } // namespace
 
-void mlir::tensor::registerSubsetInsertionOpInterfaceExternalModels(
+void mlir::tensor::registerSubsetOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
-    InsertSliceOp::attachInterface<InsertSliceLikeOpInterface<InsertSliceOp>>(
+    // Note: `SubsetExtractionOpInterface` and `SubsetInsertionOpInterface`
+    // require `SubsetOpInterface`.
+    ExtractSliceOp::attachInterface<ExtractSliceOpSubsetOpInterface>(*ctx);
+    ExtractSliceOp::attachInterface<ExtractSliceOpSubsetExtractionOpInterface>(
         *ctx);
+    InsertSliceOp::attachInterface<
+        InsertSliceLikeOpSubsetOpInterface<InsertSliceOp>>(*ctx);
+    InsertSliceOp::attachInterface<
+        InsertSliceLikeOpSubsetInsertionOpInterface<InsertSliceOp>>(*ctx);
+    ParallelInsertSliceOp::attachInterface<
+        InsertSliceLikeOpSubsetOpInterface<ParallelInsertSliceOp>>(*ctx);
     ParallelInsertSliceOp::attachInterface<
-        InsertSliceLikeOpInterface<ParallelInsertSliceOp>>(*ctx);
+        InsertSliceLikeOpSubsetInsertionOpInterface<ParallelInsertSliceOp>>(
+        *ctx);
   });
 }
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index f74306206d63f14..2652d261f480ba4 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -16,7 +16,7 @@ set(LLVM_OPTIONAL_SOURCES
   RuntimeVerifiableOpInterface.cpp
   ShapedOpInterfaces.cpp
   SideEffectInterfaces.cpp
-  SubsetInsertionOpInterface.cpp
+  SubsetOpInterface.cpp
   TilingInterface.cpp
   ValueBoundsOpInterface.cpp
   VectorInterfaces.cpp
@@ -84,15 +84,15 @@ add_mlir_interface_library(RuntimeVerifiableOpInterface)
 add_mlir_interface_library(ShapedOpInterfaces)
 add_mlir_interface_library(SideEffectInterfaces)
 
-add_mlir_library(MLIRSubsetInsertionOpInterface
-  SubsetInsertionOpInterface.cpp
+add_mlir_library(MLIRSubsetOpInterface
+  SubsetOpInterface.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces
 
   DEPENDS
   MLIRDestinationStyleOpInterface
-  MLIRSubsetInsertionOpInterfaceIncGen
+  MLIRSubsetOpInterfaceIncGen
 
   LINK_LIBS PUBLIC
   MLIRDestinationStyleOpInterface
@@ -112,6 +112,7 @@ add_mlir_library(MLIRValueBoundsOpInterface
   DEPENDS
   MLIRDestinationStyleOpInterface
   MLIRValueBoundsOpInterfaceIncGen
+  MLIRViewLikeInterface
 
   LINK_LIBS PUBLIC
   MLIRAnalysis
diff --git a/mlir/lib/Interfaces/SubsetInsertionOpInterface.cpp b/mlir/lib/Interfaces/SubsetInsertionOpInterface.cpp
deleted file mode 100644
index b2b092287f96ba6..000000000000000
--- a/mlir/lib/Interfaces/SubsetInsertionOpInterface.cpp
+++ /dev/null
@@ -1,23 +0,0 @@
-//===- SubsetInsertionOpInterface.cpp - Tensor Subsets --------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Interfaces/SubsetInsertionOpInterface.h"
-#include "mlir/Interfaces/DestinationStyleOpInterface.h"
-
-#include "mlir/Interfaces/SubsetInsertionOpInterface.cpp.inc"
-
-using namespace mlir;
-
-OpOperand &detail::defaultGetDestinationOperand(Operation *op) {
-  auto dstOp = dyn_cast<DestinationStyleOpInterface>(op);
-  assert(dstOp && "getDestination must be implemented for non-DPS ops");
-  assert(
-      dstOp.getNumDpsInits() == 1 &&
-      "getDestination must be implemented for ops with 0 or more than 1 init");
-  return *dstOp.getDpsInitOperand(0);
-}
diff --git a/mlir/lib/Interfaces/SubsetOpInterface.cpp b/mlir/lib/Interfaces/SubsetOpInterface.cpp
new file mode 100644
index 000000000000000..7245ab20c499e20
--- /dev/null
+++ b/mlir/lib/Interfaces/SubsetOpInterface.cpp
@@ -0,0 +1,58 @@
+//===- SubsetOpInterface.cpp - Tensor Subsets -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/SubsetOpInterface.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+
+#include "mlir/Interfaces/SubsetOpInterface.cpp.inc"
+
+using namespace mlir;
+
+OpOperand &detail::defaultGetDestinationOperand(Operation *op) {
+  auto dstOp = dyn_cast<DestinationStyleOpInterface>(op);
+  assert(dstOp && "getDestination must be implemented for non-DPS ops");
+  assert(
+      dstOp.getNumDpsInits() == 1 &&
+      "getDestination must be implemented for ops with 0 or more than 1 init");
+  return *dstOp.getDpsInitOperand(0);
+}
+
+OpResult detail::defaultGetUpdatedDestination(Operation *op) {
+  auto dstOp = dyn_cast<DestinationStyleOpInterface>(op);
+  assert(dstOp && "getUpdatedDestination must be implemented for non-DPS ops");
+  auto insertionOp = cast<SubsetInsertionOpInterface>(op);
+  return dstOp.getTiedOpResult(&insertionOp.getDestinationOperand());
+}
+
+bool detail::defaultIsEquivalentSubset(
+    Operation *op, Value candidate,
+    function_ref<bool(Value, Value)> equivalenceFn) {
+  assert(isa<SubsetInsertionOpInterface>(op) &&
+         "expected SubsetInsertionOpInterface");
+  if (!candidate.getDefiningOp<SubsetExtractionOpInterface>())
+    return false;
+  return cast<SubsetOpInterface>(op).operatesOnEquivalentSubset(
+      candidate.getDefiningOp<SubsetOpInterface>(), equivalenceFn);
+}
+
+LogicalResult detail::verifySubsetOpInterface(SubsetOpInterface op) {
+  if (!(isa<SubsetExtractionOpInterface>(op.getOperation()) ^
+        isa<SubsetInsertionOpInterface>(op.getOperation())))
+    return op->emitOpError(
+        "SubsetOpInterface ops must implement either "
+        "SubsetExtractionOpInterface or SubsetInsertionOpInterface");
+  return success();
+}
+
+LogicalResult
+detail::verifySubsetExtractionOpInterface(SubsetExtractionOpInterface op) {
+  if (op->getNumResults() != 1)
+    return op->emitOpError(
+        "SubsetExtractionOpInterface ops must have one result");
+  return success();
+}
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index ff941115219f68b..200f7f3b25b731b 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -11,6 +11,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
 #include "llvm/ADT/APSInt.h"
 #include "llvm/Support/Debug.h"
 
@@ -484,6 +485,17 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
   return failure();
 }
 
+FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
+    presburger::BoundType type, AffineMap map, ArrayRef<Value> operands,
+    StopConditionFn stopCondition, bool closedUB) {
+  ValueDimList valueDims;
+  for (Value v : operands) {
+    assert(v.getType().isIndex() && "expected index type");
+    valueDims.emplace_back(v, std::nullopt);
+  }
+  return computeConstantBound(type, map, valueDims, stopCondition, closedUB);
+}
+
 FailureOr<int64_t>
 ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
                                                std::optional<int64_t> dim1,
@@ -512,6 +524,70 @@ ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
   return *delta == 0;
 }
 
+FailureOr<bool> ValueBoundsConstraintSet::areOverlappingSlices(
+    OffsetSizeAndStrideOpInterface slice1,
+    OffsetSizeAndStrideOpInterface slice2) {
+  assert(slice1.getStaticOffsets().size() == slice1.getStaticOffsets().size() &&
+         "expected slices of same rank");
+  assert(slice1.getStaticSizes().size() == slice1.getStaticSizes().size() &&
+         "expected slices of same rank");
+  assert(slice1.getStaticStrides().size() == slice1.getStaticStrides().size() &&
+         "expected slices of same rank");
+
+  Builder b(slice1.getContext());
+  bool foundUnknownBound = false;
+  for (int64_t i = 0, e = slice1.getStaticOffsets().size(); i < e; ++i) {
+    AffineMap map =
+        AffineMap::get(/*dimCount=*/0, /*symbolCount=*/4,
+                       b.getAffineSymbolExpr(0) +
+                           b.getAffineSymbolExpr(1) * b.getAffineSymbolExpr(2) -
+                           b.getAffineSymbolExpr(3));
+    {
+      // Case 1: Slices are guaranteed to be non-overlapping if
+      // offset1 + size1 * stride1 <= offset2 (for at least one dimension).
+      SmallVector<OpFoldResult> ofrOperands;
+      ofrOperands.push_back(slice1.getMixedOffsets()[i]);
+      ofrOperands.push_back(slice1.getMixedSizes()[i]);
+      ofrOperands.push_back(slice1.getMixedStrides()[i]);
+      ofrOperands.push_back(slice2.getMixedOffsets()[i]);
+      SmallVector<Value> valueOperands;
+      AffineMap foldedMap =
+          foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
+      FailureOr<int64_t> constBound = computeConstantBound(
+          presburger::BoundType::EQ, foldedMap, valueOperands);
+      foundUnknownBound |= failed(constBound);
+      if (succeeded(constBound) && *constBound <= 0)
+        return false;
+    }
+    {
+      // Case 2: Slices are guaranteed to be non-overlapping if
+      // offset2 + size2 * stride2 <= offset1 (for at least one dimension).
+      SmallVector<OpFoldResult> ofrOperands;
+      ofrOperands.push_back(slice2.getMixedOffsets()[i]);
+      ofrOperands.push_back(slice2.getMixedSizes()[i]);
+      ofrOperands.push_back(slice2.getMixedStrides()[i]);
+      ofrOperands.push_back(slice1.getMixedOffsets()[i]);
+      SmallVector<Value> valueOperands;
+      AffineMap foldedMap =
+          foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
+      FailureOr<int64_t> constBound = computeConstantBound(
+          presburger::BoundType::EQ, foldedMap, valueOperands);
+      foundUnknownBound |= failed(constBound);
+      if (succeeded(constBound) && *constBound <= 0)
+        return false;
+    }
+  }
+
+  // If at least one bound could not be computed, we cannot be certain that the
+  // slices are really overlapping.
+  if (foundUnknownBound)
+    return failure();
+
+  // All bounds could be computed and none of the above cases applied.
+  // Therefore, the slices are guaranteed to overlap.
+  return true;
+}
+
 ValueBoundsConstraintSet::BoundBuilder &
 ValueBoundsConstraintSet::BoundBuilder::operator[](int64_t dim) {
   assert(!this->dim.has_value() && "dim was already set");
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index b1bd543ae7fce63..add146bf084d321 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6933,7 +6933,7 @@ cc_library(
         ":MemRefDialect",
         ":Pass",
         ":SCFDialect",
-        ":SubsetInsertionOpInterface",
+        ":SubsetOpInterface",
         ":TensorDialect",
         ":TensorPassIncGen",
         ":TensorUtils",
@@ -10189,9 +10189,9 @@ gentbl_cc_library(
 )
 
 td_library(
-    name = "SubsetInsertionOpInterfaceTdFiles",
+    name = "SubsetOpInterfaceTdFiles",
     srcs = [
-        "include/mlir/Interfaces/SubsetInsertionOpInterface.td",
+        "include/mlir/Interfaces/SubsetOpInterface.td",
     ],
     includes = ["include"],
     deps = [
@@ -10200,33 +10200,33 @@ td_library(
 )
 
 gentbl_cc_library(
-    name = "SubsetInsertionOpInterfaceIncGen",
+    name = "SubsetOpInterfaceIncGen",
     tbl_outs = [
         (
             ["-gen-op-interface-decls"],
-            "include/mlir/Interfaces/SubsetInsertionOpInterface.h.inc",
+            "include/mlir/Interfaces/SubsetOpInterface.h.inc",
         ),
         (
             ["-gen-op-interface-defs"],
-            "include/mlir/Interfaces/SubsetInsertionOpInterface.cpp.inc",
+            "include/mlir/Interfaces/SubsetOpInterface.cpp.inc",
         ),
     ],
     tblgen = ":mlir-tblgen",
-    td_file = "include/mlir/Interfaces/SubsetInsertionOpInterface.td",
+    td_file = "include/mlir/Interfaces/SubsetOpInterface.td",
     deps = [
-        ":SubsetInsertionOpInterfaceTdFiles",
+        ":SubsetOpInterfaceTdFiles",
     ],
 )
 
 cc_library(
-    name = "SubsetInsertionOpInterface",
-    srcs = ["lib/Interfaces/SubsetInsertionOpInterface.cpp"],
-    hdrs = ["include/mlir/Interfaces/SubsetInsertionOpInterface.h"],
+    name = "SubsetOpInterface",
+    srcs = ["lib/Interfaces/SubsetOpInterface.cpp"],
+    hdrs = ["include/mlir/Interfaces/SubsetOpInterface.h"],
     includes = ["include"],
     deps = [
         ":DestinationStyleOpInterface",
         ":IR",
-        ":SubsetInsertionOpInterfaceIncGen",
+        ":SubsetOpInterfaceIncGen",
         ":Support",
         "//llvm:Support",
     ],
@@ -10470,7 +10470,7 @@ cc_library(
         ":SCFTransforms",
         ":SCFUtils",
         ":SparseTensorDialect",
-        ":SubsetInsertionOpInterface",
+        ":SubsetOpInterface",
         ":Support",
         ":TensorDialect",
         ":TensorTilingInterfaceImpl",
@@ -10530,6 +10530,7 @@ cc_library(
         ":IR",
         ":Support",
         ":ValueBoundsOpInterfaceIncGen",
+        ":ViewLikeInterface",
         "//llvm:Support",
     ],
 )
@@ -12579,7 +12580,7 @@ gentbl_cc_library(
         ":BufferizableOpInterfaceTdFiles",
         ":BufferizationOpsTdFiles",
         ":DestinationStyleOpInterfaceTdFiles",
-        ":SubsetInsertionOpInterfaceTdFiles",
+        ":SubsetOpInterfaceTdFiles",
     ],
 )
 
@@ -12619,7 +12620,7 @@ cc_library(
         ":InferTypeOpInterface",
         ":MemRefDialect",
         ":SparseTensorDialect",
-        ":SubsetInsertionOpInterface",
+        ":SubsetOpInterface",
         ":Support",
         ":TensorDialect",
         "//llvm:Support",
@@ -12669,7 +12670,7 @@ cc_library(
         ":Pass",
         ":SCFDialect",
         ":SideEffectInterfaces",
-        ":SubsetInsertionOpInterface",
+        ":SubsetOpInterface",
         ":Support",
         ":TensorDialect",
         ":Transforms",
diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
index c83e4cc7ada23e3..348ee2beabeb061 100644
--- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
@@ -413,7 +413,7 @@ gentbl_filegroup(
     deps = [
         ":BufferizationOpsPyTdFiles",
         "//mlir:DestinationStyleOpInterfaceTdFiles",
-        "//mlir:SubsetInsertionOpInterfaceTdFiles",
+        "//mlir:SubsetOpInterfaceTdFiles",
     ],
 )
 

>From 23ce8911cf57df4ad6cdeb9f2cd89bc7b0c000d5 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 30 Oct 2023 14:32:49 +0900
Subject: [PATCH 2/7] [mlir] Loop-invariant subset hoisting

---
 .../Transforms/LoopInvariantCodeMotionUtils.h |  39 +++
 mlir/include/mlir/Transforms/Passes.h         |   2 +
 mlir/include/mlir/Transforms/Passes.td        |   5 +
 .../lib/Interfaces/ValueBoundsOpInterface.cpp |   4 +-
 .../Transforms/LoopInvariantCodeMotion.cpp    |  20 ++
 mlir/lib/Transforms/Utils/CMakeLists.txt      |   1 +
 .../Utils/LoopInvariantCodeMotionUtils.cpp    | 253 +++++++++++++++++-
 .../loop-invariant-subset-hoisting.mlir       | 237 ++++++++++++++++
 .../llvm-project-overlay/mlir/BUILD.bazel     |   1 +
 9 files changed, 556 insertions(+), 6 deletions(-)
 create mode 100644 mlir/test/Transforms/loop-invariant-subset-hoisting.mlir

diff --git a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
index c7b816eb28faf5f..579054070f729b0 100644
--- a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
@@ -71,6 +71,45 @@ size_t moveLoopInvariantCode(
 /// methods provided by the interface.
 size_t moveLoopInvariantCode(LoopLikeOpInterface loopLike);
 
+/// Hoist loop-invariant tensor subsets (subset extraction and subset insertion
+/// ops) from loop-like ops. Extraction ops are moved before the loop. Insertion
+/// ops are moved after the loop. The loop body operates on newly added region
+/// iter_args (one per extraction-insertion pair).
+///
+/// A subset extraction op (`SubsetExtractionOpInterface`) extracts from a
+/// tensor value at a subset. The result of the op may have an arbitrary type,
+/// i.e., not necessarily a tensor type. Example: "tensor.extract_slice".
+///
+/// A subset insertion op  (`SubsetInsertionOpInterface`) inserts into a tensor
+/// value ("destination") at a subset. Example: "tensor.insert_slice".
+///
+/// Matching extraction-insertion subset ops can be hoisted from a loop if there
+/// are no other ops within the loop that operate on the same or on an
+/// overlapping subset. In particular, non-subset ops can prevent hoisting
+/// because the analysis does not know what subset they operate on.
+///
+/// Example:
+/// ```
+/// %r = scf.for ... iter_args(%t = %a) -> (tensor<?xf32>) {
+///   %0 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+///   %1 = "test.foo"(%0) : (tensor<5xf32>) -> (tensor<5xf32>)
+///   %2 = tensor.insert_slice %1 into %t[0][5][1]
+///       : tensor<5xf32> into tensor<?xf32>
+///   scf.yield %2 : tensor<?xf32>
+/// }
+/// ```
+/// Is rewritten to:
+/// ```
+/// %0 = tensor.extract_slice %a[0][5][1] : tensor<?xf32> to tensor<5xf32>
+/// %new_loop:2 = scf.for ... iter_args(%t = %a, %h = %0) -> (tensor<?xf32>) {
+///   %1 = "test.foo"(%h) : (tensor<5xf32>) -> (tensor<5xf32>)
+///   scf.yield %t, %2 : tensor<?xf32>, tensor<5xf32>
+/// }
+/// %r = tensor.insert_slice %new_loop#1 into %new_loop#0
+///     : tensor<5xf32> into tensor<?xf32>
+/// ```
+LoopLikeOpInterface hoistLoopInvariantSubsets(LoopLikeOpInterface loopLike);
+
 } // end namespace mlir
 
 #endif // MLIR_TRANSFORMS_LOOPINVARIANTCODEMOTIONUTILS_H
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 320932bb999561f..ab04478ae076694 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -78,6 +78,8 @@ std::unique_ptr<Pass> createGenerateRuntimeVerificationPass();
 /// instructions out of the loop.
 std::unique_ptr<Pass> createLoopInvariantCodeMotionPass();
 
+std::unique_ptr<Pass> createLoopInvariantSubsetHoistingPass();
+
 /// Creates a pass to strip debug information from a function.
 std::unique_ptr<Pass> createStripDebugInfoPass();
 
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 26d2ff3c30ded57..2d2d54fb8fb5eaa 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -329,6 +329,11 @@ def LoopInvariantCodeMotion : Pass<"loop-invariant-code-motion"> {
   let constructor = "mlir::createLoopInvariantCodeMotionPass()";
 }
 
+def LoopInvariantSubsetHoisting : Pass<"loop-invariant-subset-hoisting"> {
+  let summary = "Hoist loop invariant subset ops outside of the loop";
+  let constructor = "mlir::createLoopInvariantSubsetHoistingPass()";
+}
+
 def Mem2Reg : Pass<"mem2reg"> {
   let summary = "Promotes memory slots into values.";
   let description = [{
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index 200f7f3b25b731b..f0c37c872e6d31d 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -86,7 +86,7 @@ AffineExpr ValueBoundsConstraintSet::getExpr(Value value,
       return builder.getAffineConstantExpr(shapedType.getDimSize(*dim));
   } else {
     // Constant index value: return directly.
-    if (auto constInt = getConstantIntValue(value))
+    if (auto constInt = ::getConstantIntValue(value))
       return builder.getAffineConstantExpr(*constInt);
   }
 
@@ -103,7 +103,7 @@ AffineExpr ValueBoundsConstraintSet::getExpr(Value value,
 AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) {
   if (Value value = llvm::dyn_cast_if_present<Value>(ofr))
     return getExpr(value, /*dim=*/std::nullopt);
-  auto constInt = getConstantIntValue(ofr);
+  auto constInt = ::getConstantIntValue(ofr);
   assert(constInt.has_value() && "expected Integer constant");
   return builder.getAffineConstantExpr(*constInt);
 }
diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
index 854fde09bac796e..e6d8af8f05832d3 100644
--- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
+++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
@@ -18,6 +18,7 @@
 
 namespace mlir {
 #define GEN_PASS_DEF_LOOPINVARIANTCODEMOTION
+#define GEN_PASS_DEF_LOOPINVARIANTSUBSETHOISTING
 #include "mlir/Transforms/Passes.h.inc"
 } // namespace mlir
 
@@ -29,6 +30,12 @@ struct LoopInvariantCodeMotion
     : public impl::LoopInvariantCodeMotionBase<LoopInvariantCodeMotion> {
   void runOnOperation() override;
 };
+
+struct LoopInvariantSubsetHoisting
+    : public impl::LoopInvariantSubsetHoistingBase<
+          LoopInvariantSubsetHoisting> {
+  void runOnOperation() override;
+};
 } // namespace
 
 void LoopInvariantCodeMotion::runOnOperation() {
@@ -39,6 +46,19 @@ void LoopInvariantCodeMotion::runOnOperation() {
       [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
 }
 
+void LoopInvariantSubsetHoisting::runOnOperation() {
+  // Walk through all loops in a function in innermost-loop-first order. This
+  // way, we first hoist from the inner loop, and place the ops in the outer
+  // loop, which in turn can be further hoisted from.
+  getOperation()->walk([&](LoopLikeOpInterface loopLike) {
+    (void)hoistLoopInvariantSubsets(loopLike);
+  });
+}
+
 std::unique_ptr<Pass> mlir::createLoopInvariantCodeMotionPass() {
   return std::make_unique<LoopInvariantCodeMotion>();
 }
+
+std::unique_ptr<Pass> mlir::createLoopInvariantSubsetHoistingPass() {
+  return std::make_unique<LoopInvariantSubsetHoisting>();
+}
diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
index efc7a5160b2399e..1c608e0634a67e2 100644
--- a/mlir/lib/Transforms/Utils/CMakeLists.txt
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -20,5 +20,6 @@ add_mlir_library(MLIRTransformUtils
   MLIRFunctionInterfaces
   MLIRLoopLikeInterface
   MLIRSideEffectInterfaces
+  MLIRSubsetOpInterface
   MLIRRewrite
   )
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index 080492da6ae4b97..f39f8ec3598f322 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -11,9 +11,12 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
+
 #include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/SubsetOpInterface.h"
 #include "llvm/Support/Debug.h"
 #include <queue>
 
@@ -26,7 +29,7 @@ using namespace mlir;
 ///   loop (by means of calling definedOutside).
 /// - the op has no side-effects.
 static bool canBeHoisted(Operation *op,
-                         function_ref<bool(Value)> definedOutside) {
+                         function_ref<bool(OpOperand &)> condition) {
   // Do not move terminators.
   if (op->hasTrait<OpTrait::IsTerminator>())
     return false;
@@ -35,11 +38,11 @@ static bool canBeHoisted(Operation *op,
   // defined outside of the loop or in a nested region, but not at the level of
   // the loop body.
   auto walkFn = [&](Operation *child) {
-    for (Value operand : child->getOperands()) {
+    for (OpOperand &operand : child->getOpOperands()) {
       // Ignore values defined in a nested region.
-      if (op->isAncestor(operand.getParentRegion()->getParentOp()))
+      if (op->isAncestor(operand.get().getParentRegion()->getParentOp()))
         continue;
-      if (!definedOutside(operand))
+      if (!condition(operand))
         return WalkResult::interrupt();
     }
     return WalkResult::advance();
@@ -47,6 +50,12 @@ static bool canBeHoisted(Operation *op,
   return !op->walk(walkFn).wasInterrupted();
 }
 
+static bool canBeHoisted(Operation *op,
+                         function_ref<bool(Value)> definedOutside) {
+  return canBeHoisted(
+      op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
+}
+
 size_t mlir::moveLoopInvariantCode(
     ArrayRef<Region *> regions,
     function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
@@ -105,3 +114,239 @@ size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
       },
       [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
 }
+
+namespace {
+/// Helper data structure that keeps track of equivalent/disjoint subset ops.
+class MatchingSubsets {
+public:
+  /// Insert a subset op.
+  void insert(SubsetOpInterface op) {
+    allSubsetOps.push_back(op);
+    if (auto extractionOp =
+            dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
+      insertExtractionOp(extractionOp);
+    if (auto insertionOp =
+            dyn_cast<SubsetInsertionOpInterface>(op.getOperation()))
+      insertInsertionOp(insertionOp);
+  }
+
+  /// Return a range of matching extraction-insertion subset ops. If there is no
+  /// matching extraction/insertion op, the respective value is empty. Ops are
+  /// skipped if there are other subset ops that are not guaranteed to operate
+  /// on disjoint subsets.
+  auto getHoistableSubsetOps() {
+    return llvm::make_filter_range(
+        llvm::zip(extractions, insertions), [&](auto pair) {
+          auto [extractionOp, insertionOp] = pair;
+          // Hoist only if the extracted and inserted values have the same type.
+          if (extractionOp && insertionOp &&
+              extractionOp->getResult(0).getType() !=
+                  insertionOp.getSourceOperand().get().getType())
+            return false;
+          // Hoist only if there are no conflicting subset ops.
+          return allDisjoint(extractionOp, insertionOp);
+        });
+  }
+
+private:
+  /// Helper function for equivalence of tensor values. Since only insertion
+  /// subset ops (that are also destination style ops) are followed when
+  /// traversing the SSA use-def chain, all tensor values are equivalent.
+  static bool isEquivalent(Value v1, Value v2) { return true; }
+
+  /// Return "true" if the subsets of the given extraction and insertion ops
+  /// are operating disjoint from the subsets that all other known subset ops
+  /// are operating on.
+  bool allDisjoint(SubsetExtractionOpInterface extractionOp,
+                   SubsetInsertionOpInterface insertionOp) const {
+    for (SubsetOpInterface other : allSubsetOps) {
+      if (other == extractionOp || other == insertionOp)
+        continue;
+      if (extractionOp &&
+          !other.operatesOnDisjointSubset(extractionOp, isEquivalent))
+        return false;
+      if (insertionOp &&
+          !other.operatesOnDisjointSubset(insertionOp, isEquivalent))
+        return false;
+    }
+    return true;
+  }
+
+  /// Insert a subset extraction op. If the subset is equivalent to an existing
+  /// subset insertion op, pair them up. (If there is already a paired up subset
+  /// extraction op, overwrite the subset extraction op.)
+  void insertExtractionOp(SubsetExtractionOpInterface extractionOp) {
+    for (auto it : llvm::enumerate(insertions)) {
+      if (!it.value())
+        continue;
+      auto other = cast<SubsetOpInterface>(it.value().getOperation());
+      if (other.operatesOnEquivalentSubset(extractionOp, isEquivalent)) {
+        extractions[it.index()] = extractionOp;
+        return;
+      }
+    }
+    // There is no known equivalent insertion op. Create a new entry.
+    extractions.push_back(extractionOp);
+    insertions.push_back({});
+  }
+
+  /// Insert a subset insertion op. If the subset is equivalent to an existing
+  /// subset extraction op, pair them up. (If there is already a paired up
+  /// subset insertion op, overwrite the subset insertion op.)
+  void insertInsertionOp(SubsetInsertionOpInterface insertionOp) {
+    for (auto it : llvm::enumerate(extractions)) {
+      if (!it.value())
+        continue;
+      auto other = cast<SubsetOpInterface>(it.value().getOperation());
+      if (other.operatesOnEquivalentSubset(insertionOp, isEquivalent)) {
+        insertions[it.index()] = insertionOp;
+        return;
+      }
+    }
+    // There is no known equivalent extraction op. Create a new entry.
+    extractions.push_back({});
+    insertions.push_back(insertionOp);
+  }
+
+  SmallVector<SubsetExtractionOpInterface> extractions;
+  SmallVector<SubsetInsertionOpInterface> insertions;
+  SmallVector<SubsetOpInterface> allSubsetOps;
+};
+} // namespace
+
+/// If the given value has a single use by an op that is a terminator, return
+/// that use. Otherwise, return nullptr.
+static OpOperand *getSingleTerminatorUse(Value value) {
+  if (!value.hasOneUse())
+    return nullptr;
+  OpOperand &use = *value.getUses().begin();
+  if (use.getOwner()->hasTrait<OpTrait::IsTerminator>())
+    return &use;
+  return nullptr;
+}
+
+/// Hoist all subset ops that operate on the idx-th region iter_arg of the given
+/// loop-like op and index into loop-invariant subset locations. Return the
+/// newly created loop op (that has extra iter_args) or the original loop op if
+/// nothing was hoisted.
+static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
+                                                BlockArgument iterArg) {
+  IRRewriter rewriter(loopLike.getContext());
+  assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
+  auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
+  int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
+  Value value = iterArg;
+  MatchingSubsets subsets;
+
+  // Traverse use-def chain. Subset ops can be hoisted only if all ops along the
+  // use-def chain starting from the region iter_arg are subset extraction or
+  // subset insertion ops. The chain must terminate at the corresponding yield
+  // operand (e.g., no swapping of iter_args).
+  OpOperand *yieldedOperand = nullptr;
+  // Iterate until the single use of the current SSA value is a terminator,
+  // which is expected to be the yielding operation of the loop.
+  while (!(yieldedOperand = getSingleTerminatorUse(value))) {
+    Value nextValue = {};
+
+    for (OpOperand &use : value.getUses()) {
+      auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner());
+      if (!subsetOp)
+        return loopLike;
+      subsets.insert(subsetOp);
+
+      if (auto insertionOp =
+              dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) {
+        // The value must be used as a destination. (In case of a source, the
+        // entire tensor would be read, which would prevent any hoisting.)
+        if (&use != &insertionOp.getDestinationOperand())
+          return loopLike;
+        // There must be a single use-def chain from the region iter_arg to the
+        // terminator. I.e., only one insertion op. Branches are not supported.
+        if (nextValue)
+          return loopLike;
+        nextValue = insertionOp.getUpdatedDestination();
+      }
+    }
+
+    // Nothing can be hoisted if the chain does not continue with loop yielding
+    // op or a subset insertion op.
+    if (!nextValue)
+      return loopLike;
+    value = nextValue;
+  }
+
+  // Hoist only if the SSA use-def chain ends in the yielding terminator of the
+  // loop and the yielded value is the `idx`-th operand. (I.e., there is no
+  // swapping yield.)
+  if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand)
+    return loopLike;
+
+  // Hoist all matching extraction-insertion pairs one-by-one.
+  for (auto it : subsets.getHoistableSubsetOps()) {
+    auto extractionOp = std::get<0>(it);
+    auto insertionOp = std::get<1>(it);
+
+    // Ops cannot be hoisted if they depend on loop-variant values.
+    if (extractionOp) {
+      if (!canBeHoisted(extractionOp, [&](OpOperand &operand) {
+            return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
+                   &operand == &extractionOp.getSourceOperand();
+          }))
+        extractionOp = {};
+    }
+    if (insertionOp) {
+      if (!canBeHoisted(insertionOp, [&](OpOperand &operand) {
+            return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
+                   &operand == &insertionOp.getSourceOperand() ||
+                   &operand == &insertionOp.getDestinationOperand();
+          }))
+        insertionOp = {};
+    }
+
+    // Only hoist extraction-insertion pairs for now. Standalone extractions/
+    // insertions that are loop-invariant could be hoisted, but there may be
+    // easier ways to canonicalize the IR.
+    if (extractionOp && insertionOp) {
+      // Create a new loop with an additional iter_arg.
+      NewYieldValuesFn newYieldValuesFn =
+          [&](OpBuilder &b, Location loc,
+              ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
+        return {insertionOp.getSourceOperand().get()};
+      };
+      FailureOr<LoopLikeOpInterface> newLoop =
+          loopLike.replaceWithAdditionalYields(
+              rewriter, extractionOp.getResult(),
+              /*replaceInitOperandUsesInLoop=*/true, newYieldValuesFn);
+      if (failed(newLoop))
+        return loopLike;
+      loopLike = *newLoop;
+
+      // Hoist the extraction/insertion ops.
+      iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
+      OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
+      OpResult newLoopResult = loopLike.getLoopResults()->back();
+      extractionOp->moveBefore(loopLike);
+      insertionOp->moveAfter(loopLike);
+      insertionOp.getUpdatedDestination().replaceAllUsesWith(
+          insertionOp.getDestinationOperand().get());
+      extractionOp.getSourceOperand().set(
+          loopLike.getTiedLoopInit(iterArg)->get());
+      loopResult.replaceAllUsesWith(insertionOp.getUpdatedDestination());
+      insertionOp.getSourceOperand().set(newLoopResult);
+      insertionOp.getDestinationOperand().set(loopResult);
+    }
+  }
+
+  return loopLike;
+}
+
+LoopLikeOpInterface
+mlir::hoistLoopInvariantSubsets(LoopLikeOpInterface loopLike) {
+  // Note: As subset ops are getting hoisted, the number of region iter_args
+  // increases. This can enable further hoisting opportunities on the new
+  // iter_args.
+  for (int64_t i = 0; i < loopLike.getRegionIterArgs().size(); ++i) {
+    loopLike = hoistSubsetAtIterArg(loopLike, loopLike.getRegionIterArgs()[i]);
+  }
+  return loopLike;
+}
diff --git a/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
new file mode 100644
index 000000000000000..5cded4c99182c14
--- /dev/null
+++ b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
@@ -0,0 +1,237 @@
+// RUN: mlir-opt %s  -split-input-file -loop-invariant-subset-hoisting | FileCheck %s
+
+// CHECK-LABEL: func @hoist_matching_extract_insert(
+//  CHECK-SAME:     %[[arg:.*]]: tensor<?xf32>
+func.func @hoist_matching_extract_insert(%arg: tensor<?xf32>) -> tensor<?xf32> {
+  %lb = "test.foo"() : () -> (index)
+  %ub = "test.foo"() : () -> (index)
+  %step = "test.foo"() : () -> (index)
+
+  // CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]]
+  // CHECK: %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]])
+  %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+    // CHECK: tensor.extract_slice %[[t]][9] [5] [1]
+    %standalone = tensor.extract_slice %t[9][5][1] : tensor<?xf32> to tensor<5xf32>
+    "test.foo"(%standalone) : (tensor<5xf32>) -> ()
+
+    %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+    // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]])
+    %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+    %3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
+    // CHECK: scf.yield %[[t]], %[[foo]]
+    scf.yield %3 : tensor<?xf32>
+  }
+  // CHECK: %[[insert:.*]] = tensor.insert_slice %[[for]]#1 into %[[for]]#0
+
+  // CHECK: return %[[insert]]
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+func.func @subset_of_subset(%arg: tensor<?xf32>) -> tensor<?xf32> {
+  %lb = "test.foo"() : () -> (index)
+  %ub = "test.foo"() : () -> (index)
+  %step = "test.foo"() : () -> (index)
+
+  // CHECK: %[[extract1:.*]] = tensor.extract_slice %[[arg]]
+  // CHECK: %[[extract2:.*]] = tensor.extract_slice %[[extract1]]
+  // CHECK: %[[for:.*]]:3 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted1:.*]] = %[[extract1]], %[[hoisted2:.*]] = %[[extract2]])
+  %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+    %extract1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+    %extract2 = tensor.extract_slice %extract1[1][2][1] : tensor<5xf32> to tensor<2xf32>
+
+    // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted2]])
+    %2 = "test.foo"(%extract2) : (tensor<2xf32>) -> (tensor<2xf32>)
+
+    %insert1 = tensor.insert_slice %2 into %extract1[1][2][1] : tensor<2xf32> into tensor<5xf32>
+    %insert2 = tensor.insert_slice %insert1 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
+
+    // CHECK: scf.yield %[[t]], %[[hoisted1]], %[[foo]]
+    scf.yield %insert2 : tensor<?xf32>
+  }
+  // CHECK: %[[insert2:.*]] = tensor.insert_slice %[[for]]#2 into %[[for]]#1[1] [2] [1]
+  // CHECK: %[[insert1:.*]] = tensor.insert_slice %[[insert2]] into %[[for]]#0[0] [5] [1]
+
+  // CHECK: return %[[insert1]]
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @hoist_matching_chain(
+//  CHECK-SAME:     %[[arg:.*]]: tensor<?xf32>
+func.func @hoist_matching_chain(%arg: tensor<?xf32>) -> tensor<?xf32> {
+  %lb = "test.foo"() : () -> (index)
+  %ub = "test.foo"() : () -> (index)
+  %step = "test.foo"() : () -> (index)
+  %sz = "test.foo"() : () -> (index)
+
+  // CHECK: %[[extract2:.*]] = tensor.extract_slice %[[arg]][%{{.*}}] [5] [1]
+  // CHECK: %[[extract1:.*]] = tensor.extract_slice %[[arg]][0] [%{{.*}}] [1]
+  // CHECK: %[[for:.*]]:3 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted2:.*]] = %[[extract2]], %[[hoisted1:.*]] = %[[extract1]])
+  %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+    %1 = tensor.extract_slice %t[0][%sz][1] : tensor<?xf32> to tensor<?xf32>
+    %2 = tensor.extract_slice %t[%sz][5][1] : tensor<?xf32> to tensor<5xf32>
+    // CHECK-DAG: %[[foo1:.*]] = "test.foo"(%[[hoisted1]])
+    // CHECK-DAG: %[[foo2:.*]] = "test.foo"(%[[hoisted2]])
+    %foo1 = "test.foo"(%1) : (tensor<?xf32>) -> (tensor<?xf32>)
+    %foo2 = "test.foo"(%2) : (tensor<5xf32>) -> (tensor<5xf32>)
+    %5 = tensor.insert_slice %foo2 into %t[%sz][5][1] : tensor<5xf32> into tensor<?xf32>
+    %6 = tensor.insert_slice %foo1 into %5[0][%sz][1] : tensor<?xf32> into tensor<?xf32>
+    // CHECK: scf.yield %[[t]], %[[foo2]], %[[foo1]]
+    scf.yield %6 : tensor<?xf32>
+  }
+  // CHECK: %[[insert2:.*]] = tensor.insert_slice %[[for]]#2 into %[[for]]#0[0] [%{{.*}}] [1]
+  // CHECK: %[[insert1:.*]] = tensor.insert_slice %[[for]]#1 into %[[insert2]][%{{.*}}] [5] [1]
+
+  // CHECK: return %[[insert1]]
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @do_not_hoist_overlapping_subsets(
+func.func @do_not_hoist_overlapping_subsets(%arg: tensor<?xf32>) -> tensor<?xf32> {
+  %lb = "test.foo"() : () -> (index)
+  %ub = "test.foo"() : () -> (index)
+  %step = "test.foo"() : () -> (index)
+  %sz1 = "test.foo"() : () -> (index)
+  %sz2 = "test.foo"() : () -> (index)
+
+  // CHECK: scf.for
+  %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+    // These two slices are potentially overlapping. Do not hoist.
+    // CHECK: tensor.extract_slice
+    // CHECK: tensor.extract_slice
+    %1 = tensor.extract_slice %t[0][%sz1][1] : tensor<?xf32> to tensor<?xf32>
+    %2 = tensor.extract_slice %t[10][%sz2][1] : tensor<?xf32> to tensor<?xf32>
+    // CHECK: "test.foo"
+    // CHECK: "test.foo"
+    %foo1 = "test.foo"(%1) : (tensor<?xf32>) -> (tensor<?xf32>)
+    %foo2 = "test.foo"(%2) : (tensor<?xf32>) -> (tensor<?xf32>)
+    // CHECK: tensor.insert_slice
+    // CHECK: tensor.insert_slice
+    %5 = tensor.insert_slice %foo2 into %t[0][%sz1][1] : tensor<?xf32> into tensor<?xf32>
+    %6 = tensor.insert_slice %foo1 into %5[10][%sz2][1] : tensor<?xf32> into tensor<?xf32>
+    // CHECK: scf.yield
+    scf.yield %6 : tensor<?xf32>
+  }
+
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @multiple_yields(
+//  CHECK-SAME:     %[[arg:.*]]: tensor<?xf32>
+func.func @multiple_yields(%arg: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
+  %lb = "test.foo"() : () -> (index)
+  %ub = "test.foo"() : () -> (index)
+  %step = "test.foo"() : () -> (index)
+
+  // CHECK: %[[extract1:.*]] = tensor.extract_slice
+  // CHECK: %[[extract2:.*]] = tensor.extract_slice
+  // CHECK: scf.for {{.*}} iter_args(%{{.*}} = %[[arg]], %{{.*}} = %[[arg]], %{{.*}} = %[[extract1]], %{{.*}} = %[[extract2]])
+  %0:2 = scf.for %iv = %lb to %ub step %step iter_args(%t1 = %arg, %t2 = %arg)
+      -> (tensor<?xf32>, tensor<?xf32>) {
+    %1 = tensor.extract_slice %t1[0][5][1] : tensor<?xf32> to tensor<5xf32>
+    %2 = tensor.extract_slice %t2[5][5][1] : tensor<?xf32> to tensor<5xf32>
+    // CHECK: "test.foo"
+    // CHECK: "test.foo"
+    %foo1 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+    %foo2 = "test.foo"(%2) : (tensor<5xf32>) -> (tensor<5xf32>)
+    %5 = tensor.insert_slice %foo2 into %t1[0][5][1] : tensor<5xf32> into tensor<?xf32>
+    %6 = tensor.insert_slice %foo1 into %t2[5][5][1] : tensor<5xf32> into tensor<?xf32>
+    // CHECK: scf.yield
+    scf.yield %5, %6 : tensor<?xf32>, tensor<?xf32>
+  }
+  // CHECK: tensor.insert_slice
+  // CHECK: tensor.insert_slice
+
+  return %0#0, %0#1 : tensor<?xf32>, tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @do_not_hoist_swapping_yields(
+func.func @do_not_hoist_swapping_yields(%arg: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
+  %lb = "test.foo"() : () -> (index)
+  %ub = "test.foo"() : () -> (index)
+  %step = "test.foo"() : () -> (index)
+
+  // CHECK: scf.for
+  %0:2 = scf.for %iv = %lb to %ub step %step iter_args(%t1 = %arg, %t2 = %arg)
+      -> (tensor<?xf32>, tensor<?xf32>) {
+    // CHECK: tensor.extract_slice
+    // CHECK: tensor.extract_slice
+    %1 = tensor.extract_slice %t1[0][5][1] : tensor<?xf32> to tensor<5xf32>
+    %2 = tensor.extract_slice %t2[5][5][1] : tensor<?xf32> to tensor<5xf32>
+    // CHECK: "test.foo"
+    // CHECK: "test.foo"
+    %foo1 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+    %foo2 = "test.foo"(%2) : (tensor<5xf32>) -> (tensor<5xf32>)
+    // CHECK: tensor.insert_slice
+    // CHECK: tensor.insert_slice
+    %5 = tensor.insert_slice %foo2 into %t1[0][5][1] : tensor<5xf32> into tensor<?xf32>
+    %6 = tensor.insert_slice %foo1 into %t2[5][5][1] : tensor<5xf32> into tensor<?xf32>
+    // Swapping yields: do not hoist.
+    // CHECK: scf.yield
+    scf.yield %6, %5 : tensor<?xf32>, tensor<?xf32>
+  }
+
+  return %0#0, %0#1 : tensor<?xf32>, tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @non_subset_op(
+func.func @non_subset_op(%arg: tensor<?xf32>) -> tensor<?xf32> {
+  %lb = "test.foo"() : () -> (index)
+  %ub = "test.foo"() : () -> (index)
+  %step = "test.foo"() : () -> (index)
+
+  // CHECK: scf.for
+  %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+    // If any value along the use-def chain from the region iter_arg to the
+    // terminator is used by a non-subset op, no subset op along that chain can
+    // be hoisted. That is because it is unknown which parts of the value are
+    // accessed by the non-subset op.
+    // CHECK: "test.non_subset_op"
+    "test.non_subset_op"(%t) : (tensor<?xf32>) -> ()
+    // CHECK: tensor.extract_slice
+    %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+    // CHECK: "test.foo"
+    %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+    // CHECK: tensor.insert_slice
+    %3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
+    // CHECK: scf.yield
+    scf.yield %3 : tensor<?xf32>
+  }
+
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @non_loop_invariant_subset_op(
+func.func @non_loop_invariant_subset_op(%arg: tensor<?xf32>) -> tensor<?xf32> {
+  %lb = "test.foo"() : () -> (index)
+  %ub = "test.foo"() : () -> (index)
+  %step = "test.foo"() : () -> (index)
+
+  // CHECK: scf.for
+  %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+    // Subset ops that are not loop-invariant cannot be hoisted.
+    // CHECK: tensor.extract_slice
+    %1 = tensor.extract_slice %t[%iv][5][1] : tensor<?xf32> to tensor<5xf32>
+    // CHECK: "test.foo"
+    %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+    // CHECK: tensor.insert_slice
+    %3 = tensor.insert_slice %2 into %t[%iv][5][1] : tensor<5xf32> into tensor<?xf32>
+    // CHECK: scf.yield
+    scf.yield %3 : tensor<?xf32>
+  }
+
+  return %0 : tensor<?xf32>
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index add146bf084d321..a9aa1d848fce4f5 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -7035,6 +7035,7 @@ cc_library(
         ":MemorySlotInterfaces",
         ":Rewrite",
         ":SideEffectInterfaces",
+        ":SubsetOpInterface",
         ":Support",
         ":TransformsPassIncGen",
         ":config",

>From 959e9254c112c339993b96c4ff2152bbf1388257 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 30 Oct 2023 15:06:41 +0900
Subject: [PATCH 3/7] [mlir][WIP] Bypass analysis for loops

---
 .../Utils/LoopInvariantCodeMotionUtils.cpp    | 72 ++++++++++++++-----
 .../loop-invariant-subset-hoisting.mlir       | 35 +++++++++
 2 files changed, 91 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index f39f8ec3598f322..bb4e6dc62b9c935 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -120,8 +120,10 @@ namespace {
 class MatchingSubsets {
 public:
   /// Insert a subset op.
-  void insert(SubsetOpInterface op) {
+  void insert(SubsetOpInterface op, bool collectHoistableOps = true) {
     allSubsetOps.push_back(op);
+    if (!collectHoistableOps)
+      return;
     if (auto extractionOp =
             dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
       insertExtractionOp(extractionOp);
@@ -148,6 +150,15 @@ class MatchingSubsets {
         });
   }
 
+  /// Populate subset ops starting from the given region iter_arg. Return
+  /// "failure" if non-subset ops are found along the path to the loop yielding
+  /// op or if there is no single path to the tied yielded operand. If
+  /// `collectHoistableOps` is set to "false", subset ops are gathered
+  /// throughout the traversal, but not enumerated by `getHoistableSubsetOps`.
+  LogicalResult populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
+                                           BlockArgument iterArg,
+                                           bool collectHoistableOps = true);
+
 private:
   /// Helper function for equivalence of tensor values. Since only insertion
   /// subset ops (that are also destination style ops) are followed when
@@ -225,18 +236,12 @@ static OpOperand *getSingleTerminatorUse(Value value) {
   return nullptr;
 }
 
-/// Hoist all subset ops that operate on the idx-th region iter_arg of the given
-/// loop-like op and index into loop-invariant subset locations. Return the
-/// newly created loop op (that has extra iter_args) or the original loop op if
-/// nothing was hoisted.
-static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
-                                                BlockArgument iterArg) {
-  IRRewriter rewriter(loopLike.getContext());
+LogicalResult
+MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
+                                            BlockArgument iterArg,
+                                            bool collectHoistableOps) {
   assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
-  auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
-  int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
   Value value = iterArg;
-  MatchingSubsets subsets;
 
   // Traverse use-def chain. Subset ops can be hoisted only if all ops along the
   // use-def chain starting from the region iter_arg are subset extraction or
@@ -249,21 +254,39 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
     Value nextValue = {};
 
     for (OpOperand &use : value.getUses()) {
+      if (auto nestedLoop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
+        // Subset ops in nested loops are collected to check if there are only
+        // disjoint subset ops, but such subset ops are not subject to hoisting.
+        // To hoist subset ops from nested loops, the hoisting transformation
+        // should be run on the nested loop.
+        auto nestedIterArg = nestedLoop.getTiedLoopRegionIterArg(&use);
+        if (!nestedIterArg)
+          return failure();
+        // Note: `populateSubsetOpsAtIterArg` fails if there is no single SSA
+        // use-def chain starting at `nestedIterArg` and terminating in the
+        // tied, yielding operand.
+        if (failed(populateSubsetOpsAtIterArg(nestedLoop, nestedIterArg,
+                                              /*collectHoistableOps=*/false)))
+          return failure();
+        nextValue = nestedLoop.getTiedLoopResult(&use);
+        continue;
+      }
+
       auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner());
       if (!subsetOp)
-        return loopLike;
-      subsets.insert(subsetOp);
+        return failure();
+      insert(subsetOp);
 
       if (auto insertionOp =
               dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) {
         // The value must be used as a destination. (In case of a source, the
         // entire tensor would be read, which would prevent any hoisting.)
         if (&use != &insertionOp.getDestinationOperand())
-          return loopLike;
+          return failure();
         // There must be a single use-def chain from the region iter_arg to the
         // terminator. I.e., only one insertion op. Branches are not supported.
         if (nextValue)
-          return loopLike;
+          return failure();
         nextValue = insertionOp.getUpdatedDestination();
       }
     }
@@ -271,7 +294,7 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
     // Nothing can be hoisted if the chain does not continue with loop yielding
     // op or a subset insertion op.
     if (!nextValue)
-      return loopLike;
+      return failure();
     value = nextValue;
   }
 
@@ -279,6 +302,23 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
   // loop and the yielded value is the `idx`-th operand. (I.e., there is no
   // swapping yield.)
   if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand)
+    return failure();
+
+  return success();
+}
+
+/// Hoist all subset ops that operate on the idx-th region iter_arg of the given
+/// loop-like op and index into loop-invariant subset locations. Return the
+/// newly created loop op (that has extra iter_args) or the original loop op if
+/// nothing was hoisted.
+static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
+                                                BlockArgument iterArg) {
+  assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
+  auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
+  int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
+  IRRewriter rewriter(loopLike.getContext());
+  MatchingSubsets subsets;
+  if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg)))
     return loopLike;
 
   // Hoist all matching extraction-insertion pairs one-by-one.
diff --git a/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
index 5cded4c99182c14..b9161f4e20d1927 100644
--- a/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
+++ b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
@@ -235,3 +235,38 @@ func.func @non_loop_invariant_subset_op(%arg: tensor<?xf32>) -> tensor<?xf32> {
 
   return %0 : tensor<?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @nested_hoisting(
+//  CHECK-SAME:     %[[arg:.*]]: tensor<?xf32>
+func.func @nested_hoisting(%arg: tensor<?xf32>) -> tensor<?xf32> {
+  %lb = "test.foo"() : () -> (index)
+  %ub = "test.foo"() : () -> (index)
+  %step = "test.foo"() : () -> (index)
+
+  // CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]][0] [5] [1]
+  // CHECK: %[[extract2:.*]] = tensor.extract_slice %[[arg]][5] [5] [1]
+  // CHECK: %[[for:.*]]:3 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]], %[[hoisted2:.*]] = %[[extract2]])
+  %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+    %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+    // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]])
+    %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+    %3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
+    // CHECK: %[[for2:.*]]:2 = {{.*}} iter_args(%[[t2:.*]] = %[[t]], %[[hoisted2_nested:.*]] = %[[hoisted2]])
+    %4 = scf.for %iv2 = %lb to %ub step %step iter_args(%t2 = %3) -> (tensor<?xf32>) {
+      %5 = tensor.extract_slice %t2[5][5][1] : tensor<?xf32> to tensor<5xf32>
+      // CHECK: %[[foo2:.*]] = "test.foo"(%[[hoisted2_nested]])
+      %6 = "test.foo"(%5) : (tensor<5xf32>) -> (tensor<5xf32>)
+      %7 = tensor.insert_slice %6 into %t2[5][5][1] : tensor<5xf32> into tensor<?xf32>
+      // CHECK: scf.yield %[[t2]], %[[foo2]]
+      scf.yield %7 : tensor<?xf32>
+    }
+    // CHECK: scf.yield %[[for2]]#0, %[[foo]], %[[for2]]#1
+    scf.yield %4 : tensor<?xf32>
+  }
+  // CHECK: %[[insert:.*]] = tensor.insert_slice %[[for]]#2 into %[[for]]#0[5] [5] [1]
+  // CHECK: %[[insert2:.*]] = tensor.insert_slice %[[for]]#1 into %[[insert]][0] [5] [1]
+  // CHECK: return %[[insert2]]
+  return %0 : tensor<?xf32>
+}

>From 387be39420a1830683d3b6291d95b0beeff5a5b4 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 30 Oct 2023 15:17:02 +0900
Subject: [PATCH 4/7] [mlir][Interfaces] `SubsetOpInterface`: Add helpers for
 hyperrectangular subsets

---
 .../Bufferization/IR/BufferizationOps.td      |   3 +-
 mlir/include/mlir/IR/OpDefinition.h           |   5 +
 .../mlir/Interfaces/SubsetOpInterface.h       |  16 ++-
 .../mlir/Interfaces/SubsetOpInterface.td      |  53 +++++++--
 .../mlir/Interfaces/ValueBoundsOpInterface.h  |  53 ++++++++-
 .../SubsetInsertionOpInterfaceImpl.cpp        |  82 +------------
 mlir/lib/Interfaces/CMakeLists.txt            |   2 +
 mlir/lib/Interfaces/SubsetOpInterface.cpp     |  49 ++++++++
 .../lib/Interfaces/ValueBoundsOpInterface.cpp | 109 ++++++++++++++++--
 .../loop-invariant-subset-hoisting.mlir       |   9 +-
 .../llvm-project-overlay/mlir/BUILD.bazel     |   1 +
 11 files changed, 285 insertions(+), 97 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index e6b6d052df96a8c..9dc6afcaab31c86 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -220,7 +220,8 @@ def Bufferization_MaterializeInDestinationOp
          AllElementTypesMatch<["source", "dest"]>,
          BufferizableOpInterface, DestinationStyleOpInterface,
          DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
-         DeclareOpInterfaceMethods<SubsetOpInterface>,
+         DeclareOpInterfaceMethods<SubsetOpInterface,
+            ["operatesOnEquivalentSubset", "operatesOnDisjointSubset"]>,
          DeclareOpInterfaceMethods<SubsetInsertionOpInterface,
             ["getSourceOperand", "getValuesNeededToBuildSubsetExtraction",
              "buildSubsetExtraction", "isEquivalentSubset"]>,
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 8ab37c1d51d6b6c..bd68c27445744e3 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -268,6 +268,11 @@ class OpFoldResult : public PointerUnion<Attribute, Value> {
 
 public:
   void dump() const { llvm::errs() << *this << "\n"; }
+
+  MLIRContext *getContext() const {
+    return is<Attribute>() ? get<Attribute>().getContext()
+                           : get<Value>().getContext();
+  }
 };
 
 // Temporarily exit the MLIR namespace to add casting support as later code in
diff --git a/mlir/include/mlir/Interfaces/SubsetOpInterface.h b/mlir/include/mlir/Interfaces/SubsetOpInterface.h
index 049cf2456a9c842..98c33ec65012fca 100644
--- a/mlir/include/mlir/Interfaces/SubsetOpInterface.h
+++ b/mlir/include/mlir/Interfaces/SubsetOpInterface.h
@@ -10,6 +10,7 @@
 #define MLIR_INTERFACES_SUBSETOPINTERFACE_H_
 
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
 
 namespace mlir {
 class SubsetOpInterface;
@@ -27,10 +28,23 @@ OpOperand &defaultGetDestinationOperand(Operation *op);
 /// `DestinationStyleOpInterface`.
 OpResult defaultGetUpdatedDestination(Operation *op);
 
-/// Default implementation of `isEquivalentSubset`.
+/// Default implementation of `SubsetInsertionOpInterface::isEquivalentSubset`.
 bool defaultIsEquivalentSubset(Operation *op, Value candidate,
                                function_ref<bool(Value, Value)> equivalenceFn);
 
+/// Default implementation of `SubsetOpInterface::operatesOnEquivalentSubset`.
+bool defaultOperatesOnEquivalentSubset(
+    Operation *op, SubsetOpInterface candidate,
+    function_ref<bool(Value, Value)> equivalenceFn);
+
+/// Default implementation of `SubsetOpInterface::operatesOnDisjointSubset`.
+bool defaultOperatesOnDisjointSubset(
+    Operation *op, SubsetOpInterface candidate,
+    function_ref<bool(Value, Value)> equivalenceFn);
+
+/// Return the container that the given subset op is operating on.
+Value getTensorContainer(Operation *op);
+
 /// Verify `SubsetOpInterface`.
 LogicalResult verifySubsetOpInterface(SubsetOpInterface op);
 
diff --git a/mlir/include/mlir/Interfaces/SubsetOpInterface.td b/mlir/include/mlir/Interfaces/SubsetOpInterface.td
index 07d62b8319c2961..a00e398618a0118 100644
--- a/mlir/include/mlir/Interfaces/SubsetOpInterface.td
+++ b/mlir/include/mlir/Interfaces/SubsetOpInterface.td
@@ -32,11 +32,6 @@ def SubsetOpInterface : OpInterface<"SubsetOpInterface"> {
       hyperrectangular slice.
     - `tensor.gather/scatter` describe the subset as list of indices. (Not
       implemented yet.)
-
-    Note: This interface does not expose any interface methods to get a
-    description of the accessed subset. That is because there is currently no
-    efficient way to describe arbitrary subsets. This interface merely provides
-    interface methods to check if two subsets are equivalent or disjoint.
   }];
 
   let cppNamespace = "::mlir";
@@ -46,24 +41,59 @@ def SubsetOpInterface : OpInterface<"SubsetOpInterface"> {
           Return "true" if this op and the given candidate subset op operate on
           an equivalent subset. Return "false" is the two subsets are disjoint
           or cannot be proven to be equivalent.
+
+          This interface method does not have to be implemented if
+          `getAccessedHyperrectangularSlice` is implemented.
         }],
         /*retType=*/"bool",
         /*methodName=*/"operatesOnEquivalentSubset",
         /*args=*/(ins
             "::mlir::SubsetOpInterface":$candidate,
-            "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn)
+            "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return ::mlir::detail::defaultOperatesOnEquivalentSubset(
+              $_op, candidate, equivalenceFn);
+        }]
       >,
       InterfaceMethod<
         /*desc=*/[{
           Return "true" if this op and the given candidate subset op operate on
           disjoint subsets. Return "false" is the two subsets are equivalent,
           overlapping or cannot be proven to be disjoint.
+
+          This interface method does not have to be implemented if
+          `getAccessedHyperrectangularSlice` is implemented.
         }],
         /*retType=*/"bool",
         /*methodName=*/"operatesOnDisjointSubset",
         /*args=*/(ins
             "::mlir::SubsetOpInterface":$candidate,
-            "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn)
+            "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return ::mlir::detail::defaultOperatesOnDisjointSubset(
+              $_op, candidate, equivalenceFn);
+        }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          If this op operates on a hyperrectangular subset, return a
+          description of the subset in terms of offsets, sizes and strides.
+          Otherwise, return "failure".
+
+          This interface method is a convenience method for the most common case
+          of hyperrectangular subset ops. It is optional. If it is implemented,
+          `operatesOnEquivalentSubset` and `operatesOnDisjointSubset` do not
+          have to be implemented.
+        }],
+        /*retType=*/"::mlir::FailureOr<::mlir::HyperrectangularSlice>",
+        /*methodName=*/"getAccessedHyperrectangularSlice",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return ::mlir::failure();
+        }]
       >,
   ];
 
@@ -71,6 +101,15 @@ def SubsetOpInterface : OpInterface<"SubsetOpInterface"> {
     return ::mlir::detail::verifySubsetOpInterface(
         ::mlir::cast<::mlir::SubsetOpInterface>($_op));
   }];
+
+  let extraClassDeclaration = [{
+    /// Return the container that this operation is operating on. In case of an
+    /// extraction op, the container is the source tensor. In case of an
+    /// insertion op, the container is the destination tensor.
+    Value getTensorContainer() {
+      return ::mlir::detail::getTensorContainer(getOperation());
+    }
+  }];
 }
 
 def SubsetExtractionOpInterface
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 8e2986a2d1f05f6..28dadfb9ecf8688 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -21,6 +21,31 @@
 namespace mlir {
 class OffsetSizeAndStrideOpInterface;
 
+/// A hyperrectangular slice, represented as a list of offsets, sizes and
+/// strides.
+class HyperrectangularSlice {
+public:
+  HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
+                        ArrayRef<OpFoldResult> sizes,
+                        ArrayRef<OpFoldResult> strides);
+
+  /// Create a hyperrectangular slice with unit strides.
+  HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
+                        ArrayRef<OpFoldResult> sizes);
+
+  /// Infer a hyperrectangular slice from `OffsetSizeAndStrideOpInterface`.
+  HyperrectangularSlice(OffsetSizeAndStrideOpInterface op);
+
+  ArrayRef<OpFoldResult> getMixedOffsets() const { return mixedOffsets; }
+  ArrayRef<OpFoldResult> getMixedSizes() const { return mixedSizes; }
+  ArrayRef<OpFoldResult> getMixedStrides() const { return mixedStrides; }
+
+private:
+  SmallVector<OpFoldResult> mixedOffsets;
+  SmallVector<OpFoldResult> mixedSizes;
+  SmallVector<OpFoldResult> mixedStrides;
+};
+
 using ValueDimList = SmallVector<std::pair<Value, std::optional<int64_t>>>;
 
 /// A helper class to be used with `ValueBoundsOpInterface`. This class stores a
@@ -182,12 +207,34 @@ class ValueBoundsConstraintSet {
                                   std::optional<int64_t> dim1 = std::nullopt,
                                   std::optional<int64_t> dim2 = std::nullopt);
 
+  /// Compute whether the given values/attributes are equal. Return "failure" if
+  /// equality could not be determined.
+  ///
+  /// `ofr1`/`ofr2` must be of index type.
+  static FailureOr<bool> areEqual(OpFoldResult ofr1, OpFoldResult ofr2);
+
   /// Return "true" if the given slices are guaranteed to be overlapping.
   /// Return "false" if the given slices are guaranteed to be non-overlapping.
   /// Return "failure" if unknown.
-  static FailureOr<bool>
-  areOverlappingSlices(OffsetSizeAndStrideOpInterface slice1,
-                       OffsetSizeAndStrideOpInterface slice2);
+  ///
+  /// Slices are overlapping if for all dimensions:
+  /// *      offset1 + size1 * stride1 <= offset2
+  /// * and  offset2 + size2 * stride2 <= offset1
+  ///
+  /// Slice are non-overlapping if the above constraint is not satisfied for
+  /// at least one dimension.
+  static FailureOr<bool> areOverlappingSlices(MLIRContext *ctx,
+                                              HyperrectangularSlice slice1,
+                                              HyperrectangularSlice slice2);
+
+  /// Return "true" if the given slices are guaranteed to be equivalent.
+  /// Return "false" if the given slices are guaranteed to be non-equivalent.
+  /// Return "failure" if unknown.
+  ///
+  /// Slices are equivalent if their offsets, sizes and strices are equal.
+  static FailureOr<bool> areEquivalentSlices(MLIRContext *ctx,
+                                             HyperrectangularSlice slice1,
+                                             HyperrectangularSlice slice2);
 
   /// Add a bound for the given index-typed value or shaped value. This function
   /// returns a builder that adds the bound.
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
index 7a1bafd409eea60..d50d7c62b789c8c 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
@@ -17,73 +17,12 @@ using namespace mlir::tensor;
 
 namespace {
 
-/// Return the tensor that the given subset op operates on.
-Value getContainerOperand(SubsetOpInterface op) {
-  if (auto extractionOp =
-          dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
-    return extractionOp.getSourceOperand().get();
-  if (auto insertionOp =
-          dyn_cast<SubsetInsertionOpInterface>(op.getOperation()))
-    return insertionOp.getDestinationOperand().get();
-  llvm_unreachable("expected SubsetExtraction/InsertionOpInterface");
-}
-
-/// Return "true" if the two ops operate on an equivalent subset.
-/// `equivalenceFn` is used to determine equivalence of tensors. Return "false"
-/// if the two ops operate non-equivalent subsets, if equivalence cannot be
-/// determined or if `op1` is not a subset op.
-template <typename OpTy>
-bool operateOnEquivalentSubsets(
-    OpTy op1, SubsetOpInterface op2,
-    function_ref<bool(Value, Value)> equivalenceFn) {
-  auto offsetsSizesAndStrides2 =
-      dyn_cast<OffsetSizeAndStrideOpInterface>(op2.getOperation());
-  if (!offsetsSizesAndStrides2)
-    return false;
-  if (!sameOffsetsSizesAndStrides(op1, offsetsSizesAndStrides2,
-                                  isEqualConstantIntOrValue))
-    return false;
-  return equivalenceFn(
-      getContainerOperand(cast<SubsetOpInterface>(op1.getOperation())),
-      getContainerOperand(op2));
-}
-
-/// Return "true" if the two ops operate on a disjoint subsets.
-/// `equivalenceFn` is used to determine equivalence of tensors. Return "false"
-/// if the two ops operate non-disjoint subsets, if disjointness cannot be
-/// determined or if `op1` is not a subset op.
-template <typename OpTy>
-bool operateOnDisjointSubsets(OpTy op1, SubsetOpInterface op2,
-                              function_ref<bool(Value, Value)> equivalenceFn) {
-  auto offsetsSizesAndStrides2 =
-      dyn_cast<OffsetSizeAndStrideOpInterface>(op2.getOperation());
-  if (!offsetsSizesAndStrides2)
-    return false;
-  FailureOr<bool> overlappingSlices =
-      ValueBoundsConstraintSet::areOverlappingSlices(op1,
-                                                     offsetsSizesAndStrides2);
-  if (failed(overlappingSlices) || *overlappingSlices)
-    return false;
-  return equivalenceFn(
-      getContainerOperand(cast<SubsetOpInterface>(op1.getOperation())),
-      getContainerOperand(op2));
-}
-
 struct ExtractSliceOpSubsetOpInterface
     : public SubsetOpInterface::ExternalModel<ExtractSliceOpSubsetOpInterface,
                                               tensor::ExtractSliceOp> {
-  bool operatesOnEquivalentSubset(
-      Operation *op, SubsetOpInterface candidate,
-      function_ref<bool(Value, Value)> equivalenceFn) const {
-    auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
-    return operateOnEquivalentSubsets(extractSliceOp, candidate, equivalenceFn);
-  }
-
-  bool operatesOnDisjointSubset(
-      Operation *op, SubsetOpInterface candidate,
-      function_ref<bool(Value, Value)> equivalenceFn) const {
-    auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
-    return operateOnDisjointSubsets(extractSliceOp, candidate, equivalenceFn);
+  FailureOr<HyperrectangularSlice>
+  getAccessedHyperrectangularSlice(Operation *op) const {
+    return HyperrectangularSlice(cast<OffsetSizeAndStrideOpInterface>(op));
   }
 };
 
@@ -99,18 +38,9 @@ template <typename OpTy>
 struct InsertSliceLikeOpSubsetOpInterface
     : public SubsetOpInterface::ExternalModel<
           InsertSliceLikeOpSubsetOpInterface<OpTy>, OpTy> {
-  bool operatesOnEquivalentSubset(
-      Operation *op, SubsetOpInterface candidate,
-      function_ref<bool(Value, Value)> equivalenceFn) const {
-    auto insertSliceOp = cast<OpTy>(op);
-    return operateOnEquivalentSubsets(insertSliceOp, candidate, equivalenceFn);
-  }
-
-  bool operatesOnDisjointSubset(
-      Operation *op, SubsetOpInterface candidate,
-      function_ref<bool(Value, Value)> equivalenceFn) const {
-    auto insertSliceOp = cast<OpTy>(op);
-    return operateOnDisjointSubsets(insertSliceOp, candidate, equivalenceFn);
+  FailureOr<HyperrectangularSlice>
+  getAccessedHyperrectangularSlice(Operation *op) const {
+    return HyperrectangularSlice(cast<OffsetSizeAndStrideOpInterface>(op));
   }
 };
 
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index 2652d261f480ba4..e7c76e70ed6b5d7 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -93,10 +93,12 @@ add_mlir_library(MLIRSubsetOpInterface
   DEPENDS
   MLIRDestinationStyleOpInterface
   MLIRSubsetOpInterfaceIncGen
+  MLIRValueBoundsOpInterface
 
   LINK_LIBS PUBLIC
   MLIRDestinationStyleOpInterface
   MLIRIR
+  MLIRValueBoundsOpInterface
   )
 
 add_mlir_interface_library(TilingInterface)
diff --git a/mlir/lib/Interfaces/SubsetOpInterface.cpp b/mlir/lib/Interfaces/SubsetOpInterface.cpp
index 7245ab20c499e20..d0bdadf500f6f6c 100644
--- a/mlir/lib/Interfaces/SubsetOpInterface.cpp
+++ b/mlir/lib/Interfaces/SubsetOpInterface.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Interfaces/SubsetOpInterface.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
 
 #include "mlir/Interfaces/SubsetOpInterface.cpp.inc"
 
@@ -40,6 +41,54 @@ bool detail::defaultIsEquivalentSubset(
       candidate.getDefiningOp<SubsetOpInterface>(), equivalenceFn);
 }
 
+bool detail::defaultOperatesOnEquivalentSubset(
+    Operation *op, SubsetOpInterface candidate,
+    function_ref<bool(Value, Value)> equivalenceFn) {
+  auto subsetOp = cast<SubsetOpInterface>(op);
+  FailureOr<HyperrectangularSlice> slice =
+      subsetOp.getAccessedHyperrectangularSlice();
+  assert(succeeded(slice) &&
+         "operatesOnEquivalentSubset must be implemented if "
+         "getAccessedHyperrectangularSlice is not implemented");
+  FailureOr<HyperrectangularSlice> otherSlice =
+      candidate.getAccessedHyperrectangularSlice();
+  if (failed(otherSlice))
+    return false;
+  if (!equivalenceFn(subsetOp.getTensorContainer(),
+                     candidate.getTensorContainer()))
+    return false;
+  FailureOr<bool> equivalent = ValueBoundsConstraintSet::areEquivalentSlices(
+      op->getContext(), *slice, *otherSlice);
+  return succeeded(equivalent) && *equivalent;
+}
+
+bool detail::defaultOperatesOnDisjointSubset(
+    Operation *op, SubsetOpInterface candidate,
+    function_ref<bool(Value, Value)> equivalenceFn) {
+  auto subsetOp = cast<SubsetOpInterface>(op);
+  FailureOr<HyperrectangularSlice> slice =
+      subsetOp.getAccessedHyperrectangularSlice();
+  assert(succeeded(slice) &&
+         "defaultOperatesOnDisjointSubset must be implemented if "
+         "getAccessedHyperrectangularSlice is not implemented");
+  FailureOr<HyperrectangularSlice> otherSlice =
+      candidate.getAccessedHyperrectangularSlice();
+  if (failed(otherSlice))
+    return false;
+  if (!equivalenceFn(subsetOp.getTensorContainer(),
+                     candidate.getTensorContainer()))
+    return false;
+  FailureOr<bool> overlapping = ValueBoundsConstraintSet::areOverlappingSlices(
+      op->getContext(), *slice, *otherSlice);
+  return succeeded(overlapping) && !*overlapping;
+}
+
+Value detail::getTensorContainer(Operation *op) {
+  if (auto insertionOp = dyn_cast<::mlir::SubsetInsertionOpInterface>(op))
+    return insertionOp.getDestinationOperand().get();
+  return cast<::mlir::SubsetExtractionOpInterface>(op).getSourceOperand().get();
+}
+
 LogicalResult detail::verifySubsetOpInterface(SubsetOpInterface op) {
   if (!(isa<SubsetExtractionOpInterface>(op.getOperation()) ^
         isa<SubsetInsertionOpInterface>(op.getOperation())))
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index f0c37c872e6d31d..62ba63402925e01 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -25,6 +25,32 @@ namespace mlir {
 #include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc"
 } // namespace mlir
 
+HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
+                                             ArrayRef<OpFoldResult> sizes,
+                                             ArrayRef<OpFoldResult> strides)
+    : mixedOffsets(offsets), mixedSizes(sizes), mixedStrides(strides) {
+  assert(offsets.size() == sizes.size() &&
+         "expected same number of offsets, sizes, strides");
+  assert(offsets.size() == strides.size() &&
+         "expected same number of offsets, sizes, strides");
+}
+
+HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
+                                             ArrayRef<OpFoldResult> sizes)
+    : mixedOffsets(offsets), mixedSizes(sizes) {
+  assert(offsets.size() == sizes.size() &&
+         "expected same number of offsets and sizes");
+  // Assume that all strides are 1.
+  if (offsets.empty())
+    return;
+  MLIRContext *ctx = offsets.front().getContext();
+  mixedStrides.append(offsets.size(), Builder(ctx).getIndexAttr(1));
+}
+
+HyperrectangularSlice::HyperrectangularSlice(OffsetSizeAndStrideOpInterface op)
+    : HyperrectangularSlice(op.getMixedOffsets(), op.getMixedSizes(),
+                            op.getMixedStrides()) {}
+
 /// If ofr is a constant integer or an IntegerAttr, return the integer.
 static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
   // Case 1: Check for Constant integer.
@@ -524,19 +550,44 @@ ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
   return *delta == 0;
 }
 
-FailureOr<bool> ValueBoundsConstraintSet::areOverlappingSlices(
-    OffsetSizeAndStrideOpInterface slice1,
-    OffsetSizeAndStrideOpInterface slice2) {
-  assert(slice1.getStaticOffsets().size() == slice1.getStaticOffsets().size() &&
+FailureOr<bool> ValueBoundsConstraintSet::areEqual(OpFoldResult ofr1,
+                                                   OpFoldResult ofr2) {
+  Builder b(ofr1.getContext());
+  AffineMap map =
+      AffineMap::get(/*dimCount=*/0, /*symbolCount=*/2,
+                     b.getAffineSymbolExpr(0) - b.getAffineSymbolExpr(1));
+  SmallVector<OpFoldResult> ofrOperands;
+  ofrOperands.push_back(ofr1);
+  ofrOperands.push_back(ofr2);
+  SmallVector<Value> valueOperands;
+  AffineMap foldedMap =
+      foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
+  ValueDimList valueDims;
+  for (Value v : valueOperands) {
+    assert(v.getType().isIndex() && "expected index type");
+    valueDims.emplace_back(v, std::nullopt);
+  }
+  FailureOr<int64_t> delta =
+      computeConstantBound(presburger::BoundType::EQ, foldedMap, valueDims);
+  if (failed(delta))
+    return failure();
+  return *delta == 0;
+}
+
+FailureOr<bool>
+ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx,
+                                               HyperrectangularSlice slice1,
+                                               HyperrectangularSlice slice2) {
+  assert(slice1.getMixedOffsets().size() == slice1.getMixedOffsets().size() &&
          "expected slices of same rank");
-  assert(slice1.getStaticSizes().size() == slice1.getStaticSizes().size() &&
+  assert(slice1.getMixedSizes().size() == slice1.getMixedSizes().size() &&
          "expected slices of same rank");
-  assert(slice1.getStaticStrides().size() == slice1.getStaticStrides().size() &&
+  assert(slice1.getMixedStrides().size() == slice1.getMixedStrides().size() &&
          "expected slices of same rank");
 
-  Builder b(slice1.getContext());
+  Builder b(ctx);
   bool foundUnknownBound = false;
-  for (int64_t i = 0, e = slice1.getStaticOffsets().size(); i < e; ++i) {
+  for (int64_t i = 0, e = slice1.getMixedOffsets().size(); i < e; ++i) {
     AffineMap map =
         AffineMap::get(/*dimCount=*/0, /*symbolCount=*/4,
                        b.getAffineSymbolExpr(0) +
@@ -588,6 +639,48 @@ FailureOr<bool> ValueBoundsConstraintSet::areOverlappingSlices(
   return true;
 }
 
+FailureOr<bool>
+ValueBoundsConstraintSet::areEquivalentSlices(MLIRContext *ctx,
+                                              HyperrectangularSlice slice1,
+                                              HyperrectangularSlice slice2) {
+  assert(slice1.getMixedOffsets().size() == slice1.getMixedOffsets().size() &&
+         "expected slices of same rank");
+  assert(slice1.getMixedSizes().size() == slice1.getMixedSizes().size() &&
+         "expected slices of same rank");
+  assert(slice1.getMixedStrides().size() == slice1.getMixedStrides().size() &&
+         "expected slices of same rank");
+
+  // The two slices are equivalent if all of their offsets, sizes and strides
+  // are equal. If equality cannot be determined for at least one of those
+  // values, equivalence cannot be determined and this function returns
+  // "failure".
+  for (auto [offset1, offset2] :
+       llvm::zip_equal(slice1.getMixedOffsets(), slice2.getMixedOffsets())) {
+    FailureOr<bool> equal = areEqual(offset1, offset2);
+    if (failed(equal))
+      return failure();
+    if (!equal.value())
+      return false;
+  }
+  for (auto [size1, size2] :
+       llvm::zip_equal(slice1.getMixedSizes(), slice2.getMixedSizes())) {
+    FailureOr<bool> equal = areEqual(size1, size2);
+    if (failed(equal))
+      return failure();
+    if (!equal.value())
+      return false;
+  }
+  for (auto [stride1, stride2] :
+       llvm::zip_equal(slice1.getMixedStrides(), slice2.getMixedStrides())) {
+    FailureOr<bool> equal = areEqual(stride1, stride2);
+    if (failed(equal))
+      return failure();
+    if (!equal.value())
+      return false;
+  }
+  return true;
+}
+
 ValueBoundsConstraintSet::BoundBuilder &
 ValueBoundsConstraintSet::BoundBuilder::operator[](int64_t dim) {
   assert(!this->dim.has_value() && "dim was already set");
diff --git a/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
index b9161f4e20d1927..bb60eeaba52455c 100644
--- a/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
+++ b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
@@ -7,6 +7,11 @@ func.func @hoist_matching_extract_insert(%arg: tensor<?xf32>) -> tensor<?xf32> {
   %ub = "test.foo"() : () -> (index)
   %step = "test.foo"() : () -> (index)
 
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %add = arith.addi %c0, %c1 : index
+  %sub = arith.subi %add, %c1 : index
+
   // CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]]
   // CHECK: %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]])
   %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
@@ -17,7 +22,9 @@ func.func @hoist_matching_extract_insert(%arg: tensor<?xf32>) -> tensor<?xf32> {
     %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
     // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]])
     %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
-    %3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
+    // Obfuscate the IR by inserting at offset %sub instead of 0; both of them
+    // have the same value.
+    %3 = tensor.insert_slice %2 into %t[%sub][5][1] : tensor<5xf32> into tensor<?xf32>
     // CHECK: scf.yield %[[t]], %[[foo]]
     scf.yield %3 : tensor<?xf32>
   }
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index a9aa1d848fce4f5..4e814428f0cc457 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -10229,6 +10229,7 @@ cc_library(
         ":IR",
         ":SubsetOpInterfaceIncGen",
         ":Support",
+        ":ValueBoundsOpInterface",
         "//llvm:Support",
     ],
 )

>From e903f76f1ec4612a1262bc57d1d2c8c4b84d1a9a Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 30 Oct 2023 15:31:47 +0900
Subject: [PATCH 5/7] [mlir][vector] Implement subset op interface for xfer ops

---
 .../Vector/Transforms/SubsetOpInterfaceImpl.h |  20 +
 mlir/include/mlir/InitAllDialects.h           |   2 +
 .../Dialect/Vector/Transforms/CMakeLists.txt  |   2 +
 .../Transforms/SubsetOpInterfaceImpl.cpp      |  82 +++
 mlir/test/Dialect/Linalg/hoisting.mlir        | 471 ------------------
 .../loop-invariant-subset-hoisting.mlir       | 318 ++++++++++++
 .../llvm-project-overlay/mlir/BUILD.bazel     |   1 +
 7 files changed, 425 insertions(+), 471 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h
 create mode 100644 mlir/lib/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.cpp

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h
new file mode 100644
index 000000000000000..74bde485fa17a99
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- SubsetOpInterfaceImpl.h - Tensor subsets -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_VECTOR_SUBSETOPINTERFACEIMPL_H
+#define MLIR_DIALECT_VECTOR_SUBSETOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace vector {
+void registerSubsetOpInterfaceExternalModels(DialectRegistry &registry);
+} // namespace vector
+} // namespace mlir
+
+#endif // MLIR_DIALECT_VECTOR_SUBSETOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 7c2ffb7408d9afd..621110d130818d3 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -85,6 +85,7 @@
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
 #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/Interfaces/CastInterfaces.h"
@@ -171,6 +172,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   tensor::registerTilingInterfaceExternalModels(registry);
   tensor::registerValueBoundsOpInterfaceExternalModels(registry);
   vector::registerBufferizableOpInterfaceExternalModels(registry);
+  vector::registerSubsetOpInterfaceExternalModels(registry);
   NVVM::registerNVVMTargetInterfaceExternalModels(registry);
   ROCDL::registerROCDLTargetInterfaceExternalModels(registry);
 }
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index e6e7fb465aec3b9..513340096a5c1fc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   LowerVectorShapeCast.cpp
   LowerVectorTransfer.cpp
   LowerVectorTranspose.cpp
+  SubsetOpInterfaceImpl.cpp
   VectorDistribute.cpp
   VectorDropLeadUnitDim.cpp
   VectorEmulateNarrowType.cpp
@@ -40,6 +41,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   MLIRMemRefUtils
   MLIRSCFDialect
   MLIRSideEffectInterfaces
+  MLIRSubsetOpInterface
   MLIRTensorDialect
   MLIRTransforms
   MLIRVectorDialect
diff --git a/mlir/lib/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.cpp
new file mode 100644
index 000000000000000..b450d5b78a46663
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.cpp
@@ -0,0 +1,82 @@
+//===- SubsetOpInterfaceImpl.cpp - Tensor subsets -------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Interfaces/SubsetOpInterface.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+template <typename OpTy>
+struct XferOpSubsetOpInterface
+    : public SubsetOpInterface::ExternalModel<XferOpSubsetOpInterface<OpTy>,
+                                              OpTy> {
+  FailureOr<HyperrectangularSlice>
+  getAccessedHyperrectangularSlice(Operation *op) const {
+    auto xferOp = cast<OpTy>(op);
+    Builder b(xferOp->getContext());
+    SmallVector<OpFoldResult> offsets = llvm::map_to_vector(
+        xferOp.getIndices(), [](Value v) -> OpFoldResult { return v; });
+    SmallVector<OpFoldResult> sizes = llvm::map_to_vector(
+        xferOp.getTransferChunkAccessed(),
+        [&](int64_t sz) -> OpFoldResult { return b.getIndexAttr(sz); });
+    return HyperrectangularSlice(offsets, sizes);
+  }
+};
+
+struct TransferReadOpSubsetExtractionOpInterface
+    : public SubsetExtractionOpInterface::ExternalModel<
+          TransferReadOpSubsetExtractionOpInterface, vector::TransferReadOp> {
+  OpOperand &getSourceOperand(Operation *op) const {
+    return cast<vector::TransferReadOp>(op).getSourceMutable();
+  }
+};
+
+struct TransferWriteOpSubsetInsertionOpInterface
+    : public SubsetInsertionOpInterface::ExternalModel<
+          TransferWriteOpSubsetInsertionOpInterface, vector::TransferWriteOp> {
+  OpOperand &getSourceOperand(Operation *op) const {
+    return cast<vector::TransferWriteOp>(op).getVectorMutable();
+  }
+
+  OpOperand &getDestinationOperand(Operation *op) const {
+    return cast<vector::TransferWriteOp>(op).getSourceMutable();
+  }
+
+  Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
+                              Location loc) const {
+    // TODO: Implement when needed.
+    return Value();
+  }
+
+  SmallVector<Value>
+  getValuesNeededToBuildSubsetExtraction(Operation *op) const {
+    // TODO: Implement when needed.
+    return {};
+  }
+};
+
+} // namespace
+
+void mlir::vector::registerSubsetOpInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
+    TransferReadOp::attachInterface<XferOpSubsetOpInterface<TransferReadOp>>(
+        *ctx);
+    TransferReadOp::attachInterface<TransferReadOpSubsetExtractionOpInterface>(
+        *ctx);
+    TransferWriteOp::attachInterface<XferOpSubsetOpInterface<TransferWriteOp>>(
+        *ctx);
+    TransferWriteOp::attachInterface<TransferWriteOpSubsetInsertionOpInterface>(
+        *ctx);
+  });
+}
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 3623952a08df024..550ffbc7bab678a 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -224,477 +224,6 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor
-func.func @hoist_vector_transfer_pairs_tensor(
-    %tensor0: tensor<?x?xf32>, %tensor1: tensor<?x?xf32>, %tensor2: tensor<?x?xf32>,
-    %tensor3: tensor<?x?xf32>, %tensor4: tensor<?x?xf32>, %tensor5: tensor<?x?xf32>,
-    %val: index, %lb : index, %ub : index, %step: index) ->
-    (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
-     tensor<?x?xf32>, tensor<?x?xf32>) {
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant 0.0 : f32
-
-// CHECK: vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<1xf32>
-// CHECK: scf.for {{.*}} iter_args({{.*}}) ->
-// CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>) {
-// CHECK:   vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<2xf32>
-// CHECK:   scf.for {{.*}} iter_args({{.*}}) ->
-// CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<2xf32>, vector<1xf32>) {
-// CHECK:     vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<4xf32>
-// CHECK:     "some_crippling_use"(%{{.*}}) : (tensor<?x?xf32>) -> ()
-// CHECK:     vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<5xf32>
-// CHECK:     "some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32>
-// CHECK:     "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
-// CHECK:     "some_use"(%{{.*}}) : (tensor<?x?xf32>) -> vector<3xf32>
-// CHECK:     "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
-// CHECK:     "some_use"(%{{.*}}) : (vector<5xf32>) -> vector<5xf32>
-// CHECK:     vector.transfer_write %{{.*}} : vector<3xf32>, tensor<?x?xf32>
-// CHECK:     vector.transfer_write %{{.*}} : vector<4xf32>, tensor<?x?xf32>
-// CHECK:     vector.transfer_write %{{.*}} : vector<5xf32>, tensor<?x?xf32>
-// CHECK:     "some_crippling_use"(%{{.*}}) : (tensor<?x?xf32>) -> ()
-// CHECK:     scf.yield {{.*}} :
-// CHECK-SAME: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<2xf32>, vector<1xf32>
-// CHECK:   }
-// CHECK:   vector.transfer_write %{{.*}} : vector<2xf32>, tensor<?x?xf32>
-// CHECK:   scf.yield {{.*}} :
-// CHECK-SAME: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>
-// CHECK: }
-// CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, tensor<?x?xf32>
-  %0:6 = scf.for %i = %lb to %ub step %step
-  iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2,
-            %arg3 = %tensor3,  %arg4 = %tensor4, %arg5 = %tensor5)
-  -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
-     tensor<?x?xf32>, tensor<?x?xf32>)  {
-    %1:6 = scf.for %j = %lb to %ub step %step
-    iter_args(%arg6 = %arg0, %arg7 = %arg1, %arg8 = %arg2,
-              %arg9 = %arg3,  %arg10 = %arg4, %arg11 = %arg5)
-    -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
-       tensor<?x?xf32>, tensor<?x?xf32>)  {
-      %r0 = vector.transfer_read %arg7[%c0, %c0], %cst: tensor<?x?xf32>, vector<1xf32>
-      %r1 = vector.transfer_read %arg6[%i, %i], %cst: tensor<?x?xf32>, vector<2xf32>
-      %r3 = vector.transfer_read %arg9[%c0, %c0], %cst: tensor<?x?xf32>, vector<4xf32>
-      "some_crippling_use"(%arg10) : (tensor<?x?xf32>) -> ()
-      %r4 = vector.transfer_read %arg10[%c0, %c0], %cst: tensor<?x?xf32>, vector<5xf32>
-      %r5 = vector.transfer_read %arg11[%c0, %c0], %cst: tensor<?x?xf32>, vector<6xf32>
-      "some_crippling_use"(%arg11) : (tensor<?x?xf32>) -> ()
-      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
-      %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
-      %u2 = "some_use"(%arg8) : (tensor<?x?xf32>) -> vector<3xf32>
-      %u3 = "some_use"(%r3) : (vector<4xf32>) -> vector<4xf32>
-      %u4 = "some_use"(%r4) : (vector<5xf32>) -> vector<5xf32>
-      %u5 = "some_use"(%r5) : (vector<6xf32>) -> vector<6xf32>
-      %w1 = vector.transfer_write %u0, %arg7[%c0, %c0] : vector<1xf32>, tensor<?x?xf32>
-      %w0 = vector.transfer_write %u1, %arg6[%i, %i] : vector<2xf32>, tensor<?x?xf32>
-      %w2 = vector.transfer_write %u2, %arg8[%c0, %c0] : vector<3xf32>, tensor<?x?xf32>
-      %w3 = vector.transfer_write %u3, %arg9[%c0, %c0] : vector<4xf32>, tensor<?x?xf32>
-      %w4 = vector.transfer_write %u4, %arg10[%c0, %c0] : vector<5xf32>, tensor<?x?xf32>
-      %w5 = vector.transfer_write %u5, %arg11[%c0, %c0] : vector<6xf32>, tensor<?x?xf32>
-      "some_crippling_use"(%w3) : (tensor<?x?xf32>) -> ()
-      scf.yield %w0, %w1, %w2, %w3, %w4, %w5 :
-        tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
-        tensor<?x?xf32>, tensor<?x?xf32>
-      }
-      scf.yield %1#0,  %1#1, %1#2, %1#3, %1#4, %1#5 :
-        tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
-        tensor<?x?xf32>, tensor<?x?xf32>
-  }
-  return %0#0,  %0#1, %0#2, %0#3, %0#4,  %0#5 :
-        tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
-        tensor<?x?xf32>, tensor<?x?xf32>
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["func.func"]} in %arg1
-      : (!transform.any_op) -> !transform.any_op
-    transform.structured.hoist_redundant_tensor_subsets %0
-      : (!transform.any_op) -> ()
-    transform.yield
-  }
-}
-
-// -----
-
-// CHECK-LABEL: func @hoist_vector_transfer_pairs_disjoint_tensor(
-//  CHECK-SAME:   %[[TENSOR0:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
-//  CHECK-SAME:   %[[TENSOR1:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
-//  CHECK-SAME:   %[[TENSOR2:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
-//  CHECK-SAME:   %[[TENSOR3:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
-func.func @hoist_vector_transfer_pairs_disjoint_tensor(
-    %tensor0: tensor<?x?xf32>, %tensor1: tensor<?x?xf32>,
-    %tensor2: tensor<?x?xf32>, %tensor3: tensor<?x?xf32>,
-    %val: index, %lb : index, %ub : index, %step: index,
-    %random_index : index) ->
-    (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c3 = arith.constant 3 : index
-  %cst = arith.constant 0.0 : f32
-
-// CHECK: vector.transfer_read %[[TENSOR2]]{{.*}} : tensor<?x?xf32>, vector<3xf32>
-// CHECK: vector.transfer_read %[[TENSOR2]]{{.*}} : tensor<?x?xf32>, vector<3xf32>
-// CHECK: vector.transfer_read %[[TENSOR3]]{{.*}} : tensor<?x?xf32>, vector<4xf32>
-// CHECK: vector.transfer_read %[[TENSOR3]]{{.*}} : tensor<?x?xf32>, vector<4xf32>
-// CHECK: %[[R:.*]]:6 = scf.for {{.*}} iter_args({{.*}}) ->
-// CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
-// CHECK:   scf.for {{.*}} iter_args({{.*}}) ->
-// CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
-// CHECK:     vector.transfer_read %[[TENSOR1]]{{.*}} : tensor<?x?xf32>, vector<2xf32>
-// CHECK:     vector.transfer_read %[[TENSOR1]]{{.*}} : tensor<?x?xf32>, vector<2xf32>
-// CHECK:     "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
-// CHECK:     "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
-// CHECK:     "some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32>
-// CHECK:     "some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32>
-// CHECK:     "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
-// CHECK:     "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
-// CHECK:     "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
-// CHECK:     "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
-// CHECK:     vector.transfer_write %{{.*}}, %{{.*}}{{.*}} : vector<2xf32>, tensor<?x?xf32>
-// CHECK:     vector.transfer_write %{{.*}}, %{{.*}}{{.*}} : vector<2xf32>, tensor<?x?xf32>
-// CHECK:     scf.yield {{.*}} :
-// CHECK-SAME: tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>
-// CHECK:   }
-// CHECK:   scf.yield {{.*}} :
-// CHECK-SAME: tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>
-// CHECK: }
-// CHECK: %[[TENSOR4:.*]] = vector.transfer_write %[[R]]#5, %[[TENSOR3]]{{.*}} : vector<4xf32>, tensor<?x?xf32>
-// CHECK:                   vector.transfer_write %[[R]]#4, %[[TENSOR4]]{{.*}} : vector<4xf32>, tensor<?x?xf32>
-// CHECK: %[[TENSOR5:.*]] = vector.transfer_write %[[R]]#3, %[[TENSOR2]]{{.*}} : vector<3xf32>, tensor<?x?xf32>
-// CHECK:                   vector.transfer_write %[[R]]#2, %[[TENSOR5]]{{.*}} : vector<3xf32>, tensor<?x?xf32>
-  %0:4 = scf.for %i = %lb to %ub step %step
-  iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2,
-            %arg3 = %tensor3)
-  -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
-    %1:4 = scf.for %j = %lb to %ub step %step
-    iter_args(%arg4 = %arg0, %arg5 = %arg1, %arg6 = %arg2,
-              %arg7 = %arg3)
-    -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
-      %r00 = vector.transfer_read %arg5[%c0, %c0], %cst: tensor<?x?xf32>, vector<2xf32>
-      %r01 = vector.transfer_read %arg5[%c0, %c1], %cst: tensor<?x?xf32>, vector<2xf32>
-      %r20 = vector.transfer_read %arg6[%c0, %c0], %cst: tensor<?x?xf32>, vector<3xf32>
-      %r21 = vector.transfer_read %arg6[%c0, %c3], %cst: tensor<?x?xf32>, vector<3xf32>
-      %r30 = vector.transfer_read %arg7[%c0, %random_index], %cst: tensor<?x?xf32>, vector<4xf32>
-      %r31 = vector.transfer_read %arg7[%c1, %random_index], %cst: tensor<?x?xf32>, vector<4xf32>
-      %r10 = vector.transfer_read %arg4[%i, %i], %cst: tensor<?x?xf32>, vector<2xf32>
-      %r11 = vector.transfer_read %arg4[%random_index, %random_index], %cst: tensor<?x?xf32>, vector<2xf32>
-      %u00 = "some_use"(%r00) : (vector<2xf32>) -> vector<2xf32>
-      %u01 = "some_use"(%r01) : (vector<2xf32>) -> vector<2xf32>
-      %u20 = "some_use"(%r20) : (vector<3xf32>) -> vector<3xf32>
-      %u21 = "some_use"(%r21) : (vector<3xf32>) -> vector<3xf32>
-      %u30 = "some_use"(%r30) : (vector<4xf32>) -> vector<4xf32>
-      %u31 = "some_use"(%r31) : (vector<4xf32>) -> vector<4xf32>
-      %u10 = "some_use"(%r10) : (vector<2xf32>) -> vector<2xf32>
-      %u11 = "some_use"(%r11) : (vector<2xf32>) -> vector<2xf32>
-      %w10 = vector.transfer_write %u00, %arg5[%c0, %c0] : vector<2xf32>, tensor<?x?xf32>
-      %w11 = vector.transfer_write %u01, %w10[%c0, %c1] : vector<2xf32>, tensor<?x?xf32>
-      %w20 = vector.transfer_write %u20, %arg6[%c0, %c0] : vector<3xf32>, tensor<?x?xf32>
-      %w21 = vector.transfer_write %u21, %w20[%c0, %c3] : vector<3xf32>, tensor<?x?xf32>
-      %w30 = vector.transfer_write %u30, %arg7[%c0, %random_index] : vector<4xf32>, tensor<?x?xf32>
-      %w31 = vector.transfer_write %u31, %w30[%c1, %random_index] : vector<4xf32>, tensor<?x?xf32>
-      %w00 = vector.transfer_write %u10, %arg4[%i, %i] : vector<2xf32>, tensor<?x?xf32>
-      %w01 = vector.transfer_write %u11, %w00[%random_index, %random_index] : vector<2xf32>, tensor<?x?xf32>
-      scf.yield %w01, %w11, %w21, %w31 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
-    }
-    scf.yield %1#0,  %1#1, %1#2, %1#3 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
-  }
-  return %0#0,  %0#1, %0#2, %0#3 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["func.func"]} in %arg1
-      : (!transform.any_op) -> !transform.any_op
-    transform.structured.hoist_redundant_tensor_subsets %0
-      : (!transform.any_op) -> ()
-    transform.yield
-  }
-}
-
-// -----
-
-// CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor_and_slices
-//  CHECK-SAME:   %[[TENSOR0:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
-//  CHECK-SAME:   %[[TENSOR1:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
-//  CHECK-SAME:   %[[TENSOR2:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
-//  CHECK-SAME:   %[[TENSOR3:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
-//  CHECK-SAME:   %[[TENSOR4:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
-//  CHECK-SAME:   %[[TENSOR5:[a-zA-Z0-9]*]]: tensor<?x?xf32>
-func.func @hoist_vector_transfer_pairs_tensor_and_slices(
-    %tensor0: tensor<?x?xf32>, %tensor1: tensor<?x?xf32>, %tensor2: tensor<?x?xf32>,
-    %tensor3: tensor<?x?xf32>, %tensor4: tensor<?x?xf32>, %tensor5: tensor<?x?xf32>,
-    %val: index, %lb : index, %ub : index, %step: index) ->
-    (
-      tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>//, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
-    ) {
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant 0.0 : f32
-
-  //      CHECK: scf.for %[[I:.*]] = {{.*}} iter_args(
-  // CHECK-SAME:   %[[TENSOR0_ARG:[0-9a-zA-Z]+]] = %[[TENSOR0]],
-  // CHECK-SAME:   %[[TENSOR1_ARG:[0-9a-zA-Z]+]] = %[[TENSOR1]],
-  // CHECK-SAME:   %[[TENSOR2_ARG:[0-9a-zA-Z]+]] = %[[TENSOR2]]
-  // CHECK-SAME: ) ->
-  // CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
-  %0:3 = scf.for %i = %lb to %ub step %step
-  iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2)
-    -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>)  {
-
-    // Hoisted
-    // CHECK:   %[[ST0:.*]] = tensor.extract_slice %[[TENSOR0_ARG]][%[[I]], %[[I]]]{{.*}}: tensor<?x?xf32> to tensor<?x?xf32>
-    // CHECK:   %[[V0:.*]] = vector.transfer_read %[[ST0]]{{.*}} : tensor<?x?xf32>, vector<1xf32>
-
-    //      CHECK:   %[[R:.*]]:3 = scf.for %[[J:.*]] = {{.*}} iter_args(
-    // CHECK-SAME:   %[[TENSOR1_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR1_ARG]]
-    // CHECK-SAME:   %[[TENSOR2_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR2_ARG]]
-    // CHECK-SAME:   %[[V0_ARG_L2:[0-9a-zA-Z]+]] = %[[V0]]
-    // CHECK-SAME: ) ->
-    // CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>
-    %1:3 = scf.for %j = %lb to %ub step %step
-    iter_args(%arg6 = %arg0, %arg7 = %arg1, %arg8 = %arg2)
-    -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>)  {
-      // Hoists.
-      %st0 = tensor.extract_slice %arg6[%i, %i][%step, %step][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-      %r0 = vector.transfer_read %st0[%c0, %c0], %cst: tensor<?x?xf32>, vector<1xf32>
-
-      // CHECK:     %[[ST1:.*]] = tensor.extract_slice %[[TENSOR1_ARG_L2]][%[[J]],{{.*}}: tensor<?x?xf32> to tensor<?x?xf32>
-      // CHECK:     %[[V1:.*]] = vector.transfer_read %[[ST1]]{{.*}} : tensor<?x?xf32>, vector<2xf32>
-      // Does not hoist (slice depends on %j)
-      %st1 = tensor.extract_slice %arg7[%j, %c0][%step, %step][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-      %r1 = vector.transfer_read %st1[%c0, %c0], %cst: tensor<?x?xf32>, vector<2xf32>
-
-      // CHECK:     %[[ST2:.*]] = tensor.extract_slice %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor<?x?xf32> to tensor<?x?xf32>
-      // CHECK:     %[[V2:.*]] = vector.transfer_read %[[ST2]]{{.*}} : tensor<?x?xf32>, vector<3xf32>
-      // Does not hoist, 2 slice %arg8.
-      %st2 = tensor.extract_slice %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-      %r2 = vector.transfer_read %st2[%c0, %c0], %cst: tensor<?x?xf32>, vector<3xf32>
-
-      // CHECK:     %[[U0:.*]] = "some_use"(%[[V0_ARG_L2]]) : (vector<1xf32>) -> vector<1xf32>
-      // CHECK:     %[[U1:.*]] = "some_use"(%[[V1]]) : (vector<2xf32>) -> vector<2xf32>
-      // CHECK:     %[[U2:.*]] = "some_use"(%[[V2]]) : (vector<3xf32>) -> vector<3xf32>
-      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
-      %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
-      %u2 = "some_use"(%r2) : (vector<3xf32>) -> vector<3xf32>
-
-      // Hoists
-      %w0 = vector.transfer_write %u0, %st0[%c0, %c0] : vector<1xf32>, tensor<?x?xf32>
-
-      // CHECK-DAG:     %[[STI1:.*]] = vector.transfer_write %[[U1]], %{{.*}} : vector<2xf32>, tensor<?x?xf32>
-      // Does not hoist (associated slice depends on %j).
-      %w1 = vector.transfer_write %u1, %st1[%i, %i] : vector<2xf32>, tensor<?x?xf32>
-
-      // CHECK-DAG:     %[[STI2:.*]] = vector.transfer_write %[[U2]], %{{.*}} : vector<3xf32>, tensor<?x?xf32>
-      // Does not hoist, 2 slice / insert_slice for %arg8.
-      %w2 = vector.transfer_write %u2, %st2[%c0, %c0] : vector<3xf32>, tensor<?x?xf32>
-
-      // Hoists.
-      %sti0 = tensor.insert_slice %w0 into %arg6[%i, %i][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
-
-      // CHECK-DAG:     tensor.insert_slice %[[STI1]] into %[[TENSOR1_ARG_L2]][%[[J]],{{.*}}: tensor<?x?xf32> into tensor<?x?xf32>
-      // Does not hoist (depends on %j).
-      %sti1 = tensor.insert_slice %w1 into %arg7[%j, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
-
-      // CHECK-DAG:     tensor.insert_slice %[[STI2]] into %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor<?x?xf32> into tensor<?x?xf32>
-      // Does not hoist, 2 slice / insert_slice for %arg8.
-      %sti2 = tensor.insert_slice %w2 into %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
-      // Extract with a different stride to make sure we cannot fold this extract with the above insert.
-      %st22 = tensor.extract_slice %sti2[%i, %c0][%step, %step][2, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-      %sti22 = tensor.insert_slice %st22 into %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
-
-      // CHECK:     scf.yield {{.*}} : tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>
-      // CHECK:   }
-      scf.yield %sti0, %sti1, %sti22:
-        tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
-    }
-
-    // Hoisted
-    // CHECK:   %[[STI0:.*]] = vector.transfer_write %[[R]]#2, %[[ST0]]{{.*}} : vector<1xf32>, tensor<?x?xf32>
-    // CHECK:   tensor.insert_slice %[[STI0]] into %[[TENSOR0_ARG]][%[[I]], %[[I]]]{{.*}} : tensor<?x?xf32> into tensor<?x?xf32>
-
-    // CHECK:   scf.yield {{.*}} : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
-    scf.yield %1#0, %1#1, %1#2 :
-      tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
-
-    // CHECK: }
-  }
-  return %0#0, %0#1, %0#2 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["func.func"]} in %arg1
-      : (!transform.any_op) -> !transform.any_op
-    transform.structured.hoist_redundant_tensor_subsets %0
-      : (!transform.any_op) -> ()
-    transform.yield
-  }
-}
-
-// -----
-
-// CHECK-LABEL: func @hoist_vector_transfer_write_pairs_disjoint_tensor(
-//  CHECK-SAME:   %[[T:.*]]: tensor<?x?xf32>,
-//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
-//   CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
-//   CHECK-DAG:   %[[R0:.*]] = vector.transfer_read %[[T]][%[[C0]], %[[C0]]], %{{.*}} : tensor<?x?xf32>, vector<2xf32>
-//   CHECK-DAG:   %[[R1:.*]] = vector.transfer_read %[[T]][%[[C0]], %[[C3]]], %{{.*}} : tensor<?x?xf32>, vector<2xf32>
-//       CHECK:   %[[F:.*]]:2 = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[R3:.*]] = %[[R1:.*]], %[[R2:.*]] = %[[R0]]) -> (vector<2xf32>, vector<2xf32>) {
-//       CHECK:     %[[R4:.*]] = "some_use"(%[[R2]]) : (vector<2xf32>) -> vector<2xf32>
-//       CHECK:     %[[R5:.*]] = "some_use"(%[[R3]]) : (vector<2xf32>) -> vector<2xf32>
-//       CHECK:     scf.yield %[[R5]], %[[R4]] : vector<2xf32>, vector<2xf32>
-//       CHECK:   }
-//       CHECK:   %[[W0:.*]] = vector.transfer_write %[[F]]#1, %[[T]][%[[C0]], %[[C0]]] : vector<2xf32>, tensor<?x?xf32>
-//       CHECK:   %[[W1:.*]] = vector.transfer_write %[[F]]#0, %[[W0]][%[[C0]], %[[C3]]] : vector<2xf32>, tensor<?x?xf32>
-//       CHECK:  return %[[W1]] : tensor<?x?xf32>
-func.func @hoist_vector_transfer_write_pairs_disjoint_tensor(
-    %tensor: tensor<?x?xf32>,
-    %val: index, %lb : index, %ub : index, %step: index) ->
-    (tensor<?x?xf32>) {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c3 = arith.constant 3 : index
-  %cst = arith.constant 0.0 : f32
-  %1 = scf.for %j = %lb to %ub step %step iter_args(%arg5 = %tensor)
-    -> (tensor<?x?xf32>) {
-    %r00 = vector.transfer_read %arg5[%c0, %c0], %cst: tensor<?x?xf32>, vector<2xf32>
-    %u00 = "some_use"(%r00) : (vector<2xf32>) -> vector<2xf32>
-    %w10 = vector.transfer_write %u00, %arg5[%c0, %c0] : vector<2xf32>, tensor<?x?xf32>
-
-    // Hoist by properly bypassing the disjoint write %w10.
-    %r01 = vector.transfer_read %w10[%c0, %c3], %cst: tensor<?x?xf32>, vector<2xf32>
-    %u01 = "some_use"(%r01) : (vector<2xf32>) -> vector<2xf32>
-    %w11 = vector.transfer_write %u01, %w10[%c0, %c3] : vector<2xf32>, tensor<?x?xf32>
-    scf.yield %w11 : tensor<?x?xf32>
-  }
-  return %1 : tensor<?x?xf32>
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["func.func"]} in %arg1
-      : (!transform.any_op) -> !transform.any_op
-    transform.structured.hoist_redundant_tensor_subsets %0
-      : (!transform.any_op) -> ()
-    transform.yield
-  }
-}
-
-// -----
-
-// CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor_and_slices_static_large_tensor
-//  CHECK-SAME:   %[[TENSOR0:[a-zA-Z0-9]*]]: tensor<100x100xf32>,
-//  CHECK-SAME:   %[[TENSOR1:[a-zA-Z0-9]*]]: tensor<200x200xf32>,
-//  CHECK-SAME:   %[[TENSOR2:[a-zA-Z0-9]*]]: tensor<300x300xf32>
-func.func @hoist_vector_transfer_pairs_tensor_and_slices_static_large_tensor(
-    %tensor0: tensor<100x100xf32>, %tensor1: tensor<200x200xf32>, %tensor2: tensor<300x300xf32>,
-    %val: index, %lb : index, %ub : index, %step: index) ->
-    (
-      tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>
-    ) {
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant 0.0 : f32
-
-  //      CHECK: scf.for %[[I:.*]] = {{.*}} iter_args(
-  // CHECK-SAME:   %[[TENSOR0_ARG:[0-9a-zA-Z]+]] = %[[TENSOR0]],
-  // CHECK-SAME:   %[[TENSOR1_ARG:[0-9a-zA-Z]+]] = %[[TENSOR1]],
-  // CHECK-SAME:   %[[TENSOR2_ARG:[0-9a-zA-Z]+]] = %[[TENSOR2]]
-  // CHECK-SAME: ) ->
-  // CHECK-SAME: (tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>
-  %0:3 = scf.for %i = %lb to %ub step %step
-  iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2)
-    -> (tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>)  {
-
-    // Hoisted
-    // CHECK:   %[[ST0:.*]] = tensor.extract_slice %[[TENSOR0_ARG]][%[[I]], %[[I]]]{{.*}}: tensor<100x100xf32> to tensor<?x?xf32>
-    // CHECK:   %[[V0:.*]] = vector.transfer_read %[[ST0]]{{.*}} : tensor<?x?xf32>, vector<1xf32>
-
-    //      CHECK:   %[[R:.*]]:3 = scf.for %[[J:.*]] = {{.*}} iter_args(
-    // CHECK-SAME:   %[[TENSOR1_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR1_ARG]]
-    // CHECK-SAME:   %[[TENSOR2_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR2_ARG]]
-    // CHECK-SAME:   %[[V0_ARG_L2:[0-9a-zA-Z]+]] = %[[V0]]
-    // CHECK-SAME: ) ->
-    // CHECK-SAME: (tensor<200x200xf32>, tensor<300x300xf32>, vector<1xf32>
-    %1:3 = scf.for %j = %lb to %ub step %step
-    iter_args(%arg6 = %arg0, %arg7 = %arg1, %arg8 = %arg2)
-    -> (tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>)  {
-      // Hoists.
-      %st0 = tensor.extract_slice %arg6[%i, %i][%step, %step][1, 1] : tensor<100x100xf32> to tensor<?x?xf32>
-      %r0 = vector.transfer_read %st0[%c0, %c0], %cst: tensor<?x?xf32>, vector<1xf32>
-
-      // CHECK:     %[[ST1:.*]] = tensor.extract_slice %[[TENSOR1_ARG_L2]][%[[J]],{{.*}}: tensor<200x200xf32> to tensor<?x?xf32>
-      // CHECK:     %[[V1:.*]] = vector.transfer_read %[[ST1]]{{.*}} : tensor<?x?xf32>, vector<2xf32>
-      // Does not hoist (slice depends on %j)
-      %st1 = tensor.extract_slice %arg7[%j, %c0][%step, %step][1, 1] : tensor<200x200xf32> to tensor<?x?xf32>
-      %r1 = vector.transfer_read %st1[%c0, %c0], %cst: tensor<?x?xf32>, vector<2xf32>
-
-      // CHECK:     %[[ST2:.*]] = tensor.extract_slice %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor<300x300xf32> to tensor<?x?xf32>
-      // CHECK:     %[[V2:.*]] = vector.transfer_read %[[ST2]]{{.*}} : tensor<?x?xf32>, vector<3xf32>
-      // Does not hoist, 2 slice %arg8.
-      %st2 = tensor.extract_slice %arg8[%i, %c0][%step, %step][1, 1] : tensor<300x300xf32> to tensor<?x?xf32>
-      %r2 = vector.transfer_read %st2[%c0, %c0], %cst: tensor<?x?xf32>, vector<3xf32>
-
-      // CHECK:     %[[U0:.*]] = "some_use"(%[[V0_ARG_L2]]) : (vector<1xf32>) -> vector<1xf32>
-      // CHECK:     %[[U1:.*]] = "some_use"(%[[V1]]) : (vector<2xf32>) -> vector<2xf32>
-      // CHECK:     %[[U2:.*]] = "some_use"(%[[V2]]) : (vector<3xf32>) -> vector<3xf32>
-      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
-      %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
-      %u2 = "some_use"(%r2) : (vector<3xf32>) -> vector<3xf32>
-
-      // Hoists
-      %w0 = vector.transfer_write %u0, %st0[%c0, %c0] : vector<1xf32>, tensor<?x?xf32>
-
-      // CHECK-DAG:     %[[STI1:.*]] = vector.transfer_write %[[U1]], %{{.*}} : vector<2xf32>, tensor<?x?xf32>
-      // Does not hoist (associated slice depends on %j).
-      %w1 = vector.transfer_write %u1, %st1[%i, %i] : vector<2xf32>, tensor<?x?xf32>
-
-      // CHECK-DAG:     %[[STI2:.*]] = vector.transfer_write %[[U2]], %{{.*}} : vector<3xf32>, tensor<?x?xf32>
-      // Does not hoist, 2 slice / insert_slice for %arg8.
-      %w2 = vector.transfer_write %u2, %st2[%c0, %c0] : vector<3xf32>, tensor<?x?xf32>
-
-      // Hoists.
-      %sti0 = tensor.insert_slice %w0 into %arg6[%i, %i][%step, %step][1, 1] : tensor<?x?xf32> into tensor<100x100xf32>
-
-      // CHECK-DAG:     tensor.insert_slice %[[STI1]] into %[[TENSOR1_ARG_L2]][%[[J]],{{.*}}: tensor<?x?xf32> into tensor<200x200xf32>
-      // Does not hoist (depends on %j).
-      %sti1 = tensor.insert_slice %w1 into %arg7[%j, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<200x200xf32>
-
-      // CHECK-DAG:     tensor.insert_slice %[[STI2]] into %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor<?x?xf32> into tensor<300x300xf32>
-      // Does not hoist, 2 slice / insert_slice for %arg8.
-      %sti2 = tensor.insert_slice %w2 into %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<300x300xf32>
-      // Extract with a different stride to make sure we cannot fold this extract with the above insert.
-      %st22 = tensor.extract_slice %sti2[%i, %c0][%step, %step][2, 1] : tensor<300x300xf32> to tensor<?x?xf32>
-      %sti22 = tensor.insert_slice %st22 into %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<300x300xf32>
-
-      // CHECK:     scf.yield {{.*}} : tensor<200x200xf32>, tensor<300x300xf32>, vector<1xf32>
-      // CHECK:   }
-      scf.yield %sti0, %sti1, %sti22:
-        tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>
-    }
-
-    // Hoisted
-    // CHECK:   %[[STI0:.*]] = vector.transfer_write %[[R]]#2, %[[ST0]]{{.*}} : vector<1xf32>, tensor<?x?xf32>
-    // CHECK:   tensor.insert_slice %[[STI0]] into %[[TENSOR0_ARG]][%[[I]], %[[I]]]{{.*}} : tensor<?x?xf32> into tensor<100x100xf32>
-
-    // CHECK:   scf.yield {{.*}} : tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>
-    scf.yield %1#0, %1#1, %1#2 :
-      tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>
-
-    // CHECK: }
-  }
-  return %0#0, %0#1, %0#2 : tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["func.func"]} in %arg1
-      : (!transform.any_op) -> !transform.any_op
-    transform.structured.hoist_redundant_tensor_subsets %0
-      : (!transform.any_op) -> ()
-    transform.yield
-  }
-}
-
-// -----
-
 // CHECK-LABEL:  func.func @hoist_vector_transfer_read(
 // CHECK-DAG:      %[[C0:.+]] = arith.constant 0 : index
 // CHECK-DAG:      %[[C128:.+]] = arith.constant 128 : index
diff --git a/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
index bb60eeaba52455c..3a78287a0dcad2f 100644
--- a/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
+++ b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
@@ -277,3 +277,321 @@ func.func @nested_hoisting(%arg: tensor<?xf32>) -> tensor<?xf32> {
   // CHECK: return %[[insert2]]
   return %0 : tensor<?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor
+func.func @hoist_vector_transfer_pairs_tensor(
+    %tensor0: tensor<?x?xf32>, %tensor1: tensor<?x?xf32>, %tensor2: tensor<?x?xf32>,
+    %tensor3: tensor<?x?xf32>, %tensor4: tensor<?x?xf32>, %tensor5: tensor<?x?xf32>,
+    %val: index, %lb : index, %ub : index, %step: index) ->
+    (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+     tensor<?x?xf32>, tensor<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f32
+
+// CHECK: vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<1xf32>
+// CHECK: scf.for {{.*}} iter_args({{.*}}) ->
+// CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>) {
+// CHECK:   vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<2xf32>
+// CHECK:   scf.for {{.*}} iter_args({{.*}}) ->
+// CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<2xf32>, vector<1xf32>) {
+// CHECK:     vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<4xf32>
+// CHECK:     "test.some_crippling_use"(%{{.*}}) : (tensor<?x?xf32>) -> ()
+// CHECK:     vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<5xf32>
+// CHECK:     "test.some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:     "test.some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
+// CHECK:     "test.some_use"(%{{.*}}) : (tensor<?x?xf32>) -> vector<3xf32>
+// CHECK:     "test.some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
+// CHECK:     "test.some_use"(%{{.*}}) : (vector<5xf32>) -> vector<5xf32>
+// CHECK:     vector.transfer_write %{{.*}} : vector<3xf32>, tensor<?x?xf32>
+// CHECK:     vector.transfer_write %{{.*}} : vector<4xf32>, tensor<?x?xf32>
+// CHECK:     vector.transfer_write %{{.*}} : vector<5xf32>, tensor<?x?xf32>
+// CHECK:     "test.some_crippling_use"(%{{.*}}) : (tensor<?x?xf32>) -> ()
+// CHECK:     scf.yield {{.*}} :
+// CHECK-SAME: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<2xf32>, vector<1xf32>
+// CHECK:   }
+// CHECK:   vector.transfer_write %{{.*}} : vector<2xf32>, tensor<?x?xf32>
+// CHECK:   scf.yield {{.*}} :
+// CHECK-SAME: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>
+// CHECK: }
+// CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, tensor<?x?xf32>
+  %0:6 = scf.for %i = %lb to %ub step %step
+  iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2,
+            %arg3 = %tensor3,  %arg4 = %tensor4, %arg5 = %tensor5)
+  -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+     tensor<?x?xf32>, tensor<?x?xf32>)  {
+    %1:6 = scf.for %j = %lb to %ub step %step
+    iter_args(%arg6 = %arg0, %arg7 = %arg1, %arg8 = %arg2,
+              %arg9 = %arg3,  %arg10 = %arg4, %arg11 = %arg5)
+    -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+       tensor<?x?xf32>, tensor<?x?xf32>)  {
+      %r0 = vector.transfer_read %arg7[%c0, %c0], %cst: tensor<?x?xf32>, vector<1xf32>
+      %r1 = vector.transfer_read %arg6[%i, %i], %cst: tensor<?x?xf32>, vector<2xf32>
+      %r3 = vector.transfer_read %arg9[%c0, %c0], %cst: tensor<?x?xf32>, vector<4xf32>
+      "test.some_crippling_use"(%arg10) : (tensor<?x?xf32>) -> ()
+      %r4 = vector.transfer_read %arg10[%c0, %c0], %cst: tensor<?x?xf32>, vector<5xf32>
+      %r5 = vector.transfer_read %arg11[%c0, %c0], %cst: tensor<?x?xf32>, vector<6xf32>
+      "test.some_crippling_use"(%arg11) : (tensor<?x?xf32>) -> ()
+      %u0 = "test.some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+      %u1 = "test.some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
+      %u2 = "test.some_use"(%arg8) : (tensor<?x?xf32>) -> vector<3xf32>
+      %u3 = "test.some_use"(%r3) : (vector<4xf32>) -> vector<4xf32>
+      %u4 = "test.some_use"(%r4) : (vector<5xf32>) -> vector<5xf32>
+      %u5 = "test.some_use"(%r5) : (vector<6xf32>) -> vector<6xf32>
+      %w1 = vector.transfer_write %u0, %arg7[%c0, %c0] : vector<1xf32>, tensor<?x?xf32>
+      %w0 = vector.transfer_write %u1, %arg6[%i, %i] : vector<2xf32>, tensor<?x?xf32>
+      %w2 = vector.transfer_write %u2, %arg8[%c0, %c0] : vector<3xf32>, tensor<?x?xf32>
+      %w3 = vector.transfer_write %u3, %arg9[%c0, %c0] : vector<4xf32>, tensor<?x?xf32>
+      %w4 = vector.transfer_write %u4, %arg10[%c0, %c0] : vector<5xf32>, tensor<?x?xf32>
+      %w5 = vector.transfer_write %u5, %arg11[%c0, %c0] : vector<6xf32>, tensor<?x?xf32>
+      "test.some_crippling_use"(%w3) : (tensor<?x?xf32>) -> ()
+      scf.yield %w0, %w1, %w2, %w3, %w4, %w5 :
+        tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+        tensor<?x?xf32>, tensor<?x?xf32>
+      }
+      scf.yield %1#0,  %1#1, %1#2, %1#3, %1#4, %1#5 :
+        tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+        tensor<?x?xf32>, tensor<?x?xf32>
+  }
+  return %0#0,  %0#1, %0#2, %0#3, %0#4,  %0#5 :
+        tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+        tensor<?x?xf32>, tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @hoist_vector_transfer_pairs_disjoint_tensor(
+//  CHECK-SAME:   %[[TENSOR0:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+//  CHECK-SAME:   %[[TENSOR1:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+//  CHECK-SAME:   %[[TENSOR2:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+//  CHECK-SAME:   %[[TENSOR3:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+func.func @hoist_vector_transfer_pairs_disjoint_tensor(
+    %tensor0: tensor<?x?xf32>, %tensor1: tensor<?x?xf32>,
+    %tensor2: tensor<?x?xf32>, %tensor3: tensor<?x?xf32>,
+    %val: index, %lb : index, %ub : index, %step: index,
+    %random_index : index) ->
+    (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %cst = arith.constant 0.0 : f32
+
+// CHECK: vector.transfer_read %[[TENSOR2]]{{.*}} : tensor<?x?xf32>, vector<3xf32>
+// CHECK: vector.transfer_read %[[TENSOR2]]{{.*}} : tensor<?x?xf32>, vector<3xf32>
+// CHECK: vector.transfer_read %[[TENSOR3]]{{.*}} : tensor<?x?xf32>, vector<4xf32>
+// CHECK: vector.transfer_read %[[TENSOR3]]{{.*}} : tensor<?x?xf32>, vector<4xf32>
+// CHECK: %[[R:.*]]:8 = scf.for {{.*}} iter_args({{.*}}) ->
+// CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
+// CHECK:   scf.for {{.*}} iter_args({{.*}}) ->
+// CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
+// CHECK:     vector.transfer_read %[[TENSOR1]]{{.*}} : tensor<?x?xf32>, vector<2xf32>
+// CHECK:     vector.transfer_read %[[TENSOR1]]{{.*}} : tensor<?x?xf32>, vector<2xf32>
+// CHECK:     "test.some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
+// CHECK:     "test.some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
+// CHECK:     "test.some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32>
+// CHECK:     "test.some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32>
+// CHECK:     "test.some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
+// CHECK:     "test.some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
+// CHECK:     "test.some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
+// CHECK:     "test.some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
+// CHECK:     vector.transfer_write %{{.*}}, %{{.*}}{{.*}} : vector<2xf32>, tensor<?x?xf32>
+// CHECK:     vector.transfer_write %{{.*}}, %{{.*}}{{.*}} : vector<2xf32>, tensor<?x?xf32>
+// CHECK:     scf.yield {{.*}} :
+// CHECK-SAME: tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>
+// CHECK:   }
+// CHECK:   scf.yield {{.*}} :
+// CHECK-SAME: tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>
+// CHECK: }
+// CHECK: %[[TENSOR4:.*]] = vector.transfer_write %[[R]]#7, %[[R]]#3{{.*}} : vector<4xf32>, tensor<?x?xf32>
+// CHECK:                   vector.transfer_write %[[R]]#6, %[[TENSOR4]]{{.*}} : vector<4xf32>, tensor<?x?xf32>
+// CHECK: %[[TENSOR5:.*]] = vector.transfer_write %[[R]]#5, %[[R]]#2{{.*}} : vector<3xf32>, tensor<?x?xf32>
+// CHECK:                   vector.transfer_write %[[R]]#4, %[[TENSOR5]]{{.*}} : vector<3xf32>, tensor<?x?xf32>
+  %0:4 = scf.for %i = %lb to %ub step %step
+  iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2,
+            %arg3 = %tensor3)
+  -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
+    %1:4 = scf.for %j = %lb to %ub step %step
+    iter_args(%arg4 = %arg0, %arg5 = %arg1, %arg6 = %arg2,
+              %arg7 = %arg3)
+    -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
+      %r00 = vector.transfer_read %arg5[%c0, %c0], %cst: tensor<?x?xf32>, vector<2xf32>
+      %r01 = vector.transfer_read %arg5[%c0, %c1], %cst: tensor<?x?xf32>, vector<2xf32>
+      %r20 = vector.transfer_read %arg6[%c0, %c0], %cst: tensor<?x?xf32>, vector<3xf32>
+      %r21 = vector.transfer_read %arg6[%c0, %c3], %cst: tensor<?x?xf32>, vector<3xf32>
+      %r30 = vector.transfer_read %arg7[%c0, %random_index], %cst: tensor<?x?xf32>, vector<4xf32>
+      %r31 = vector.transfer_read %arg7[%c1, %random_index], %cst: tensor<?x?xf32>, vector<4xf32>
+      %r10 = vector.transfer_read %arg4[%i, %i], %cst: tensor<?x?xf32>, vector<2xf32>
+      %r11 = vector.transfer_read %arg4[%random_index, %random_index], %cst: tensor<?x?xf32>, vector<2xf32>
+      %u00 = "test.some_use"(%r00) : (vector<2xf32>) -> vector<2xf32>
+      %u01 = "test.some_use"(%r01) : (vector<2xf32>) -> vector<2xf32>
+      %u20 = "test.some_use"(%r20) : (vector<3xf32>) -> vector<3xf32>
+      %u21 = "test.some_use"(%r21) : (vector<3xf32>) -> vector<3xf32>
+      %u30 = "test.some_use"(%r30) : (vector<4xf32>) -> vector<4xf32>
+      %u31 = "test.some_use"(%r31) : (vector<4xf32>) -> vector<4xf32>
+      %u10 = "test.some_use"(%r10) : (vector<2xf32>) -> vector<2xf32>
+      %u11 = "test.some_use"(%r11) : (vector<2xf32>) -> vector<2xf32>
+      %w10 = vector.transfer_write %u00, %arg5[%c0, %c0] : vector<2xf32>, tensor<?x?xf32>
+      %w11 = vector.transfer_write %u01, %w10[%c0, %c1] : vector<2xf32>, tensor<?x?xf32>
+      %w20 = vector.transfer_write %u20, %arg6[%c0, %c0] : vector<3xf32>, tensor<?x?xf32>
+      %w21 = vector.transfer_write %u21, %w20[%c0, %c3] : vector<3xf32>, tensor<?x?xf32>
+      %w30 = vector.transfer_write %u30, %arg7[%c0, %random_index] : vector<4xf32>, tensor<?x?xf32>
+      %w31 = vector.transfer_write %u31, %w30[%c1, %random_index] : vector<4xf32>, tensor<?x?xf32>
+      %w00 = vector.transfer_write %u10, %arg4[%i, %i] : vector<2xf32>, tensor<?x?xf32>
+      %w01 = vector.transfer_write %u11, %w00[%random_index, %random_index] : vector<2xf32>, tensor<?x?xf32>
+      scf.yield %w01, %w11, %w21, %w31 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+    }
+    scf.yield %1#0,  %1#1, %1#2, %1#3 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+  }
+  return %0#0,  %0#1, %0#2, %0#3 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor_and_slices
+//  CHECK-SAME:   %[[TENSOR0:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+//  CHECK-SAME:   %[[TENSOR1:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+//  CHECK-SAME:   %[[TENSOR2:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+//  CHECK-SAME:   %[[TENSOR3:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+//  CHECK-SAME:   %[[TENSOR4:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+//  CHECK-SAME:   %[[TENSOR5:[a-zA-Z0-9]*]]: tensor<?x?xf32>
+func.func @hoist_vector_transfer_pairs_tensor_and_slices(
+    %tensor0: tensor<?x?xf32>, %tensor1: tensor<?x?xf32>, %tensor2: tensor<?x?xf32>,
+    %tensor3: tensor<?x?xf32>, %tensor4: tensor<?x?xf32>, %tensor5: tensor<?x?xf32>,
+    %val: index, %lb : index, %ub : index, %step: index) ->
+    (
+      tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>//, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+    ) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f32
+
+  //      CHECK: scf.for %[[I:.*]] = {{.*}} iter_args(
+  // CHECK-SAME:   %[[TENSOR0_ARG:[0-9a-zA-Z]+]] = %[[TENSOR0]],
+  // CHECK-SAME:   %[[TENSOR1_ARG:[0-9a-zA-Z]+]] = %[[TENSOR1]],
+  // CHECK-SAME:   %[[TENSOR2_ARG:[0-9a-zA-Z]+]] = %[[TENSOR2]]
+  // CHECK-SAME: ) ->
+  // CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+  %0:3 = scf.for %i = %lb to %ub step %step
+  iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2)
+    -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>)  {
+
+    // Hoisted
+    // CHECK:   %[[ST0:.*]] = tensor.extract_slice %[[TENSOR0_ARG]][%[[I]], %[[I]]]{{.*}}: tensor<?x?xf32> to tensor<?x?xf32>
+    // CHECK:   %[[V0:.*]] = vector.transfer_read %[[ST0]]{{.*}} : tensor<?x?xf32>, vector<1xf32>
+
+    //      CHECK:   %[[R:.*]]:5 = scf.for %[[J:.*]] = {{.*}} iter_args(
+    // CHECK-SAME:   %[[TENSOR0_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR0_ARG]]
+    // CHECK-SAME:   %[[TENSOR1_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR1_ARG]]
+    // CHECK-SAME:   %[[TENSOR2_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR2_ARG]]
+    // CHECK-SAME:   %[[ST0_ARG_L2:[0-9a-zA-Z]+]] = %[[ST0]]
+    // CHECK-SAME:   %[[V0_ARG_L2:[0-9a-zA-Z]+]] = %[[V0]]
+    // CHECK-SAME: ) ->
+    // CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>)
+    %1:3 = scf.for %j = %lb to %ub step %step
+    iter_args(%arg6 = %arg0, %arg7 = %arg1, %arg8 = %arg2)
+    -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>)  {
+      // Hoists.
+      %st0 = tensor.extract_slice %arg6[%i, %i][%step, %step][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+      %r0 = vector.transfer_read %st0[%c0, %c0], %cst: tensor<?x?xf32>, vector<1xf32>
+
+      // CHECK:     %[[ST1:.*]] = tensor.extract_slice %[[TENSOR1_ARG_L2]][%[[J]],{{.*}}: tensor<?x?xf32> to tensor<?x?xf32>
+      // CHECK:     %[[V1:.*]] = vector.transfer_read %[[ST1]]{{.*}} : tensor<?x?xf32>, vector<2xf32>
+      // Does not hoist (slice depends on %j)
+      %st1 = tensor.extract_slice %arg7[%j, %c0][%step, %step][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+      %r1 = vector.transfer_read %st1[%c0, %c0], %cst: tensor<?x?xf32>, vector<2xf32>
+
+      // CHECK:     %[[ST2:.*]] = tensor.extract_slice %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor<?x?xf32> to tensor<?x?xf32>
+      // CHECK:     %[[V2:.*]] = vector.transfer_read %[[ST2]]{{.*}} : tensor<?x?xf32>, vector<3xf32>
+      // Does not hoist, 2 slice %arg8.
+      %st2 = tensor.extract_slice %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+      %r2 = vector.transfer_read %st2[%c0, %c0], %cst: tensor<?x?xf32>, vector<3xf32>
+
+      // CHECK:     %[[U0:.*]] = "test.some_use"(%[[V0_ARG_L2]]) : (vector<1xf32>) -> vector<1xf32>
+      // CHECK:     %[[U1:.*]] = "test.some_use"(%[[V1]]) : (vector<2xf32>) -> vector<2xf32>
+      // CHECK:     %[[U2:.*]] = "test.some_use"(%[[V2]]) : (vector<3xf32>) -> vector<3xf32>
+      %u0 = "test.some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+      %u1 = "test.some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
+      %u2 = "test.some_use"(%r2) : (vector<3xf32>) -> vector<3xf32>
+
+      // Hoists
+      %w0 = vector.transfer_write %u0, %st0[%c0, %c0] : vector<1xf32>, tensor<?x?xf32>
+
+      // CHECK-DAG:     %[[STI1:.*]] = vector.transfer_write %[[U1]], %{{.*}} : vector<2xf32>, tensor<?x?xf32>
+      // Does not hoist (associated slice depends on %j).
+      %w1 = vector.transfer_write %u1, %st1[%i, %i] : vector<2xf32>, tensor<?x?xf32>
+
+      // CHECK-DAG:     %[[STI2:.*]] = vector.transfer_write %[[U2]], %{{.*}} : vector<3xf32>, tensor<?x?xf32>
+      // Does not hoist, 2 slice / insert_slice for %arg8.
+      %w2 = vector.transfer_write %u2, %st2[%c0, %c0] : vector<3xf32>, tensor<?x?xf32>
+
+      // Hoists.
+      %sti0 = tensor.insert_slice %w0 into %arg6[%i, %i][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+
+      // CHECK-DAG:     tensor.insert_slice %[[STI1]] into %[[TENSOR1_ARG_L2]][%[[J]],{{.*}}: tensor<?x?xf32> into tensor<?x?xf32>
+      // Does not hoist (depends on %j).
+      %sti1 = tensor.insert_slice %w1 into %arg7[%j, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+
+      // CHECK-DAG:     tensor.insert_slice %[[STI2]] into %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor<?x?xf32> into tensor<?x?xf32>
+      // Does not hoist, 2 slice / insert_slice for %arg8.
+      %sti2 = tensor.insert_slice %w2 into %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+      // Extract with a different stride to make sure we cannot fold this extract with the above insert.
+      %st22 = tensor.extract_slice %sti2[%i, %c0][%step, %step][2, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+      %sti22 = tensor.insert_slice %st22 into %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+
+      // CHECK:     scf.yield {{.*}} : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>
+      // CHECK:   }
+      scf.yield %sti0, %sti1, %sti22:
+        tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+    }
+
+    // Hoisted
+    // CHECK:   %[[STI0:.*]] = vector.transfer_write %[[R]]#4, %[[R]]#3{{.*}} : vector<1xf32>, tensor<?x?xf32>
+    // CHECK:   tensor.insert_slice %[[STI0]] into %[[R]]#0[%[[I]], %[[I]]]{{.*}} : tensor<?x?xf32> into tensor<?x?xf32>
+
+    // CHECK:   scf.yield {{.*}} : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+    scf.yield %1#0, %1#1, %1#2 :
+      tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+
+    // CHECK: }
+  }
+  return %0#0, %0#1, %0#2 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @hoist_vector_transfer_write_pairs_disjoint_tensor(
+//  CHECK-SAME:   %[[T:.*]]: tensor<?x?xf32>,
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
+//   CHECK-DAG:   %[[R0:.*]] = vector.transfer_read %[[T]][%[[C0]], %[[C0]]], %{{.*}} : tensor<?x?xf32>, vector<2xf32>
+//   CHECK-DAG:   %[[R1:.*]] = vector.transfer_read %[[T]][%[[C0]], %[[C3]]], %{{.*}} : tensor<?x?xf32>, vector<2xf32>
+//       CHECK:   %[[F:.*]]:3 = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[TL:.*]] = %[[T]], %[[R2:.*]] = %[[R0]], %[[R3:.*]] = %[[R1]]) -> (tensor<?x?xf32>, vector<2xf32>, vector<2xf32>) {
+//       CHECK:     %[[R4:.*]] = "test.some_use"(%[[R2]]) : (vector<2xf32>) -> vector<2xf32>
+//       CHECK:     %[[R5:.*]] = "test.some_use"(%[[R3]]) : (vector<2xf32>) -> vector<2xf32>
+//       CHECK:     scf.yield %[[TL]], %[[R4]], %[[R5]] : tensor<?x?xf32>, vector<2xf32>, vector<2xf32>
+//       CHECK:   }
+//       CHECK:   %[[W0:.*]] = vector.transfer_write %[[F]]#2, %[[F]]#0[%[[C0]], %[[C3]]] : vector<2xf32>, tensor<?x?xf32>
+//       CHECK:   %[[W1:.*]] = vector.transfer_write %[[F]]#1, %[[W0]][%[[C0]], %[[C0]]] : vector<2xf32>, tensor<?x?xf32>
+//       CHECK:  return %[[W1]] : tensor<?x?xf32>
+func.func @hoist_vector_transfer_write_pairs_disjoint_tensor(
+    %tensor: tensor<?x?xf32>,
+    %val: index, %lb : index, %ub : index, %step: index) ->
+    (tensor<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %cst = arith.constant 0.0 : f32
+  %1 = scf.for %j = %lb to %ub step %step iter_args(%arg5 = %tensor)
+    -> (tensor<?x?xf32>) {
+    %r00 = vector.transfer_read %arg5[%c0, %c0], %cst: tensor<?x?xf32>, vector<2xf32>
+    %u00 = "test.some_use"(%r00) : (vector<2xf32>) -> vector<2xf32>
+    %w10 = vector.transfer_write %u00, %arg5[%c0, %c0] : vector<2xf32>, tensor<?x?xf32>
+
+    // Hoist by properly bypassing the disjoint write %w10.
+    %r01 = vector.transfer_read %w10[%c0, %c3], %cst: tensor<?x?xf32>, vector<2xf32>
+    %u01 = "test.some_use"(%r01) : (vector<2xf32>) -> vector<2xf32>
+    %w11 = vector.transfer_write %u01, %w10[%c0, %c3] : vector<2xf32>, tensor<?x?xf32>
+    scf.yield %w11 : tensor<?x?xf32>
+  }
+  return %1 : tensor<?x?xf32>
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 4e814428f0cc457..daa3e56d7d4f06b 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4598,6 +4598,7 @@ cc_library(
         ":Pass",
         ":SCFDialect",
         ":SideEffectInterfaces",
+        ":SubsetOpInterface",
         ":Support",
         ":TensorDialect",
         ":Transforms",

>From 62af2908562759413c6955148c6486d06b739ee0 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 30 Oct 2023 15:51:20 +0900
Subject: [PATCH 6/7] [mlir][transform] Add transform op for loop-invariant
 subset hoisting

---
 .../Linalg/TransformOps/LinalgTransformOps.td | 50 ----------------
 .../mlir/Dialect/Transform/IR/TransformOps.td | 59 +++++++++++++++++++
 .../Transforms/LoopInvariantCodeMotionUtils.h |  4 +-
 .../TransformOps/LinalgTransformOps.cpp       | 29 ---------
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 20 +++++++
 .../Transforms/LoopInvariantCodeMotion.cpp    |  4 +-
 .../Utils/LoopInvariantCodeMotionUtils.cpp    | 17 +++---
 .../Dialect/Transform/test-interpreter.mlir   | 45 ++++++++++++++
 8 files changed, 140 insertions(+), 88 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 1ff88d036bc036c..732b6fe95c837d6 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2210,56 +2210,6 @@ def ConvertConv2DToImg2ColOp : Op<Transform_Dialect,
   }];
 }
 
-//===----------------------------------------------------------------------===//
-// HoistRedundantTensorSubsetsOp
-//===----------------------------------------------------------------------===//
-
-def HoistRedundantTensorSubsetsOp :
-  Op<Transform_Dialect, "structured.hoist_redundant_tensor_subsets",
-    [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
-     TransformEachOpTrait,
-     TransformOpInterface,
-     ReportTrackingListenerFailuresOpTrait]> {
-  let description = [{
-    Hoists supported tensor subset extract/insert operation pairs out of
-    immediately enclosing loop iteratively, if the following conditions
-    are true:
-       1. The 2 ops access the same tensor subset.
-       2. All operands are invariant under the enclosing loop.
-
-    The supported subset extract/insert operation pairs currently comprise:
-       - tensor.extract_slice / tensor.insert_slice
-       - vector.transfer_read / vector.transfer_write on tensors
-
-    Only scf.for loops are currently supported.
-
-    When applied to:
-       1. an scf.for loop, hoist out of this loop only.
-       2. a non-loop op, apply hoisting to all the contained loop ops.
-
-    #### Return modes:
-
-    The operation always succeeds and returns nothing.
-  }];
-
-  let arguments = (ins TransformHandleTypeInterface:$target);
-  let results = (outs);
-
-  let assemblyFormat = [{
-    $target
-    attr-dict
-    `:` functional-type(operands, results)
-  }];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::Operation *target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
-}
-
 //===----------------------------------------------------------------------===//
 // InsertSliceToCopyOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index b14c89eadb097d9..6d57e104a90285a 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -691,6 +691,65 @@ def GetTypeOp : TransformDialectOp<"get_type",
                        "functional-type(operands, results)";
 }
 
+def HoistLoopInvariantSubsetsOp
+    : TransformDialectOp<"hoist_loop_invariant_subsets",
+        [TransformOpInterface, TransformEachOpTrait,
+         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+         ReportTrackingListenerFailuresOpTrait]> {
+  let summary = "Hoist loop invariant subset ops";
+  let description = [{
+    This transform hoist loop-invariant subset ops out of loop-like ops. It
+    looks for matching subset extraction/insertion op pairs and hoists them. The
+    loop body operates on a newly introduced region iter_arg.
+
+    Example:
+    ```
+    %r = scf.for ... iter_args(%t = %a) -> (tensor<?xf32>) {
+      %0 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+      %1 = "test.foo"(%0) : (tensor<5xf32>) -> (tensor<5xf32>)
+      %2 = tensor.insert_slice %1 into %t[0][5][1]
+          : tensor<5xf32> into tensor<?xf32>
+      scf.yield %2 : tensor<?xf32>
+    }
+    ```
+    Is transformed to:
+    ```
+    %0 = tensor.extract_slice %a[0][5][1] : tensor<?xf32> to tensor<5xf32>
+    %new_loop:2 = scf.for ... iter_args(%t = %a, %h = %0) -> (tensor<?xf32>) {
+      %1 = "test.foo"(%h) : (tensor<5xf32>) -> (tensor<5xf32>)
+      scf.yield %t, %2 : tensor<?xf32>, tensor<5xf32>
+    }
+    %r = tensor.insert_slice %new_loop#1 into %new_loop#0
+        : tensor<5xf32> into tensor<?xf32>
+    ```
+
+    Subset ops are hoisted only if there are no conflicting subset ops. E.g.,
+    if there were a second overlapping extraction in the above example, no ops
+    could be hoisted safely.
+
+    This transform looks for `LoopLikeOpInterface` ops within the targeted op,
+    including the target op itself. It attempts hoisting on all found loop-like
+    ops.
+
+    This transform reads the target handle and modifies the payload.
+
+    TODO: Make this op more targeted if needed. I.e., apply the transformation
+    only to the targeted `LoopLikeOpInterface` op.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs);
+  let assemblyFormat = "$target attr-dict `:` type($target)";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+      ::mlir::transform::TransformRewriter &rewriter,
+      ::mlir::Operation *target,
+      ::mlir::transform::ApplyToEachResultList &results,
+      ::mlir::transform::TransformState &state);
+  }];
+}
+
 def IncludeOp : TransformDialectOp<"include",
     [CallOpInterface,
      MatchOpInterface,
diff --git a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
index 579054070f729b0..3ceef44d799e893 100644
--- a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
@@ -18,6 +18,7 @@ namespace mlir {
 class LoopLikeOpInterface;
 class Operation;
 class Region;
+class RewriterBase;
 class Value;
 
 /// Given a list of regions, perform loop-invariant code motion. An operation is
@@ -108,7 +109,8 @@ size_t moveLoopInvariantCode(LoopLikeOpInterface loopLike);
 /// %r = tensor.insert_slice %new_loop#1 into %new_loop#0
 ///     : tensor<5xf32> into tensor<?xf32>
 /// ```
-LoopLikeOpInterface hoistLoopInvariantSubsets(LoopLikeOpInterface loopLike);
+LoopLikeOpInterface hoistLoopInvariantSubsets(RewriterBase &rewriter,
+                                              LoopLikeOpInterface loopLike);
 
 } // end namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8508507871d0c6c..79bb708eea67572 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3139,35 +3139,6 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
-//===----------------------------------------------------------------------===//
-// HoistRedundantTensorSubsetsOp
-//===----------------------------------------------------------------------===//
-
-DiagnosedSilenceableFailure
-transform::HoistRedundantTensorSubsetsOp::applyToOne(
-    transform::TransformRewriter &rewriter, Operation *target,
-    transform::ApplyToEachResultList &results,
-    transform::TransformState &state) {
-  auto forOp = dyn_cast<scf::ForOp>(target);
-  if (forOp) {
-    linalg::hoistRedundantSubsetExtractInsert(rewriter, forOp);
-    return DiagnosedSilenceableFailure::success();
-  }
-
-  // TODO: walking in some reverse / inside-out order would be more efficient
-  // and would capture more cases.
-  target->walk([&](scf::ForOp forOp) {
-    hoistRedundantSubsetExtractInsert(rewriter, forOp);
-  });
-  return DiagnosedSilenceableFailure::success();
-}
-
-void transform::HoistRedundantTensorSubsetsOp::getEffects(
-    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  transform::onlyReadsHandle(getTarget(), effects);
-  transform::modifiesPayload(effects);
-}
-
 //===----------------------------------------------------------------------===//
 // InsertSliceToCopyOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 8db77b6059dd2e3..ff71518e4358472 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1395,6 +1395,26 @@ transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// HoistLoopInvariantSubsetsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::HoistLoopInvariantSubsetsOp::applyToOne(
+    transform::TransformRewriter &rewriter, Operation *target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  target->walk([&](LoopLikeOpInterface loopLike) {
+    (void)hoistLoopInvariantSubsets(rewriter, loopLike);
+  });
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::HoistLoopInvariantSubsetsOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getTarget(), effects);
+  transform::modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // IncludeOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
index e6d8af8f05832d3..02c3ea1ce9b650c 100644
--- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
+++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Transforms/Passes.h"
 
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
@@ -47,11 +48,12 @@ void LoopInvariantCodeMotion::runOnOperation() {
 }
 
 void LoopInvariantSubsetHoisting::runOnOperation() {
+  IRRewriter rewriter(getOperation()->getContext());
   // Walk through all loops in a function in innermost-loop-first order. This
   // way, we first hoist from the inner loop, and place the ops in the outer
   // loop, which in turn can be further hoisted from.
   getOperation()->walk([&](LoopLikeOpInterface loopLike) {
-    (void)hoistLoopInvariantSubsets(loopLike);
+    (void)hoistLoopInvariantSubsets(rewriter, loopLike);
   });
 }
 
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index bb4e6dc62b9c935..4f67cca74088248 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -311,12 +311,12 @@ MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
 /// loop-like op and index into loop-invariant subset locations. Return the
 /// newly created loop op (that has extra iter_args) or the original loop op if
 /// nothing was hoisted.
-static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
+static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
+                                                LoopLikeOpInterface loopLike,
                                                 BlockArgument iterArg) {
   assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
   auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
   int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
-  IRRewriter rewriter(loopLike.getContext());
   MatchingSubsets subsets;
   if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg)))
     return loopLike;
@@ -367,11 +367,12 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
       OpResult newLoopResult = loopLike.getLoopResults()->back();
       extractionOp->moveBefore(loopLike);
       insertionOp->moveAfter(loopLike);
-      insertionOp.getUpdatedDestination().replaceAllUsesWith(
-          insertionOp.getDestinationOperand().get());
+      rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
+                                  insertionOp.getDestinationOperand().get());
       extractionOp.getSourceOperand().set(
           loopLike.getTiedLoopInit(iterArg)->get());
-      loopResult.replaceAllUsesWith(insertionOp.getUpdatedDestination());
+      rewriter.replaceAllUsesWith(loopResult,
+                                  insertionOp.getUpdatedDestination());
       insertionOp.getSourceOperand().set(newLoopResult);
       insertionOp.getDestinationOperand().set(loopResult);
     }
@@ -381,12 +382,14 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
 }
 
 LoopLikeOpInterface
-mlir::hoistLoopInvariantSubsets(LoopLikeOpInterface loopLike) {
+mlir::hoistLoopInvariantSubsets(RewriterBase &rewriter,
+                                LoopLikeOpInterface loopLike) {
   // Note: As subset ops are getting hoisted, the number of region iter_args
   // increases. This can enable further hoisting opportunities on the new
   // iter_args.
   for (int64_t i = 0; i < loopLike.getRegionIterArgs().size(); ++i) {
-    loopLike = hoistSubsetAtIterArg(loopLike, loopLike.getRegionIterArgs()[i]);
+    loopLike = hoistSubsetAtIterArg(rewriter, loopLike,
+                                    loopLike.getRegionIterArgs()[i]);
   }
   return loopLike;
 }
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 3891c16b4115595..fd0c0fd2117ab63 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -2109,3 +2109,48 @@ transform.sequence failures(propagate) {
   transform.yield 
 }
 }
+
+// -----
+
+// CHECK-LABEL: func @test_loop_invariant_subset_hoisting(
+//  CHECK-SAME:     %[[arg:.*]]: tensor<?xf32>
+func.func @test_loop_invariant_subset_hoisting(%arg: tensor<?xf32>) -> tensor<?xf32> {
+  %lb = "test.foo"() : () -> (index)
+  %ub = "test.foo"() : () -> (index)
+  %step = "test.foo"() : () -> (index)
+  // CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]]
+  // CHECK: %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]])
+  // expected-remark @below{{new loop op}}
+  %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+    %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+    // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]])
+    %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+    // Obfuscate the IR by inserting at offset %sub instead of 0; both of them
+    // have the same value.
+    %3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
+    // CHECK: scf.yield %[[t]], %[[foo]]
+    scf.yield %3 : tensor<?xf32>
+  }
+  // CHECK: %[[insert:.*]] = tensor.insert_slice %[[for]]#1 into %[[for]]#0
+  // CHECK: return %[[insert]]
+  return %0 : tensor<?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1 = transform.structured.match ops{["tensor.extract_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %2 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+
+  transform.hoist_loop_invariant_subsets %0 : !transform.any_op
+  // Make sure that the handles are still valid (and were updated in case of
+  // the loop).
+
+  // expected-remark @below{{1}}
+  test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op
+  test_print_remark_at_operand %0, "new loop op" : !transform.any_op
+  // expected-remark @below{{1}}
+  test_print_number_of_associated_payload_ir_ops %1 : !transform.any_op
+  // expected-remark @below{{1}}
+  test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op
+}

>From 3d0de8d6ae9adbef0bc1dc0f4bd673088121540a Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 30 Oct 2023 18:12:15 +0900
Subject: [PATCH 7/7] [mlir][linalg] Remove subset hoisting on tensors

---
 .../mlir/Dialect/Linalg/Transforms/Hoisting.h | 103 ----
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   1 -
 .../Linalg/Transforms/HoistPadding.cpp        |   5 +-
 .../Dialect/Linalg/Transforms/Hoisting.cpp    |   9 -
 .../Linalg/Transforms/SubsetHoisting.cpp      | 553 ------------------
 5 files changed, 3 insertions(+), 668 deletions(-)
 delete mode 100644 mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
index d4444c3f869e5cc..921c3c3e8c7db69 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
@@ -45,109 +45,6 @@ namespace linalg {
 /// when used on distributed loops with memref semantics!
 void hoistRedundantVectorTransfers(func::FuncOp func);
 
-/// Greedily hoist redundant subset extract/insert operations on tensors outside
-/// of `forOp`. The logic follows:
-///   1. Look for a write walking back from the `forOp` yield.
-///   2. Check the uses of the matching block argument and look for a matching
-///      read (i.e. extract_slice of transfer_read) with matching indices.
-///   3. In the case of a transfer_write, we can bypass other non-conflicting
-///      operations and find more hoisting opportunities.
-///   4. Hoist the read/write pair and update the tensor SSA links.
-///
-/// Return the unmodified `forOp` if no hoisting occured.
-/// Return a new scf::ForOp if hoisting on tensors occured.
-///
-/// After this transformation the returned scf::ForOp may have unused arguments
-/// that can be removed by application of canonicalization patterns.
-///
-/// Example:
-/// ========
-/// IR Resembling:
-///
-/// ```
-/// %0 = scf.for %i = %l to %u step %s iter_args(%a0 = %t0)->(tensor<10xf32>) {
-///  %1 = scf.for %j = %l to %u step %s iter_args(%a6 = %a0)->(tensor<10xf32>) {
-///   %e = tensor.extract_slice %a6[%i][%sz][1]: tensor<10xf32> to tensor<?xf32>
-///   %r = vector.transfer_read %e[%c0], %cst: tensor<?xf32>, vector<4xf32>
-///   %u = "some_use"(%r) : (vector<4xf32>) -> vector<4xf32>
-///   %w = vector.transfer_write %u, %e[%c0] : vector<4xf32>, tensor<?xf32>
-///   %st = tensor.insert_slice %w into %a6[%i][%sz][1]
-///     : tensor<?xf32> into tensor<10xf32>
-///   scf.yield %st: tensor<10xf32>
-///  }
-///  scf.yield %1: tensor<10xf32>
-/// }
-/// ```
-///
-/// Progressively hoists to:
-///
-/// ```
-/// %0 = scf.for %i = %l to %u step %s iter_args(%a0 = %t0) -> (tensor<10xf32>){
-///  %e = tensor.extract_slice %a0[%i][%sz][1]: tensor<10xf32> to tensor<?xf32>
-///  %1:2 = scf.for %j = %l to %u step %s iter_args(%a6 = a0, %a7 = %e)
-///     -> (tensor<10xf32>, tensor<?xf32>) {
-///   %r = vector.transfer_read %a7[%c0], %cst: tensor<?xf32>, vector<4xf32>
-///   %u = "some_use"(%r) : (vector<4xf32>) -> vector<4xf32>
-///   %w = vector.transfer_write %u, %a7[%c0] : vector<4xf32>, tensor<?xf32>
-///   scf.yield %a6, %w: tensor<10xf32>, tensor<?xf32>
-///  }
-///  %st = tensor.insert_slice %1#1 into %1#0[%i][%sz][1]
-///    : tensor<?xf32> into tensor<10xf32>
-///  scf.yield %1: tensor<10xf32>
-/// }
-/// ```
-///
-/// and
-///
-/// ```
-/// %0 = scf.for %i = %l to %u step %s iter_args(%a0 = %t0) -> (tensor<10xf32>){
-///  %e = tensor.extract_slice %a0[%i][%sz][1]: tensor<10xf32> to tensor<?xf32>
-///  %r = vector.transfer_read %a7[%c0], %cst: tensor<?xf32>, vector<4xf32>
-///  %1:3 = scf.for %j = %l to %u step %s iter_args(%a6 = a0, %a7 = %e, %a7 = r)
-///     -> (tensor<10xf32>, tensor<?xf32>, vector<4xf32>) {
-///   %u = "some_use"(%r) : (vector<4xf32>) -> vector<4xf32>
-///   scf.yield %a6, %a7, %u: tensor<10xf32>, tensor<?xf32>, vector<4xf32>
-///  }
-///  %w = vector.transfer_write %1#2, %1#1[%c0] : vector<4xf32>, tensor<?xf32>
-///  %st = tensor.insert_slice %w into %1#0[%i][%sz][1]
-///    : tensor<?xf32> into tensor<10xf32>
-///  scf.yield %1: tensor<10xf32>
-/// }
-/// ```
-///
-/// It can then canonicalize to:
-///
-/// ```
-/// %0 = scf.for %i = %l to %u step %s iter_args(%a0 = %t0) -> (tensor<10xf32>){
-///  %e = tensor.extract_slice %a0[%i][%sz][1]: tensor<10xf32> to tensor<?xf32>
-///  %r = vector.transfer_read %a7[%c0], %cst: tensor<?xf32>, vector<4xf32>
-///  %1 = scf.for %j = %l to %u step %s iter_args(%a7 = r)
-///     -> (tensor<10xf32>, tensor<?xf32>, vector<4xf32>) {
-///   %u = "some_use"(%r) : (vector<4xf32>) -> vector<4xf32>
-///   scf.yield %u: vector<4xf32>
-///  }
-///  %w = vector.transfer_write %1, %e[%c0] : vector<4xf32>, tensor<?xf32>
-///  %st = tensor.insert_slice %w into %a0[%i][%sz][1]
-///    : tensor<?xf32> into tensor<10xf32>
-///  scf.yield %1: tensor<10xf32>
-/// }
-/// ```
-///
-// TODO: This should be further generalized along a few different axes:
-//   - Other loops than scf.ForOp that operate on tensors (both sequential and
-//     parallel loops).
-//   - Other subset extract/insert pairs than tensor.extract/insert_slice and
-//     vector.transfer_read/write.
-//   - More general areSubsetDisjoint analysis/interface to work across all
-//     subset op types and allow bypassing non-WAW-conflicting operations in
-//     more cases.
-scf::ForOp hoistRedundantSubsetExtractInsert(RewriterBase &rewriter,
-                                             scf::ForOp forOp);
-
-/// Call into `hoistRedundantSubsetInsertExtract` without a RewriterBase.
-// TODO: obsolete and should be retired
-void hoistRedundantVectorTransfersOnTensor(func::FuncOp func);
-
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index bb90e5ee546d113..e5776a24d096f55 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -26,7 +26,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   Promotion.cpp
   Split.cpp
   SplitReduction.cpp
-  SubsetHoisting.cpp
   SubsetInsertionOpInterfaceImpl.cpp
   SwapExtractSliceWithFillPatterns.cpp
   Tiling.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 19f704f5232ed81..facb71d756877ce 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -25,6 +25,7 @@
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/Support/Debug.h"
 
@@ -292,8 +293,8 @@ void HoistPaddingAnalysis::enableHoistPadding(RewriterBase &rewriter) {
   // enclosing loop, try to apply hoisting on this outermost loop.
   // TODO: we may want finer-grained hoisting of only that particular `sliceOp`.
   if (!outermostEnclosingForOp.isDefinedOutsideOfLoop(sliceOp.getSource())) {
-    outermostEnclosingForOp =
-        hoistRedundantSubsetExtractInsert(rewriter, outermostEnclosingForOp);
+    outermostEnclosingForOp = cast<scf::ForOp>(
+        hoistLoopInvariantSubsets(rewriter, outermostEnclosingForOp));
   }
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index cbb2c507de69f9e..80ce97ee3437a5f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -43,15 +43,6 @@ using llvm::dbgs;
 using namespace mlir;
 using namespace mlir::linalg;
 
-void mlir::linalg::hoistRedundantVectorTransfersOnTensor(func::FuncOp func) {
-  IRRewriter rewriter(func->getContext());
-  // TODO: walking in some reverse / inside-out order would be more efficient
-  // and would capture more cases.
-  func.walk([&](scf::ForOp forOp) {
-    hoistRedundantSubsetExtractInsert(rewriter, forOp);
-  });
-}
-
 static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
                                 LoopLikeOpInterface loop) {
   Value source = transferRead.getSource();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp
deleted file mode 100644
index 91e0d139ec5c2f0..000000000000000
--- a/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp
+++ /dev/null
@@ -1,553 +0,0 @@
-//===- SubsetHoisting.cpp - Linalg hoisting transformations----------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements functions concerned with hoisting invariant subset
-// operations in the context of Linalg transformations.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/SCF/Utils/Utils.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Utils/StaticValueUtils.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
-#include "llvm/Support/Debug.h"
-#include "llvm/Support/ErrorHandling.h"
-
-#define DEBUG_TYPE "subset-hoisting"
-
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-
-using namespace mlir;
-using namespace mlir::linalg;
-
-/// Return true if the location of the subset defined by the op is invariant of
-/// the loop iteration.
-static bool
-isSubsetLocationLoopInvariant(scf::ForOp forOp,
-                              vector::TransferWriteOp transferWriteOp) {
-  for (Value operand : transferWriteOp.getIndices())
-    if (!forOp.isDefinedOutsideOfLoop(operand))
-      return false;
-  return true;
-}
-
-/// Return true if the location of the subset defined by the op is invariant of
-/// the loop iteration.
-static bool isSubsetLocationLoopInvariant(scf::ForOp forOp,
-                                          tensor::InsertSliceOp insertSliceOp) {
-  for (Value operand : insertSliceOp->getOperands().drop_front(
-           tensor::InsertSliceOp::getOffsetSizeAndStrideStartOperandIndex()))
-    if (!forOp.isDefinedOutsideOfLoop(operand))
-      return false;
-  return true;
-}
-
-/// Given an `srcTensor` that is a block argument belong to a loop.
-/// Greedily look for the first read that can be hoisted out of the loop (i.e.
-/// that satisfied the conditions):
-///   - The read is of type `tensor.extract_slice`.
-///   - The read is one of the uses of `srcTensor`.
-///   - The read is to the same subset that `tensor.insert_slice` writes.
-// TODO: Unify implementations once the "bypassing behavior" is the same.
-static FailureOr<tensor::ExtractSliceOp>
-findHoistableMatchingExtractSlice(RewriterBase &rewriter,
-                                  tensor::InsertSliceOp insertSliceOp,
-                                  BlockArgument srcTensor) {
-  assert(isa<RankedTensorType>(srcTensor.getType()) && "not a ranked tensor");
-
-  auto forOp = cast<scf::ForOp>(srcTensor.getOwner()->getParentOp());
-
-  LLVM_DEBUG(DBGS() << "--find matching read for: " << insertSliceOp << "\n";
-             DBGS() << "--amongst users of: " << srcTensor << "\n");
-
-  SmallVector<Operation *> users(srcTensor.getUsers());
-  if (forOp.isDefinedOutsideOfLoop(insertSliceOp.getDest()))
-    llvm::append_range(users, insertSliceOp.getDest().getUsers());
-
-  for (Operation *user : users) {
-    LLVM_DEBUG(DBGS() << "----inspect user: " << *user << "\n");
-    auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
-    // Skip ops other than extract_slice with an exact matching of their tensor
-    // subset.
-    if (extractSliceOp) {
-      auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
-      if (extractSliceOp.getResultType() != insertSliceOp.getSourceType() ||
-          !extractSliceOp.isSameAs(insertSliceOp, isSame)) {
-        LLVM_DEBUG(DBGS() << "------not a matching extract_slice\n";
-                   DBGS() << *user << " vs " << *insertSliceOp << "\n");
-        continue;
-      }
-
-      // Skip insert_slice whose vector is defined within the loop: we need to
-      // hoist that definition first otherwise dominance violations trigger.
-      if (!isa<BlockArgument>(extractSliceOp.getSource()) &&
-          !forOp.isDefinedOutsideOfLoop(extractSliceOp.getSource())) {
-        LLVM_DEBUG(DBGS() << "------transfer_read vector is loop-dependent\n");
-        continue;
-      }
-      return extractSliceOp;
-    }
-
-    // TODO: Look through disjoint subsets, similar to vector.transfer_write
-    // and unify implementations.
-  }
-
-  LLVM_DEBUG(DBGS() << "----no matching extract_slice");
-  return failure();
-}
-
-/// Given an `srcTensor` that is a block argument belong to a loop.
-/// Greedily look for the first read that can be hoisted out of the loop (i.e.
-/// that satisfied the conditions):
-///   - The read is of type `tensor.transfer_read`.
-///   - The read is one of the uses of `srcTensor`.
-///   - The read is to the same subset that `tensor.transfer_write` writes.
-// TODO: Unify implementations once the "bypassing behavior" is the same.
-static FailureOr<vector::TransferReadOp>
-findHoistableMatchingTransferRead(RewriterBase &rewriter,
-                                  vector::TransferWriteOp transferWriteOp,
-                                  BlockArgument srcTensor) {
-  if (!isa<RankedTensorType>(srcTensor.getType()))
-    return failure();
-
-  auto forOp = cast<scf::ForOp>(srcTensor.getOwner()->getParentOp());
-
-  LLVM_DEBUG(DBGS() << "--find matching read for: " << transferWriteOp << "\n";
-             DBGS() << "--amongst users of: " << srcTensor << "\n";);
-
-  // vector.transfer_write is a bit peculiar: we look through dependencies
-  // to disjoint tensor subsets. This requires a while loop.
-  // TODO: Look through disjoint subsets for tensor.insert_slice and unify
-  // implementations.
-  SmallVector<Operation *> users(srcTensor.getUsers());
-  // TODO: transferWriteOp.getSource is actually the destination tensor!!
-  if (forOp.isDefinedOutsideOfLoop(transferWriteOp.getSource()))
-    llvm::append_range(users, transferWriteOp.getSource().getUsers());
-  while (!users.empty()) {
-    Operation *user = users.pop_back_val();
-    LLVM_DEBUG(DBGS() << "----inspect user: " << *user << "\n");
-    auto read = dyn_cast<vector::TransferReadOp>(user);
-    if (read) {
-      // Skip ops other than transfer_read with an exact matching subset.
-      if (read.getIndices() != transferWriteOp.getIndices() ||
-          read.getVectorType() != transferWriteOp.getVectorType()) {
-        LLVM_DEBUG(DBGS() << "------not a transfer_read that matches the "
-                             "transfer_write: "
-                          << *user << "\n\t(vs " << *transferWriteOp << ")\n");
-        continue;
-      }
-
-      // transfer_read may be of a vector that is defined within the loop: we
-      // traverse it by virtue of bypassing disjoint subset operations rooted at
-      // a bbArg and yielding a matching yield.
-      if (!isa<BlockArgument>(read.getSource()) &&
-          !forOp.isDefinedOutsideOfLoop(read.getSource())) {
-        LLVM_DEBUG(DBGS() << "------transfer_read vector appears loop "
-                             "dependent but will be tested for disjointness as "
-                             "part of the bypass analysis\n");
-      }
-      LLVM_DEBUG(DBGS() << "------found match\n");
-      return read;
-    }
-
-    // As an optimization, we look further through dependencies to disjoint
-    // tensor subsets. This creates more opportunities to find a matching read.
-    if (isa<vector::TransferWriteOp>(user)) {
-      // If we find a write with disjoint indices append all its uses.
-      // TODO: Generalize areSubsetsDisjoint and allow other bypass than
-      // just vector.transfer_write - vector.transfer_write.
-      if (vector::isDisjointTransferIndices(
-              cast<VectorTransferOpInterface>(user),
-              cast<VectorTransferOpInterface>(
-                  transferWriteOp.getOperation()))) {
-        LLVM_DEBUG(DBGS() << "----follow through disjoint write\n");
-        users.append(user->getUsers().begin(), user->getUsers().end());
-      } else {
-        LLVM_DEBUG(DBGS() << "----skip non-disjoint write\n");
-      }
-    }
-  }
-
-  LLVM_DEBUG(DBGS() << "--no matching transfer_read\n");
-  return rewriter.notifyMatchFailure(transferWriteOp,
-                                     "no matching transfer_read");
-}
-
-/// Return the `vector.transfer_write` that produces `yieldOperand`, if:
-///   - The write operates on tensors.
-///   - All indices are defined outside of the loop.
-/// Return failure otherwise.
-///
-/// This is sufficient condition to hoist the `vector.transfer_write`; other
-/// operands can always be yielded by the loop where needed.
-// TODO: generalize beyond scf::ForOp.
-// TODO: Unify implementations once the "bypassing behavior" is the same.
-static FailureOr<vector::TransferWriteOp>
-getLoopInvariantTransferWriteDefining(RewriterBase &rewriter, scf::ForOp forOp,
-                                      BlockArgument bbArg,
-                                      OpOperand &yieldOperand) {
-  assert(bbArg.getArgNumber() ==
-             forOp.getNumInductionVars() + yieldOperand.getOperandNumber() &&
-         "bbArg and yieldOperand must match");
-  assert(isa<scf::YieldOp>(yieldOperand.getOwner()) && "must be an scf.yield");
-
-  Value v = yieldOperand.get();
-  auto transferWriteOp = v.getDefiningOp<vector::TransferWriteOp>();
-  if (!transferWriteOp)
-    return rewriter.notifyMatchFailure(v.getLoc(), "not a transfer_write");
-
-  if (transferWriteOp->getNumResults() == 0) {
-    return rewriter.notifyMatchFailure(v.getLoc(),
-                                       "unsupported transfer_write on buffers");
-  }
-
-  // We do not explicitly check that the destination is a BBarg that matches the
-  // yield operand as this would prevent us from bypassing other non-conflicting
-  // writes.
-
-  // Indexing must not depend on `forOp`.
-  if (!isSubsetLocationLoopInvariant(forOp, transferWriteOp))
-    return rewriter.notifyMatchFailure(
-        v.getLoc(), "transfer_write indexing is loop-dependent");
-
-  return transferWriteOp;
-}
-
-/// Return the `tensor.insert_slice` that produces `yieldOperand`, if:
-///   1. Its destination tensor is a block argument of the `forOp`.
-///   2. The unique use of its result is a yield with operand number matching
-///   the block argument.
-///   3. All indices are defined outside of the loop.
-/// Return failure otherwise.
-///
-/// This is sufficient condition to hoist the `tensor.insert_slice`; other
-/// operands can always be yielded by the loop where needed.
-/// Note: 1. + 2. ensure that the yield / iter_args cycle results in proper
-/// semantics (i.e. no ping-ping between iter_args across iterations).
-// TODO: generalize beyond scf::ForOp.
-// TODO: Unify implementations once the "bypassing behavior" is the same.
-static FailureOr<tensor::InsertSliceOp>
-getLoopInvariantInsertSliceDefining(RewriterBase &rewriter, scf::ForOp forOp,
-                                    BlockArgument bbArg,
-                                    OpOperand &yieldOperand) {
-  assert(bbArg.getArgNumber() ==
-             forOp.getNumInductionVars() + yieldOperand.getOperandNumber() &&
-         "bbArg and yieldOperand must match");
-  assert(isa<scf::YieldOp>(yieldOperand.getOwner()) && "must be an scf.yield");
-
-  Value v = yieldOperand.get();
-  auto insertSliceOp = v.getDefiningOp<tensor::InsertSliceOp>();
-  if (!insertSliceOp)
-    return rewriter.notifyMatchFailure(v.getLoc(), "not an insert_slice");
-
-  // Tensor inserted into must be a BBArg at position matching yield operand.
-  // TODO: In the future we should not perform this check if we want to bypass
-  // other non-conflicting writes.
-  if (bbArg != insertSliceOp.getDest())
-    return rewriter.notifyMatchFailure(v.getLoc(), "not a matching bbarg");
-
-  // Indexing inserted into must not depend on `forOp`.
-  if (!isSubsetLocationLoopInvariant(forOp, insertSliceOp))
-    return rewriter.notifyMatchFailure(
-        v.getLoc(), "insert_slice indexing is loop-dependent");
-
-  return insertSliceOp;
-}
-
-/// Check if the chunk of data inserted by the `writeOp` is read by any other
-/// op than the candidateReadOp. This conflicting operation prevents hoisting,
-/// return it or nullptr if none is found.
-// TODO: Generalize subset disjunction analysis/interface.
-// TODO: Support more subset op types.
-static Operation *isTensorChunkAccessedByUnknownOp(Operation *writeOp,
-                                                   Operation *candidateReadOp,
-                                                   BlockArgument tensorArg) {
-  // Make sure none of the other uses read the part of the tensor modified
-  // by the transfer_write.
-  llvm::SmallVector<Value::use_range, 1> uses;
-  uses.push_back(tensorArg.getUses());
-  while (!uses.empty()) {
-    for (OpOperand &use : uses.pop_back_val()) {
-      Operation *user = use.getOwner();
-      // Skip the candidate use, only inspect the "other" uses.
-      if (user == candidateReadOp || user == writeOp)
-        continue;
-
-      // TODO: Consider all transitive uses through
-      // extract_slice/insert_slice. Atm we just bail because a stronger
-      // analysis is needed for these cases.
-      if (isa<tensor::ExtractSliceOp, tensor::InsertSliceOp>(user))
-        return user;
-
-      // Consider all transitive uses through a vector.transfer_write.
-      if (isa<vector::TransferWriteOp>(writeOp)) {
-        if (auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) {
-          uses.push_back(writeUser->getResult(0).getUses());
-          continue;
-        }
-      }
-
-      // Consider all nested uses through an scf::ForOp. We may have
-      // pass-through tensor arguments left from previous level of
-      // hoisting.
-      if (auto forUser = dyn_cast<scf::ForOp>(user)) {
-        Value arg = forUser.getBody()->getArgument(
-            use.getOperandNumber() - forUser.getNumControlOperands() +
-            /*iv value*/ 1);
-        uses.push_back(arg.getUses());
-        continue;
-      }
-
-      // Follow the use yield, only if it doesn't escape the original region.
-      scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user);
-      if (yieldUser &&
-          writeOp->getParentOp()->isAncestor(yieldUser->getParentOp())) {
-        Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber());
-        uses.push_back(ret.getUses());
-        continue;
-      }
-
-      // If the write is a vector::TransferWriteOp, it may have been bypassed
-      // and we need to check subset disjunction
-      if (isa<vector::TransferWriteOp>(writeOp)) {
-        auto read = dyn_cast<vector::TransferReadOp>(user);
-        if (!read || !vector::isDisjointTransferIndices(
-                         cast<VectorTransferOpInterface>(read.getOperation()),
-                         cast<VectorTransferOpInterface>(writeOp))) {
-          return user;
-        }
-      }
-    }
-  }
-  return nullptr;
-}
-
-/// Mechanical hoisting of a matching read / write pair.
-/// Return the newly created scf::ForOp with an extra yields.
-// TODO: Unify implementations once the "bypassing behavior" is the same.
-static scf::ForOp hoistTransferReadWrite(
-    RewriterBase &rewriter, vector::TransferReadOp transferReadOp,
-    vector::TransferWriteOp transferWriteOp, BlockArgument tensorBBArg) {
-  scf::ForOp forOp = cast<scf::ForOp>(tensorBBArg.getOwner()->getParentOp());
-  LLVM_DEBUG(DBGS() << "--Start hoisting\n";
-             DBGS() << "--Hoist read : " << transferReadOp << "\n";
-             DBGS() << "--Hoist write: " << transferWriteOp << "\n";
-             DBGS() << "--Involving  : " << tensorBBArg << "\n");
-
-  // TODO: don't hardcode /*numIvs=*/1.
-  assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1);
-  int64_t initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1;
-
-  // 1. Hoist the read op. Thanks to our previous checks we know this will not
-  // trigger dominance violations once BBArgs are updated.
-  // TODO: should the rewriter ever want to track this move ?
-  transferReadOp->moveBefore(forOp);
-  if (!forOp.isDefinedOutsideOfLoop(transferReadOp.getSource())) {
-    rewriter.startRootUpdate(transferReadOp);
-    transferReadOp.getSourceMutable().assign(
-        forOp.getInitArgs()[initArgNumber]);
-    rewriter.finalizeRootUpdate(transferReadOp);
-  }
-
-  // 2. Rewrite `loop` with an additional yield. This is the quantity that is
-  // computed iteratively but whose storage has become loop-invariant.
-  NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
-                                 ArrayRef<BlockArgument> newBBArgs) {
-    return SmallVector<Value>{transferWriteOp.getVector()};
-  };
-  auto newForOp = cast<scf::ForOp>(*forOp.replaceWithAdditionalYields(
-      rewriter, {transferReadOp.getVector()},
-      /*replaceInitOperandUsesInLoop=*/true, yieldFn));
-
-  // 3. Update the yield. Invariant: initArgNumber is the destination tensor.
-  auto yieldOp =
-      cast<scf::YieldOp>(newForOp.getRegion().front().getTerminator());
-  // TODO: transferWriteOp.getSource is actually the destination tensor!!
-  rewriter.startRootUpdate(yieldOp);
-  yieldOp->setOperand(initArgNumber, transferWriteOp.getSource());
-  rewriter.finalizeRootUpdate(yieldOp);
-
-  // 4. Hoist write after and make uses of newForOp.getResult(initArgNumber)
-  // flow through it.
-  // TODO: should the rewriter ever want to track this move ?
-  transferWriteOp->moveAfter(newForOp);
-  rewriter.startRootUpdate(transferWriteOp);
-  transferWriteOp.getVectorMutable().assign(newForOp.getResults().back());
-  // TODO: transferWriteOp.getSource is actually the destination tensor!!
-  transferWriteOp.getSourceMutable().assign(newForOp.getResult(initArgNumber));
-  rewriter.finalizeRootUpdate(transferWriteOp);
-  rewriter.replaceAllUsesExcept(newForOp.getResult(initArgNumber),
-                                transferWriteOp.getResult(), transferWriteOp);
-  return newForOp;
-}
-
-/// Mechanical hoisting of a matching read / write pair.
-/// Return the newly created scf::ForOp with an extra yields.
-// TODO: Unify implementations once the "bypassing behavior" is the same.
-static scf::ForOp hoistExtractInsertSlice(RewriterBase &rewriter,
-                                          tensor::ExtractSliceOp extractSliceOp,
-                                          tensor::InsertSliceOp insertSliceOp,
-                                          BlockArgument tensorBBArg) {
-  scf::ForOp forOp = cast<scf::ForOp>(tensorBBArg.getOwner()->getParentOp());
-  LLVM_DEBUG(DBGS() << "--Start hoisting\n";
-             DBGS() << "--Hoist read : " << extractSliceOp << "\n";
-             DBGS() << "--Hoist write: " << insertSliceOp << "\n";
-             DBGS() << "--Involving  : " << tensorBBArg << "\n");
-
-  // TODO: don't hardcode /*numIvs=*/1.
-  assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1);
-  int64_t initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1;
-
-  // 1. Hoist the read op. Thanks to our previous checks we know this will not
-  // trigger dominance violations once BBArgs are updated.
-  // TODO: should the rewriter ever want to track this move ?
-  extractSliceOp->moveBefore(forOp);
-  if (!forOp.isDefinedOutsideOfLoop(extractSliceOp.getSource())) {
-    assert(extractSliceOp.getSource() == tensorBBArg &&
-           "extractSlice source not defined above must be the tracked bbArg");
-    rewriter.startRootUpdate(extractSliceOp);
-    extractSliceOp.getSourceMutable().assign(
-        forOp.getInitArgs()[initArgNumber]);
-    rewriter.finalizeRootUpdate(extractSliceOp);
-  }
-
-  // 2. Rewrite `loop` with an additional yield. This is the quantity that is
-  // computed iteratively but whose storage has become loop-invariant.
-  NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
-                                 ArrayRef<BlockArgument> newBBArgs) {
-    return SmallVector<Value>{insertSliceOp.getSource()};
-  };
-  auto newForOp = cast<scf::ForOp>(*forOp.replaceWithAdditionalYields(
-      rewriter, extractSliceOp.getResult(),
-      /*replaceInitOperandUsesInLoop=*/true, yieldFn));
-
-  // 3. Update the yield. Invariant: initArgNumber is the destination tensor.
-  auto yieldOp =
-      cast<scf::YieldOp>(newForOp.getRegion().front().getTerminator());
-  // TODO: should the rewriter ever want to track this ?
-  rewriter.startRootUpdate(yieldOp);
-  yieldOp->setOperand(initArgNumber, insertSliceOp.getDest());
-  rewriter.finalizeRootUpdate(yieldOp);
-
-  // 4. Hoist write after and make uses of newForOp.getResult(initArgNumber)
-  // flow through it.
-  // TODO: should the rewriter ever want to track this move ?
-  insertSliceOp->moveAfter(newForOp);
-  rewriter.startRootUpdate(insertSliceOp);
-  insertSliceOp.getSourceMutable().assign(newForOp.getResults().back());
-  insertSliceOp.getDestMutable().assign(newForOp.getResult(initArgNumber));
-  rewriter.finalizeRootUpdate(insertSliceOp);
-  rewriter.replaceAllUsesExcept(newForOp.getResult(initArgNumber),
-                                insertSliceOp.getResult(), insertSliceOp);
-  return newForOp;
-}
-
-/// Greedily hoist redundant subset extract/insert operations on tensors
-/// outside `forOp`.
-/// Return the unmodified `forOp` if no hoisting occurred.
-/// Return a new scf::ForOp if hoisting on tensors occurred.
-scf::ForOp
-mlir::linalg::hoistRedundantSubsetExtractInsert(RewriterBase &rewriter,
-                                                scf::ForOp forOp) {
-  LLVM_DEBUG(DBGS() << "Enter hoistRedundantSubsetExtractInsert scf.for\n");
-  Operation *yield = forOp.getBody()->getTerminator();
-
-  LLVM_DEBUG(DBGS() << "\n"; DBGS() << "Consider " << forOp << "\n");
-
-  scf::ForOp newForOp = forOp;
-  do {
-    forOp = newForOp;
-    for (const auto &it : llvm::enumerate(forOp.getRegionIterArgs())) {
-      LLVM_DEBUG(DBGS() << "Consider " << it.value() << "\n");
-
-      // 1. Find a loop invariant subset write yielding `ret` that we can
-      // consider for hoisting.
-      // TODO: TypeSwitch when we add more cases.
-      OpOperand &ret = yield->getOpOperand(it.index());
-      FailureOr<vector::TransferWriteOp> transferWriteOp =
-          getLoopInvariantTransferWriteDefining(rewriter, forOp, it.value(),
-                                                ret);
-      FailureOr<tensor::InsertSliceOp> insertSliceOp =
-          getLoopInvariantInsertSliceDefining(rewriter, forOp, it.value(), ret);
-      if (failed(transferWriteOp) && failed(insertSliceOp)) {
-        LLVM_DEBUG(DBGS() << "no loop invariant write defining iter_args "
-                          << it.value() << "\n");
-        continue;
-      }
-
-      Operation *writeOp = succeeded(transferWriteOp)
-                               ? transferWriteOp->getOperation()
-                               : insertSliceOp->getOperation();
-
-      // 2. Only accept writes with a single use (i.e. the yield).
-      if (!writeOp->hasOneUse()) {
-        LLVM_DEBUG(DBGS() << "write with more than 1 use " << *writeOp << "\n");
-        continue;
-      }
-
-      LLVM_DEBUG(DBGS() << "Write to hoist: " << *writeOp << "\n");
-
-      // 3. Find a matching read that can also be hoisted.
-      Operation *matchingReadOp = nullptr;
-      // TODO: TypeSwitch.
-      if (succeeded(transferWriteOp)) {
-        auto maybeTransferRead = findHoistableMatchingTransferRead(
-            rewriter, *transferWriteOp, it.value());
-        if (succeeded(maybeTransferRead))
-          matchingReadOp = maybeTransferRead->getOperation();
-      } else if (succeeded(insertSliceOp)) {
-        auto maybeExtractSlice = findHoistableMatchingExtractSlice(
-            rewriter, *insertSliceOp, it.value());
-        if (succeeded(maybeExtractSlice))
-          matchingReadOp = maybeExtractSlice->getOperation();
-      } else {
-        llvm_unreachable("unexpected case");
-      }
-      if (!matchingReadOp) {
-        LLVM_DEBUG(DBGS() << "No matching read\n");
-        continue;
-      }
-
-      // 4. Make sure no other use reads the part of the modified tensor.
-      // This is necessary to guard against hazards when non-conflicting subset
-      // ops are bypassed.
-      Operation *maybeUnknownOp =
-          isTensorChunkAccessedByUnknownOp(writeOp, matchingReadOp, it.value());
-      if (maybeUnknownOp) {
-        LLVM_DEBUG(DBGS() << "Tensor chunk accessed by unknown op, skip: "
-                          << *maybeUnknownOp << "\n");
-        continue;
-      }
-
-      // 5. Perform the actual mechanical hoisting.
-      // TODO: TypeSwitch.
-      LLVM_DEBUG(DBGS() << "Read to hoist: " << *matchingReadOp << "\n");
-      if (succeeded(transferWriteOp)) {
-        newForOp = hoistTransferReadWrite(
-            rewriter, cast<vector::TransferReadOp>(matchingReadOp),
-            *transferWriteOp, it.value());
-      } else if (succeeded(insertSliceOp)) {
-        newForOp = hoistExtractInsertSlice(
-            rewriter, cast<tensor::ExtractSliceOp>(matchingReadOp),
-            *insertSliceOp, it.value());
-      } else {
-        llvm_unreachable("unexpected case");
-      }
-      break;
-    }
-  } while (forOp != newForOp);
-
-  return newForOp;
-}



More information about the Mlir-commits mailing list