[Mlir-commits] [mlir] [mlir][bufferization] Empty tensor elimination for materialize_in_destination (PR #65468)

Matthias Springer llvmlistbot at llvm.org
Wed Sep 6 05:26:47 PDT 2023


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/65468:

This revision adds support for empty tensor elimination to "bufferization.materialize_in_destination" by implementing the `InferDestinationOpInterface`.

Furthermore, the One-Shot Bufferize conflict detection is improved for "bufferization.materialize_in_destination".

Depends on D159415 and #65467. Review only the top commit.


>From 0198fcbf4bdaf339c1001d8d3881f68ef0b0a827 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 6 Sep 2023 14:24:02 +0200
Subject: [PATCH 1/3] [mlir][bufferization] Add interface for empty tensor
 elimination

This revision generalizes empty tensor elimination to arbitrary ops that implement the new interface.

Empty tensor elimination tries looks for ops where a "source" tensor is transferred to a new destination and the source tensor can be traced back to a "tensor.empty". Such "tensor.empty" can be replaced with the destination (or a slice thereof), such that the computation takes place on the destination buffer directly (after bufferization).

There are currently two ops that are supported in empty tensor elimination: "tensor.insert_slice" and "linalg.generic". Furthermore, downstream projects (e.g., IREE) may have their own ops. We would like to support "bufferization.copy_tensor" in the future and to avoid code duplication around empty tensor elimination, the `InferDestinationOpInterface` is added. This interface retrieves or builds the destination into which the result of the computation is written. It is different from `DestinationStyleOpInterface` because slices of the destination tensor may be built and because users can specify cases in which empty tensor elimination does not pay off (and in which it should not be applied).
---
 .../Dialect/Bufferization/IR/CMakeLists.txt   |   1 +
 .../IR/InferDestinationOpInterface.h          |  20 ++++
 .../IR/InferDestinationOpInterface.td         | 103 +++++++++++++++++
 .../Bufferization/Transforms/Transforms.h     |  19 +---
 .../Linalg/TransformOps/LinalgTransformOps.td |  68 -----------
 .../Dialect/Linalg/Transforms/Transforms.h    |  29 -----
 .../Dialect/Bufferization/IR/CMakeLists.txt   |   1 +
 .../IR/InferDestinationOpInterface.cpp        |  10 ++
 .../BufferizationTransformOps.cpp             |   3 +-
 .../Transforms/EmptyTensorElimination.cpp     |  93 ++-------------
 .../TransformOps/LinalgTransformOps.cpp       |  30 -----
 .../BufferizableOpInterfaceImpl.cpp           |  59 ++++++++++
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   1 -
 .../Transforms/EliminateEmptyTensors.cpp      | 107 ------------------
 .../BufferizableOpInterfaceImpl.cpp           |  46 ++++++++
 ...ot-bufferize-empty-tensor-elimination.mlir |  18 ++-
 .../llvm-project-overlay/mlir/BUILD.bazel     |  33 ++++++
 17 files changed, 298 insertions(+), 343 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Bufferization/IR/InferDestinationOpInterface.h
 create mode 100644 mlir/include/mlir/Dialect/Bufferization/IR/InferDestinationOpInterface.td
 create mode 100644 mlir/lib/Dialect/Bufferization/IR/InferDestinationOpInterface.cpp
 delete mode 100644 mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
index aa93534a78fea3f..7be6c5e3a9dc327 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect(BufferizationOps bufferization)
 add_mlir_doc(BufferizationOps BufferizationOps Dialects/ -gen-dialect-doc)
 add_mlir_interface(AllocationOpInterface)
 add_mlir_interface(BufferizableOpInterface)
+add_mlir_interface(InferDestinationOpInterface)
 
 set(LLVM_TARGET_DEFINITIONS BufferizationEnums.td)
 mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/InferDestinationOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/InferDestinationOpInterface.h
new file mode 100644
index 000000000000000..83f5d87bb1ea091
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/InferDestinationOpInterface.h
@@ -0,0 +1,20 @@
+//===- InferDestinationOpInterface.h - Infer destination --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_IR_INFERDESTINATIONOPINTERFACE_H_
+#define MLIR_DIALECT_BUFFERIZATION_IR_INFERDESTINATIONOPINTERFACE_H_
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+class OpBuilder;
+} // namespace mlir
+
+#include "mlir/Dialect/Bufferization/IR/InferDestinationOpInterface.h.inc"
+
+#endif // MLIR_DIALECT_BUFFERIZATION_IR_INFERDESTINATIONOPINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/InferDestinationOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/InferDestinationOpInterface.td
new file mode 100644
index 000000000000000..d50f153cc2be9d7
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/InferDestinationOpInterface.td
@@ -0,0 +1,103 @@
+//===-- InferDestinationOpInterface.td - Infer destination -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef INFER_DESTINATION_OP_INTERFACE
+#define INFER_DESTINATION_OP_INTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def InferDestinationOpInterface : OpInterface<"InferDestinationOpInterface"> {
+  let description = [{
+    This interface can be implemented by ops that read data from an "input"
+    tensor and store the result into an "output"/"init" tensor (i.e., the
+    "destination" tensor). It drives the empty tensor elimination pass.
+
+    The `getOrBuildDestination` interface method returns the destination tensor
+    (or a slice thereof). Assuming that the op does not bufferize to a memory
+    read on the destination tensor (or the aforementioned slice), if the source
+    originates from a "tensor.empty", that "tensor.empty" can be replaced with
+    the result of `getOrBuildDestination`. This can reduce the number of
+    allocations and memcpys during bufferization: instead of computing data in
+    a temporary buffer (the result of "tensor.empty") and then copying it into
+    the destination buffer, the result will be computed in the destination
+    buffer directly.
+
+    Example:
+    ```
+    %0 = tensor.empty() : tensor<5xf32>
+    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
+    %2 = tensor.insert_slice %1 into %dst [%offset][5][1]
+        : tensor<5xf32> into tensor<?xf32>
+    ```
+
+    "tensor.insert_slice" transfers %1 into a slice of %dst.
+    `getOrBuildDestination` should the slice. `getNeededValues` should return
+    %dst and %offset, because these values are needed to build the destination.
+
+    After empty tensor elimination:
+    ```
+    %0 = tensor.extract_slice %dst [%offset][5][1]
+        : tensor<?xf32> to tensor<5xf32>
+    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
+    %2 = tensor.insert_slice %1 into %dst [%offset][5][1]
+        : tensor<5xf32> into tensor<?xf32>
+    ```
+
+    This IR will bufferize without an allocation and without a memcpy.
+  }];
+  let cppNamespace = "::mlir::bufferization";
+  let methods = [
+      InterfaceMethod<
+        /*desc=*/[{
+          Build or return the destination tensor (or a slice thereof) into which
+          the given operand (or an element-wise computation thereof) will be
+          stored. Only values returned by `getNeededValues` may be used to build
+          the destination. This interface method will be called only on tensor
+          operands for which `isAnchor` returns "true".
+        }],
+        /*retType=*/"::mlir::Value",
+        /*methodName=*/"getOrBuildDestination",
+        /*args=*/(ins "::mlir::OpBuilder &":$builder,
+                      "::mlir::Location":$loc,
+                      "::mlir::OpOperand &":$operand),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          llvm_unreachable("getOrBuildDestination not implemented");
+        }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return the SSA values that are needed to build the destination tensor.
+          This interface method will be called only on tensor operands for which
+          `isAnchor` returns "true"
+        }],
+        /*retType=*/"::llvm::SmallVector<::mlir::Value>",
+        /*methodName=*/"getNeededValues",
+        /*args=*/(ins "::mlir::OpOperand &":$operand),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          llvm_unreachable("getNeededValues not implemented");
+        }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return "true" if the given operand is an "input" tensor on which
+          empty tensor elimination can be applied.
+        }],
+        /*retType=*/"bool",
+        /*methodName=*/"isAnchor",
+        /*args=*/(ins "::mlir::OpOperand &":$operand),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          llvm_unreachable("isAnchor not implemented");
+        }]
+      >,
+  ];
+}
+
+#endif // INFER_DESTINATION_OP_INTERFACE
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
index a0cfc811a0b50a5..61a775fbaac5128 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
@@ -19,15 +19,6 @@ struct BufferizationStatistics;
 class OneShotAnalysisState;
 struct OneShotBufferizationOptions;
 
