[llvm-branch-commits] [mlir] 31a233d - [mlir] canonicalize away zero-iteration SCF for loops

Alex Zinenko via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Nov 23 06:09:24 PST 2020


Author: Alex Zinenko
Date: 2020-11-23T15:04:31+01:00
New Revision: 31a233d46367636f94c487b51aa2931a1cc9cf79

URL: https://github.com/llvm/llvm-project/commit/31a233d46367636f94c487b51aa2931a1cc9cf79
DIFF: https://github.com/llvm/llvm-project/commit/31a233d46367636f94c487b51aa2931a1cc9cf79.diff

LOG: [mlir] canonicalize away zero-iteration SCF for loops

An SCF 'for' loop does not iterate if its lower bound is equal to its upper
bound. Remove loops where both bounds are the same SSA value as such bounds are
guaranteed to be equal. Similarly, remove 'parallel' loops where at least one
pair of respective lower/upper bounds is specified by the same SSA value.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D91880

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 5da9f7c29cab..48b1b473f86d 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -521,6 +521,13 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
 
   LogicalResult matchAndRewrite(ForOp op,
                                 PatternRewriter &rewriter) const override {
+    // If the upper bound is the same as the lower bound, the loop does not
+    // iterate, just remove it.
+    if (op.lowerBound() == op.upperBound()) {
+      rewriter.replaceOp(op, op.getIterOperands());
+      return success();
+    }
+
     auto lb = op.lowerBound().getDefiningOp<ConstantOp>();
     auto ub = op.upperBound().getDefiningOp<ConstantOp>();
     if (!lb || !ub)
@@ -1066,11 +1073,30 @@ struct CollapseSingleIterationLoops : public OpRewritePattern<ParallelOp> {
     return success();
   }
 };
+
+/// Removes parallel loops in which at least one lower/upper bound pair consists
+/// of the same values - such loops have an empty iteration domain.
+struct RemoveEmptyParallelLoops : public OpRewritePattern<ParallelOp> {
+  using OpRewritePattern<ParallelOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ParallelOp op,
+                                PatternRewriter &rewriter) const override {
+    for (auto dim : llvm::zip(op.lowerBound(), op.upperBound())) {
+      if (std::get<0>(dim) == std::get<1>(dim)) {
+        rewriter.replaceOp(op, op.initVals());
+        return success();
+      }
+    }
+    return failure();
+  }
+};
+
 } // namespace
 
 void ParallelOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                              MLIRContext *context) {
-  results.insert<CollapseSingleIterationLoops>(context);
+  results.insert<CollapseSingleIterationLoops, RemoveEmptyParallelLoops>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index faac86b94cdb..d57563461241 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -32,30 +32,6 @@ func @single_iteration(%A: memref<?x?x?xi32>) {
 
 // -----
 
-func @no_iteration(%A: memref<?x?xi32>) {
-  %c0 = constant 0 : index
-  %c1 = constant 1 : index
-  scf.parallel (%i0, %i1) = (%c0, %c0) to (%c1, %c0) step (%c1, %c1) {
-    %c42 = constant 42 : i32
-    store %c42, %A[%i0, %i1] : memref<?x?xi32>
-    scf.yield
-  }
-  return
-}
-
-// CHECK-LABEL:   func @no_iteration(
-// CHECK-SAME:                        [[ARG0:%.*]]: memref<?x?xi32>) {
-// CHECK:           [[C0:%.*]] = constant 0 : index
-// CHECK:           [[C1:%.*]] = constant 1 : index
-// CHECK:           [[C42:%.*]] = constant 42 : i32
-// CHECK:           scf.parallel ([[V1:%.*]]) = ([[C0]]) to ([[C0]]) step ([[C1]]) {
-// CHECK:             store [[C42]], [[ARG0]]{{\[}}[[C0]], [[V1]]] : memref<?x?xi32>
-// CHECK:             scf.yield
-// CHECK:           }
-// CHECK:           return
-
-// -----
-
 func @one_unused(%cond: i1) -> (index) {
   %c0 = constant 0 : index
   %c1 = constant 1 : index
@@ -241,6 +217,22 @@ func @remove_zero_iteration_loop() {
   return
 }
 
+// CHECK-LABEL: @remove_zero_iteration_loop_vals
+func @remove_zero_iteration_loop_vals(%arg0: index) {
+  %c2 = constant 2 : index
+  // CHECK: %[[INIT:.*]] = "test.init"
+  %init = "test.init"() : () -> i32
+  // CHECK-NOT: scf.for
+  // CHECK-NOT: test.op
+  %0 = scf.for %i = %arg0 to %arg0 step %c2 iter_args(%arg = %init) -> (i32) {
+    %1 = "test.op"(%i, %arg) : (index, i32) -> i32
+    scf.yield %1 : i32
+  }
+  // CHECK: "test.consume"(%[[INIT]])
+  "test.consume"(%0) : (i32) -> ()
+  return
+}
+
 // CHECK-LABEL: @replace_single_iteration_loop
 func @replace_single_iteration_loop() {
   // CHECK: %[[LB:.*]] = constant 42
@@ -278,3 +270,24 @@ func @replace_single_iteration_loop_non_unit_step() {
   "test.consume"(%0) : (i32) -> ()
   return
 }
+
+// CHECK-LABEL: @remove_empty_parallel_loop
+func @remove_empty_parallel_loop(%lb: index, %ub: index, %s: index) {
+  // CHECK: %[[INIT:.*]] = "test.init"
+  %init = "test.init"() : () -> f32
+  // CHECK-NOT: scf.parallel
+  // CHECK-NOT: test.produce
+  // CHECK-NOT: test.transform
+  %0 = scf.parallel (%i, %j, %k) = (%lb, %ub, %lb) to (%ub, %ub, %ub) step (%s, %s, %s) init(%init) -> f32 {
+    %1 = "test.produce"() : () -> f32
+    scf.reduce(%1) : f32 {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %2 = "test.transform"(%lhs, %rhs) : (f32, f32) -> f32
+      scf.reduce.return %2 : f32
+    }
+    scf.yield
+  }
+  // CHECK: "test.consume"(%[[INIT]])
+  "test.consume"(%0) : (f32) -> ()
+  return
+}


        


More information about the llvm-branch-commits mailing list