[llvm] [mlir][bufferization] Generalize tensor slice rules to subset ops (PR #65619)

Matthias Springer via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 13 02:52:37 PDT 2023


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/65619:

>From 5ca4a37f1efac7c86b123073cfec500d875b2dda Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 13 Sep 2023 11:42:34 +0200
Subject: [PATCH] [mlir][bufferization] Generalize tensor slice rules to subset
 ops

This commit generalizes the special tensor.extract_slice/tensor.insert_slice bufferization rules to tensor subset ops.

Ops that insert a tensor into a tensor at a specified subset (e.g., tensor.insert_slice, tensor.scatter) can implement the `SubsetOpInterface`.

Apart from adding a new op interface (extending the API), this change is NFC. The only ops that currently implement the new interface are tensor.insert_slice and tensor.parallel_insert_slice, and those ops were are supported by One-Shot Bufferize.
---
 .../Dialect/Bufferization/IR/CMakeLists.txt   |   1 +
 .../IR/SubsetInsertionOpInterface.h           |  29 ++++
 .../IR/SubsetInsertionOpInterface.td          | 119 ++++++++++++++++
 .../SubsetInsertionOpInterfaceImpl.h          |  21 +++
 mlir/include/mlir/InitAllDialects.h           |   2 +
 .../Dialect/Bufferization/IR/CMakeLists.txt   |   1 +
 .../IR/SubsetInsertionOpInterface.cpp         |  23 +++
 .../Transforms/OneShotAnalysis.cpp            | 106 ++++++++++++++
 .../BufferizableOpInterfaceImpl.cpp           | 131 +-----------------
 .../Dialect/Tensor/Transforms/CMakeLists.txt  |   1 +
 .../SubsetInsertionOpInterfaceImpl.cpp        |  82 +++++++++++
 .../llvm-project-overlay/mlir/BUILD.bazel     |  40 +++++-
 12 files changed, 425 insertions(+), 131 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h
 create mode 100644 mlir/include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.td
 create mode 100644 mlir/include/mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h
 create mode 100644 mlir/lib/Dialect/Bufferization/IR/SubsetInsertionOpInterface.cpp
 create mode 100644 mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
index 31a553f9a32f554..38057d4910d2958 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_doc(BufferizationOps BufferizationOps Dialects/ -gen-dialect-doc)
 add_mlir_interface(AllocationOpInterface)
 add_mlir_interface(BufferDeallocationOpInterface)
 add_mlir_interface(BufferizableOpInterface)
+add_mlir_interface(SubsetInsertionOpInterface)
 
 set(LLVM_TARGET_DEFINITIONS BufferizationEnums.td)
 mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h
new file mode 100644
index 000000000000000..e5b06d746e74bfd
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h
@@ -0,0 +1,29 @@
+//===- SubsetInsertionOpInterface.h - Tensor Subsets ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_IR_SUBSETINSERTIONOPINTERFACE_H_
+#define MLIR_DIALECT_BUFFERIZATION_IR_SUBSETINSERTIONOPINTERFACE_H_
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace bufferization {
+namespace detail {
+
+/// Return the destination/"init" operand of the op if it implements the
+/// `DestinationStyleOpInterface` and has exactly one "init" operand. Asserts
+/// otherwise.
+OpOperand &defaultGetDestinationOperand(Operation *op);
+
+} // namespace detail
+} // namespace bufferization
+} // namespace mlir
+
+#include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h.inc"
+
+#endif // MLIR_DIALECT_BUFFERIZATION_IR_SUBSETINSERTIONOPINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.td
new file mode 100644
index 000000000000000..edf652537795771
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.td
@@ -0,0 +1,119 @@
+//===-- SubsetInsertionOpInterface.td - Tensor Subsets -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SUBSET_INSERTION_OP_INTERFACE
+#define SUBSET_INSERTION_OP_INTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def SubsetInsertionOpInterface : OpInterface<"SubsetInsertionOpInterface"> {
+  let description = [{
+    This interface can be implemented by ops that insert a source tensor into
+    a destination tensor.
+
+    The elements in the destination tensor that are overwritten by this
+    insertion are called the "subset". How the subset is defined is up to the
+    op. E.g., "tensor.insert_slice" defines the subset via a hyperrectangular
+    slice. A scatter operation could define the subset via a list of indices.
+
+    Ops that deal with tensor subsets come in two flavours:
+    - Insertion flavor: Ops that insert a source tensor into a destination
+      tensor at the specified subset. Such ops usually return a new destination
+      tensor and implement the `DestinationStyleOpInterface`. Insertion ops can
+      implement the `SubsetInsertionOpInterface`. Example: "tensor.insert_slice"
+    - Extraction flavor: Ops that define a tensor subset. They extract a
+      specified subset from a tensor. There is currently no op interface for
+      such ops. Example: "tensor.extract_slice"
+
+    This interface provides helper methods for efficient bufferization of
+    subset-based tensor IR. Tensor subsets can bufferize to buffer "views"/
+    "aliases" (in contrast to one or multiple less efficient buffer allocation).
+
+    This interface is queried by One-Shot Bufferize to detect cases where a
+    seeming read-after-write is not actually a conflict because the respective
+    ops are operating on equivalent subsets. More details can be found in the
+    documentation of One-Shot Analysis (see `areNonConflictingSubsets`).
+
+    Note: This interface currently assumes that a subset op inserts a single
+    tensor (source) into a destination tensor at a single subset.
+  }];
+  let cppNamespace = "::mlir::bufferization";
+  let methods = [
+      InterfaceMethod<
+        /*desc=*/[{
+          Return the source tensor operand.
+        }],
+        /*retType=*/"::mlir::OpOperand &",
+        /*methodName=*/"getSourceOperand",
+        /*args=*/(ins)
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return the destination tensor operand.
+        }],
+        /*retType=*/"::mlir::OpOperand &",
+        /*methodName=*/"getDestinationOperand",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return ::mlir::bufferization::detail::defaultGetDestinationOperand(
+              $_op.getOperation());
+        }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return "true" if this operation inserts into a subset that is
+          equivalent to the subset defined by `candidate`.
+
+          Two subsets are "equivalent" and "same" if they can bufferize to the
+          same buffer views/aliases. If they are "equivalent", the tensor IR
+          may be expressed in terms of different SSA values (but they could
+          bufferize to MemRef SSA values that can CSE without breaking
+          correctness). `equivalenceFn` should return "true" if the two given
+          values are equivalent.
+
+          Example:
+          ```
+          // The subset of the SubsetInsertionOpInterface op %1 is equivalent to
+          // the subset defined by %2 (but not "same"):
+          %0 = arith.select %c, %t, %t : tensor<?xf32>
+          %1 = tensor.insert_slice %x into %0[0][5][1]
+              : tensor<5xf32> into tensor<?xf32>
+          %2 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+
+          // The subset of the SubsetInsertionOpInterface op %1 is equivalent to
+          // and "same" as the subset defined by %2.
+          %1 = tensor.insert_slice %x into %t[0][5][1]
+              : tensor<5xf32> into tensor<?xf32>
+          %2 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+          ```
+        }],
+        /*retType=*/"bool",
+        /*methodName=*/"isEquivalentSubset",
+        /*args=*/(ins
+            "::mlir::Value":$candidate,
+            "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn)
+      >,
+  ];
+
+  let extraClassDeclaration = [{
+    /// Return "true" if this operation inserts into the same subset as defined
+    /// by `candidate`.
+    ///
+    /// Note: This function is useful outside of bufferization, where no tensor
+    /// equivalence information is available.
+    bool isSameSubset(OpResult candidate) {
+      auto subsetOp = cast<::mlir::bufferization::SubsetInsertionOpInterface>(
+          getOperation());
+      return subsetOp.isEquivalentSubset(
+          candidate, [](Value v1, Value v2) { return v1 == v2; });
+    }
+  }];
+}
+
+#endif // SUBSET_INSERTION_OP_INTERFACE
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h
new file mode 100644
index 000000000000000..e21b07d8a2705a0
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h
@@ -0,0 +1,21 @@
+//===- SubsetInsertionOpInterfaceImpl.h - Tensor subsets ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TENSOR_SUBSETINSERTIONOPINTERFACEIMPL_H
+#define MLIR_DIALECT_TENSOR_SUBSETINSERTIONOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace tensor {
+void registerSubsetInsertionOpInterfaceExternalModels(
+    DialectRegistry &registry);
+} // namespace tensor
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TENSOR_SUBSETINSERTIONOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 06e3a654e573ffa..5b2b1ed24d5173d 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -74,6 +74,7 @@
 #include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
