[Mlir-commits] [mlir] [mlir][bufferization] Generalize tensor slice rules to subset ops (PR #65619)

Matthias Springer llvmlistbot at llvm.org
Thu Sep 7 08:17:21 PDT 2023


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/65619:

This commit generalizes the special tensor.extract_slice/tensor.insert_slice bufferization rules to tensor subset ops.

Ops that insert a tensor into a tensor at a specified subset (e.g., tensor.insert_slice, tensor.scatter) can implement the `SubsetOpInterface`.

Apart from adding a new op interface (extending the API), this change is NFC. The only ops that currently implement the new interface are tensor.insert_slice and tensor.parallel_insert_slice, and those ops were are supported by One-Shot Bufferize.

>From 41ff6b366fb4ad1ada3e5e4e4c8b1917bb5e84be Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 7 Sep 2023 17:16:26 +0200
Subject: [PATCH] [mlir][bufferization] Generalize tensor slice rules to subset
 ops

This commit generalizes the special tensor.extract_slice/tensor.insert_slice bufferization rules to tensor subset ops.

Ops that insert a tensor into a tensor at a specified subset (e.g., tensor.insert_slice, tensor.scatter) can implement the `SubsetOpInterface`.

Apart from adding a new op interface (extending the API), this change is NFC. The only ops that currently implement the new interface are tensor.insert_slice and tensor.parallel_insert_slice, and those ops were are supported by One-Shot Bufferize.
---
 .../Dialect/Bufferization/IR/CMakeLists.txt   |   1 +
 .../Bufferization/IR/SubsetOpInterface.h      |  31 +++
 .../Bufferization/IR/SubsetOpInterface.td     | 119 +++++++++++
 .../Dialect/Bufferization/IR/CMakeLists.txt   |   1 +
 .../Bufferization/IR/SubsetOpInterface.cpp    |  23 +++
 .../Transforms/OneShotAnalysis.cpp            | 103 ++++++++++
 .../BufferizableOpInterfaceImpl.cpp           | 184 ++++++------------
 .../llvm-project-overlay/mlir/BUILD.bazel     |  37 +++-
 8 files changed, 372 insertions(+), 127 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.h
 create mode 100644 mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.td
 create mode 100644 mlir/lib/Dialect/Bufferization/IR/SubsetOpInterface.cpp

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
index aa93534a78fea3f..4a97493cbf42a49 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect(BufferizationOps bufferization)
 add_mlir_doc(BufferizationOps BufferizationOps Dialects/ -gen-dialect-doc)
 add_mlir_interface(AllocationOpInterface)
 add_mlir_interface(BufferizableOpInterface)
+add_mlir_interface(SubsetOpInterface)
 
 set(LLVM_TARGET_DEFINITIONS BufferizationEnums.td)
 mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.h
new file mode 100644
index 000000000000000..9ab7f11d42fbb56
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.h
@@ -0,0 +1,31 @@
+//===- SubsetOpInterface.h - Tensor subsets ---------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_IR_SUBSETOPINTERFACE_H_
+#define MLIR_DIALECT_BUFFERIZATION_IR_SUBSETOPINTERFACE_H_
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+class OpBuilder;
+
+namespace bufferization {
+namespace detail {
+
+/// Return the destination/"init" operand of the op if it implements the
+/// `DestinationStyleOpInterface` and has exactly one "init" operand. Asserts
+/// otherwise.
+OpOperand &defaultGetDestination(Operation *op);
+
+} // namespace detail
+} // namespace bufferization
+} // namespace mlir
+
+#include "mlir/Dialect/Bufferization/IR/SubsetOpInterface.h.inc"
+
+#endif // MLIR_DIALECT_BUFFERIZATION_IR_SUBSETOPINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.td
new file mode 100644
index 000000000000000..e7686ac8db9f8dc
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.td
@@ -0,0 +1,119 @@
+//===-- SubsetOpInterface.td - Tensor subsets --------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SUBSET_OP_INTERFACE
+#define SUBSET_OP_INTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def SubsetOpInterface : OpInterface<"SubsetOpInterface"> {
+  let description = [{
+    This interface can be implemented by ops that insert a source tensor into
+    a destination tensor.
+
+    The elements in the destination tensor that are overwritten by this
+    insertion are called the "subset". How the subset is defined is up to the
+    op. E.g., "tensor.insert_slice" defines the subset via a hyperrectangular
+    slice. A scatter operation could define the subset via a list of indices.
+
+    Ops that deal with tensor subsets come in two flavours:
+    - Insertion flavor: Ops that insert a source tensor into a destination
+      tensor at the specified subset. Such ops usuallt return a new destination
+      tensor and implement the `DestinationStyleOpInterface`. Insertion ops can
+      implement the `SubsetOpInterface`. Example: "tensor.insert_slice"
+    - Extraction flavor: Ops that define a tensor subset. They extract a
+      specified subset from a tensor. There is currently no op interface for
+      such ops. Example: "tensor.extract_slice"
+
+    This interface provides helper methods for efficient bufferization of
+    subset-based tensor IR. Tensor subsets can bufferize to buffer "views"/
+    "aliases" (in contrast to one or multiple less efficient buffer allocation).
+
+    This interface is queried by One-Shot Bufferize to detect cases where a
+    seeming read-after-write is not actually a conflict because the respective
+    ops are operating on equivalent subsets. More details can be found in the
+    documentation of One-Shot Analysis.
+
+    Note: This interface currently assumes that a subset op inserts a single
+    tensor (source) into a destination tensor at a single subset.
+  }];
+  let cppNamespace = "::mlir::bufferization";
+  let methods = [
+      InterfaceMethod<
+        /*desc=*/[{
+          Return the source tensor operand.
+        }],
+        /*retType=*/"::mlir::OpOperand &",
+        /*methodName=*/"getSource",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          llvm_unreachable("getSource not implemented");
+        }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return the destination tensor operand.
+        }],
+        /*retType=*/"::mlir::OpOperand &",
+        /*methodName=*/"getDestination",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return ::mlir::bufferization::detail::defaultGetDestination(
+              $_op.getOperation());
+        }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return "true" if this operation inserts into a subset that is
+          equivalent to the subset defined by `candidate`.
+
+          Two subsets are "equivalent" and "same" if they can bufferize to the
+          same buffer views/aliases. If they are "equivalent", the tensor IR
+          may be expressed in terms of different SSA values.
+
+          Example:
+          ```
+          // %1 and %2 are equivalent (but not "same"):
+          %0 = arith.select %c, %t, %t : tensor<?xf32>
+          %1 = tensor.extract_slice %0[0][5][1] : tensor<?xf32> to tensor<5xf32>
+          %2 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+
+          // %1 and %2 are equivalent and "same"
+          %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+          %2 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+          ```
+        }],
+        /*retType=*/"bool",
+        /*methodName=*/"isEquivalentSubset",
+        /*args=*/(ins
+            "::mlir::OpResult":$candidate,
+            "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          llvm_unreachable("isEquivalentSubset not implemented");
+        }]
+      >,
+  ];
+
+  let extraClassDeclaration = [{
+    /// Return "true" if this operation inserts into the same subset as defined
+    /// by `candidate`.
+    ///
+    /// Note: This function is useful outside of bufferization, where no tensor
+    /// equivalence information is available.
+    bool isSameSubset(OpResult candidate) {
+      return cast<::mlir::bufferization::SubsetOpInterface>(getOperation())
+          .isEquivalentSubset(candidate,
+                              [](Value v1, Value v2) { return v1 == v2; });
+    }
+  }];
+}
+
+#endif // INFER_DESTINATION_OP_INTERFACE
diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
index 2d8d09b9c41d993..f406595fc1e9fc9 100644
--- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
   BufferizableOpInterface.cpp
   BufferizationOps.cpp
   BufferizationDialect.cpp
