[Mlir-commits] [mlir] [mlir][SCF] `ForOp`: Remove `getIterArgNumberForOpOperand` (PR #66629)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 18 03:37:22 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

<details>
<summary>Changes</summary>

This function was inconsistent with the remaining API because it accepted `OpOperand &` that do not belong to the op. All the other functions assert. This helper function is also not really necessary, as the iter_arg number is identical to the result number.

Depends on #<!-- -->66622. Review only the top commit.

---
Full diff: https://github.com/llvm/llvm-project/pull/66629.diff


10 Files Affected:

- (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (-17) 
- (modified) mlir/include/mlir/IR/ValueRange.h (+3-7) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp (+5-1) 
- (modified) mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp (+5-5) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp (+9-13) 
- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+2-2) 
- (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+3-3) 
- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+9-10) 
- (modified) mlir/lib/IR/OperationSupport.cpp (+8) 
- (modified) mlir/lib/Transforms/Utils/CFGToSCF.cpp (+20-8) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 6d8aaf64e3263b9..89c1a06412947b2 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -250,17 +250,10 @@ def ForOp : SCF_Op<"for",
         "expected an index less than the number of region iter args");
       return getBody()->getArguments().drop_front(getNumInductionVars())[index];
     }
-    MutableArrayRef<OpOperand> getIterOpOperands() {
-      return
-        getOperation()->getOpOperands().drop_front(getNumControlOperands());
-    }
 
     void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); }
     void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); }
     void setStep(Value step) { getOperation()->setOperand(2, step); }
-    void setIterArg(unsigned iterArgNum, Value iterArgValue) {
-      getOperation()->setOperand(iterArgNum + getNumControlOperands(), iterArgValue);
-    }
 
     /// Number of induction variables, always 1 for scf::ForOp.
     unsigned getNumInductionVars() { return 1; }
@@ -270,16 +263,6 @@ def ForOp : SCF_Op<"for",
     }
     /// Number of operands controlling the loop: lb, ub, step
     unsigned getNumControlOperands() { return 3; }
-    /// Get the iter arg number for an operand. If it isnt an iter arg
-    /// operand return std::nullopt.
-    std::optional<unsigned> getIterArgNumberForOpOperand(OpOperand &opOperand) {
-      if (opOperand.getOwner() != getOperation())
-        return std::nullopt;
-      unsigned operandNumber = opOperand.getOperandNumber();
-      if (operandNumber < getNumControlOperands())
-        return std::nullopt;
-      return operandNumber - getNumControlOperands();
-    }
 
     /// Get the region iter arg that corresponds to an OpOperand.
     /// This helper prevents internal op implementation detail leakage to
diff --git a/mlir/include/mlir/IR/ValueRange.h b/mlir/include/mlir/IR/ValueRange.h
index f1a1f1841f179e7..9c11178f9cd9cae 100644
--- a/mlir/include/mlir/IR/ValueRange.h
+++ b/mlir/include/mlir/IR/ValueRange.h
@@ -165,13 +165,9 @@ class MutableOperandRange {
   /// Returns the OpOperand at the given index.
   OpOperand &operator[](unsigned index) const;
 
-  OperandRange::iterator begin() const {
-    return static_cast<OperandRange>(*this).begin();
-  }
-
-  OperandRange::iterator end() const {
-    return static_cast<OperandRange>(*this).end();
-  }
+  /// Iterators enumerate OpOperands.
+  MutableArrayRef<OpOperand>::iterator begin() const;
+  MutableArrayRef<OpOperand>::iterator end() const;
 
 private:
   /// Update the length of this range to the one provided.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index 43ba11cf132cb92..09d30835828084d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -47,6 +47,10 @@ static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
 
 static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }
 
+static bool isMemrefOperand(OpOperand &operand) {
+  return isMemref(operand.get());
+}
+
 //===----------------------------------------------------------------------===//
 // Backedges analysis
 //===----------------------------------------------------------------------===//
