[Mlir-commits] [mlir] [mlir][SCF] `ForOp`: Remove `getIterArgNumberForOpOperand` (PR #66629)

Matthias Springer llvmlistbot at llvm.org
Tue Sep 19 02:23:37 PDT 2023


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/66629

>From 30b28acd70337ff86c34bb50697eeb74b3a9e8be Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 18 Sep 2023 12:34:21 +0200
Subject: [PATCH] [mlir][SCF] `ForOp`: Remove `getIterArgNumberForOpOperand`

This function was inconsistent with the remaining API because it accepted `OpOperand &` that do not belong to the op. All the other functions assert. This helper function is also not really necessary, as the iter_arg number is identical to the result number.
---
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    | 10 ---------
 .../Linalg/Transforms/HoistPadding.cpp        | 22 ++++++++-----------
 .../SCF/Transforms/TileUsingInterface.cpp     |  7 +++---
 3 files changed, 13 insertions(+), 26 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 8a9ce949a750d43..89c1a06412947b2 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -263,16 +263,6 @@ def ForOp : SCF_Op<"for",
     }
     /// Number of operands controlling the loop: lb, ub, step
     unsigned getNumControlOperands() { return 3; }
-    /// Get the iter arg number for an operand. If it isnt an iter arg
-    /// operand return std::nullopt.
-    std::optional<unsigned> getIterArgNumberForOpOperand(OpOperand &opOperand) {
-      if (opOperand.getOwner() != getOperation())
-        return std::nullopt;
-      unsigned operandNumber = opOperand.getOperandNumber();
-      if (operandNumber < getNumControlOperands())
-        return std::nullopt;
-      return operandNumber - getNumControlOperands();
-    }
 
     /// Get the region iter arg that corresponds to an OpOperand.
     /// This helper prevents internal op implementation detail leakage to
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 21bc0554e717692..a9debb7bbc489a4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -810,13 +810,9 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp());
 
-  std::optional<unsigned> maybeOperandNumber =
-      forOp.getIterArgNumberForOpOperand(*pUse);
-  assert(maybeOperandNumber.has_value() && "expected a proper iter arg number");
-
-  int64_t operandNumber = maybeOperandNumber.value();
+  unsigned iterArgNumber = forOp.getResultForOpOperand(*pUse).getResultNumber();
   auto yieldOp = cast<scf::YieldOp>(forOp.getBody(0)->getTerminator());
-  auto yieldingExtractSliceOp = yieldOp->getOperand(operandNumber)
+  auto yieldingExtractSliceOp = yieldOp->getOperand(iterArgNumber)
                                     .getDefiningOp<tensor::ExtractSliceOp>();
   if (!yieldingExtractSliceOp)
     return tensor::ExtractSliceOp();
@@ -829,9 +825,9 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
     return tensor::ExtractSliceOp();
 
   SmallVector<Value> initArgs = forOp.getInitArgs();
-  initArgs[operandNumber] = hoistedPackedTensor;
+  initArgs[iterArgNumber] = hoistedPackedTensor;
   SmallVector<Value> yieldOperands = yieldOp.getOperands();
-  yieldOperands[operandNumber] = yieldingExtractSliceOp.getSource();
+  yieldOperands[iterArgNumber] = yieldingExtractSliceOp.getSource();
 
   int64_t numOriginalForOpResults = initArgs.size();
   LLVM_DEBUG(DBGS() << "numOriginalForOpResults: " << numOriginalForOpResults
@@ -844,7 +840,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
         hoistedPackedTensor.getLoc(), hoistedPackedTensor,
         outerSliceOp.getMixedOffsets(), outerSliceOp.getMixedSizes(),
         outerSliceOp.getMixedStrides());
-    rewriter.replaceAllUsesWith(forOp.getResult(operandNumber), extracted);
+    rewriter.replaceAllUsesWith(forOp.getResult(iterArgNumber), extracted);
   }
   scf::ForOp newForOp =
       replaceLoopWithNewYields(rewriter, forOp, initArgs, yieldOperands);
@@ -853,20 +849,20 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
                     << "\n");
   LLVM_DEBUG(DBGS() << "replace source of: " << extracted << "\n");
   LLVM_DEBUG(DBGS() << "with result #"
-                    << numOriginalForOpResults + operandNumber
+                    << numOriginalForOpResults + iterArgNumber
                     << " of forOp, giving us: " << extracted << "\n");
   rewriter.startRootUpdate(extracted);
   extracted.getSourceMutable().assign(
-      newForOp.getResult(numOriginalForOpResults + operandNumber));
+      newForOp.getResult(numOriginalForOpResults + iterArgNumber));
   rewriter.finalizeRootUpdate(extracted);
 
   LLVM_DEBUG(DBGS() << "replace uses of: " << paddedValueBeforeHoisting
                     << "\n");
   LLVM_DEBUG(DBGS() << "with region iter arg #"
-                    << numOriginalForOpResults + operandNumber << "\n");
+                    << numOriginalForOpResults + iterArgNumber << "\n");
   rewriter.replaceAllUsesWith(
       paddedValueBeforeHoisting,
-      newForOp.getRegionIterArg(numOriginalForOpResults + operandNumber));
+      newForOp.getRegionIterArg(numOriginalForOpResults + iterArgNumber));
 
   return extracted;
 }
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index f5586712a84fae3..6cfba3fef15ebda 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -569,8 +569,9 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
   scf::ForOp outerMostLoop = loops.front();
   if (destinationInitArg &&
       (*destinationInitArg)->getOwner() == outerMostLoop) {
-    std::optional<unsigned> iterArgNumber =
-        outerMostLoop.getIterArgNumberForOpOperand(**destinationInitArg);
+    unsigned iterArgNumber =
+        outerMostLoop.getResultForOpOperand(**destinationInitArg)
+            .getResultNumber();
     int64_t resultNumber = fusableProducer.getResultNumber();
     if (auto dstOp =
             dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
@@ -584,7 +585,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
       scf::ForOp innerMostLoop = loops.back();
       updateDestinationOperandsForTiledOp(
           rewriter, dstOp.getDpsInitOperand(resultNumber)->get(),
-          innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]);
+          innerMostLoop.getRegionIterArgs()[iterArgNumber]);
     }
   }
   return scf::SCFFuseProducerOfSliceResult{fusableProducer,



More information about the Mlir-commits mailing list