+  SubsetOpInterface.cpp
   UnstructuredControlFlow.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Bufferization/IR/SubsetOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/SubsetOpInterface.cpp
new file mode 100644
index 000000000000000..87804c50a530fbc
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/IR/SubsetOpInterface.cpp
@@ -0,0 +1,23 @@
+//===- SubsetOpInterface.cpp - Tensor subsets -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/SubsetOpInterface.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+
+#include "mlir/Dialect/Bufferization/IR/SubsetOpInterface.cpp.inc"
+
+using namespace mlir;
+
+OpOperand &bufferization::detail::defaultGetDestination(Operation *op) {
+  auto dstOp = dyn_cast<DestinationStyleOpInterface>(op);
+  assert(dstOp && "getDestination must be implemented for non-DPS ops");
+  assert(
+      dstOp.getNumDpsInits() == 1 &&
+      "getDestination must be implemented for ops with 0 or more than 1 init");
+  return *dstOp.getDpsInitOperand(0);
+}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 49b5ebdf722a1a7..c613b77c10dca57 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -45,6 +45,7 @@
 
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/SubsetOpInterface.h"
 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -531,6 +532,102 @@ static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state,
               .empty();
 }
 
+/// Return "true" if `value` is originating from a subset that is equivalent to
+/// the subset that `subsetOp` inserts into.
+static bool matchesInsertDestination(const AnalysisState &state, Value value,
+                                     SubsetOpInterface subsetOp) {
+  auto matchingSubset = [&](Value val) {
+    if (auto opResult = dyn_cast<OpResult>(val))
+      if (subsetOp.isEquivalentSubset(opResult, [&](Value v1, Value v2) {
+            return state.areEquivalentBufferizedValues(v1, v2);
+          }))
+        return true;
+    return false;
+  };
+  // There may be multiple leaves at which the reverse SSA use-def chain lookup
+  // terminates. All of them must be equivalent subsets.
+  SetVector<Value> backwardSlice =
+      state.findValueInReverseUseDefChain(value, matchingSubset);
+  return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset));
+}
+
+/// Return "true" if the given "read" and potentially conflicting "write" are
+/// not conflicting due to their subset relationship. The comments in this
+/// function are expressed in terms of tensor.extract_slice/tensor.insert_slice
+/// pairs, but apply to any subset ops that implement the `SubsetOpInterface`.
+static bool areNonConflictingSubsets(OpOperand *uRead,
+                                     OpOperand *uConflictingWrite,
+                                     const AnalysisState &state) {
+  Operation *readingOp = uRead->getOwner();
+  Operation *conflictingWritingOp = uConflictingWrite->getOwner();
+
+  // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
+  // uRead is an InsertSliceOp...
+  if (auto subsetOp = dyn_cast<SubsetOpInterface>(readingOp)) {
+    // As an example, consider the following IR.
+    //
+    // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+    // %1 = linalg.fill %cst, %0 {inplace= [true] }
+    // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+    //     {inplace= [true] }
+
+    if (uRead == &subsetOp.getDestination() &&
+        matchesInsertDestination(state, uConflictingWrite->get(), subsetOp))
+      // Case 1: The main insight is that InsertSliceOp reads only part of
+      // the destination tensor. The overwritten area is not read. If
+      // uConflictingWrite writes into exactly the memory location that is
+      // being read by uRead, this is not a conflict.
+      //
+      // In the above example:
+      // uRead             = OpOperand 1 (%t) of tensor.insert_slice
+      // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
+      //
+      // The read of %t does not conflict with the write of the FillOp
+      // (same aliases!) because the area that the FillOp operates on is
+      // exactly the one that is *not* read via %t.
+      return true;
+
+    if (uRead == &subsetOp.getSource() &&
+        uConflictingWrite == &subsetOp.getDestination() &&
+        matchesInsertDestination(state, uRead->get(), subsetOp))
+      // Case 2: The read of the source tensor and the write to the dest
+      // tensor via an InsertSliceOp is not a conflict if the read is
+      // reading exactly that part of an equivalent tensor that the
+      // InsertSliceOp is writing.
+      //
+      // In the above example:
+      // uRead             = OpOperand 0 (%1) of tensor.insert_slice
+      // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+      return true;
+  }
+
+  // If uConflictingWrite is an InsertSliceOp...
+  if (auto subsetOp = dyn_cast<SubsetOpInterface>(conflictingWritingOp))
+    // As an example, consider the following IR.
+    //
+    // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+    // %1 = linalg.fill %cst, %0 {inplace= [true] }
+    // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+    //     {inplace= [true] }
+    // %3 = vector.transfer_read %1, %cst
+    //
+    // In the above example:
+    // uRead             = OpOperand 0 (%1) of vector.transfer_read
+    // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+    // definition        = %1
+    //
+    // This is not a conflict because the InsertSliceOp overwrites the
+    // memory segment of %1 with the exact same data. (Effectively, there
+    // is no memory write here.)
+    if (uConflictingWrite == &subsetOp.getDestination() &&
+        state.areEquivalentBufferizedValues(uRead->get(),
+                                            subsetOp.getSource().get()) &&
+        matchesInsertDestination(state, subsetOp.getSource().get(), subsetOp))
+      return true;
+
+  return false;
+}
+
 /// Given sets of uses and writes, return true if there is a RaW conflict under
 /// the assumption that all given reads/writes alias the same buffer and that
 /// all given writes bufferize inplace.
