[Mlir-commits] [mlir] 18d0f7d - [mlir] add canonicalization patterns for trivial SCF 'for' and 'if'

Alex Zinenko llvmlistbot at llvm.org
Fri Nov 20 10:04:48 PST 2020


Author: Alex Zinenko
Date: 2020-11-20T19:04:39+01:00
New Revision: 18d0f7d5c3b0f8fae2cb6cd5db3977df26e3533f

URL: https://github.com/llvm/llvm-project/commit/18d0f7d5c3b0f8fae2cb6cd5db3977df26e3533f
DIFF: https://github.com/llvm/llvm-project/commit/18d0f7d5c3b0f8fae2cb6cd5db3977df26e3533f.diff

LOG: [mlir] add canonicalization patterns for trivial SCF 'for' and 'if'

Add canoncalization patterns to remove zero-iteration 'for' loops, replace
single-iteration 'for' loops with their bodies; remove known-false conditionals
with no 'else' branch and replace conditionals with known value by the
respective region. Although similar transformations are performed at the CFG
level, not all flows reach that level, e.g., the GPU flow may want to remove
single-iteration loops before deciding on loop mapping to thread dimensions.

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D91865

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index fe2eb9ced469..5da9f7c29cab 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -393,6 +393,19 @@ LoopNest mlir::scf::buildLoopNest(
                        });
 }
 
+/// Replaces the given op with the contents of the given single-block region,
+/// using the operands of the block terminator to replace operation results.
+static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
+                                Region &region, ValueRange blockArgs = {}) {
+  assert(llvm::hasSingleElement(region) && "expected single-region block");
+  Block *block = &region.front();
+  Operation *terminator = block->getTerminator();
+  ValueRange results = terminator->getOperands();
+  rewriter.mergeBlockBefore(block, op, blockArgs);
+  rewriter.replaceOp(op, results);
+  rewriter.eraseOp(terminator);
+}
+
 namespace {
 // Fold away ForOp iter arguments that are also yielded by the op.
 // These arguments must be defined outside of the ForOp region and can just be
@@ -500,11 +513,51 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
     return success();
   }
 };
+
+/// Rewriting pattern that erases loops that are known not to iterate and
+/// replaces single-iteration loops with their bodies.
+struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
+  using OpRewritePattern<ForOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ForOp op,
+                                PatternRewriter &rewriter) const override {
+    auto lb = op.lowerBound().getDefiningOp<ConstantOp>();
+    auto ub = op.upperBound().getDefiningOp<ConstantOp>();
+    if (!lb || !ub)
+      return failure();
+
+    // If the loop is known to have 0 iterations, remove it.
+    llvm::APInt lbValue = lb.getValue().cast<IntegerAttr>().getValue();
+    llvm::APInt ubValue = ub.getValue().cast<IntegerAttr>().getValue();
+    if (lbValue.sge(ubValue)) {
+      rewriter.replaceOp(op, op.getIterOperands());
+      return success();
+    }
+
+    auto step = op.step().getDefiningOp<ConstantOp>();
+    if (!step)
+      return failure();
+
+    // If the loop is known to have 1 iteration, inline its body and remove the
+    // loop.
+    llvm::APInt stepValue = lb.getValue().cast<IntegerAttr>().getValue();
+    if ((lbValue + stepValue).sge(ubValue)) {
+      SmallVector<Value, 4> blockArgs;
+      blockArgs.reserve(op.getNumIterOperands() + 1);
+      blockArgs.push_back(op.lowerBound());
+      llvm::append_range(blockArgs, op.getIterOperands());
+      replaceOpWithRegion(rewriter, op, op.getLoopBody(), blockArgs);
+      return success();
+    }
+
+    return failure();
+  }
+};
 } // namespace
 
 void ForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                         MLIRContext *context) {
-  results.insert<ForOpIterArgsFolder>(context);
+  results.insert<ForOpIterArgsFolder, SimplifyTrivialLoops>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -710,11 +763,31 @@ struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
     return success();
   }
 };