-/// A function that matches anchor OpOperands for tensor::EmptyOp elimination.
-/// If an OpOperand is matched, the function should populate the SmallVector
-/// with all values that are needed during `RewriteFn` to produce the
-/// replacement value.
-using AnchorMatchFn = std::function<bool(OpOperand &, SmallVector<Value> &)>;
-
-/// A function that rewrites matched anchors.
-using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;
-
 /// Try to eliminate tensor::EmptyOps inside `op`.
 ///
 /// * `rewriteFunc` generates the replacement for the tensor::EmptyOp.
@@ -37,20 +28,12 @@ using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;
 ///   following the aliasing  OpOperand, that eventually ends at a single
 ///   tensor::EmptyOp.
 LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op,
-                                    OneShotAnalysisState &state,
-                                    AnchorMatchFn anchorMatchFunc,
-                                    RewriteFn rewriteFunc);
+                                    OneShotAnalysisState &state);
 
 /// Within the given operation, hoist buffers from loops where possible. See
 /// "BufferLoopHoistingPass" for more information.
 void hoistBuffersFromLoops(Operation *op);
 
-/// Try to eliminate tensor::EmptyOps inside `op` that are anchored on an
-/// InsertSliceOp, i.e., if it is eventually inserted into another tensor
-/// (and some other conditions are met).
-LogicalResult insertSliceAnchoredEmptyTensorEliminationStep(
-    RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state);
-
 /// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
 /// After applying this transform, the IR can be bufferized without inserting
 /// additional buffer allocations.
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index ee6e12f72b80bab..858be9e60621aa7 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -198,74 +198,6 @@ def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
   }];
 }
 
-//===----------------------------------------------------------------------===//
-// EliminateLinalgOpAnchoredEmptyTensorsOp
-//===----------------------------------------------------------------------===//
-
-def EliminateLinalgOpAnchoredEmptyTensorsOp
-    : Op<Transform_Dialect, "structured.eliminate_empty_tensors",
-        [DeclareOpInterfaceMethods<TransformOpInterface>,
-         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
-  let description = [{
-    Try to eliminate all `tensor.empty` op uses that are anchored on a LinalgOp
-    within the targeted op.
-
-    This op is similar to `bufferization.eliminate_empty_tensors`, but specific
-    to LinalgOps.
-
-    `tensor.empty` ops cannot be bufferized. They can either be converted to
-    `bufferization.alloc_tensor` or replaced with another tensor (via this
-    transform). `tensor.empty` does not specify the contents of the returned
-    tensor so their results can be replaced with arbitrary tensor values as long
-    as the dimensions match.
-
-    This transform looks for `tensor.empty` ops where the SSA use-def chain of
-    the result ends in a supported LinalgOp (always following the aliasing
-    OpOperand/OpResult chain). The following LinalgOps are supported:
-    - Only parallel iterator types.
-    - The use-def chain ends in an input operand of the LinalgOp.
-    - The LinalgOp has an unused output operand with the same shape and
-      indexing map.
-
-    Example:
-
-    ```
-    %0 = tensor.empty()
-    %1 = linalg.matmul ins(...) outs(%0)
-    %2 = linalg.generic ins(%1) outs(%dest) {
-      ^bb0(%in: f32, %out: f32):
-      // out not used
-    }
-    ```
-
-    Is rewritten with:
-    ```
-    %0 = tensor.empty()
-    %1 = linalg.matmul ins(...) outs(%dest)
-    %2 = linalg.generic ins(%0) outs(%1) {
-      ^bb0(%in: f32, %out: f32):
-      // Use %out instead of %in
-    }
-    ```
-
-    After this transformation, the "ins" operand has no uses inside the body of
-    the LinalgOp and can be folded away with existing cleanup patterns.
-    Afterwards, the tensor::EmptyOp can also fold away, so that the example can
-    bufferize without an allocation (in the absence of other conflicts).
-
-    #### Return modes
-
-    This transform reads the target handle and modifies the payload. It does
-    not produce any handle.
-  }];
-
-  let arguments = (ins TransformHandleTypeInterface:$target);
-
-  let results = (outs);
-
-  let assemblyFormat = "$target attr-dict `:` type($target)";
-}
-
 //===----------------------------------------------------------------------===//
 // FuseOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index fd82c67ede5fa97..82f7915395ba2d9 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -125,35 +125,6 @@ Value bufferizeToAllocation(RewriterBase &rewriter,
                             Operation *op, Attribute memorySpace = {},
                             Operation *insertionPoint = nullptr);
 
-/// Try to eliminate tensor::EmptyOps inside `op` that are anchored on a
-/// LinalgOp. This transforms looks for LinalgOps that have an unused output
-/// operand and an input operand that is rooted in a tensor::EmptyOp. The
-/// tensor::EmptyOp uses are replaced with the output operand and the two
-/// operands of the LinalgOp are swapped.
-///
-/// Example:
-/// %0 = tensor.empty()
-/// %1 = linalg.matmul ins(...) outs(%0)
-/// %2 = linalg.generic ins(%1) outs(%dest) {
-///   ^bb0(%in: f32, %out: f32):
-///   // out not used
-/// }
-///
-/// The IR is transformed as follows:
-/// %0 = tensor.empty()
-/// %1 = linalg.matmul ins(...) outs(%dest)
-/// %2 = linalg.generic ins(%0) outs(%1) {
-///   ^bb0(%in: f32, %out: f32):
-///   // Use %out instead of %in
-/// }
-///
-/// The "ins" operand has no uses inside the body of the LinalgOp and can be
-/// folded away with existing cleanup patterns. Afterwards, the tensor::EmptyOp
-/// can also fold away.
-LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(
-    RewriterBase &rewriter, Operation *op,
-    bufferization::OneShotAnalysisState &state);
-
 //===----------------------------------------------------------------------===//
 // Structs that configure the behavior of various transformations.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
index 2d8d09b9c41d993..e5f2bc4cff20662 100644
--- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
   BufferizableOpInterface.cpp
   BufferizationOps.cpp
   BufferizationDialect.cpp
+  InferDestinationOpInterface
   UnstructuredControlFlow.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Bufferization/IR/InferDestinationOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/InferDestinationOpInterface.cpp
