[Mlir-commits] [mlir] 340e1b2 - [mlir] LoopToStandard conversion: support "if/else" with results
Alex Zinenko
llvmlistbot at llvm.org
Fri Apr 3 14:49:13 PDT 2020
Author: Alex Zinenko
Date: 2020-04-03T23:49:03+02:00
New Revision: 340e1b20779ebeb93f681689a345217672a308e3
URL: https://github.com/llvm/llvm-project/commit/340e1b20779ebeb93f681689a345217672a308e3
DIFF: https://github.com/llvm/llvm-project/commit/340e1b20779ebeb93f681689a345217672a308e3.diff
LOG: [mlir] LoopToStandard conversion: support "if/else" with results
Summary:
A recent extension allowed the `loop.if` operation to return results yielded by
its regions. However, such operations could not be lowered to a CFG of standard
operations because it would have required to modify the argument list of a
block, which is not allowed in a conversion pattern. Now that the conversion
infrastructure supports block creation, use it to create a block with an
argument list that dominates the operations following the `loop.if` and forward
the results as arguments of this block.
Depends On D77416
Differential Revision: https://reviews.llvm.org/D77418
Added:
Modified:
mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp
mlir/test/Conversion/convert-to-cfg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp b/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp
index e72c83027611..e971ca4ba4b5 100644
--- a/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp
+++ b/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp
@@ -112,13 +112,21 @@ struct ForLowering : public OpRewritePattern<ForOp> {
// blocks are respectively the first/last block of the enclosing region. The
// operations following the loop.if are split into a continuation (subgraph
// exit) block. The condition is lowered to a chain of blocks that implement the
-// short-circuit scheme. Condition blocks are created by splitting out an empty
-// block from the block that contains the loop.if operation. They
-// conditionally branch to either the first block of the "then" region, or to
-// the first block of the "else" region. If the latter is absent, they branch
-// to the continuation block instead. The last blocks of "then" and "else"
-// regions (which are known to be exit blocks thanks to the invariant we
-// maintain).
+// short-circuit scheme. The "loop.if" operation is replaced with a conditional
+// branch to either the first block of the "then" region, or to the first block
+// of the "else" region. In these blocks, "loop.yield" is unconditional branches
+// to the post-dominating block. When the "loop.if" does not return values, the
+// post-dominating block is the same as the continuation block. When it returns
+// values, the post-dominating block is a new block with arguments that
+// correspond to the values returned by the "loop.if" that unconditionally
+// branches to the continuation block. This allows block arguments to dominate
+// any uses of the hitherto "loop.if" results that they replaced. (Inserting a
+// new block allows us to avoid modifying the argument list of an existing
+// block, which is illegal in a conversion pattern). When the "else" region is
+// empty, which is only allowed for "loop.if"s that don't return values, the
+// condition branches directly to the continuation block.
+//
+// CFG for a loop.if with else and without results.
//
// +--------------------------------+
// | <code before the IfOp> |
@@ -148,6 +156,42 @@ struct ForLowering : public OpRewritePattern<ForOp> {
// | <code after the IfOp> |
// +--------------------------------+
//
+// CFG for a loop.if with results.
+//
+// +--------------------------------+
+// | <code before the IfOp> |
+// | cond_br %cond, %then, %else |
+// +--------------------------------+
+// | |
+// | --------------|
+// v |
+// +--------------------------------+ |
+// | then: | |
+// | <then contents> | |
+// | br dom(%args...) | |
+// +--------------------------------+ |
+// | |
+// |---------- |-------------
+// | V
+// | +--------------------------------+
+// | | else: |
+// | | <else contents> |
+// | | br dom(%args...) |
+// | +--------------------------------+
+// | |
+// ------| |
+// v v
+// +--------------------------------+
+// | dom(%args...): |
+// | br continue |
+// +--------------------------------+
+// |
+// v
+// +--------------------------------+
+// | continue: |
+// | <code after the IfOp> |
+// +--------------------------------+
+//
struct IfLowering : public OpRewritePattern<IfOp> {
using OpRewritePattern<IfOp>::OpRewritePattern;
@@ -238,15 +282,25 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
// continuation point.
auto *condBlock = rewriter.getInsertionBlock();
auto opPosition = rewriter.getInsertionPoint();
- auto *continueBlock = rewriter.splitBlock(condBlock, opPosition);
+ auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
+ Block *continueBlock;
+ if (ifOp.getNumResults() == 0) {
+ continueBlock = remainingOpsBlock;
+ } else {
+ continueBlock =
+ rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes());
+ rewriter.create<BranchOp>(loc, remainingOpsBlock);
+ }
// Move blocks from the "then" region to the region containing 'loop.if',
// place it before the continuation block, and branch to it.
auto &thenRegion = ifOp.thenRegion();
auto *thenBlock = &thenRegion.front();
- rewriter.eraseOp(thenRegion.back().getTerminator());
+ Operation *thenTerminator = thenRegion.back().getTerminator();
+ ValueRange thenTerminatorOperands = thenTerminator->getOperands();
rewriter.setInsertionPointToEnd(&thenRegion.back());
- rewriter.create<BranchOp>(loc, continueBlock);
+ rewriter.create<BranchOp>(loc, continueBlock, thenTerminatorOperands);
+ rewriter.eraseOp(thenTerminator);
rewriter.inlineRegionBefore(thenRegion, continueBlock);
// Move blocks from the "else" region (if present) to the region containing
@@ -256,9 +310,11 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
auto &elseRegion = ifOp.elseRegion();
if (!elseRegion.empty()) {
elseBlock = &elseRegion.front();
- rewriter.eraseOp(elseRegion.back().getTerminator());
+ Operation *elseTerminator = elseRegion.back().getTerminator();
+ ValueRange elseTerminatorOperands = elseTerminator->getOperands();
rewriter.setInsertionPointToEnd(&elseRegion.back());
- rewriter.create<BranchOp>(loc, continueBlock);
+ rewriter.create<BranchOp>(loc, continueBlock, elseTerminatorOperands);
+ rewriter.eraseOp(elseTerminator);
rewriter.inlineRegionBefore(elseRegion, continueBlock);
}
@@ -268,7 +324,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
/*falseArgs=*/ArrayRef<Value>());
// Ok, we're done!
- rewriter.eraseOp(ifOp);
+ rewriter.replaceOp(ifOp, continueBlock->getArguments());
return success();
}
diff --git a/mlir/test/Conversion/convert-to-cfg.mlir b/mlir/test/Conversion/convert-to-cfg.mlir
index 8a8a999d5ee9..74ae6aeffd9c 100644
--- a/mlir/test/Conversion/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/convert-to-cfg.mlir
@@ -148,6 +148,83 @@ func @simple_std_for_loop_with_2_ifs(%arg0 : index, %arg1 : index, %arg2 : index
return
}
+// CHECK-LABEL: func @simple_if_yield
+func @simple_if_yield(%arg0: i1) -> (i1, i1) {
+// CHECK: cond_br %{{.*}}, ^[[then:.*]], ^[[else:.*]]
+ %0:2 = loop.if %arg0 -> (i1, i1) {
+// CHECK: ^[[then]]:
+// CHECK: %[[v0:.*]] = constant 0
+// CHECK: %[[v1:.*]] = constant 1
+// CHECK: br ^[[dom:.*]](%[[v0]], %[[v1]] : i1, i1)
+ %c0 = constant 0 : i1
+ %c1 = constant 1 : i1
+ loop.yield %c0, %c1 : i1, i1
+ } else {
+// CHECK: ^[[else]]:
+// CHECK: %[[v2:.*]] = constant 0
+// CHECK: %[[v3:.*]] = constant 1
+// CHECK: br ^[[dom]](%[[v3]], %[[v2]] : i1, i1)
+ %c0 = constant 0 : i1
+ %c1 = constant 1 : i1
+ loop.yield %c1, %c0 : i1, i1
+ }
+// CHECK: ^[[dom]](%[[arg1:.*]]: i1, %[[arg2:.*]]: i1):
+// CHECK: br ^[[cont:.*]]
+// CHECK: ^[[cont]]:
+// CHECK: return %[[arg1]], %[[arg2]]
+ return %0#0, %0#1 : i1, i1
+}
+
+// CHECK-LABEL: func @nested_if_yield
+func @nested_if_yield(%arg0: i1) -> (index) {
+// CHECK: cond_br %{{.*}}, ^[[first_then:.*]], ^[[first_else:.*]]
+ %0 = loop.if %arg0 -> i1 {
+// CHECK: ^[[first_then]]:
+ %1 = constant 1 : i1
+// CHECK: br ^[[first_dom:.*]]({{.*}})
+ loop.yield %1 : i1
+ } else {
+// CHECK: ^[[first_else]]:
+ %2 = constant 0 : i1
+// CHECK: br ^[[first_dom]]({{.*}})
+ loop.yield %2 : i1
+ }
+// CHECK: ^[[first_dom]](%[[arg1:.*]]: i1):
+// CHECK: br ^[[first_cont:.*]]
+// CHECK: ^[[first_cont]]:
+// CHECK: cond_br %[[arg1]], ^[[second_outer_then:.*]], ^[[second_outer_else:.*]]
+ %1 = loop.if %0 -> index {
+// CHECK: ^[[second_outer_then]]:
+// CHECK: cond_br %arg0, ^[[second_inner_then:.*]], ^[[second_inner_else:.*]]
+ %3 = loop.if %arg0 -> index {
+// CHECK: ^[[second_inner_then]]:
+ %4 = constant 40 : index
+// CHECK: br ^[[second_inner_dom:.*]]({{.*}})
+ loop.yield %4 : index
+ } else {
+// CHECK: ^[[second_inner_else]]:
+ %5 = constant 41 : index
+// CHECK: br ^[[second_inner_dom]]({{.*}})
+ loop.yield %5 : index
+ }
+// CHECK: ^[[second_inner_dom]](%[[arg2:.*]]: index):
+// CHECK: br ^[[second_inner_cont:.*]]
+// CHECK: ^[[second_inner_cont]]:
+// CHECK: br ^[[second_outer_dom:.*]]({{.*}})
+ loop.yield %3 : index
+ } else {
+// CHECK: ^[[second_outer_else]]:
+ %6 = constant 42 : index
+// CHECK: br ^[[second_outer_dom]]({{.*}}
+ loop.yield %6 : index
+ }
+// CHECK: ^[[second_outer_dom]](%[[arg3:.*]]: index):
+// CHECK: br ^[[second_outer_cont:.*]]
+// CHECK: ^[[second_outer_cont]]:
+// CHECK: return %[[arg3]] : index
+ return %1 : index
+}
+
// CHECK-LABEL: func @parallel_loop(
// CHECK-SAME: [[VAL_0:%.*]]: index, [[VAL_1:%.*]]: index, [[VAL_2:%.*]]: index, [[VAL_3:%.*]]: index, [[VAL_4:%.*]]: index) {
// CHECK: [[VAL_5:%.*]] = constant 1 : index
More information about the Mlir-commits
mailing list