[Mlir-commits] [mlir] 340e1b2 - [mlir] LoopToStandard conversion: support "if/else" with results

Alex Zinenko llvmlistbot at llvm.org
Fri Apr 3 14:49:13 PDT 2020


Author: Alex Zinenko
Date: 2020-04-03T23:49:03+02:00
New Revision: 340e1b20779ebeb93f681689a345217672a308e3

URL: https://github.com/llvm/llvm-project/commit/340e1b20779ebeb93f681689a345217672a308e3
DIFF: https://github.com/llvm/llvm-project/commit/340e1b20779ebeb93f681689a345217672a308e3.diff

LOG: [mlir] LoopToStandard conversion: support "if/else" with results

Summary:
A recent extension allowed the `loop.if` operation to return results yielded by
its regions. However, such operations could not be lowered to a CFG of standard
operations because it would have required to modify the argument list of a
block, which is not allowed in a conversion pattern. Now that the conversion
infrastructure supports block creation, use it to create a block with an
argument list that dominates the operations following the `loop.if` and forward
the results as arguments of this block.

Depends On D77416

Differential Revision: https://reviews.llvm.org/D77418

Added: 
    

Modified: 
    mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp
    mlir/test/Conversion/convert-to-cfg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp b/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp
index e72c83027611..e971ca4ba4b5 100644
--- a/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp
+++ b/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp
@@ -112,13 +112,21 @@ struct ForLowering : public OpRewritePattern<ForOp> {
 // blocks are respectively the first/last block of the enclosing region. The
 // operations following the loop.if are split into a continuation (subgraph
 // exit) block. The condition is lowered to a chain of blocks that implement the
-// short-circuit scheme.  Condition blocks are created by splitting out an empty
-// block from the block that contains the loop.if operation.  They
-// conditionally branch to either the first block of the "then" region, or to
-// the first block of the "else" region.  If the latter is absent, they branch
-// to the continuation block instead.  The last blocks of "then" and "else"
-// regions (which are known to be exit blocks thanks to the invariant we
-// maintain).
+// short-circuit scheme. The "loop.if" operation is replaced with a conditional
+// branch to either the first block of the "then" region, or to the first block
+// of the "else" region. In these blocks, "loop.yield" is unconditional branches
+// to the post-dominating block. When the "loop.if" does not return values, the
+// post-dominating block is the same as the continuation block. When it returns
+// values, the post-dominating block is a new block with arguments that
+// correspond to the values returned by the "loop.if" that unconditionally
+// branches to the continuation block. This allows block arguments to dominate
+// any uses of the hitherto "loop.if" results that they replaced. (Inserting a
+// new block allows us to avoid modifying the argument list of an existing
+// block, which is illegal in a conversion pattern). When the "else" region is
+// empty, which is only allowed for "loop.if"s that don't return values, the
+// condition branches directly to the continuation block.
+//
+// CFG for a loop.if with else and without results.
 //
 //      +--------------------------------+
 //      | <code before the IfOp>         |
@@ -148,6 +156,42 @@ struct ForLowering : public OpRewritePattern<ForOp> {
 //      |   <code after the IfOp>        |
 //      +--------------------------------+
 //
+// CFG for a loop.if with results.
+//
+//      +--------------------------------+
+//      | <code before the IfOp>         |
+//      | cond_br %cond, %then, %else    |
+//      +--------------------------------+
+//             |              |
+//             |              --------------|
+//             v                            |
+//      +--------------------------------+  |
+//      | then:                          |  |
+//      |   <then contents>              |  |
+//      |   br dom(%args...)             |  |
+//      +--------------------------------+  |
+//             |                            |
+//   |----------               |-------------
+//   |                         V
+//   |  +--------------------------------+
+//   |  | else:                          |
+//   |  |   <else contents>              |
+//   |  |   br dom(%args...)             |
+//   |  +--------------------------------+
+//   |         |
+//   ------|   |
+//         v   v
+//      +--------------------------------+
+//      | dom(%args...):                 |
+//      |   br continue                  |
+//      +--------------------------------+
+//             |
+//             v
+//      +--------------------------------+
+//      | continue:                      |
+//      | <code after the IfOp>          |
+//      +--------------------------------+
+//
 struct IfLowering : public OpRewritePattern<IfOp> {
   using OpRewritePattern<IfOp>::OpRewritePattern;
 
@@ -238,15 +282,25 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
   // continuation point.
   auto *condBlock = rewriter.getInsertionBlock();
   auto opPosition = rewriter.getInsertionPoint();
-  auto *continueBlock = rewriter.splitBlock(condBlock, opPosition);
+  auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
+  Block *continueBlock;
+  if (ifOp.getNumResults() == 0) {
+    continueBlock = remainingOpsBlock;
+  } else {
+    continueBlock =
+        rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes());
+    rewriter.create<BranchOp>(loc, remainingOpsBlock);
+  }
 
   // Move blocks from the "then" region to the region containing 'loop.if',
   // place it before the continuation block, and branch to it.
   auto &thenRegion = ifOp.thenRegion();
   auto *thenBlock = &thenRegion.front();
-  rewriter.eraseOp(thenRegion.back().getTerminator());
+  Operation *thenTerminator = thenRegion.back().getTerminator();
+  ValueRange thenTerminatorOperands = thenTerminator->getOperands();
   rewriter.setInsertionPointToEnd(&thenRegion.back());
-  rewriter.create<BranchOp>(loc, continueBlock);
+  rewriter.create<BranchOp>(loc, continueBlock, thenTerminatorOperands);
+  rewriter.eraseOp(thenTerminator);
   rewriter.inlineRegionBefore(thenRegion, continueBlock);
 
   // Move blocks from the "else" region (if present) to the region containing
@@ -256,9 +310,11 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
   auto &elseRegion = ifOp.elseRegion();
   if (!elseRegion.empty()) {
     elseBlock = &elseRegion.front();
-    rewriter.eraseOp(elseRegion.back().getTerminator());
+    Operation *elseTerminator = elseRegion.back().getTerminator();
+    ValueRange elseTerminatorOperands = elseTerminator->getOperands();
     rewriter.setInsertionPointToEnd(&elseRegion.back());
-    rewriter.create<BranchOp>(loc, continueBlock);
+    rewriter.create<BranchOp>(loc, continueBlock, elseTerminatorOperands);
+    rewriter.eraseOp(elseTerminator);
     rewriter.inlineRegionBefore(elseRegion, continueBlock);
   }
 
@@ -268,7 +324,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
                                 /*falseArgs=*/ArrayRef<Value>());
 
   // Ok, we're done!
-  rewriter.eraseOp(ifOp);
+  rewriter.replaceOp(ifOp, continueBlock->getArguments());
   return success();
 }
 

