[Mlir-commits] [mlir] [mlir][scf] Align `scf.while` `before` block args in canonicalizer (PR #76195)

Ivan Butygin llvmlistbot at llvm.org
Sat Mar 30 06:45:37 PDT 2024


https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/76195

>From fa1a29f3e8502d24b8f7274bfc185fd9a631f25d Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Fri, 22 Dec 2023 00:37:45 +0100
Subject: [PATCH 1/2] [mlir][scf] Align `scf.while` `before` block args in
 canonicalizer

If `before` block args are directly forwarded to `scf.condition` make sure they are passes in the same order.
This is needed for `scf.while` uplifting https://github.com/llvm/llvm-project/pull/76108
---
 mlir/lib/Dialect/SCF/IR/SCF.cpp         | 77 ++++++++++++++++++++++++-
 mlir/test/Dialect/SCF/canonicalize.mlir | 29 ++++++++++
 2 files changed, 105 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5bca8e85f889d9..3c3eb6de5986d1 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3884,6 +3884,81 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
     return success();
   }
 };
+
+/// If both ranges contain same values return mappping indices from args1 to
+/// args2. Otherwise return std::nullopt
+static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
+                                                           ValueRange args2) {
+  if (args1.size() != args2.size())
+    return std::nullopt;
+
+  SmallVector<unsigned> ret(args1.size());
+  for (auto &&[i, arg1] : llvm::enumerate(args1)) {
+    auto it = llvm::find(args2, arg1);
+    if (it == args2.end())
+      return std::nullopt;
+
+    auto j = it - args2.begin();
+    ret[j] = static_cast<unsigned>(i);
+  }
+
+  return ret;
+}
+
+/// If `before` block args are directly forwarded to `scf.condition`, rearrange
+/// `scf.condition` args into same order as block args. Update `after` block
+// args and results values accordingly.
+/// Needed to simplify `scf.while` -> `scf.for` uplifting.
+struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(WhileOp loop,
+                                PatternRewriter &rewriter) const override {
+    auto oldBefore = loop.getBeforeBody();
+    ConditionOp oldTerm = loop.getConditionOp();
+    ValueRange beforeArgs = oldBefore->getArguments();
+    ValueRange termArgs = oldTerm.getArgs();
+    if (beforeArgs == termArgs)
+      return failure();
+
+    auto mapping = getArgsMapping(beforeArgs, termArgs);
+    if (!mapping)
+      return failure();
+
+    {
+      OpBuilder::InsertionGuard g(rewriter);
+      rewriter.setInsertionPoint(oldTerm);
+      rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
+                                               beforeArgs);
+    }
+
+    auto oldAfter = loop.getAfterBody();
+
+    SmallVector<Type> newResultTypes(beforeArgs.size());
+    for (auto &&[i, j] : llvm::enumerate(*mapping))
+      newResultTypes[j] = loop.getResult(i).getType();
+
+    auto newLoop = rewriter.create<WhileOp>(loop.getLoc(), newResultTypes,
+                                            loop.getInits(), nullptr, nullptr);
+    auto newBefore = newLoop.getBeforeBody();
+    auto newAfter = newLoop.getAfterBody();
+
+    SmallVector<Value> newResults(beforeArgs.size());
+    SmallVector<Value> newAfterArgs(beforeArgs.size());
+    for (auto &&[i, j] : llvm::enumerate(*mapping)) {
+      newResults[i] = newLoop.getResult(j);
+      newAfterArgs[i] = newAfter->getArgument(j);
+    }
+
+    rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
+                               newBefore->getArguments());
+    rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
+                               newAfterArgs);
+
+    rewriter.replaceOp(loop, newResults);
+    return success();
+  }
+};
 } // namespace
 
 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -3891,7 +3966,7 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<RemoveLoopInvariantArgsFromBeforeBlock,
               RemoveLoopInvariantValueYielded, WhileConditionTruth,
               WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
