[Mlir-commits] [mlir] 2ab2775 - Revert "[MLIR][SCF] Inline ExecuteRegion if parent can contain multiple blocks"

William S. Moses llvmlistbot at llvm.org
Mon Jun 28 10:53:19 PDT 2021


Author: William S. Moses
Date: 2021-06-28T13:52:30-04:00
New Revision: 2ab27758d5c5e7985cee1a2651bc0a9ee4c2d8c9

URL: https://github.com/llvm/llvm-project/commit/2ab27758d5c5e7985cee1a2651bc0a9ee4c2d8c9
DIFF: https://github.com/llvm/llvm-project/commit/2ab27758d5c5e7985cee1a2651bc0a9ee4c2d8c9.diff

LOG: Revert "[MLIR][SCF] Inline ExecuteRegion if parent can contain multiple blocks"

This reverts commit 5d6240b77e7e7199fcf0e89f6dd2f7eea3596a3c.

The commit was mistakenly landed without a PR approval, this will be
reverted now and resubmitted.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/SCFOps.td
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index 9f039b6fcda68..c10441f59bd55 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -108,8 +108,14 @@ def ExecuteRegionOp : SCF_Op<"execute_region"> {
 
   let regions = (region AnyRegion:$region);
 
+  // TODO: If the parent is a func like op (which would be the case if all other
+  // ops are from the std dialect), the inliner logic could be readily used to
+  // inline.
   let hasCanonicalizer = 1;
 
+  // TODO: can fold if it returns a constant.
+  // TODO: Single block execute_region ops can be readily inlined irrespective
+  // of which op is a parent. Add a fold for this.
   let hasFolder = 0;
 }
 

diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 38760ca4050d3..99d2386ced1b1 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -143,94 +143,23 @@ static LogicalResult verify(ExecuteRegionOp op) {
 //
 //     "test.foo"() : () -> ()
 //     %x = "test.val"() : () -> i64
-//     "test.bar"(%x) : (i64) -> ()
+//     "test.bar"(%v) : (i64) -> ()
 //
 struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
   using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(ExecuteRegionOp op,
                                 PatternRewriter &rewriter) const override {
-    if (!llvm::hasSingleElement(op.region()))
+    if (op.region().getBlocks().size() != 1)
       return failure();
     replaceOpWithRegion(rewriter, op, op.region());
     return success();
   }
 };
 
-// Inline an ExecuteRegionOp if its parent can contain multiple blocks.
-// TODO generalize the conditions for operations which can be inlined into.
-// func @func_execute_region_elim() {
-//     "test.foo"() : () -> ()
-//     %v = scf.execute_region -> i64 {
-//       %c = "test.cmp"() : () -> i1
-//       cond_br %c, ^bb2, ^bb3
-//     ^bb2:
-//       %x = "test.val1"() : () -> i64
-//       br ^bb4(%x : i64)
-//     ^bb3:
-//       %y = "test.val2"() : () -> i64
-//       br ^bb4(%y : i64)
-//     ^bb4(%z : i64):
-//       scf.yield %z : i64
-//     }
-//     "test.bar"(%v) : (i64) -> ()
-//   return
-// }
-//
-//  becomes
-//
-// func @func_execute_region_elim() {
-//    "test.foo"() : () -> ()
-//    %c = "test.cmp"() : () -> i1
-//    cond_br %c, ^bb1, ^bb2
-//  ^bb1:  // pred: ^bb0
-//    %x = "test.val1"() : () -> i64
-//    br ^bb3(%x : i64)
-//  ^bb2:  // pred: ^bb0
-//    %y = "test.val2"() : () -> i64
-//    br ^bb3(%y : i64)
-//  ^bb3(%z: i64):  // 2 preds: ^bb1, ^bb2
-//    "test.bar"(%z) : (i64) -> ()
-//    return
-//  }
-//
-struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
-  using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExecuteRegionOp op,
-                                PatternRewriter &rewriter) const override {
-    if (!isa<FuncOp, ExecuteRegionOp>(op->getParentOp()))
-      return failure();
-
-    Block *prevBlock = op->getBlock();
-    Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
-    rewriter.setInsertionPointToEnd(prevBlock);
-
-    rewriter.create<BranchOp>(op.getLoc(), &op.region().front());
-
-    for (Block &blk : op.region()) {
-      if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
-        rewriter.setInsertionPoint(yieldOp);
-        rewriter.create<BranchOp>(yieldOp.getLoc(), postBlock,
-                                  yieldOp.results());
-        rewriter.eraseOp(yieldOp);
-      }
-    }
-
-    rewriter.inlineRegionBefore(op.region(), postBlock);
-    SmallVector<Value> blockArgs;
-
-    for (auto res : op.getResults())
-      blockArgs.push_back(postBlock->addArgument(res.getType()));
-
-    rewriter.replaceOp(op, blockArgs);
-    return success();
-  }
-};
-
 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
