[Mlir-commits] [mlir] aff6bf4 - [mlir] support conversion of parallel reduction loops to std

Alex Zinenko llvmlistbot at llvm.org
Wed Mar 4 07:37:26 PST 2020


Author: Alex Zinenko
Date: 2020-03-04T16:37:17+01:00
New Revision: aff6bf4ff81a35a85034b478cccc7015499ce427

URL: https://github.com/llvm/llvm-project/commit/aff6bf4ff81a35a85034b478cccc7015499ce427
DIFF: https://github.com/llvm/llvm-project/commit/aff6bf4ff81a35a85034b478cccc7015499ce427.diff

LOG: [mlir] support conversion of parallel reduction loops to std

Recently introduced support for converting sequential reduction loops to
CFG of basic blocks in the Standard dialect makes it possible to perform
a staged conversion of parallel reduction loops into a similar CFG by
using sequential loops as an intermediate step. This is already the case
for parallel loops without reduction, so extend the pattern to support
an additional use case.

Differential Revision: https://reviews.llvm.org/D75599

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LoopOps/LoopOps.td
    mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp
    mlir/lib/Dialect/LoopOps/LoopOps.cpp
    mlir/test/Conversion/convert-to-cfg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td
index f92a399dce70..8850349af574 100644
--- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td
+++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td
@@ -131,7 +131,8 @@ def ForOp : Loop_Op<"for",
   let skipDefaultBuilders = 1;
   let builders = [
     OpBuilder<"Builder *builder, OperationState &result, "
-              "Value lowerBound, Value upperBound, Value step">
+              "Value lowerBound, Value upperBound, Value step, "
+              "ValueRange iterArgs = llvm::None">
   ];
 
   let extraClassDeclaration = [{

diff  --git a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp
index b4d25e04e389..a16c4a0c5cfb 100644
--- a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp
+++ b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp
@@ -274,29 +274,75 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
   Location loc = parallelOp.getLoc();
   BlockAndValueMapping mapping;
 
-  if (parallelOp.getNumResults() != 0) {
-    // TODO: Implement lowering of parallelOp with reductions.
-    return matchFailure();
-  }
-
   // For a parallel loop, we essentially need to create an n-dimensional loop
   // nest. We do this by translating to loop.for ops and have those lowered in
-  // a further rewrite.
+  // a further rewrite. If a parallel loop contains reductions (and thus returns
+  // values), forward the initial values for the reductions down the loop
+  // hierarchy and bubble up the results by modifying the "yield" terminator.
+  SmallVector<Value, 4> iterArgs = llvm::to_vector<4>(parallelOp.initVals());
+  bool first = true;
+  SmallVector<Value, 4> loopResults(iterArgs);
   for (auto loop_operands :
        llvm::zip(parallelOp.getInductionVars(), parallelOp.lowerBound(),
                  parallelOp.upperBound(), parallelOp.step())) {
     Value iv, lower, upper, step;
     std::tie(iv, lower, upper, step) = loop_operands;
-    ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step);
+    ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step, iterArgs);
     mapping.map(iv, forOp.getInductionVar());
+    auto iterRange = forOp.getRegionIterArgs();
+    iterArgs.assign(iterRange.begin(), iterRange.end());
+
+    if (first) {
+      // Store the results of the outermost loop that will be used to replace
+      // the results of the parallel loop when it is fully rewritten.
+      loopResults.assign(forOp.result_begin(), forOp.result_end());
+      first = false;
+    } else {
+      // A loop is constructed with an empty "yield" terminator by default.
+      // Replace it with another "yield" that forwards the results of the nested
+      // loop to the parent loop. We need to explicitly make sure the new
+      // terminator is the last operation in the block because further transfoms
+      // rely on this.
+      rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
+      rewriter.replaceOpWithNewOp<YieldOp>(
+          rewriter.getInsertionBlock()->getTerminator(), forOp.getResults());
+    }
+
     rewriter.setInsertionPointToStart(forOp.getBody());
   }
 
   // Now copy over the contents of the body.