+
+struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
+  using OpRewritePattern<IfOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IfOp op,
+                                PatternRewriter &rewriter) const override {
+    auto constant = op.condition().getDefiningOp<ConstantOp>();
+    if (!constant)
+      return failure();
+
+    if (constant.getValue().cast<BoolAttr>().getValue())
+      replaceOpWithRegion(rewriter, op, op.thenRegion());
+    else if (!op.elseRegion().empty())
+      replaceOpWithRegion(rewriter, op, op.elseRegion());
+    else
+      rewriter.eraseOp(op);
+
+    return success();
+  }
+};
 } // namespace
 
 void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                        MLIRContext *context) {
-  results.insert<RemoveUnusedResults>(context);
+  results.insert<RemoveUnusedResults, RemoveStaticCondition>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index dd44e3d2933a..faac86b94cdb 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -56,11 +56,10 @@ func @no_iteration(%A: memref<?x?xi32>) {
 
 // -----
 
-func @one_unused() -> (index) {
+func @one_unused(%cond: i1) -> (index) {
   %c0 = constant 0 : index
   %c1 = constant 1 : index
-  %true = constant true
-  %0, %1 = scf.if %true -> (index, index) {
+  %0, %1 = scf.if %cond -> (index, index) {
     scf.yield %c0, %c1 : index, index
   } else {
     scf.yield %c0, %c1 : index, index
@@ -70,8 +69,7 @@ func @one_unused() -> (index) {
 
 // CHECK-LABEL:   func @one_unused
 // CHECK:           [[C0:%.*]] = constant 1 : index
-// CHECK:           [[C1:%.*]] = constant true
-// CHECK:           [[V0:%.*]] = scf.if [[C1]] -> (index) {
+// CHECK:           [[V0:%.*]] = scf.if %{{.*}} -> (index) {
 // CHECK:             scf.yield [[C0]] : index
 // CHECK:           } else
 // CHECK:             scf.yield [[C0]] : index
@@ -80,12 +78,11 @@ func @one_unused() -> (index) {
 
 // -----
 
-func @nested_unused() -> (index) {
+func @nested_unused(%cond1: i1, %cond2: i1) -> (index) {
   %c0 = constant 0 : index
   %c1 = constant 1 : index
-  %true = constant true
-  %0, %1 = scf.if %true -> (index, index) {
-    %2, %3 = scf.if %true -> (index, index) {
+  %0, %1 = scf.if %cond1 -> (index, index) {
+    %2, %3 = scf.if %cond2 -> (index, index) {
       scf.yield %c0, %c1 : index, index
     } else {
       scf.yield %c0, %c1 : index, index
@@ -99,9 +96,8 @@ func @nested_unused() -> (index) {
 
 // CHECK-LABEL:   func @nested_unused
 // CHECK:           [[C0:%.*]] = constant 1 : index
-// CHECK:           [[C1:%.*]] = constant true
-// CHECK:           [[V0:%.*]] = scf.if [[C1]] -> (index) {
-// CHECK:             [[V1:%.*]] = scf.if [[C1]] -> (index) {
+// CHECK:           [[V0:%.*]] = scf.if {{.*}} -> (index) {
+// CHECK:             [[V1:%.*]] = scf.if {{.*}} -> (index) {
 // CHECK:               scf.yield [[C0]] : index
 // CHECK:             } else
 // CHECK:               scf.yield [[C0]] : index
@@ -115,11 +111,10 @@ func @nested_unused() -> (index) {
 // -----
 
 func private @side_effect() {}
-func @all_unused() {
+func @all_unused(%cond: i1) {
   %c0 = constant 0 : index
   %c1 = constant 1 : index
-  %true = constant true
-  %0, %1 = scf.if %true -> (index, index) {
+  %0, %1 = scf.if %cond -> (index, index) {
     call @side_effect() : () -> ()
     scf.yield %c0, %c1 : index, index
   } else {
@@ -130,8 +125,7 @@ func @all_unused() {
 }
 
 // CHECK-LABEL:   func @all_unused
-// CHECK:           [[C1:%.*]] = constant true
-// CHECK:           scf.if [[C1]] {
+// CHECK:           scf.if %{{.*}} {
 // CHECK:             call @side_effect() : () -> ()
 // CHECK:           } else
 // CHECK:             call @side_effect() : () -> ()
@@ -172,3 +166,115 @@ func @for_yields_3(%lb : index, %ub : index, %step : index) -> (i32, i32, i32) {
 //  CHECK-NEXT:       scf.yield %[[c]] : i32
 //  CHECK-NEXT:     }
 //  CHECK-NEXT:     return %[[a]], %[[r1]], %[[b]] : i32, i32, i32
+
+// CHECK-LABEL: @replace_true_if
+func @replace_true_if() {
+  %true = constant true
+  // CHECK-NOT: scf.if
+  // CHECK: "test.op"
+  scf.if %true {
+    "test.op"() : () -> ()
+    scf.yield
+  }
+  return
+}
+
+// CHECK-LABEL: @remove_false_if
+func @remove_false_if() {
+  %false = constant false
+  // CHECK-NOT: scf.if
+  // CHECK-NOT: "test.op"
+  scf.if %false {
+    "test.op"() : () -> ()
+    scf.yield
+  }
+  return
+}
+
+// CHECK-LABEL: @replace_true_if_with_values
+func @replace_true_if_with_values() {
+  %true = constant true
+  // CHECK-NOT: scf.if
+  // CHECK: %[[VAL:.*]] = "test.op"
+  %0 = scf.if %true -> (i32) {
+    %1 = "test.op"() : () -> i32
+    scf.yield %1 : i32
+  } else {
+    %2 = "test.other_op"() : () -> i32
+    scf.yield %2 : i32
+  }
+  // CHECK: "test.consume"(%[[VAL]])
+  "test.consume"(%0) : (i32) -> ()
+  return
+}
+
+// CHECK-LABEL: @replace_false_if_with_values
+func @replace_false_if_with_values() {
+  %false = constant false
+  // CHECK-NOT: scf.if
+  // CHECK: %[[VAL:.*]] = "test.other_op"
+  %0 = scf.if %false -> (i32) {
+    %1 = "test.op"() : () -> i32
+    scf.yield %1 : i32
+  } else {
+    %2 = "test.other_op"() : () -> i32
+    scf.yield %2 : i32
+  }
+  // CHECK: "test.consume"(%[[VAL]])
+  "test.consume"(%0) : (i32) -> ()
+  return
+}
+
+// CHECK-LABEL: @remove_zero_iteration_loop
+func @remove_zero_iteration_loop() {
+  %c42 = constant 42 : index
+  %c1 = constant 1 : index
+  // CHECK: %[[INIT:.*]] = "test.init"
+  %init = "test.init"() : () -> i32
+  // CHECK-NOT: scf.for
+  %0 = scf.for %i = %c42 to %c1 step %c1 iter_args(%arg = %init) -> (i32) {
+    %1 = "test.op"(%i, %arg) : (index, i32) -> i32
+    scf.yield %1 : i32
+  }
+  // CHECK: "test.consume"(%[[INIT]])
+  "test.consume"(%0) : (i32) -> ()
+  return
+}
+
+// CHECK-LABEL: @replace_single_iteration_loop
+func @replace_single_iteration_loop() {
+  // CHECK: %[[LB:.*]] = constant 42
+  %c42 = constant 42 : index
+  %c43 = constant 43 : index
+  %c1 = constant 1 : index
+  // CHECK: %[[INIT:.*]] = "test.init"
+  %init = "test.init"() : () -> i32
+  // CHECK-NOT: scf.for
+  // CHECK: %[[VAL:.*]] = "test.op"(%[[LB]], %[[INIT]])
+  %0 = scf.for %i = %c42 to %c43 step %c1 iter_args(%arg = %init) -> (i32) {
+    %1 = "test.op"(%i, %arg) : (index, i32) -> i32
+    scf.yield %1 : i32
+  }
+  // CHECK: "test.consume"(%[[VAL]])
+  "test.consume"(%0) : (i32) -> ()
+  return
+}
+
+// CHECK-LABEL: @replace_single_iteration_loop_non_unit_step
+func @replace_single_iteration_loop_non_unit_step() {
+  // CHECK: %[[LB:.*]] = constant 42
+  %c42 = constant 42 : index
+  %c47 = constant 47 : index
+  %c5 = constant 5 : index
+  // CHECK: %[[INIT:.*]] = "test.init"
+  %init = "test.init"() : () -> i32
+  // CHECK-NOT: scf.for
+  // CHECK: %[[VAL:.*]] = "test.op"(%[[LB]], %[[INIT]])
+  %0 = scf.for %i = %c42 to %c47 step %c5 iter_args(%arg = %init) -> (i32) {
+    %1 = "test.op"(%i, %arg) : (index, i32) -> i32
+    scf.yield %1 : i32
+  }
+  // CHECK: "test.consume"(%[[VAL]])
+  "test.consume"(%0) : (i32) -> ()
+  return
+}


        


More information about the Mlir-commits mailing list