[Mlir-commits] [mlir] [mlir][bufferization] Generalize tensor slice rules to subset ops (PR #65619)
Matthias Springer
llvmlistbot at llvm.org
Thu Sep 7 08:17:21 PDT 2023
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/65619:
This commit generalizes the special tensor.extract_slice/tensor.insert_slice bufferization rules to tensor subset ops.
Ops that insert a tensor into a tensor at a specified subset (e.g., tensor.insert_slice, tensor.scatter) can implement the `SubsetOpInterface`.
Apart from adding a new op interface (extending the API), this change is NFC. The only ops that currently implement the new interface are tensor.insert_slice and tensor.parallel_insert_slice, and those ops were are supported by One-Shot Bufferize.
>From 41ff6b366fb4ad1ada3e5e4e4c8b1917bb5e84be Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 7 Sep 2023 17:16:26 +0200
Subject: [PATCH] [mlir][bufferization] Generalize tensor slice rules to subset
ops
This commit generalizes the special tensor.extract_slice/tensor.insert_slice bufferization rules to tensor subset ops.
Ops that insert a tensor into a tensor at a specified subset (e.g., tensor.insert_slice, tensor.scatter) can implement the `SubsetOpInterface`.
Apart from adding a new op interface (extending the API), this change is NFC. The only ops that currently implement the new interface are tensor.insert_slice and tensor.parallel_insert_slice, and those ops were are supported by One-Shot Bufferize.
---
.../Dialect/Bufferization/IR/CMakeLists.txt | 1 +
.../Bufferization/IR/SubsetOpInterface.h | 31 +++
.../Bufferization/IR/SubsetOpInterface.td | 119 +++++++++++
.../Dialect/Bufferization/IR/CMakeLists.txt | 1 +
.../Bufferization/IR/SubsetOpInterface.cpp | 23 +++
.../Transforms/OneShotAnalysis.cpp | 103 ++++++++++
.../BufferizableOpInterfaceImpl.cpp | 184 ++++++------------
.../llvm-project-overlay/mlir/BUILD.bazel | 37 +++-
8 files changed, 372 insertions(+), 127 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.h
create mode 100644 mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.td
create mode 100644 mlir/lib/Dialect/Bufferization/IR/SubsetOpInterface.cpp
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
index aa93534a78fea3f..4a97493cbf42a49 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect(BufferizationOps bufferization)
add_mlir_doc(BufferizationOps BufferizationOps Dialects/ -gen-dialect-doc)
add_mlir_interface(AllocationOpInterface)
add_mlir_interface(BufferizableOpInterface)
+add_mlir_interface(SubsetOpInterface)
set(LLVM_TARGET_DEFINITIONS BufferizationEnums.td)
mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.h
new file mode 100644
index 000000000000000..9ab7f11d42fbb56
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.h
@@ -0,0 +1,31 @@
+//===- SubsetOpInterface.h - Tensor subsets ---------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_IR_SUBSETOPINTERFACE_H_
+#define MLIR_DIALECT_BUFFERIZATION_IR_SUBSETOPINTERFACE_H_
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+class OpBuilder;
+
+namespace bufferization {
+namespace detail {
+
+/// Return the destination/"init" operand of the op if it implements the
+/// `DestinationStyleOpInterface` and has exactly one "init" operand. Asserts
+/// otherwise.
+OpOperand &defaultGetDestination(Operation *op);
+
+} // namespace detail
+} // namespace bufferization
+} // namespace mlir
+
+#include "mlir/Dialect/Bufferization/IR/SubsetOpInterface.h.inc"
+
+#endif // MLIR_DIALECT_BUFFERIZATION_IR_SUBSETOPINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.td
new file mode 100644
index 000000000000000..e7686ac8db9f8dc
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.td
@@ -0,0 +1,119 @@
+//===-- SubsetOpInterface.td - Tensor subsets --------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SUBSET_OP_INTERFACE
+#define SUBSET_OP_INTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def SubsetOpInterface : OpInterface<"SubsetOpInterface"> {
+ let description = [{
+ This interface can be implemented by ops that insert a source tensor into
+ a destination tensor.
+
+ The elements in the destination tensor that are overwritten by this
+ insertion are called the "subset". How the subset is defined is up to the
+ op. E.g., "tensor.insert_slice" defines the subset via a hyperrectangular
+ slice. A scatter operation could define the subset via a list of indices.
+
+ Ops that deal with tensor subsets come in two flavours:
+ - Insertion flavor: Ops that insert a source tensor into a destination
+ tensor at the specified subset. Such ops usuallt return a new destination
+ tensor and implement the `DestinationStyleOpInterface`. Insertion ops can
+ implement the `SubsetOpInterface`. Example: "tensor.insert_slice"
+ - Extraction flavor: Ops that define a tensor subset. They extract a
+ specified subset from a tensor. There is currently no op interface for
+ such ops. Example: "tensor.extract_slice"
+
+ This interface provides helper methods for efficient bufferization of
+ subset-based tensor IR. Tensor subsets can bufferize to buffer "views"/
+ "aliases" (in contrast to one or multiple less efficient buffer allocation).
+
+ This interface is queried by One-Shot Bufferize to detect cases where a
+ seeming read-after-write is not actually a conflict because the respective
+ ops are operating on equivalent subsets. More details can be found in the
+ documentation of One-Shot Analysis.
+
+ Note: This interface currently assumes that a subset op inserts a single
+ tensor (source) into a destination tensor at a single subset.
+ }];
+ let cppNamespace = "::mlir::bufferization";
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the source tensor operand.
+ }],
+ /*retType=*/"::mlir::OpOperand &",
+ /*methodName=*/"getSource",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ llvm_unreachable("getSource not implemented");
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the destination tensor operand.
+ }],
+ /*retType=*/"::mlir::OpOperand &",
+ /*methodName=*/"getDestination",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::bufferization::detail::defaultGetDestination(
+ $_op.getOperation());
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return "true" if this operation inserts into a subset that is
+ equivalent to the subset defined by `candidate`.
+
+ Two subsets are "equivalent" and "same" if they can bufferize to the
+ same buffer views/aliases. If they are "equivalent", the tensor IR
+ may be expressed in terms of different SSA values.
+
+ Example:
+ ```
+ // %1 and %2 are equivalent (but not "same"):
+ %0 = arith.select %c, %t, %t : tensor<?xf32>
+ %1 = tensor.extract_slice %0[0][5][1] : tensor<?xf32> to tensor<5xf32>
+ %2 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+
+ // %1 and %2 are equivalent and "same"
+ %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+ %2 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+ ```
+ }],
+ /*retType=*/"bool",
+ /*methodName=*/"isEquivalentSubset",
+ /*args=*/(ins
+ "::mlir::OpResult":$candidate,
+ "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ llvm_unreachable("isEquivalentSubset not implemented");
+ }]
+ >,
+ ];
+
+ let extraClassDeclaration = [{
+ /// Return "true" if this operation inserts into the same subset as defined
+ /// by `candidate`.
+ ///
+ /// Note: This function is useful outside of bufferization, where no tensor
+ /// equivalence information is available.
+ bool isSameSubset(OpResult candidate) {
+ return cast<::mlir::bufferization::SubsetOpInterface>(getOperation())
+ .isEquivalentSubset(candidate,
+ [](Value v1, Value v2) { return v1 == v2; });
+ }
+ }];
+}
+
+#endif // INFER_DESTINATION_OP_INTERFACE
diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
index 2d8d09b9c41d993..f406595fc1e9fc9 100644
--- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
BufferizableOpInterface.cpp
BufferizationOps.cpp
BufferizationDialect.cpp
+ SubsetOpInterface.cpp
UnstructuredControlFlow.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Bufferization/IR/SubsetOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/SubsetOpInterface.cpp
new file mode 100644
index 000000000000000..87804c50a530fbc
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/IR/SubsetOpInterface.cpp
@@ -0,0 +1,23 @@
+//===- SubsetOpInterface.cpp - Tensor subsets -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/SubsetOpInterface.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+
+#include "mlir/Dialect/Bufferization/IR/SubsetOpInterface.cpp.inc"
+
+using namespace mlir;
+
+OpOperand &bufferization::detail::defaultGetDestination(Operation *op) {
+ auto dstOp = dyn_cast<DestinationStyleOpInterface>(op);
+ assert(dstOp && "getDestination must be implemented for non-DPS ops");
+ assert(
+ dstOp.getNumDpsInits() == 1 &&
+ "getDestination must be implemented for ops with 0 or more than 1 init");
+ return *dstOp.getDpsInitOperand(0);
+}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 49b5ebdf722a1a7..c613b77c10dca57 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -45,6 +45,7 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/SubsetOpInterface.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -531,6 +532,102 @@ static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state,
.empty();
}
+/// Return "true" if `value` is originating from a subset that is equivalent to
+/// the subset that `subsetOp` inserts into.
+static bool matchesInsertDestination(const AnalysisState &state, Value value,
+ SubsetOpInterface subsetOp) {
+ auto matchingSubset = [&](Value val) {
+ if (auto opResult = dyn_cast<OpResult>(val))
+ if (subsetOp.isEquivalentSubset(opResult, [&](Value v1, Value v2) {
+ return state.areEquivalentBufferizedValues(v1, v2);
+ }))
+ return true;
+ return false;
+ };
+ // There may be multiple leaves at which the reverse SSA use-def chain lookup
+ // terminates. All of them must be equivalent subsets.
+ SetVector<Value> backwardSlice =
+ state.findValueInReverseUseDefChain(value, matchingSubset);
+ return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset));
+}
+
+/// Return "true" if the given "read" and potentially conflicting "write" are
+/// not conflicting due to their subset relationship. The comments in this
+/// function are expressed in terms of tensor.extract_slice/tensor.insert_slice
+/// pairs, but apply to any subset ops that implement the `SubsetOpInterface`.
+static bool areNonConflictingSubsets(OpOperand *uRead,
+ OpOperand *uConflictingWrite,
+ const AnalysisState &state) {
+ Operation *readingOp = uRead->getOwner();
+ Operation *conflictingWritingOp = uConflictingWrite->getOwner();
+
+ // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
+ // uRead is an InsertSliceOp...
+ if (auto subsetOp = dyn_cast<SubsetOpInterface>(readingOp)) {
+ // As an example, consider the following IR.
+ //
+ // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+ // %1 = linalg.fill %cst, %0 {inplace= [true] }
+ // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+ // {inplace= [true] }
+
+ if (uRead == &subsetOp.getDestination() &&
+ matchesInsertDestination(state, uConflictingWrite->get(), subsetOp))
+ // Case 1: The main insight is that InsertSliceOp reads only part of
+ // the destination tensor. The overwritten area is not read. If
+ // uConflictingWrite writes into exactly the memory location that is
+ // being read by uRead, this is not a conflict.
+ //
+ // In the above example:
+ // uRead = OpOperand 1 (%t) of tensor.insert_slice
+ // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
+ //
+ // The read of %t does not conflict with the write of the FillOp
+ // (same aliases!) because the area that the FillOp operates on is
+ // exactly the one that is *not* read via %t.
+ return true;
+
+ if (uRead == &subsetOp.getSource() &&
+ uConflictingWrite == &subsetOp.getDestination() &&
+ matchesInsertDestination(state, uRead->get(), subsetOp))
+ // Case 2: The read of the source tensor and the write to the dest
+ // tensor via an InsertSliceOp is not a conflict if the read is
+ // reading exactly that part of an equivalent tensor that the
+ // InsertSliceOp is writing.
+ //
+ // In the above example:
+ // uRead = OpOperand 0 (%1) of tensor.insert_slice
+ // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+ return true;
+ }
+
+ // If uConflictingWrite is an InsertSliceOp...
+ if (auto subsetOp = dyn_cast<SubsetOpInterface>(conflictingWritingOp))
+ // As an example, consider the following IR.
+ //
+ // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+ // %1 = linalg.fill %cst, %0 {inplace= [true] }
+ // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+ // {inplace= [true] }
+ // %3 = vector.transfer_read %1, %cst
+ //
+ // In the above example:
+ // uRead = OpOperand 0 (%1) of vector.transfer_read
+ // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+ // definition = %1
+ //
+ // This is not a conflict because the InsertSliceOp overwrites the
+ // memory segment of %1 with the exact same data. (Effectively, there
+ // is no memory write here.)
+ if (uConflictingWrite == &subsetOp.getDestination() &&
+ state.areEquivalentBufferizedValues(uRead->get(),
+ subsetOp.getSource().get()) &&
+ matchesInsertDestination(state, subsetOp.getSource().get(), subsetOp))
+ return true;
+
+ return false;
+}
+
/// Given sets of uses and writes, return true if there is a RaW conflict under
/// the assumption that all given reads/writes alias the same buffer and that
/// all given writes bufferize inplace.
@@ -647,6 +744,12 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
}
}
+ // No conflict if the operands are non-conflicting subsets.
+ if (areNonConflictingSubsets(uRead, uConflictingWrite, state)) {
+ LLVM_DEBUG(llvm::dbgs() << " no conflict: non-conflicting subsets\n");
+ continue;
+ }
+
// No conflict if the op interface says so.
if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index a67ea0334b22b9b..23bb4c34928e6ac 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/IR/SubsetOpInterface.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -628,117 +629,6 @@ struct InsertOpInterface
}
};
-/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
-/// equivalent operand / result and same offset/sizes/strides specification).
-template <typename OpTy>
-static bool areEquivalentSlices(const AnalysisState &state,
- ExtractSliceOp extractSliceOp,
- OpTy insertSliceOp) {
- if (!extractSliceOp || !insertSliceOp)
- return false;
- if (extractSliceOp != insertSliceOp &&
- !state.areEquivalentBufferizedValues(extractSliceOp.getSource(),
- insertSliceOp.getDest()))
- return false;
- if (!sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
- isEqualConstantIntOrValue))
- return false;
- return true;
-}
-
-/// Return true if `value` is originating from an ExtractSliceOp that matches
-/// the given InsertSliceOp.
-template <typename OpTy>
-static bool matchesInsertDestination(const AnalysisState &state, Value value,
- OpTy insertSliceOp) {
- // Look for matching slices.
- auto matchesSlice = [&](Value val) {
- if (auto extractSliceOp = val.getDefiningOp<ExtractSliceOp>())
- if (areEquivalentSlices(state, extractSliceOp, insertSliceOp))
- return true;
- return false;
- };
- return static_cast<bool>(llvm::all_of(
- state.findValueInReverseUseDefChain(value, matchesSlice), matchesSlice));
-}
-
-template <typename OpTy>
-static bool isNotConflictingInsertSliceLikeOp(Operation *op, OpOperand *uRead,
- OpOperand *uConflictingWrite,
- const AnalysisState &state) {
- Operation *readingOp = uRead->getOwner();
- Operation *conflictingWritingOp = uConflictingWrite->getOwner();
-
- // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
- // uRead is an InsertSliceOp...
- if (auto insertSliceOp = dyn_cast<OpTy>(readingOp)) {
- // As an example, consider the following IR.
- //
- // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
- // %1 = linalg.fill %cst, %0 {inplace= [true] }
- // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
- // {inplace= [true] }
-
- // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
- if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
- matchesInsertDestination(state, uConflictingWrite->get(),
- insertSliceOp))
- // Case 1: The main insight is that InsertSliceOp reads only part of
- // the destination tensor. The overwritten area is not read. If
- // uConflictingWrite writes into exactly the memory location that is
- // being read by uRead, this is not a conflict.
- //
- // In the above example:
- // uRead = OpOperand 1 (%t) of tensor.insert_slice
- // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
- //
- // The read of %t does not conflict with the write of the FillOp
- // (same aliases!) because the area that the FillOp operates on is
- // exactly the one that is *not* read via %t.
- return true;
-
- if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
- uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
- matchesInsertDestination(state, uRead->get(), insertSliceOp))
- // Case 2: The read of the source tensor and the write to the dest
- // tensor via an InsertSliceOp is not a conflict if the read is
- // reading exactly that part of an equivalent tensor that the
- // InsertSliceOp is writing.
- //
- // In the above example:
- // uRead = OpOperand 0 (%1) of tensor.insert_slice
- // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
- return true;
- }
-
- // If uConflictingWrite is an InsertSliceOp...
- if (auto insertSliceOp = dyn_cast<OpTy>(conflictingWritingOp))
- // As an example, consider the following IR.
- //
- // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
- // %1 = linalg.fill %cst, %0 {inplace= [true] }
- // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
- // {inplace= [true] }
- // %3 = vector.transfer_read %1, %cst
- //
- // In the above example:
- // uRead = OpOperand 0 (%1) of vector.transfer_read
- // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
- // definition = %1
- //
- // This is not a conflict because the InsertSliceOp overwrites the
- // memory segment of %1 with the exact same data. (Effectively, there
- // is no memory write here.)
- if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
- state.areEquivalentBufferizedValues(uRead->get(),
- insertSliceOp.getSource()) &&
- matchesInsertDestination(state, insertSliceOp.getSource(),
- insertSliceOp))
- return true;
-
- return false;
-}
-
/// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
/// certain circumstances, this op can also be a no-op.
///
@@ -777,13 +667,6 @@ struct InsertSliceOpInterface
return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
}
- bool isNotConflicting(Operation *op, OpOperand *uRead,
- OpOperand *uConflictingWrite,
- const AnalysisState &state) const {
- return isNotConflictingInsertSliceLikeOp<tensor::InsertSliceOp>(
- op, uRead, uConflictingWrite, state);
- }
-
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
// insert_slice ops arise from tiling and bufferizing them out-of-place is
@@ -1092,13 +975,6 @@ struct ParallelInsertSliceOpInterface
rewriter.eraseOp(op);
return success();
}
-
- bool isNotConflicting(Operation *op, OpOperand *uRead,
- OpOperand *uConflictingWrite,
- const AnalysisState &state) const {
- return isNotConflictingInsertSliceLikeOp<tensor::ParallelInsertSliceOp>(
- op, uRead, uConflictingWrite, state);
- }
};
/// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled
@@ -1147,6 +1023,58 @@ struct SplatOpInterface
}
};
+namespace {
+/// Return "true" if `insertSliceOp` inserts into a subset that is equivalent
+/// to the subset defined by `candidate`. `equivalenceFn` is used to determine
+/// equivalence of tensors.
+template <typename OpTy>
+bool isSubsetEquivalentToInsertSliceLikeOp(
+ OpTy insertSliceOp, OpResult candidate,
+ function_ref<bool(Value, Value)> equivalenceFn) {
+ // Look for a matching tensor.extract_slice op.
+ auto extractSliceOp = candidate.getDefiningOp<tensor::ExtractSliceOp>();
+ if (!extractSliceOp)
+ return false;
+ if (!equivalenceFn(extractSliceOp.getSource(), insertSliceOp.getDest()))
+ return false;
+ if (!sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
+ isEqualConstantIntOrValue))
+ return false;
+ return true;
+}
+} // namespace
+
+struct InsertSliceOpSubsetOpInterface
+ : public SubsetOpInterface::ExternalModel<InsertSliceOpSubsetOpInterface,
+ tensor::InsertSliceOp> {
+ OpOperand &getSource(Operation *op) const { return op->getOpOperand(0); }
+
+ bool
+ isEquivalentSubset(Operation *op, OpResult candidate,
+ function_ref<bool(Value, Value)> equivalenceFn) const {
+ auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
+ return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
+ equivalenceFn);
+ }
+};
+
+struct ParallelInsertSliceOpSubsetOpInterface
+ : public SubsetOpInterface::ExternalModel<
+ ParallelInsertSliceOpSubsetOpInterface,
+ tensor::ParallelInsertSliceOp> {
+ OpOperand &getSource(Operation *op) const { return op->getOpOperand(0); }
+
+ OpOperand &getDestination(Operation *op) const { return op->getOpOperand(1); }
+
+ bool
+ isEquivalentSubset(Operation *op, OpResult candidate,
+ function_ref<bool(Value, Value)> equivalenceFn) const {
+ auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
+ return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
+ equivalenceFn);
+ }
+};
+
} // namespace
} // namespace tensor
} // namespace mlir
@@ -1154,6 +1082,7 @@ struct SplatOpInterface
void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
+ // BufferizableOpInterface models.
CastOp::attachInterface<CastOpInterface>(*ctx);
CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
DimOp::attachInterface<DimOpInterface>(*ctx);
@@ -1172,6 +1101,11 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
SplatOp::attachInterface<SplatOpInterface>(*ctx);
+ // SubsetOpInterface models.
+ InsertSliceOp::attachInterface<InsertSliceOpSubsetOpInterface>(*ctx);
+ ParallelInsertSliceOp::attachInterface<
+ ParallelInsertSliceOpSubsetOpInterface>(*ctx);
+
// Load additional dialects of which ops may get created.
ctx->loadDialect<arith::ArithDialect, linalg::LinalgDialect>();
});
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 7829bb0ffbd2932..b3645710232f2fa 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -7,14 +7,14 @@
load("@bazel_skylib//rules:expand_template.bzl", "expand_template")
load("@bazel_skylib//rules:write_file.bzl", "write_file")
-load(":tblgen.bzl", "gentbl_cc_library", "td_library")
-load(":linalggen.bzl", "genlinalg")
load(
":build_defs.bzl",
"cc_headers_only",
"if_cuda_available",
"mlir_c_api_cc_library",
)
+load(":linalggen.bzl", "genlinalg")
+load(":tblgen.bzl", "gentbl_cc_library", "td_library")
package(
default_visibility = ["//visibility:public"],
@@ -9705,6 +9705,36 @@ gentbl_cc_library(
],
)
+td_library(
+ name = "SubsetOpInterfaceTdFiles",
+ srcs = [
+ "include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.td",
+ ],
+ includes = ["include"],
+ deps = [
+ ":OpBaseTdFiles",
+ ],
+)
+
+gentbl_cc_library(
+ name = "SubsetOpInterfaceIncGen",
+ tbl_outs = [
+ (
+ ["-gen-op-interface-decls"],
+ "include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.h.inc",
+ ),
+ (
+ ["-gen-op-interface-defs"],
+ "include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.cpp.inc",
+ ),
+ ],
+ tblgen = ":mlir-tblgen",
+ td_file = "include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.td",
+ deps = [
+ ":SubsetOpInterfaceTdFiles",
+ ],
+)
+
td_library(
name = "LinalgDocTdFiles",
srcs = ["include/mlir/Dialect/Linalg/IR/LinalgDoc.td"],
@@ -11972,12 +12002,14 @@ cc_library(
"lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp",
"lib/Dialect/Bufferization/IR/BufferizationDialect.cpp",
"lib/Dialect/Bufferization/IR/BufferizationOps.cpp",
+ "lib/Dialect/Bufferization/IR/SubsetOpInterface.cpp",
"lib/Dialect/Bufferization/IR/UnstructuredControlFlow.cpp",
],
hdrs = [
"include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h",
"include/mlir/Dialect/Bufferization/IR/Bufferization.h",
"include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h",
+ "include/mlir/Dialect/Bufferization/IR/SubsetOpInterface.h",
"include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h",
],
includes = ["include"],
@@ -11998,6 +12030,7 @@ cc_library(
":InferTypeOpInterface",
":MemRefDialect",
":SparseTensorDialect",
+ ":SubsetOpInterfaceIncGen",
":Support",
":TensorDialect",
"//llvm:Support",
More information about the Mlir-commits
mailing list