[Mlir-commits] [mlir] 8ba8ab8 - [mlir] support reductions in loop to std conversion

Alex Zinenko llvmlistbot at llvm.org
Tue Mar 3 09:21:21 PST 2020


Author: Alex Zinenko
Date: 2020-03-03T18:21:13+01:00
New Revision: 8ba8ab8c95fb185d722842bf78053a8dad6181cd

URL: https://github.com/llvm/llvm-project/commit/8ba8ab8c95fb185d722842bf78053a8dad6181cd
DIFF: https://github.com/llvm/llvm-project/commit/8ba8ab8c95fb185d722842bf78053a8dad6181cd.diff

LOG: [mlir] support reductions in loop to std conversion

Summary:
Introduce support for converting loop.for operations with loop-carried values
to a CFG in the standard dialect. This is achieved by passing loop-carried
values as block arguments to the loop condition block. This block dominates
both the loop body and the block immediately following the loop, so the
arguments of this block are remain visible there.

Differential Revision: https://reviews.llvm.org/D75513

Added: 
    

Modified: 
    mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp
    mlir/test/Conversion/convert-to-cfg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp
index c57ec80c7a99..b4d25e04e389 100644
--- a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp
+++ b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp
@@ -45,21 +45,26 @@ struct LoopToStandardPass : public OperationPass<LoopToStandardPass> {
 // are split out into a separate continuation (exit) block. A condition block is
 // created before the continuation block. It checks the exit condition of the
 // loop and branches either to the continuation block, or to the first block of
-// the body. Induction variable modification is appended to the last block of
-// the body (which is the exit block from the body subgraph thanks to the
+// the body. The condition block takes as arguments the values of the induction
+// variable followed by loop-carried values. Since it dominates both the body
+// blocks and the continuation block, loop-carried values are visible in all of
+// those blocks. Induction variable modification is appended to the last block
+// of the body (which is the exit block from the body subgraph thanks to the
 // invariant we maintain) along with a branch that loops back to the condition
-// block.
+// block. Loop-carried values are the loop terminator operands, which are
+// forwarded to the branch.
 //
 //      +---------------------------------+
 //      |   <code before the ForOp>       |
+//      |   <definitions of %init...>     |
 //      |   <compute initial %iv value>   |
-//      |   br cond(%iv)                  |
+//      |   br cond(%iv, %init...)        |
 //      +---------------------------------+
 //             |
 //  -------|   |
 //  |      v   v
 //  |   +--------------------------------+
-//  |   | cond(%iv):                     |
+//  |   | cond(%iv, %init...):           |
 //  |   |   <compare %iv to upper bound> |
 //  |   |   cond_br %r, body, end        |
 //  |   +--------------------------------+
@@ -68,6 +73,7 @@ struct LoopToStandardPass : public OperationPass<LoopToStandardPass> {
 //  |          v                            |
 //  |   +--------------------------------+  |
 //  |   | body-first:                    |  |
+//  |   |   <%init visible by dominance> |  |
 //  |   |   <body contents>              |  |
 //  |   +--------------------------------+  |
 //  |                   |                   |
@@ -76,15 +82,17 @@ struct LoopToStandardPass : public OperationPass<LoopToStandardPass> {
 //  |   +--------------------------------+  |
 //  |   | body-last:                     |  |
 //  |   |   <body contents>              |  |
+//  |   |   <operands of yield = %yields>|  |
 //  |   |   %new_iv =<add step to %iv>   |  |
-//  |   |   br cond(%new_iv)             |  |
+//  |   |   br cond(%new_iv, %yields)    |  |
 //  |   +--------------------------------+  |
 //  |          |                            |
 //  |-----------        |--------------------
 //                      v
 //      +--------------------------------+
 //      | end:                           |
-//      |   <code after the ForOp> |
+//      |   <code after the ForOp>       |
+//      |   <%init visible by dominance> |
 //      +--------------------------------+
 //
 struct ForLowering : public OpRewritePattern<ForOp> {
@@ -133,7 +141,7 @@ struct ForLowering : public OpRewritePattern<ForOp> {
 //         v   v
 //      +--------------------------------+
 //      | continue:                      |
-//      |   <code after the IfOp>  |
+//      |   <code after the IfOp>        |
 //      +--------------------------------+
 //
 struct IfLowering : public OpRewritePattern<IfOp> {
@@ -162,10 +170,10 @@ ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const {
   auto initPosition = rewriter.getInsertionPoint();
   auto *endBlock = rewriter.splitBlock(initBlock, initPosition);
 
-  // Use the first block of the loop body as the condition block since it is
-  // the block that has the induction variable as its argument.  Split out
-  // all operations from the first block into a new block.  Move all body
-  // blocks from the loop body region to the region containing the loop.
+  // Use the first block of the loop body as the condition block since it is the
+  // block that has the induction variable and loop-carried values as arguments.
+  // Split out all operations from the first block into a new block. Move all
+  // body blocks from the loop body region to the region containing the loop.
   auto *conditionBlock = &forOp.region().front();
   auto *firstBodyBlock =
       rewriter.splitBlock(conditionBlock, conditionBlock->begin());
@@ -174,15 +182,20 @@ ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const {
   auto iv = conditionBlock->getArgument(0);
 
   // Append the induction variable stepping logic to the last body block and
-  // branch back to the condition block.  Construct an expression f :
-  // (x -> x+step) and apply this expression to the induction variable.
-  rewriter.eraseOp(lastBodyBlock->getTerminator());
+  // branch back to the condition block. Loop-carried values are taken from
+  // operands of the loop terminator.
+  Operation *terminator = lastBodyBlock->getTerminator();
   rewriter.setInsertionPointToEnd(lastBodyBlock);
   auto step = forOp.step();
   auto stepped = rewriter.create<AddIOp>(loc, iv, step).getResult();
   if (!stepped)
     return matchFailure();
-  rewriter.create<BranchOp>(loc, conditionBlock, stepped);
+
+  SmallVector<Value, 8> loopCarried;
+  loopCarried.push_back(stepped);
+  loopCarried.append(terminator->operand_begin(), terminator->operand_end());
+  rewriter.create<BranchOp>(loc, conditionBlock, loopCarried);
+  rewriter.eraseOp(terminator);
 
   // Compute loop bounds before branching to the condition.
   rewriter.setInsertionPointToEnd(initBlock);
@@ -190,7 +203,14 @@ ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const {
   Value upperBound = forOp.upperBound();
   if (!lowerBound || !upperBound)
     return matchFailure();
-  rewriter.create<BranchOp>(loc, conditionBlock, lowerBound);
+
+  // The initial values of loop-carried values is obtained from the operands
+  // of the loop operation.
+  SmallVector<Value, 8> destOperands;
+  destOperands.push_back(lowerBound);
+  auto iterOperands = forOp.getIterOperands();
+  destOperands.append(iterOperands.begin(), iterOperands.end());
+  rewriter.create<BranchOp>(loc, conditionBlock, destOperands);
 
   // With the body block done, we can fill in the condition block.
   rewriter.setInsertionPointToEnd(conditionBlock);
@@ -199,8 +219,9 @@ ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const {
 
   rewriter.create<CondBranchOp>(loc, comparison, firstBodyBlock,
                                 ArrayRef<Value>(), endBlock, ArrayRef<Value>());
-  // Ok, we're done!
-  rewriter.eraseOp(forOp);
+  // The result of the loop operation is the values of the condition block
+  // arguments except the induction variable on the last iteration.
+  rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front());
   return matchSuccess();
 }
 

diff  --git a/mlir/test/Conversion/convert-to-cfg.mlir b/mlir/test/Conversion/convert-to-cfg.mlir
index b53dc23c7a78..c6be3fdb8953 100644
--- a/mlir/test/Conversion/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/convert-to-cfg.mlir
@@ -180,3 +180,59 @@ func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
   }
   return
 }
+
+// CHECK-LABEL: @for_yield
+// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index)
+// CHECK:        %[[INIT0:.*]] = constant 0
+// CHECK:        %[[INIT1:.*]] = constant 1
+// CHECK:        br ^[[COND:.*]](%[[LB]], %[[INIT0]], %[[INIT1]] : index, f32, f32)
+//
+// CHECK:      ^[[COND]](%[[ITER:.*]]: index, %[[ITER_ARG0:.*]]: f32, %[[ITER_ARG1:.*]]: f32):
+// CHECK:        %[[CMP:.*]] = cmpi "slt", %[[ITER]], %[[UB]] : index
+// CHECK:        cond_br %[[CMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]]
+//
+// CHECK:      ^[[BODY]]:
+// CHECK:        %[[SUM:.*]] = addf %[[ITER_ARG0]], %[[ITER_ARG1]] : f32
+// CHECK:        %[[STEPPED:.*]] = addi %[[ITER]], %[[STEP]] : index
+// CHECK:        br ^[[COND]](%[[STEPPED]], %[[SUM]], %[[SUM]] : index, f32, f32)
+//
+// CHECK:      ^[[CONTINUE]]:
+// CHECK:        return %[[ITER_ARG0]], %[[ITER_ARG1]] : f32, f32
+func @for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> (f32, f32) {
+  %s0 = constant 0.0 : f32
+  %s1 = constant 1.0 : f32
+  %result:2 = loop.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0, %sj = %s1) -> (f32, f32) {
+    %sn = addf %si, %sj : f32
+    loop.yield %sn, %sn : f32, f32
+  }
+  return %result#0, %result#1 : f32, f32
+}
+
+// CHECK-LABEL: @nested_for_yield
+// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index)
+// CHECK:         %[[INIT:.*]] = constant
+// CHECK:         br ^[[COND_OUT:.*]](%[[LB]], %[[INIT]] : index, f32)
+// CHECK:       ^[[COND_OUT]](%[[ITER_OUT:.*]]: index, %[[ARG_OUT:.*]]: f32):
+// CHECK:         cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]]
+// CHECK:       ^[[BODY_OUT]]:
+// CHECK:         br ^[[COND_IN:.*]](%[[LB]], %[[ARG_OUT]] : index, f32)
+// CHECK:       ^[[COND_IN]](%[[ITER_IN:.*]]: index, %[[ARG_IN:.*]]: f32):
+// CHECK:         cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]]
+// CHECK:       ^[[BODY_IN]]
+// CHECK:         %[[RES:.*]] = addf
+// CHECK:         br ^[[COND_IN]](%{{.*}}, %[[RES]] : index, f32)
+// CHECK:       ^[[CONT_IN]]:
+// CHECK:         br ^[[COND_OUT]](%{{.*}}, %[[ARG_IN]] : index, f32)
+// CHECK:       ^[[CONT_OUT]]:
+// CHECK:         return %[[ARG_OUT]] : f32
+func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32 {
+  %s0 = constant 1.0 : f32
+  %r = loop.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%iter = %s0) -> (f32) {
+    %result = loop.for %i1 = %arg0 to %arg1 step %arg2 iter_args(%si = %iter) -> (f32) {
+      %sn = addf %si, %si : f32
+      loop.yield %sn : f32
+    }
+    loop.yield %result : f32
+  }
+  return %r : f32
+}


        


More information about the Mlir-commits mailing list