@@ -647,6 +744,12 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
         }
       }
 
+      // No conflict if the operands are non-conflicting subsets.
+      if (areNonConflictingSubsets(uRead, uConflictingWrite, state)) {
+        LLVM_DEBUG(llvm::dbgs() << "  no conflict: non-conflicting subsets\n");
+        continue;
+      }
+
       // No conflict if the op interface says so.
       if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {
         if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index a67ea0334b22b9b..23bb4c34928e6ac 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/IR/SubsetOpInterface.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -628,117 +629,6 @@ struct InsertOpInterface
   }
 };
 
-/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
-/// equivalent operand / result and same offset/sizes/strides specification).
-template <typename OpTy>
-static bool areEquivalentSlices(const AnalysisState &state,
-                                ExtractSliceOp extractSliceOp,
-                                OpTy insertSliceOp) {
-  if (!extractSliceOp || !insertSliceOp)
-    return false;
-  if (extractSliceOp != insertSliceOp &&
-      !state.areEquivalentBufferizedValues(extractSliceOp.getSource(),
-                                           insertSliceOp.getDest()))
-    return false;
-  if (!sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
-                                  isEqualConstantIntOrValue))
-    return false;
-  return true;
-}
-
-/// Return true if `value` is originating from an ExtractSliceOp that matches
-/// the given InsertSliceOp.
-template <typename OpTy>
-static bool matchesInsertDestination(const AnalysisState &state, Value value,
-                                     OpTy insertSliceOp) {
-  // Look for matching slices.
-  auto matchesSlice = [&](Value val) {
-    if (auto extractSliceOp = val.getDefiningOp<ExtractSliceOp>())
-      if (areEquivalentSlices(state, extractSliceOp, insertSliceOp))
-        return true;
-    return false;
-  };
-  return static_cast<bool>(llvm::all_of(
-      state.findValueInReverseUseDefChain(value, matchesSlice), matchesSlice));
-}
-
-template <typename OpTy>
-static bool isNotConflictingInsertSliceLikeOp(Operation *op, OpOperand *uRead,
-                                              OpOperand *uConflictingWrite,
-                                              const AnalysisState &state) {
-  Operation *readingOp = uRead->getOwner();
-  Operation *conflictingWritingOp = uConflictingWrite->getOwner();
-
-  // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
-  // uRead is an InsertSliceOp...
-  if (auto insertSliceOp = dyn_cast<OpTy>(readingOp)) {
-    // As an example, consider the following IR.
-    //
-    // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
-    // %1 = linalg.fill %cst, %0 {inplace= [true] }
-    // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
-    //     {inplace= [true] }
-
-    // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
-    if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-        matchesInsertDestination(state, uConflictingWrite->get(),
-                                 insertSliceOp))
-      // Case 1: The main insight is that InsertSliceOp reads only part of
-      // the destination tensor. The overwritten area is not read. If
-      // uConflictingWrite writes into exactly the memory location that is
-      // being read by uRead, this is not a conflict.
-      //
-      // In the above example:
-      // uRead             = OpOperand 1 (%t) of tensor.insert_slice
-      // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
-      //
-      // The read of %t does not conflict with the write of the FillOp
-      // (same aliases!) because the area that the FillOp operates on is
-      // exactly the one that is *not* read via %t.
-      return true;
-
-    if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
-        uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-        matchesInsertDestination(state, uRead->get(), insertSliceOp))
-      // Case 2: The read of the source tensor and the write to the dest
-      // tensor via an InsertSliceOp is not a conflict if the read is
-      // reading exactly that part of an equivalent tensor that the
-      // InsertSliceOp is writing.
-      //
-      // In the above example:
-      // uRead             = OpOperand 0 (%1) of tensor.insert_slice
-      // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
-      return true;
-  }
-
-  // If uConflictingWrite is an InsertSliceOp...
-  if (auto insertSliceOp = dyn_cast<OpTy>(conflictingWritingOp))
-    // As an example, consider the following IR.
-    //
-    // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
-    // %1 = linalg.fill %cst, %0 {inplace= [true] }
-    // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
-    //     {inplace= [true] }
-    // %3 = vector.transfer_read %1, %cst
-    //
-    // In the above example:
-    // uRead             = OpOperand 0 (%1) of vector.transfer_read
-    // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
-    // definition        = %1
-    //
-    // This is not a conflict because the InsertSliceOp overwrites the
-    // memory segment of %1 with the exact same data. (Effectively, there
-    // is no memory write here.)
-    if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-        state.areEquivalentBufferizedValues(uRead->get(),
-                                            insertSliceOp.getSource()) &&
-        matchesInsertDestination(state, insertSliceOp.getSource(),
-                                 insertSliceOp))
-      return true;
-
-  return false;
-}
-
 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
 /// certain circumstances, this op can also be a no-op.
 ///
