[llvm] [mlir][bufferization] Empty tensor elimination for materialize_in_destination (PR #65468)

Matthias Springer via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 8 07:50:17 PDT 2023


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/65468:

>From b4277bb5542483d776496cc63b250ade66a48904 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 8 Sep 2023 16:47:58 +0200
Subject: [PATCH 1/4] [mlir][bufferization] Generalize tensor slice rules to
 subset ops

This commit generalizes the special tensor.extract_slice/tensor.insert_slice bufferization rules to tensor subset ops.

Ops that insert a tensor into a tensor at a specified subset (e.g., tensor.insert_slice, tensor.scatter) can implement the `SubsetOpInterface`.

Apart from adding a new op interface (extending the API), this change is NFC. The only ops that currently implement the new interface are tensor.insert_slice and tensor.parallel_insert_slice, and those ops were are supported by One-Shot Bufferize.
---
 .../Dialect/Bufferization/IR/CMakeLists.txt   |   1 +
 .../Bufferization/IR/SubsetOpInterface.h      |  31 +++
 .../Bufferization/IR/SubsetOpInterface.td     | 115 +++++++++++
 .../Dialect/Bufferization/IR/CMakeLists.txt   |   1 +
 .../Bufferization/IR/SubsetOpInterface.cpp    |  23 +++
 .../Transforms/OneShotAnalysis.cpp            | 104 ++++++++++
 .../BufferizableOpInterfaceImpl.cpp           | 190 ++++++------------
 .../llvm-project-overlay/mlir/BUILD.bazel     |  37 +++-
 8 files changed, 375 insertions(+), 127 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.h
 create mode 100644 mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.td
 create mode 100644 mlir/lib/Dialect/Bufferization/IR/SubsetOpInterface.cpp

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
index aa93534a78fea3f..4a97493cbf42a49 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect(BufferizationOps bufferization)
 add_mlir_doc(BufferizationOps BufferizationOps Dialects/ -gen-dialect-doc)
 add_mlir_interface(AllocationOpInterface)
 add_mlir_interface(BufferizableOpInterface)
+add_mlir_interface(SubsetOpInterface)
 
 set(LLVM_TARGET_DEFINITIONS BufferizationEnums.td)
 mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.h
new file mode 100644
index 000000000000000..e18b00e2a4afa8f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.h
@@ -0,0 +1,31 @@
+//===- SubsetOpInterface.h - Tensor subsets ---------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_IR_SUBSETOPINTERFACE_H_
+#define MLIR_DIALECT_BUFFERIZATION_IR_SUBSETOPINTERFACE_H_
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+class OpBuilder;
+
+namespace bufferization {
+namespace detail {
+
+/// Return the destination/"init" operand of the op if it implements the
+/// `DestinationStyleOpInterface` and has exactly one "init" operand. Asserts
+/// otherwise.
+OpOperand &defaultGetDestinationOperand(Operation *op);
+
+} // namespace detail
+} // namespace bufferization
+} // namespace mlir
+
+#include "mlir/Dialect/Bufferization/IR/SubsetOpInterface.h.inc"
+
+#endif // MLIR_DIALECT_BUFFERIZATION_IR_SUBSETOPINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.td
new file mode 100644
index 000000000000000..97dd4d2102a9323
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.td
@@ -0,0 +1,115 @@
+//===-- SubsetOpInterface.td - Tensor subsets --------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SUBSET_OP_INTERFACE
+#define SUBSET_OP_INTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def SubsetOpInterface : OpInterface<"SubsetOpInterface"> {
+  let description = [{
+    This interface can be implemented by ops that insert a source tensor into
+    a destination tensor.
+
+    The elements in the destination tensor that are overwritten by this
+    insertion are called the "subset". How the subset is defined is up to the
+    op. E.g., "tensor.insert_slice" defines the subset via a hyperrectangular
+    slice. A scatter operation could define the subset via a list of indices.
+
+    Ops that deal with tensor subsets come in two flavours:
+    - Insertion flavor: Ops that insert a source tensor into a destination
+      tensor at the specified subset. Such ops usuallt return a new destination
+      tensor and implement the `DestinationStyleOpInterface`. Insertion ops can
+      implement the `SubsetOpInterface`. Example: "tensor.insert_slice"
+    - Extraction flavor: Ops that define a tensor subset. They extract a
+      specified subset from a tensor. There is currently no op interface for
+      such ops. Example: "tensor.extract_slice"
+
+    This interface provides helper methods for efficient bufferization of
+    subset-based tensor IR. Tensor subsets can bufferize to buffer "views"/
+    "aliases" (in contrast to one or multiple less efficient buffer allocation).
+
+    This interface is queried by One-Shot Bufferize to detect cases where a
+    seeming read-after-write is not actually a conflict because the respective
+    ops are operating on equivalent subsets. More details can be found in the
+    documentation of One-Shot Analysis.
+
+    Note: This interface currently assumes that a subset op inserts a single
+    tensor (source) into a destination tensor at a single subset.
+  }];
+  let cppNamespace = "::mlir::bufferization";
+  let methods = [
+      InterfaceMethod<
+        /*desc=*/[{
+          Return the source tensor operand.
+        }],
+        /*retType=*/"::mlir::OpOperand &",
+        /*methodName=*/"getSourceOperand",
+        /*args=*/(ins)
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return the destination tensor operand.
+        }],
+        /*retType=*/"::mlir::OpOperand &",
+        /*methodName=*/"getDestinationOperand",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return ::mlir::bufferization::detail::defaultGetDestinationOperand(
+              $_op.getOperation());
+        }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return "true" if this operation inserts into a subset that is
+          equivalent to the subset defined by `candidate`.
+
+          Two subsets are "equivalent" and "same" if they can bufferize to the
+          same buffer views/aliases. If they are "equivalent", the tensor IR
+          may be expressed in terms of different SSA values.
+
+          Example:
+          ```
+          // The subset of the SubsetOpInterface op %1 is equivalent to the
+          // subset defined by %2 (but not "same"):
+          %0 = arith.select %c, %t, %t : tensor<?xf32>
+          %1 = tensor.insert_slice %x into %0[0][5][1]
+              : tensor<5xf32> into tensor<?xf32>
+          %2 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+
+          // The subset of the SubsetOpInterface op %1 is equivalent to and
+          // "same" as the subset defined by %2.
+          %1 = tensor.insert_slice %x into %t[0][5][1]
+              : tensor<5xf32> into tensor<?xf32>
+          %2 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+          ```
+        }],
+        /*retType=*/"bool",
+        /*methodName=*/"isEquivalentSubset",
+        /*args=*/(ins
+            "::mlir::Value":$candidate,
+            "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn)
+      >,
+  ];
+
+  let extraClassDeclaration = [{
+    /// Return "true" if this operation inserts into the same subset as defined
+    /// by `candidate`.
+    ///
+    /// Note: This function is useful outside of bufferization, where no tensor
+    /// equivalence information is available.
+    bool isSameSubset(OpResult candidate) {
+      return cast<::mlir::bufferization::SubsetOpInterface>(getOperation())
+          .isEquivalentSubset(candidate,
+                              [](Value v1, Value v2) { return v1 == v2; });
+    }
+  }];
+}
+
+#endif // INFER_DESTINATION_OP_INTERFACE
diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
index 2d8d09b9c41d993..f406595fc1e9fc9 100644
--- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
   BufferizableOpInterface.cpp
   BufferizationOps.cpp
   BufferizationDialect.cpp
+  SubsetOpInterface.cpp
   UnstructuredControlFlow.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Bufferization/IR/SubsetOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/SubsetOpInterface.cpp
new file mode 100644
index 000000000000000..7c5920d8ef23f02
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/IR/SubsetOpInterface.cpp
@@ -0,0 +1,23 @@
+//===- SubsetOpInterface.cpp - Tensor subsets -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/SubsetOpInterface.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+
+#include "mlir/Dialect/Bufferization/IR/SubsetOpInterface.cpp.inc"
+
+using namespace mlir;
+
+OpOperand &bufferization::detail::defaultGetDestinationOperand(Operation *op) {
+  auto dstOp = dyn_cast<DestinationStyleOpInterface>(op);
+  assert(dstOp && "getDestination must be implemented for non-DPS ops");
+  assert(
+      dstOp.getNumDpsInits() == 1 &&
+      "getDestination must be implemented for ops with 0 or more than 1 init");
+  return *dstOp.getDpsInitOperand(0);
+}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 49b5ebdf722a1a7..6ce3df5903eebbe 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -45,6 +45,7 @@
 
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/SubsetOpInterface.h"
 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -531,6 +532,103 @@ static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state,
               .empty();
 }
 
