[Mlir-commits] [mlir] eb31540 - [mlir] Canonicalize single-iteration ParallelOp

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 13 03:42:49 PDT 2021


Author: Butygin
Date: 2021-04-13T13:42:19+03:00
New Revision: eb31540066736658a71d7fc1154be8432e553a11

URL: https://github.com/llvm/llvm-project/commit/eb31540066736658a71d7fc1154be8432e553a11
DIFF: https://github.com/llvm/llvm-project/commit/eb31540066736658a71d7fc1154be8432e553a11.diff

LOG: [mlir] Canonicalize single-iteration ParallelOp

Differential Revision: https://reviews.llvm.org/D100248

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 2d1ad054bf05..efbc87273ca8 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -1239,10 +1239,36 @@ struct CollapseSingleIterationLoops : public OpRewritePattern<ParallelOp> {
         newSteps.push_back(step);
       }
     }
-    // Exit if all or none of the loop dimensions perform a single iteration.
-    if (newLowerBounds.size() == 0 ||
-        newLowerBounds.size() == op.lowerBound().size())
+    // Exit if none of the loop dimensions perform a single iteration.
+    if (newLowerBounds.size() == op.lowerBound().size())
       return failure();
+
+    if (newLowerBounds.empty()) {
+      // All of the loop dimensions perform a single iteration. Inline
+      // loop body and nested ReduceOp's
+      SmallVector<Value> results;
+      results.reserve(op.initVals().size());
+      for (auto &bodyOp : op.getLoopBody().front().without_terminator()) {
+        auto reduce = dyn_cast<ReduceOp>(bodyOp);
+        if (!reduce) {
+          rewriter.clone(bodyOp, mapping);
+          continue;
+        }
+        Block &reduceBlock = reduce.reductionOperator().front();
+        auto initValIndex = results.size();
+        mapping.map(reduceBlock.getArgument(0), op.initVals()[initValIndex]);
+        mapping.map(reduceBlock.getArgument(1),
+                    mapping.lookupOrDefault(reduce.operand()));
+        for (auto &reduceBodyOp : reduceBlock.without_terminator())
+          rewriter.clone(reduceBodyOp, mapping);
+
+        auto result = mapping.lookupOrDefault(
+            cast<ReduceReturnOp>(reduceBlock.getTerminator()).result());
+        results.push_back(result);
+      }
+      rewriter.replaceOp(op, results);
+      return success();
+    }
     // Replace the parallel loop by lower-dimensional parallel loop.
     auto newOp =
         rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 7c751623db86..3964f85ba3d2 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -3,7 +3,7 @@
 
 // -----
 
-func @single_iteration(%A: memref<?x?x?xi32>) {
+func @single_iteration_some(%A: memref<?x?x?xi32>) {
   %c0 = constant 0 : index
   %c1 = constant 1 : index
   %c2 = constant 2 : index
@@ -19,7 +19,7 @@ func @single_iteration(%A: memref<?x?x?xi32>) {
   return
 }
 
-// CHECK-LABEL:   func @single_iteration(
+// CHECK-LABEL:   func @single_iteration_some(
 // CHECK-SAME:                        [[ARG0:%.*]]: memref<?x?x?xi32>) {
 // CHECK-DAG:           [[C42:%.*]] = constant 42 : i32
 // CHECK-DAG:           [[C7:%.*]] = constant 7 : index
@@ -35,6 +35,70 @@ func @single_iteration(%A: memref<?x?x?xi32>) {
 
 // -----
 
+func @single_iteration_all(%A: memref<?x?x?xi32>) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c3 = constant 3 : index
+  %c6 = constant 6 : index
+  %c7 = constant 7 : index
+  %c10 = constant 10 : index
+  scf.parallel (%i0, %i1, %i2) = (%c0, %c3, %c7) to (%c1, %c6, %c10) step (%c1, %c3, %c3) {
+    %c42 = constant 42 : i32
+    memref.store %c42, %A[%i0, %i1, %i2] : memref<?x?x?xi32>
+    scf.yield
+  }
+  return
+}
+
+// CHECK-LABEL:   func @single_iteration_all(
+// CHECK-SAME:                        [[ARG0:%.*]]: memref<?x?x?xi32>) {
+// CHECK-DAG:           [[C42:%.*]] = constant 42 : i32
+// CHECK-DAG:           [[C7:%.*]] = constant 7 : index
+// CHECK-DAG:           [[C3:%.*]] = constant 3 : index
+// CHECK-DAG:           [[C0:%.*]] = constant 0 : index
+// CHECK-NOT:           scf.parallel
+// CHECK:               memref.store [[C42]], [[ARG0]]{{\[}}[[C0]], [[C3]], [[C7]]] : memref<?x?x?xi32>
+// CHECK-NOT:           scf.yield
+// CHECK:               return
+
+// -----
+
+func @single_iteration_reduce(%A: index, %B: index) -> (index, index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c3 = constant 3 : index
+  %c6 = constant 6 : index
+  %0:2 = scf.parallel (%i0, %i1) = (%c1, %c3) to (%c2, %c6) step (%c1, %c3) init(%A, %B) -> (index, index) {
+    scf.reduce(%i0) : index {
+    ^bb0(%lhs: index, %rhs: index):
+      %1 = addi %lhs, %rhs : index
+      scf.reduce.return %1 : index
+    }
+    scf.reduce(%i1) : index {
+    ^bb0(%lhs: index, %rhs: index):
+      %2 = muli %lhs, %rhs : index
+      scf.reduce.return %2 : index
+    }
+    scf.yield
+  }
+  return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL:   func @single_iteration_reduce(
+// CHECK-SAME:                        [[ARG0:%.*]]: index, [[ARG1:%.*]]: index)
+// CHECK-DAG:           [[C3:%.*]] = constant 3 : index
+// CHECK-DAG:           [[C1:%.*]] = constant 1 : index
+// CHECK-NOT:           scf.parallel
+// CHECK-NOT:           scf.reduce
+// CHECK-NOT:           scf.reduce.return
+// CHECK-NOT:           scf.yield
+// CHECK:               [[V0:%.*]] = addi [[ARG0]], [[C1]]
+// CHECK:               [[V1:%.*]] = muli [[ARG1]], [[C3]]
+// CHECK:               return [[V0]], [[V1]]
+
+// -----
+
 func private @side_effect()
 func @one_unused(%cond: i1) -> (index) {
   %c0 = constant 0 : index
@@ -488,7 +552,7 @@ func @fold_away_iter_with_no_use_and_yielded_input(%arg0 : i32,
                     %ub : index, %lb : index, %step : index) -> (i32, i32) {
   // CHECK-NEXT: %[[C32:.*]] = constant 32 : i32
   %cst = constant 32 : i32
-  // CHECK-NEXT: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args({{.*}} = %[[A0]]) -> (i32) { 
+  // CHECK-NEXT: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args({{.*}} = %[[A0]]) -> (i32) {
   %0:2 = scf.for %arg1 = %lb to %ub step %step iter_args(%arg2 = %arg0, %arg3 = %cst)
     -> (i32, i32) {
     %1 = addi %arg2, %cst : i32
@@ -512,7 +576,7 @@ func @fold_away_iter_and_result_with_no_use(%arg0 : i32,
     %1 = addi %arg2, %cst : i32
     scf.yield %1, %1 : i32, i32
   }
-  
+
   // CHECK: return %[[FOR_RES]] : i32
   return %0#0 : i32
 }


        


More information about the Mlir-commits mailing list