@@ -777,13 +667,6 @@ struct InsertSliceOpInterface
     return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
   }
 
-  bool isNotConflicting(Operation *op, OpOperand *uRead,
-                        OpOperand *uConflictingWrite,
-                        const AnalysisState &state) const {
-    return isNotConflictingInsertSliceLikeOp<tensor::InsertSliceOp>(
-        op, uRead, uConflictingWrite, state);
-  }
-
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     // insert_slice ops arise from tiling and bufferizing them out-of-place is
@@ -1092,13 +975,6 @@ struct ParallelInsertSliceOpInterface
     rewriter.eraseOp(op);
     return success();
   }
-
-  bool isNotConflicting(Operation *op, OpOperand *uRead,
-                        OpOperand *uConflictingWrite,
-                        const AnalysisState &state) const {
-    return isNotConflictingInsertSliceLikeOp<tensor::ParallelInsertSliceOp>(
-        op, uRead, uConflictingWrite, state);
-  }
 };
 
 /// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled
@@ -1147,6 +1023,58 @@ struct SplatOpInterface
   }
 };
 
+namespace {
+/// Return "true" if `insertSliceOp` inserts into a subset that is equivalent
+/// to the subset defined by `candidate`. `equivalenceFn` is used to determine
+/// equivalence of tensors.
+template <typename OpTy>
+bool isSubsetEquivalentToInsertSliceLikeOp(
+    OpTy insertSliceOp, OpResult candidate,
+    function_ref<bool(Value, Value)> equivalenceFn) {
+  // Look for a matching tensor.extract_slice op.
+  auto extractSliceOp = candidate.getDefiningOp<tensor::ExtractSliceOp>();
+  if (!extractSliceOp)
+    return false;
+  if (!equivalenceFn(extractSliceOp.getSource(), insertSliceOp.getDest()))
+    return false;
+  if (!sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
+                                  isEqualConstantIntOrValue))
+    return false;
+  return true;
+}
+} // namespace
+
+struct InsertSliceOpSubsetOpInterface
+    : public SubsetOpInterface::ExternalModel<InsertSliceOpSubsetOpInterface,
+                                              tensor::InsertSliceOp> {
+  OpOperand &getSource(Operation *op) const { return op->getOpOperand(0); }
+
+  bool
+  isEquivalentSubset(Operation *op, OpResult candidate,
+                     function_ref<bool(Value, Value)> equivalenceFn) const {
+    auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
+    return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
+                                                 equivalenceFn);
+  }
+};
+
+struct ParallelInsertSliceOpSubsetOpInterface
+    : public SubsetOpInterface::ExternalModel<
+          ParallelInsertSliceOpSubsetOpInterface,
+          tensor::ParallelInsertSliceOp> {
+  OpOperand &getSource(Operation *op) const { return op->getOpOperand(0); }
+
+  OpOperand &getDestination(Operation *op) const { return op->getOpOperand(1); }
+
+  bool
+  isEquivalentSubset(Operation *op, OpResult candidate,
+                     function_ref<bool(Value, Value)> equivalenceFn) const {
+    auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
+    return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
+                                                 equivalenceFn);
+  }
+};
+
 } // namespace
 } // namespace tensor
 } // namespace mlir
