[Mlir-commits] [mlir] e78d341 - [mlir][scf] More WhileOp canonicalizations

Ivan Butygin llvmlistbot at llvm.org
Wed Apr 12 07:40:53 PDT 2023


Author: Ivan Butygin
Date: 2023-04-12T16:39:05+02:00
New Revision: e78d341f936e52df92279cc5bff1bff0625b7d23

URL: https://github.com/llvm/llvm-project/commit/e78d341f936e52df92279cc5bff1bff0625b7d23
DIFF: https://github.com/llvm/llvm-project/commit/e78d341f936e52df92279cc5bff1bff0625b7d23.diff

LOG: [mlir][scf] More WhileOp canonicalizations

Remove duplicated ConditonOp args, remove unused init/yield args.

Differential Revision: https://reviews.llvm.org/D146252

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir
    mlir/test/Dialect/SCF/one-shot-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 3eda0d68f5fbb..b43615c9f1933 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3785,13 +3785,129 @@ struct WhileUnusedArg : public OpRewritePattern<WhileOp> {
     return success();
   }
 };
+
+/// Remove duplicated ConditionOp args.
+struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(WhileOp op,
+                                PatternRewriter &rewriter) const override {
+    Block &beforeBlock = op.getBefore().front();
+    Block &afterBlock = op.getAfter().front();
+
+    auto condOp = cast<ConditionOp>(beforeBlock.getTerminator());
+    ValueRange condOpArgs = condOp.getArgs();
+    llvm::SmallDenseMap<Value, unsigned> argsMap;
+    SmallVector<Value> newArgs;
+    for (auto arg : condOpArgs) {
+      if (!argsMap.count(arg)) {
+        auto pos = static_cast<unsigned>(argsMap.size());
+        argsMap.insert({arg, pos});
+        newArgs.emplace_back(arg);
+      }
+    }
+
+    if (argsMap.size() == condOpArgs.size())
+      return rewriter.notifyMatchFailure(op, "No results to remove");
+
+    ValueRange argsRange(newArgs);
+    auto emptyBuilder = [](OpBuilder &, Location, ValueRange) {
+      // Nothing
+    };
+
+    Location loc = op.getLoc();
+    auto newWhileOp = rewriter.create<scf::WhileOp>(
+        loc, argsRange.getTypes(), op.getInits(), emptyBuilder, emptyBuilder);
+    Block &newBeforeBlock = newWhileOp.getBefore().front();
+    Block &newAfterBlock = newWhileOp.getAfter().front();
+
+    SmallVector<Value> afterArgsMapping;
+    SmallVector<Value> resultsMapping;
+    for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
+      auto it = argsMap.find(arg);
+      assert(it != argsMap.end());
+      auto pos = it->second;
+      afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));
+      resultsMapping.emplace_back(newWhileOp->getResult(pos));
+    }
+
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(condOp);
+    rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
+                                             argsRange);
+
+    rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
+                         newBeforeBlock.getArguments());
+    rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
+    rewriter.replaceOp(op, resultsMapping);
+    return success();
+  }
+};
+
+/// Remove unused init/yield args.
+struct WhileRemoveUnusedArgs : public mlir::OpRewritePattern<WhileOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(WhileOp op,
+                                PatternRewriter &rewriter) const override {
+    Block &beforeBlock = op.getBefore().front();
+    Block &afterBlock = op.getAfter().front();
+
+    auto yield = cast<YieldOp>(afterBlock.getTerminator());
+
+    llvm::BitVector argsToRemove;
+    SmallVector<Value> newInits;
+    SmallVector<Value> newYieldArgs;
+
+    bool changed = false;
+    for (auto &&[arg, init, yieldArg] : llvm::zip(
+             beforeBlock.getArguments(), op.getInits(), yield.getResults())) {
+      bool empty = arg.use_empty();
+      argsToRemove.push_back(empty);
+      if (empty) {
+        changed = true;
+        continue;
+      }
+
+      newInits.emplace_back(init);
+      newYieldArgs.emplace_back(yieldArg);
+    }
+
+    if (!changed)
+      return rewriter.notifyMatchFailure(op, "No args to remove");
+
+    beforeBlock.eraseArguments(argsToRemove);
+
+    auto emptyBuilder = [](OpBuilder &, Location, ValueRange) {
+      // Nothing
+    };
+
+    Location loc = op.getLoc();
+    auto newWhileOp = rewriter.create<WhileOp>(
+        loc, op->getResultTypes(), newInits, emptyBuilder, emptyBuilder);
+    Block &newBeforeBlock = newWhileOp.getBefore().front();
+    Block &newAfterBlock = newWhileOp.getAfter().front();
+
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(yield);
+    rewriter.replaceOpWithNewOp<YieldOp>(yield, newYieldArgs);
+
+    rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
+                         newBeforeBlock.getArguments());
+    rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
+                         newAfterBlock.getArguments());
+    rewriter.replaceOp(op, newWhileOp.getResults());
+    return success();
+  }
+};
 } // namespace
 
 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                           MLIRContext *context) {
   results.add<RemoveLoopInvariantArgsFromBeforeBlock,
               RemoveLoopInvariantValueYielded, WhileConditionTruth,
-              WhileCmpCond, WhileUnusedResult>(context);
+              WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
+              WhileRemoveUnusedArgs>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index ec6e35a200075..355d343442c46 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1195,6 +1195,60 @@ func.func @while_cmp_rhs(%arg0 : i32) {
 // CHECK-NEXT:           scf.yield
 // CHECK-NEXT:         }
 
+// -----
+
+// CHECK-LABEL: @while_duplicated_res
+func.func @while_duplicated_res() -> (i32, i32) {
+  %0:2 = scf.while () : () -> (i32, i32) {
+    %val = "test.val"() : () -> i32
+    %condition = "test.condition"() : () -> i1
+    scf.condition(%condition) %val, %val : i32, i32
+  } do {
+  ^bb0(%val2: i32, %val3: i32):
+    "test.use"(%val2, %val3) : (i32, i32) -> ()
+    scf.yield
+  }
+  return %0#0, %0#1: i32, i32
+}
+// CHECK:         %[[RES:.*]] = scf.while : () -> i32 {
+// CHECK:         %[[VAL:.*]] = "test.val"() : () -> i32
+// CHECK:         %[[COND:.*]] = "test.condition"() : () -> i1
+// CHECK:           scf.condition(%[[COND]]) %[[VAL]] : i32
+// CHECK:         } do {
+// CHECK:         ^bb0(%[[ARG:.*]]: i32):
+// CHECK:           "test.use"(%[[ARG]], %[[ARG]]) : (i32, i32) -> ()
+// CHECK:           scf.yield
+// CHECK:         }
+// CHECK:         return %[[RES]], %[[RES]] : i32, i32
+
+// -----
+
+// CHECK-LABEL: @while_unused_arg
+func.func @while_unused_arg(%val0: i32) -> i32 {
+  %0 = scf.while (%val1 = %val0) : (i32) -> i32 {
+    %val = "test.val"() : () -> i32
+    %condition = "test.condition"() : () -> i1
+    scf.condition(%condition) %val: i32
+  } do {
+  ^bb0(%val2: i32):
+    "test.use"(%val2) : (i32) -> ()
+    %val1 = "test.val1"() : () -> i32
+    scf.yield %val1 : i32
+  }
+  return %0 : i32
+}
+// CHECK:         %[[RES:.*]] = scf.while : () -> i32 {
+// CHECK:         %[[VAL:.*]] = "test.val"() : () -> i32
+// CHECK:         %[[COND:.*]] = "test.condition"() : () -> i1
+// CHECK:           scf.condition(%[[COND]]) %[[VAL]] : i32
+// CHECK:         } do {
+// CHECK:         ^bb0(%[[ARG:.*]]: i32):
+// CHECK:           "test.use"(%[[ARG]]) : (i32) -> ()
+// CHECK:           scf.yield
+// CHECK:         }
+// CHECK:         return %[[RES]] : i32
+
+
 // -----
 
 // CHECK-LABEL: @combineIfs

diff  --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index 6a7bd40e3bbee..e23e86184c311 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -518,6 +518,8 @@ func.func @scf_while_iter_arg_result_mismatch(%arg0: tensor<5xi1>,
                                               %arg2: index) {
   scf.while (%arg3 = %arg1) : (tensor<5xi1>) -> () {
     %0 = tensor.extract %arg0[%arg2] : tensor<5xi1>
+    %1 = tensor.extract %arg3[%arg2] : tensor<5xi1>
+    "dummy.use"(%1) : (i1) -> ()
     scf.condition(%0)
   } do {
     %0 = "dummy.some_op"() : () -> index


        


More information about the Mlir-commits mailing list