[Mlir-commits] [mlir] 31a233d - [mlir] canonicalize away zero-iteration SCF for loops
Alex Zinenko
llvmlistbot at llvm.org
Mon Nov 23 06:04:40 PST 2020
Author: Alex Zinenko
Date: 2020-11-23T15:04:31+01:00
New Revision: 31a233d46367636f94c487b51aa2931a1cc9cf79
URL: https://github.com/llvm/llvm-project/commit/31a233d46367636f94c487b51aa2931a1cc9cf79
DIFF: https://github.com/llvm/llvm-project/commit/31a233d46367636f94c487b51aa2931a1cc9cf79.diff
LOG: [mlir] canonicalize away zero-iteration SCF for loops
An SCF 'for' loop does not iterate if its lower bound is equal to its upper
bound. Remove loops where both bounds are the same SSA value as such bounds are
guaranteed to be equal. Similarly, remove 'parallel' loops where at least one
pair of respective lower/upper bounds is specified by the same SSA value.
Reviewed By: gysit
Differential Revision: https://reviews.llvm.org/D91880
Added:
Modified:
mlir/lib/Dialect/SCF/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 5da9f7c29cab..48b1b473f86d 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -521,6 +521,13 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
LogicalResult matchAndRewrite(ForOp op,
PatternRewriter &rewriter) const override {
+ // If the upper bound is the same as the lower bound, the loop does not
+ // iterate, just remove it.
+ if (op.lowerBound() == op.upperBound()) {
+ rewriter.replaceOp(op, op.getIterOperands());
+ return success();
+ }
+
auto lb = op.lowerBound().getDefiningOp<ConstantOp>();
auto ub = op.upperBound().getDefiningOp<ConstantOp>();
if (!lb || !ub)
@@ -1066,11 +1073,30 @@ struct CollapseSingleIterationLoops : public OpRewritePattern<ParallelOp> {
return success();
}
};
+
+/// Removes parallel loops in which at least one lower/upper bound pair consists
+/// of the same values - such loops have an empty iteration domain.
+struct RemoveEmptyParallelLoops : public OpRewritePattern<ParallelOp> {
+ using OpRewritePattern<ParallelOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ParallelOp op,
+ PatternRewriter &rewriter) const override {
+ for (auto dim : llvm::zip(op.lowerBound(), op.upperBound())) {
+ if (std::get<0>(dim) == std::get<1>(dim)) {
+ rewriter.replaceOp(op, op.initVals());
+ return success();
+ }
+ }
+ return failure();
+ }
+};
+
} // namespace
void ParallelOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.insert<CollapseSingleIterationLoops>(context);
+ results.insert<CollapseSingleIterationLoops, RemoveEmptyParallelLoops>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index faac86b94cdb..d57563461241 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -32,30 +32,6 @@ func @single_iteration(%A: memref<?x?x?xi32>) {
// -----
-func @no_iteration(%A: memref<?x?xi32>) {
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- scf.parallel (%i0, %i1) = (%c0, %c0) to (%c1, %c0) step (%c1, %c1) {
- %c42 = constant 42 : i32
- store %c42, %A[%i0, %i1] : memref<?x?xi32>
- scf.yield
- }
- return
-}
-
-// CHECK-LABEL: func @no_iteration(
-// CHECK-SAME: [[ARG0:%.*]]: memref<?x?xi32>) {
-// CHECK: [[C0:%.*]] = constant 0 : index
-// CHECK: [[C1:%.*]] = constant 1 : index
-// CHECK: [[C42:%.*]] = constant 42 : i32
-// CHECK: scf.parallel ([[V1:%.*]]) = ([[C0]]) to ([[C0]]) step ([[C1]]) {
-// CHECK: store [[C42]], [[ARG0]]{{\[}}[[C0]], [[V1]]] : memref<?x?xi32>
-// CHECK: scf.yield
-// CHECK: }
-// CHECK: return
-
-// -----
-
func @one_unused(%cond: i1) -> (index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
@@ -241,6 +217,22 @@ func @remove_zero_iteration_loop() {
return
}
+// CHECK-LABEL: @remove_zero_iteration_loop_vals
+func @remove_zero_iteration_loop_vals(%arg0: index) {
+ %c2 = constant 2 : index
+ // CHECK: %[[INIT:.*]] = "test.init"
+ %init = "test.init"() : () -> i32
+ // CHECK-NOT: scf.for
+ // CHECK-NOT: test.op
+ %0 = scf.for %i = %arg0 to %arg0 step %c2 iter_args(%arg = %init) -> (i32) {
+ %1 = "test.op"(%i, %arg) : (index, i32) -> i32
+ scf.yield %1 : i32
+ }
+ // CHECK: "test.consume"(%[[INIT]])
+ "test.consume"(%0) : (i32) -> ()
+ return
+}
+
// CHECK-LABEL: @replace_single_iteration_loop
func @replace_single_iteration_loop() {
// CHECK: %[[LB:.*]] = constant 42
@@ -278,3 +270,24 @@ func @replace_single_iteration_loop_non_unit_step() {
"test.consume"(%0) : (i32) -> ()
return
}
+
+// CHECK-LABEL: @remove_empty_parallel_loop
+func @remove_empty_parallel_loop(%lb: index, %ub: index, %s: index) {
+ // CHECK: %[[INIT:.*]] = "test.init"
+ %init = "test.init"() : () -> f32
+ // CHECK-NOT: scf.parallel
+ // CHECK-NOT: test.produce
+ // CHECK-NOT: test.transform
+ %0 = scf.parallel (%i, %j, %k) = (%lb, %ub, %lb) to (%ub, %ub, %ub) step (%s, %s, %s) init(%init) -> f32 {
+ %1 = "test.produce"() : () -> f32
+ scf.reduce(%1) : f32 {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %2 = "test.transform"(%lhs, %rhs) : (f32, f32) -> f32
+ scf.reduce.return %2 : f32
+ }
+ scf.yield
+ }
+ // CHECK: "test.consume"(%[[INIT]])
+ "test.consume"(%0) : (f32) -> ()
+ return
+}
More information about the Mlir-commits
mailing list