@@ -937,7 +941,7 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) {
 
   // Add an additional operand for every MemRef for the ownership indicator.
   if (!funcWithoutDynamicOwnership) {
-    unsigned numMemRefs = llvm::count_if(operands, isMemref);
+    unsigned numMemRefs = llvm::count_if(operands, isMemrefOperand);
     SmallVector<Value> newOperands{OperandRange(operands)};
     auto ownershipValues =
         deallocOp.getUpdatedConditions().take_front(numMemRefs);
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp
index e847e946eef1b5d..9423af2542690d9 100644
--- a/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp
@@ -96,12 +96,12 @@ struct CondBranchOpInterface
         mapping[retained] = ownership;
       }
       SmallVector<Value> replacements, ownerships;
-      for (Value operand : destOperands) {
-        replacements.push_back(operand);
-        if (isMemref(operand)) {
-          assert(mapping.contains(operand) &&
+      for (OpOperand &operand : destOperands) {
+        replacements.push_back(operand.get());
+        if (isMemref(operand.get())) {
+          assert(mapping.contains(operand.get()) &&
                  "Should be contained at this point");
-          ownerships.push_back(mapping[operand]);
+          ownerships.push_back(mapping[operand.get()]);
         }
       }
       replacements.append(ownerships);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 21bc0554e717692..a9debb7bbc489a4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -810,13 +810,9 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp());
 
-  std::optional<unsigned> maybeOperandNumber =
-      forOp.getIterArgNumberForOpOperand(*pUse);
-  assert(maybeOperandNumber.has_value() && "expected a proper iter arg number");
-
-  int64_t operandNumber = maybeOperandNumber.value();
+  unsigned iterArgNumber = forOp.getResultForOpOperand(*pUse).getResultNumber();
   auto yieldOp = cast<scf::YieldOp>(forOp.getBody(0)->getTerminator());
-  auto yieldingExtractSliceOp = yieldOp->getOperand(operandNumber)
+  auto yieldingExtractSliceOp = yieldOp->getOperand(iterArgNumber)
                                     .getDefiningOp<tensor::ExtractSliceOp>();
   if (!yieldingExtractSliceOp)
     return tensor::ExtractSliceOp();
@@ -829,9 +825,9 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
     return tensor::ExtractSliceOp();
 
   SmallVector<Value> initArgs = forOp.getInitArgs();
-  initArgs[operandNumber] = hoistedPackedTensor;
+  initArgs[iterArgNumber] = hoistedPackedTensor;
   SmallVector<Value> yieldOperands = yieldOp.getOperands();
-  yieldOperands[operandNumber] = yieldingExtractSliceOp.getSource();
+  yieldOperands[iterArgNumber] = yieldingExtractSliceOp.getSource();
 
   int64_t numOriginalForOpResults = initArgs.size();
   LLVM_DEBUG(DBGS() << "numOriginalForOpResults: " << numOriginalForOpResults
@@ -844,7 +840,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
         hoistedPackedTensor.getLoc(), hoistedPackedTensor,
         outerSliceOp.getMixedOffsets(), outerSliceOp.getMixedSizes(),
         outerSliceOp.getMixedStrides());
-    rewriter.replaceAllUsesWith(forOp.getResult(operandNumber), extracted);
+    rewriter.replaceAllUsesWith(forOp.getResult(iterArgNumber), extracted);
   }
   scf::ForOp newForOp =
       replaceLoopWithNewYields(rewriter, forOp, initArgs, yieldOperands);
@@ -853,20 +849,20 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
                     << "\n");
   LLVM_DEBUG(DBGS() << "replace source of: " << extracted << "\n");
   LLVM_DEBUG(DBGS() << "with result #"
-                    << numOriginalForOpResults + operandNumber
+                    << numOriginalForOpResults + iterArgNumber
                     << " of forOp, giving us: " << extracted << "\n");
   rewriter.startRootUpdate(extracted);
   extracted.getSourceMutable().assign(
-      newForOp.getResult(numOriginalForOpResults + operandNumber));
+      newForOp.getResult(numOriginalForOpResults + iterArgNumber));
   rewriter.finalizeRootUpdate(extracted);
 
   LLVM_DEBUG(DBGS() << "replace uses of: " << paddedValueBeforeHoisting
                     << "\n");
   LLVM_DEBUG(DBGS() << "with region iter arg #"
-                    << numOriginalForOpResults + operandNumber << "\n");
+                    << numOriginalForOpResults + iterArgNumber << "\n");
   rewriter.replaceAllUsesWith(
       paddedValueBeforeHoisting,
-      newForOp.getRegionIterArg(numOriginalForOpResults + operandNumber));
+      newForOp.getRegionIterArg(numOriginalForOpResults + iterArgNumber));
 
   return extracted;
 }
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5565aefbad18db5..2a760c76d2f6867 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -932,7 +932,7 @@ replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand,
   assert(operand.get().getType() != replacement.getType() &&
          "Expected a different type");
   SmallVector<Value> newIterOperands;
