[Mlir-commits] [mlir] d69293c - [mlir][SCF] `ForOp`: Remove `getIterArgNumberForOpOperand` (#66629)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 19 08:33:45 PDT 2023
Author: Matthias Springer
Date: 2023-09-19T17:33:40+02:00
New Revision: d69293c1c80f0f0e3eb012bc006573d4d5cb820f
URL: https://github.com/llvm/llvm-project/commit/d69293c1c80f0f0e3eb012bc006573d4d5cb820f
DIFF: https://github.com/llvm/llvm-project/commit/d69293c1c80f0f0e3eb012bc006573d4d5cb820f.diff
LOG: [mlir][SCF] `ForOp`: Remove `getIterArgNumberForOpOperand` (#66629)
This function was inconsistent with the remaining API because it
accepted `OpOperand &` that do not belong to the op. All the other
functions assert. This helper function is also not really necessary, as
the iter_arg number is identical to the result number.
Added:
Modified:
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 8a9ce949a750d43..89c1a06412947b2 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -263,16 +263,6 @@ def ForOp : SCF_Op<"for",
}
/// Number of operands controlling the loop: lb, ub, step
unsigned getNumControlOperands() { return 3; }
- /// Get the iter arg number for an operand. If it isnt an iter arg
- /// operand return std::nullopt.
- std::optional<unsigned> getIterArgNumberForOpOperand(OpOperand &opOperand) {
- if (opOperand.getOwner() != getOperation())
- return std::nullopt;
- unsigned operandNumber = opOperand.getOperandNumber();
- if (operandNumber < getNumControlOperands())
- return std::nullopt;
- return operandNumber - getNumControlOperands();
- }
/// Get the region iter arg that corresponds to an OpOperand.
/// This helper prevents internal op implementation detail leakage to
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 21bc0554e717692..a9debb7bbc489a4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -810,13 +810,9 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp());
- std::optional<unsigned> maybeOperandNumber =
- forOp.getIterArgNumberForOpOperand(*pUse);
- assert(maybeOperandNumber.has_value() && "expected a proper iter arg number");
-
- int64_t operandNumber = maybeOperandNumber.value();
+ unsigned iterArgNumber = forOp.getResultForOpOperand(*pUse).getResultNumber();
auto yieldOp = cast<scf::YieldOp>(forOp.getBody(0)->getTerminator());
- auto yieldingExtractSliceOp = yieldOp->getOperand(operandNumber)
+ auto yieldingExtractSliceOp = yieldOp->getOperand(iterArgNumber)
.getDefiningOp<tensor::ExtractSliceOp>();
if (!yieldingExtractSliceOp)
return tensor::ExtractSliceOp();
@@ -829,9 +825,9 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
return tensor::ExtractSliceOp();
SmallVector<Value> initArgs = forOp.getInitArgs();
- initArgs[operandNumber] = hoistedPackedTensor;
+ initArgs[iterArgNumber] = hoistedPackedTensor;
SmallVector<Value> yieldOperands = yieldOp.getOperands();
- yieldOperands[operandNumber] = yieldingExtractSliceOp.getSource();
+ yieldOperands[iterArgNumber] = yieldingExtractSliceOp.getSource();
int64_t numOriginalForOpResults = initArgs.size();
LLVM_DEBUG(DBGS() << "numOriginalForOpResults: " << numOriginalForOpResults
@@ -844,7 +840,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
hoistedPackedTensor.getLoc(), hoistedPackedTensor,
outerSliceOp.getMixedOffsets(), outerSliceOp.getMixedSizes(),
outerSliceOp.getMixedStrides());
- rewriter.replaceAllUsesWith(forOp.getResult(operandNumber), extracted);
+ rewriter.replaceAllUsesWith(forOp.getResult(iterArgNumber), extracted);
}
scf::ForOp newForOp =
replaceLoopWithNewYields(rewriter, forOp, initArgs, yieldOperands);
@@ -853,20 +849,20 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
<< "\n");
LLVM_DEBUG(DBGS() << "replace source of: " << extracted << "\n");
LLVM_DEBUG(DBGS() << "with result #"
- << numOriginalForOpResults + operandNumber
+ << numOriginalForOpResults + iterArgNumber
<< " of forOp, giving us: " << extracted << "\n");
rewriter.startRootUpdate(extracted);
extracted.getSourceMutable().assign(
- newForOp.getResult(numOriginalForOpResults + operandNumber));
+ newForOp.getResult(numOriginalForOpResults + iterArgNumber));
rewriter.finalizeRootUpdate(extracted);
LLVM_DEBUG(DBGS() << "replace uses of: " << paddedValueBeforeHoisting
<< "\n");
LLVM_DEBUG(DBGS() << "with region iter arg #"
- << numOriginalForOpResults + operandNumber << "\n");
+ << numOriginalForOpResults + iterArgNumber << "\n");
rewriter.replaceAllUsesWith(
paddedValueBeforeHoisting,
- newForOp.getRegionIterArg(numOriginalForOpResults + operandNumber));
+ newForOp.getRegionIterArg(numOriginalForOpResults + iterArgNumber));
return extracted;
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index f5586712a84fae3..6cfba3fef15ebda 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -569,8 +569,9 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
scf::ForOp outerMostLoop = loops.front();
if (destinationInitArg &&
(*destinationInitArg)->getOwner() == outerMostLoop) {
- std::optional<unsigned> iterArgNumber =
- outerMostLoop.getIterArgNumberForOpOperand(**destinationInitArg);
+ unsigned iterArgNumber =
+ outerMostLoop.getResultForOpOperand(**destinationInitArg)
+ .getResultNumber();
int64_t resultNumber = fusableProducer.getResultNumber();
if (auto dstOp =
dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
@@ -584,7 +585,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
scf::ForOp innerMostLoop = loops.back();
updateDestinationOperandsForTiledOp(
rewriter, dstOp.getDpsInitOperand(resultNumber)->get(),
- innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]);
+ innerMostLoop.getRegionIterArgs()[iterArgNumber]);
}
}
return scf::SCFFuseProducerOfSliceResult{fusableProducer,
More information about the Mlir-commits
mailing list