[Mlir-commits] [mlir] [mlir][IR] Change `MutableArrayRange` to enumerate `OpOperand &` (PR #66622)

Matthias Springer llvmlistbot at llvm.org
Mon Sep 18 02:55:55 PDT 2023


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/66622

In line with #66515, change `MutableArrayRange::begin`/`end` to enumerate `OpOperand &` instead of `Value`. Also remove `ForOp::getIterOpOperands`/`setIterArg`, which are now redundant.

Note: `MutableOperandRange` cannot be made a derived class of `indexed_accessor_range_base` (like `OperandRange`), because `MutableOperandRange::assign` can change the number of operands in the range.

>From 5046dd723deb659288cbfef0242f1517d0e765ba Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 18 Sep 2023 11:52:06 +0200
Subject: [PATCH] [mlir][IR] Change `MutableArrayRange` to enumerate `OpOperand
 &`

In line with #66515, change `MutableArrayRange::begin`/`end` to enumerate `OpOperand &` instead of `Value`. Also a remove `ForOp::getIterOpOperands`/`setIterArg`, which are now redundant.

Note: `MutableOperandRange` cannot be made a derived class of `indexed_accessor_range_base` (like `OperandRange`), because `MutableOperandRange::assign` can change the number of operands in the range.
---
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    |  7 -----
 mlir/include/mlir/IR/ValueRange.h             | 10 ++-----
 .../OwnershipBasedBufferDeallocation.cpp      |  6 +++-
 .../BufferDeallocationOpInterfaceImpl.cpp     | 10 +++----
 mlir/lib/Dialect/SCF/IR/SCF.cpp               |  4 +--
 .../BufferizableOpInterfaceImpl.cpp           |  6 ++--
 .../SCF/Transforms/TileUsingInterface.cpp     | 16 +++++------
 mlir/lib/IR/OperationSupport.cpp              |  8 ++++++
 mlir/lib/Transforms/Utils/CFGToSCF.cpp        | 28 +++++++++++++------
 9 files changed, 53 insertions(+), 42 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 6d8aaf64e3263b9..8a9ce949a750d43 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -250,17 +250,10 @@ def ForOp : SCF_Op<"for",
         "expected an index less than the number of region iter args");
       return getBody()->getArguments().drop_front(getNumInductionVars())[index];
     }
-    MutableArrayRef<OpOperand> getIterOpOperands() {
-      return
-        getOperation()->getOpOperands().drop_front(getNumControlOperands());
-    }
 
     void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); }
     void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); }
     void setStep(Value step) { getOperation()->setOperand(2, step); }
-    void setIterArg(unsigned iterArgNum, Value iterArgValue) {
-      getOperation()->setOperand(iterArgNum + getNumControlOperands(), iterArgValue);
-    }
 
     /// Number of induction variables, always 1 for scf::ForOp.
     unsigned getNumInductionVars() { return 1; }
diff --git a/mlir/include/mlir/IR/ValueRange.h b/mlir/include/mlir/IR/ValueRange.h
index f1a1f1841f179e7..9c11178f9cd9cae 100644
--- a/mlir/include/mlir/IR/ValueRange.h
+++ b/mlir/include/mlir/IR/ValueRange.h
@@ -165,13 +165,9 @@ class MutableOperandRange {
   /// Returns the OpOperand at the given index.
   OpOperand &operator[](unsigned index) const;
 
-  OperandRange::iterator begin() const {
-    return static_cast<OperandRange>(*this).begin();
-  }
-
-  OperandRange::iterator end() const {
-    return static_cast<OperandRange>(*this).end();
-  }
+  /// Iterators enumerate OpOperands.
+  MutableArrayRef<OpOperand>::iterator begin() const;
+  MutableArrayRef<OpOperand>::iterator end() const;
 
 private:
   /// Update the length of this range to the one provided.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index 43ba11cf132cb92..09d30835828084d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -47,6 +47,10 @@ static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
 
 static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }
 
+static bool isMemrefOperand(OpOperand &operand) {
+  return isMemref(operand.get());
+}
+
 //===----------------------------------------------------------------------===//
 // Backedges analysis
 //===----------------------------------------------------------------------===//