@@ -158,6 +159,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   tensor::registerBufferizableOpInterfaceExternalModels(registry);
   tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry);
   tensor::registerInferTypeOpInterfaceExternalModels(registry);
+  tensor::registerSubsetInsertionOpInterfaceExternalModels(registry);
   tensor::registerTilingInterfaceExternalModels(registry);
   tensor::registerValueBoundsOpInterfaceExternalModels(registry);
   vector::registerBufferizableOpInterfaceExternalModels(registry);
diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
index 492173b59616c2e..b1940e40ba34114 100644
--- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
   BufferDeallocationOpInterface.cpp
   BufferizationOps.cpp
   BufferizationDialect.cpp
+  SubsetInsertionOpInterface.cpp
   UnstructuredControlFlow.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Bufferization/IR/SubsetInsertionOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/SubsetInsertionOpInterface.cpp
new file mode 100644
index 000000000000000..19a6fbba403c779
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/IR/SubsetInsertionOpInterface.cpp
@@ -0,0 +1,23 @@
+//===- SubsetInsertionOpInterface.cpp - Tensor Subsets --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+
+#include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.cpp.inc"
+
+using namespace mlir;
+
+OpOperand &bufferization::detail::defaultGetDestinationOperand(Operation *op) {
+  auto dstOp = dyn_cast<DestinationStyleOpInterface>(op);
+  assert(dstOp && "getDestination must be implemented for non-DPS ops");
+  assert(
+      dstOp.getNumDpsInits() == 1 &&
+      "getDestination must be implemented for ops with 0 or more than 1 init");
+  return *dstOp.getDpsInitOperand(0);
+}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index e314193ff2b567d..09205388a644720 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -45,6 +45,7 @@
 
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h"
 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -531,6 +532,105 @@ static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state,
               .empty();
 }
 
