[Mlir-commits] [mlir] eb31540 - [mlir] Canonicalize single-iteration ParallelOp
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 13 03:42:49 PDT 2021
Author: Butygin
Date: 2021-04-13T13:42:19+03:00
New Revision: eb31540066736658a71d7fc1154be8432e553a11
URL: https://github.com/llvm/llvm-project/commit/eb31540066736658a71d7fc1154be8432e553a11
DIFF: https://github.com/llvm/llvm-project/commit/eb31540066736658a71d7fc1154be8432e553a11.diff
LOG: [mlir] Canonicalize single-iteration ParallelOp
Differential Revision: https://reviews.llvm.org/D100248
Added:
Modified:
mlir/lib/Dialect/SCF/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 2d1ad054bf05..efbc87273ca8 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -1239,10 +1239,36 @@ struct CollapseSingleIterationLoops : public OpRewritePattern<ParallelOp> {
newSteps.push_back(step);
}
}
- // Exit if all or none of the loop dimensions perform a single iteration.
- if (newLowerBounds.size() == 0 ||
- newLowerBounds.size() == op.lowerBound().size())
+ // Exit if none of the loop dimensions perform a single iteration.
+ if (newLowerBounds.size() == op.lowerBound().size())
return failure();
+
+ if (newLowerBounds.empty()) {
+ // All of the loop dimensions perform a single iteration. Inline
+ // loop body and nested ReduceOp's
+ SmallVector<Value> results;
+ results.reserve(op.initVals().size());
+ for (auto &bodyOp : op.getLoopBody().front().without_terminator()) {
+ auto reduce = dyn_cast<ReduceOp>(bodyOp);
+ if (!reduce) {
+ rewriter.clone(bodyOp, mapping);
+ continue;
+ }
+ Block &reduceBlock = reduce.reductionOperator().front();
+ auto initValIndex = results.size();
+ mapping.map(reduceBlock.getArgument(0), op.initVals()[initValIndex]);
+ mapping.map(reduceBlock.getArgument(1),
+ mapping.lookupOrDefault(reduce.operand()));
+ for (auto &reduceBodyOp : reduceBlock.without_terminator())
+ rewriter.clone(reduceBodyOp, mapping);
+
+ auto result = mapping.lookupOrDefault(
+ cast<ReduceReturnOp>(reduceBlock.getTerminator()).result());
+ results.push_back(result);
+ }
+ rewriter.replaceOp(op, results);
+ return success();
+ }
// Replace the parallel loop by lower-dimensional parallel loop.
auto newOp =
rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 7c751623db86..3964f85ba3d2 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -3,7 +3,7 @@
// -----
-func @single_iteration(%A: memref<?x?x?xi32>) {
+func @single_iteration_some(%A: memref<?x?x?xi32>) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
@@ -19,7 +19,7 @@ func @single_iteration(%A: memref<?x?x?xi32>) {
return
}
-// CHECK-LABEL: func @single_iteration(
+// CHECK-LABEL: func @single_iteration_some(
// CHECK-SAME: [[ARG0:%.*]]: memref<?x?x?xi32>) {
// CHECK-DAG: [[C42:%.*]] = constant 42 : i32
// CHECK-DAG: [[C7:%.*]] = constant 7 : index
@@ -35,6 +35,70 @@ func @single_iteration(%A: memref<?x?x?xi32>) {
// -----
+func @single_iteration_all(%A: memref<?x?x?xi32>) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c3 = constant 3 : index
+ %c6 = constant 6 : index
+ %c7 = constant 7 : index
+ %c10 = constant 10 : index
+ scf.parallel (%i0, %i1, %i2) = (%c0, %c3, %c7) to (%c1, %c6, %c10) step (%c1, %c3, %c3) {
+ %c42 = constant 42 : i32
+ memref.store %c42, %A[%i0, %i1, %i2] : memref<?x?x?xi32>
+ scf.yield
+ }
+ return
+}
+
+// CHECK-LABEL: func @single_iteration_all(
+// CHECK-SAME: [[ARG0:%.*]]: memref<?x?x?xi32>) {
+// CHECK-DAG: [[C42:%.*]] = constant 42 : i32
+// CHECK-DAG: [[C7:%.*]] = constant 7 : index
+// CHECK-DAG: [[C3:%.*]] = constant 3 : index
+// CHECK-DAG: [[C0:%.*]] = constant 0 : index
+// CHECK-NOT: scf.parallel
+// CHECK: memref.store [[C42]], [[ARG0]]{{\[}}[[C0]], [[C3]], [[C7]]] : memref<?x?x?xi32>
+// CHECK-NOT: scf.yield
+// CHECK: return
+
+// -----
+
+func @single_iteration_reduce(%A: index, %B: index) -> (index, index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %c3 = constant 3 : index
+ %c6 = constant 6 : index
+ %0:2 = scf.parallel (%i0, %i1) = (%c1, %c3) to (%c2, %c6) step (%c1, %c3) init(%A, %B) -> (index, index) {
+ scf.reduce(%i0) : index {
+ ^bb0(%lhs: index, %rhs: index):
+ %1 = addi %lhs, %rhs : index
+ scf.reduce.return %1 : index
+ }
+ scf.reduce(%i1) : index {
+ ^bb0(%lhs: index, %rhs: index):
+ %2 = muli %lhs, %rhs : index
+ scf.reduce.return %2 : index
+ }
+ scf.yield
+ }
+ return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func @single_iteration_reduce(
+// CHECK-SAME: [[ARG0:%.*]]: index, [[ARG1:%.*]]: index)
+// CHECK-DAG: [[C3:%.*]] = constant 3 : index
+// CHECK-DAG: [[C1:%.*]] = constant 1 : index
+// CHECK-NOT: scf.parallel
+// CHECK-NOT: scf.reduce
+// CHECK-NOT: scf.reduce.return
+// CHECK-NOT: scf.yield
+// CHECK: [[V0:%.*]] = addi [[ARG0]], [[C1]]
+// CHECK: [[V1:%.*]] = muli [[ARG1]], [[C3]]
+// CHECK: return [[V0]], [[V1]]
+
+// -----
+
func private @side_effect()
func @one_unused(%cond: i1) -> (index) {
%c0 = constant 0 : index
@@ -488,7 +552,7 @@ func @fold_away_iter_with_no_use_and_yielded_input(%arg0 : i32,
%ub : index, %lb : index, %step : index) -> (i32, i32) {
// CHECK-NEXT: %[[C32:.*]] = constant 32 : i32
%cst = constant 32 : i32
- // CHECK-NEXT: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args({{.*}} = %[[A0]]) -> (i32) {
+ // CHECK-NEXT: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args({{.*}} = %[[A0]]) -> (i32) {
%0:2 = scf.for %arg1 = %lb to %ub step %step iter_args(%arg2 = %arg0, %arg3 = %cst)
-> (i32, i32) {
%1 = addi %arg2, %cst : i32
@@ -512,7 +576,7 @@ func @fold_away_iter_and_result_with_no_use(%arg0 : i32,
%1 = addi %arg2, %cst : i32
scf.yield %1, %1 : i32, i32
}
-
+
// CHECK: return %[[FOR_RES]] : i32
return %0#0 : i32
}
More information about the Mlir-commits
mailing list