-  results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
+  results.add<SingleBlockExecuteInliner>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index d1789c6dfde52..8692f2d9705e0 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -948,70 +948,3 @@ func @execute_region_elim() {
 // CHECK-NEXT:       "test.bar"(%[[VAL]]) : (i64) -> ()
 // CHECK-NEXT:     }
 
-
-// -----
-
-// CHECK-LABEL: func @func_execute_region_elim
-func @func_execute_region_elim() {
-    "test.foo"() : () -> ()
-    %v = scf.execute_region -> i64 {
-      %c = "test.cmp"() : () -> i1
-      cond_br %c, ^bb2, ^bb3
-    ^bb2:
-      %x = "test.val1"() : () -> i64
-      br ^bb4(%x : i64)
-    ^bb3:
-      %y = "test.val2"() : () -> i64
-      br ^bb4(%y : i64)
-    ^bb4(%z : i64):
-      scf.yield %z : i64
-    }
-    "test.bar"(%v) : (i64) -> ()
-  return
-}
-
-// CHECK:     "test.foo"
-// CHECK:     %[[cmp:.+]] = "test.cmp"
-// CHECK:     cond_br %[[cmp]], ^[[bb1:.+]], ^[[bb2:.+]]
-// CHECK:   ^[[bb1]]:  // pred: ^bb0
-// CHECK:     %[[x:.+]] = "test.val1"
-// CHECK:     br ^[[bb3:.+]](%[[x]] : i64)
-// CHECK:   ^[[bb2]]:  // pred: ^bb0
-// CHECK:     %[[y:.+]] = "test.val2"
-// CHECK:     br ^[[bb3]](%[[y:.+]] : i64)
-// CHECK:   ^[[bb3]](%[[z:.+]]: i64):
-// CHECK:     "test.bar"(%[[z]])
-// CHECK:     return
-
-
-// -----
-
-// CHECK-LABEL: func @func_execute_region_elim2
-func @func_execute_region_elim2() {
-    "test.foo"() : () -> ()
-    %v = scf.execute_region -> i64 {
-      %c = "test.cmp"() : () -> i1
-      cond_br %c, ^bb2, ^bb3
-    ^bb2:
-      %x = "test.val1"() : () -> i64
-      scf.yield %x : i64
-    ^bb3:
-      %y = "test.val2"() : () -> i64
-      scf.yield %y : i64
-    }
-    "test.bar"(%v) : (i64) -> ()
-  return
-}
-
-// CHECK:     "test.foo"
-// CHECK:     %[[cmp:.+]] = "test.cmp"
-// CHECK:     cond_br %[[cmp]], ^[[bb1:.+]], ^[[bb2:.+]]
-// CHECK:   ^[[bb1]]:  // pred: ^bb0
-// CHECK:     %[[x:.+]] = "test.val1"
-// CHECK:     br ^[[bb3:.+]](%[[x]] : i64)
-// CHECK:   ^[[bb2]]:  // pred: ^bb0
-// CHECK:     %[[y:.+]] = "test.val2"
-// CHECK:     br ^[[bb3]](%[[y:.+]] : i64)
-// CHECK:   ^[[bb3]](%[[z:.+]]: i64):
-// CHECK:     "test.bar"(%[[z]])
-// CHECK:     return


        


More information about the Mlir-commits mailing list