@@ -1154,6 +1082,7 @@ struct SplatOpInterface
 void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
+    // BufferizableOpInterface models.
     CastOp::attachInterface<CastOpInterface>(*ctx);
     CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
     DimOp::attachInterface<DimOpInterface>(*ctx);
@@ -1172,6 +1101,11 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
     ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
     SplatOp::attachInterface<SplatOpInterface>(*ctx);
 
+    // SubsetOpInterface models.
+    InsertSliceOp::attachInterface<InsertSliceOpSubsetOpInterface>(*ctx);
+    ParallelInsertSliceOp::attachInterface<
+        ParallelInsertSliceOpSubsetOpInterface>(*ctx);
+
     // Load additional dialects of which ops may get created.
     ctx->loadDialect<arith::ArithDialect, linalg::LinalgDialect>();
   });
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 7829bb0ffbd2932..b3645710232f2fa 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -7,14 +7,14 @@
 
 load("@bazel_skylib//rules:expand_template.bzl", "expand_template")
 load("@bazel_skylib//rules:write_file.bzl", "write_file")
-load(":tblgen.bzl", "gentbl_cc_library", "td_library")
-load(":linalggen.bzl", "genlinalg")
 load(
     ":build_defs.bzl",
     "cc_headers_only",
     "if_cuda_available",
     "mlir_c_api_cc_library",
 )
