[Mlir-commits] [mlir] d69293c - [mlir][SCF] `ForOp`: Remove `getIterArgNumberForOpOperand` (#66629)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 19 08:33:45 PDT 2023


Author: Matthias Springer
Date: 2023-09-19T17:33:40+02:00
New Revision: d69293c1c80f0f0e3eb012bc006573d4d5cb820f

URL: https://github.com/llvm/llvm-project/commit/d69293c1c80f0f0e3eb012bc006573d4d5cb820f
DIFF: https://github.com/llvm/llvm-project/commit/d69293c1c80f0f0e3eb012bc006573d4d5cb820f.diff

LOG: [mlir][SCF] `ForOp`: Remove `getIterArgNumberForOpOperand` (#66629)

This function was inconsistent with the remaining API because it
accepted `OpOperand &` that do not belong to the op. All the other
functions assert. This helper function is also not really necessary, as
the iter_arg number is identical to the result number.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
    mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
    mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 8a9ce949a750d43..89c1a06412947b2 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -263,16 +263,6 @@ def ForOp : SCF_Op<"for",
     }
     /// Number of operands controlling the loop: lb, ub, step
     unsigned getNumControlOperands() { return 3; }
-    /// Get the iter arg number for an operand. If it isnt an iter arg
-    /// operand return std::nullopt.
-    std::optional<unsigned> getIterArgNumberForOpOperand(OpOperand &opOperand) {
-      if (opOperand.getOwner() != getOperation())
-        return std::nullopt;
-      unsigned operandNumber = opOperand.getOperandNumber();
-      if (operandNumber < getNumControlOperands())
-        return std::nullopt;
-      return operandNumber - getNumControlOperands();
-    }
 
     /// Get the region iter arg that corresponds to an OpOperand.
     /// This helper prevents internal op implementation detail leakage to

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 21bc0554e717692..a9debb7bbc489a4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -810,13 +810,9 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp());
 
-  std::optional<unsigned> maybeOperandNumber =
-      forOp.getIterArgNumberForOpOperand(*pUse);
-  assert(maybeOperandNumber.has_value() && "expected a proper iter arg number");
-
-  int64_t operandNumber = maybeOperandNumber.value();
+  unsigned iterArgNumber = forOp.getResultForOpOperand(*pUse).getResultNumber();
   auto yieldOp = cast<scf::YieldOp>(forOp.getBody(0)->getTerminator());
-  auto yieldingExtractSliceOp = yieldOp->getOperand(operandNumber)
+  auto yieldingExtractSliceOp = yieldOp->getOperand(iterArgNumber)
                                     .getDefiningOp<tensor::ExtractSliceOp>();
   if (!yieldingExtractSliceOp)
     return tensor::ExtractSliceOp();
@@ -829,9 +825,9 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
     return tensor::ExtractSliceOp();
 
   SmallVector<Value> initArgs = forOp.getInitArgs();
-  initArgs[operandNumber] = hoistedPackedTensor;
+  initArgs[iterArgNumber] = hoistedPackedTensor;
   SmallVector<Value> yieldOperands = yieldOp.getOperands();
-  yieldOperands[operandNumber] = yieldingExtractSliceOp.getSource();
+  yieldOperands[iterArgNumber] = yieldingExtractSliceOp.getSource();
 
   int64_t numOriginalForOpResults = initArgs.size();
   LLVM_DEBUG(DBGS() << "numOriginalForOpResults: " << numOriginalForOpResults
@@ -844,7 +840,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
         hoistedPackedTensor.getLoc(), hoistedPackedTensor,
         outerSliceOp.getMixedOffsets(), outerSliceOp.getMixedSizes(),
         outerSliceOp.getMixedStrides());
-    rewriter.replaceAllUsesWith(forOp.getResult(operandNumber), extracted);
+    rewriter.replaceAllUsesWith(forOp.getResult(iterArgNumber), extracted);
   }
   scf::ForOp newForOp =
       replaceLoopWithNewYields(rewriter, forOp, initArgs, yieldOperands);
@@ -853,20 +849,20 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
                     << "\n");
   LLVM_DEBUG(DBGS() << "replace source of: " << extracted << "\n");
   LLVM_DEBUG(DBGS() << "with result #"
-                    << numOriginalForOpResults + operandNumber
+                    << numOriginalForOpResults + iterArgNumber
                     << " of forOp, giving us: " << extracted << "\n");
   rewriter.startRootUpdate(extracted);
   extracted.getSourceMutable().assign(
-      newForOp.getResult(numOriginalForOpResults + operandNumber));
+      newForOp.getResult(numOriginalForOpResults + iterArgNumber));
   rewriter.finalizeRootUpdate(extracted);
 
   LLVM_DEBUG(DBGS() << "replace uses of: " << paddedValueBeforeHoisting
                     << "\n");
   LLVM_DEBUG(DBGS() << "with region iter arg #"
-                    << numOriginalForOpResults + operandNumber << "\n");
+                    << numOriginalForOpResults + iterArgNumber << "\n");
   rewriter.replaceAllUsesWith(
       paddedValueBeforeHoisting,
-      newForOp.getRegionIterArg(numOriginalForOpResults + operandNumber));
+      newForOp.getRegionIterArg(numOriginalForOpResults + iterArgNumber));
 
   return extracted;
 }

diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index f5586712a84fae3..6cfba3fef15ebda 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -569,8 +569,9 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
   scf::ForOp outerMostLoop = loops.front();
   if (destinationInitArg &&
       (*destinationInitArg)->getOwner() == outerMostLoop) {
-    std::optional<unsigned> iterArgNumber =
-        outerMostLoop.getIterArgNumberForOpOperand(**destinationInitArg);
+    unsigned iterArgNumber =
+        outerMostLoop.getResultForOpOperand(**destinationInitArg)
+            .getResultNumber();
     int64_t resultNumber = fusableProducer.getResultNumber();
     if (auto dstOp =
             dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
@@ -584,7 +585,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
       scf::ForOp innerMostLoop = loops.back();
       updateDestinationOperandsForTiledOp(
           rewriter, dstOp.getDpsInitOperand(resultNumber)->get(),
-          innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]);
+          innerMostLoop.getRegionIterArgs()[iterArgNumber]);
     }
   }
   return scf::SCFFuseProducerOfSliceResult{fusableProducer,


        


More information about the Mlir-commits mailing list