-  for (auto &op : parallelOp.getBody()->without_terminator())
-    rewriter.clone(op, mapping);
+  SmallVector<Value, 4> yieldOperands;
+  yieldOperands.reserve(parallelOp.getNumResults());
+  for (auto &op : parallelOp.getBody()->without_terminator()) {
+    // Reduction blocks are handled 
diff erently.
+    auto reduce = dyn_cast<ReduceOp>(op);
+    if (!reduce) {
+      rewriter.clone(op, mapping);
+      continue;
+    }
+
+    // Clone the body of the reduction operation into the body of the loop,
+    // using operands of "loop.reduce" and iteration arguments corresponding
+    // to the reduction value to replace arguments of the reduction block.
+    // Collect operands of "loop.reduce.return" to be returned by a final
+    // "loop.yield" instead.
+    Value arg = iterArgs[yieldOperands.size()];
+    Block &reduceBlock = reduce.reductionOperator().front();
+    mapping.map(reduceBlock.getArgument(0), mapping.lookupOrDefault(arg));
+    mapping.map(reduceBlock.getArgument(1),
+                mapping.lookupOrDefault(reduce.operand()));
+    for (auto &nested : reduceBlock.without_terminator())
+      rewriter.clone(nested, mapping);
+    yieldOperands.push_back(
+        mapping.lookup(reduceBlock.getTerminator()->getOperand(0)));
+  }
+
+  rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
+  rewriter.replaceOpWithNewOp<YieldOp>(
+      rewriter.getInsertionBlock()->getTerminator(), yieldOperands);
 
-  rewriter.eraseOp(parallelOp);
+  rewriter.replaceOp(parallelOp, loopResults);
 
   return matchSuccess();
 }

diff  --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp
index e9e9397a6d19..c0cb149bf815 100644
--- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp
+++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp
@@ -61,11 +61,16 @@ LoopOpsDialect::LoopOpsDialect(MLIRContext *context)
 //===----------------------------------------------------------------------===//
 
 void ForOp::build(Builder *builder, OperationState &result, Value lb, Value ub,
-                  Value step) {
+                  Value step, ValueRange iterArgs) {
   result.addOperands({lb, ub, step});
+  result.addOperands(iterArgs);
+  for (Value v : iterArgs)
+    result.addTypes(v.getType());
   Region *bodyRegion = result.addRegion();
   ForOp::ensureTerminator(*bodyRegion, *builder, result.location);
   bodyRegion->front().addArgument(builder->getIndexType());
+  for (Value v : iterArgs)
+    bodyRegion->front().addArgument(v.getType());
 }
 
 static LogicalResult verify(ForOp op) {

diff  --git a/mlir/test/Conversion/convert-to-cfg.mlir b/mlir/test/Conversion/convert-to-cfg.mlir
index c6be3fdb8953..54c5d4c4a9cf 100644
--- a/mlir/test/Conversion/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/convert-to-cfg.mlir
@@ -236,3 +236,88 @@ func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32 {
   }
   return %r : f32
 }
