[Mlir-commits] [mlir] [mlir][scf] Align `scf.while` `before` block args in canonicalizer (PR #76195)

Mehdi Amini llvmlistbot at llvm.org
Fri Dec 22 02:18:20 PST 2023


================
@@ -3872,14 +3872,89 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
     return success();
   }
 };
+
+/// If both ranges contain same values return mappping indices from args1 to
+/// args2. Otherwise return std::nullopt
+static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
+                                                           ValueRange args2) {
+  if (args1.size() != args2.size())
+    return std::nullopt;
+
+  SmallVector<unsigned> ret(args1.size());
+  for (auto &&[i, arg1] : llvm::enumerate(args1)) {
+    auto it = llvm::find(args2, arg1);
+    if (it == args2.end())
+      return std::nullopt;
+
+    auto j = it - args2.begin();
+    ret[j] = static_cast<unsigned>(i);
+  }
+
+  return ret;
+}
+
+/// If `before` block args are directly forwarded to `scf.condition`, rearrange
+/// `scf.condition` args into same order as block args. Update `after` block
+// args and results values accordingly.
+/// Needed to simplify `scf.while` -> `scf.for` uplifting.
+struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(WhileOp loop,
+                                PatternRewriter &rewriter) const override {
+    auto oldBefore = loop.getBeforeBody();
+    ConditionOp oldTerm = loop.getConditionOp();
+    ValueRange beforeArgs = oldBefore->getArguments();
+    ValueRange termArgs = oldTerm.getArgs();
+    if (beforeArgs == termArgs)
+      return failure();
+
+    auto mapping = getArgsMapping(beforeArgs, termArgs);
+    if (!mapping)
+      return failure();
+
+    {
+      OpBuilder::InsertionGuard g(rewriter);
+      rewriter.setInsertionPoint(oldTerm);
+      rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
+                                               beforeArgs);
+    }
+
+    auto oldAfter = loop.getAfterBody();
+
+    SmallVector<Type> newResultTypes(beforeArgs.size());
+    for (auto &&[i, j] : llvm::enumerate(*mapping))
+      newResultTypes[j] = loop.getResult(i).getType();
+
+    auto newLoop = rewriter.create<WhileOp>(loop.getLoc(), newResultTypes,
+                                            loop.getInits(), nullptr, nullptr);
----------------
joker-eph wrote:

Better: we should be able to remove these two `nullptr` entirely, there is another build method that should match.

https://github.com/llvm/llvm-project/pull/76195


More information about the Mlir-commits mailing list