@@ -937,7 +941,7 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) {
 
   // Add an additional operand for every MemRef for the ownership indicator.
   if (!funcWithoutDynamicOwnership) {
-    unsigned numMemRefs = llvm::count_if(operands, isMemref);
+    unsigned numMemRefs = llvm::count_if(operands, isMemrefOperand);
     SmallVector<Value> newOperands{OperandRange(operands)};
     auto ownershipValues =
         deallocOp.getUpdatedConditions().take_front(numMemRefs);
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp
index e847e946eef1b5d..9423af2542690d9 100644
--- a/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp
@@ -96,12 +96,12 @@ struct CondBranchOpInterface
         mapping[retained] = ownership;
       }
       SmallVector<Value> replacements, ownerships;
-      for (Value operand : destOperands) {
-        replacements.push_back(operand);
-        if (isMemref(operand)) {
-          assert(mapping.contains(operand) &&
+      for (OpOperand &operand : destOperands) {
+        replacements.push_back(operand.get());
+        if (isMemref(operand.get())) {
+          assert(mapping.contains(operand.get()) &&
                  "Should be contained at this point");
-          ownerships.push_back(mapping[operand]);
+          ownerships.push_back(mapping[operand.get()]);
         }
       }
       replacements.append(ownerships);
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5565aefbad18db5..2a760c76d2f6867 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -932,7 +932,7 @@ replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand,
   assert(operand.get().getType() != replacement.getType() &&
          "Expected a different type");
   SmallVector<Value> newIterOperands;
-  for (OpOperand &opOperand : forOp.getIterOpOperands()) {
+  for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
     if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
       newIterOperands.push_back(replacement);
       continue;
@@ -1015,7 +1015,7 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
 
   LogicalResult matchAndRewrite(ForOp op,
                                 PatternRewriter &rewriter) const override {
-    for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) {
+    for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
       OpOperand &iterOpOperand = std::get<0>(it);
       auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
       if (!incomingCast ||
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 11cfefed890c669..8c04a8887013c8f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -332,7 +332,7 @@ DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
 /// Helper function for loop bufferization. Return the bufferized values of the
 /// given OpOperands. If an operand is not a tensor, return the original value.
 static FailureOr<SmallVector<Value>>
-getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands,
+getBuffers(RewriterBase &rewriter, MutableOperandRange operands,
            const BufferizationOptions &options) {
   SmallVector<Value> result;
   for (OpOperand &opOperand : operands) {
@@ -606,7 +606,7 @@ struct ForOpInterface
 
     // The new memref init_args of the loop.
     FailureOr<SmallVector<Value>> maybeInitArgs =
-        getBuffers(rewriter, forOp.getIterOpOperands(), options);
+        getBuffers(rewriter, forOp.getInitArgsMutable(), options);
     if (failed(maybeInitArgs))
       return failure();
     SmallVector<Value> initArgs = *maybeInitArgs;
@@ -825,7 +825,7 @@ struct WhileOpInterface
 
     // The new memref init_args of the loop.
     FailureOr<SmallVector<Value>> maybeInitArgs =
-        getBuffers(rewriter, whileOp->getOpOperands(), options);
+        getBuffers(rewriter, whileOp.getInitsMutable(), options);
     if (failed(maybeInitArgs))
       return failure();
     SmallVector<Value> initArgs = *maybeInitArgs;
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1ce25565edcaf61..ceec0756e421ffd 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -508,7 +508,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
                                       MutableArrayRef<scf::ForOp> loops) {
   // 1. Get the producer of the source (potentially walking through
   // `iter_args` of nested `scf.for`)
-  auto [fusableProducer, destinationIterArg] =
+  auto [fusableProducer, destinationInitArg] =
       getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable()[0],
                                         loops);
   if (!fusableProducer)
@@ -575,17 +575,15 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
   // TODO: This can be modeled better if the `DestinationStyleOpInterface`.
   // Update to use that when it does become available.
   scf::ForOp outerMostLoop = loops.front();
-  std::optional<unsigned> iterArgNumber;
-  if (destinationIterArg) {
-    iterArgNumber =
-        outerMostLoop.getIterArgNumberForOpOperand(*destinationIterArg.value());
-  }
-  if (iterArgNumber) {
+  if (destinationInitArg &&
+      (*destinationInitArg)->getOwner() == outerMostLoop) {
+    std::optional<unsigned> iterArgNumber =
+        outerMostLoop.getIterArgNumberForOpOperand(**destinationInitArg);
     int64_t resultNumber = fusableProducer.getResultNumber();
     if (auto dstOp =
             dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
-      outerMostLoop.setIterArg(iterArgNumber.value(),
-                               dstOp.getTiedOpOperand(fusableProducer)->get());
+      (*destinationInitArg)
+          ->set(dstOp.getTiedOpOperand(fusableProducer)->get());
     }
     for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) {
       auto dstOp = dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 7b17e231ce1065f..b0c50f3d6e29855 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -522,6 +522,14 @@ OpOperand &MutableOperandRange::operator[](unsigned index) const {
   return owner->getOpOperand(start + index);
 }
 
+MutableArrayRef<OpOperand>::iterator MutableOperandRange::begin() const {
+  return owner->getOpOperands().slice(start, length).begin();
+}
+
+MutableArrayRef<OpOperand>::iterator MutableOperandRange::end() const {
+  return owner->getOpOperands().slice(start, length).end();
+}
+
 //===----------------------------------------------------------------------===//
 // MutableOperandRangeRange
 
diff --git a/mlir/lib/Transforms/Utils/CFGToSCF.cpp b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
index 9aab89ed7553600..e7bf6628ccbd7e0 100644
--- a/mlir/lib/Transforms/Utils/CFGToSCF.cpp
+++ b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
@@ -137,6 +137,13 @@ getMutableSuccessorOperands(Block *block, unsigned successorIndex) {
   return succOps.getMutableForwardedOperands();
 }
 
+/// Return the operand range used to transfer operands from `block` to its
+/// successor with the given index.
+static OperandRange getSuccessorOperands(Block *block,
+                                         unsigned successorIndex) {
+  return getMutableSuccessorOperands(block, successorIndex);
+}
+
 /// Appends all the block arguments from `other` to the block arguments of
 /// `block`, copying their types and locations.
 static void addBlockArgumentsFromOther(Block *block, Block *other) {
@@ -175,8 +182,14 @@ class Edge {
 
   /// Returns the arguments of this edge that are passed to the block arguments
   /// of the successor.
-  MutableOperandRange getSuccessorOperands() const {
-    return getMutableSuccessorOperands(fromBlock, successorIndex);
+  MutableOperandRange getMutableSuccessorOperands() const {
+    return ::getMutableSuccessorOperands(fromBlock, successorIndex);
+  }
+
+  /// Returns the arguments of this edge that are passed to the block arguments
+  /// of the successor.
+  OperandRange getSuccessorOperands() const {
+    return ::getSuccessorOperands(fromBlock, successorIndex);
   }
 };
 
@@ -262,7 +275,7 @@ class EdgeMultiplexer {
     assert(result != blockArgMapping.end() &&
            "Edge was not originally passed to `create` method.");
 
-    MutableOperandRange successorOperands = edge.getSuccessorOperands();
+    MutableOperandRange successorOperands = edge.getMutableSuccessorOperands();
 
     // Extra arguments are always appended at the end of the block arguments.
     unsigned extraArgsBeginIndex =
@@ -666,7 +679,7 @@ transformToReduceLoop(Block *loopHeader, Block *exitBlock,
   // invalidated when mutating the operands through a different
   // `MutableOperandRange` of the same operation.
   SmallVector<Value> loopHeaderSuccessorOperands =
-      llvm::to_vector(getMutableSuccessorOperands(latch, loopHeaderIndex));
+      llvm::to_vector(getSuccessorOperands(latch, loopHeaderIndex));
 
   // Add all values used in the next iteration to the exit block. Replace
   // any uses that are outside the loop with the newly created exit block.
@@ -742,7 +755,7 @@ transformToReduceLoop(Block *loopHeader, Block *exitBlock,
 
           loopHeaderSuccessorOperands.push_back(argument);
           for (Edge edge : successorEdges(latch))
-            edge.getSuccessorOperands().append(argument);
+            edge.getMutableSuccessorOperands().append(argument);
         }
 
         use.set(blockArgument);
@@ -939,9 +952,8 @@ static FailureOr<SmallVector<Block *>> transformToStructuredCFBranches(
   if (regionEntry->getNumSuccessors() == 1) {
     // Single successor we can just splice together.
     Block *successor = regionEntry->getSuccessor(0);
-    for (auto &&[oldValue, newValue] :
-         llvm::zip(successor->getArguments(),
-                   getMutableSuccessorOperands(regionEntry, 0)))
+    for (auto &&[oldValue, newValue] : llvm::zip(
+             successor->getArguments(), getSuccessorOperands(regionEntry, 0)))
       oldValue.replaceAllUsesWith(newValue);
     regionEntry->getTerminator()->erase();
 



More information about the Mlir-commits mailing list