[Mlir-commits] [mlir] [MLIR][SCF] Loop pipelining fails on failed predication (no assert) (PR #107442)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 5 11:18:58 PDT 2024
https://github.com/sjw36 created https://github.com/llvm/llvm-project/pull/107442
The SCFLoopPipelining allows predication on peeled or loop ops. When the predicationFn returns a nullptr this signifies the op type is unsupported and the pipeliner fails except in `emitPrologue` where it asserts.
This patch fixes handling in the prologue to gracefully fail.
>From 50276c685d845b754c58c56dd5e7969e905fdab2 Mon Sep 17 00:00:00 2001
From: SJW <swaters at amd.com>
Date: Thu, 5 Sep 2024 18:02:30 +0000
Subject: [PATCH] [MLIR][SCF] Loop pipelining fails on failed predication (no
assert)
---
mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp | 11 +++++++----
1 file changed, 7 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index d8e1cc0ecef88e..07bec5ee1ce1f7 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -77,7 +77,7 @@ struct LoopPipelinerInternal {
bool initializeLoopInfo(ForOp op, const PipeliningOption &options);
/// Emits the prologue, this creates `maxStage - 1` part which will contain
/// operations from stages [0; i], where i is the part index.
- void emitPrologue(RewriterBase &rewriter);
+ LogicalResult emitPrologue(RewriterBase &rewriter);
/// Gather liverange information for Values that are used in a different stage
/// than its definition.
llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues();
@@ -267,7 +267,7 @@ cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
return clone;
}
-void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
+LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
// Initialize the iteration argument to the loop initial values.
for (auto [arg, operand] :
llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
@@ -314,7 +314,8 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
int predicateIdx = i - stages[op];
if (predicates[predicateIdx]) {
newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
- assert(newOp && "failed to predicate op.");
+ if (newOp == nullptr)
+ return failure();
}
rewriter.setInsertionPointAfter(newOp);
if (annotateFn)
@@ -343,6 +344,7 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
}
}
}
+ return success();
}
llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
@@ -733,7 +735,8 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
*modifiedIR = true;
// 1. Emit prologue.
- pipeliner.emitPrologue(rewriter);
+ if (failed(pipeliner.emitPrologue(rewriter)))
+ return failure();
// 2. Track values used across stages. When a value cross stages it will
// need to be passed as loop iteration arguments.
More information about the Mlir-commits
mailing list