new file mode 100644
index 000000000000000..024d82c24383e4a
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/IR/InferDestinationOpInterface.cpp
@@ -0,0 +1,10 @@
+//===- InferDestinationOpInterface.cpp -  Infer destination ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/InferDestinationOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/InferDestinationOpInterface.cpp.inc"
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index 097f75a7bc50f5b..b84cc452d0141cd 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -121,8 +121,7 @@ DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply(
     if (failed(analyzeOp(target, state)))
       return mlir::emitSilenceableFailure(target->getLoc())
              << "failed to analyze op";
-    if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep(
-            rewriter, target, state)))
+    if (failed(bufferization::eliminateEmptyTensors(rewriter, target, state)))
       return mlir::emitSilenceableFailure(target->getLoc())
              << "failed to eliminate insert_slice anchored tensor.empty ops";
   }
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 4e0781dae0c2523..3ba7747444195f1 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -10,6 +10,7 @@
 
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/InferDestinationOpInterface.h"
 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -114,20 +115,21 @@ findValidInsertionPoint(Operation *emptyTensorOp,
 /// op. When tracing back the reverse use-def chain, we end up at a
 /// tensor.empty op.
 LogicalResult mlir::bufferization::eliminateEmptyTensors(
-    RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
-    AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc) {
+    RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
   OpBuilder::InsertionGuard g(rewriter);
 
-  op->walk([&](Operation *op) {
+  op->walk([&](InferDestinationOpInterface op) {
     for (OpOperand &operand : op->getOpOperands()) {
+      if (!isa<RankedTensorType>(operand.get().getType()))
+        continue;
       // Skip operands that do not bufferize inplace.
       if (!state.isInPlace(operand))
         continue;
-      // All values that are needed to create the replacement op.
-      SmallVector<Value> neededValues;
       // Is this an anchor?
-      if (!anchorMatchFunc(operand, neededValues))
+      if (!op.isAnchor(operand))
         continue;
+      // All values that are needed to create the replacement op.
+      SmallVector<Value> neededValues = op.getNeededValues(operand);
 
       // Find tensor.empty ops on the reverse SSA use-def chain. Only follow
       // equivalent tensors. I.e., stop when there are ops such as extract_slice
@@ -159,8 +161,8 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
           continue;
 
         rewriter.setInsertionPoint(insertionPoint);
-        Value replacement =
-            rewriteFunc(rewriter, emptyTensorOp->getLoc(), operand);
+        Value replacement = op.getOrBuildDestination(
+            rewriter, emptyTensorOp->getLoc(), operand);
         if (!replacement)
           continue;
         if (replacement.getType() != v.getType()) {
@@ -178,78 +180,6 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
   return success();
 }
 
-/// Try to eliminate tensor::EmptyOps inside `op`. An tensor::EmptyOp can be
-/// eliminated if it is eventually inserted into another tensor (and some other
-/// conditions are met).
-///
-/// E.g.:
-/// %0 = tensor.empty()
-/// %1 = linalg.fill(%cst, %0) {inplace = [true]}
-/// %2 = tensor.insert_slice %1 into %t[10][20][1]
-///
-/// tensor::EmptyOp elimination will try to fill %t inplace instead of filling a
-/// new allocation %0 and inserting it into %t. This is done by replacing the
-/// tensor::EmptyOp with:
-///
-/// %0 = tensor.extract_slice %t[10][20][1]
-///
-/// The analysis looks for matching ExtractSliceOp/InsertSliceOp pairs and lets
-/// those bufferize inplace in the absence of other conflicts.
-///
-/// Starting from an InsertSliceOp, an tensor::EmptyOp at the end of the insert
-/// source's reverse use-def chain is eliminated if:
-/// * On the reverse use-def chain path from the InsertSliceOp to the
-///   tensor::EmptyOp, all ops were decided to bufferize inplace and the buffer
-///   relation is "equivalent" (TODO: can be relaxed if needed).
-/// * The reverse use-def chain has exactly one end, which is the
-///   tensor::EmptyOp.
-template <typename OpTy>
-static LogicalResult insertSliceLikeAnchoredEmptyTensorEliminationStep(
-    RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
-  return eliminateEmptyTensors(
-      rewriter, op, state,
-      /*anchorMatchFunc=*/
-      [&](OpOperand &operand, SmallVector<Value> &neededValues) {
-        auto insertSliceOp = dyn_cast<OpTy>(operand.getOwner());
-        if (!insertSliceOp)
-          return false;
-        if (&operand != &insertSliceOp->getOpOperand(0) /*source*/)
-          return false;
-
-        // Collect all values that are needed to construct the replacement op.
-        neededValues.append(insertSliceOp.getOffsets().begin(),
-                            insertSliceOp.getOffsets().end());
-        neededValues.append(insertSliceOp.getSizes().begin(),
-                            insertSliceOp.getSizes().end());
-        neededValues.append(insertSliceOp.getStrides().begin(),
-                            insertSliceOp.getStrides().end());
-        neededValues.push_back(insertSliceOp.getDest());
-
-        return true;
-      },
-      /*rewriteFunc=*/
-      [](OpBuilder &b, Location loc, OpOperand &operand) {
-        auto insertOp = cast<OpTy>(operand.getOwner());
-        auto extractOp = b.create<tensor::ExtractSliceOp>(
-            loc, insertOp.getSourceType(), insertOp.getDest(),
-            insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
-            insertOp.getMixedStrides());
-        return extractOp.getResult();
-      });
-}
-
-LogicalResult
-mlir::bufferization::insertSliceAnchoredEmptyTensorEliminationStep(
-    RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
-  if (failed(insertSliceLikeAnchoredEmptyTensorEliminationStep<
-             tensor::InsertSliceOp>(rewriter, op, state)))
-    return failure();
-  if (failed(insertSliceLikeAnchoredEmptyTensorEliminationStep<
-             tensor::ParallelInsertSliceOp>(rewriter, op, state)))
-    return failure();
-  return success();
-}
-
 namespace {
 struct EmptyTensorElimination
     : public bufferization::impl::EmptyTensorEliminationBase<
@@ -276,8 +206,7 @@ void EmptyTensorElimination::runOnOperation() {
   }
 
   IRRewriter rewriter(op->getContext());
-  if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep(
-          rewriter, op, state)))
+  if (failed(bufferization::eliminateEmptyTensors(rewriter, op, state)))
     signalPassFailure();
 }
 
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 6549c27b0d0dfb4..6aa575dd865cf75 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -373,36 +373,6 @@ DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
-//===----------------------------------------------------------------------===//
-// EliminateLinalgOpAnchoredEmptyTensorsOp
-//===----------------------------------------------------------------------===//
-
-void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
-    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  onlyReadsHandle(getTarget(), effects);
-  modifiesPayload(effects);
-}
-
-DiagnosedSilenceableFailure
-transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
-    transform::TransformRewriter &rewriter, TransformResults &transformResults,
-    TransformState &state) {
-  bufferization::OneShotBufferizationOptions options;
-  options.allowReturnAllocs = true;
-
-  for (Operation *target : state.getPayloadOps(getTarget())) {
-    bufferization::OneShotAnalysisState state(target, options);
-    if (failed(analyzeOp(target, state)))
-      return mlir::emitSilenceableFailure(target->getLoc())
-             << "failed to analyze op";
-    if (failed(linalg::linalgOpAnchoredEmptyTensorEliminationStep(
-            rewriter, target, state)))
-      return mlir::emitSilenceableFailure(target->getLoc())
-             << "failed to eliminate LinalgOp anchored tensor.empty ops";
-  }
-  return DiagnosedSilenceableFailure::success();
-}
-
 //===----------------------------------------------------------------------===//
 // FuseOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 0577441bdd28d27..119c4b4ded0d0ff 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -7,9 +7,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
+
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/IR/InferDestinationOpInterface.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Dialect.h"
@@ -143,12 +145,69 @@ struct LinalgOpInterface
   }
 };
 
+template <typename OpTy>
+struct LinalgOpInferDestinationInterface
+    : public InferDestinationOpInterface::ExternalModel<
+          LinalgOpInferDestinationInterface<OpTy>, OpTy> {
+  static Value findMatchingDestinationOperand(LinalgOp linalgOp,
+                                              OpOperand &inOperand) {
+    // Only element-wise ops with a single result are supported. "tensor.empty"
+    // anchored on non-element-wise ops could still be eliminated but this op
+    // would bufferize out-of-place because of a RaW between `inOperand` and the
+    // "out" operand. (Both operands are aliasing after empty tensor
+    // elimination.)
+    if (linalgOp.getNumDpsInits() != 1)
+      return {};
+    if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
+      return {};
+    OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
+    // Operand must be unused. If it is used, the "tensor.empty" could still be
+    // eliminated but an OpOperand would have to bufferize out-of-place due to
+    // a RaW conflict.
+    if (linalgOp.payloadUsesValueFromOperand(outOperand))
+      return {};
+    // Types must match. Other cases are not supported by empty tensor
+    // elimination.
+    if (outOperand->get().getType() != inOperand.get().getType())
+      return {};
+    // Indexing maps must match.
+    if (linalgOp.getMatchingIndexingMap(outOperand) !=
+        linalgOp.getMatchingIndexingMap(&inOperand))
+      return {};
+    return outOperand->get();
+  }
+
+  Value getOrBuildDestination(Operation *op, OpBuilder &builder, Location loc,
+                              OpOperand &operand) const {
+    return findMatchingDestinationOperand(cast<linalg::LinalgOp>(op), operand);
+  }
+
+  SmallVector<Value> getNeededValues(Operation *op, OpOperand &operand) const {
+    SmallVector<Value> neededValues;
+    neededValues.push_back(
+        findMatchingDestinationOperand(cast<LinalgOp>(op), operand));
+    return neededValues;
+  }
+
+  bool isAnchor(Operation *op, OpOperand &operand) const {
+    // Check if `operand` (or a computation on its data) is transferred into an
+    // unused "init".
+    auto linalgOp = cast<LinalgOp>(op);
+    if (!linalgOp.isDpsInput(&operand))
+      return false;
+    return static_cast<bool>(findMatchingDestinationOperand(linalgOp, operand));
+  }
+};
+
 /// Helper structure that iterates over all LinalgOps in `OpTys` and registers
 /// the `BufferizableOpInterface` with each of them.
 template <typename... Ops>
 struct LinalgOpInterfaceHelper {
   static void registerOpInterface(MLIRContext *ctx) {
     (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), ...);
+    (Ops::template attachInterface<LinalgOpInferDestinationInterface<Ops>>(
+         *ctx),
+     ...);
   }
 };
 } // namespace
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 5ae9b7f7b1efc3f..82787a3f70eb9a6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -11,7 +11,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   DropUnitDims.cpp
   ElementwiseOpFusion.cpp
   ElementwiseToLinalg.cpp