+/// Return "true" if `value` is originating from a subset that is equivalent to
+/// the subset that `subsetOp` inserts into.
+static bool matchesInsertDestination(const AnalysisState &state, Value value,
+                                     SubsetOpInterface subsetOp) {
+  auto matchingSubset = [&](Value val) {
+    if (auto opResult = dyn_cast<OpResult>(val))
+      if (subsetOp.isEquivalentSubset(opResult, [&](Value v1, Value v2) {
+            return state.areEquivalentBufferizedValues(v1, v2);
+          }))
+        return true;
+    return false;
+  };
+  // There may be multiple leaves at which the reverse SSA use-def chain lookup
+  // terminates. All of them must be equivalent subsets.
+  SetVector<Value> backwardSlice =
+      state.findValueInReverseUseDefChain(value, matchingSubset);
+  return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset));
+}
+
+/// Return "true" if the given "read" and potentially conflicting "write" are
+/// not conflicting due to their subset relationship. The comments in this
+/// function are expressed in terms of tensor.extract_slice/tensor.insert_slice
+/// pairs, but apply to any subset ops that implement the `SubsetOpInterface`.
+static bool areNonConflictingSubsets(OpOperand *uRead,
+                                     OpOperand *uConflictingWrite,
+                                     const AnalysisState &state) {
+  Operation *readingOp = uRead->getOwner();
+  Operation *conflictingWritingOp = uConflictingWrite->getOwner();
+
+  // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
+  // uRead is an InsertSliceOp...
+  if (auto subsetOp = dyn_cast<SubsetOpInterface>(readingOp)) {
+    // As an example, consider the following IR.
+    //
+    // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+    // %1 = linalg.fill %cst, %0 {inplace= [true] }
+    // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+    //     {inplace= [true] }
+
+    if (uRead == &subsetOp.getDestinationOperand() &&
+        matchesInsertDestination(state, uConflictingWrite->get(), subsetOp))
+      // Case 1: The main insight is that InsertSliceOp reads only part of
+      // the destination tensor. The overwritten area is not read. If
+      // uConflictingWrite writes into exactly the memory location that is
+      // being read by uRead, this is not a conflict.
+      //
+      // In the above example:
+      // uRead             = OpOperand 1 (%t) of tensor.insert_slice
+      // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
+      //
+      // The read of %t does not conflict with the write of the FillOp
+      // (same aliases!) because the area that the FillOp operates on is
+      // exactly the one that is *not* read via %t.
+      return true;
+
+    if (uRead == &subsetOp.getSourceOperand() &&
+        uConflictingWrite == &subsetOp.getDestinationOperand() &&
+        matchesInsertDestination(state, uRead->get(), subsetOp))
+      // Case 2: The read of the source tensor and the write to the dest
+      // tensor via an InsertSliceOp is not a conflict if the read is
+      // reading exactly that part of an equivalent tensor that the
+      // InsertSliceOp is writing.
+      //
+      // In the above example:
+      // uRead             = OpOperand 0 (%1) of tensor.insert_slice
+      // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+      return true;
+  }
+
+  // If uConflictingWrite is an InsertSliceOp...
+  if (auto subsetOp = dyn_cast<SubsetOpInterface>(conflictingWritingOp))
+    // As an example, consider the following IR.
+    //
+    // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+    // %1 = linalg.fill %cst, %0 {inplace= [true] }
+    // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+    //     {inplace= [true] }
+    // %3 = vector.transfer_read %1, %cst
+    //
+    // In the above example:
+    // uRead             = OpOperand 0 (%1) of vector.transfer_read
+    // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+    // definition        = %1
+    //
+    // This is not a conflict because the InsertSliceOp overwrites the
+    // memory segment of %1 with the exact same data. (Effectively, there
+    // is no memory write here.)
+    if (uConflictingWrite == &subsetOp.getDestinationOperand() &&
+        state.areEquivalentBufferizedValues(
+            uRead->get(), subsetOp.getSourceOperand().get()) &&
+        matchesInsertDestination(state, subsetOp.getSourceOperand().get(),
+                                 subsetOp))
+      return true;
+
+  return false;
+}
+
 /// Given sets of uses and writes, return true if there is a RaW conflict under
 /// the assumption that all given reads/writes alias the same buffer and that
 /// all given writes bufferize inplace.
@@ -647,6 +745,12 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
         }
       }
 
+      // No conflict if the operands are non-conflicting subsets.
+      if (areNonConflictingSubsets(uRead, uConflictingWrite, state)) {
+        LLVM_DEBUG(llvm::dbgs() << "  no conflict: non-conflicting subsets\n");
+        continue;
+      }
+
       // No conflict if the op interface says so.
       if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {
         if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index a67ea0334b22b9b..dee4c22305b1fd7 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/IR/SubsetOpInterface.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -628,117 +629,6 @@ struct InsertOpInterface
   }
 };
 
-/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
-/// equivalent operand / result and same offset/sizes/strides specification).
-template <typename OpTy>
-static bool areEquivalentSlices(const AnalysisState &state,
-                                ExtractSliceOp extractSliceOp,
-                                OpTy insertSliceOp) {
-  if (!extractSliceOp || !insertSliceOp)
-    return false;
-  if (extractSliceOp != insertSliceOp &&
-      !state.areEquivalentBufferizedValues(extractSliceOp.getSource(),
-                                           insertSliceOp.getDest()))
-    return false;
-  if (!sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
-                                  isEqualConstantIntOrValue))
-    return false;
-  return true;
-}
-
-/// Return true if `value` is originating from an ExtractSliceOp that matches
-/// the given InsertSliceOp.
-template <typename OpTy>
-static bool matchesInsertDestination(const AnalysisState &state, Value value,
-                                     OpTy insertSliceOp) {
-  // Look for matching slices.
-  auto matchesSlice = [&](Value val) {
-    if (auto extractSliceOp = val.getDefiningOp<ExtractSliceOp>())
-      if (areEquivalentSlices(state, extractSliceOp, insertSliceOp))
-        return true;
-    return false;
-  };
-  return static_cast<bool>(llvm::all_of(
-      state.findValueInReverseUseDefChain(value, matchesSlice), matchesSlice));
-}
-
-template <typename OpTy>
-static bool isNotConflictingInsertSliceLikeOp(Operation *op, OpOperand *uRead,
-                                              OpOperand *uConflictingWrite,
-                                              const AnalysisState &state) {
-  Operation *readingOp = uRead->getOwner();
-  Operation *conflictingWritingOp = uConflictingWrite->getOwner();
-
-  // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
-  // uRead is an InsertSliceOp...
-  if (auto insertSliceOp = dyn_cast<OpTy>(readingOp)) {
-    // As an example, consider the following IR.
-    //
-    // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
-    // %1 = linalg.fill %cst, %0 {inplace= [true] }
-    // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
-    //     {inplace= [true] }
-
-    // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
-    if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-        matchesInsertDestination(state, uConflictingWrite->get(),
-                                 insertSliceOp))
-      // Case 1: The main insight is that InsertSliceOp reads only part of
-      // the destination tensor. The overwritten area is not read. If
-      // uConflictingWrite writes into exactly the memory location that is
-      // being read by uRead, this is not a conflict.
-      //
-      // In the above example:
-      // uRead             = OpOperand 1 (%t) of tensor.insert_slice
-      // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
-      //
-      // The read of %t does not conflict with the write of the FillOp
-      // (same aliases!) because the area that the FillOp operates on is
-      // exactly the one that is *not* read via %t.
-      return true;
-
-    if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
-        uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-        matchesInsertDestination(state, uRead->get(), insertSliceOp))
-      // Case 2: The read of the source tensor and the write to the dest
-      // tensor via an InsertSliceOp is not a conflict if the read is
-      // reading exactly that part of an equivalent tensor that the
-      // InsertSliceOp is writing.
-      //
-      // In the above example:
-      // uRead             = OpOperand 0 (%1) of tensor.insert_slice
-      // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
-      return true;
-  }
-
-  // If uConflictingWrite is an InsertSliceOp...
-  if (auto insertSliceOp = dyn_cast<OpTy>(conflictingWritingOp))
-    // As an example, consider the following IR.
-    //
-    // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
-    // %1 = linalg.fill %cst, %0 {inplace= [true] }
-    // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
-    //     {inplace= [true] }
-    // %3 = vector.transfer_read %1, %cst
-    //
-    // In the above example:
-    // uRead             = OpOperand 0 (%1) of vector.transfer_read
-    // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
-    // definition        = %1
-    //
-    // This is not a conflict because the InsertSliceOp overwrites the
-    // memory segment of %1 with the exact same data. (Effectively, there
-    // is no memory write here.)
-    if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-        state.areEquivalentBufferizedValues(uRead->get(),
-                                            insertSliceOp.getSource()) &&
-        matchesInsertDestination(state, insertSliceOp.getSource(),
-                                 insertSliceOp))
-      return true;
-
-  return false;
-}
-
 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
 /// certain circumstances, this op can also be a no-op.
 ///
