[Mlir-commits] [mlir] 5558504 - [mlir][IR] Make `OpOperand` comparable (#70410)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 26 23:51:49 PDT 2023


Author: Matthias Springer
Date: 2023-10-27T15:51:45+09:00
New Revision: 5558504374cf3364310e5f088c18ce9fb5a58d65

URL: https://github.com/llvm/llvm-project/commit/5558504374cf3364310e5f088c18ce9fb5a58d65
DIFF: https://github.com/llvm/llvm-project/commit/5558504374cf3364310e5f088c18ce9fb5a58d65.diff

LOG: [mlir][IR] Make `OpOperand` comparable (#70410)

Two `OpOperand`s are the same if they belong to the same owner and have
the same operand number. There are currently no comparison operators
defined on `OpOperand` and we work around this in multiple places by
comparing pointers.

Note: `OpOperand`s are stored in an op, so it is valid to compare their
pointers to determine if they are the same operand. E.g.,
`getOperandNumber` is also implemented via pointer arithmetics.

Added: 
    

Modified: 
    mlir/include/mlir/IR/UseDefLists.h
    mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h
index e3e3e86231465dc..2d60036716611b1 100644
--- a/mlir/include/mlir/IR/UseDefLists.h
+++ b/mlir/include/mlir/IR/UseDefLists.h
@@ -146,6 +146,16 @@ class IROperand : public detail::IROperandBase {
     return *this;
   }
 
+  /// Two operands are equal if they have the same owner and the same operand
+  /// number. They are stored inside of ops, so it is valid to compare their
+  /// pointers to determine equality.
+  bool operator==(const IROperand<DerivedT, IRValueT> &other) const {
+    return this == &other;
+  }
+  bool operator!=(const IROperand<DerivedT, IRValueT> &other) const {
+    return !(*this == other);
+  }
+
   /// Return the current value being used by this operand.
   IRValueT get() const { return value; }
 

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 5716dcc9d905016..52ff6ceeee85b03 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -537,12 +537,12 @@ LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
 
 bool MaterializeInDestinationOp::bufferizesToMemoryRead(
     OpOperand &opOperand, const AnalysisState &state) {
-  return &opOperand == &getSourceMutable();
+  return opOperand == getSourceMutable();
 }
 
 bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
     OpOperand &opOperand, const AnalysisState &state) {
-  if (&opOperand == &getDestMutable()) {
+  if (opOperand == getDestMutable()) {
     assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
     return true;
   }
@@ -560,7 +560,7 @@ bool MaterializeInDestinationOp::mustBufferizeInPlace(
 AliasingValueList
 MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
                                               const AnalysisState &state) {
-  if (&opOperand == &getDestMutable()) {
+  if (opOperand == getDestMutable()) {
     assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
     return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
   }

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 9386d0fd0f04faf..a95443db88b50b2 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -644,11 +644,11 @@ struct InsertSliceOpInterface
     RankedTensorType destType = insertSliceOp.getDestType();
 
     // The source is always read.
-    if (&opOperand == &insertSliceOp.getSourceMutable())
+    if (opOperand == insertSliceOp.getSourceMutable())
       return true;
 
     // For the destination, it depends...
-    assert(&opOperand == &insertSliceOp.getDestMutable() && "expected dest");
+    assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
 
     // Dest is not read if it is entirely overwritten. E.g.:
     // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
@@ -849,7 +849,7 @@ struct ReshapeOpInterface
                               const AnalysisState &state) const {
     // Depending on the layout map, the source buffer may have to be copied.
     auto reshapeOp = cast<tensor::ReshapeOp>(op);
-    return &opOperand == &reshapeOp.getShapeMutable();
+    return opOperand == reshapeOp.getShapeMutable();
   }
 
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
@@ -931,7 +931,7 @@ struct ParallelInsertSliceOpInterface
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                                const AnalysisState &state) const {
     auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
-    return &opOperand == &parallelInsertSliceOp.getDestMutable();
+    return opOperand == parallelInsertSliceOp.getDestMutable();
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,


        


More information about the Mlir-commits mailing list