-  EliminateEmptyTensors.cpp
   EraseUnusedOperandsAndResults.cpp
   FusePadOpWithLinalgProducer.cpp
   Fusion.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
deleted file mode 100644
index 4b754065e318971..000000000000000
--- a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
+++ /dev/null
@@ -1,107 +0,0 @@
-//===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-
-#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
-#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
-#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-
-using namespace mlir;
-using namespace mlir::bufferization;
-using namespace mlir::linalg;
-
-/// Get an output operand that matches the given input operand and can be used
-/// to eliminate a tensor.empty op.
-static OpOperand *getUnusedOutOperand(LinalgOp op, OpOperand *in) {
-  for (OpOperand *operand : op.getDpsInitOperands()) {
-    // Operand must be unused.
-    if (op.payloadUsesValueFromOperand(operand))
-      continue;
-    // Types must match.
-    if (operand->get().getType() != in->get().getType())
-      continue;
-    // Indexing maps must match.
-    if (op.getMatchingIndexingMap(operand) != op.getMatchingIndexingMap(in))
-      continue;
-    return operand;
-  }
-  return nullptr;
-}
-
-LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
-    RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
-  OpBuilder::InsertionGuard g(rewriter);
-  DominanceInfo domInfo;
-
-  op->walk([&](LinalgOp op) {
-    // Only ops with all "parallel" iterator types are supported.
-    if (op.getNumParallelLoops() != op.getNumLoops())
-      return WalkResult::skip();
-
-    for (OpOperand *in : op.getDpsInputOperands()) {
-      // Skip non-tensor operands.
-      if (!in->get().getType().isa<RankedTensorType>())
-        continue;
-
-      // Find tensor.empty ops on the reverse SSA use-def chain. Only follow
-      // equivalent tensors. I.e., stop when there are ops such as extract_slice
-      // on the path.
-      TraversalConfig config;
-      config.followEquivalentOnly = true;
-      config.alwaysIncludeLeaves = false;
-      SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
-          in->get(), /*condition=*/
-          [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
-          config);
-      if (emptyTensors.empty())
-        continue;
-
-      // Find matching out operand.
-      OpOperand *out = getUnusedOutOperand(op, in);
-      if (!out)
-        continue;
-
-      // Check if this transform would violate dominance.
-      if (!llvm::all_of(emptyTensors, [&](Value v) {
-            return domInfo.properlyDominates(out->get(), v.getDefiningOp());
-          }))
-        continue;
-
-      // Replace all uses of the tensor.empty, but do not delete it yet. It will
-      // fold away later (to not invalidate DominanceInfo).
-      for (Value v : emptyTensors) {
-        assert(v.getDefiningOp<tensor::EmptyOp>() && "expected tensor.empty");
-        rewriter.replaceAllUsesWith(v, out->get());
-      }
-
-      // Turn the "in" into an "out".
-      rewriter.updateRootInPlace(op, [&]() {
-        out->set(in->get());
-        // The original "in" could be removed entirely here (because it will no
-        // longer have any uses in the payload), but we delegate this to
-        // existing cleanup patterns that remove unused operands.
-        in->set(emptyTensors.front());
-        BlockArgument outArg = op.getMatchingBlockArgument(out);
-        assert(outArg.getUses().empty() && "expected that out has no uses");
-        BlockArgument inArg = op.getMatchingBlockArgument(in);
-        rewriter.replaceAllUsesWith(inArg, outArg);
-        assert(!op.payloadUsesValueFromOperand(in) &&
-               "expected that the in operand is now unused");
-      });
-
-      state.resetCache();
-    }
-
-    return WalkResult::advance();
-  });
-  return success();
-}
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index a67ea0334b22b9b..84a31a9872d2486 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -7,11 +7,13 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
+
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/IR/InferDestinationOpInterface.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -1147,6 +1149,44 @@ struct SplatOpInterface
   }
 };
 
