[Mlir-commits] [mlir] [MLIR][SCF] Add support for loop pipeline peeling for dynamic loops. (PR #106436)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 4 10:50:22 PDT 2024
https://github.com/sjw36 updated https://github.com/llvm/llvm-project/pull/106436
>From 9be66a1a091d268b9d61bceb39d43cd2b66e63bd Mon Sep 17 00:00:00 2001
From: SJW <swaters at amd.com>
Date: Thu, 29 Aug 2024 19:45:57 +0000
Subject: [PATCH 1/4] [MLIR][SCF] Add support for loop pipeline peeling for
dynamic loops. * Allow speculative execution and predicate results per stage.
---
.../Dialect/SCF/Transforms/LoopPipelining.cpp | 126 ++++++++++++------
mlir/test/Dialect/SCF/loop-pipelining.mlir | 43 +++++-
mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp | 4 +-
3 files changed, 129 insertions(+), 44 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index d8e1cc0ecef88e..258e075e263259 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -94,8 +94,8 @@ struct LoopPipelinerInternal {
RewriterBase &rewriter);
/// Emits the epilogue, this creates `maxStage - 1` part which will contain
/// operations from stages [i; maxStage], where i is the part index.
- void emitEpilogue(RewriterBase &rewriter,
- llvm::SmallVector<Value> &returnValues);
+ LogicalResult emitEpilogue(RewriterBase &rewriter,
+ llvm::SmallVector<Value> &returnValues);
};
bool LoopPipelinerInternal::initializeLoopInfo(
@@ -133,10 +133,6 @@ bool LoopPipelinerInternal::initializeLoopInfo(
LDBG("--no epilogue or predicate set -> BAIL");
return false;
}
- if (dynamicLoop && peelEpilogue) {
- LDBG("--dynamic loop doesn't support epilogue yet -> BAIL");
- return false;
- }
std::vector<std::pair<Operation *, unsigned>> schedule;
options.getScheduleFn(forOp, schedule);
if (schedule.empty()) {
@@ -313,10 +309,10 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
});
int predicateIdx = i - stages[op];
if (predicates[predicateIdx]) {
+ OpBuilder::InsertionGuard insertGuard(rewriter);
newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
assert(newOp && "failed to predicate op.");
}
- rewriter.setInsertionPointAfter(newOp);
if (annotateFn)
annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
@@ -561,6 +557,7 @@ LogicalResult LoopPipelinerInternal::createKernel(
}
if (predicates[useStage]) {
+ OpBuilder::InsertionGuard insertGuard(rewriter);
newOp = predicateFn(rewriter, newOp, predicates[useStage]);
if (!newOp)
return failure();
@@ -568,7 +565,6 @@ LogicalResult LoopPipelinerInternal::createKernel(
for (auto values : llvm::zip(op->getResults(), newOp->getResults()))
mapping.map(std::get<0>(values), std::get<1>(values));
}
- rewriter.setInsertionPointAfter(newOp);
if (annotateFn)
annotateFn(newOp, PipeliningOption::PipelinerPart::Kernel, 0);
}
@@ -640,70 +636,123 @@ LogicalResult LoopPipelinerInternal::createKernel(
return success();
}
-void LoopPipelinerInternal::emitEpilogue(
- RewriterBase &rewriter, llvm::SmallVector<Value> &returnValues) {
+LogicalResult
+LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
+ llvm::SmallVector<Value> &returnValues) {
+ Location loc = forOp.getLoc();
// Emit different versions of the induction variable. They will be
// removed by dead code if not used.
+
+ // bounds_range = ub - lb
+ // total_iterations = bounds_range / step + (bounds_range % step ? 1 : 0)
+ Type t = lb.getType();
+ Value minus1 =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
+
+ Value const_0 =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
+ Value const_1 =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 1));
+ Value boundsRange = rewriter.create<arith::SubIOp>(loc, ub, lb);
+ Value boundsRem = rewriter.create<arith::RemUIOp>(loc, boundsRange, step);
+ Value hasRem = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
+ boundsRem, const_0);
+ Value selRem =
+ rewriter.create<arith::SelectOp>(loc, hasRem, const_1, const_0);
+ Value boundsDiv = rewriter.create<arith::DivUIOp>(loc, boundsRange, step);
+ Value totalIterations =
+ rewriter.create<arith::AddIOp>(loc, boundsDiv, selRem);
+
+ SmallVector<Value> predicates(maxStage + 1);
for (int64_t i = 0; i < maxStage; i++) {
- Location loc = forOp.getLoc();
- Type t = lb.getType();
- Value minusOne =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
- // number of iterations = ((ub - 1) - lb) / step
- Value totalNumIteration = rewriter.create<arith::DivUIOp>(
- loc,
- rewriter.create<arith::SubIOp>(
- loc, rewriter.create<arith::AddIOp>(loc, ub, minusOne), lb),
- step);
- // newLastIter = lb + step * ((((ub - 1) - lb) / step) - i)
+ // iterI = total_iters - 1 - i
+ // May go negative...
Value minusI =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
+ Value iterI = rewriter.create<arith::AddIOp>(
+ loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minus1),
+ minusI);
+ // newLastIter = lb + step * iterI
Value newlastIter = rewriter.create<arith::AddIOp>(
- loc, lb,
- rewriter.create<arith::MulIOp>(
- loc, step,
- rewriter.create<arith::AddIOp>(loc, totalNumIteration, minusI)));
+ loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
+
setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
+
+ if (dynamicLoop) {
+ // pred = iterI >= lb
+ predicates[i + 1] = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, iterI, lb);
+ }
}
+
// Emit `maxStage - 1` epilogue part that includes operations from stages
// [i; maxStage].
for (int64_t i = 1; i <= maxStage; i++) {
+ SmallVector<std::pair<Value, unsigned>> returnMap(returnValues.size());
for (Operation *op : opOrder) {
if (stages[op] < i)
continue;
+ unsigned currentVersion = maxStage - stages[op] + i;
+ unsigned nextVersion = currentVersion + 1;
Operation *newOp =
cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
auto it = valueMapping.find(newOperand->get());
if (it != valueMapping.end()) {
- Value replacement = it->second[maxStage - stages[op] + i];
+ Value replacement = it->second[currentVersion];
newOperand->set(replacement);
}
});
if (annotateFn)
annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
- for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
- setValueMapping(op->getResult(destId), newOp->getResult(destId),
- maxStage - stages[op] + i);
+ if (dynamicLoop) {
+ OpBuilder::InsertionGuard insertGuard(rewriter);
+ newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
+ if (!newOp)
+ return failure();
+ }
+
+ for (auto [opRes, newRes] :
+ llvm::zip(op->getResults(), newOp->getResults())) {
+ setValueMapping(opRes, newRes, currentVersion);
// If the value is a loop carried dependency update the loop argument
// mapping and keep track of the last version to replace the original
// forOp uses.
for (OpOperand &operand :
forOp.getBody()->getTerminator()->getOpOperands()) {
- if (operand.get() != op->getResult(destId))
+ if (operand.get() != opRes)
continue;
- unsigned version = maxStage - stages[op] + i + 1;
// If the version is greater than maxStage it means it maps to the
// original forOp returned value.
- if (version > maxStage) {
- returnValues[operand.getOperandNumber()] = newOp->getResult(destId);
- continue;
- }
- setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
- newOp->getResult(destId), version);
+ unsigned ri = operand.getOperandNumber();
+ returnValues[ri] = newRes;
+ Value mapVal = forOp.getRegionIterArgs()[ri];
+ returnMap[ri] = std::make_pair(mapVal, currentVersion);
+ if (nextVersion <= maxStage)
+ setValueMapping(mapVal, newRes, nextVersion);
+ }
+ }
+ }
+ if (dynamicLoop) {
+ // Select return values from this stage (live outs) based on predication.
+ // If the stage is valid select the peeled value, else use previous stage
+ // value.
+ for (auto pair : llvm::enumerate(returnValues)) {
+ unsigned ri = pair.index();
+ auto [mapVal, currentVersion] = returnMap[ri];
+ if (mapVal) {
+ unsigned nextVersion = currentVersion + 1;
+ Value pred = predicates[currentVersion];
+ Value prevValue = valueMapping[mapVal][currentVersion];
+ auto selOp = rewriter.create<arith::SelectOp>(loc, pred, pair.value(),
+ prevValue);
+ returnValues[ri] = selOp;
+ if (nextVersion <= maxStage)
+ setValueMapping(mapVal, selOp, nextVersion);
}
}
}
}
+ return success();
}
void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
@@ -760,7 +809,8 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
if (options.peelEpilogue) {
// 4. Emit the epilogue after the new forOp.
rewriter.setInsertionPointAfter(newForOp);
- pipeliner.emitEpilogue(rewriter, returnValues);
+ if (failed(pipeliner.emitEpilogue(rewriter, returnValues)))
+ return failure();
}
// 5. Erase the original loop and replace the uses with the epilogue output.
if (forOp->getNumResults() > 0)
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 9687f80f5ddfc8..957dc5295c0583 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -764,11 +764,46 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
// NOEPILOGUE: memref.load %[[A]][%[[IV3]]] : memref<?xf32>
// NOEPILOGUE: scf.yield %[[V2]], %[[L3]] : f32, f32
-// In case dynamic loop pipelining is off check that the transformation didn't
-// apply.
+// Check for predicated epilogue for dynamic loop.
// CHECK-LABEL: dynamic_loop(
-// CHECK-NOT: memref.load
-// CHECK: scf.for
+// CHECK: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
+// CHECK: memref.store %[[ARG6]], %{{.*}}[%[[ARG5]]]
+// CHECK: %[[ADDF_26:.*]] = arith.addf %[[ARG7]], %{{.*}}
+// CHECK: %[[MULI_27:.*]] = arith.muli %{{.*}}, %{{.*}}
+// CHECK: %[[ADDI_28:.*]] = arith.addi %[[ARG5]], %[[MULI_27]]
+// CHECK: %[[LOAD_29:.*]] = memref.load %{{.*}}[%[[ADDI_28]]]
+// CHECK: scf.yield %[[ADDF_26]], %[[LOAD_29]]
+// CHECK: }
+// CHECK: %[[SUBI_10:.*]] = arith.subi %{{.*}}, %{{.*}}
+// CHECK: %[[REMUI_11:.*]] = arith.remui %[[SUBI_10]], %{{.*}}
+// CHECK: %[[CMPI_12:.*]] = arith.cmpi ne, %[[REMUI_11]], %{{.*}}
+// CHECK: %[[SELECT_13:.*]] = arith.select %[[CMPI_12]], %{{.*}}, %{{.*}}
+// CHECK: %[[DIVUI_14:.*]] = arith.divui %[[SUBI_10]], %{{.*}}
+// CHECK: %[[ADDI_15:.*]] = arith.addi %[[DIVUI_14]], %[[SELECT_13]]
+// CHECK: %[[ADDI_16:.*]] = arith.addi %[[ADDI_15]], %{{.*}}-1
+// CHECK: %[[MULI_17:.*]] = arith.muli %{{.*}}, %[[ADDI_16]]
+// CHECK: %[[ADDI_18:.*]] = arith.addi %{{.*}}, %[[MULI_17]]
+// CHECK: %[[CMPI_19:.*]] = arith.cmpi sge, %[[ADDI_16]], %{{.*}}
+// CHECK: %[[ADDI_20:.*]] = arith.addi %[[ADDI_15]], %{{.*}}-1
+// CHECK: %[[ADDI_21:.*]] = arith.addi %[[ADDI_20]], %{{.*}}-1
+// CHECK: %[[MULI_22:.*]] = arith.muli %{{.*}}, %[[ADDI_21]]
+// CHECK: %[[ADDI_23:.*]] = arith.addi %{{.*}}, %[[MULI_22]]
+// CHECK: %[[CMPI_24:.*]] = arith.cmpi sge, %[[ADDI_21]], %{{.*}}
+// CHECK: scf.if %[[CMPI_19]] {
+// CHECK: memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_23]]]
+// CHECK: } else {
+// CHECK: }
+// CHECK: %[[IF_25:.*]] = scf.if %[[CMPI_24]] -> (f32) {
+// CHECK: %[[ADDF_26:.*]] = arith.addf %{{.*}}#1, %{{.*}}
+// CHECK: scf.yield %[[ADDF_26]]
+// CHECK: } else {
+// CHECK: scf.yield %{{.*}}
+// CHECK: }
+// CHECK: scf.if %[[CMPI_24]] {
+// CHECK: memref.store %[[IF_25]], %{{.*}}[%[[ADDI_18]]]
+// CHECK: } else {
+// CHECK: }
+// CHECK: return
func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
%cf = arith.constant 1.0 : f32
scf.for %i0 = %lb to %ub step %step {
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
index 8a92d840ad1302..3ff7f9966e93da 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -214,12 +214,12 @@ struct TestSCFPipeliningPass
RewritePatternSet patterns(&getContext());
mlir::scf::PipeliningOption options;
options.getScheduleFn = getSchedule;
+ options.supportDynamicLoops = true;
+ options.predicateFn = predicateOp;
if (annotatePipeline)
options.annotateFn = annotate;
if (noEpiloguePeeling) {
- options.supportDynamicLoops = true;
options.peelEpilogue = false;
- options.predicateFn = predicateOp;
}
scf::populateSCFLoopPipeliningPatterns(patterns, options);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
>From df8268d576f17625c87c8a3b7383ee4247eeabbd Mon Sep 17 00:00:00 2001
From: SJW <swaters at amd.com>
Date: Thu, 29 Aug 2024 20:24:50 +0000
Subject: [PATCH 2/4] * annotate predicated op
---
mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 258e075e263259..0615ffce072262 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -702,14 +702,14 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
newOperand->set(replacement);
}
});
- if (annotateFn)
- annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
if (dynamicLoop) {
OpBuilder::InsertionGuard insertGuard(rewriter);
newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
if (!newOp)
return failure();
}
+ if (annotateFn)
+ annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
for (auto [opRes, newRes] :
llvm::zip(op->getResults(), newOp->getResults())) {
>From 5603dedcde6f6f5dab3777c8106435b35192ac3b Mon Sep 17 00:00:00 2001
From: SJW <swaters at amd.com>
Date: Fri, 30 Aug 2024 02:58:50 +0000
Subject: [PATCH 3/4] * strength reduce
---
.../Dialect/SCF/Transforms/LoopPipelining.cpp | 18 ++----
mlir/test/Dialect/SCF/loop-pipelining.mlir | 56 +++++++++----------
2 files changed, 31 insertions(+), 43 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 0615ffce072262..a34542f0161aca 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -644,24 +644,14 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
// removed by dead code if not used.
// bounds_range = ub - lb
- // total_iterations = bounds_range / step + (bounds_range % step ? 1 : 0)
+ // total_iterations = (bounds_range + step - 1) / step
Type t = lb.getType();
Value minus1 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
-
- Value const_0 =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
- Value const_1 =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 1));
Value boundsRange = rewriter.create<arith::SubIOp>(loc, ub, lb);
- Value boundsRem = rewriter.create<arith::RemUIOp>(loc, boundsRange, step);
- Value hasRem = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
- boundsRem, const_0);
- Value selRem =
- rewriter.create<arith::SelectOp>(loc, hasRem, const_1, const_0);
- Value boundsDiv = rewriter.create<arith::DivUIOp>(loc, boundsRange, step);
- Value totalIterations =
- rewriter.create<arith::AddIOp>(loc, boundsDiv, selRem);
+ Value rangeIncr = rewriter.create<arith::AddIOp>(loc, boundsRange, step);
+ Value rangeDecr = rewriter.create<arith::AddIOp>(loc, rangeIncr, minus1);
+ Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);
SmallVector<Value> predicates(maxStage + 1);
for (int64_t i = 0; i < maxStage; i++) {
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 957dc5295c0583..010c39f21afc30 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -767,40 +767,38 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
// Check for predicated epilogue for dynamic loop.
// CHECK-LABEL: dynamic_loop(
// CHECK: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
-// CHECK: memref.store %[[ARG6]], %{{.*}}[%[[ARG5]]]
-// CHECK: %[[ADDF_26:.*]] = arith.addf %[[ARG7]], %{{.*}}
-// CHECK: %[[MULI_27:.*]] = arith.muli %{{.*}}, %{{.*}}
-// CHECK: %[[ADDI_28:.*]] = arith.addi %[[ARG5]], %[[MULI_27]]
-// CHECK: %[[LOAD_29:.*]] = memref.load %{{.*}}[%[[ADDI_28]]]
-// CHECK: scf.yield %[[ADDF_26]], %[[LOAD_29]]
+// CHECK: memref.store %[[ARG6]], %{{.*}}[%[[ARG5]]]
+// CHECK: %[[ADDF_24:.*]] = arith.addf %[[ARG7]], %{{.*}}
+// CHECK: %[[MULI_25:.*]] = arith.muli %{{.*}}, %{{.*}}
+// CHECK: %[[ADDI_26:.*]] = arith.addi %[[ARG5]], %[[MULI_25]]
+// CHECK: %[[LOAD_27:.*]] = memref.load %{{.*}}[%[[ADDI_26]]]
+// CHECK: scf.yield %[[ADDF_24]], %[[LOAD_27]]
// CHECK: }
-// CHECK: %[[SUBI_10:.*]] = arith.subi %{{.*}}, %{{.*}}
-// CHECK: %[[REMUI_11:.*]] = arith.remui %[[SUBI_10]], %{{.*}}
-// CHECK: %[[CMPI_12:.*]] = arith.cmpi ne, %[[REMUI_11]], %{{.*}}
-// CHECK: %[[SELECT_13:.*]] = arith.select %[[CMPI_12]], %{{.*}}, %{{.*}}
-// CHECK: %[[DIVUI_14:.*]] = arith.divui %[[SUBI_10]], %{{.*}}
-// CHECK: %[[ADDI_15:.*]] = arith.addi %[[DIVUI_14]], %[[SELECT_13]]
-// CHECK: %[[ADDI_16:.*]] = arith.addi %[[ADDI_15]], %{{.*}}-1
-// CHECK: %[[MULI_17:.*]] = arith.muli %{{.*}}, %[[ADDI_16]]
-// CHECK: %[[ADDI_18:.*]] = arith.addi %{{.*}}, %[[MULI_17]]
-// CHECK: %[[CMPI_19:.*]] = arith.cmpi sge, %[[ADDI_16]], %{{.*}}
-// CHECK: %[[ADDI_20:.*]] = arith.addi %[[ADDI_15]], %{{.*}}-1
-// CHECK: %[[ADDI_21:.*]] = arith.addi %[[ADDI_20]], %{{.*}}-1
-// CHECK: %[[MULI_22:.*]] = arith.muli %{{.*}}, %[[ADDI_21]]
-// CHECK: %[[ADDI_23:.*]] = arith.addi %{{.*}}, %[[MULI_22]]
-// CHECK: %[[CMPI_24:.*]] = arith.cmpi sge, %[[ADDI_21]], %{{.*}}
-// CHECK: scf.if %[[CMPI_19]] {
-// CHECK: memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_23]]]
+// CHECK: %[[SUBI_10:.*]] = arith.subi %{{.*}}, %{{.*}}
+// CHECK: %[[ADDI_11:.*]] = arith.addi %[[SUBI_10]], %{{.*}}
+// CHECK: %[[ADDI_12:.*]] = arith.addi %[[ADDI_11]], %{{.*}}-1
+// CHECK: %[[DIVUI_13:.*]] = arith.divui %[[ADDI_12]], %{{.*}}
+// CHECK: %[[ADDI_14:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
+// CHECK: %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]]
+// CHECK: %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]]
+// CHECK: %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %{{.*}}
+// CHECK: %[[ADDI_18:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
+// CHECK: %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1
+// CHECK: %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]]
+// CHECK: %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]]
+// CHECK: %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %{{.*}}
+// CHECK: scf.if %[[CMPI_17]] {
+// CHECK: memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_21]]]
// CHECK: } else {
// CHECK: }
-// CHECK: %[[IF_25:.*]] = scf.if %[[CMPI_24]] -> (f32) {
-// CHECK: %[[ADDF_26:.*]] = arith.addf %{{.*}}#1, %{{.*}}
-// CHECK: scf.yield %[[ADDF_26]]
+// CHECK: %[[IF_23:.*]] = scf.if %[[CMPI_22]] -> (f32) {
+// CHECK: %[[ADDF_24:.*]] = arith.addf %{{.*}}#1, %{{.*}}
+// CHECK: scf.yield %[[ADDF_24]]
// CHECK: } else {
-// CHECK: scf.yield %{{.*}}
+// CHECK: scf.yield %{{.*}}
// CHECK: }
-// CHECK: scf.if %[[CMPI_24]] {
-// CHECK: memref.store %[[IF_25]], %{{.*}}[%[[ADDI_18]]]
+// CHECK: scf.if %[[CMPI_22]] {
+// CHECK: memref.store %[[IF_23]], %{{.*}}[%[[ADDI_16]]]
// CHECK: } else {
// CHECK: }
// CHECK: return
>From 969e8bf3c9f99f39a0c46c57f4cf0b78c09c151b Mon Sep 17 00:00:00 2001
From: SJW <swaters at amd.com>
Date: Wed, 4 Sep 2024 15:28:16 +0000
Subject: [PATCH 4/4] * added test with scf.for results
---
mlir/test/Dialect/SCF/loop-pipelining.mlir | 62 ++++++++++++++++++++++
1 file changed, 62 insertions(+)
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 010c39f21afc30..4a1406faabce1b 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -814,6 +814,68 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
// -----
+// NOEPILOGUE-LABEL: func.func @dynamic_loop_result
+// NOEPILOGUE: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
+// NOEPILOGUE: %[[SUBI_3:.*]] = arith.subi %{{.*}}, %{{.*}}
+// NOEPILOGUE: %[[CMPI_4:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_3]]
+// NOEPILOGUE: %[[ADDF_5:.*]] = arith.addf %[[ARG7]], %[[ARG6]]
+// NOEPILOGUE: %[[MULF_6:.*]] = arith.mulf %[[ADDF_5]], %{{.*}}
+// NOEPILOGUE: %[[ADDI_7:.*]] = arith.addi %[[ARG5]], %{{.*}}
+// NOEPILOGUE: %[[IF_8:.*]] = scf.if %[[CMPI_4]]
+// NOEPILOGUE: %[[LOAD_9:.*]] = memref.load %{{.*}}[%[[ADDI_7]]]
+// NOEPILOGUE: scf.yield %[[LOAD_9]]
+// NOEPILOGUE: } else {
+// NOEPILOGUE: scf.yield %{{.*}}
+// NOEPILOGUE: }
+// NOEPILOGUE: scf.yield %[[MULF_6]], %[[IF_8]]
+// NOEPILOGUE: }
+// NOEPILOGUE: memref.store %{{.*}}#0, %{{.*}}[%{{.*}}]
+
+// Check for predicated epilogue for dynamic loop.
+// CHECK-LABEL: func.func @dynamic_loop_result
+// CHECK: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
+// CHECK: %[[ADDF_13:.*]] = arith.addf %[[ARG7]], %[[ARG6]]
+// CHECK: %[[MULF_14:.*]] = arith.mulf %[[ADDF_13]], %{{.*}}
+// CHECK: %[[ADDI_15:.*]] = arith.addi %[[ARG5]], %{{.*}}
+// CHECK: %[[LOAD_16:.*]] = memref.load %{{.*}}[%[[ADDI_15]]]
+// CHECK: scf.yield %[[MULF_14]], %[[LOAD_16]]
+// CHECK: }
+// CHECK: %[[SUBI_4:.*]] = arith.subi %{{.*}}, %{{.*}}
+// CHECK: %[[ADDI_5:.*]] = arith.addi %[[SUBI_4]], %{{.*}}
+// CHECK: %[[ADDI_6:.*]] = arith.addi %[[ADDI_5]], %{{.*}}-1
+// CHECK: %[[DIVUI_7:.*]] = arith.divui %[[ADDI_6]], %{{.*}}
+// CHECK: %[[ADDI_8:.*]] = arith.addi %[[DIVUI_7]], %{{.*}}-1
+// CHECK: %[[CMPI_9:.*]] = arith.cmpi sge, %[[ADDI_8]], %{{.*}}
+// CHECK: %[[IF_10:.*]] = scf.if %[[CMPI_9]]
+// CHECK: %[[ADDF_13:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
+// CHECK: scf.yield %[[ADDF_13]]
+// CHECK: } else {
+// CHECK: scf.yield %{{.*}}
+// CHECK: }
+// CHECK: %[[IF_11:.*]] = scf.if %[[CMPI_9]]
+// CHECK: %[[MULF_13:.*]] = arith.mulf %[[IF_10]], %{{.*}}
+// CHECK: scf.yield %[[MULF_13]]
+// CHECK: } else {
+// CHECK: scf.yield %{{.*}}
+// CHECK: }
+// CHECK: %[[SELECT_12:.*]] = arith.select %[[CMPI_9]], %[[IF_11]], %{{.*}}#0
+// CHECK: memref.store %[[SELECT_12]], %{{.*}}[%{{.*}}]
+func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
+ %cf0 = arith.constant 1.0 : f32
+ %cf1 = arith.constant 33.0 : f32
+ %cst = arith.constant 0 : index
+ %res:1 = scf.for %i0 = %lb to %ub step %step iter_args (%arg0 = %cf0) -> (f32) {
+ %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+ %A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
+ %A2_elem = arith.mulf %A1_elem, %cf1 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
+ scf.yield %A2_elem : f32
+ } { __test_pipelining_loop__ }
+ memref.store %res#0, %result[%cst] : memref<?xf32>
+ return
+}
+
+// -----
+
// CHECK-LABEL: yield_constant_loop(
// CHECK-SAME: %[[A:.*]]: memref<?xf32>) -> f32 {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
More information about the Mlir-commits
mailing list