@@ -777,13 +667,6 @@ struct InsertSliceOpInterface
     return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
   }
 
-  bool isNotConflicting(Operation *op, OpOperand *uRead,
-                        OpOperand *uConflictingWrite,
-                        const AnalysisState &state) const {
-    return isNotConflictingInsertSliceLikeOp<tensor::InsertSliceOp>(
-        op, uRead, uConflictingWrite, state);
-  }
-
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     // insert_slice ops arise from tiling and bufferizing them out-of-place is
@@ -1092,13 +975,6 @@ struct ParallelInsertSliceOpInterface
     rewriter.eraseOp(op);
     return success();
   }
-
-  bool isNotConflicting(Operation *op, OpOperand *uRead,
-                        OpOperand *uConflictingWrite,
-                        const AnalysisState &state) const {
-    return isNotConflictingInsertSliceLikeOp<tensor::ParallelInsertSliceOp>(
-        op, uRead, uConflictingWrite, state);
-  }
 };
 
 /// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled
@@ -1147,6 +1023,64 @@ struct SplatOpInterface
   }
 };
 
+namespace {
+/// Return "true" if `insertSliceOp` inserts into a subset that is equivalent
+/// to the subset defined by `candidate`. `equivalenceFn` is used to determine
+/// equivalence of tensors.
+template <typename OpTy>
+bool isSubsetEquivalentToInsertSliceLikeOp(
+    OpTy insertSliceOp, Value candidate,
+    function_ref<bool(Value, Value)> equivalenceFn) {
+  // Look for a matching tensor.extract_slice op.
+  auto extractSliceOp = candidate.getDefiningOp<tensor::ExtractSliceOp>();
+  if (!extractSliceOp)
+    return false;
+  if (!equivalenceFn(extractSliceOp.getSource(), insertSliceOp.getDest()))
+    return false;
+  if (!sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
+                                  isEqualConstantIntOrValue))
+    return false;
+  return true;
+}
+} // namespace
+
+struct InsertSliceOpSubsetOpInterface
+    : public SubsetOpInterface::ExternalModel<InsertSliceOpSubsetOpInterface,
+                                              tensor::InsertSliceOp> {
+  OpOperand &getSourceOperand(Operation *op) const {
+    return op->getOpOperand(0);
+  }
+
+  bool
+  isEquivalentSubset(Operation *op, Value candidate,
+                     function_ref<bool(Value, Value)> equivalenceFn) const {
+    auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
+    return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
+                                                 equivalenceFn);
+  }
+};
+
+struct ParallelInsertSliceOpSubsetOpInterface
+    : public SubsetOpInterface::ExternalModel<
+          ParallelInsertSliceOpSubsetOpInterface,
+          tensor::ParallelInsertSliceOp> {
+  OpOperand &getSourceOperand(Operation *op) const {
+    return op->getOpOperand(0);
+  }
+
+  OpOperand &getDestinationOperand(Operation *op) const {
+    return op->getOpOperand(1);
+  }
+
+  bool
+  isEquivalentSubset(Operation *op, Value candidate,
+                     function_ref<bool(Value, Value)> equivalenceFn) const {
+    auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
+    return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
+                                                 equivalenceFn);
+  }
+};
+
 } // namespace
 } // namespace tensor
 } // namespace mlir
@@ -1154,6 +1088,7 @@ struct SplatOpInterface
 void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
+    // BufferizableOpInterface models.
     CastOp::attachInterface<CastOpInterface>(*ctx);
     CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
     DimOp::attachInterface<DimOpInterface>(*ctx);
@@ -1172,6 +1107,11 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
     ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
     SplatOp::attachInterface<SplatOpInterface>(*ctx);
 
+    // SubsetOpInterface models.
+    InsertSliceOp::attachInterface<InsertSliceOpSubsetOpInterface>(*ctx);
+    ParallelInsertSliceOp::attachInterface<
+        ParallelInsertSliceOpSubsetOpInterface>(*ctx);
+
     // Load additional dialects of which ops may get created.
     ctx->loadDialect<arith::ArithDialect, linalg::LinalgDialect>();
   });
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 7829bb0ffbd2932..b3645710232f2fa 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -7,14 +7,14 @@
 
 load("@bazel_skylib//rules:expand_template.bzl", "expand_template")
 load("@bazel_skylib//rules:write_file.bzl", "write_file")
-load(":tblgen.bzl", "gentbl_cc_library", "td_library")
-load(":linalggen.bzl", "genlinalg")
 load(
     ":build_defs.bzl",
     "cc_headers_only",
     "if_cuda_available",
     "mlir_c_api_cc_library",
 )
+load(":linalggen.bzl", "genlinalg")
+load(":tblgen.bzl", "gentbl_cc_library", "td_library")
 
 package(
     default_visibility = ["//visibility:public"],
@@ -9705,6 +9705,36 @@ gentbl_cc_library(
     ],
 )
 
+td_library(
+    name = "SubsetOpInterfaceTdFiles",
+    srcs = [
+        "include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.td",
+    ],
+    includes = ["include"],
+    deps = [
+        ":OpBaseTdFiles",
+    ],
+)
+
+gentbl_cc_library(
+    name = "SubsetOpInterfaceIncGen",
+    tbl_outs = [
+        (
+            ["-gen-op-interface-decls"],
+            "include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.h.inc",
+        ),
+        (
+            ["-gen-op-interface-defs"],
+            "include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.td",
+    deps = [
+        ":SubsetOpInterfaceTdFiles",
+    ],
+)
+
 td_library(
     name = "LinalgDocTdFiles",
     srcs = ["include/mlir/Dialect/Linalg/IR/LinalgDoc.td"],
@@ -11972,12 +12002,14 @@ cc_library(
         "lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp",
         "lib/Dialect/Bufferization/IR/BufferizationDialect.cpp",
         "lib/Dialect/Bufferization/IR/BufferizationOps.cpp",
+        "lib/Dialect/Bufferization/IR/SubsetOpInterface.cpp",
         "lib/Dialect/Bufferization/IR/UnstructuredControlFlow.cpp",
     ],
     hdrs = [
         "include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h",
         "include/mlir/Dialect/Bufferization/IR/Bufferization.h",
         "include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h",
+        "include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.h",
         "include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h",
     ],
     includes = ["include"],
@@ -11998,6 +12030,7 @@ cc_library(
         ":InferTypeOpInterface",
         ":MemRefDialect",
         ":SparseTensorDialect",
+        ":SubsetOpInterfaceIncGen",
         ":Support",
         ":TensorDialect",
         "//llvm:Support",

>From 2fb451a5f28738d8b964fce5248a53d5473eb076 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 8 Sep 2023 16:48:15 +0200
Subject: [PATCH 2/4] [mlir][bufferization] Empty tensor elimination based on
 SubsetOpInterface

This commit generalizes empty tensor elimination to operate on subset ops. No new test cases are added because all current subset ops were already supported by previously. From this perspective, this change is NFC.

A new interface method (and a helper method) are added to `SubsetOpInterface` to build the subset of the destination tensor.
---
 .../Bufferization/IR/SubsetOpInterface.td     |  44 ++++
 .../TransformOps/BufferizationTransformOps.td |  20 +-
 .../Bufferization/Transforms/Passes.td        |  21 +-
 .../Bufferization/Transforms/Transforms.h     |  38 ++--
 .../BufferizationTransformOps.cpp             |   3 +-
 .../Transforms/EmptyTensorElimination.cpp     | 193 +++++-------------
 .../BufferizableOpInterfaceImpl.cpp           |  45 ++++
 7 files changed, 182 insertions(+), 182 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.td
index 97dd4d2102a9323..ffea92dac0397d3 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.td
@@ -96,6 +96,50 @@ def SubsetOpInterface : OpInterface<"SubsetOpInterface"> {
             "::mlir::Value":$candidate,
             "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn)
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return the subset of the destination tensor that this operation
+          inserts into.
+
+          Example:
+          ```
+          // SubsetOpInterface op:
+          %0 = tensor.insert_slice %t0 into %t1[%pos][5][1]
+              : tensor<5xf32> into tensor<?xf32>
+          // Subset (built by this function):
+          %1 = tensor.extract_slice %t1[%pos][5][1]
+              : tensor<?xf32> to tensor<5xf32>
+          ```
+
+          Note: Implementations do not necessarily have to build new IR. They
+          may return existing SSA values.
+        }],
+        /*retType=*/"::mlir::Value",
+        /*methodName=*/"getSubset",
+        /*args=*/(ins "::mlir::OpBuilder &":$builder, "Location":$loc),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          llvm_unreachable("getSubset not implemented");
+        }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return all SSA values that are needed (i.e., must be in scope) at the
+          insertion of the builder when calling `getSubset`. Users of
+          `getSubset` can use this helper method to find a suitable insertion
+          point.
+
+          Example: The SSA values needed to build the subset in the example of
+          `getSubset` are %t1 and %pos.
+        }],
+        /*retType=*/"::llvm::SmallVector<::mlir::Value>",
+        /*methodName=*/"getValuesNeededToBuildSubset",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          llvm_unreachable("getSubset not implemented");
+        }]
+      >,
   ];
 
   let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
