[Mlir-commits] [mlir] [mlir][scf] Align `scf.while` `before` block args in canonicalizer (PR #76195)
Ivan Butygin
llvmlistbot at llvm.org
Sat Mar 30 06:45:37 PDT 2024
https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/76195
>From fa1a29f3e8502d24b8f7274bfc185fd9a631f25d Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Fri, 22 Dec 2023 00:37:45 +0100
Subject: [PATCH 1/2] [mlir][scf] Align `scf.while` `before` block args in
canonicalizer
If `before` block args are directly forwarded to `scf.condition` make sure they are passes in the same order.
This is needed for `scf.while` uplifting https://github.com/llvm/llvm-project/pull/76108
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 77 ++++++++++++++++++++++++-
mlir/test/Dialect/SCF/canonicalize.mlir | 29 ++++++++++
2 files changed, 105 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5bca8e85f889d9..3c3eb6de5986d1 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3884,6 +3884,81 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
return success();
}
};
+
+/// If both ranges contain same values return mappping indices from args1 to
+/// args2. Otherwise return std::nullopt
+static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
+ ValueRange args2) {
+ if (args1.size() != args2.size())
+ return std::nullopt;
+
+ SmallVector<unsigned> ret(args1.size());
+ for (auto &&[i, arg1] : llvm::enumerate(args1)) {
+ auto it = llvm::find(args2, arg1);
+ if (it == args2.end())
+ return std::nullopt;
+
+ auto j = it - args2.begin();
+ ret[j] = static_cast<unsigned>(i);
+ }
+
+ return ret;
+}
+
+/// If `before` block args are directly forwarded to `scf.condition`, rearrange
+/// `scf.condition` args into same order as block args. Update `after` block
+// args and results values accordingly.
+/// Needed to simplify `scf.while` -> `scf.for` uplifting.
+struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(WhileOp loop,
+ PatternRewriter &rewriter) const override {
+ auto oldBefore = loop.getBeforeBody();
+ ConditionOp oldTerm = loop.getConditionOp();
+ ValueRange beforeArgs = oldBefore->getArguments();
+ ValueRange termArgs = oldTerm.getArgs();
+ if (beforeArgs == termArgs)
+ return failure();
+
+ auto mapping = getArgsMapping(beforeArgs, termArgs);
+ if (!mapping)
+ return failure();
+
+ {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(oldTerm);
+ rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
+ beforeArgs);
+ }
+
+ auto oldAfter = loop.getAfterBody();
+
+ SmallVector<Type> newResultTypes(beforeArgs.size());
+ for (auto &&[i, j] : llvm::enumerate(*mapping))
+ newResultTypes[j] = loop.getResult(i).getType();
+
+ auto newLoop = rewriter.create<WhileOp>(loop.getLoc(), newResultTypes,
+ loop.getInits(), nullptr, nullptr);
+ auto newBefore = newLoop.getBeforeBody();
+ auto newAfter = newLoop.getAfterBody();
+
+ SmallVector<Value> newResults(beforeArgs.size());
+ SmallVector<Value> newAfterArgs(beforeArgs.size());
+ for (auto &&[i, j] : llvm::enumerate(*mapping)) {
+ newResults[i] = newLoop.getResult(j);
+ newAfterArgs[i] = newAfter->getArgument(j);
+ }
+
+ rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
+ newBefore->getArguments());
+ rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
+ newAfterArgs);
+
+ rewriter.replaceOp(loop, newResults);
+ return success();
+ }
+};
} // namespace
void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -3891,7 +3966,7 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<RemoveLoopInvariantArgsFromBeforeBlock,
RemoveLoopInvariantValueYielded, WhileConditionTruth,
WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
- WhileRemoveUnusedArgs>(context);
+ WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 52e0fdfa36d6cd..b4c9ed4db94e0e 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1198,6 +1198,35 @@ func.func @while_unused_arg2(%val0: i32) -> i32 {
// CHECK: return %[[RES]] : i32
+// -----
+
+// CHECK-LABEL: func @test_align_args
+// CHECK: %[[RES:.*]]:3 = scf.while (%[[ARG0:.*]] = %{{.*}}, %[[ARG1:.*]] = %{{.*}}, %[[ARG2:.*]] = %{{.*}}) : (f32, i32, i64) -> (f32, i32, i64) {
+// CHECK: scf.condition(%{{.*}}) %[[ARG0]], %[[ARG1]], %[[ARG2]] : f32, i32, i64
+// CHECK: ^bb0(%[[ARG3:.*]]: f32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i64):
+// CHECK: %[[R1:.*]] = "test.test"(%[[ARG5]]) : (i64) -> f32
+// CHECK: %[[R2:.*]] = "test.test"(%[[ARG3]]) : (f32) -> i32
+// CHECK: %[[R3:.*]] = "test.test"(%[[ARG4]]) : (i32) -> i64
+// CHECK: scf.yield %[[R1]], %[[R2]], %[[R3]] : f32, i32, i64
+// CHECK: return %[[RES]]#2, %[[RES]]#0, %[[RES]]#1
+func.func @test_align_args() -> (i64, f32, i32) {
+ %0 = "test.test"() : () -> (f32)
+ %1 = "test.test"() : () -> (i32)
+ %2 = "test.test"() : () -> (i64)
+ %3:3 = scf.while (%arg0 = %0, %arg1 = %1, %arg2 = %2) : (f32, i32, i64) -> (i64, f32, i32) {
+ %cond = "test.test"() : () -> (i1)
+ scf.condition(%cond) %arg2, %arg0, %arg1 : i64, f32, i32
+ } do {
+ ^bb0(%arg3: i64, %arg4: f32, %arg5: i32):
+ %4 = "test.test"(%arg3) : (i64) -> (f32)
+ %5 = "test.test"(%arg4) : (f32) -> (i32)
+ %6 = "test.test"(%arg5) : (i32) -> (i64)
+ scf.yield %4, %5, %6 : f32, i32, i64
+ }
+ return %3#0, %3#1, %3#2 : i64, f32, i32
+}
+
+
// -----
// CHECK-LABEL: @combineIfs
>From 24797e2bfae3c2dbb7e72a83e1b48d1fb51720c8 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 30 Mar 2024 14:43:44 +0100
Subject: [PATCH 2/2] review comments
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 23 ++++++++++++++++++-----
1 file changed, 18 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 3c3eb6de5986d1..a04913610fcb61 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3885,8 +3885,8 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
}
};
-/// If both ranges contain same values return mappping indices from args1 to
-/// args2. Otherwise return std::nullopt
+/// If both ranges contain same values return mappping indices from args2 to
+/// args1. Otherwise return std::nullopt.
static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
ValueRange args2) {
if (args1.size() != args2.size())
@@ -3898,16 +3898,26 @@ static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
if (it == args2.end())
return std::nullopt;
- auto j = it - args2.begin();
- ret[j] = static_cast<unsigned>(i);
+ ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);
}
return ret;
}
+static bool hasDuplicates(ValueRange args) {
+ llvm::SmallDenseSet<Value> set;
+ for (Value arg : args) {
+ if (set.contains(arg))
+ return true;
+
+ set.insert(arg);
+ }
+ return false;
+}
+
/// If `before` block args are directly forwarded to `scf.condition`, rearrange
/// `scf.condition` args into same order as block args. Update `after` block
-// args and results values accordingly.
+// args and op result values accordingly.
/// Needed to simplify `scf.while` -> `scf.for` uplifting.
struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
using OpRewritePattern::OpRewritePattern;
@@ -3921,6 +3931,9 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
if (beforeArgs == termArgs)
return failure();
+ if (hasDuplicates(termArgs))
+ return failure();
+
auto mapping = getArgsMapping(beforeArgs, termArgs);
if (!mapping)
return failure();
More information about the Mlir-commits
mailing list