+template <typename OpTy>
+struct InsertSliceLikeOpInferDestinationInterface
+    : public InferDestinationOpInterface::ExternalModel<
+          InsertSliceLikeOpInferDestinationInterface<OpTy>, OpTy> {
+  Value getOrBuildDestination(Operation *op, OpBuilder &builder, Location loc,
+                              OpOperand &operand) const {
+    auto insertSliceOp = cast<OpTy>(op);
+    assert(&operand == &op->getOpOperand(0) /*source*/ &&
+           "expected source operand");
+    auto extractOp = builder.create<tensor::ExtractSliceOp>(
+        loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(),
+        insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
+        insertSliceOp.getMixedStrides());
+    return extractOp.getResult();
+  }
+
+  SmallVector<Value> getNeededValues(Operation *op, OpOperand &operand) const {
+    auto insertSliceOp = cast<OpTy>(op);
+    assert(&operand == &op->getOpOperand(0) /*source*/ &&
+           "expected source operand");
+    SmallVector<Value> neededValues;
+    // Collect all values that are needed to construct the replacement op.
+    neededValues.append(insertSliceOp.getOffsets().begin(),
+                        insertSliceOp.getOffsets().end());
+    neededValues.append(insertSliceOp.getSizes().begin(),
+                        insertSliceOp.getSizes().end());
+    neededValues.append(insertSliceOp.getStrides().begin(),
+                        insertSliceOp.getStrides().end());
+    neededValues.push_back(insertSliceOp.getDest());
+    return neededValues;
+  }
+
+  bool isAnchor(Operation *op, OpOperand &operand) const {
+    // The source is transferred into the destination.
+    return &operand == &op->getOpOperand(0) /*source*/;
+  }
+};
+
 } // namespace
 } // namespace tensor
 } // namespace mlir
@@ -1165,9 +1205,15 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
     GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
     InsertOp::attachInterface<InsertOpInterface>(*ctx);
     InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
+    InsertSliceOp::attachInterface<
+        InsertSliceLikeOpInferDestinationInterface<tensor::InsertSliceOp>>(
+        *ctx);
     PadOp::attachInterface<PadOpInterface>(*ctx);
     ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
         *ctx);
+    ParallelInsertSliceOp::attachInterface<
+        InsertSliceLikeOpInferDestinationInterface<
+            tensor::ParallelInsertSliceOp>>(*ctx);
     RankOp::attachInterface<RankOpInterface>(*ctx);
     ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
     SplatOp::attachInterface<SplatOpInterface>(*ctx);
diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir
index 939eea37e9b7983..f5a4d66d66a0513 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -1,12 +1,21 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | \
+// RUN:     FileCheck %s
+// RUN: mlir-opt %s --test-transform-dialect-interpreter \
+// RUN:             --one-shot-bufferize="bufferize-function-boundaries" \
+// RUN:             --split-input-file | \
+// RUN:     FileCheck %s --check-prefix=CHECK-BUFFERIZE
 
 // CHECK-LABEL: func.func @eliminate_tensor_empty(
 //  CHECK-SAME:     %[[arg0:.*]]: tensor<50x91xf32>,
 //   CHECK-NOT:   tensor.empty
 //       CHECK:   %[[filled:.*]] = linalg.fill {{.*}} outs(%[[arg0]]
 //       CHECK:   %[[matmul:.*]] = linalg.matmul {{.*}} outs(%[[filled]]
-//       CHECK:   %[[generic:.*]] = linalg.generic {{.*}} outs(%[[matmul]]
+//       CHECK:   %[[generic:.*]] = linalg.generic {{.*}} ins(%{{.*}}, %[[matmul]] : {{.*}}) outs(%[[arg0]]
 //       CHECK:   return %[[generic]]
+
+// CHECK-BUFFERIZE-LABEL: func @eliminate_tensor_empty(
+//   CHECK-BUFFERIZE-NOT:   memref.alloc
+//   CHECK-BUFFERIZE-NOT:   memref.copy
 func.func @eliminate_tensor_empty(
     %arg0: tensor<50x91xf32>, %arg1: tensor<91xf32>, %arg2: tensor<50x1280xf32>,
     %arg3: tensor<1280x91xf32>) -> tensor<50x91xf32>
@@ -35,8 +44,5 @@ func.func @eliminate_tensor_empty(
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-  transform.structured.eliminate_empty_tensors %0 : !transform.any_op
-  transform.apply_patterns to %0 {
-    transform.apply_patterns.linalg.erase_unnecessary_inputs
-  } : !transform.any_op
+  transform.bufferization.eliminate_empty_tensors %0 : !transform.any_op
 }
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 7829bb0ffbd2932..ece4913c8753444 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -9705,6 +9705,36 @@ gentbl_cc_library(
     ],
 )
 