index 46a95ad8db2a6db..84bd047e6d51eed 100644
--- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
@@ -109,19 +109,17 @@ def EliminateEmptyTensorsOp
          DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
   let description = [{
     Try to eliminate all `tensor.empty` ops within the targeted op by replacing
-    them with a destination tensor.
+    them with another destination tensor.
 
-    `tensor.empty` ops cannot be bufferizes. They can either be converted to
-    `bufferization.alloc_tensor` or replaced with another tensor (via this
-    transform). `tensor.empty` does not specify the contents of the returned
+    "tensor.empty" ops cannot be bufferized. They can either be converted to
+    "bufferization.alloc_tensor" or replaced with another tensor (via this
+    transform). "tensor.empty" does not specify the contents of the returned
     tensor so their results can be replaced with arbitrary tensor values as long
     as the dimensions match.
 
-    This transform looks for `tensor.empty` ops where the SSA use-def chain of
-    the result ends in a supported "anchor op" (always following the aliasing
-    OpOperand/OpResult chain). Currently supported anchor ops are:
-    - `tensor.insert_slice`
-    - `bufferization.yield` (inside `bufferization.alloc_tensor`)
+    This transformation looks for subset ops that insert a tensor that
+    originates from a "tensor.empty" (as per the reverse use-def chain). Such
+    "tensor.empty" ops are replaced with the destination subset.
 
     Example:
 
@@ -138,6 +136,10 @@ def EliminateEmptyTensorsOp
     %2 = tensor.insert_slice %1 into %t[1][5][1]
     ```
 
+    In the above example, the subset op is "tensor.insert_slice". When tracing
+    back the reverse use-def chain of a the source, we end up at a
+    "tensor.empty" op.
+
     The above example can bufferize without an allocation (in the absence of
     other conflicts) because there is no longer a `tensor.empty` op.
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index df9bfcbfc548806..ff43cff817b64a8 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -402,11 +402,22 @@ def PromoteBuffersToStack : Pass<"promote-buffers-to-stack", "func::FuncOp"> {
 def EmptyTensorElimination : Pass<"eliminate-empty-tensors"> {
   let summary = "Try to eliminate all tensor.empty ops.";
   let description = [{
-    This pass tries to eliminate all insert_slice op-anchored tensor.empty ops.
-    I.e., when a value that is equivalent to an tensor.empty op is inserted into
-    another tensor, this pass tries to rewrite the IR in such a way that the
-    destination tensor of the insert_slice op is used directly instead of the
-    tensor.empty result.
+    Try to eliminate "tensor.empty" ops inside `op`. This transformation looks
+    for subset ops that insert a tensor that originates from a "tensor.empty"
+    (as per the reverse use-def chain). Such "tensor.empty" ops are replaced
+    with the destination subset.
+
+    E.g.:
+    ```
+    %0 = tensor.empty() : tensor<10xf32>
+    %1 = linalg.fill ... outs(%0 : tensor<10xf32>)
+    %2 = tensor.insert_slice %0 into %t ...
+    ```
+
+    In the above example, the subset op is "tensor.insert_slice". When tracing
+    back the reverse use-def chain of a the source, we end up at a
+    "tensor.empty" op. The "tensor.empty" op is replaced with a
+    "tensor.extract_slice" op.
   }];
   let constructor = "mlir::bufferization::createEmptyTensorEliminationPass()";
 }
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
index a0cfc811a0b50a5..df866daf1ab1ff6 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
@@ -19,38 +19,26 @@ struct BufferizationStatistics;
 class OneShotAnalysisState;
 struct OneShotBufferizationOptions;
 
-/// A function that matches anchor OpOperands for tensor::EmptyOp elimination.
-/// If an OpOperand is matched, the function should populate the SmallVector
-/// with all values that are needed during `RewriteFn` to produce the
-/// replacement value.
-using AnchorMatchFn = std::function<bool(OpOperand &, SmallVector<Value> &)>;
-
-/// A function that rewrites matched anchors.
-using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;
-
-/// Try to eliminate tensor::EmptyOps inside `op`.
+/// Try to eliminate "tensor.empty" ops inside `op`. This transformation looks
+/// for subset ops that insert a tensor that originates from a "tensor.empty"
+/// (as per the reverse use-def chain). Such "tensor.empty" ops are replaced
+/// with the destination subset.
 ///
-/// * `rewriteFunc` generates the replacement for the tensor::EmptyOp.
-/// * Only tensor::EmptyOps that are anchored on a matching OpOperand as per
-///   `anchorMatchFunc` are considered. "Anchored" means that there is a path
-///   on the reverse SSA use-def chain, starting from the OpOperand and always
-///   following the aliasing  OpOperand, that eventually ends at a single
-///   tensor::EmptyOp.
+/// E.g.:
+/// %0 = tensor.empty() : tensor<10xf32>
+/// %1 = linalg.fill ... outs(%0 : tensor<10xf32>)
+/// %2 = tensor.insert_slice %0 into %t ...
+///
+/// In the above example, the subset op is "tensor.insert_slice". When tracing
+/// back the reverse use-def chain of a the source, we end up at a
+/// "tensor.empty" op.
 LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op,
-                                    OneShotAnalysisState &state,
-                                    AnchorMatchFn anchorMatchFunc,
-                                    RewriteFn rewriteFunc);
+                                    OneShotAnalysisState &state);
 
 /// Within the given operation, hoist buffers from loops where possible. See
 /// "BufferLoopHoistingPass" for more information.
 void hoistBuffersFromLoops(Operation *op);
 
-/// Try to eliminate tensor::EmptyOps inside `op` that are anchored on an
-/// InsertSliceOp, i.e., if it is eventually inserted into another tensor
-/// (and some other conditions are met).
-LogicalResult insertSliceAnchoredEmptyTensorEliminationStep(
-    RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state);
-
 /// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
 /// After applying this transform, the IR can be bufferized without inserting
 /// additional buffer allocations.
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index 097f75a7bc50f5b..b84cc452d0141cd 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -121,8 +121,7 @@ DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply(
     if (failed(analyzeOp(target, state)))
       return mlir::emitSilenceableFailure(target->getLoc())
              << "failed to analyze op";
-    if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep(
-            rewriter, target, state)))
+    if (failed(bufferization::eliminateEmptyTensors(rewriter, target, state)))
       return mlir::emitSilenceableFailure(target->getLoc())
              << "failed to eliminate insert_slice anchored tensor.empty ops";
   }
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 4e0781dae0c2523..f09ef61b5905319 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -10,6 +10,7 @@
 
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/SubsetOpInterface.h"
 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -99,154 +100,65 @@ findValidInsertionPoint(Operation *emptyTensorOp,
   return nullptr;
 }
 
-/// Try to eliminate tensor::EmptyOps inside `op`. A tensor::EmptyOp is replaced
-/// with the result of `rewriteFunc` if it is anchored on a matching
-/// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def
-/// chain, starting from the OpOperand and always following the aliasing
-/// OpOperand, that eventually ends at the tensor::EmptyOp.
-///
-/// E.g.:
-/// %0 = tensor.empty() : tensor<10xf32>
-/// %1 = linalg.fill ... outs(%0 : tensor<10xf32>)
-/// %2 = tensor.insert_slice %0 into %t ...
-///
-/// In the above example, the anchor is the source operand of the insert_slice
-/// op. When tracing back the reverse use-def chain, we end up at a
-/// tensor.empty op.
 LogicalResult mlir::bufferization::eliminateEmptyTensors(
-    RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
-    AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc) {
+    RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
   OpBuilder::InsertionGuard g(rewriter);
 
-  op->walk([&](Operation *op) {
-    for (OpOperand &operand : op->getOpOperands()) {
-      // Skip operands that do not bufferize inplace.
-      if (!state.isInPlace(operand))
-        continue;
-      // All values that are needed to create the replacement op.
-      SmallVector<Value> neededValues;
-      // Is this an anchor?
-      if (!anchorMatchFunc(operand, neededValues))
+  op->walk([&](SubsetOpInterface op) {
+    OpOperand &source = op.getSourceOperand();
+    // Skip operands that do not bufferize inplace. "tensor.empty" could still
+    // be replaced, but the transformation may not be beneficial.
+    if (!state.isInPlace(source))
+      return WalkResult::skip();
+    // All values that are needed to create the replacement op.
+    SmallVector<Value> neededValues = op.getValuesNeededToBuildSubset();
+
+    // Find tensor.empty ops on the reverse SSA use-def chain. Only follow
+    // equivalent tensors. I.e., stop when there are ops such as extract_slice
+    // on the path.
+    TraversalConfig config;
+    config.followEquivalentOnly = true;
+    config.alwaysIncludeLeaves = false;
+    // Replace only if the types match or are static <-> dynamic casts. We do
+    // not support slices or reshapes.
+    // TODO: This could be extended to support IR such as:
+    // %0 = tensor.empty() : tensor<128xf32>
+    // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
+    // %2 = tensor.expand_shape %1 ...
+    // %3 = tensor.insert_slice %2 into ...
+    config.followSameTypeOrCastsOnly = true;
+    SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
+        source.get(), /*condition=*/
+        [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
+        config);
+
+    for (Value v : emptyTensors) {
+      Operation *emptyTensorOp = v.getDefiningOp();
+
+      // Find a suitable insertion point. If no suitable insertion point for
+      // the replacement can be found, skip this replacement.
+      Operation *insertionPoint =
+          findValidInsertionPoint(emptyTensorOp, neededValues);
+      if (!insertionPoint)
         continue;
 
-      // Find tensor.empty ops on the reverse SSA use-def chain. Only follow
-      // equivalent tensors. I.e., stop when there are ops such as extract_slice
-      // on the path.
-      TraversalConfig config;
-      config.followEquivalentOnly = true;
-      config.alwaysIncludeLeaves = false;
-      // Replace only if the types match or are static <-> dynamic casts. We do
-      // not support slices or reshapes.
-      // TODO: This could be extended to support IR such as:
-      // %0 = tensor.empty() : tensor<128xf32>
-      // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
-      // %2 = tensor.expand_shape %1 ...
-      // %3 = tensor.insert_slice %2 into ...
-      config.followSameTypeOrCastsOnly = true;
-      SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
-          operand.get(), /*condition=*/
-          [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
-          config);
-
-      for (Value v : emptyTensors) {
-        Operation *emptyTensorOp = v.getDefiningOp();
-
-        // Find a suitable insertion point. If no suitable insertion point for
-        // the replacement can be found, skip this replacement.
-        Operation *insertionPoint =
-            findValidInsertionPoint(emptyTensorOp, neededValues);
-        if (!insertionPoint)
-          continue;
-
-        rewriter.setInsertionPoint(insertionPoint);
-        Value replacement =
-            rewriteFunc(rewriter, emptyTensorOp->getLoc(), operand);
-        if (!replacement)
-          continue;
-        if (replacement.getType() != v.getType()) {
-          rewriter.setInsertionPointAfterValue(replacement);
-          replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
-                                                        replacement);
-        }
-        // Replace the tensor::EmptyOp.
-        rewriter.replaceOp(emptyTensorOp, replacement);
-        state.resetCache();
+      rewriter.setInsertionPoint(insertionPoint);
+      Value replacement = op.getSubset(rewriter, emptyTensorOp->getLoc());
+      if (!replacement)
+        continue;
+      if (replacement.getType() != v.getType()) {
+        rewriter.setInsertionPointAfterValue(replacement);
+        replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
+                                                      replacement);
       }
+      // Replace the tensor::EmptyOp.
+      rewriter.replaceOp(emptyTensorOp, replacement);
+      state.resetCache();
     }
-  });
-
-  return success();
-}
-
-/// Try to eliminate tensor::EmptyOps inside `op`. An tensor::EmptyOp can be
-/// eliminated if it is eventually inserted into another tensor (and some other
-/// conditions are met).
-///
-/// E.g.:
-/// %0 = tensor.empty()
-/// %1 = linalg.fill(%cst, %0) {inplace = [true]}
-/// %2 = tensor.insert_slice %1 into %t[10][20][1]
-///
-/// tensor::EmptyOp elimination will try to fill %t inplace instead of filling a
-/// new allocation %0 and inserting it into %t. This is done by replacing the
-/// tensor::EmptyOp with:
-///
-/// %0 = tensor.extract_slice %t[10][20][1]
-///
-/// The analysis looks for matching ExtractSliceOp/InsertSliceOp pairs and lets
-/// those bufferize inplace in the absence of other conflicts.
-///
-/// Starting from an InsertSliceOp, an tensor::EmptyOp at the end of the insert
-/// source's reverse use-def chain is eliminated if:
-/// * On the reverse use-def chain path from the InsertSliceOp to the
-///   tensor::EmptyOp, all ops were decided to bufferize inplace and the buffer
-///   relation is "equivalent" (TODO: can be relaxed if needed).
-/// * The reverse use-def chain has exactly one end, which is the
-///   tensor::EmptyOp.
-template <typename OpTy>
-static LogicalResult insertSliceLikeAnchoredEmptyTensorEliminationStep(
-    RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
-  return eliminateEmptyTensors(
-      rewriter, op, state,
-      /*anchorMatchFunc=*/
-      [&](OpOperand &operand, SmallVector<Value> &neededValues) {
-        auto insertSliceOp = dyn_cast<OpTy>(operand.getOwner());
-        if (!insertSliceOp)
-          return false;
-        if (&operand != &insertSliceOp->getOpOperand(0) /*source*/)
-          return false;
 
-        // Collect all values that are needed to construct the replacement op.
-        neededValues.append(insertSliceOp.getOffsets().begin(),
-                            insertSliceOp.getOffsets().end());
-        neededValues.append(insertSliceOp.getSizes().begin(),
-                            insertSliceOp.getSizes().end());
-        neededValues.append(insertSliceOp.getStrides().begin(),
-                            insertSliceOp.getStrides().end());
-        neededValues.push_back(insertSliceOp.getDest());
-
-        return true;
-      },
-      /*rewriteFunc=*/
-      [](OpBuilder &b, Location loc, OpOperand &operand) {
-        auto insertOp = cast<OpTy>(operand.getOwner());
-        auto extractOp = b.create<tensor::ExtractSliceOp>(
-            loc, insertOp.getSourceType(), insertOp.getDest(),
-            insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
-            insertOp.getMixedStrides());
-        return extractOp.getResult();
-      });
-}
+    return WalkResult::advance();
+  });
 
-LogicalResult
-mlir::bufferization::insertSliceAnchoredEmptyTensorEliminationStep(
-    RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
-  if (failed(insertSliceLikeAnchoredEmptyTensorEliminationStep<
-             tensor::InsertSliceOp>(rewriter, op, state)))
-    return failure();
-  if (failed(insertSliceLikeAnchoredEmptyTensorEliminationStep<
-             tensor::ParallelInsertSliceOp>(rewriter, op, state)))
-    return failure();
   return success();
 }
 
@@ -276,8 +188,7 @@ void EmptyTensorElimination::runOnOperation() {
   }
 
   IRRewriter rewriter(op->getContext());
-  if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep(
-          rewriter, op, state)))
+  if (failed(bufferization::eliminateEmptyTensors(rewriter, op, state)))
     signalPassFailure();
 }
 
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index dee4c22305b1fd7..bbbbc682ee1f00f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1042,6 +1042,31 @@ bool isSubsetEquivalentToInsertSliceLikeOp(
     return false;
   return true;
 }
+
+template <typename OpTy>
+Value getSubsetOfInsertSliceLikeOp(OpBuilder &b, Location loc,
+                                   OpTy insertSliceOp) {
+  auto extractOp = b.create<tensor::ExtractSliceOp>(
+      loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(),
+      insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
+      insertSliceOp.getMixedStrides());
+  return extractOp.getResult();
+}
+
+template <typename OpTy>
+SmallVector<Value>
+getValuesNeededToBuildSubsetOfInsertSliceLikeOp(OpTy insertSliceOp) {
+  SmallVector<Value> neededValues;
+  // Collect all values that are needed to construct the replacement op.
+  neededValues.append(insertSliceOp.getOffsets().begin(),
+                      insertSliceOp.getOffsets().end());
+  neededValues.append(insertSliceOp.getSizes().begin(),
+                      insertSliceOp.getSizes().end());
+  neededValues.append(insertSliceOp.getStrides().begin(),
+                      insertSliceOp.getStrides().end());
+  neededValues.push_back(insertSliceOp.getDest());
+  return neededValues;
+}
 } // namespace
 
 struct InsertSliceOpSubsetOpInterface
@@ -1058,6 +1083,16 @@ struct InsertSliceOpSubsetOpInterface
     return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
                                                  equivalenceFn);
   }
+
+  Value getSubset(Operation *op, OpBuilder &builder, Location loc) const {
+    return getSubsetOfInsertSliceLikeOp(builder, loc,
+                                        cast<tensor::InsertSliceOp>(op));
+  }
+
+  SmallVector<Value> getValuesNeededToBuildSubset(Operation *op) const {
+    return getValuesNeededToBuildSubsetOfInsertSliceLikeOp(
+        cast<tensor::InsertSliceOp>(op));
+  }
 };
 
 struct ParallelInsertSliceOpSubsetOpInterface
@@ -1079,6 +1114,16 @@ struct ParallelInsertSliceOpSubsetOpInterface
     return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
                                                  equivalenceFn);
   }
+
+  Value getSubset(Operation *op, OpBuilder &builder, Location loc) const {
+    return getSubsetOfInsertSliceLikeOp(
+        builder, loc, cast<tensor::ParallelInsertSliceOp>(op));
+  }
+
+  SmallVector<Value> getValuesNeededToBuildSubset(Operation *op) const {
+    return getValuesNeededToBuildSubsetOfInsertSliceLikeOp(
+        cast<tensor::ParallelInsertSliceOp>(op));
+  }
 };
 
 } // namespace

>From fd006d283238ae9ff3a45752b1706be37117669d Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 8 Sep 2023 16:48:33 +0200
Subject: [PATCH 3/4] [mlir][bufferization][NFC] Rename copy_tensor op to
 materialize_in_destination

The previous name was badly chosen. The op is used to ensure that a computation materializes in the future buffer of a certain tensor.

BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
---
 .../Bufferization/IR/BufferizationOps.td      | 27 ++++--
 .../Linalg/TransformOps/LinalgTransformOps.td |  6 +-
 .../Dialect/Linalg/Transforms/Transforms.h    |  4 +-
 .../Bufferization/IR/BufferizationOps.cpp     | 86 ++++++++++---------
 .../TransformOps/LinalgTransformOps.cpp       | 10 ++-
 .../lib/Dialect/Linalg/Transforms/Padding.cpp |  8 +-
 .../Transforms/one-shot-bufferize.mlir        |  2 +-
 mlir/test/Dialect/Bufferization/invalid.mlir  |  4 +-
 mlir/test/Dialect/Bufferization/ops.mlir      |  8 +-
 .../test/Dialect/Linalg/transform-op-pad.mlir |  8 +-
 10 files changed, 91 insertions(+), 72 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index fec07af349b3a8d..a6a733dfce13251 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -209,17 +209,30 @@ def Bufferization_CloneOp : Bufferization_Op<"clone", [
 }
 
 //===----------------------------------------------------------------------===//
-// CopyTensorOp
+// MaterializeInDestinationOp
 //===----------------------------------------------------------------------===//
 
-def Bufferization_CopyTensorOp : Bufferization_Op<"copy_tensor",
-    [BufferizableOpInterface, SameOperandsAndResultType,
-     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+def Bufferization_MaterializeInDestinationOp
+    : Bufferization_Op<"materialize_in_destination",
+        [BufferizableOpInterface, SameOperandsAndResultType,
+         DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
   let summary = "copy a tensor";
 
   let description = [{
-    Copy the contents of the source tensor into the destination tensor. This
-    operation is guaranteed to bufferize to a memory copy.
+    This op indicates that the data of the `source` tensor should materialize
+    in the future buffer of the `dest` tensors. Both tensors must have the same
+    shape and element type at runtime.
+
+    By default, this op bufferizes to a memcpy from the future buffer of the
+    `source` tensor to the future buffer of the `dest` tensor. However,
+    transformations such as "empty tensor elimination" may rewrite IR such that
+    a computation is performed directly in the future buffer of the `dest`
+    tensor and no memcpy is needed.
+
+    Note: "tensor.insert_slice" could be used for the same purpose, but since
+    tensor dialect ops only indicate *what* should be computed but not *where*,
+    it could fold away, causing the computation to materialize in a different
+    buffer.
   }];
 
   let arguments = (ins AnyTensor:$source,
@@ -245,7 +258,7 @@ def Bufferization_CopyTensorOp : Bufferization_Op<"copy_tensor",
     }
   }];
 
-  let assemblyFormat = "$source `,` $dest attr-dict `:` type($source)";
+  let assemblyFormat = "$source `in` $dest attr-dict `:` type($source)";
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index ee6e12f72b80bab..8091395b8b53fb3 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -934,7 +934,7 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
     the original destination tensor of the targeted op. The op that copies back
     the result can be customized with `copy_back_op`:
 
-    * "bufferization.copy_tensor" (default)
+    * "bufferization.materialize_in_destination" (default)
     * "linalg.copy"
     * "none" (no copy back)
 
@@ -959,7 +959,7 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
          DefaultValuedAttr<
           TypedArrayAttrBase<I64ArrayAttr, "array of arrays of i64">,
           "{}">:$transpose_paddings,
-         DefaultValuedAttr<StrAttr, "::mlir::bufferization::CopyTensorOp::getOperationName()">:$copy_back_op);
+         DefaultValuedAttr<StrAttr, "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copy_back_op);
   let results = (outs TransformHandleTypeInterface:$padded,
                       TransformHandleTypeInterface:$pad,
                       TransformHandleTypeInterface:$copy);
@@ -979,7 +979,7 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
                    CArg<"ArrayRef<int64_t>", "{}">:$padToMultipleOf,
                    CArg<"ArrayRef<int64_t>", "{}">:$packPaddings,
                    CArg<"ArrayRef<Attribute>", "{}">:$transposePaddings,
-                   CArg<"StringRef", "::mlir::bufferization::CopyTensorOp::getOperationName()">:$copyBackOp)>
+                   CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp)>
   ];
 
   let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index fd82c67ede5fa97..759fc528bba37bd 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -293,12 +293,12 @@ struct LinalgPaddingOptions {
   }
   enum class CopyBackOp : int8_t {
     None = 0,
-    BufferizationCopyTensor = 1,
+    BufferizationMaterializeInDestination = 1,
     LinalgCopy = 2
   };
   /// The op to be used for copying the padded result to the original
   /// destination tensor.
-  CopyBackOp copyBackOp = CopyBackOp::BufferizationCopyTensor;
+  CopyBackOp copyBackOp = CopyBackOp::BufferizationMaterializeInDestination;
   LinalgPaddingOptions &setCopyBackOp(CopyBackOp op) {
     copyBackOp = op;
     return *this;
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 9a2a6d0f5c6d981..e5016c956804688 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -441,48 +441,6 @@ Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
   return getOperand(getIndexOfDynamicSize(idx));
 }
 
-//===----------------------------------------------------------------------===//
-// CopyTensorOp
-//===----------------------------------------------------------------------===//
-
-bool CopyTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
-                                          const AnalysisState &state) {
-  if (&opOperand == &getOperation()->getOpOperand(0) /*source*/)
-    return true;
-  return false;
-}
-
-bool CopyTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
-                                           const AnalysisState &state) {
-  if (&opOperand == &getOperation()->getOpOperand(1) /*dest*/)
-    return true;
-  return false;
-}
-
-AliasingValueList CopyTensorOp::getAliasingValues(OpOperand &opOperand,
-                                                  const AnalysisState &state) {
-  if (&opOperand == &getOperation()->getOpOperand(1) /*dest*/)
-    return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
-  return {};
-}
-
-LogicalResult CopyTensorOp::bufferize(RewriterBase &rewriter,
-                                      const BufferizationOptions &options) {
-  FailureOr<Value> buffer = getBuffer(rewriter, getDest(), options);
-  if (failed(buffer))
-    return failure();
-  rewriter.create<memref::TensorStoreOp>(getLoc(), getSource(), *buffer);
-  replaceOpWithBufferizedValues(rewriter, getOperation(), *buffer);
-  return success();
-}
-
-LogicalResult CopyTensorOp::reifyResultShapes(
-    OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
-  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
-  reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // CloneOp
 //===----------------------------------------------------------------------===//
@@ -585,6 +543,50 @@ LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// MaterializeInDestinationOp
+//===----------------------------------------------------------------------===//
+
+bool MaterializeInDestinationOp::bufferizesToMemoryRead(
+    OpOperand &opOperand, const AnalysisState &state) {
+  if (&opOperand == &getOperation()->getOpOperand(0) /*source*/)
+    return true;
+  return false;
+}
+
+bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
+    OpOperand &opOperand, const AnalysisState &state) {
+  if (&opOperand == &getOperation()->getOpOperand(1) /*dest*/)
+    return true;
+  return false;
+}
+
+AliasingValueList
+MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
+                                              const AnalysisState &state) {
+  if (&opOperand == &getOperation()->getOpOperand(1) /*dest*/)
+    return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
+  return {};
+}
+
+LogicalResult
+MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
+                                      const BufferizationOptions &options) {
+  FailureOr<Value> buffer = getBuffer(rewriter, getDest(), options);
+  if (failed(buffer))
+    return failure();
+  rewriter.create<memref::TensorStoreOp>(getLoc(), getSource(), *buffer);
+  replaceOpWithBufferizedValues(rewriter, getOperation(), *buffer);
+  return success();
+}
+
+LogicalResult MaterializeInDestinationOp::reifyResultShapes(
+    OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
+  reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ToTensorOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 6549c27b0d0dfb4..d62b5d1f5074ecf 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1679,9 +1679,10 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
     options.padToMultipleOf = padToMultipleOf;
     options.paddingValues = paddingValues;
     options.packPaddings = packPaddings;
-    if (getCopyBackOp() == bufferization::CopyTensorOp::getOperationName()) {
-      options.copyBackOp =
-          LinalgPaddingOptions::CopyBackOp::BufferizationCopyTensor;
+    if (getCopyBackOp() ==
+        bufferization::MaterializeInDestinationOp::getOperationName()) {
+      options.copyBackOp = LinalgPaddingOptions::CopyBackOp::
+          BufferizationMaterializeInDestination;
     } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
       options.copyBackOp = LinalgPaddingOptions::CopyBackOp::LinalgCopy;
     } else if (getCopyBackOp() == kCopyOpNone) {
@@ -1757,7 +1758,8 @@ LogicalResult transform::PadOp::verify() {
              << attr;
     }
   }
-  if (getCopyBackOp() != bufferization::CopyTensorOp::getOperationName() &&
+  if (getCopyBackOp() !=
+          bufferization::MaterializeInDestinationOp::getOperationName() &&
       getCopyBackOp() != linalg::CopyOp::getOperationName() &&
       getCopyBackOp() != kCopyOpNone)
     return emitOpError() << "invalid copy_back_op";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
index b8ebf7dbb0fe72f..8fe745d97ca3dd8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
@@ -245,9 +245,11 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
                                                          std::get<1>(it)->get())
                                  .getResult(0));
     } else if (options.copyBackOp ==
-               LinalgPaddingOptions::CopyBackOp::BufferizationCopyTensor) {
-      replacements.push_back(rewriter.create<bufferization::CopyTensorOp>(
-          loc, std::get<0>(it), std::get<1>(it)->get()));
+               LinalgPaddingOptions::CopyBackOp::
+                   BufferizationMaterializeInDestination) {
+      replacements.push_back(
+          rewriter.create<bufferization::MaterializeInDestinationOp>(
+              loc, std::get<0>(it), std::get<1>(it)->get()));
     } else {
       llvm_unreachable("unsupported copy back op");
     }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
index a261256c033fa41..f92c7b4ee585150 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
@@ -224,6 +224,6 @@ func.func @tensor_copy(%arg0: tensor<5xf32>) -> tensor<5xf32> {
   // CHECK: memref.dealloc %[[alloc]]
   // CHECK: return %[[r]]
   %dest = bufferization.alloc_tensor() : tensor<5xf32>
-  %0 = bufferization.copy_tensor %arg0, %dest : tensor<5xf32>
+  %0 = bufferization.materialize_in_destination %arg0 in %dest : tensor<5xf32>
   return %0 : tensor<5xf32>
 }
diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir
index 3b4bfee5622e9bb..7c92193ab068dba 100644
--- a/mlir/test/Dialect/Bufferization/invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/invalid.mlir
@@ -99,9 +99,9 @@ func.func @invalid_writable_on_op() {
 // -----
 
 // expected-note @below{{prior use here}}
-func.func @invalid_tensor_copy(%arg0: tensor<?xf32>, %arg1: tensor<5xf32>) {
+func.func @invalid_materialize_in_destination(%arg0: tensor<?xf32>, %arg1: tensor<5xf32>) {
   // expected-error @below{{expects different type than prior uses: 'tensor<?xf32>' vs 'tensor<5xf32>'}}
-  bufferization.copy_tensor %arg0, %arg1 : tensor<?xf32>
+  bufferization.materialize_in_destination %arg0 in %arg1 : tensor<?xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Bufferization/ops.mlir b/mlir/test/Dialect/Bufferization/ops.mlir
index 773f15c1ffcb89b..665f5697fdc5fdf 100644
--- a/mlir/test/Dialect/Bufferization/ops.mlir
+++ b/mlir/test/Dialect/Bufferization/ops.mlir
@@ -58,11 +58,11 @@ func.func @test_dealloc_tensor_op(%arg0: tensor<4xi32>) {
   return
 }
 
-// CHECK-LABEL: func @test_copy_tensor_op
-func.func @test_copy_tensor_op(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>)
+// CHECK-LABEL: func @test_materialize_in_destination_op
+func.func @test_materialize_in_destination_op(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>)
     -> tensor<?xf32> {
-  // CHECK: bufferization.copy_tensor {{.*}} : tensor<?xf32>
-  %1 = bufferization.copy_tensor %arg0, %arg1 : tensor<?xf32>
+  // CHECK: bufferization.materialize_in_destination {{.*}} : tensor<?xf32>
+  %1 = bufferization.materialize_in_destination %arg0 in %arg1 : tensor<?xf32>
   return %1 : tensor<?xf32>
 }
 
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index d8d0fc56f04406b..5c5d162b7c16f0a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -27,7 +27,7 @@ func.func @static_sizes_output_divisible(%arg0: tensor<24x12xf32>,
   // CHECK-SAME:              outs(%[[T2]] : tensor<4x5xf32>)
 
   //      CHECK: %[[T6:.*]] = tensor.extract_slice %[[T5]]
-  //      CHECK: %[[T7:.*]] = bufferization.copy_tensor %[[T6]], %[[T2]]
+  //      CHECK: %[[T7:.*]] = bufferization.materialize_in_destination %[[T6]] in %[[T2]]
   %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor<?x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32>
   %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32>
   func.return %5 : tensor<24x25xf32>
@@ -40,9 +40,9 @@ transform.sequence failures(propagate) {
     padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32],
     padding_dimensions=[0, 1, 2],
     pack_paddings=[1, 1, 0]
-  } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.op<"bufferization.copy_tensor">)
+  } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.op<"bufferization.materialize_in_destination">)
   // expected-remark @below {{1}}
-  test_print_number_of_associated_payload_ir_ops %copy_back : !transform.op<"bufferization.copy_tensor">
+  test_print_number_of_associated_payload_ir_ops %copy_back : !transform.op<"bufferization.materialize_in_destination">
 }
 
 // -----
@@ -272,7 +272,7 @@ func.func @pack_everything(%arg0: tensor<24x12xf32>,
   //      CHECK: %[[T6:.*]] = tensor.extract_slice %[[T5]]
   // Copy back result to the original buffer, so that the destination of the
   // computation does not change.
-  //      CHECK: %[[T7:.*]] = bufferization.copy_tensor %[[T6]], %[[T2]]
+  //      CHECK: %[[T7:.*]] = bufferization.materialize_in_destination %[[T6]] in %[[T2]]
   %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor<?x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32>
 
   //      CHECK: %[[T8:.*]] = tensor.insert_slice %[[T7]] into %{{.*}}

>From cd1eaae2bc67127d1afdae1a860161665e85f169 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 8 Sep 2023 16:48:50 +0200
Subject: [PATCH 4/4] [mlir][bufferization] Empty tensor elimination for
 materialize_in_destination

This revision adds support for empty tensor elimination to "bufferization.materialize_in_destination" by implementing the `InferDestinationOpInterface`.

Furthermore, the One-Shot Bufferize conflict detection is improved for "bufferization.materialize_in_destination".
---
 .../Dialect/Bufferization/IR/Bufferization.h  |  2 ++
 .../Bufferization/IR/BufferizationOps.td      | 15 +++++++++-
 .../Bufferization/IR/BufferizationOps.cpp     | 25 +++++++++++++++++
 .../one-shot-bufferize-analysis.mlir          | 28 +++++++++++++++++++
 ...ot-bufferize-empty-tensor-elimination.mlir | 14 ++++++++++
 .../llvm-project-overlay/mlir/BUILD.bazel     |  2 ++
 6 files changed, 85 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
index 450dfb37ddb2e18..ff2bbf63d742a0a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
@@ -12,7 +12,9 @@
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/SubsetOpInterface.h"
 #include "mlir/Interfaces/CopyOpInterface.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index a6a733dfce13251..855411c9ec3b934 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -12,6 +12,8 @@
 include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
+include "mlir/Dialect/Bufferization/IR/SubsetOpInterface.td"
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/CopyOpInterface.td"
@@ -215,7 +217,11 @@ def Bufferization_CloneOp : Bufferization_Op<"clone", [
 def Bufferization_MaterializeInDestinationOp
     : Bufferization_Op<"materialize_in_destination",
         [BufferizableOpInterface, SameOperandsAndResultType,
-         DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+         DestinationStyleOpInterface,
+         DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+         DeclareOpInterfaceMethods<SubsetOpInterface,
+            ["getSourceOperand", "getValuesNeededToBuildSubset", "getSubset",
+             "isEquivalentSubset"]>]> {
   let summary = "copy a tensor";
 
   let description = [{
@@ -250,12 +256,19 @@ def Bufferization_MaterializeInDestinationOp
     bool bufferizesToMemoryWrite(OpOperand &opOperand,
                                  const AnalysisState &state);
 
+    bool bufferizesToElementwiseAccess(const AnalysisState &state,
+                                       ArrayRef<OpOperand *> opOperands);
+
     AliasingValueList getAliasingValues(
         OpOperand &opOperand, const AnalysisState &state);
 
     RankedTensorType getType() {
       return ::llvm::cast<RankedTensorType>(getResult().getType());
     }
+
+    std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
+      return {1, 2};  // `dest` operand
+    }
   }];
 
   let assemblyFormat = "$source `in` $dest attr-dict `:` type($source)";
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index e5016c956804688..9e16d45889aad6d 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -580,6 +580,13 @@ MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
   return success();
 }
 
+bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
+    const AnalysisState &state, ArrayRef<OpOperand *> opOperands) {
+  // As elements are copied from the "source" buffer to the "dest" buffer,
+  // already copied elements are not read a second time.
+  return true;
+}
+
 LogicalResult MaterializeInDestinationOp::reifyResultShapes(
     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
   reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
@@ -587,6 +594,24 @@ LogicalResult MaterializeInDestinationOp::reifyResultShapes(
   return success();
 }
 
+Value MaterializeInDestinationOp::getSubset(OpBuilder &builder, Location loc) {
+  // The subset is the entire destination tensor.
+  return getDest();
+}
+
+bool MaterializeInDestinationOp::isEquivalentSubset(
+    Value candidate, function_ref<bool(Value, Value)> equivalenceFn) {
+  return equivalenceFn(getDest(), candidate);
+}
+
+SmallVector<Value> MaterializeInDestinationOp::getValuesNeededToBuildSubset() {
+  return {getDest()};
+}
+
+OpOperand &MaterializeInDestinationOp::getSourceOperand() {
+  return getOperation()->getOpOperand(0) /*source*/;
+}
+
 //===----------------------------------------------------------------------===//
 // ToTensorOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
index 5a505c66892f113..a2fbb06d179ebda 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
@@ -158,3 +158,31 @@ func.func @bbarg_of_unknown_op_2(%f: f32) {
   // CHECK: {__inplace_operands_attr__ = ["false"]} : (tensor<10xf32>) -> ()
   return
 }
+
+// -----
+
+// CHECK: func @materialize_in_destination_aliasing(
+func.func @materialize_in_destination_aliasing(%t: tensor<?xf32>, %p1: index, %p2: index, %sz: index) -> tensor<5xf32> {
+  %buffer = tensor.empty(%sz) : tensor<?xf32>
+  // CHECK: tensor.extract_slice
+  // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"]}
+  %src = tensor.extract_slice %t[%p1][5][1] : tensor<?xf32> to tensor<5xf32>
+  // CHECK: tensor.extract_slice
+  // CHECK-SAME: {__inplace_operands_attr__ = ["false", "none"]}
+  %dest = tensor.extract_slice %t[%p2][5][1] : tensor<?xf32> to tensor<5xf32>
+  // CHECK: bufferization.materialize_in_destination
+  // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]}
+  %r = bufferization.materialize_in_destination %src in %dest : tensor<5xf32>
+  return %r : tensor<5xf32>
+}
+
+// -----
+
+// CHECK: func @materialize_in_destination(
+func.func @materialize_in_destination(%t: tensor<?xf32>, %sz: index) -> tensor<?xf32> {
+  %buffer = tensor.empty(%sz) : tensor<?xf32>
+  // CHECK: bufferization.materialize_in_destination
+  // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]}
+  %r = bufferization.materialize_in_destination %buffer in %buffer : tensor<?xf32>
+  return %r : tensor<?xf32>
+}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index 3d15599915f0cfc..063ed07467f90d0 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -291,3 +291,17 @@ func.func @regression_multiple_insertion_points(%t1: tensor<?x?xf32>) -> tensor<
   %2 = tensor.insert_slice %filled into %t1 [%0, %1] [2, 5] [1, 1] : tensor<2x5xf32> into tensor<?x?xf32>
   return %2 : tensor<?x?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @materialize_in_destination(
+//  CHECK-SAME:     %[[m:.*]]: memref<5xf32, strided<[?], offset: ?>>,
+//       CHECK:   linalg.fill {{.*}} outs(%[[m]]
+//       CHECK:   return %[[m]]
+func.func @materialize_in_destination(%t: tensor<5xf32>, %f: f32) -> tensor<5xf32> {
+  %0 = tensor.empty() : tensor<5xf32>
+  %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
+  %1 = bufferization.materialize_in_destination %filled in %t : tensor<5xf32>
+  //%1 = tensor.insert_slice %filled into %t[0][5][1] : tensor<5xf32> into tensor<5xf32>
+  return %1 : tensor<5xf32>
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index b3645710232f2fa..6b615c43aac7f0f 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -11993,6 +11993,8 @@ gentbl_cc_library(
     deps = [
         ":BufferizableOpInterfaceTdFiles",
         ":BufferizationOpsTdFiles",
+        ":DestinationStyleOpInterfaceTdFiles",
+        ":SubsetOpInterfaceTdFiles",
     ],
 )
 



More information about the llvm-commits mailing list