[Mlir-commits] [mlir] [mlir][bufferization] Generalize tensor slice rules to subset ops (PR #65619)
Matthias Springer
llvmlistbot at llvm.org
Wed Sep 13 01:20:18 PDT 2023
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/65619:
>From d14933a3d1e75b196b3489ecd94c1c582cc9e3f3 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 13 Sep 2023 10:19:35 +0200
Subject: [PATCH] [mlir][bufferization] Generalize tensor slice rules to subset
ops
This commit generalizes the special tensor.extract_slice/tensor.insert_slice bufferization rules to tensor subset ops.
Ops that insert a tensor into a tensor at a specified subset (e.g., tensor.insert_slice, tensor.scatter) can implement the `SubsetOpInterface`.
Apart from adding a new op interface (extending the API), this change is NFC. The only ops that currently implement the new interface are tensor.insert_slice and tensor.parallel_insert_slice, and those ops were are supported by One-Shot Bufferize.
---
.../Dialect/Bufferization/IR/CMakeLists.txt | 1 +
.../IR/SubsetInsertionOpInterface.h | 29 +++
.../IR/SubsetInsertionOpInterface.td | 119 +++++++++++
.../Dialect/Bufferization/IR/CMakeLists.txt | 1 +
.../IR/SubsetInsertionOpInterface.cpp | 23 +++
.../Transforms/OneShotAnalysis.cpp | 106 ++++++++++
.../BufferizableOpInterfaceImpl.cpp | 189 ++++++------------
.../llvm-project-overlay/mlir/BUILD.bazel | 37 +++-
8 files changed, 378 insertions(+), 127 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h
create mode 100644 mlir/include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.td
create mode 100644 mlir/lib/Dialect/Bufferization/IR/SubsetInsertionOpInterface.cpp
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
index aa93534a78fea3f..440125031b1acc5 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect(BufferizationOps bufferization)
add_mlir_doc(BufferizationOps BufferizationOps Dialects/ -gen-dialect-doc)
add_mlir_interface(AllocationOpInterface)
add_mlir_interface(BufferizableOpInterface)
+add_mlir_interface(SubsetInsertionOpInterface)
set(LLVM_TARGET_DEFINITIONS BufferizationEnums.td)
mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h
new file mode 100644
index 000000000000000..e5b06d746e74bfd
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h
@@ -0,0 +1,29 @@
+//===- SubsetInsertionOpInterface.h - Tensor Subsets ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_IR_SUBSETINSERTIONOPINTERFACE_H_
+#define MLIR_DIALECT_BUFFERIZATION_IR_SUBSETINSERTIONOPINTERFACE_H_
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace bufferization {
+namespace detail {
+
+/// Return the destination/"init" operand of the op if it implements the
+/// `DestinationStyleOpInterface` and has exactly one "init" operand. Asserts
+/// otherwise.
+OpOperand &defaultGetDestinationOperand(Operation *op);
+
+} // namespace detail
+} // namespace bufferization
+} // namespace mlir
+
+#include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h.inc"
+
+#endif // MLIR_DIALECT_BUFFERIZATION_IR_SUBSETINSERTIONOPINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.td
new file mode 100644
index 000000000000000..edf652537795771
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.td
@@ -0,0 +1,119 @@
+//===-- SubsetInsertionOpInterface.td - Tensor Subsets -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SUBSET_INSERTION_OP_INTERFACE
+#define SUBSET_INSERTION_OP_INTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def SubsetInsertionOpInterface : OpInterface<"SubsetInsertionOpInterface"> {
+ let description = [{
+ This interface can be implemented by ops that insert a source tensor into
+ a destination tensor.
+
+ The elements in the destination tensor that are overwritten by this
+ insertion are called the "subset". How the subset is defined is up to the
+ op. E.g., "tensor.insert_slice" defines the subset via a hyperrectangular
+ slice. A scatter operation could define the subset via a list of indices.
+
+ Ops that deal with tensor subsets come in two flavours:
+ - Insertion flavor: Ops that insert a source tensor into a destination
+ tensor at the specified subset. Such ops usually return a new destination
+ tensor and implement the `DestinationStyleOpInterface`. Insertion ops can
+ implement the `SubsetInsertionOpInterface`. Example: "tensor.insert_slice"
+ - Extraction flavor: Ops that define a tensor subset. They extract a
+ specified subset from a tensor. There is currently no op interface for
+ such ops. Example: "tensor.extract_slice"
+
+ This interface provides helper methods for efficient bufferization of
+ subset-based tensor IR. Tensor subsets can bufferize to buffer "views"/
+ "aliases" (in contrast to one or multiple less efficient buffer allocation).
+
+ This interface is queried by One-Shot Bufferize to detect cases where a
+ seeming read-after-write is not actually a conflict because the respective
+ ops are operating on equivalent subsets. More details can be found in the
+ documentation of One-Shot Analysis (see `areNonConflictingSubsets`).
+
+ Note: This interface currently assumes that a subset op inserts a single
+ tensor (source) into a destination tensor at a single subset.
+ }];
+ let cppNamespace = "::mlir::bufferization";
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the source tensor operand.
+ }],
+ /*retType=*/"::mlir::OpOperand &",
+ /*methodName=*/"getSourceOperand",
+ /*args=*/(ins)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the destination tensor operand.
+ }],
+ /*retType=*/"::mlir::OpOperand &",
+ /*methodName=*/"getDestinationOperand",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::bufferization::detail::defaultGetDestinationOperand(
+ $_op.getOperation());
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return "true" if this operation inserts into a subset that is
+ equivalent to the subset defined by `candidate`.
+
+ Two subsets are "equivalent" and "same" if they can bufferize to the
+ same buffer views/aliases. If they are "equivalent", the tensor IR
+ may be expressed in terms of different SSA values (but they could
+ bufferize to MemRef SSA values that can CSE without breaking
+ correctness). `equivalenceFn` should return "true" if the two given
+ values are equivalent.
+
+ Example:
+ ```
+ // The subset of the SubsetInsertionOpInterface op %1 is equivalent to
+ // the subset defined by %2 (but not "same"):
+ %0 = arith.select %c, %t, %t : tensor<?xf32>
+ %1 = tensor.insert_slice %x into %0[0][5][1]
+ : tensor<5xf32> into tensor<?xf32>
+ %2 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+
+ // The subset of the SubsetInsertionOpInterface op %1 is equivalent to
+ // and "same" as the subset defined by %2.
+ %1 = tensor.insert_slice %x into %t[0][5][1]
+ : tensor<5xf32> into tensor<?xf32>
+ %2 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+ ```
+ }],
+ /*retType=*/"bool",
+ /*methodName=*/"isEquivalentSubset",
+ /*args=*/(ins
+ "::mlir::Value":$candidate,
+ "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn)
+ >,
+ ];
+
+ let extraClassDeclaration = [{
+ /// Return "true" if this operation inserts into the same subset as defined
+ /// by `candidate`.
+ ///
+ /// Note: This function is useful outside of bufferization, where no tensor
+ /// equivalence information is available.
+ bool isSameSubset(OpResult candidate) {
+ auto subsetOp = cast<::mlir::bufferization::SubsetInsertionOpInterface>(
+ getOperation());
+ return subsetOp.isEquivalentSubset(
+ candidate, [](Value v1, Value v2) { return v1 == v2; });
+ }
+ }];
+}
+
+#endif // SUBSET_INSERTION_OP_INTERFACE
diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
index 2d8d09b9c41d993..3fd9221624d0f88 100644
--- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
BufferizableOpInterface.cpp
BufferizationOps.cpp
BufferizationDialect.cpp
+ SubsetInsertionOpInterface.cpp
UnstructuredControlFlow.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Bufferization/IR/SubsetInsertionOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/SubsetInsertionOpInterface.cpp
new file mode 100644
index 000000000000000..19a6fbba403c779
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/IR/SubsetInsertionOpInterface.cpp
@@ -0,0 +1,23 @@
+//===- SubsetInsertionOpInterface.cpp - Tensor Subsets --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+
+#include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.cpp.inc"
+
+using namespace mlir;
+
+OpOperand &bufferization::detail::defaultGetDestinationOperand(Operation *op) {
+ auto dstOp = dyn_cast<DestinationStyleOpInterface>(op);
+ assert(dstOp && "getDestination must be implemented for non-DPS ops");
+ assert(
+ dstOp.getNumDpsInits() == 1 &&
+ "getDestination must be implemented for ops with 0 or more than 1 init");
+ return *dstOp.getDpsInitOperand(0);
+}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 49b5ebdf722a1a7..bcc667086f489bf 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -45,6 +45,7 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -531,6 +532,105 @@ static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state,
.empty();
}
+/// Return "true" if `value` is originating from a subset that is equivalent to
+/// the subset that `subsetOp` inserts into.
+static bool matchesInsertDestination(const AnalysisState &state, Value value,
+ SubsetInsertionOpInterface subsetOp) {
+ auto matchingSubset = [&](Value val) {
+ if (auto opResult = dyn_cast<OpResult>(val))
+ if (subsetOp.isEquivalentSubset(opResult, [&](Value v1, Value v2) {
+ return state.areEquivalentBufferizedValues(v1, v2);
+ }))
+ return true;
+ return false;
+ };
+ // There may be multiple leaves at which the reverse SSA use-def chain lookup
+ // terminates. All of them must be equivalent subsets.
+ SetVector<Value> backwardSlice =
+ state.findValueInReverseUseDefChain(value, matchingSubset);
+ return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset));
+}
+
+/// Return "true" if the given "read" and potentially conflicting "write" are
+/// not conflicting due to their subset relationship. The comments in this
+/// function are expressed in terms of tensor.extract_slice/tensor.insert_slice
+/// pairs, but apply to any subset ops that implement the
+/// `SubsetInsertionOpInterface`.
+static bool areNonConflictingSubsets(OpOperand *uRead,
+ OpOperand *uConflictingWrite,
+ const AnalysisState &state) {
+ Operation *readingOp = uRead->getOwner();
+ Operation *conflictingWritingOp = uConflictingWrite->getOwner();
+
+ // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
+ // uRead is an InsertSliceOp...
+ if (auto subsetOp = dyn_cast<SubsetInsertionOpInterface>(readingOp)) {
+ // As an example, consider the following IR.
+ //
+ // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+ // %1 = linalg.fill %cst, %0 {inplace= [true] }
+ // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+ // {inplace= [true] }
+
+ if (uRead == &subsetOp.getDestinationOperand() &&
+ matchesInsertDestination(state, uConflictingWrite->get(), subsetOp))
+ // Case 1: The main insight is that InsertSliceOp reads only part of
+ // the destination tensor. The overwritten area is not read. If
+ // uConflictingWrite writes into exactly the memory location that is
+ // being read by uRead, this is not a conflict.
+ //
+ // In the above example:
+ // uRead = OpOperand 1 (%t) of tensor.insert_slice
+ // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
+ //
+ // The read of %t does not conflict with the write of the FillOp
+ // (same aliases!) because the area that the FillOp operates on is
+ // exactly the one that is *not* read via %t.
+ return true;
+
+ if (uRead == &subsetOp.getSourceOperand() &&
+ uConflictingWrite == &subsetOp.getDestinationOperand() &&
+ matchesInsertDestination(state, uRead->get(), subsetOp))
+ // Case 2: The read of the source tensor and the write to the dest
+ // tensor via an InsertSliceOp is not a conflict if the read is
+ // reading exactly that part of an equivalent tensor that the
+ // InsertSliceOp is writing.
+ //
+ // In the above example:
+ // uRead = OpOperand 0 (%1) of tensor.insert_slice
+ // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+ return true;
+ }
+
+ // If uConflictingWrite is an InsertSliceOp...
+ if (auto subsetOp =
+ dyn_cast<SubsetInsertionOpInterface>(conflictingWritingOp))
+ // As an example, consider the following IR.
+ //
+ // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+ // %1 = linalg.fill %cst, %0 {inplace= [true] }
+ // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+ // {inplace= [true] }
+ // %3 = vector.transfer_read %1, %cst
+ //
+ // In the above example:
+ // uRead = OpOperand 0 (%1) of vector.transfer_read
+ // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+ // definition = %1
+ //
+ // This is not a conflict because the InsertSliceOp overwrites the
+ // memory segment of %1 with the exact same data. (Effectively, there
+ // is no memory write here.)
+ if (uConflictingWrite == &subsetOp.getDestinationOperand() &&
+ state.areEquivalentBufferizedValues(
+ uRead->get(), subsetOp.getSourceOperand().get()) &&
+ matchesInsertDestination(state, subsetOp.getSourceOperand().get(),
+ subsetOp))
+ return true;
+
+ return false;
+}
+
/// Given sets of uses and writes, return true if there is a RaW conflict under
/// the assumption that all given reads/writes alias the same buffer and that
/// all given writes bufferize inplace.
@@ -647,6 +747,12 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
}
}
+ // No conflict if the operands are non-conflicting subsets.
+ if (areNonConflictingSubsets(uRead, uConflictingWrite, state)) {
+ LLVM_DEBUG(llvm::dbgs() << " no conflict: non-conflicting subsets\n");
+ continue;
+ }
+
// No conflict if the op interface says so.
if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index a67ea0334b22b9b..48ca0007ef4898f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -628,117 +629,6 @@ struct InsertOpInterface
}
};
-/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
-/// equivalent operand / result and same offset/sizes/strides specification).
-template <typename OpTy>
-static bool areEquivalentSlices(const AnalysisState &state,
- ExtractSliceOp extractSliceOp,
- OpTy insertSliceOp) {
- if (!extractSliceOp || !insertSliceOp)
- return false;
- if (extractSliceOp != insertSliceOp &&
- !state.areEquivalentBufferizedValues(extractSliceOp.getSource(),
- insertSliceOp.getDest()))
- return false;
- if (!sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
- isEqualConstantIntOrValue))
- return false;
- return true;
-}
-
-/// Return true if `value` is originating from an ExtractSliceOp that matches
-/// the given InsertSliceOp.
-template <typename OpTy>
-static bool matchesInsertDestination(const AnalysisState &state, Value value,
- OpTy insertSliceOp) {
- // Look for matching slices.
- auto matchesSlice = [&](Value val) {
- if (auto extractSliceOp = val.getDefiningOp<ExtractSliceOp>())
- if (areEquivalentSlices(state, extractSliceOp, insertSliceOp))
- return true;
- return false;
- };
- return static_cast<bool>(llvm::all_of(
- state.findValueInReverseUseDefChain(value, matchesSlice), matchesSlice));
-}
-
-template <typename OpTy>
-static bool isNotConflictingInsertSliceLikeOp(Operation *op, OpOperand *uRead,
- OpOperand *uConflictingWrite,
- const AnalysisState &state) {
- Operation *readingOp = uRead->getOwner();
- Operation *conflictingWritingOp = uConflictingWrite->getOwner();
-
- // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
- // uRead is an InsertSliceOp...
- if (auto insertSliceOp = dyn_cast<OpTy>(readingOp)) {
- // As an example, consider the following IR.
- //
- // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
- // %1 = linalg.fill %cst, %0 {inplace= [true] }
- // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
- // {inplace= [true] }
-
- // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
- if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
- matchesInsertDestination(state, uConflictingWrite->get(),
- insertSliceOp))
- // Case 1: The main insight is that InsertSliceOp reads only part of
- // the destination tensor. The overwritten area is not read. If
- // uConflictingWrite writes into exactly the memory location that is
- // being read by uRead, this is not a conflict.
- //
- // In the above example:
- // uRead = OpOperand 1 (%t) of tensor.insert_slice
- // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
- //
- // The read of %t does not conflict with the write of the FillOp
- // (same aliases!) because the area that the FillOp operates on is
- // exactly the one that is *not* read via %t.
- return true;
-
- if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
- uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
- matchesInsertDestination(state, uRead->get(), insertSliceOp))
- // Case 2: The read of the source tensor and the write to the dest
- // tensor via an InsertSliceOp is not a conflict if the read is
- // reading exactly that part of an equivalent tensor that the
- // InsertSliceOp is writing.
- //
- // In the above example:
- // uRead = OpOperand 0 (%1) of tensor.insert_slice
- // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
- return true;
- }
-
- // If uConflictingWrite is an InsertSliceOp...
- if (auto insertSliceOp = dyn_cast<OpTy>(conflictingWritingOp))
- // As an example, consider the following IR.
- //
- // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
- // %1 = linalg.fill %cst, %0 {inplace= [true] }
- // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
- // {inplace= [true] }
- // %3 = vector.transfer_read %1, %cst
- //
- // In the above example:
- // uRead = OpOperand 0 (%1) of vector.transfer_read
- // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
- // definition = %1
- //
- // This is not a conflict because the InsertSliceOp overwrites the
- // memory segment of %1 with the exact same data. (Effectively, there
- // is no memory write here.)
- if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
- state.areEquivalentBufferizedValues(uRead->get(),
- insertSliceOp.getSource()) &&
- matchesInsertDestination(state, insertSliceOp.getSource(),
- insertSliceOp))
- return true;
-
- return false;
-}
-
/// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
/// certain circumstances, this op can also be a no-op.
///
@@ -777,13 +667,6 @@ struct InsertSliceOpInterface
return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
}
- bool isNotConflicting(Operation *op, OpOperand *uRead,
- OpOperand *uConflictingWrite,
- const AnalysisState &state) const {
- return isNotConflictingInsertSliceLikeOp<tensor::InsertSliceOp>(
- op, uRead, uConflictingWrite, state);
- }
-
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
// insert_slice ops arise from tiling and bufferizing them out-of-place is
@@ -1092,13 +975,6 @@ struct ParallelInsertSliceOpInterface
rewriter.eraseOp(op);
return success();
}
-
- bool isNotConflicting(Operation *op, OpOperand *uRead,
- OpOperand *uConflictingWrite,
- const AnalysisState &state) const {
- return isNotConflictingInsertSliceLikeOp<tensor::ParallelInsertSliceOp>(
- op, uRead, uConflictingWrite, state);
- }
};
/// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled
@@ -1147,6 +1023,62 @@ struct SplatOpInterface
}
};
+namespace {
+/// Return "true" if `insertSliceOp` inserts into a subset that is equivalent
+/// to the subset defined by `candidate`. `equivalenceFn` is used to determine
+/// equivalence of tensors.
+template <typename OpTy>
+bool isSubsetEquivalentToInsertSliceLikeOp(
+ OpTy insertSliceOp, Value candidate,
+ function_ref<bool(Value, Value)> equivalenceFn) {
+ // Look for a matching tensor.extract_slice op.
+ auto extractSliceOp = candidate.getDefiningOp<tensor::ExtractSliceOp>();
+ if (!extractSliceOp)
+ return false;
+ if (!equivalenceFn(extractSliceOp.getSource(), insertSliceOp.getDest()))
+ return false;
+ return sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
+ isEqualConstantIntOrValue);
+}
+} // namespace
+
+struct InsertSliceOpSubsetInsertionOpInterface
+ : public SubsetInsertionOpInterface::ExternalModel<
+ InsertSliceOpSubsetInsertionOpInterface, tensor::InsertSliceOp> {
+ OpOperand &getSourceOperand(Operation *op) const {
+ return op->getOpOperand(0);
+ }
+
+ bool
+ isEquivalentSubset(Operation *op, Value candidate,
+ function_ref<bool(Value, Value)> equivalenceFn) const {
+ auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
+ return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
+ equivalenceFn);
+ }
+};
+
+struct ParallelInsertSliceOpSubsetInsertionOpInterface
+ : public SubsetInsertionOpInterface::ExternalModel<
+ ParallelInsertSliceOpSubsetInsertionOpInterface,
+ tensor::ParallelInsertSliceOp> {
+ OpOperand &getSourceOperand(Operation *op) const {
+ return op->getOpOperand(0);
+ }
+
+ OpOperand &getDestinationOperand(Operation *op) const {
+ return op->getOpOperand(1);
+ }
+
+ bool
+ isEquivalentSubset(Operation *op, Value candidate,
+ function_ref<bool(Value, Value)> equivalenceFn) const {
+ auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
+ return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
+ equivalenceFn);
+ }
+};
+
} // namespace
} // namespace tensor
} // namespace mlir
@@ -1154,6 +1086,7 @@ struct SplatOpInterface
void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
+ // BufferizableOpInterface models.
CastOp::attachInterface<CastOpInterface>(*ctx);
CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
DimOp::attachInterface<DimOpInterface>(*ctx);
@@ -1172,6 +1105,12 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
SplatOp::attachInterface<SplatOpInterface>(*ctx);
+ // SubsetInsertionOpInterface models.
+ InsertSliceOp::attachInterface<InsertSliceOpSubsetInsertionOpInterface>(
+ *ctx);
+ ParallelInsertSliceOp::attachInterface<
+ ParallelInsertSliceOpSubsetInsertionOpInterface>(*ctx);
+
// Load additional dialects of which ops may get created.
ctx->loadDialect<arith::ArithDialect, linalg::LinalgDialect>();
});
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 7829bb0ffbd2932..27b57907a315cf1 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -7,14 +7,14 @@
load("@bazel_skylib//rules:expand_template.bzl", "expand_template")
load("@bazel_skylib//rules:write_file.bzl", "write_file")
-load(":tblgen.bzl", "gentbl_cc_library", "td_library")
-load(":linalggen.bzl", "genlinalg")
load(
":build_defs.bzl",
"cc_headers_only",
"if_cuda_available",
"mlir_c_api_cc_library",
)
+load(":linalggen.bzl", "genlinalg")
+load(":tblgen.bzl", "gentbl_cc_library", "td_library")
package(
default_visibility = ["//visibility:public"],
@@ -9705,6 +9705,36 @@ gentbl_cc_library(
],
)
+td_library(
+ name = "SubsetInsertionOpInterfaceTdFiles",
+ srcs = [
+ "include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.td",
+ ],
+ includes = ["include"],
+ deps = [
+ ":OpBaseTdFiles",
+ ],
+)
+
+gentbl_cc_library(
+ name = "SubsetInsertionOpInterfaceIncGen",
+ tbl_outs = [
+ (
+ ["-gen-op-interface-decls"],
+ "include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h.inc",
+ ),
+ (
+ ["-gen-op-interface-defs"],
+ "include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.cpp.inc",
+ ),
+ ],
+ tblgen = ":mlir-tblgen",
+ td_file = "include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.td",
+ deps = [
+ ":SubsetInsertionOpInterfaceTdFiles",
+ ],
+)
+
td_library(
name = "LinalgDocTdFiles",
srcs = ["include/mlir/Dialect/Linalg/IR/LinalgDoc.td"],
@@ -11972,12 +12002,14 @@ cc_library(
"lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp",
"lib/Dialect/Bufferization/IR/BufferizationDialect.cpp",
"lib/Dialect/Bufferization/IR/BufferizationOps.cpp",
+ "lib/Dialect/Bufferization/IR/SubsetInsertionOpInterface.cpp",
"lib/Dialect/Bufferization/IR/UnstructuredControlFlow.cpp",
],
hdrs = [
"include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h",
"include/mlir/Dialect/Bufferization/IR/Bufferization.h",
"include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h",
+ "include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h",
"include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h",
],
includes = ["include"],
@@ -11998,6 +12030,7 @@ cc_library(
":InferTypeOpInterface",
":MemRefDialect",
":SparseTensorDialect",
+ ":SubsetInsertionOpInterfaceIncGen",
":Support",
":TensorDialect",
"//llvm:Support",
More information about the Mlir-commits
mailing list