[Mlir-commits] [mlir] [MLIR][SCF] Add support for loop pipeline peeling for dynamic loops. (PR #106436)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 3 16:27:17 PDT 2024
================
@@ -640,70 +636,113 @@ LogicalResult LoopPipelinerInternal::createKernel(
return success();
}
-void LoopPipelinerInternal::emitEpilogue(
- RewriterBase &rewriter, llvm::SmallVector<Value> &returnValues) {
+LogicalResult
+LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
+ llvm::SmallVector<Value> &returnValues) {
+ Location loc = forOp.getLoc();
// Emit different versions of the induction variable. They will be
// removed by dead code if not used.
+
+ // bounds_range = ub - lb
+ // total_iterations = (bounds_range + step - 1) / step
+ Type t = lb.getType();
+ Value minus1 =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
+ Value boundsRange = rewriter.create<arith::SubIOp>(loc, ub, lb);
+ Value rangeIncr = rewriter.create<arith::AddIOp>(loc, boundsRange, step);
+ Value rangeDecr = rewriter.create<arith::AddIOp>(loc, rangeIncr, minus1);
+ Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);
+
+ SmallVector<Value> predicates(maxStage + 1);
for (int64_t i = 0; i < maxStage; i++) {
- Location loc = forOp.getLoc();
- Type t = lb.getType();
- Value minusOne =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
- // number of iterations = ((ub - 1) - lb) / step
- Value totalNumIteration = rewriter.create<arith::DivUIOp>(
- loc,
- rewriter.create<arith::SubIOp>(
- loc, rewriter.create<arith::AddIOp>(loc, ub, minusOne), lb),
- step);
- // newLastIter = lb + step * ((((ub - 1) - lb) / step) - i)
+ // iterI = total_iters - 1 - i
+ // May go negative...
Value minusI =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
+ Value iterI = rewriter.create<arith::AddIOp>(
+ loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minus1),
+ minusI);
+ // newLastIter = lb + step * iterI
Value newlastIter = rewriter.create<arith::AddIOp>(
- loc, lb,
- rewriter.create<arith::MulIOp>(
- loc, step,
- rewriter.create<arith::AddIOp>(loc, totalNumIteration, minusI)));
+ loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
+
setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
+
+ if (dynamicLoop) {
+ // pred = iterI >= lb
+ predicates[i + 1] = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, iterI, lb);
+ }
}
+
// Emit `maxStage - 1` epilogue part that includes operations from stages
// [i; maxStage].
for (int64_t i = 1; i <= maxStage; i++) {
+ SmallVector<std::pair<Value, unsigned>> returnMap(returnValues.size());
for (Operation *op : opOrder) {
if (stages[op] < i)
continue;
+ unsigned currentVersion = maxStage - stages[op] + i;
+ unsigned nextVersion = currentVersion + 1;
Operation *newOp =
cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
auto it = valueMapping.find(newOperand->get());
if (it != valueMapping.end()) {
- Value replacement = it->second[maxStage - stages[op] + i];
+ Value replacement = it->second[currentVersion];
newOperand->set(replacement);
}
});
+ if (dynamicLoop) {
+ OpBuilder::InsertionGuard insertGuard(rewriter);
+ newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
+ if (!newOp)
+ return failure();
+ }
if (annotateFn)
annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
- for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
- setValueMapping(op->getResult(destId), newOp->getResult(destId),
- maxStage - stages[op] + i);
+
+ for (auto [opRes, newRes] :
+ llvm::zip(op->getResults(), newOp->getResults())) {
+ setValueMapping(opRes, newRes, currentVersion);
// If the value is a loop carried dependency update the loop argument
// mapping and keep track of the last version to replace the original
// forOp uses.
for (OpOperand &operand :
forOp.getBody()->getTerminator()->getOpOperands()) {
- if (operand.get() != op->getResult(destId))
+ if (operand.get() != opRes)
continue;
- unsigned version = maxStage - stages[op] + i + 1;
// If the version is greater than maxStage it means it maps to the
// original forOp returned value.
- if (version > maxStage) {
- returnValues[operand.getOperandNumber()] = newOp->getResult(destId);
- continue;
- }
- setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
- newOp->getResult(destId), version);
+ unsigned ri = operand.getOperandNumber();
+ returnValues[ri] = newRes;
+ Value mapVal = forOp.getRegionIterArgs()[ri];
+ returnMap[ri] = std::make_pair(mapVal, currentVersion);
+ if (nextVersion <= maxStage)
+ setValueMapping(mapVal, newRes, nextVersion);
+ }
+ }
+ }
+ if (dynamicLoop) {
+ // Select return values from this stage (live outs) based on predication.
+ // If the stage is valid select the peeled value, else use previous stage
+ // value.
+ for (auto pair : llvm::enumerate(returnValues)) {
+ unsigned ri = pair.index();
+ auto [mapVal, currentVersion] = returnMap[ri];
+ if (mapVal) {
----------------
sjw36 wrote:
For example:
```
func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
%cf0 = arith.constant 1.0 : f32
%cf1 = arith.constant 33.0 : f32
%cst = arith.constant 0 : index
%res:1 = scf.for %i0 = %lb to %ub step %step iter_args (%arg0 = %cf0) -> (f32) {
%A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
%A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
%A2_elem = arith.mulf %A1_elem, %cf1 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
scf.yield %A2_elem : f32
} { __test_pipelining_loop__ }
memref.store %res#0, %result[%cst] : memref<?xf32>
return
}
```
I see now the example predicates every operation using the predicateFn, not just the side-effecting ops. So this becomes:
```
func.func @dynamic_loop_result(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: index, %arg3: index, %arg4: index) {
%c-1 = arith.constant -1 : index
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 1.000000e+00 : f32
%cst_1 = arith.constant 3.300000e+01 : f32
%c0 = arith.constant 0 : index
%0 = arith.cmpi slt, %arg2, %arg3 : index
%1 = scf.if %0 -> (f32) {
%13 = memref.load %arg0[%arg2] : memref<?xf32>
scf.yield %13 : f32
} else {
scf.yield %cst : f32
}
%2 = arith.subi %arg3, %arg4 : index
%3:2 = scf.for %arg5 = %arg2 to %2 step %arg4 iter_args(%arg6 = %cst_0, %arg7 = %1) -> (f32, f32) {
%13 = arith.addf %arg7, %arg6 : f32
%14 = arith.mulf %13, %cst_1 : f32
%15 = arith.addi %arg5, %arg4 : index
%16 = memref.load %arg0[%15] : memref<?xf32>
scf.yield %14, %16 : f32, f32
}
%4 = arith.subi %arg3, %arg2 : index
%5 = arith.addi %4, %arg4 : index
%6 = arith.addi %5, %c-1 : index
%7 = arith.divui %6, %arg4 : index
%8 = arith.addi %7, %c-1 : index
%9 = arith.cmpi sge, %8, %arg2 : index
%10 = scf.if %9 -> (f32) {
%13 = arith.addf %3#1, %3#0 : f32
scf.yield %13 : f32
} else {
scf.yield %cst : f32
}
%11 = scf.if %9 -> (f32) {
%13 = arith.mulf %10, %cst_1 : f32
scf.yield %13 : f32
} else {
scf.yield %cst : f32
}
%12 = arith.select %9, %11, %3#0 : f32 /// redundant
memref.store %12, %arg1[%c0] : memref<?xf32>
return
}
```
As you can see every operations is guarded (including ops that do not produce a loop result). And it doesn't really do speculative execution then.
If only side-effecting ops are guarded and only results are selected based on stage range, results would be:
```
func.func @dynamic_loop_result(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: index, %arg3: index, %arg4: index) {
%c-1 = arith.constant -1 : index
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 1.000000e+00 : f32
%cst_1 = arith.constant 3.300000e+01 : f32
%c0 = arith.constant 0 : index
%0 = arith.cmpi slt, %arg2, %arg3 : index
%1 = scf.if %0 -> (f32) {
%13 = memref.load %arg0[%arg2] : memref<?xf32>
scf.yield %13 : f32
} else {
scf.yield %cst : f32
}
%2 = arith.subi %arg3, %arg4 : index
%3:2 = scf.for %arg5 = %arg2 to %2 step %arg4 iter_args(%arg6 = %cst_0, %arg7 = %1) -> (f32, f32) {
%13 = arith.addf %arg7, %arg6 : f32
%14 = arith.mulf %13, %cst_1 : f32
%15 = arith.addi %arg5, %arg4 : index
%16 = memref.load %arg0[%15] : memref<?xf32>
scf.yield %14, %16 : f32, f32
}
%4 = arith.subi %arg3, %arg2 : index
%5 = arith.addi %4, %arg4 : index
%6 = arith.addi %5, %c-1 : index
%7 = arith.divui %6, %arg4 : index
%8 = arith.addi %7, %c-1 : index
%9 = arith.cmpi sge, %8, %arg2 : index
%10 = arith.addf %3#1, %3#0 : f32
%11 = arith.mulf %10, %cst_1 : f32
%12 = arith.select %9, %11, %3#0 : f32
memref.store %12, %arg1[%c0] : memref<?xf32>
return
}
```
And this seems to be what the Prologue logic is doing as well (see line 343).
https://github.com/llvm/llvm-project/pull/106436
More information about the Mlir-commits
mailing list