[Mlir-commits] [mlir] [mlir][scf] Allow different forwarding ordering in uplift (PR #133117)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 26 09:22:20 PDT 2025


https://github.com/darkbuck created https://github.com/llvm/llvm-project/pull/133117

- Allow 'before' arguments are forwarded in different order to 'after' body when uplifting `scf.while` to `scf.for`.

>From 25d6e13cba7b94c51afb094253d15a547f98f3d9 Mon Sep 17 00:00:00 2001
From: Michael Liao <michael.hliao at gmail.com>
Date: Wed, 26 Mar 2025 02:31:16 -0400
Subject: [PATCH] [mlir][scf] Allow different forwarding ordering in uplift

- Allow 'before' arguments are forwarded in different order to 'after'
  body when uplifting `scf.while` to `scf.for`.
---
 .../SCF/Transforms/UpliftWhileToFor.cpp       | 58 ++++++++++++++++++-
 mlir/test/Dialect/SCF/uplift-while.mlir       | 30 ++++++++++
 2 files changed, 85 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index 7b4024b6861a7..9c4fe702de119 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -48,10 +48,43 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
       diag << "Expected single condition use: " << *cmp;
     });
 
+  std::optional<SmallVector<unsigned>> argReorder;
   // All `before` block args must be directly forwarded to ConditionOp.
   // They will be converted to `scf.for` `iter_vars` except induction var.
-  if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
-    return rewriter.notifyMatchFailure(loop, "Invalid args order");
+  if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs()) {
+    auto getArgReordering =
+        [](Block *beforeBody,
+           scf::ConditionOp cond) -> std::optional<SmallVector<unsigned>> {
+      // Skip further checking if their sizes mismatch.
+      if (beforeBody->getNumArguments() != cond.getArgs().size())
+        return std::nullopt;
+      // Bitset on which 'before' argument is forwarded.
+      BitVector forwarded(beforeBody->getNumArguments(), false);
+      // The forwarding order of 'before' arguments.
+      SmallVector<unsigned> order;
+      for (Value a : cond.getArgs()) {
+        BlockArgument arg = dyn_cast<BlockArgument>(a);
+        // Skip if 'arg' is not a 'before' argument.
+        if (!arg || arg.getOwner() != beforeBody)
+          return std::nullopt;
+        unsigned idx = arg.getArgNumber();
+        // Skip if 'arg' is already forwarded in another place.
+        if (forwarded[idx])
+          return std::nullopt;
+        // Record the presence of 'arg' and its order.
+        forwarded[idx] = true;
+        order.push_back(idx);
+      }
+      // Skip if not all 'before' arguments are forwarded.
+      if (!forwarded.all())
+        return std::nullopt;
+      return order;
+    };
+    // Check if 'before' arguments are all forwarded but just reordered.
+    argReorder = getArgReordering(beforeBody, beforeTerm);
+    if (!argReorder)
+      return rewriter.notifyMatchFailure(loop, "Invalid args order");
+  }
 
   using Pred = arith::CmpIPredicate;
   Pred predicate = cmp.getPredicate();
@@ -100,6 +133,17 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
     });
 
   Block *afterBody = loop.getAfterBody();
+  if (argReorder) {
+    // If forwarded arguments are not the same order as 'before' arguments,
+    // reorder them before converting 'after' body into 'for' body.
+    for (unsigned order : *argReorder) {
+      BlockArgument oldArg = afterBody->getArgument(order);
+      BlockArgument newArg =
+          afterBody->addArgument(oldArg.getType(), oldArg.getLoc());
+      oldArg.replaceAllUsesWith(newArg);
+    }
+    afterBody->eraseArguments(0, argReorder->size());
+  }
   scf::YieldOp afterTerm = loop.getYieldOp();
   unsigned argNumber = inductionVar.getArgNumber();
   Value afterTermIndArg = afterTerm.getResults()[argNumber];
@@ -130,7 +174,7 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
   assert(lb.getType() == ub.getType());
   assert(lb.getType() == step.getType());
 
-  llvm::SmallVector<Value> newArgs;
+  SmallVector<Value> newArgs;
 
   // Populate inits for new `scf.for`, skip induction var.
   newArgs.reserve(loop.getInits().size());
@@ -205,6 +249,14 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
   newArgs.clear();
   llvm::append_range(newArgs, newLoop.getResults());
   newArgs.insert(newArgs.begin() + argNumber, res);
+  if (argReorder) {
+    // If 'yield' arguments (or forwarded arguments) are not the same order as
+    // 'before' arguments (or 'for' results), reorder them.
+    SmallVector<Value> results;
+    for (unsigned order : *argReorder)
+      results.push_back(newArgs[order]);
+    newArgs = results;
+  }
   rewriter.replaceOp(loop, newArgs);
   return newLoop;
 }
diff --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir
index 25ea6142a332d..cbe2ce5076ad2 100644
--- a/mlir/test/Dialect/SCF/uplift-while.mlir
+++ b/mlir/test/Dialect/SCF/uplift-while.mlir
@@ -155,3 +155,33 @@ func.func @uplift_while(%arg0: i64, %arg1: i64, %arg2: i64) -> i64 {
 //       CHECK:     %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : i64
 //       CHECK:     %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : i64
 //       CHECK:     return %[[R7]] : i64
+
+// -----
+
+// A case where all 'before' arguments are forwarded but reordered.
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32) {
+  %c1 = arith.constant 1 : i32
+  %c2 = arith.constant 2.0 : f32
+  %0:3 = scf.while (%arg4 = %c1, %arg3 = %arg0, %arg5 = %c2) : (i32, index, f32) -> (index, i32, f32) {
+    %1 = arith.cmpi slt, %arg3, %arg1 : index
+    scf.condition(%1) %arg3, %arg4, %arg5 : index, i32, f32
+  } do {
+  ^bb0(%arg3: index, %arg4: i32, %arg5: f32):
+    %1 = "test.test1"(%arg4) : (i32) -> i32
+    %added = arith.addi %arg3, %arg2 : index
+    %2 = "test.test2"(%arg5) : (f32) -> f32
+    scf.yield %1, %added, %2 : i32, index, f32
+  }
+  return %0#1, %0#2 : i32, f32
+}
+
+// CHECK-LABEL: func @uplift_while
+//  CHECK-SAME:     (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> (i32, f32)
+//   CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : i32
+//   CHECK-DAG:     %[[C2:.*]] = arith.constant 2.000000e+00 : f32
+//       CHECK:     %[[RES:.*]]:2 = scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]]
+//  CHECK-SAME:     iter_args(%[[ARG1:.*]] = %[[C1]], %[[ARG2:.*]] = %[[C2]]) -> (i32, f32) {
+//       CHECK:     %[[T1:.*]] = "test.test1"(%[[ARG1]]) : (i32) -> i32
+//       CHECK:     %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
+//       CHECK:     scf.yield %[[T1]], %[[T2]] : i32, f32
+//       CHECK:     return %[[RES]]#0, %[[RES]]#1 : i32, f32



More information about the Mlir-commits mailing list