diff  --git a/mlir/test/Conversion/convert-to-cfg.mlir b/mlir/test/Conversion/convert-to-cfg.mlir
index 8a8a999d5ee9..74ae6aeffd9c 100644
--- a/mlir/test/Conversion/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/convert-to-cfg.mlir
@@ -148,6 +148,83 @@ func @simple_std_for_loop_with_2_ifs(%arg0 : index, %arg1 : index, %arg2 : index
   return
 }
 
+// CHECK-LABEL: func @simple_if_yield
+func @simple_if_yield(%arg0: i1) -> (i1, i1) {
+// CHECK:   cond_br %{{.*}}, ^[[then:.*]], ^[[else:.*]]
+  %0:2 = loop.if %arg0 -> (i1, i1) {
+// CHECK: ^[[then]]:
+// CHECK:   %[[v0:.*]] = constant 0
+// CHECK:   %[[v1:.*]] = constant 1
+// CHECK:   br ^[[dom:.*]](%[[v0]], %[[v1]] : i1, i1)
+    %c0 = constant 0 : i1
+    %c1 = constant 1 : i1
+    loop.yield %c0, %c1 : i1, i1
+  } else {
+// CHECK: ^[[else]]:
+// CHECK:   %[[v2:.*]] = constant 0
+// CHECK:   %[[v3:.*]] = constant 1
+// CHECK:   br ^[[dom]](%[[v3]], %[[v2]] : i1, i1)
+    %c0 = constant 0 : i1
+    %c1 = constant 1 : i1
+    loop.yield %c1, %c0 : i1, i1
+  }
+// CHECK: ^[[dom]](%[[arg1:.*]]: i1, %[[arg2:.*]]: i1):
+// CHECK:   br ^[[cont:.*]]
+// CHECK: ^[[cont]]:
+// CHECK:   return %[[arg1]], %[[arg2]]
+  return %0#0, %0#1 : i1, i1
+}
+
+// CHECK-LABEL: func @nested_if_yield
+func @nested_if_yield(%arg0: i1) -> (index) {
+// CHECK:   cond_br %{{.*}}, ^[[first_then:.*]], ^[[first_else:.*]]
+  %0 = loop.if %arg0 -> i1 {
+// CHECK: ^[[first_then]]:
+    %1 = constant 1 : i1
+// CHECK:   br ^[[first_dom:.*]]({{.*}})
+    loop.yield %1 : i1
+  } else {
+// CHECK: ^[[first_else]]:
+    %2 = constant 0 : i1
+// CHECK:   br ^[[first_dom]]({{.*}})
+    loop.yield %2 : i1
+  }
+// CHECK: ^[[first_dom]](%[[arg1:.*]]: i1):
+// CHECK:   br ^[[first_cont:.*]]
+// CHECK: ^[[first_cont]]:
+// CHECK:   cond_br %[[arg1]], ^[[second_outer_then:.*]], ^[[second_outer_else:.*]]
+  %1 = loop.if %0 -> index {
+// CHECK: ^[[second_outer_then]]:
+// CHECK:   cond_br %arg0, ^[[second_inner_then:.*]], ^[[second_inner_else:.*]]
+    %3 = loop.if %arg0 -> index {
+// CHECK: ^[[second_inner_then]]:
+      %4 = constant 40 : index
+// CHECK:   br ^[[second_inner_dom:.*]]({{.*}})
+      loop.yield %4 : index
+    } else {
+// CHECK: ^[[second_inner_else]]:
+      %5 = constant 41 : index
+// CHECK:   br ^[[second_inner_dom]]({{.*}})
+      loop.yield %5 : index
+    }
+// CHECK: ^[[second_inner_dom]](%[[arg2:.*]]: index):
+// CHECK:   br ^[[second_inner_cont:.*]]
+// CHECK: ^[[second_inner_cont]]:
+// CHECK:   br ^[[second_outer_dom:.*]]({{.*}})
+    loop.yield %3 : index
+  } else {
+// CHECK: ^[[second_outer_else]]:
+    %6 = constant 42 : index
+// CHECK:   br ^[[second_outer_dom]]({{.*}}
+    loop.yield %6 : index
+  }
+// CHECK: ^[[second_outer_dom]](%[[arg3:.*]]: index):
+// CHECK:   br ^[[second_outer_cont:.*]]
+// CHECK: ^[[second_outer_cont]]:
+// CHECK:   return %[[arg3]] : index
+  return %1 : index
+}
+
 // CHECK-LABEL:   func @parallel_loop(
 // CHECK-SAME:                        [[VAL_0:%.*]]: index, [[VAL_1:%.*]]: index, [[VAL_2:%.*]]: index, [[VAL_3:%.*]]: index, [[VAL_4:%.*]]: index) {
 // CHECK:           [[VAL_5:%.*]] = constant 1 : index


        


More information about the Mlir-commits mailing list