[Mlir-commits] [mlir] 5558504 - [mlir][IR] Make `OpOperand` comparable (#70410)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 26 23:51:49 PDT 2023
Author: Matthias Springer
Date: 2023-10-27T15:51:45+09:00
New Revision: 5558504374cf3364310e5f088c18ce9fb5a58d65
URL: https://github.com/llvm/llvm-project/commit/5558504374cf3364310e5f088c18ce9fb5a58d65
DIFF: https://github.com/llvm/llvm-project/commit/5558504374cf3364310e5f088c18ce9fb5a58d65.diff
LOG: [mlir][IR] Make `OpOperand` comparable (#70410)
Two `OpOperand`s are the same if they belong to the same owner and have
the same operand number. There are currently no comparison operators
defined on `OpOperand` and we work around this in multiple places by
comparing pointers.
Note: `OpOperand`s are stored in an op, so it is valid to compare their
pointers to determine if they are the same operand. E.g.,
`getOperandNumber` is also implemented via pointer arithmetics.
Added:
Modified:
mlir/include/mlir/IR/UseDefLists.h
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h
index e3e3e86231465dc..2d60036716611b1 100644
--- a/mlir/include/mlir/IR/UseDefLists.h
+++ b/mlir/include/mlir/IR/UseDefLists.h
@@ -146,6 +146,16 @@ class IROperand : public detail::IROperandBase {
return *this;
}
+ /// Two operands are equal if they have the same owner and the same operand
+ /// number. They are stored inside of ops, so it is valid to compare their
+ /// pointers to determine equality.
+ bool operator==(const IROperand<DerivedT, IRValueT> &other) const {
+ return this == &other;
+ }
+ bool operator!=(const IROperand<DerivedT, IRValueT> &other) const {
+ return !(*this == other);
+ }
+
/// Return the current value being used by this operand.
IRValueT get() const { return value; }
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 5716dcc9d905016..52ff6ceeee85b03 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -537,12 +537,12 @@ LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
bool MaterializeInDestinationOp::bufferizesToMemoryRead(
OpOperand &opOperand, const AnalysisState &state) {
- return &opOperand == &getSourceMutable();
+ return opOperand == getSourceMutable();
}
bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
OpOperand &opOperand, const AnalysisState &state) {
- if (&opOperand == &getDestMutable()) {
+ if (opOperand == getDestMutable()) {
assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
return true;
}
@@ -560,7 +560,7 @@ bool MaterializeInDestinationOp::mustBufferizeInPlace(
AliasingValueList
MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
const AnalysisState &state) {
- if (&opOperand == &getDestMutable()) {
+ if (opOperand == getDestMutable()) {
assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
}
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 9386d0fd0f04faf..a95443db88b50b2 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -644,11 +644,11 @@ struct InsertSliceOpInterface
RankedTensorType destType = insertSliceOp.getDestType();
// The source is always read.
- if (&opOperand == &insertSliceOp.getSourceMutable())
+ if (opOperand == insertSliceOp.getSourceMutable())
return true;
// For the destination, it depends...
- assert(&opOperand == &insertSliceOp.getDestMutable() && "expected dest");
+ assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
// Dest is not read if it is entirely overwritten. E.g.:
// tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
@@ -849,7 +849,7 @@ struct ReshapeOpInterface
const AnalysisState &state) const {
// Depending on the layout map, the source buffer may have to be copied.
auto reshapeOp = cast<tensor::ReshapeOp>(op);
- return &opOperand == &reshapeOp.getShapeMutable();
+ return opOperand == reshapeOp.getShapeMutable();
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
@@ -931,7 +931,7 @@ struct ParallelInsertSliceOpInterface
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
- return &opOperand == ¶llelInsertSliceOp.getDestMutable();
+ return opOperand == parallelInsertSliceOp.getDestMutable();
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
More information about the Mlir-commits
mailing list