[Mlir-commits] [mlir] [mlir][scf] Align `scf.while` `before` block args in canonicalizer (PR #76195)
Matthias Springer
llvmlistbot at llvm.org
Thu Dec 21 18:41:09 PST 2023
================
@@ -3872,14 +3872,89 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
return success();
}
};
+
+/// If both ranges contain same values return mappping indices from args1 to
+/// args2. Otherwise return std::nullopt
+static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
+ ValueRange args2) {
+ if (args1.size() != args2.size())
+ return std::nullopt;
+
+ SmallVector<unsigned> ret(args1.size());
+ for (auto &&[i, arg1] : llvm::enumerate(args1)) {
+ auto it = llvm::find(args2, arg1);
+ if (it == args2.end())
+ return std::nullopt;
+
+ auto j = it - args2.begin();
+ ret[j] = static_cast<unsigned>(i);
+ }
+
+ return ret;
+}
+
+/// If `before` block args are directly forwarded to `scf.condition`, rearrange
+/// `scf.condition` args into same order as block args. Update `after` block
+// args and results values accordingly.
+/// Needed to simplify `scf.while` -> `scf.for` uplifting.
+struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(WhileOp loop,
+ PatternRewriter &rewriter) const override {
+ auto oldBefore = loop.getBeforeBody();
+ ConditionOp oldTerm = loop.getConditionOp();
+ ValueRange beforeArgs = oldBefore->getArguments();
+ ValueRange termArgs = oldTerm.getArgs();
+ if (beforeArgs == termArgs)
+ return failure();
+
+ auto mapping = getArgsMapping(beforeArgs, termArgs);
+ if (!mapping)
+ return failure();
+
+ {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(oldTerm);
+ rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
+ beforeArgs);
+ }
+
+ auto oldAfter = loop.getAfterBody();
+
+ SmallVector<Type> newResultTypes(beforeArgs.size());
+ for (auto &&[i, j] : llvm::enumerate(*mapping))
+ newResultTypes[j] = loop.getResult(i).getType();
+
+ auto newLoop = rewriter.create<WhileOp>(loop.getLoc(), newResultTypes,
+ loop.getInits(), nullptr, nullptr);
----------------
matthias-springer wrote:
nit: spell out variable names for `nullptr` arguments, i.e., `/*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr`
https://github.com/llvm/llvm-project/pull/76195
More information about the Mlir-commits
mailing list