[Mlir-commits] [mlir] [mlir][IR] Make `OpOperand` comparable (PR #70410)
Matthias Springer
llvmlistbot at llvm.org
Thu Oct 26 21:11:00 PDT 2023
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/70410
Two `OpOperand`s are the same if they belong to the same owner and have the same operand number. There are currently no comparison operators defined on `OpOperand` and we work around this in multiple places by comparing pointers.
Note: `OpOperand`s are stored in an op, so it is valid to compare their pointers to determine if they are the same operand. E.g., `getOperandNumber` is also implemented via pointer arithmetics.
>From 4a8c54773cadcd2be4289d7e308de9f64fb63b3e Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 27 Oct 2023 13:09:01 +0900
Subject: [PATCH] [mlir][IR] Make `OpOperand` comparable
Two `OpOperand`s are the same if they belong to the same owner and have the same operand number. There are currently no comparison operators defined on `OpOperand` and we work around this in multiple places by comparing pointers.
Note: `OpOperand`s are stored in an op, so it is valid to compare their pointers to determine if they are the same operand. E.g., `getOperandNumber` is also implemented via pointer arithmetics.
---
mlir/include/mlir/IR/UseDefLists.h | 7 +++++++
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp | 6 +++---
.../Tensor/Transforms/BufferizableOpInterfaceImpl.cpp | 8 ++++----
3 files changed, 14 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h
index e3e3e86231465dc..ae9287e6621b03f 100644
--- a/mlir/include/mlir/IR/UseDefLists.h
+++ b/mlir/include/mlir/IR/UseDefLists.h
@@ -146,6 +146,13 @@ class IROperand : public detail::IROperandBase {
return *this;
}
+ bool operator==(const IROperand<DerivedT, IRValueT> &other) const {
+ return this == &other;
+ }
+ bool operator!=(const IROperand<DerivedT, IRValueT> &other) const {
+ return !(*this == other);
+ }
+
/// Return the current value being used by this operand.
IRValueT get() const { return value; }
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 5716dcc9d905016..52ff6ceeee85b03 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -537,12 +537,12 @@ LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
bool MaterializeInDestinationOp::bufferizesToMemoryRead(
OpOperand &opOperand, const AnalysisState &state) {
- return &opOperand == &getSourceMutable();
+ return opOperand == getSourceMutable();
}
bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
OpOperand &opOperand, const AnalysisState &state) {
- if (&opOperand == &getDestMutable()) {
+ if (opOperand == getDestMutable()) {
assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
return true;
}
@@ -560,7 +560,7 @@ bool MaterializeInDestinationOp::mustBufferizeInPlace(
AliasingValueList
MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
const AnalysisState &state) {
- if (&opOperand == &getDestMutable()) {
+ if (opOperand == getDestMutable()) {
assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
}
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 9386d0fd0f04faf..a95443db88b50b2 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -644,11 +644,11 @@ struct InsertSliceOpInterface
RankedTensorType destType = insertSliceOp.getDestType();
// The source is always read.
- if (&opOperand == &insertSliceOp.getSourceMutable())
+ if (opOperand == insertSliceOp.getSourceMutable())
return true;
// For the destination, it depends...
- assert(&opOperand == &insertSliceOp.getDestMutable() && "expected dest");
+ assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
// Dest is not read if it is entirely overwritten. E.g.:
// tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
@@ -849,7 +849,7 @@ struct ReshapeOpInterface
const AnalysisState &state) const {
// Depending on the layout map, the source buffer may have to be copied.
auto reshapeOp = cast<tensor::ReshapeOp>(op);
- return &opOperand == &reshapeOp.getShapeMutable();
+ return opOperand == reshapeOp.getShapeMutable();
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
@@ -931,7 +931,7 @@ struct ParallelInsertSliceOpInterface
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
- return &opOperand == ¶llelInsertSliceOp.getDestMutable();
+ return opOperand == parallelInsertSliceOp.getDestMutable();
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
More information about the Mlir-commits
mailing list