[Mlir-commits] [mlir] [mlir][scf] Allow different forwarding ordering in uplift (PR #133117)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 26 09:22:20 PDT 2025
https://github.com/darkbuck created https://github.com/llvm/llvm-project/pull/133117
- Allow 'before' arguments are forwarded in different order to 'after' body when uplifting `scf.while` to `scf.for`.
>From 25d6e13cba7b94c51afb094253d15a547f98f3d9 Mon Sep 17 00:00:00 2001
From: Michael Liao <michael.hliao at gmail.com>
Date: Wed, 26 Mar 2025 02:31:16 -0400
Subject: [PATCH] [mlir][scf] Allow different forwarding ordering in uplift
- Allow 'before' arguments are forwarded in different order to 'after'
body when uplifting `scf.while` to `scf.for`.
---
.../SCF/Transforms/UpliftWhileToFor.cpp | 58 ++++++++++++++++++-
mlir/test/Dialect/SCF/uplift-while.mlir | 30 ++++++++++
2 files changed, 85 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index 7b4024b6861a7..9c4fe702de119 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -48,10 +48,43 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
diag << "Expected single condition use: " << *cmp;
});
+ std::optional<SmallVector<unsigned>> argReorder;
// All `before` block args must be directly forwarded to ConditionOp.
// They will be converted to `scf.for` `iter_vars` except induction var.
- if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
- return rewriter.notifyMatchFailure(loop, "Invalid args order");
+ if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs()) {
+ auto getArgReordering =
+ [](Block *beforeBody,
+ scf::ConditionOp cond) -> std::optional<SmallVector<unsigned>> {
+ // Skip further checking if their sizes mismatch.
+ if (beforeBody->getNumArguments() != cond.getArgs().size())
+ return std::nullopt;
+ // Bitset on which 'before' argument is forwarded.
+ BitVector forwarded(beforeBody->getNumArguments(), false);
+ // The forwarding order of 'before' arguments.
+ SmallVector<unsigned> order;
+ for (Value a : cond.getArgs()) {
+ BlockArgument arg = dyn_cast<BlockArgument>(a);
+ // Skip if 'arg' is not a 'before' argument.
+ if (!arg || arg.getOwner() != beforeBody)
+ return std::nullopt;
+ unsigned idx = arg.getArgNumber();
+ // Skip if 'arg' is already forwarded in another place.
+ if (forwarded[idx])
+ return std::nullopt;
+ // Record the presence of 'arg' and its order.
+ forwarded[idx] = true;
+ order.push_back(idx);
+ }
+ // Skip if not all 'before' arguments are forwarded.
+ if (!forwarded.all())
+ return std::nullopt;
+ return order;
+ };
+ // Check if 'before' arguments are all forwarded but just reordered.
+ argReorder = getArgReordering(beforeBody, beforeTerm);
+ if (!argReorder)
+ return rewriter.notifyMatchFailure(loop, "Invalid args order");
+ }
using Pred = arith::CmpIPredicate;
Pred predicate = cmp.getPredicate();
@@ -100,6 +133,17 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
});
Block *afterBody = loop.getAfterBody();
+ if (argReorder) {
+ // If forwarded arguments are not the same order as 'before' arguments,
+ // reorder them before converting 'after' body into 'for' body.
+ for (unsigned order : *argReorder) {
+ BlockArgument oldArg = afterBody->getArgument(order);
+ BlockArgument newArg =
+ afterBody->addArgument(oldArg.getType(), oldArg.getLoc());
+ oldArg.replaceAllUsesWith(newArg);
+ }
+ afterBody->eraseArguments(0, argReorder->size());
+ }
scf::YieldOp afterTerm = loop.getYieldOp();
unsigned argNumber = inductionVar.getArgNumber();
Value afterTermIndArg = afterTerm.getResults()[argNumber];
@@ -130,7 +174,7 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
assert(lb.getType() == ub.getType());
assert(lb.getType() == step.getType());
- llvm::SmallVector<Value> newArgs;
+ SmallVector<Value> newArgs;
// Populate inits for new `scf.for`, skip induction var.
newArgs.reserve(loop.getInits().size());
@@ -205,6 +249,14 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
newArgs.clear();
llvm::append_range(newArgs, newLoop.getResults());
newArgs.insert(newArgs.begin() + argNumber, res);
+ if (argReorder) {
+ // If 'yield' arguments (or forwarded arguments) are not the same order as
+ // 'before' arguments (or 'for' results), reorder them.
+ SmallVector<Value> results;
+ for (unsigned order : *argReorder)
+ results.push_back(newArgs[order]);
+ newArgs = results;
+ }
rewriter.replaceOp(loop, newArgs);
return newLoop;
}
diff --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir
index 25ea6142a332d..cbe2ce5076ad2 100644
--- a/mlir/test/Dialect/SCF/uplift-while.mlir
+++ b/mlir/test/Dialect/SCF/uplift-while.mlir
@@ -155,3 +155,33 @@ func.func @uplift_while(%arg0: i64, %arg1: i64, %arg2: i64) -> i64 {
// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : i64
// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : i64
// CHECK: return %[[R7]] : i64
+
+// -----
+
+// A case where all 'before' arguments are forwarded but reordered.
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32) {
+ %c1 = arith.constant 1 : i32
+ %c2 = arith.constant 2.0 : f32
+ %0:3 = scf.while (%arg4 = %c1, %arg3 = %arg0, %arg5 = %c2) : (i32, index, f32) -> (index, i32, f32) {
+ %1 = arith.cmpi slt, %arg3, %arg1 : index
+ scf.condition(%1) %arg3, %arg4, %arg5 : index, i32, f32
+ } do {
+ ^bb0(%arg3: index, %arg4: i32, %arg5: f32):
+ %1 = "test.test1"(%arg4) : (i32) -> i32
+ %added = arith.addi %arg3, %arg2 : index
+ %2 = "test.test2"(%arg5) : (f32) -> f32
+ scf.yield %1, %added, %2 : i32, index, f32
+ }
+ return %0#1, %0#2 : i32, f32
+}
+
+// CHECK-LABEL: func @uplift_while
+// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> (i32, f32)
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[RES:.*]]:2 = scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]]
+// CHECK-SAME: iter_args(%[[ARG1:.*]] = %[[C1]], %[[ARG2:.*]] = %[[C2]]) -> (i32, f32) {
+// CHECK: %[[T1:.*]] = "test.test1"(%[[ARG1]]) : (i32) -> i32
+// CHECK: %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
+// CHECK: scf.yield %[[T1]], %[[T2]] : i32, f32
+// CHECK: return %[[RES]]#0, %[[RES]]#1 : i32, f32
More information about the Mlir-commits
mailing list