+td_library(
+    name = "InferDestinationOpInterfaceTdFiles",
+    srcs = [
+        "include/mlir/Dialect/Bufferization/IR/InferDestinationOpInterface.td",
+    ],
+    includes = ["include"],
+    deps = [
+        ":OpBaseTdFiles",
+    ],
+)
+
+gentbl_cc_library(
+    name = "InferDestinationOpInterfaceIncGen",
+    tbl_outs = [
+        (
+            ["-gen-op-interface-decls"],
+            "include/mlir/Dialect/Bufferization/IR/InferDestinationOpInterface.h.inc",
+        ),
+        (
+            ["-gen-op-interface-defs"],
+            "include/mlir/Dialect/Bufferization/IR/InferDestinationOpInterface.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/Bufferization/IR/InferDestinationOpInterface.td",
+    deps = [
+        ":InferDestinationOpInterfaceTdFiles",
+    ],
+)
+
 td_library(
     name = "LinalgDocTdFiles",
     srcs = ["include/mlir/Dialect/Linalg/IR/LinalgDoc.td"],
@@ -11972,12 +12002,14 @@ cc_library(
         "lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp",
         "lib/Dialect/Bufferization/IR/BufferizationDialect.cpp",
         "lib/Dialect/Bufferization/IR/BufferizationOps.cpp",
+        "lib/Dialect/Bufferization/IR/InferDestinationOpInterface.cpp",
         "lib/Dialect/Bufferization/IR/UnstructuredControlFlow.cpp",
     ],
     hdrs = [
         "include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h",
         "include/mlir/Dialect/Bufferization/IR/Bufferization.h",
         "include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h",
+        "include/mlir/Dialect/Bufferization/IR/InferDestinationOpInterface.h",
         "include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h",
     ],
     includes = ["include"],
@@ -11995,6 +12027,7 @@ cc_library(
         ":FuncDialect",
         ":FunctionInterfaces",
         ":IR",
+        ":InferDestinationOpInterfaceIncGen",
         ":InferTypeOpInterface",
         ":MemRefDialect",
         ":SparseTensorDialect",

>From 5f333312297e378fc004323955322d83cb3e9d6d Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 6 Sep 2023 14:24:20 +0200
Subject: [PATCH 2/3] [mlir][bufferization][NFC] Rename copy_tensor op to
 materialize_in_destination

The previous name was badly chosen. The op is used to ensure that a computation materializes in the future buffer of a certain tensor.

BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
---
 .../Bufferization/IR/BufferizationOps.td      | 27 ++++--
 .../Linalg/TransformOps/LinalgTransformOps.td |  6 +-
 .../Dialect/Linalg/Transforms/Transforms.h    |  4 +-
 .../Bufferization/IR/BufferizationOps.cpp     | 86 ++++++++++---------
 .../TransformOps/LinalgTransformOps.cpp       | 10 ++-
 .../lib/Dialect/Linalg/Transforms/Padding.cpp |  8 +-
 .../Transforms/one-shot-bufferize.mlir        |  2 +-
 mlir/test/Dialect/Bufferization/invalid.mlir  |  4 +-
 mlir/test/Dialect/Bufferization/ops.mlir      |  8 +-
 .../test/Dialect/Linalg/transform-op-pad.mlir |  8 +-
 10 files changed, 91 insertions(+), 72 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index fec07af349b3a8d..a6a733dfce13251 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -209,17 +209,30 @@ def Bufferization_CloneOp : Bufferization_Op<"clone", [
 }
 
 //===----------------------------------------------------------------------===//
-// CopyTensorOp
+// MaterializeInDestinationOp
 //===----------------------------------------------------------------------===//
 
-def Bufferization_CopyTensorOp : Bufferization_Op<"copy_tensor",
-    [BufferizableOpInterface, SameOperandsAndResultType,
-     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+def Bufferization_MaterializeInDestinationOp
+    : Bufferization_Op<"materialize_in_destination",
+        [BufferizableOpInterface, SameOperandsAndResultType,
+         DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
   let summary = "copy a tensor";
 
   let description = [{
-    Copy the contents of the source tensor into the destination tensor. This
-    operation is guaranteed to bufferize to a memory copy.
+    This op indicates that the data of the `source` tensor should materialize
+    in the future buffer of the `dest` tensors. Both tensors must have the same
+    shape and element type at runtime.
+
+    By default, this op bufferizes to a memcpy from the future buffer of the
+    `source` tensor to the future buffer of the `dest` tensor. However,
+    transformations such as "empty tensor elimination" may rewrite IR such that
+    a computation is performed directly in the future buffer of the `dest`
+    tensor and no memcpy is needed.
+
+    Note: "tensor.insert_slice" could be used for the same purpose, but since
+    tensor dialect ops only indicate *what* should be computed but not *where*,
+    it could fold away, causing the computation to materialize in a different
+    buffer.
   }];
 
   let arguments = (ins AnyTensor:$source,
@@ -245,7 +258,7 @@ def Bufferization_CopyTensorOp : Bufferization_Op<"copy_tensor",
     }
   }];
 
-  let assemblyFormat = "$source `,` $dest attr-dict `:` type($source)";
+  let assemblyFormat = "$source `in` $dest attr-dict `:` type($source)";
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 858be9e60621aa7..a1b1934d6862b85 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -866,7 +866,7 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
     the original destination tensor of the targeted op. The op that copies back
     the result can be customized with `copy_back_op`:
 
-    * "bufferization.copy_tensor" (default)
+    * "bufferization.materialize_in_destination" (default)
     * "linalg.copy"
     * "none" (no copy back)
 
@@ -891,7 +891,7 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
          DefaultValuedAttr<
           TypedArrayAttrBase<I64ArrayAttr, "array of arrays of i64">,
           "{}">:$transpose_paddings,
-         DefaultValuedAttr<StrAttr, "::mlir::bufferization::CopyTensorOp::getOperationName()">:$copy_back_op);
+         DefaultValuedAttr<StrAttr, "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copy_back_op);
   let results = (outs TransformHandleTypeInterface:$padded,
                       TransformHandleTypeInterface:$pad,
                       TransformHandleTypeInterface:$copy);
@@ -911,7 +911,7 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
                    CArg<"ArrayRef<int64_t>", "{}">:$padToMultipleOf,
                    CArg<"ArrayRef<int64_t>", "{}">:$packPaddings,
                    CArg<"ArrayRef<Attribute>", "{}">:$transposePaddings,
-                   CArg<"StringRef", "::mlir::bufferization::CopyTensorOp::getOperationName()">:$copyBackOp)>
+                   CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp)>
   ];
 
   let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 82f7915395ba2d9..cea9651d1f91cb5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -264,12 +264,12 @@ struct LinalgPaddingOptions {
   }
   enum class CopyBackOp : int8_t {
     None = 0,
-    BufferizationCopyTensor = 1,
+    BufferizationMaterializeInDestination = 1,
     LinalgCopy = 2
   };
   /// The op to be used for copying the padded result to the original
   /// destination tensor.
-  CopyBackOp copyBackOp = CopyBackOp::BufferizationCopyTensor;
+  CopyBackOp copyBackOp = CopyBackOp::BufferizationMaterializeInDestination;
   LinalgPaddingOptions &setCopyBackOp(CopyBackOp op) {
     copyBackOp = op;
     return *this;
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 9a2a6d0f5c6d981..e5016c956804688 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -441,48 +441,6 @@ Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
   return getOperand(getIndexOfDynamicSize(idx));
 }
 
-//===----------------------------------------------------------------------===//
-// CopyTensorOp
-//===----------------------------------------------------------------------===//
-
-bool CopyTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
-                                          const AnalysisState &state) {
-  if (&opOperand == &getOperation()->getOpOperand(0) /*source*/)
-    return true;
-  return false;
-}
-
-bool CopyTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
-                                           const AnalysisState &state) {
-  if (&opOperand == &getOperation()->getOpOperand(1) /*dest*/)
-    return true;
-  return false;
-}
-
-AliasingValueList CopyTensorOp::getAliasingValues(OpOperand &opOperand,
-                                                  const AnalysisState &state) {
-  if (&opOperand == &getOperation()->getOpOperand(1) /*dest*/)
-    return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
-  return {};
-}
-
-LogicalResult CopyTensorOp::bufferize(RewriterBase &rewriter,
-                                      const BufferizationOptions &options) {
-  FailureOr<Value> buffer = getBuffer(rewriter, getDest(), options);
-  if (failed(buffer))
-    return failure();
-  rewriter.create<memref::TensorStoreOp>(getLoc(), getSource(), *buffer);
-  replaceOpWithBufferizedValues(rewriter, getOperation(), *buffer);
-  return success();
-}
-
-LogicalResult CopyTensorOp::reifyResultShapes(
-    OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
-  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
-  reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // CloneOp
 //===----------------------------------------------------------------------===//
@@ -585,6 +543,50 @@ LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// MaterializeInDestinationOp
+//===----------------------------------------------------------------------===//
+
+bool MaterializeInDestinationOp::bufferizesToMemoryRead(
+    OpOperand &opOperand, const AnalysisState &state) {
+  if (&opOperand == &getOperation()->getOpOperand(0) /*source*/)
+    return true;
+  return false;
+}
+
+bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
+    OpOperand &opOperand, const AnalysisState &state) {
+  if (&opOperand == &getOperation()->getOpOperand(1) /*dest*/)
+    return true;
+  return false;
+}
+
+AliasingValueList
+MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
+                                              const AnalysisState &state) {
+  if (&opOperand == &getOperation()->getOpOperand(1) /*dest*/)
+    return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
+  return {};
+}
+
+LogicalResult
+MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
+                                      const BufferizationOptions &options) {
+  FailureOr<Value> buffer = getBuffer(rewriter, getDest(), options);
+  if (failed(buffer))
+    return failure();
+  rewriter.create<memref::TensorStoreOp>(getLoc(), getSource(), *buffer);
+  replaceOpWithBufferizedValues(rewriter, getOperation(), *buffer);
+  return success();
+}
+
+LogicalResult MaterializeInDestinationOp::reifyResultShapes(
+    OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
+  reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ToTensorOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 6aa575dd865cf75..9e5f81d0cffae99 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1649,9 +1649,10 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
     options.padToMultipleOf = padToMultipleOf;
     options.paddingValues = paddingValues;
     options.packPaddings = packPaddings;
-    if (getCopyBackOp() == bufferization::CopyTensorOp::getOperationName()) {
-      options.copyBackOp =
-          LinalgPaddingOptions::CopyBackOp::BufferizationCopyTensor;
+    if (getCopyBackOp() ==
+        bufferization::MaterializeInDestinationOp::getOperationName()) {
+      options.copyBackOp = LinalgPaddingOptions::CopyBackOp::
+          BufferizationMaterializeInDestination;
     } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
       options.copyBackOp = LinalgPaddingOptions::CopyBackOp::LinalgCopy;
     } else if (getCopyBackOp() == kCopyOpNone) {
@@ -1727,7 +1728,8 @@ LogicalResult transform::PadOp::verify() {
              << attr;
     }
   }
-  if (getCopyBackOp() != bufferization::CopyTensorOp::getOperationName() &&
+  if (getCopyBackOp() !=
+          bufferization::MaterializeInDestinationOp::getOperationName() &&
       getCopyBackOp() != linalg::CopyOp::getOperationName() &&
       getCopyBackOp() != kCopyOpNone)
     return emitOpError() << "invalid copy_back_op";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
index b8ebf7dbb0fe72f..8fe745d97ca3dd8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
@@ -245,9 +245,11 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
                                                          std::get<1>(it)->get())
                                  .getResult(0));
     } else if (options.copyBackOp ==
-               LinalgPaddingOptions::CopyBackOp::BufferizationCopyTensor) {
-      replacements.push_back(rewriter.create<bufferization::CopyTensorOp>(
-          loc, std::get<0>(it), std::get<1>(it)->get()));
+               LinalgPaddingOptions::CopyBackOp::
+                   BufferizationMaterializeInDestination) {
+      replacements.push_back(
+          rewriter.create<bufferization::MaterializeInDestinationOp>(
+              loc, std::get<0>(it), std::get<1>(it)->get()));
     } else {
       llvm_unreachable("unsupported copy back op");
     }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
