[Mlir-commits] [mlir] [MLIR][SCF] Add support for loop pipeline peeling for dynamic loops. (PR #106436)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Aug 29 12:46:20 PDT 2024
https://github.com/sjw36 updated https://github.com/llvm/llvm-project/pull/106436
>From 9be66a1a091d268b9d61bceb39d43cd2b66e63bd Mon Sep 17 00:00:00 2001
From: SJW <swaters at amd.com>
Date: Thu, 29 Aug 2024 19:45:57 +0000
Subject: [PATCH] [MLIR][SCF] Add support for loop pipeline peeling for dynamic
loops. * Allow speculative execution and predicate results per stage.
---
.../Dialect/SCF/Transforms/LoopPipelining.cpp | 126 ++++++++++++------
mlir/test/Dialect/SCF/loop-pipelining.mlir | 43 +++++-
mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp | 4 +-
3 files changed, 129 insertions(+), 44 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index d8e1cc0ecef88e..258e075e263259 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -94,8 +94,8 @@ struct LoopPipelinerInternal {
RewriterBase &rewriter);
/// Emits the epilogue, this creates `maxStage - 1` part which will contain
/// operations from stages [i; maxStage], where i is the part index.
- void emitEpilogue(RewriterBase &rewriter,
- llvm::SmallVector<Value> &returnValues);
+ LogicalResult emitEpilogue(RewriterBase &rewriter,
+ llvm::SmallVector<Value> &returnValues);
};
bool LoopPipelinerInternal::initializeLoopInfo(
@@ -133,10 +133,6 @@ bool LoopPipelinerInternal::initializeLoopInfo(
LDBG("--no epilogue or predicate set -> BAIL");
return false;
}
- if (dynamicLoop && peelEpilogue) {
- LDBG("--dynamic loop doesn't support epilogue yet -> BAIL");
- return false;
- }
std::vector<std::pair<Operation *, unsigned>> schedule;
options.getScheduleFn(forOp, schedule);
if (schedule.empty()) {
@@ -313,10 +309,10 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
});
int predicateIdx = i - stages[op];
if (predicates[predicateIdx]) {
+ OpBuilder::InsertionGuard insertGuard(rewriter);
newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
assert(newOp && "failed to predicate op.");
}
- rewriter.setInsertionPointAfter(newOp);
if (annotateFn)
annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
@@ -561,6 +557,7 @@ LogicalResult LoopPipelinerInternal::createKernel(
}
if (predicates[useStage]) {
+ OpBuilder::InsertionGuard insertGuard(rewriter);
newOp = predicateFn(rewriter, newOp, predicates[useStage]);
if (!newOp)
return failure();
@@ -568,7 +565,6 @@ LogicalResult LoopPipelinerInternal::createKernel(
for (auto values : llvm::zip(op->getResults(), newOp->getResults()))
mapping.map(std::get<0>(values), std::get<1>(values));
}
- rewriter.setInsertionPointAfter(newOp);
if (annotateFn)
annotateFn(newOp, PipeliningOption::PipelinerPart::Kernel, 0);
}
@@ -640,70 +636,123 @@ LogicalResult LoopPipelinerInternal::createKernel(
return success();
}
-void LoopPipelinerInternal::emitEpilogue(
- RewriterBase &rewriter, llvm::SmallVector<Value> &returnValues) {
+LogicalResult
+LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
+ llvm::SmallVector<Value> &returnValues) {
+ Location loc = forOp.getLoc();
// Emit different versions of the induction variable. They will be
// removed by dead code if not used.
+
+ // bounds_range = ub - lb
+ // total_iterations = bounds_range / step + (bounds_range % step ? 1 : 0)
+ Type t = lb.getType();
+ Value minus1 =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
+
+ Value const_0 =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
+ Value const_1 =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 1));
+ Value boundsRange = rewriter.create<arith::SubIOp>(loc, ub, lb);
+ Value boundsRem = rewriter.create<arith::RemUIOp>(loc, boundsRange, step);
+ Value hasRem = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
+ boundsRem, const_0);
+ Value selRem =
+ rewriter.create<arith::SelectOp>(loc, hasRem, const_1, const_0);
+ Value boundsDiv = rewriter.create<arith::DivUIOp>(loc, boundsRange, step);
+ Value totalIterations =
+ rewriter.create<arith::AddIOp>(loc, boundsDiv, selRem);
+
+ SmallVector<Value> predicates(maxStage + 1);
for (int64_t i = 0; i < maxStage; i++) {
- Location loc = forOp.getLoc();
- Type t = lb.getType();
- Value minusOne =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
- // number of iterations = ((ub - 1) - lb) / step
- Value totalNumIteration = rewriter.create<arith::DivUIOp>(
- loc,
- rewriter.create<arith::SubIOp>(
- loc, rewriter.create<arith::AddIOp>(loc, ub, minusOne), lb),
- step);
- // newLastIter = lb + step * ((((ub - 1) - lb) / step) - i)
+ // iterI = total_iters - 1 - i
+ // May go negative...
Value minusI =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
+ Value iterI = rewriter.create<arith::AddIOp>(
+ loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minus1),
+ minusI);
+ // newLastIter = lb + step * iterI
Value newlastIter = rewriter.create<arith::AddIOp>(
- loc, lb,
- rewriter.create<arith::MulIOp>(
- loc, step,
- rewriter.create<arith::AddIOp>(loc, totalNumIteration, minusI)));
+ loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
+
setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
+
+ if (dynamicLoop) {
+ // pred = iterI >= lb
+ predicates[i + 1] = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, iterI, lb);
+ }
}
+
// Emit `maxStage - 1` epilogue part that includes operations from stages
// [i; maxStage].
for (int64_t i = 1; i <= maxStage; i++) {
+ SmallVector<std::pair<Value, unsigned>> returnMap(returnValues.size());
for (Operation *op : opOrder) {
if (stages[op] < i)
continue;
+ unsigned currentVersion = maxStage - stages[op] + i;
+ unsigned nextVersion = currentVersion + 1;
Operation *newOp =
cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
auto it = valueMapping.find(newOperand->get());
if (it != valueMapping.end()) {
- Value replacement = it->second[maxStage - stages[op] + i];
+ Value replacement = it->second[currentVersion];
newOperand->set(replacement);
}
});
if (annotateFn)
annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
- for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
- setValueMapping(op->getResult(destId), newOp->getResult(destId),
- maxStage - stages[op] + i);
+ if (dynamicLoop) {
+ OpBuilder::InsertionGuard insertGuard(rewriter);
+ newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
+ if (!newOp)
+ return failure();
+ }
+
+ for (auto [opRes, newRes] :
+ llvm::zip(op->getResults(), newOp->getResults())) {
+ setValueMapping(opRes, newRes, currentVersion);
// If the value is a loop carried dependency update the loop argument
// mapping and keep track of the last version to replace the original
// forOp uses.
for (OpOperand &operand :
forOp.getBody()->getTerminator()->getOpOperands()) {
- if (operand.get() != op->getResult(destId))
+ if (operand.get() != opRes)
continue;
- unsigned version = maxStage - stages[op] + i + 1;
// If the version is greater than maxStage it means it maps to the
// original forOp returned value.
- if (version > maxStage) {
- returnValues[operand.getOperandNumber()] = newOp->getResult(destId);
- continue;
- }
- setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
- newOp->getResult(destId), version);
+ unsigned ri = operand.getOperandNumber();
+ returnValues[ri] = newRes;
+ Value mapVal = forOp.getRegionIterArgs()[ri];
+ returnMap[ri] = std::make_pair(mapVal, currentVersion);
+ if (nextVersion <= maxStage)
+ setValueMapping(mapVal, newRes, nextVersion);
+ }
+ }
+ }
+ if (dynamicLoop) {
+ // Select return values from this stage (live outs) based on predication.
+ // If the stage is valid select the peeled value, else use previous stage
+ // value.
+ for (auto pair : llvm::enumerate(returnValues)) {
+ unsigned ri = pair.index();
+ auto [mapVal, currentVersion] = returnMap[ri];
+ if (mapVal) {
+ unsigned nextVersion = currentVersion + 1;
+ Value pred = predicates[currentVersion];
+ Value prevValue = valueMapping[mapVal][currentVersion];
+ auto selOp = rewriter.create<arith::SelectOp>(loc, pred, pair.value(),
+ prevValue);
+ returnValues[ri] = selOp;
+ if (nextVersion <= maxStage)
+ setValueMapping(mapVal, selOp, nextVersion);
}
}
}
}
+ return success();
}
void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
@@ -760,7 +809,8 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
if (options.peelEpilogue) {
// 4. Emit the epilogue after the new forOp.
rewriter.setInsertionPointAfter(newForOp);
- pipeliner.emitEpilogue(rewriter, returnValues);
+ if (failed(pipeliner.emitEpilogue(rewriter, returnValues)))
+ return failure();
}
// 5. Erase the original loop and replace the uses with the epilogue output.
if (forOp->getNumResults() > 0)
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 9687f80f5ddfc8..957dc5295c0583 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -764,11 +764,46 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
// NOEPILOGUE: memref.load %[[A]][%[[IV3]]] : memref<?xf32>
// NOEPILOGUE: scf.yield %[[V2]], %[[L3]] : f32, f32
-// In case dynamic loop pipelining is off check that the transformation didn't
-// apply.
+// Check for predicated epilogue for dynamic loop.
// CHECK-LABEL: dynamic_loop(
-// CHECK-NOT: memref.load
-// CHECK: scf.for
+// CHECK: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
+// CHECK: memref.store %[[ARG6]], %{{.*}}[%[[ARG5]]]
+// CHECK: %[[ADDF_26:.*]] = arith.addf %[[ARG7]], %{{.*}}
+// CHECK: %[[MULI_27:.*]] = arith.muli %{{.*}}, %{{.*}}
+// CHECK: %[[ADDI_28:.*]] = arith.addi %[[ARG5]], %[[MULI_27]]
+// CHECK: %[[LOAD_29:.*]] = memref.load %{{.*}}[%[[ADDI_28]]]
+// CHECK: scf.yield %[[ADDF_26]], %[[LOAD_29]]
+// CHECK: }
+// CHECK: %[[SUBI_10:.*]] = arith.subi %{{.*}}, %{{.*}}
+// CHECK: %[[REMUI_11:.*]] = arith.remui %[[SUBI_10]], %{{.*}}
+// CHECK: %[[CMPI_12:.*]] = arith.cmpi ne, %[[REMUI_11]], %{{.*}}
+// CHECK: %[[SELECT_13:.*]] = arith.select %[[CMPI_12]], %{{.*}}, %{{.*}}
+// CHECK: %[[DIVUI_14:.*]] = arith.divui %[[SUBI_10]], %{{.*}}
+// CHECK: %[[ADDI_15:.*]] = arith.addi %[[DIVUI_14]], %[[SELECT_13]]
+// CHECK: %[[ADDI_16:.*]] = arith.addi %[[ADDI_15]], %{{.*}}-1
+// CHECK: %[[MULI_17:.*]] = arith.muli %{{.*}}, %[[ADDI_16]]
+// CHECK: %[[ADDI_18:.*]] = arith.addi %{{.*}}, %[[MULI_17]]
+// CHECK: %[[CMPI_19:.*]] = arith.cmpi sge, %[[ADDI_16]], %{{.*}}
+// CHECK: %[[ADDI_20:.*]] = arith.addi %[[ADDI_15]], %{{.*}}-1
+// CHECK: %[[ADDI_21:.*]] = arith.addi %[[ADDI_20]], %{{.*}}-1
+// CHECK: %[[MULI_22:.*]] = arith.muli %{{.*}}, %[[ADDI_21]]
+// CHECK: %[[ADDI_23:.*]] = arith.addi %{{.*}}, %[[MULI_22]]
+// CHECK: %[[CMPI_24:.*]] = arith.cmpi sge, %[[ADDI_21]], %{{.*}}
+// CHECK: scf.if %[[CMPI_19]] {
+// CHECK: memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_23]]]
+// CHECK: } else {
+// CHECK: }
+// CHECK: %[[IF_25:.*]] = scf.if %[[CMPI_24]] -> (f32) {
+// CHECK: %[[ADDF_26:.*]] = arith.addf %{{.*}}#1, %{{.*}}
+// CHECK: scf.yield %[[ADDF_26]]
+// CHECK: } else {
+// CHECK: scf.yield %{{.*}}
+// CHECK: }
+// CHECK: scf.if %[[CMPI_24]] {
+// CHECK: memref.store %[[IF_25]], %{{.*}}[%[[ADDI_18]]]
+// CHECK: } else {
+// CHECK: }
+// CHECK: return
func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
%cf = arith.constant 1.0 : f32
scf.for %i0 = %lb to %ub step %step {
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
index 8a92d840ad1302..3ff7f9966e93da 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -214,12 +214,12 @@ struct TestSCFPipeliningPass
RewritePatternSet patterns(&getContext());
mlir::scf::PipeliningOption options;
options.getScheduleFn = getSchedule;
+ options.supportDynamicLoops = true;
+ options.predicateFn = predicateOp;
if (annotatePipeline)
options.annotateFn = annotate;
if (noEpiloguePeeling) {
- options.supportDynamicLoops = true;
options.peelEpilogue = false;
- options.predicateFn = predicateOp;
}
scf::populateSCFLoopPipeliningPatterns(patterns, options);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
More information about the Mlir-commits
mailing list