-              WhileRemoveUnusedArgs>(context);
+              WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 52e0fdfa36d6cd..b4c9ed4db94e0e 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1198,6 +1198,35 @@ func.func @while_unused_arg2(%val0: i32) -> i32 {
 // CHECK:         return %[[RES]] : i32
 
 
+// -----
+
+// CHECK-LABEL: func @test_align_args
+//       CHECK:  %[[RES:.*]]:3 = scf.while (%[[ARG0:.*]] = %{{.*}}, %[[ARG1:.*]] = %{{.*}}, %[[ARG2:.*]] = %{{.*}}) : (f32, i32, i64) -> (f32, i32, i64) {
+//       CHECK:  scf.condition(%{{.*}}) %[[ARG0]], %[[ARG1]], %[[ARG2]] : f32, i32, i64
+//       CHECK:  ^bb0(%[[ARG3:.*]]: f32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i64):
+//       CHECK:  %[[R1:.*]] = "test.test"(%[[ARG5]]) : (i64) -> f32
+//       CHECK:  %[[R2:.*]] = "test.test"(%[[ARG3]]) : (f32) -> i32
+//       CHECK:  %[[R3:.*]] = "test.test"(%[[ARG4]]) : (i32) -> i64
+//       CHECK:  scf.yield %[[R1]], %[[R2]], %[[R3]] : f32, i32, i64
+//       CHECK:  return %[[RES]]#2, %[[RES]]#0, %[[RES]]#1
+func.func @test_align_args() -> (i64, f32, i32) {
+  %0 = "test.test"() : () -> (f32)
+  %1 = "test.test"() : () -> (i32)
+  %2 = "test.test"() : () -> (i64)
+  %3:3 = scf.while (%arg0 = %0, %arg1 = %1, %arg2 = %2) : (f32, i32, i64) -> (i64, f32, i32) {
+    %cond = "test.test"() : () -> (i1)
+    scf.condition(%cond) %arg2, %arg0, %arg1 : i64, f32, i32
+  } do {
+  ^bb0(%arg3: i64, %arg4: f32, %arg5: i32):
+    %4 = "test.test"(%arg3) : (i64) -> (f32)
+    %5 = "test.test"(%arg4) : (f32) -> (i32)
+    %6 = "test.test"(%arg5) : (i32) -> (i64)
+    scf.yield %4, %5, %6 : f32, i32, i64
+  }
+  return %3#0, %3#1, %3#2 : i64, f32, i32
+}
+
+
 // -----
 
 // CHECK-LABEL: @combineIfs

>From 24797e2bfae3c2dbb7e72a83e1b48d1fb51720c8 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 30 Mar 2024 14:43:44 +0100
Subject: [PATCH 2/2] review comments

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp | 23 ++++++++++++++++++-----
 1 file changed, 18 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 3c3eb6de5986d1..a04913610fcb61 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3885,8 +3885,8 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
   }
 };
 
-/// If both ranges contain same values return mappping indices from args1 to
-/// args2. Otherwise return std::nullopt
+/// If both ranges contain same values return mappping indices from args2 to
+/// args1. Otherwise return std::nullopt.
 static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
                                                            ValueRange args2) {
   if (args1.size() != args2.size())
@@ -3898,16 +3898,26 @@ static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
     if (it == args2.end())
       return std::nullopt;
 
-    auto j = it - args2.begin();
-    ret[j] = static_cast<unsigned>(i);
+    ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);
   }
 
   return ret;
 }
 
+static bool hasDuplicates(ValueRange args) {
+  llvm::SmallDenseSet<Value> set;
+  for (Value arg : args) {
+    if (set.contains(arg))
+      return true;
+
+    set.insert(arg);
+  }
+  return false;
+}
+
 /// If `before` block args are directly forwarded to `scf.condition`, rearrange
 /// `scf.condition` args into same order as block args. Update `after` block
-// args and results values accordingly.
+// args and op result values accordingly.
 /// Needed to simplify `scf.while` -> `scf.for` uplifting.
 struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -3921,6 +3931,9 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
     if (beforeArgs == termArgs)
       return failure();
 
+    if (hasDuplicates(termArgs))
+      return failure();
+
     auto mapping = getArgsMapping(beforeArgs, termArgs);
     if (!mapping)
       return failure();



More information about the Mlir-commits mailing list