[Mlir-commits] [mlir] ca27260 - [MLIR] Add SCF.if Condition Canonicalizations

William S. Moses llvmlistbot at llvm.org
Mon Apr 26 17:13:31 PDT 2021


Author: William S. Moses
Date: 2021-04-26T20:13:08-04:00
New Revision: ca27260701e237a4470cc00f0791b93e78e5fed8

URL: https://github.com/llvm/llvm-project/commit/ca27260701e237a4470cc00f0791b93e78e5fed8
DIFF: https://github.com/llvm/llvm-project/commit/ca27260701e237a4470cc00f0791b93e78e5fed8.diff

LOG: [MLIR] Add SCF.if Condition Canonicalizations

Add two canoncalizations for scf.if.
  1) A canonicalization that allows users of a condition within an if to assume the condition
     is true if in the true region, etc.
  2) A canonicalization that removes yielded statements that are equivalent to the condition
     or its negation

Differential Revision: https://reviews.llvm.org/D101012

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir
    mlir/test/Transforms/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index fa4fb9ffef33..b3f4b166947d 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -1106,12 +1106,172 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
     return success();
   }
 };
+
+// Allow the true region of an if to assume the condition is true
+// and vice versa. For example:
+//
+//   scf.if %cmp {
+//      print(%cmp)
+//   }
+//
+//  becomes
+//
+//   scf.if %cmp {
+//      print(true)
+//   }
+//
+struct ConditionPropagation : public OpRewritePattern<IfOp> {
+  using OpRewritePattern<IfOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IfOp op,
+                                PatternRewriter &rewriter) const override {
+    // Early exit if the condition is constant since replacing a constant
+    // in the body with another constant isn't a simplification.
+    if (op.condition().getDefiningOp<ConstantOp>())
+      return failure();
+
+    bool changed = false;
+    mlir::Type i1Ty = rewriter.getI1Type();
+
+    // These variables serve to prevent creating duplicate constants
+    // and hold constant true or false values.
+    Value constantTrue = nullptr;
+    Value constantFalse = nullptr;
+
+    for (OpOperand &use :
+         llvm::make_early_inc_range(op.condition().getUses())) {
+      if (op.thenRegion().isAncestor(use.getOwner()->getParentRegion())) {
+        changed = true;
+
+        if (!constantTrue)
+          constantTrue = rewriter.create<mlir::ConstantOp>(
+              op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
+
+        rewriter.updateRootInPlace(use.getOwner(),
+                                   [&]() { use.set(constantTrue); });
+      } else if (op.elseRegion().isAncestor(
+                     use.getOwner()->getParentRegion())) {
+        changed = true;
+
+        if (!constantFalse)
+          constantFalse = rewriter.create<mlir::ConstantOp>(
+              op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
+
+        rewriter.updateRootInPlace(use.getOwner(),
+                                   [&]() { use.set(constantFalse); });
+      }
+    }
+
+    return success(changed);
+  }
+};
+
+/// Remove any statements from an if that are equivalent to the condition
+/// or its negation. For example:
+///
+///    %res:2 = scf.if %cmp {
+///       yield something(), true
+///    } else {
+///       yield something2(), false
+///    }
+///    print(%res#1)
+///
+///  becomes
+///    %res = scf.if %cmp {
+///       yield something()
+///    } else {
+///       yield something2()
+///    }
+///    print(%cmp)
+///
+/// Additionally if both branches yield the same value, replace all uses
+/// of the result with the yielded value
+///
+///    %res:2 = scf.if %cmp {
+///       yield something(), %arg1
+///    } else {
+///       yield something2(), %arg1
+///    }
+///    print(%res#1)
+///
+///  becomes
+///    %res = scf.if %cmp {
+///       yield something()
+///    } else {
+///       yield something2()
+///    }
+//    print(%arg1)
+struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
+  using OpRewritePattern<IfOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IfOp op,
+                                PatternRewriter &rewriter) const override {
+    // Early exit if there are no results that could be replaced.
+    if (op.getNumResults() == 0)
+      return failure();
+
+    auto trueYield = cast<scf::YieldOp>(op.thenRegion().back().getTerminator());
+    auto falseYield =
+        cast<scf::YieldOp>(op.elseRegion().back().getTerminator());
+
+    rewriter.setInsertionPoint(op->getBlock(),
+                               op.getOperation()->getIterator());
+    bool changed = false;
+    Type i1Ty = rewriter.getI1Type();
+    for (auto tup :
+         llvm::zip(trueYield.results(), falseYield.results(), op.results())) {
+      Value trueResult, falseResult, opResult;
+      std::tie(trueResult, falseResult, opResult) = tup;
+
+      if (trueResult == falseResult) {
+        if (!opResult.use_empty()) {
+          opResult.replaceAllUsesWith(trueResult);
+          changed = true;
+        }
+        continue;
+      }
+
+      auto trueYield = trueResult.getDefiningOp<ConstantOp>();
+      if (!trueYield)
+        continue;
+
+      if (!trueYield.getType().isInteger(1))
+        continue;
+
+      auto falseYield = falseResult.getDefiningOp<ConstantOp>();
+      if (!falseYield)
+        continue;
+
+      bool trueVal = trueYield.getValue().cast<BoolAttr>().getValue();
+      bool falseVal = falseYield.getValue().cast<BoolAttr>().getValue();
+      if (!trueVal && falseVal) {
+        if (!opResult.use_empty()) {
+          Value notCond = rewriter.create<XOrOp>(
+              op.getLoc(), op.condition(),
+              rewriter.create<mlir::ConstantOp>(
+                  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)));
+          opResult.replaceAllUsesWith(notCond);
+          changed = true;
+        }
+      }
+      if (trueVal && !falseVal) {
+        if (!opResult.use_empty()) {
+          opResult.replaceAllUsesWith(op.condition());
+          changed = true;
+        }
+      }
+    }
+    return success(changed);
+  }
+};
+
 } // namespace
 
 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                        MLIRContext *context) {
-  results.add<RemoveUnusedResults, RemoveStaticCondition,
-              ConvertTrivialIfToSelect>(context);
+  results
+      .add<RemoveUnusedResults, RemoveStaticCondition, ConvertTrivialIfToSelect,
+           ConditionPropagation, ReplaceIfYieldWithConditionOrValue>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index d0d9e9c9a847..4dee3825d870 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -103,22 +103,25 @@ func private @side_effect()
 func @one_unused(%cond: i1) -> (index) {
   %c0 = constant 0 : index
   %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c3 = constant 3 : index
   %0, %1 = scf.if %cond -> (index, index) {
     call @side_effect() : () -> ()
     scf.yield %c0, %c1 : index, index
   } else {
-    scf.yield %c0, %c1 : index, index
+    scf.yield %c2, %c3 : index, index
   }
   return %1 : index
 }
 
 // CHECK-LABEL:   func @one_unused
 // CHECK:           [[C0:%.*]] = constant 1 : index
+// CHECK:           [[C3:%.*]] = constant 3 : index
 // CHECK:           [[V0:%.*]] = scf.if %{{.*}} -> (index) {
 // CHECK:             call @side_effect() : () -> ()
 // CHECK:             scf.yield [[C0]] : index
 // CHECK:           } else
-// CHECK:             scf.yield [[C0]] : index
+// CHECK:             scf.yield [[C3]] : index
 // CHECK:           }
 // CHECK:           return [[V0]] : index
 
@@ -128,12 +131,14 @@ func private @side_effect()
 func @nested_unused(%cond1: i1, %cond2: i1) -> (index) {
   %c0 = constant 0 : index
   %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c3 = constant 3 : index
   %0, %1 = scf.if %cond1 -> (index, index) {
     %2, %3 = scf.if %cond2 -> (index, index) {
       call @side_effect() : () -> ()
       scf.yield %c0, %c1 : index, index
     } else {
-      scf.yield %c0, %c1 : index, index
+      scf.yield %c2, %c3 : index, index
     }
     scf.yield %2, %3 : index, index
   } else {
@@ -144,12 +149,13 @@ func @nested_unused(%cond1: i1, %cond2: i1) -> (index) {
 
 // CHECK-LABEL:   func @nested_unused
 // CHECK:           [[C0:%.*]] = constant 1 : index
+// CHECK:           [[C3:%.*]] = constant 3 : index
 // CHECK:           [[V0:%.*]] = scf.if {{.*}} -> (index) {
 // CHECK:             [[V1:%.*]] = scf.if {{.*}} -> (index) {
 // CHECK:               call @side_effect() : () -> ()
 // CHECK:               scf.yield [[C0]] : index
 // CHECK:             } else
-// CHECK:               scf.yield [[C0]] : index
+// CHECK:               scf.yield [[C3]] : index
 // CHECK:             }
 // CHECK:             scf.yield [[V1]] : index
 // CHECK:           } else
@@ -610,3 +616,111 @@ func @matmul_on_tensors(%t0: tensor<32x1024xf32>, %t1: tensor<1024x1024xf32>) ->
   %res = subtensor_insert %2 into %t1[0, 0] [32, 1024] [1, 1] : tensor<32x1024xf32> into tensor<1024x1024xf32>
   return %res : tensor<1024x1024xf32>
 }
+
+
+
+// CHECK-LABEL: @cond_prop
+func @cond_prop(%arg0 : i1) -> index {
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c3 = constant 3 : index
+  %c4 = constant 4 : index
+  %res = scf.if %arg0 -> index {
+    %res1 = scf.if %arg0 -> index {
+      %v1 = "test.get_some_value"() : () -> i32
+      scf.yield %c1 : index
+    } else {
+      %v2 = "test.get_some_value"() : () -> i32
+      scf.yield %c2 : index
+    } 
+    scf.yield %res1 : index
+  } else {
+    %res2 = scf.if %arg0 -> index {
+      %v3 = "test.get_some_value"() : () -> i32
+      scf.yield %c3 : index
+    } else {
+      %v4 = "test.get_some_value"() : () -> i32
+      scf.yield %c4 : index
+    } 
+    scf.yield %res2 : index
+  }
+  return %res : index
+}
+// CHECK-DAG:  %[[c1:.+]] = constant 1 : index
+// CHECK-DAG:  %[[c4:.+]] = constant 4 : index
+// CHECK-NEXT:  %[[if:.+]] = scf.if %arg0 -> (index) {
+// CHECK-NEXT:    %{{.+}} = "test.get_some_value"() : () -> i32
+// CHECK-NEXT:    scf.yield %[[c1]] : index
+// CHECK-NEXT:  } else {
+// CHECK-NEXT:    %{{.+}} = "test.get_some_value"() : () -> i32
+// CHECK-NEXT:    scf.yield %[[c4]] : index
+// CHECK-NEXT:  }
+// CHECK-NEXT:  return %[[if]] : index
+// CHECK-NEXT:}
+
+// CHECK-LABEL: @replace_if_with_cond1
+func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) {
+  %true = constant true
+  %false = constant false
+  %res:2 = scf.if %arg0 -> (i32, i1) {
+    %v = "test.get_some_value"() : () -> i32
+    scf.yield %v, %true : i32, i1
+  } else {
+    %v2 = "test.get_some_value"() : () -> i32
+    scf.yield %v2, %false : i32, i1
+  }
+  return %res#0, %res#1 : i32, i1
+}
+// CHECK-NEXT:    %[[if:.+]] = scf.if %arg0 -> (i32) {
+// CHECK-NEXT:      %[[sv1:.+]] = "test.get_some_value"() : () -> i32
+// CHECK-NEXT:      scf.yield %[[sv1]] : i32
+// CHECK-NEXT:    } else {
+// CHECK-NEXT:      %[[sv2:.+]] = "test.get_some_value"() : () -> i32
+// CHECK-NEXT:      scf.yield %[[sv2]] : i32
+// CHECK-NEXT:    }
+// CHECK-NEXT:    return %[[if]], %arg0 : i32, i1
+
+// CHECK-LABEL: @replace_if_with_cond2
+func @replace_if_with_cond2(%arg0 : i1) -> (i32, i1) {
+  %true = constant true
+  %false = constant false
+  %res:2 = scf.if %arg0 -> (i32, i1) {
+    %v = "test.get_some_value"() : () -> i32
+    scf.yield %v, %false : i32, i1
+  } else {
+    %v2 = "test.get_some_value"() : () -> i32
+    scf.yield %v2, %true : i32, i1
+  }
+  return %res#0, %res#1 : i32, i1
+}
+// CHECK-NEXT:     %true = constant true
+// CHECK-NEXT:     %[[toret:.+]] = xor %arg0, %true : i1
+// CHECK-NEXT:     %[[if:.+]] = scf.if %arg0 -> (i32) {
+// CHECK-NEXT:       %[[sv1:.+]] = "test.get_some_value"() : () -> i32
+// CHECK-NEXT:       scf.yield %[[sv1]] : i32
+// CHECK-NEXT:     } else {
+// CHECK-NEXT:       %[[sv2:.+]] = "test.get_some_value"() : () -> i32
+// CHECK-NEXT:       scf.yield %[[sv2]] : i32
+// CHECK-NEXT:     }
+// CHECK-NEXT:     return %[[if]], %[[toret]] : i32, i1
+
+
+// CHECK-LABEL: @replace_if_with_cond3
+func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
+  %res:2 = scf.if %arg0 -> (i32, i64) {
+    %v = "test.get_some_value"() : () -> i32
+    scf.yield %v, %arg2 : i32, i64
+  } else {
+    %v2 = "test.get_some_value"() : () -> i32
+    scf.yield %v2, %arg2 : i32, i64
+  }
+  return %res#0, %res#1 : i32, i64
+}
+// CHECK-NEXT:     %[[if:.+]] = scf.if %arg0 -> (i32) {
+// CHECK-NEXT:       %[[sv1:.+]] = "test.get_some_value"() : () -> i32
+// CHECK-NEXT:       scf.yield %[[sv1]] : i32
+// CHECK-NEXT:     } else {
+// CHECK-NEXT:       %[[sv2:.+]] = "test.get_some_value"() : () -> i32
+// CHECK-NEXT:       scf.yield %[[sv2]] : i32
+// CHECK-NEXT:     }
+// CHECK-NEXT:     return %[[if]], %arg1 : i32, i64

diff  --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 49a35d162b2b..f83d7b0cfca3 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -1198,11 +1198,12 @@ func @clone_loop_alloc(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<2
 // -----
 
 // CHECK-LABEL: func @clone_nested_region
-func @clone_nested_region(%arg0: index, %arg1: index) -> memref<?x?xf32> {
+func @clone_nested_region(%arg0: index, %arg1: index, %arg2: index) -> memref<?x?xf32> {
+  %cmp = cmpi eq, %arg0, %arg1 : index
   %0 = cmpi eq, %arg0, %arg1 : index
   %1 = memref.alloc(%arg0, %arg0) : memref<?x?xf32>
   %2 = scf.if %0 -> (memref<?x?xf32>) {
-    %3 = scf.if %0 -> (memref<?x?xf32>) {
+    %3 = scf.if %cmp -> (memref<?x?xf32>) {
       %9 = memref.clone %1 : memref<?x?xf32> to memref<?x?xf32>
       scf.yield %9 : memref<?x?xf32>
     } else {


        


More information about the Mlir-commits mailing list