+/// Return "true" if `value` is originating from a subset that is equivalent to
+/// the subset that `subsetOp` inserts into.
+static bool matchesInsertDestination(const AnalysisState &state, Value value,
+                                     SubsetInsertionOpInterface subsetOp) {
+  auto matchingSubset = [&](Value val) {
+    if (auto opResult = dyn_cast<OpResult>(val))
+      if (subsetOp.isEquivalentSubset(opResult, [&](Value v1, Value v2) {
+            return state.areEquivalentBufferizedValues(v1, v2);
+          }))
+        return true;
+    return false;
+  };
+  // There may be multiple leaves at which the reverse SSA use-def chain lookup
+  // terminates. All of them must be equivalent subsets.
+  SetVector<Value> backwardSlice =
+      state.findValueInReverseUseDefChain(value, matchingSubset);
+  return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset));
+}
+
+/// Return "true" if the given "read" and potentially conflicting "write" are
+/// not conflicting due to their subset relationship. The comments in this
+/// function are expressed in terms of tensor.extract_slice/tensor.insert_slice
+/// pairs, but apply to any subset ops that implement the
+/// `SubsetInsertionOpInterface`.
+static bool areNonConflictingSubsets(OpOperand *uRead,
+                                     OpOperand *uConflictingWrite,
+                                     const AnalysisState &state) {
+  Operation *readingOp = uRead->getOwner();
+  Operation *conflictingWritingOp = uConflictingWrite->getOwner();
+
+  // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
+  // uRead is an InsertSliceOp...
+  if (auto subsetOp = dyn_cast<SubsetInsertionOpInterface>(readingOp)) {
+    // As an example, consider the following IR.
+    //
+    // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+    // %1 = linalg.fill %cst, %0 {inplace= [true] }
+    // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+    //     {inplace= [true] }
+
+    if (uRead == &subsetOp.getDestinationOperand() &&
+        matchesInsertDestination(state, uConflictingWrite->get(), subsetOp))
+      // Case 1: The main insight is that InsertSliceOp reads only part of
+      // the destination tensor. The overwritten area is not read. If
+      // uConflictingWrite writes into exactly the memory location that is
+      // being read by uRead, this is not a conflict.
+      //
+      // In the above example:
+      // uRead             = OpOperand 1 (%t) of tensor.insert_slice
+      // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
+      //
+      // The read of %t does not conflict with the write of the FillOp
+      // (same aliases!) because the area that the FillOp operates on is
+      // exactly the one that is *not* read via %t.
+      return true;
+
+    if (uRead == &subsetOp.getSourceOperand() &&
+        uConflictingWrite == &subsetOp.getDestinationOperand() &&
+        matchesInsertDestination(state, uRead->get(), subsetOp))
+      // Case 2: The read of the source tensor and the write to the dest
+      // tensor via an InsertSliceOp is not a conflict if the read is
+      // reading exactly that part of an equivalent tensor that the
+      // InsertSliceOp is writing.
+      //
+      // In the above example:
+      // uRead             = OpOperand 0 (%1) of tensor.insert_slice
+      // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+      return true;
+  }
+
+  // If uConflictingWrite is an InsertSliceOp...
+  if (auto subsetOp =
+          dyn_cast<SubsetInsertionOpInterface>(conflictingWritingOp))
+    // As an example, consider the following IR.
+    //
+    // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+    // %1 = linalg.fill %cst, %0 {inplace= [true] }
+    // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+    //     {inplace= [true] }
+    // %3 = vector.transfer_read %1, %cst
+    //
+    // In the above example:
+    // uRead             = OpOperand 0 (%1) of vector.transfer_read
+    // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+    // definition        = %1
+    //
+    // This is not a conflict because the InsertSliceOp overwrites the
+    // memory segment of %1 with the exact same data. (Effectively, there
+    // is no memory write here.)
+    if (uConflictingWrite == &subsetOp.getDestinationOperand() &&
+        state.areEquivalentBufferizedValues(
+            uRead->get(), subsetOp.getSourceOperand().get()) &&
+        matchesInsertDestination(state, subsetOp.getSourceOperand().get(),
+                                 subsetOp))
+      return true;
+
+  return false;
+}
+
 /// Given sets of uses and writes, return true if there is a RaW conflict under
 /// the assumption that all given reads/writes alias the same buffer and that
 /// all given writes bufferize inplace.