+load(":linalggen.bzl", "genlinalg")
+load(":tblgen.bzl", "gentbl_cc_library", "td_library")
 
 package(
     default_visibility = ["//visibility:public"],
@@ -9705,6 +9705,36 @@ gentbl_cc_library(
     ],
 )
 
+td_library(
+    name = "SubsetOpInterfaceTdFiles",
+    srcs = [
+        "include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.td",
+    ],
+    includes = ["include"],
+    deps = [
+        ":OpBaseTdFiles",
+    ],
+)
+
+gentbl_cc_library(
+    name = "SubsetOpInterfaceIncGen",
+    tbl_outs = [
+        (
+            ["-gen-op-interface-decls"],
+            "include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.h.inc",
+        ),
+        (
+            ["-gen-op-interface-defs"],
+            "include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.td",
+    deps = [
+        ":SubsetOpInterfaceTdFiles",
+    ],
+)
+
 td_library(
     name = "LinalgDocTdFiles",
     srcs = ["include/mlir/Dialect/Linalg/IR/LinalgDoc.td"],
@@ -11972,12 +12002,14 @@ cc_library(
         "lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp",
         "lib/Dialect/Bufferization/IR/BufferizationDialect.cpp",
         "lib/Dialect/Bufferization/IR/BufferizationOps.cpp",
+        "lib/Dialect/Bufferization/IR/SubsetOpInterface.cpp",
         "lib/Dialect/Bufferization/IR/UnstructuredControlFlow.cpp",
     ],
     hdrs = [
         "include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h",
         "include/mlir/Dialect/Bufferization/IR/Bufferization.h",
         "include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h",
+        "include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.h",
         "include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h",
     ],
     includes = ["include"],
@@ -11998,6 +12030,7 @@ cc_library(
         ":InferTypeOpInterface",
         ":MemRefDialect",
         ":SparseTensorDialect",
+        ":SubsetOpInterfaceIncGen",
         ":Support",
         ":TensorDialect",
         "//llvm:Support",



More information about the Mlir-commits mailing list