-  for (OpOperand &opOperand : forOp.getIterOpOperands()) {
+  for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
     if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
       newIterOperands.push_back(replacement);
       continue;
@@ -1015,7 +1015,7 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
 
   LogicalResult matchAndRewrite(ForOp op,
                                 PatternRewriter &rewriter) const override {
-    for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) {
+    for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
       OpOperand &iterOpOperand = std::get<0>(it);
       auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
       if (!incomingCast ||
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 11cfefed890c669..8c04a8887013c8f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -332,7 +332,7 @@ DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
 /// Helper function for loop bufferization. Return the bufferized values of the
 /// given OpOperands. If an operand is not a tensor, return the original value.
 static FailureOr<SmallVector<Value>>
-getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands,
+getBuffers(RewriterBase &rewriter, MutableOperandRange operands,
            const BufferizationOptions &options) {
   SmallVector<Value> result;
   for (OpOperand &opOperand : operands) {
@@ -606,7 +606,7 @@ struct ForOpInterface
 
     // The new memref init_args of the loop.
     FailureOr<SmallVector<Value>> maybeInitArgs =
-        getBuffers(rewriter, forOp.getIterOpOperands(), options);
+        getBuffers(rewriter, forOp.getInitArgsMutable(), options);
     if (failed(maybeInitArgs))
       return failure();
     SmallVector<Value> initArgs = *maybeInitArgs;
@@ -825,7 +825,7 @@ struct WhileOpInterface
 
     // The new memref init_args of the loop.
     FailureOr<SmallVector<Value>> maybeInitArgs =
-        getBuffers(rewriter, whileOp->getOpOperands(), options);
+        getBuffers(rewriter, whileOp.getInitsMutable(), options);
     if (failed(maybeInitArgs))
       return failure();
     SmallVector<Value> initArgs = *maybeInitArgs;
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1ce25565edcaf61..d7b986e052ae575 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -508,7 +508,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
                                       MutableArrayRef<scf::ForOp> loops) {
   // 1. Get the producer of the source (potentially walking through
   // `iter_args` of nested `scf.for`)
-  auto [fusableProducer, destinationIterArg] =
+  auto [fusableProducer, destinationInitArg] =
       getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable()[0],
                                         loops);
   if (!fusableProducer)
@@ -575,17 +575,16 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
   // TODO: This can be modeled better if the `DestinationStyleOpInterface`.
   // Update to use that when it does become available.
   scf::ForOp outerMostLoop = loops.front();
-  std::optional<unsigned> iterArgNumber;
-  if (destinationIterArg) {
-    iterArgNumber =
-        outerMostLoop.getIterArgNumberForOpOperand(*destinationIterArg.value());
-  }
-  if (iterArgNumber) {
+  if (destinationInitArg &&
+      (*destinationInitArg)->getOwner() == outerMostLoop) {
+    unsigned iterArgNumber =
+        outerMostLoop.getResultForOpOperand(**destinationInitArg)
+            .getResultNumber();
     int64_t resultNumber = fusableProducer.getResultNumber();
     if (auto dstOp =
             dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
-      outerMostLoop.setIterArg(iterArgNumber.value(),
-                               dstOp.getTiedOpOperand(fusableProducer)->get());
+      (*destinationInitArg)
+          ->set(dstOp.getTiedOpOperand(fusableProducer)->get());
     }
     for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) {
       auto dstOp = dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
@@ -594,7 +593,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
       scf::ForOp innerMostLoop = loops.back();
       updateDestinationOperandsForTiledOp(
           rewriter, dstOp.getDpsInitOperand(resultNumber)->get(),
-          innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]);
+          innerMostLoop.getRegionIterArgs()[iterArgNumber]);
     }
   }
   return scf::SCFFuseProducerOfSliceResult{fusableProducer,
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 7b17e231ce1065f..b0c50f3d6e29855 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -522,6 +522,14 @@ OpOperand &MutableOperandRange::operator[](unsigned index) const {
   return owner->getOpOperand(start + index);
 }
 
+MutableArrayRef<OpOperand>::iterator MutableOperandRange::begin() const {
+  return owner->getOpOperands().slice(start, length).begin();
+}
+
+MutableArrayRef<OpOperand>::iterator MutableOperandRange::end() const {
+  return owner->getOpOperands().slice(start, length).end();
+}
+
 //===----------------------------------------------------------------------===//
 // MutableOperandRangeRange
 
diff --git a/mlir/lib/Transforms/Utils/CFGToSCF.cpp b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
index 9aab89ed7553600..e7bf6628ccbd7e0 100644
--- a/mlir/lib/Transforms/Utils/CFGToSCF.cpp
+++ b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
@@ -137,6 +137,13 @@ getMutableSuccessorOperands(Block *block, unsigned successorIndex) {
   return succOps.getMutableForwardedOperands();
 }
 
+/// Return the operand range used to transfer operands from `block` to its
+/// successor with the given index.
+static OperandRange getSuccessorOperands(Block *block,
+                                         unsigned successorIndex) {
+  return getMutableSuccessorOperands(block, successorIndex);
+}
+
 /// Appends all the block arguments from `other` to the block arguments of
 /// `block`, copying their types and locations.
 static void addBlockArgumentsFromOther(Block *block, Block *other) {
@@ -175,8 +182,14 @@ class Edge {
 
   /// Returns the arguments of this edge that are passed to the block arguments
   /// of the successor.
-  MutableOperandRange getSuccessorOperands() const {
-    return getMutableSuccessorOperands(fromBlock, successorIndex);
+  MutableOperandRange getMutableSuccessorOperands() const {
+    return ::getMutableSuccessorOperands(fromBlock, successorIndex);
+  }
+
+  /// Returns the arguments of this edge that are passed to the block arguments
+  /// of the successor.
+  OperandRange getSuccessorOperands() const {
+    return ::getSuccessorOperands(fromBlock, successorIndex);
   }
 };
 
@@ -262,7 +275,7 @@ class EdgeMultiplexer {
     assert(result != blockArgMapping.end() &&
            "Edge was not originally passed to `create` method.");
 
-    MutableOperandRange successorOperands = edge.getSuccessorOperands();
+    MutableOperandRange successorOperands = edge.getMutableSuccessorOperands();
 
     // Extra arguments are always appended at the end of the block arguments.
     unsigned extraArgsBeginIndex =
@@ -666,7 +679,7 @@ transformToReduceLoop(Block *loopHeader, Block *exitBlock,
   // invalidated when mutating the operands through a different
   // `MutableOperandRange` of the same operation.
   SmallVector<Value> loopHeaderSuccessorOperands =
-      llvm::to_vector(getMutableSuccessorOperands(latch, loopHeaderIndex));
+      llvm::to_vector(getSuccessorOperands(latch, loopHeaderIndex));
 
   // Add all values used in the next iteration to the exit block. Replace
   // any uses that are outside the loop with the newly created exit block.
@@ -742,7 +755,7 @@ transformToReduceLoop(Block *loopHeader, Block *exitBlock,
 
           loopHeaderSuccessorOperands.push_back(argument);
           for (Edge edge : successorEdges(latch))
-            edge.getSuccessorOperands().append(argument);
+            edge.getMutableSuccessorOperands().append(argument);
         }
 
         use.set(blockArgument);
@@ -939,9 +952,8 @@ static FailureOr<SmallVector<Block *>> transformToStructuredCFBranches(
   if (regionEntry->getNumSuccessors() == 1) {
     // Single successor we can just splice together.
     Block *successor = regionEntry->getSuccessor(0);
-    for (auto &&[oldValue, newValue] :
-         llvm::zip(successor->getArguments(),
-                   getMutableSuccessorOperands(regionEntry, 0)))
+    for (auto &&[oldValue, newValue] : llvm::zip(
+             successor->getArguments(), getSuccessorOperands(regionEntry, 0)))
       oldValue.replaceAllUsesWith(newValue);
     regionEntry->getTerminator()->erase();
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/66629


More information about the Mlir-commits mailing list