@@ -684,6 +784,12 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
         }
       }
 
+      // No conflict if the operands are non-conflicting subsets.
+      if (areNonConflictingSubsets(uRead, uConflictingWrite, state)) {
+        LLVM_DEBUG(llvm::dbgs() << "  no conflict: non-conflicting subsets\n");
+        continue;
+      }
+
       // No conflict if the op interface says so.
       if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {
         if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 89a929365cf0b89..b715cf054f1da5e 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
+
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
@@ -16,6 +17,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Operation.h"
@@ -619,117 +621,6 @@ struct InsertOpInterface
   }
 };
 
-/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
-/// equivalent operand / result and same offset/sizes/strides specification).
-template <typename OpTy>
-static bool areEquivalentSlices(const AnalysisState &state,
-                                ExtractSliceOp extractSliceOp,
-                                OpTy insertSliceOp) {
-  if (!extractSliceOp || !insertSliceOp)
-    return false;
-  if (extractSliceOp != insertSliceOp &&
-      !state.areEquivalentBufferizedValues(extractSliceOp.getSource(),
-                                           insertSliceOp.getDest()))
-    return false;
-  if (!sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
-                                  isEqualConstantIntOrValue))
-    return false;
-  return true;
-}
-
-/// Return true if `value` is originating from an ExtractSliceOp that matches
-/// the given InsertSliceOp.
-template <typename OpTy>
-static bool matchesInsertDestination(const AnalysisState &state, Value value,
-                                     OpTy insertSliceOp) {
-  // Look for matching slices.
-  auto matchesSlice = [&](Value val) {
-    if (auto extractSliceOp = val.getDefiningOp<ExtractSliceOp>())
-      if (areEquivalentSlices(state, extractSliceOp, insertSliceOp))
-        return true;
-    return false;
-  };
-  return static_cast<bool>(llvm::all_of(
-      state.findValueInReverseUseDefChain(value, matchesSlice), matchesSlice));
-}
-
-template <typename OpTy>
-static bool isNotConflictingInsertSliceLikeOp(Operation *op, OpOperand *uRead,
-                                              OpOperand *uConflictingWrite,
-                                              const AnalysisState &state) {
-  Operation *readingOp = uRead->getOwner();
-  Operation *conflictingWritingOp = uConflictingWrite->getOwner();
-
-  // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
-  // uRead is an InsertSliceOp...
-  if (auto insertSliceOp = dyn_cast<OpTy>(readingOp)) {
-    // As an example, consider the following IR.
-    //
-    // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
-    // %1 = linalg.fill %cst, %0 {inplace= [true] }
-    // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
-    //     {inplace= [true] }
-
-    // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
-    if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-        matchesInsertDestination(state, uConflictingWrite->get(),
-                                 insertSliceOp))
-      // Case 1: The main insight is that InsertSliceOp reads only part of
-      // the destination tensor. The overwritten area is not read. If
-      // uConflictingWrite writes into exactly the memory location that is
-      // being read by uRead, this is not a conflict.
-      //
-      // In the above example:
-      // uRead             = OpOperand 1 (%t) of tensor.insert_slice
-      // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
-      //
-      // The read of %t does not conflict with the write of the FillOp
-      // (same aliases!) because the area that the FillOp operates on is
-      // exactly the one that is *not* read via %t.
-      return true;
-
-    if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
-        uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-        matchesInsertDestination(state, uRead->get(), insertSliceOp))
-      // Case 2: The read of the source tensor and the write to the dest
-      // tensor via an InsertSliceOp is not a conflict if the read is
-      // reading exactly that part of an equivalent tensor that the
-      // InsertSliceOp is writing.
-      //
-      // In the above example:
-      // uRead             = OpOperand 0 (%1) of tensor.insert_slice
-      // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
-      return true;
-  }
-
-  // If uConflictingWrite is an InsertSliceOp...
-  if (auto insertSliceOp = dyn_cast<OpTy>(conflictingWritingOp))
-    // As an example, consider the following IR.
-    //
-    // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
-    // %1 = linalg.fill %cst, %0 {inplace= [true] }
-    // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
-    //     {inplace= [true] }
-    // %3 = vector.transfer_read %1, %cst
-    //
-    // In the above example:
-    // uRead             = OpOperand 0 (%1) of vector.transfer_read
-    // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
-    // definition        = %1
-    //
-    // This is not a conflict because the InsertSliceOp overwrites the
-    // memory segment of %1 with the exact same data. (Effectively, there
-    // is no memory write here.)
-    if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-        state.areEquivalentBufferizedValues(uRead->get(),
-                                            insertSliceOp.getSource()) &&
-        matchesInsertDestination(state, insertSliceOp.getSource(),
-                                 insertSliceOp))
-      return true;
-
-  return false;
-}
-
 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
 /// certain circumstances, this op can also be a no-op.
 ///
