[Mlir-commits] [mlir] 1ec6086 - [mlir] Avoid cloning ops in SCF parallel conversion to CFG

Alex Zinenko llvmlistbot at llvm.org
Mon Nov 23 05:01:30 PST 2020


Author: Alex Zinenko
Date: 2020-11-23T14:01:22+01:00
New Revision: 1ec60862d7024118b2db5bcbb280eafcd9193ac5

URL: https://github.com/llvm/llvm-project/commit/1ec60862d7024118b2db5bcbb280eafcd9193ac5
DIFF: https://github.com/llvm/llvm-project/commit/1ec60862d7024118b2db5bcbb280eafcd9193ac5.diff

LOG: [mlir] Avoid cloning ops in SCF parallel conversion to CFG

The existing implementation of the conversion from SCF Parallel operation to
SCF "for" loops in order to further convert those loops to branch-based CFG has
been cloning the loop and reduction body operations into the new loop because
ConversionPatternRewriter was missing support for moving blocks while replacing
their arguments. This functionality now available, use it to implement the
conversion and avoid cloning operations, which may lead to doubling of the IR
size during the conversion.

In addition, this fixes an issue with converting nested SCF "if" conditionals
present in "parallel" operations that would cause the conversion infrastructure
to stop because of the repeated application of the pattern converting "newly"
created "if"s (which were in fact just moved). Arguably, this should be fixed
at the infrastructure level and this fix is a workaround.

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D91955

Added: 
    

Modified: 
    mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
    mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
index 56f6bf2f05fc..b8f3140dee73 100644
--- a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
+++ b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
@@ -404,7 +404,6 @@ LogicalResult
 ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
                                   PatternRewriter &rewriter) const {
   Location loc = parallelOp.getLoc();
-  BlockAndValueMapping mapping;
 
   // For a parallel loop, we essentially need to create an n-dimensional loop
   // nest. We do this by translating to scf.for ops and have those lowered in
@@ -412,6 +411,8 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
   // values), forward the initial values for the reductions down the loop
   // hierarchy and bubble up the results by modifying the "yield" terminator.
   SmallVector<Value, 4> iterArgs = llvm::to_vector<4>(parallelOp.initVals());
+  SmallVector<Value, 4> ivs;
+  ivs.reserve(parallelOp.getNumLoops());
   bool first = true;
   SmallVector<Value, 4> loopResults(iterArgs);
   for (auto loop_operands :
@@ -420,7 +421,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
     Value iv, lower, upper, step;
     std::tie(iv, lower, upper, step) = loop_operands;
     ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step, iterArgs);
-    mapping.map(iv, forOp.getInductionVar());
+    ivs.push_back(forOp.getInductionVar());
     auto iterRange = forOp.getRegionIterArgs();
     iterArgs.assign(iterRange.begin(), iterRange.end());
 
@@ -439,33 +440,33 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
     rewriter.setInsertionPointToStart(forOp.getBody());
   }
 
-  // Now copy over the contents of the body.
+  // First, merge reduction blocks into the main region.
   SmallVector<Value, 4> yieldOperands;
   yieldOperands.reserve(parallelOp.getNumResults());
