[Mlir-commits] [mlir] [mlir][IR] Make `OpOperand` comparable (PR #70410)

Matthias Springer llvmlistbot at llvm.org
Thu Oct 26 23:50:49 PDT 2023


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/70410

>From 37915761ec2988fd478161d357e91fa45744bebc Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 27 Oct 2023 13:09:01 +0900
Subject: [PATCH] [mlir][IR] Make `OpOperand` comparable

Two `OpOperand`s are the same if they belong to the same owner and have the same operand number. There are currently no comparison operators defined on `OpOperand` and we work around this in multiple places by comparing pointers.

Note: `OpOperand`s are stored in an op, so it is valid to compare their pointers to determine if they are the same operand. E.g., `getOperandNumber` is also implemented via pointer arithmetics.
---
 mlir/include/mlir/IR/UseDefLists.h                     | 10 ++++++++++
 mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp |  6 +++---
 .../Tensor/Transforms/BufferizableOpInterfaceImpl.cpp  |  8 ++++----
 3 files changed, 17 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h
index e3e3e86231465dc..2d60036716611b1 100644
--- a/mlir/include/mlir/IR/UseDefLists.h
+++ b/mlir/include/mlir/IR/UseDefLists.h
@@ -146,6 +146,16 @@ class IROperand : public detail::IROperandBase {
     return *this;
   }
 
+  /// Two operands are equal if they have the same owner and the same operand
+  /// number. They are stored inside of ops, so it is valid to compare their
+  /// pointers to determine equality.
+  bool operator==(const IROperand<DerivedT, IRValueT> &other) const {
+    return this == &other;
+  }
+  bool operator!=(const IROperand<DerivedT, IRValueT> &other) const {
+    return !(*this == other);
+  }
+
   /// Return the current value being used by this operand.
   IRValueT get() const { return value; }
 
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 5716dcc9d905016..52ff6ceeee85b03 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -537,12 +537,12 @@ LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
 
 bool MaterializeInDestinationOp::bufferizesToMemoryRead(
     OpOperand &opOperand, const AnalysisState &state) {
-  return &opOperand == &getSourceMutable();
+  return opOperand == getSourceMutable();
 }
 
 bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
     OpOperand &opOperand, const AnalysisState &state) {
-  if (&opOperand == &getDestMutable()) {
+  if (opOperand == getDestMutable()) {
     assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
     return true;
   }
@@ -560,7 +560,7 @@ bool MaterializeInDestinationOp::mustBufferizeInPlace(
 AliasingValueList
 MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
                                               const AnalysisState &state) {
-  if (&opOperand == &getDestMutable()) {
+  if (opOperand == getDestMutable()) {
     assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
     return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
   }
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 9386d0fd0f04faf..a95443db88b50b2 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -644,11 +644,11 @@ struct InsertSliceOpInterface
     RankedTensorType destType = insertSliceOp.getDestType();
 
     // The source is always read.
-    if (&opOperand == &insertSliceOp.getSourceMutable())
+    if (opOperand == insertSliceOp.getSourceMutable())
       return true;
 
     // For the destination, it depends...
-    assert(&opOperand == &insertSliceOp.getDestMutable() && "expected dest");
+    assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
 
     // Dest is not read if it is entirely overwritten. E.g.:
     // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
@@ -849,7 +849,7 @@ struct ReshapeOpInterface
                               const AnalysisState &state) const {
     // Depending on the layout map, the source buffer may have to be copied.
     auto reshapeOp = cast<tensor::ReshapeOp>(op);
-    return &opOperand == &reshapeOp.getShapeMutable();
+    return opOperand == reshapeOp.getShapeMutable();
   }
 
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
@@ -931,7 +931,7 @@ struct ParallelInsertSliceOpInterface
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                                const AnalysisState &state) const {
     auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
-    return &opOperand == &parallelInsertSliceOp.getDestMutable();
+    return opOperand == parallelInsertSliceOp.getDestMutable();
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,



More information about the Mlir-commits mailing list