@@ -768,13 +659,6 @@ struct InsertSliceOpInterface
     return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
   }
 
-  bool isNotConflicting(Operation *op, OpOperand *uRead,
-                        OpOperand *uConflictingWrite,
-                        const AnalysisState &state) const {
-    return isNotConflictingInsertSliceLikeOp<tensor::InsertSliceOp>(
-        op, uRead, uConflictingWrite, state);
-  }
-
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     // insert_slice ops arise from tiling and bufferizing them out-of-place is
@@ -1079,13 +963,6 @@ struct ParallelInsertSliceOpInterface
     rewriter.eraseOp(op);
     return success();
   }
-
-  bool isNotConflicting(Operation *op, OpOperand *uRead,
-                        OpOperand *uConflictingWrite,
-                        const AnalysisState &state) const {
-    return isNotConflictingInsertSliceLikeOp<tensor::ParallelInsertSliceOp>(
-        op, uRead, uConflictingWrite, state);
-  }
 };
 
 /// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled
@@ -1157,4 +1034,8 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
     // Load additional dialects of which ops may get created.
     ctx->loadDialect<arith::ArithDialect, linalg::LinalgDialect>();
   });
+
+  // Bufferization requires SubsetInsertionOpInterface models. Make sure that
+  // they are registered.
+  tensor::registerSubsetInsertionOpInterfaceExternalModels(registry);
 }
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index 251c129b5383f06..0d148278cec519e 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
   ReshapePatterns.cpp
   RewriteAsConstant.cpp
   SwapExtractSliceWithProducerPatterns.cpp