-  for (auto &op : parallelOp.getBody()->without_terminator()) {
-    // Reduction blocks are handled 
diff erently.
+  for (auto &op : *parallelOp.getBody()) {
     auto reduce = dyn_cast<ReduceOp>(op);
-    if (!reduce) {
-      rewriter.clone(op, mapping);
+    if (!reduce)
       continue;
-    }
 
-    // Clone the body of the reduction operation into the body of the loop,
-    // using operands of "scf.reduce" and iteration arguments corresponding
-    // to the reduction value to replace arguments of the reduction block.
-    // Collect operands of "scf.reduce.return" to be returned by a final
-    // "scf.yield" instead.
-    Value arg = iterArgs[yieldOperands.size()];
     Block &reduceBlock = reduce.reductionOperator().front();
-    mapping.map(reduceBlock.getArgument(0), mapping.lookupOrDefault(arg));
-    mapping.map(reduceBlock.getArgument(1),
-                mapping.lookupOrDefault(reduce.operand()));
-    for (auto &nested : reduceBlock.without_terminator())
-      rewriter.clone(nested, mapping);
-    yieldOperands.push_back(
-        mapping.lookup(reduceBlock.getTerminator()->getOperand(0)));
+    Value arg = iterArgs[yieldOperands.size()];
+    yieldOperands.push_back(reduceBlock.getTerminator()->getOperand(0));
+    rewriter.eraseOp(reduceBlock.getTerminator());
+    rewriter.mergeBlockBefore(&reduceBlock, &op, {arg, reduce.operand()});
+    rewriter.eraseOp(reduce);
   }
 
+  // Then merge the loop body without the terminator.
+  rewriter.eraseOp(parallelOp.getBody()->getTerminator());
+  Block *newBody = rewriter.getInsertionBlock();
+  if (newBody->empty())
+    rewriter.mergeBlocks(parallelOp.getBody(), newBody, ivs);
+  else
+    rewriter.mergeBlockBefore(parallelOp.getBody(), newBody->getTerminator(),
+                              ivs);
+
+  // Finally, create the terminator if required (for loops with no results, it
+  // has been already created in loop construction).
   if (!yieldOperands.empty()) {
     rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
     rewriter.create<scf::YieldOp>(loc, yieldOperands);

diff  --git a/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir
index 7e0671b93607..67e0bb5f9739 100644
--- a/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir
@@ -546,3 +546,44 @@ func @nested_while_ops(%arg0: f32) -> i64 {
   return %0 : i64
 }
 
+// CHECK-LABEL: @ifs_in_parallel
+// CHECK: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: i1)
+func @ifs_in_parallel(%arg1: index, %arg2: index, %arg3: index, %arg4: i1, %arg5: i1) {
+  // CHECK:   br ^[[LOOP_LATCH:.*]](%[[ARG0]] : index)
+  // CHECK: ^[[LOOP_LATCH]](%[[LOOP_IV:.*]]: index):
+  // CHECK:   %[[LOOP_COND:.*]] = cmpi "slt", %[[LOOP_IV]], %[[ARG1]] : index
+  // CHECK:   cond_br %[[LOOP_COND]], ^[[LOOP_BODY:.*]], ^[[LOOP_CONT:.*]]
+  // CHECK: ^[[LOOP_BODY]]:
+  // CHECK:   cond_br %[[ARG3]], ^[[IF1_THEN:.*]], ^[[IF1_CONT:.*]]
+  // CHECK: ^[[IF1_THEN]]:
+  // CHECK:   cond_br %[[ARG4]], ^[[IF2_THEN:.*]], ^[[IF2_ELSE:.*]]
+  // CHECK: ^[[IF2_THEN]]:
+  // CHECK:   %{{.*}} = "test.if2"() : () -> index
+  // CHECK:   br ^[[IF2_MERGE:.*]](%{{.*}} : index)
+  // CHECK: ^[[IF2_ELSE]]:
+  // CHECK:   %{{.*}} = "test.else2"() : () -> index
+  // CHECK:   br ^[[IF2_MERGE]](%{{.*}} : index)
+  // CHECK: ^[[IF2_MERGE]](%{{.*}}: index):
+  // CHECK:   br ^[[IF2_CONT:.*]]
+  // CHECK: ^[[IF2_CONT]]:
+  // CHECK:   br ^[[IF1_CONT]]
+  // CHECK: ^[[IF1_CONT]]:
+  // CHECK:   %{{.*}} = addi %[[LOOP_IV]], %[[ARG2]] : index
+  // CHECK:   br ^[[LOOP_LATCH]](%{{.*}} : index)
+  scf.parallel (%i) = (%arg1) to (%arg2) step (%arg3) {
+    scf.if %arg4 {
+      %0 = scf.if %arg5 -> (index) {
+        %1 = "test.if2"() : () -> index
+        scf.yield %1 : index
+      } else {
+        %2 = "test.else2"() : () -> index
+        scf.yield %2 : index
+      }
+    }
+    scf.yield
+  }
+
+  // CHECK: ^[[LOOP_CONT]]:
+  // CHECK:   return
+  return
+}


        


More information about the Mlir-commits mailing list