+
+func @generate() -> i64
+
+// CHECK-LABEL: @simple_parallel_reduce_loop
+// CHECK-SAME: %[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[INIT:.*]]: f32
+func @simple_parallel_reduce_loop(%arg0: index, %arg1: index,
+                                  %arg2: index, %arg3: f32) -> f32 {
+  // A parallel loop with reduction is converted through sequential loops with
+  // reductions into a CFG of blocks where the partially reduced value is
+  // passed across as a block argument.
+
+  // Branch to the condition block passing in the initial reduction value.
+  // CHECK:   br ^[[COND:.*]](%[[LB]], %[[INIT]]
+
+  // Condition branch takes as arguments the current value of the iteration
+  // variable and the current partially reduced value.
+  // CHECK: ^[[COND]](%[[ITER:.*]]: index, %[[ITER_ARG:.*]]: f32
+  // CHECK:   %[[COMP:.*]] = cmpi "slt", %[[ITER]], %[[UB]]
+  // CHECK:   cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]]
+
+  // Bodies of loop.reduce operations are folded into the main loop body. The
+  // result of this partial reduction is passed as argument to the condition
+  // block.
+  // CHECK: ^[[BODY]]:
+  // CHECK:   %[[CST:.*]] = constant 4.2
+  // CHECK:   %[[PROD:.*]] = mulf %[[ITER_ARG]], %[[CST]]
+  // CHECK:   %[[INCR:.*]] = addi %[[ITER]], %[[STEP]]
+  // CHECK:   br ^[[COND]](%[[INCR]], %[[PROD]]
+
+  // The continuation block has access to the (last value of) reduction.
+  // CHECK: ^[[CONTINUE]]:
+  // CHECK:   return %[[ITER_ARG]]
+  %0 = loop.parallel (%i) = (%arg0) to (%arg1) step (%arg2) init(%arg3) {
+    %cst = constant 42.0 : f32
+    loop.reduce(%cst) {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %1 = mulf %lhs, %rhs : f32
+      loop.reduce.return %1 : f32
+    } : f32
+  } : f32
+  return %0 : f32
+}
+
+// CHECK-LABEL: parallel_reduce_loop
+// CHECK-SAME: %[[INIT1:[0-9A-Za-z_]*]]: f32)
+func @parallel_reduce_loop(%arg0 : index, %arg1 : index, %arg2 : index,
+                           %arg3 : index, %arg4 : index, %arg5 : f32) -> (f32, i64) {
+  // Multiple reduction blocks should be folded in the same body, and the
+  // reduction value must be forwarded through block structures.
+  // CHECK:   %[[INIT2:.*]] = constant 42
+  // CHECK:   br ^[[COND_OUT:.*]](%{{.*}}, %[[INIT1]], %[[INIT2]]
+  // CHECK: ^[[COND_OUT]](%{{.*}}: index, %[[ITER_ARG1_OUT:.*]]: f32, %[[ITER_ARG2_OUT:.*]]: i64
+  // CHECK:   cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]]
+  // CHECK: ^[[BODY_OUT]]:
+  // CHECK:   br ^[[COND_IN:.*]](%{{.*}}, %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]]
+  // CHECK: ^[[COND_IN]](%{{.*}}: index, %[[ITER_ARG1_IN:.*]]: f32, %[[ITER_ARG2_IN:.*]]: i64
+  // CHECK:   cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]]
+  // CHECK: ^[[BODY_IN]]:
+  // CHECK:   %[[REDUCE1:.*]] = addf %[[ITER_ARG1_IN]], %{{.*}}
+  // CHECK:   %[[REDUCE2:.*]] = or %[[ITER_ARG2_IN]], %{{.*}}
+  // CHECK:   br ^[[COND_IN]](%{{.*}}, %[[REDUCE1]], %[[REDUCE2]]
+  // CHECK: ^[[CONT_IN]]:
+  // CHECK:   br ^[[COND_OUT]](%{{.*}}, %[[ITER_ARG1_IN]], %[[ITER_ARG2_IN]]
+  // CHECK: ^[[CONT_OUT]]:
+  // CHECK:   return %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]]
+  %step = constant 1 : index
+  %init = constant 42 : i64
+  %0:2 = loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
+                       step (%arg4, %step) init(%arg5, %init) {
+    %cf = constant 42.0 : f32
+    loop.reduce(%cf) {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %1 = addf %lhs, %rhs : f32
+      loop.reduce.return %1 : f32
+    } : f32
+
+    %2 = call @generate() : () -> i64
+    loop.reduce(%2) {
+    ^bb0(%lhs: i64, %rhs: i64):
+      %3 = or %lhs, %rhs : i64
+      loop.reduce.return %3 : i64
+    } : i64
+  } : f32, i64
+  return %0#0, %0#1 : f32, i64
+}


        


More information about the Mlir-commits mailing list