+  SubsetInsertionOpInterfaceImpl.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Transforms
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
new file mode 100644
index 000000000000000..1156f2501a96ee1
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
@@ -0,0 +1,82 @@
+//===- SubsetInsertionOpInterfaceImpl.cpp - Tensor subsets ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
+
+#include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+
+using namespace mlir;
+using namespace mlir::bufferization;
+using namespace mlir::tensor;
+
+namespace {
+
+/// Return "true" if `insertSliceOp` inserts into a subset that is equivalent
+/// to the subset defined by `candidate`. `equivalenceFn` is used to determine
+/// equivalence of tensors.
+template <typename OpTy>
+bool isSubsetEquivalentToInsertSliceLikeOp(
+    OpTy insertSliceOp, Value candidate,
+    function_ref<bool(Value, Value)> equivalenceFn) {
+  // Look for a matching tensor.extract_slice op.
+  auto extractSliceOp = candidate.getDefiningOp<tensor::ExtractSliceOp>();
+  if (!extractSliceOp)
+    return false;
+  if (!equivalenceFn(extractSliceOp.getSource(), insertSliceOp.getDest()))
+    return false;
+  return sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
+                                    isEqualConstantIntOrValue);
+}
+
+struct InsertSliceOpInterface
+    : public SubsetInsertionOpInterface::ExternalModel<InsertSliceOpInterface,
+                                                       tensor::InsertSliceOp> {
+  OpOperand &getSourceOperand(Operation *op) const {
+    return op->getOpOperand(0);
+  }
+
+  bool
+  isEquivalentSubset(Operation *op, Value candidate,
+                     function_ref<bool(Value, Value)> equivalenceFn) const {
+    auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
+    return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
+                                                 equivalenceFn);
+  }
+};
+
+struct ParallelInsertSliceOpInterface
+    : public SubsetInsertionOpInterface::ExternalModel<
+          ParallelInsertSliceOpInterface, tensor::ParallelInsertSliceOp> {
+  OpOperand &getSourceOperand(Operation *op) const {
+    return op->getOpOperand(0);
+  }
+
+  OpOperand &getDestinationOperand(Operation *op) const {
+    return op->getOpOperand(1);
+  }
+
+  bool
+  isEquivalentSubset(Operation *op, Value candidate,
+                     function_ref<bool(Value, Value)> equivalenceFn) const {
+    auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
+    return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
+                                                 equivalenceFn);
+  }
+};
+
+} // namespace
+
+void mlir::tensor::registerSubsetInsertionOpInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
+    InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
+    ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
+        *ctx);
+  });
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index bfaaa8a82094d0e..0b475bcb117b62f 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6638,12 +6638,7 @@ cc_library(
             "lib/Dialect/Tensor/Transforms/*.h",
         ],
     ),
-    hdrs = [
-        "include/mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h",
-        "include/mlir/Dialect/Tensor/Transforms/Passes.h",
-        "include/mlir/Dialect/Tensor/Transforms/TransformUtils.h",
-        "include/mlir/Dialect/Tensor/Transforms/Transforms.h",
-    ],
+    hdrs = glob(["include/mlir/Dialect/Tensor/Transforms/*.h"]),
     includes = ["include"],
     deps = [
         ":AffineDialect",
@@ -9786,6 +9781,17 @@ td_library(
     ],
 )
 
+td_library(
+    name = "SubsetInsertionOpInterfaceTdFiles",
+    srcs = [
+        "include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.td",
+    ],
+    includes = ["include"],
+    deps = [
+        ":OpBaseTdFiles",
+    ],
+)
+
 gentbl_cc_library(
     name = "BufferDeallocationOpInterfaceIncGen",
     tbl_outs = [
@@ -9805,6 +9811,25 @@ gentbl_cc_library(
     ],
 )
 
+gentbl_cc_library(
+    name = "SubsetInsertionOpInterfaceIncGen",
+    tbl_outs = [
+        (
+            ["-gen-op-interface-decls"],
+            "include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h.inc",
+        ),
+        (
+            ["-gen-op-interface-defs"],
+            "include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.td",
+    deps = [
+        ":SubsetInsertionOpInterfaceTdFiles",
+    ],
+)
+
 td_library(
     name = "LinalgDocTdFiles",
     srcs = ["include/mlir/Dialect/Linalg/IR/LinalgDoc.td"],
@@ -12076,6 +12101,7 @@ cc_library(
         "lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp",
         "lib/Dialect/Bufferization/IR/BufferizationDialect.cpp",
         "lib/Dialect/Bufferization/IR/BufferizationOps.cpp",
+        "lib/Dialect/Bufferization/IR/SubsetInsertionOpInterface.cpp",
         "lib/Dialect/Bufferization/IR/UnstructuredControlFlow.cpp",
     ],
     hdrs = [
@@ -12083,6 +12109,7 @@ cc_library(
         "include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h",
         "include/mlir/Dialect/Bufferization/IR/Bufferization.h",
         "include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h",
+        "include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h",
         "include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h",
     ],
     includes = ["include"],
@@ -12105,6 +12132,7 @@ cc_library(
         ":InferTypeOpInterface",
         ":MemRefDialect",
         ":SparseTensorDialect",
+        ":SubsetInsertionOpInterfaceIncGen",
         ":Support",
         ":TensorDialect",
         "//llvm:Support",



More information about the llvm-commits mailing list