index a261256c033fa41..f92c7b4ee585150 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
@@ -224,6 +224,6 @@ func.func @tensor_copy(%arg0: tensor<5xf32>) -> tensor<5xf32> {
   // CHECK: memref.dealloc %[[alloc]]
   // CHECK: return %[[r]]
   %dest = bufferization.alloc_tensor() : tensor<5xf32>
-  %0 = bufferization.copy_tensor %arg0, %dest : tensor<5xf32>
+  %0 = bufferization.materialize_in_destination %arg0 in %dest : tensor<5xf32>
   return %0 : tensor<5xf32>
 }
diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir
index 3b4bfee5622e9bb..7c92193ab068dba 100644
--- a/mlir/test/Dialect/Bufferization/invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/invalid.mlir
@@ -99,9 +99,9 @@ func.func @invalid_writable_on_op() {
 // -----
 
 // expected-note @below{{prior use here}}
-func.func @invalid_tensor_copy(%arg0: tensor<?xf32>, %arg1: tensor<5xf32>) {
+func.func @invalid_materialize_in_destination(%arg0: tensor<?xf32>, %arg1: tensor<5xf32>) {
   // expected-error @below{{expects different type than prior uses: 'tensor<?xf32>' vs 'tensor<5xf32>'}}
-  bufferization.copy_tensor %arg0, %arg1 : tensor<?xf32>
+  bufferization.materialize_in_destination %arg0 in %arg1 : tensor<?xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Bufferization/ops.mlir b/mlir/test/Dialect/Bufferization/ops.mlir
index 773f15c1ffcb89b..665f5697fdc5fdf 100644
--- a/mlir/test/Dialect/Bufferization/ops.mlir
+++ b/mlir/test/Dialect/Bufferization/ops.mlir
@@ -58,11 +58,11 @@ func.func @test_dealloc_tensor_op(%arg0: tensor<4xi32>) {
   return
 }
 
-// CHECK-LABEL: func @test_copy_tensor_op
-func.func @test_copy_tensor_op(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>)
+// CHECK-LABEL: func @test_materialize_in_destination_op
+func.func @test_materialize_in_destination_op(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>)
     -> tensor<?xf32> {
-  // CHECK: bufferization.copy_tensor {{.*}} : tensor<?xf32>
-  %1 = bufferization.copy_tensor %arg0, %arg1 : tensor<?xf32>
+  // CHECK: bufferization.materialize_in_destination {{.*}} : tensor<?xf32>
+  %1 = bufferization.materialize_in_destination %arg0 in %arg1 : tensor<?xf32>
   return %1 : tensor<?xf32>
 }
 
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index d8d0fc56f04406b..5c5d162b7c16f0a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -27,7 +27,7 @@ func.func @static_sizes_output_divisible(%arg0: tensor<24x12xf32>,
   // CHECK-SAME:              outs(%[[T2]] : tensor<4x5xf32>)
 
   //      CHECK: %[[T6:.*]] = tensor.extract_slice %[[T5]]
-  //      CHECK: %[[T7:.*]] = bufferization.copy_tensor %[[T6]], %[[T2]]
+  //      CHECK: %[[T7:.*]] = bufferization.materialize_in_destination %[[T6]] in %[[T2]]
   %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor<?x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32>
   %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32>
   func.return %5 : tensor<24x25xf32>
@@ -40,9 +40,9 @@ transform.sequence failures(propagate) {
     padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32],
     padding_dimensions=[0, 1, 2],
     pack_paddings=[1, 1, 0]
-  } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.op<"bufferization.copy_tensor">)
+  } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.op<"bufferization.materialize_in_destination">)
   // expected-remark @below {{1}}
-  test_print_number_of_associated_payload_ir_ops %copy_back : !transform.op<"bufferization.copy_tensor">
+  test_print_number_of_associated_payload_ir_ops %copy_back : !transform.op<"bufferization.materialize_in_destination">
 }
 
 // -----
