[Mlir-commits] [mlir] 1defec8 - [mlir][tensor][bufferize][NFC] Remove duplicate code
Matthias Springer
llvmlistbot at llvm.org
Mon Jul 25 03:34:28 PDT 2022
Author: Matthias Springer
Date: 2022-07-25T12:34:16+02:00
New Revision: 1defec87306593e71057de7baf9bd8e2389c2419
URL: https://github.com/llvm/llvm-project/commit/1defec87306593e71057de7baf9bd8e2389c2419
DIFF: https://github.com/llvm/llvm-project/commit/1defec87306593e71057de7baf9bd8e2389c2419.diff
LOG: [mlir][tensor][bufferize][NFC] Remove duplicate code
InsertSliceOp and ParallelInsertSliceOp are very similar and can share some of the bufferization analysis code.
Differential Revision: https://reviews.llvm.org/D130465
Added:
Modified:
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 44ac40abfb65..38044da2e4db 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -552,29 +552,30 @@ struct InsertOpInterface
/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
/// equivalent operand / result and same offset/sizes/strides specification).
-///
-/// This is one particular type of relationship between ops on tensors that
-/// reduce to an equivalence on buffers. This should be generalized and
-/// exposed as interfaces on the proper types.
+template <typename OpTy>
static bool areEquivalentExtractSliceOps(const AnalysisState &state,
- ExtractSliceOp st, InsertSliceOp sti) {
- if (!st || !sti)
+ ExtractSliceOp extractSliceOp,
+ OpTy insertSliceOp) {
+ if (!extractSliceOp || !insertSliceOp)
return false;
- if (sti != sti &&
- !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
+ if (extractSliceOp != insertSliceOp &&
+ !state.areEquivalentBufferizedValues(extractSliceOp.getSource(),
+ insertSliceOp.getDest()))
return false;
- if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
+ if (!sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
+ isEqualConstantIntOrValue))
return false;
return true;
}
/// Return true if `value` is originating from an ExtractSliceOp that matches
/// the given InsertSliceOp.
+template <typename OpTy>
static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
- InsertSliceOp insertOp) {
+ OpTy insertSliceOp) {
auto condition = [&](Value val) {
- if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
- if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
+ if (auto extractSliceOp = val.getDefiningOp<ExtractSliceOp>())
+ if (areEquivalentExtractSliceOps(state, extractSliceOp, insertSliceOp))
return true;
return false;
};
@@ -583,6 +584,83 @@ static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
condition);
}
+template <typename OpTy>
+static bool isNotConflictingInsertSliceLikeOp(Operation *op, OpOperand *uRead,
+ OpOperand *uConflictingWrite,
+ const AnalysisState &state) {
+ Operation *readingOp = uRead->getOwner();
+ Operation *conflictingWritingOp = uConflictingWrite->getOwner();
+
+ // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
+ // uRead is an InsertSliceOp...
+ if (auto insertSliceOp = dyn_cast<OpTy>(readingOp)) {
+ // As an example, consider the following IR.
+ //
+ // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+ // %1 = linalg.fill %cst, %0 {inplace= [true] }
+ // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+ // {inplace= [true] }
+
+ // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
+ if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
+ hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
+ insertSliceOp))
+ // Case 1: The main insight is that InsertSliceOp reads only part of
+ // the destination tensor. The overwritten area is not read. If
+ // uConflictingWrite writes into exactly the memory location that is
+ // being read by uRead, this is not a conflict.
+ //
+ // In the above example:
+ // uRead = OpOperand 1 (%t) of tensor.insert_slice
+ // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
+ //
+ // The read of %t does not conflict with the write of the FillOp
+ // (same aliases!) because the area that the FillOp operates on is
+ // exactly the one that is *not* read via %t.
+ return true;
+
+ if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
+ uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
+ hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
+ // Case 2: The read of the source tensor and the write to the dest
+ // tensor via an InsertSliceOp is not a conflict if the read is
+ // reading exactly that part of an equivalent tensor that the
+ // InsertSliceOp is writing.
+ //
+ // In the above example:
+ // uRead = OpOperand 0 (%1) of tensor.insert_slice
+ // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+ return true;
+ }
+
+ // If uConflictingWrite is an InsertSliceOp...
+ if (auto insertSliceOp = dyn_cast<OpTy>(conflictingWritingOp))
+ // As an example, consider the following IR.
+ //
+ // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+ // %1 = linalg.fill %cst, %0 {inplace= [true] }
+ // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+ // {inplace= [true] }
+ // %3 = vector.transfer_read %1, %cst
+ //
+ // In the above example:
+ // uRead = OpOperand 0 (%1) of vector.transfer_read
+ // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+ // lastWrite = %1
+ //
+ // This is not a conflict because the InsertSliceOp overwrites the
+ // memory segment of %1 with the exact same data. (Effectively, there
+ // is no memory write here.)
+ if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
+ state.areEquivalentBufferizedValues(uRead->get(),
+ insertSliceOp.getSource()) &&
+ hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
+ insertSliceOp))
+ return true;
+
+ return false;
+}
+
/// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
/// certain circumstances, this op can also be a no-op.
struct InsertSliceOpInterface
@@ -613,77 +691,8 @@ struct InsertSliceOpInterface
bool isNotConflicting(Operation *op, OpOperand *uRead,
OpOperand *uConflictingWrite,
const AnalysisState &state) const {
- Operation *readingOp = uRead->getOwner();
- Operation *conflictingWritingOp = uConflictingWrite->getOwner();
-
- // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
- // uRead is an InsertSliceOp...
- if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
- // As an example, consider the following IR.
- //
- // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
- // %1 = linalg.fill %cst, %0 {inplace= [true] }
- // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
- // {inplace= [true] }
-
- // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
- if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
- hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
- insertSliceOp))
- // Case 1: The main insight is that InsertSliceOp reads only part of
- // the destination tensor. The overwritten area is not read. If
- // uConflictingWrite writes into exactly the memory location that is
- // being read by uRead, this is not a conflict.
- //
- // In the above example:
- // uRead = OpOperand 1 (%t) of tensor.insert_slice
- // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
- //
- // The read of %t does not conflict with the write of the FillOp
- // (same aliases!) because the area that the FillOp operates on is
- // exactly the one that is *not* read via %t.
- return true;
-
- if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
- uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
- hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
- // Case 2: The read of the source tensor and the write to the dest
- // tensor via an InsertSliceOp is not a conflict if the read is
- // reading exactly that part of an equivalent tensor that the
- // InsertSliceOp is writing.
- //
- // In the above example:
- // uRead = OpOperand 0 (%1) of tensor.insert_slice
- // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
- return true;
- }
-
- // If uConflictingWrite is an InsertSliceOp...
- if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
- // As an example, consider the following IR.
- //
- // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
- // %1 = linalg.fill %cst, %0 {inplace= [true] }
- // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
- // {inplace= [true] }
- // %3 = vector.transfer_read %1, %cst
- //
- // In the above example:
- // uRead = OpOperand 0 (%1) of vector.transfer_read
- // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
- // lastWrite = %1
- //
- // This is not a conflict because the InsertSliceOp overwrites the
- // memory segment of %1 with the exact same data. (Effectively, there
- // is no memory write here.)
- if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
- state.areEquivalentBufferizedValues(uRead->get(),
- insertSliceOp.getSource()) &&
- hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
- insertSliceOp))
- return true;
-
- return false;
+ return isNotConflictingInsertSliceLikeOp<tensor::InsertSliceOp>(
+ op, uRead, uConflictingWrite, state);
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -805,36 +814,6 @@ struct ReshapeOpInterface
}
};
-/// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e.
-/// equivalent operand / result and same offset/sizes/strides specification).
-static bool areEquivalentExtractSliceOps(const AnalysisState &state,
- ExtractSliceOp st,
- ParallelInsertSliceOp sti) {
- if (!st || !sti)
- return false;
- if (st != sti &&
- !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
- return false;
- if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
- return false;
- return true;
-}
-
-/// Return true if `value` is originating from an ExtractSliceOp that matches
-/// the given InsertSliceOp.
-static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
- ParallelInsertSliceOp insertOp) {
- auto condition = [&](Value val) {
- if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
- if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
- return true;
- return false;
- };
-
- return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
- condition);
-}
-
/// Analysis of ParallelInsertSliceOp.
struct ParallelInsertSliceOpInterface
: public BufferizableOpInterface::ExternalModel<
@@ -978,83 +957,11 @@ struct ParallelInsertSliceOpInterface
return success();
}
- // TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share
- // the code.
bool isNotConflicting(Operation *op, OpOperand *uRead,
OpOperand *uConflictingWrite,
const AnalysisState &state) const {
- Operation *readingOp = uRead->getOwner();
- Operation *conflictingWritingOp = uConflictingWrite->getOwner();
-
- // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
- // uRead is an InsertSliceOp...
- if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) {
- // As an example, consider the following IR.
- //
- // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
- // %1 = linalg.fill %cst, %0 {inplace= [true] }
- // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
- // {inplace= [true] }
-
- // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
- if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
- hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
- insertSliceOp))
- // Case 1: The main insight is that InsertSliceOp reads only part of
- // the destination tensor. The overwritten area is not read. If
- // uConflictingWrite writes into exactly the memory location that is
- // being read by uRead, this is not a conflict.
- //
- // In the above example:
- // uRead = OpOperand 1 (%t) of tensor.insert_slice
- // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
- //
- // The read of %t does not conflict with the write of the FillOp
- // (same aliases!) because the area that the FillOp operates on is
- // exactly the one that is *not* read via %t.
- return true;
-
- if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
- uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
- hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
- // Case 2: The read of the source tensor and the write to the dest
- // tensor via an InsertSliceOp is not a conflict if the read is
- // reading exactly that part of an equivalent tensor that the
- // InsertSliceOp is writing.
- //
- // In the above example:
- // uRead = OpOperand 0 (%1) of tensor.insert_slice
- // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
- return true;
- }
-
- // If uConflictingWrite is an InsertSliceOp...
- if (auto insertSliceOp =
- dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp))
- // As an example, consider the following IR.
- //
- // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
- // %1 = linalg.fill %cst, %0 {inplace= [true] }
- // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
- // {inplace= [true] }
- // %3 = vector.transfer_read %1, %cst
- //
- // In the above example:
- // uRead = OpOperand 0 (%1) of vector.transfer_read
- // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
- // lastWrite = %1
- //
- // This is not a conflict because the InsertSliceOp overwrites the
- // memory segment of %1 with the exact same data. (Effectively, there
- // is no memory write here.)
- if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
- state.areEquivalentBufferizedValues(uRead->get(),
- insertSliceOp.getSource()) &&
- hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
- insertSliceOp))
- return true;
-
- return false;
+ return isNotConflictingInsertSliceLikeOp<tensor::ParallelInsertSliceOp>(
+ op, uRead, uConflictingWrite, state);
}
};
More information about the Mlir-commits
mailing list