[Mlir-commits] [mlir] [mlir][scf] Align `scf.while` `before` block args in canonicalizer (PR #76195)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 21 15:42:43 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-scf
@llvm/pr-subscribers-mlir
Author: Ivan Butygin (Hardcode84)
<details>
<summary>Changes</summary>
If `before` block args are directly forwarded to `scf.condition` make sure they are passes in the same order.
This is needed for `scf.while` uplifting https://github.com/llvm/llvm-project/pull/76108
---
Full diff: https://github.com/llvm/llvm-project/pull/76195.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+76-1)
- (modified) mlir/test/Dialect/SCF/canonicalize.mlir (+29)
``````````diff
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5570c2ec688c8a..de320723ce83f3 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3872,6 +3872,81 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
return success();
}
};
+
+/// If both ranges contain same values return mappping indices from args1 to
+/// args2. Otherwise return std::nullopt
+static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
+ ValueRange args2) {
+ if (args1.size() != args2.size())
+ return std::nullopt;
+
+ SmallVector<unsigned> ret(args1.size());
+ for (auto &&[i, arg1] : llvm::enumerate(args1)) {
+ auto it = llvm::find(args2, arg1);
+ if (it == args2.end())
+ return std::nullopt;
+
+ auto j = it - args2.begin();
+ ret[j] = static_cast<unsigned>(i);
+ }
+
+ return ret;
+}
+
+/// If `before` block args are directly forwarded to `scf.condition`, rearrange
+/// `scf.condition` args into same order as block args. Update `after` block
+// args and results values accordingly.
+/// Needed to simplify `scf.while` -> `scf.for` uplifting.
+struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(WhileOp loop,
+ PatternRewriter &rewriter) const override {
+ auto oldBefore = loop.getBeforeBody();
+ ConditionOp oldTerm = loop.getConditionOp();
+ ValueRange beforeArgs = oldBefore->getArguments();
+ ValueRange termArgs = oldTerm.getArgs();
+ if (beforeArgs == termArgs)
+ return failure();
+
+ auto mapping = getArgsMapping(beforeArgs, termArgs);
+ if (!mapping)
+ return failure();
+
+ {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(oldTerm);
+ rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
+ beforeArgs);
+ }
+
+ auto oldAfter = loop.getAfterBody();
+
+ SmallVector<Type> newResultTypes(beforeArgs.size());
+ for (auto &&[i, j] : llvm::enumerate(*mapping))
+ newResultTypes[j] = loop.getResult(i).getType();
+
+ auto newLoop = rewriter.create<WhileOp>(loop.getLoc(), newResultTypes,
+ loop.getInits(), nullptr, nullptr);
+ auto newBefore = newLoop.getBeforeBody();
+ auto newAfter = newLoop.getAfterBody();
+
+ SmallVector<Value> newResults(beforeArgs.size());
+ SmallVector<Value> newAfterArgs(beforeArgs.size());
+ for (auto &&[i, j] : llvm::enumerate(*mapping)) {
+ newResults[i] = newLoop.getResult(j);
+ newAfterArgs[i] = newAfter->getArgument(j);
+ }
+
+ rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
+ newBefore->getArguments());
+ rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
+ newAfterArgs);
+
+ rewriter.replaceOp(loop, newResults);
+ return success();
+ }
+};
} // namespace
void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -3879,7 +3954,7 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<RemoveLoopInvariantArgsFromBeforeBlock,
RemoveLoopInvariantValueYielded, WhileConditionTruth,
WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
- WhileRemoveUnusedArgs>(context);
+ WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 52e0fdfa36d6cd..b4c9ed4db94e0e 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1198,6 +1198,35 @@ func.func @while_unused_arg2(%val0: i32) -> i32 {
// CHECK: return %[[RES]] : i32
+// -----
+
+// CHECK-LABEL: func @test_align_args
+// CHECK: %[[RES:.*]]:3 = scf.while (%[[ARG0:.*]] = %{{.*}}, %[[ARG1:.*]] = %{{.*}}, %[[ARG2:.*]] = %{{.*}}) : (f32, i32, i64) -> (f32, i32, i64) {
+// CHECK: scf.condition(%{{.*}}) %[[ARG0]], %[[ARG1]], %[[ARG2]] : f32, i32, i64
+// CHECK: ^bb0(%[[ARG3:.*]]: f32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i64):
+// CHECK: %[[R1:.*]] = "test.test"(%[[ARG5]]) : (i64) -> f32
+// CHECK: %[[R2:.*]] = "test.test"(%[[ARG3]]) : (f32) -> i32
+// CHECK: %[[R3:.*]] = "test.test"(%[[ARG4]]) : (i32) -> i64
+// CHECK: scf.yield %[[R1]], %[[R2]], %[[R3]] : f32, i32, i64
+// CHECK: return %[[RES]]#2, %[[RES]]#0, %[[RES]]#1
+func.func @test_align_args() -> (i64, f32, i32) {
+ %0 = "test.test"() : () -> (f32)
+ %1 = "test.test"() : () -> (i32)
+ %2 = "test.test"() : () -> (i64)
+ %3:3 = scf.while (%arg0 = %0, %arg1 = %1, %arg2 = %2) : (f32, i32, i64) -> (i64, f32, i32) {
+ %cond = "test.test"() : () -> (i1)
+ scf.condition(%cond) %arg2, %arg0, %arg1 : i64, f32, i32
+ } do {
+ ^bb0(%arg3: i64, %arg4: f32, %arg5: i32):
+ %4 = "test.test"(%arg3) : (i64) -> (f32)
+ %5 = "test.test"(%arg4) : (f32) -> (i32)
+ %6 = "test.test"(%arg5) : (i32) -> (i64)
+ scf.yield %4, %5, %6 : f32, i32, i64
+ }
+ return %3#0, %3#1, %3#2 : i64, f32, i32
+}
+
+
// -----
// CHECK-LABEL: @combineIfs
``````````
</details>
https://github.com/llvm/llvm-project/pull/76195
More information about the Mlir-commits
mailing list