@@ -272,7 +272,7 @@ func.func @pack_everything(%arg0: tensor<24x12xf32>,
   //      CHECK: %[[T6:.*]] = tensor.extract_slice %[[T5]]
   // Copy back result to the original buffer, so that the destination of the
   // computation does not change.
-  //      CHECK: %[[T7:.*]] = bufferization.copy_tensor %[[T6]], %[[T2]]
+  //      CHECK: %[[T7:.*]] = bufferization.materialize_in_destination %[[T6]] in %[[T2]]
   %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor<?x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32>
 
   //      CHECK: %[[T8:.*]] = tensor.insert_slice %[[T7]] into %{{.*}}

>From 42cf0d236277602218fad0d51653ceee31ea9a0f Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 6 Sep 2023 14:24:38 +0200
Subject: [PATCH 3/3] [mlir][bufferization] Empty tensor elimination for
 materialize_in_destination

This revision adds support for empty tensor elimination to "bufferization.materialize_in_destination" by implementing the `InferDestinationOpInterface`.

Furthermore, the One-Shot Bufferize conflict detection is improved for "bufferization.materialize_in_destination".
---
 .../Dialect/Bufferization/IR/Bufferization.h  |  1 +
 .../Bufferization/IR/BufferizationOps.td      |  9 ++-
 .../Bufferization/IR/BufferizationOps.cpp     | 57 +++++++++++++++++++
 .../one-shot-bufferize-analysis.mlir          | 28 +++++++++
 ...ot-bufferize-empty-tensor-elimination.mlir | 14 +++++
 .../llvm-project-overlay/mlir/BUILD.bazel     |  1 +
 6 files changed, 109 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
index 450dfb37ddb2e18..fe5eaf51cbcd34a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
@@ -12,6 +12,7 @@
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/InferDestinationOpInterface.h"
 #include "mlir/Interfaces/CopyOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index a6a733dfce13251..f987c817f30b93b 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -12,6 +12,7 @@
 include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
+include "mlir/Dialect/Bufferization/IR/InferDestinationOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/CopyOpInterface.td"
@@ -215,7 +216,9 @@ def Bufferization_CloneOp : Bufferization_Op<"clone", [
 def Bufferization_MaterializeInDestinationOp
     : Bufferization_Op<"materialize_in_destination",
         [BufferizableOpInterface, SameOperandsAndResultType,
-         DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+         DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+         DeclareOpInterfaceMethods<InferDestinationOpInterface,
+            ["getOrBuildDestination", "getNeededValues", "isAnchor"]>]> {
   let summary = "copy a tensor";
 
   let description = [{
@@ -250,6 +253,10 @@ def Bufferization_MaterializeInDestinationOp
     bool bufferizesToMemoryWrite(OpOperand &opOperand,
                                  const AnalysisState &state);
 
+    bool isNotConflicting(OpOperand *uRead,
+                          OpOperand *uConflictingWrite,
+                          const AnalysisState &state);
+
     AliasingValueList getAliasingValues(
         OpOperand &opOperand, const AnalysisState &state);
 
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index e5016c956804688..91521e5a2a7480e 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -580,6 +580,48 @@ MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
   return success();
 }
 
+bool MaterializeInDestinationOp::isNotConflicting(OpOperand *uRead,
+                                                  OpOperand *uConflictingWrite,
+                                                  const AnalysisState &state) {
+  // A transfer from a tensor into the same (equivalent) tensor is not a
+  // conflict. E.g.:
+  //
+  // %1 = tensor.insert %f into %0[%pos] : tensor<5xf32>
+  // %2 = bufferization.materialize_in_destination %1, %0 : tensor<5xf32>
+  //
+  // If "tensor.insert" bufferizes in-place, %1 and %0 are equivalent tensors.
+  // One-Shot Bufferize would detect a potential RaW conflict with
+  // - uRead = first operand (source) of materialize_in_destination
+  // - uConflictingWrite = second operand (dest) of materialize_in_destination
+  // This is not a conflict and this op is in fact a no-op; there is no transfer
+  // into a new destination.
+  //
+  // If "tensor.insert" bufferizes out-of-place, there is no RaW in the above
+  // example.
+
+  if (uRead != &getOperation()->getOpOperand(0) /*source*/ ||
+      uConflictingWrite != &getOperation()->getOpOperand(1) /*dest*/)
+    return false;
+
+  // Make sure that source and dest are maybe equivalent (but not aliasing and
+  // not equivalent; e.g., overlapping). `areEquivalentBufferizedValues` cannot
+  // be used directly because the equivalence sets may not have been formed yet.
+  TraversalConfig config;
+  config.followEquivalentOnly = true;
+  config.followSameTypeOrCastsOnly = true;
+  config.alwaysIncludeLeaves = true;
+  SetVector<Value> defs = state.findValueInReverseUseDefChain(
+      uRead->get(),
+      [&](Value v) {
+        return state.areEquivalentBufferizedValues(v, uConflictingWrite->get());
+      },
+      config);
+  if (defs.size() == 1 && defs.front() == getDest())
+    return true;
+
+  return false;
+}
+
 LogicalResult MaterializeInDestinationOp::reifyResultShapes(
     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
   reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
@@ -587,6 +629,21 @@ LogicalResult MaterializeInDestinationOp::reifyResultShapes(
   return success();
 }
 
+Value MaterializeInDestinationOp::getOrBuildDestination(OpBuilder &builder,
+                                                        Location loc,
+                                                        OpOperand &operand) {
+  return getDest();
+}
+
+SmallVector<Value>
+MaterializeInDestinationOp::getNeededValues(OpOperand &operand) {
+  return {getDest()};
+}
+
+bool MaterializeInDestinationOp::isAnchor(OpOperand &operand) {
+  return &operand == &getOperation()->getOpOperand(0) /*source*/;
+}
+
 //===----------------------------------------------------------------------===//
 // ToTensorOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
index 5a505c66892f113..a2fbb06d179ebda 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
@@ -158,3 +158,31 @@ func.func @bbarg_of_unknown_op_2(%f: f32) {
   // CHECK: {__inplace_operands_attr__ = ["false"]} : (tensor<10xf32>) -> ()
   return
 }
+
+// -----
+
+// CHECK: func @materialize_in_destination_aliasing(
+func.func @materialize_in_destination_aliasing(%t: tensor<?xf32>, %p1: index, %p2: index, %sz: index) -> tensor<5xf32> {
+  %buffer = tensor.empty(%sz) : tensor<?xf32>
+  // CHECK: tensor.extract_slice
+  // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"]}
+  %src = tensor.extract_slice %t[%p1][5][1] : tensor<?xf32> to tensor<5xf32>
+  // CHECK: tensor.extract_slice
+  // CHECK-SAME: {__inplace_operands_attr__ = ["false", "none"]}
+  %dest = tensor.extract_slice %t[%p2][5][1] : tensor<?xf32> to tensor<5xf32>
+  // CHECK: bufferization.materialize_in_destination
+  // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]}
+  %r = bufferization.materialize_in_destination %src in %dest : tensor<5xf32>
+  return %r : tensor<5xf32>
+}
+
+// -----
+
+// CHECK: func @materialize_in_destination(
+func.func @materialize_in_destination(%t: tensor<?xf32>, %sz: index) -> tensor<?xf32> {
+  %buffer = tensor.empty(%sz) : tensor<?xf32>
+  // CHECK: bufferization.materialize_in_destination
+  // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]}
+  %r = bufferization.materialize_in_destination %buffer in %buffer : tensor<?xf32>
+  return %r : tensor<?xf32>
+}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index 3d15599915f0cfc..063ed07467f90d0 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -291,3 +291,17 @@ func.func @regression_multiple_insertion_points(%t1: tensor<?x?xf32>) -> tensor<
   %2 = tensor.insert_slice %filled into %t1 [%0, %1] [2, 5] [1, 1] : tensor<2x5xf32> into tensor<?x?xf32>
   return %2 : tensor<?x?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @materialize_in_destination(
+//  CHECK-SAME:     %[[m:.*]]: memref<5xf32, strided<[?], offset: ?>>,
+//       CHECK:   linalg.fill {{.*}} outs(%[[m]]
+//       CHECK:   return %[[m]]
+func.func @materialize_in_destination(%t: tensor<5xf32>, %f: f32) -> tensor<5xf32> {
+  %0 = tensor.empty() : tensor<5xf32>
+  %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
+  %1 = bufferization.materialize_in_destination %filled in %t : tensor<5xf32>
+  //%1 = tensor.insert_slice %filled into %t[0][5][1] : tensor<5xf32> into tensor<5xf32>
+  return %1 : tensor<5xf32>
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index ece4913c8753444..7f08b3601b37bdc 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -11993,6 +11993,7 @@ gentbl_cc_library(
     deps = [
         ":BufferizableOpInterfaceTdFiles",
         ":BufferizationOpsTdFiles",
+        ":InferDestinationOpInterfaceTdFiles",
     ],
 )
 



More information about the Mlir-commits mailing list