[Mlir-commits] [mlir] 5657f93 - [mlir] Canonicalize IfOp with trivial `then` and `else` bodies to list of SelectOp's

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Mar 20 02:19:49 PDT 2021


Author: Butygin
Date: 2021-03-20T12:18:49+03:00
New Revision: 5657f93e788f093c70fb448dd6f9398b149df278

URL: https://github.com/llvm/llvm-project/commit/5657f93e788f093c70fb448dd6f9398b149df278
DIFF: https://github.com/llvm/llvm-project/commit/5657f93e788f093c70fb448dd6f9398b149df278.diff

LOG: [mlir] Canonicalize IfOp with trivial `then` and `else` bodies to list of SelectOp's

* Do we need a threshold on maximum number of Yeild arguments processed (maximum number of SelectOp's to be generated)?
* Had to modify some old IfOp tests to not get optimized by this pattern

Differential Revision: https://reviews.llvm.org/D98592

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index fdb9df82900c..78c72953ee6f 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -934,11 +934,49 @@ struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
     return success();
   }
 };
+
+struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
+  using OpRewritePattern<IfOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IfOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op->getNumResults() == 0)
+      return failure();
+
+    if (!llvm::hasSingleElement(op.thenRegion().front()) ||
+        !llvm::hasSingleElement(op.elseRegion().front()))
+      return failure();
+
+    auto cond = op.condition();
+    auto thenYieldArgs =
+        cast<scf::YieldOp>(op.thenRegion().front().getTerminator())
+            .getOperands();
+    auto elseYieldArgs =
+        cast<scf::YieldOp>(op.elseRegion().front().getTerminator())
+            .getOperands();
+    SmallVector<Value> results(op->getNumResults());
+    assert(thenYieldArgs.size() == results.size());
+    assert(elseYieldArgs.size() == results.size());
+    for (auto it : llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
+      Value trueVal = std::get<0>(it.value());
+      Value falseVal = std::get<1>(it.value());
+      if (trueVal == falseVal)
+        results[it.index()] = trueVal;
+      else
+        results[it.index()] =
+            rewriter.create<SelectOp>(op.getLoc(), cond, trueVal, falseVal);
+    }
+
+    rewriter.replaceOp(op, results);
+    return success();
+  }
+};
 } // namespace
 
 void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                        MLIRContext *context) {
-  results.insert<RemoveUnusedResults, RemoveStaticCondition>(context);
+  results.insert<RemoveUnusedResults, RemoveStaticCondition,
+                 ConvertTrivialIfToSelect>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index dffe9e252eb1..7c751623db86 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -35,10 +35,12 @@ func @single_iteration(%A: memref<?x?x?xi32>) {
 
 // -----
 
+func private @side_effect()
 func @one_unused(%cond: i1) -> (index) {
   %c0 = constant 0 : index
   %c1 = constant 1 : index
   %0, %1 = scf.if %cond -> (index, index) {
+    call @side_effect() : () -> ()
     scf.yield %c0, %c1 : index, index
   } else {
     scf.yield %c0, %c1 : index, index
@@ -49,6 +51,7 @@ func @one_unused(%cond: i1) -> (index) {
 // CHECK-LABEL:   func @one_unused
 // CHECK:           [[C0:%.*]] = constant 1 : index
 // CHECK:           [[V0:%.*]] = scf.if %{{.*}} -> (index) {
+// CHECK:             call @side_effect() : () -> ()
 // CHECK:             scf.yield [[C0]] : index
 // CHECK:           } else
 // CHECK:             scf.yield [[C0]] : index
@@ -57,11 +60,13 @@ func @one_unused(%cond: i1) -> (index) {
 
 // -----
 
+func private @side_effect()
 func @nested_unused(%cond1: i1, %cond2: i1) -> (index) {
   %c0 = constant 0 : index
   %c1 = constant 1 : index
   %0, %1 = scf.if %cond1 -> (index, index) {
     %2, %3 = scf.if %cond2 -> (index, index) {
+      call @side_effect() : () -> ()
       scf.yield %c0, %c1 : index, index
     } else {
       scf.yield %c0, %c1 : index, index
@@ -77,6 +82,7 @@ func @nested_unused(%cond1: i1, %cond2: i1) -> (index) {
 // CHECK:           [[C0:%.*]] = constant 1 : index
 // CHECK:           [[V0:%.*]] = scf.if {{.*}} -> (index) {
 // CHECK:             [[V1:%.*]] = scf.if {{.*}} -> (index) {
+// CHECK:               call @side_effect() : () -> ()
 // CHECK:               scf.yield [[C0]] : index
 // CHECK:             } else
 // CHECK:               scf.yield [[C0]] : index
@@ -113,6 +119,96 @@ func @all_unused(%cond: i1) {
 
 // -----
 
+func @empty_if1(%cond: i1) {
+  scf.if %cond {
+    scf.yield
+  }
+  return
+}
+
+// CHECK-LABEL:   func @empty_if1
+// CHECK-NOT:       scf.if
+// CHECK:           return
+
+// -----
+
+func @empty_if2(%cond: i1) {
+  scf.if %cond {
+    scf.yield
+  } else {
+    scf.yield
+  }
+  return
+}
+
+// CHECK-LABEL:   func @empty_if2
+// CHECK-NOT:       scf.if
+// CHECK:           return
+
+// -----
+
+func @to_select1(%cond: i1) -> index {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %0 = scf.if %cond -> index {
+    scf.yield %c0 : index
+  } else {
+    scf.yield %c1 : index
+  }
+  return %0 : index
+}
+
+// CHECK-LABEL:   func @to_select1
+// CHECK:           [[C0:%.*]] = constant 0 : index
+// CHECK:           [[C1:%.*]] = constant 1 : index
+// CHECK:           [[V0:%.*]] = select {{.*}}, [[C0]], [[C1]]
+// CHECK:           return [[V0]] : index
+
+// -----
+
+func @to_select_same_val(%cond: i1) -> (index, index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %0, %1 = scf.if %cond -> (index, index) {
+    scf.yield %c0, %c1 : index, index
+  } else {
+    scf.yield %c1, %c1 : index, index
+  }
+  return %0, %1 : index, index
+}
+
+// CHECK-LABEL:   func @to_select_same_val
+// CHECK:           [[C0:%.*]] = constant 0 : index
+// CHECK:           [[C1:%.*]] = constant 1 : index
+// CHECK:           [[V0:%.*]] = select {{.*}}, [[C0]], [[C1]]
+// CHECK:           return [[V0]], [[C1]] : index, index
+
+// -----
+
+func @to_select2(%cond: i1) -> (index, index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c3 = constant 3 : index
+  %0, %1 = scf.if %cond -> (index, index) {
+    scf.yield %c0, %c1 : index, index
+  } else {
+    scf.yield %c2, %c3 : index, index
+  }
+  return %0, %1 : index, index
+}
+
+// CHECK-LABEL:   func @to_select2
+// CHECK:           [[C0:%.*]] = constant 0 : index
+// CHECK:           [[C1:%.*]] = constant 1 : index
+// CHECK:           [[C2:%.*]] = constant 2 : index
+// CHECK:           [[C3:%.*]] = constant 3 : index
+// CHECK:           [[V0:%.*]] = select {{.*}}, [[C0]], [[C2]]
+// CHECK:           [[V1:%.*]] = select {{.*}}, [[C1]], [[C3]]
+// CHECK:           return [[V0]], [[V1]] : index
+
+// -----
+
 func private @make_i32() -> i32
 
 func @for_yields_2(%lb : index, %ub : index, %step : index) -> i32 {


        


More information about the Mlir-commits mailing list