[Mlir-commits] [mlir] [MLIR][SCF] Add support for loop pipeline peeling for dynamic loops. (PR #106436)
Lei Zhang
llvmlistbot at llvm.org
Sun Sep 1 11:58:54 PDT 2024
================
@@ -640,70 +636,113 @@ LogicalResult LoopPipelinerInternal::createKernel(
return success();
}
-void LoopPipelinerInternal::emitEpilogue(
- RewriterBase &rewriter, llvm::SmallVector<Value> &returnValues) {
+LogicalResult
+LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
+ llvm::SmallVector<Value> &returnValues) {
+ Location loc = forOp.getLoc();
// Emit different versions of the induction variable. They will be
// removed by dead code if not used.
+
+ // bounds_range = ub - lb
+ // total_iterations = (bounds_range + step - 1) / step
+ Type t = lb.getType();
+ Value minus1 =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
+ Value boundsRange = rewriter.create<arith::SubIOp>(loc, ub, lb);
+ Value rangeIncr = rewriter.create<arith::AddIOp>(loc, boundsRange, step);
+ Value rangeDecr = rewriter.create<arith::AddIOp>(loc, rangeIncr, minus1);
+ Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);
+
+ SmallVector<Value> predicates(maxStage + 1);
for (int64_t i = 0; i < maxStage; i++) {
- Location loc = forOp.getLoc();
- Type t = lb.getType();
- Value minusOne =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
- // number of iterations = ((ub - 1) - lb) / step
- Value totalNumIteration = rewriter.create<arith::DivUIOp>(
- loc,
- rewriter.create<arith::SubIOp>(
- loc, rewriter.create<arith::AddIOp>(loc, ub, minusOne), lb),
- step);
- // newLastIter = lb + step * ((((ub - 1) - lb) / step) - i)
+ // iterI = total_iters - 1 - i
+ // May go negative...
Value minusI =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
+ Value iterI = rewriter.create<arith::AddIOp>(
+ loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minus1),
+ minusI);
+ // newLastIter = lb + step * iterI
Value newlastIter = rewriter.create<arith::AddIOp>(
- loc, lb,
- rewriter.create<arith::MulIOp>(
- loc, step,
- rewriter.create<arith::AddIOp>(loc, totalNumIteration, minusI)));
+ loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
+
setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
+
+ if (dynamicLoop) {
+ // pred = iterI >= lb
+ predicates[i + 1] = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, iterI, lb);
+ }
}
+
// Emit `maxStage - 1` epilogue part that includes operations from stages
// [i; maxStage].
for (int64_t i = 1; i <= maxStage; i++) {
+ SmallVector<std::pair<Value, unsigned>> returnMap(returnValues.size());
for (Operation *op : opOrder) {
if (stages[op] < i)
continue;
+ unsigned currentVersion = maxStage - stages[op] + i;
+ unsigned nextVersion = currentVersion + 1;
Operation *newOp =
cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
auto it = valueMapping.find(newOperand->get());
if (it != valueMapping.end()) {
- Value replacement = it->second[maxStage - stages[op] + i];
+ Value replacement = it->second[currentVersion];
newOperand->set(replacement);
}
});
+ if (dynamicLoop) {
+ OpBuilder::InsertionGuard insertGuard(rewriter);
+ newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
+ if (!newOp)
+ return failure();
+ }
if (annotateFn)
annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
- for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
- setValueMapping(op->getResult(destId), newOp->getResult(destId),
- maxStage - stages[op] + i);
+
+ for (auto [opRes, newRes] :
+ llvm::zip(op->getResults(), newOp->getResults())) {
+ setValueMapping(opRes, newRes, currentVersion);
// If the value is a loop carried dependency update the loop argument
// mapping and keep track of the last version to replace the original
// forOp uses.
for (OpOperand &operand :
forOp.getBody()->getTerminator()->getOpOperands()) {
- if (operand.get() != op->getResult(destId))
+ if (operand.get() != opRes)
continue;
- unsigned version = maxStage - stages[op] + i + 1;
// If the version is greater than maxStage it means it maps to the
// original forOp returned value.
- if (version > maxStage) {
- returnValues[operand.getOperandNumber()] = newOp->getResult(destId);
- continue;
- }
- setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
- newOp->getResult(destId), version);
+ unsigned ri = operand.getOperandNumber();
+ returnValues[ri] = newRes;
+ Value mapVal = forOp.getRegionIterArgs()[ri];
+ returnMap[ri] = std::make_pair(mapVal, currentVersion);
+ if (nextVersion <= maxStage)
+ setValueMapping(mapVal, newRes, nextVersion);
+ }
+ }
+ }
+ if (dynamicLoop) {
----------------
antiagainst wrote:
Can we add a test to excercise this case?
https://github.com/llvm/llvm-project/pull/106436
More information about the Mlir-commits
mailing list