[Mlir-commits] [mlir] ca27260 - [MLIR] Add SCF.if Condition Canonicalizations
William S. Moses
llvmlistbot at llvm.org
Mon Apr 26 17:13:31 PDT 2021
Author: William S. Moses
Date: 2021-04-26T20:13:08-04:00
New Revision: ca27260701e237a4470cc00f0791b93e78e5fed8
URL: https://github.com/llvm/llvm-project/commit/ca27260701e237a4470cc00f0791b93e78e5fed8
DIFF: https://github.com/llvm/llvm-project/commit/ca27260701e237a4470cc00f0791b93e78e5fed8.diff
LOG: [MLIR] Add SCF.if Condition Canonicalizations
Add two canoncalizations for scf.if.
1) A canonicalization that allows users of a condition within an if to assume the condition
is true if in the true region, etc.
2) A canonicalization that removes yielded statements that are equivalent to the condition
or its negation
Differential Revision: https://reviews.llvm.org/D101012
Added:
Modified:
mlir/lib/Dialect/SCF/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
mlir/test/Transforms/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index fa4fb9ffef33..b3f4b166947d 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -1106,12 +1106,172 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
return success();
}
};
+
+// Allow the true region of an if to assume the condition is true
+// and vice versa. For example:
+//
+// scf.if %cmp {
+// print(%cmp)
+// }
+//
+// becomes
+//
+// scf.if %cmp {
+// print(true)
+// }
+//
+struct ConditionPropagation : public OpRewritePattern<IfOp> {
+ using OpRewritePattern<IfOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IfOp op,
+ PatternRewriter &rewriter) const override {
+ // Early exit if the condition is constant since replacing a constant
+ // in the body with another constant isn't a simplification.
+ if (op.condition().getDefiningOp<ConstantOp>())
+ return failure();
+
+ bool changed = false;
+ mlir::Type i1Ty = rewriter.getI1Type();
+
+ // These variables serve to prevent creating duplicate constants
+ // and hold constant true or false values.
+ Value constantTrue = nullptr;
+ Value constantFalse = nullptr;
+
+ for (OpOperand &use :
+ llvm::make_early_inc_range(op.condition().getUses())) {
+ if (op.thenRegion().isAncestor(use.getOwner()->getParentRegion())) {
+ changed = true;
+
+ if (!constantTrue)
+ constantTrue = rewriter.create<mlir::ConstantOp>(
+ op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
+
+ rewriter.updateRootInPlace(use.getOwner(),
+ [&]() { use.set(constantTrue); });
+ } else if (op.elseRegion().isAncestor(
+ use.getOwner()->getParentRegion())) {
+ changed = true;
+
+ if (!constantFalse)
+ constantFalse = rewriter.create<mlir::ConstantOp>(
+ op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
+
+ rewriter.updateRootInPlace(use.getOwner(),
+ [&]() { use.set(constantFalse); });
+ }
+ }
+
+ return success(changed);
+ }
+};
+
+/// Remove any statements from an if that are equivalent to the condition
+/// or its negation. For example:
+///
+/// %res:2 = scf.if %cmp {
+/// yield something(), true
+/// } else {
+/// yield something2(), false
+/// }
+/// print(%res#1)
+///
+/// becomes
+/// %res = scf.if %cmp {
+/// yield something()
+/// } else {
+/// yield something2()
+/// }
+/// print(%cmp)
+///
+/// Additionally if both branches yield the same value, replace all uses
+/// of the result with the yielded value
+///
+/// %res:2 = scf.if %cmp {
+/// yield something(), %arg1
+/// } else {
+/// yield something2(), %arg1
+/// }
+/// print(%res#1)
+///
+/// becomes
+/// %res = scf.if %cmp {
+/// yield something()
+/// } else {
+/// yield something2()
+/// }
+// print(%arg1)
+struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
+ using OpRewritePattern<IfOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IfOp op,
+ PatternRewriter &rewriter) const override {
+ // Early exit if there are no results that could be replaced.
+ if (op.getNumResults() == 0)
+ return failure();
+
+ auto trueYield = cast<scf::YieldOp>(op.thenRegion().back().getTerminator());
+ auto falseYield =
+ cast<scf::YieldOp>(op.elseRegion().back().getTerminator());
+
+ rewriter.setInsertionPoint(op->getBlock(),
+ op.getOperation()->getIterator());
+ bool changed = false;
+ Type i1Ty = rewriter.getI1Type();
+ for (auto tup :
+ llvm::zip(trueYield.results(), falseYield.results(), op.results())) {
+ Value trueResult, falseResult, opResult;
+ std::tie(trueResult, falseResult, opResult) = tup;
+
+ if (trueResult == falseResult) {
+ if (!opResult.use_empty()) {
+ opResult.replaceAllUsesWith(trueResult);
+ changed = true;
+ }
+ continue;
+ }
+
+ auto trueYield = trueResult.getDefiningOp<ConstantOp>();
+ if (!trueYield)
+ continue;
+
+ if (!trueYield.getType().isInteger(1))
+ continue;
+
+ auto falseYield = falseResult.getDefiningOp<ConstantOp>();
+ if (!falseYield)
+ continue;
+
+ bool trueVal = trueYield.getValue().cast<BoolAttr>().getValue();
+ bool falseVal = falseYield.getValue().cast<BoolAttr>().getValue();
+ if (!trueVal && falseVal) {
+ if (!opResult.use_empty()) {
+ Value notCond = rewriter.create<XOrOp>(
+ op.getLoc(), op.condition(),
+ rewriter.create<mlir::ConstantOp>(
+ op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)));
+ opResult.replaceAllUsesWith(notCond);
+ changed = true;
+ }
+ }
+ if (trueVal && !falseVal) {
+ if (!opResult.use_empty()) {
+ opResult.replaceAllUsesWith(op.condition());
+ changed = true;
+ }
+ }
+ }
+ return success(changed);
+ }
+};
+
} // namespace
void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<RemoveUnusedResults, RemoveStaticCondition,
- ConvertTrivialIfToSelect>(context);
+ results
+ .add<RemoveUnusedResults, RemoveStaticCondition, ConvertTrivialIfToSelect,
+ ConditionPropagation, ReplaceIfYieldWithConditionOrValue>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index d0d9e9c9a847..4dee3825d870 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -103,22 +103,25 @@ func private @side_effect()
func @one_unused(%cond: i1) -> (index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %c3 = constant 3 : index
%0, %1 = scf.if %cond -> (index, index) {
call @side_effect() : () -> ()
scf.yield %c0, %c1 : index, index
} else {
- scf.yield %c0, %c1 : index, index
+ scf.yield %c2, %c3 : index, index
}
return %1 : index
}
// CHECK-LABEL: func @one_unused
// CHECK: [[C0:%.*]] = constant 1 : index
+// CHECK: [[C3:%.*]] = constant 3 : index
// CHECK: [[V0:%.*]] = scf.if %{{.*}} -> (index) {
// CHECK: call @side_effect() : () -> ()
// CHECK: scf.yield [[C0]] : index
// CHECK: } else
-// CHECK: scf.yield [[C0]] : index
+// CHECK: scf.yield [[C3]] : index
// CHECK: }
// CHECK: return [[V0]] : index
@@ -128,12 +131,14 @@ func private @side_effect()
func @nested_unused(%cond1: i1, %cond2: i1) -> (index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %c3 = constant 3 : index
%0, %1 = scf.if %cond1 -> (index, index) {
%2, %3 = scf.if %cond2 -> (index, index) {
call @side_effect() : () -> ()
scf.yield %c0, %c1 : index, index
} else {
- scf.yield %c0, %c1 : index, index
+ scf.yield %c2, %c3 : index, index
}
scf.yield %2, %3 : index, index
} else {
@@ -144,12 +149,13 @@ func @nested_unused(%cond1: i1, %cond2: i1) -> (index) {
// CHECK-LABEL: func @nested_unused
// CHECK: [[C0:%.*]] = constant 1 : index
+// CHECK: [[C3:%.*]] = constant 3 : index
// CHECK: [[V0:%.*]] = scf.if {{.*}} -> (index) {
// CHECK: [[V1:%.*]] = scf.if {{.*}} -> (index) {
// CHECK: call @side_effect() : () -> ()
// CHECK: scf.yield [[C0]] : index
// CHECK: } else
-// CHECK: scf.yield [[C0]] : index
+// CHECK: scf.yield [[C3]] : index
// CHECK: }
// CHECK: scf.yield [[V1]] : index
// CHECK: } else
@@ -610,3 +616,111 @@ func @matmul_on_tensors(%t0: tensor<32x1024xf32>, %t1: tensor<1024x1024xf32>) ->
%res = subtensor_insert %2 into %t1[0, 0] [32, 1024] [1, 1] : tensor<32x1024xf32> into tensor<1024x1024xf32>
return %res : tensor<1024x1024xf32>
}
+
+
+
+// CHECK-LABEL: @cond_prop
+func @cond_prop(%arg0 : i1) -> index {
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %c3 = constant 3 : index
+ %c4 = constant 4 : index
+ %res = scf.if %arg0 -> index {
+ %res1 = scf.if %arg0 -> index {
+ %v1 = "test.get_some_value"() : () -> i32
+ scf.yield %c1 : index
+ } else {
+ %v2 = "test.get_some_value"() : () -> i32
+ scf.yield %c2 : index
+ }
+ scf.yield %res1 : index
+ } else {
+ %res2 = scf.if %arg0 -> index {
+ %v3 = "test.get_some_value"() : () -> i32
+ scf.yield %c3 : index
+ } else {
+ %v4 = "test.get_some_value"() : () -> i32
+ scf.yield %c4 : index
+ }
+ scf.yield %res2 : index
+ }
+ return %res : index
+}
+// CHECK-DAG: %[[c1:.+]] = constant 1 : index
+// CHECK-DAG: %[[c4:.+]] = constant 4 : index
+// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (index) {
+// CHECK-NEXT: %{{.+}} = "test.get_some_value"() : () -> i32
+// CHECK-NEXT: scf.yield %[[c1]] : index
+// CHECK-NEXT: } else {
+// CHECK-NEXT: %{{.+}} = "test.get_some_value"() : () -> i32
+// CHECK-NEXT: scf.yield %[[c4]] : index
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[if]] : index
+// CHECK-NEXT:}
+
+// CHECK-LABEL: @replace_if_with_cond1
+func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) {
+ %true = constant true
+ %false = constant false
+ %res:2 = scf.if %arg0 -> (i32, i1) {
+ %v = "test.get_some_value"() : () -> i32
+ scf.yield %v, %true : i32, i1
+ } else {
+ %v2 = "test.get_some_value"() : () -> i32
+ scf.yield %v2, %false : i32, i1
+ }
+ return %res#0, %res#1 : i32, i1
+}
+// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (i32) {
+// CHECK-NEXT: %[[sv1:.+]] = "test.get_some_value"() : () -> i32
+// CHECK-NEXT: scf.yield %[[sv1]] : i32
+// CHECK-NEXT: } else {
+// CHECK-NEXT: %[[sv2:.+]] = "test.get_some_value"() : () -> i32
+// CHECK-NEXT: scf.yield %[[sv2]] : i32
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[if]], %arg0 : i32, i1
+
+// CHECK-LABEL: @replace_if_with_cond2
+func @replace_if_with_cond2(%arg0 : i1) -> (i32, i1) {
+ %true = constant true
+ %false = constant false
+ %res:2 = scf.if %arg0 -> (i32, i1) {
+ %v = "test.get_some_value"() : () -> i32
+ scf.yield %v, %false : i32, i1
+ } else {
+ %v2 = "test.get_some_value"() : () -> i32
+ scf.yield %v2, %true : i32, i1
+ }
+ return %res#0, %res#1 : i32, i1
+}
+// CHECK-NEXT: %true = constant true
+// CHECK-NEXT: %[[toret:.+]] = xor %arg0, %true : i1
+// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (i32) {
+// CHECK-NEXT: %[[sv1:.+]] = "test.get_some_value"() : () -> i32
+// CHECK-NEXT: scf.yield %[[sv1]] : i32
+// CHECK-NEXT: } else {
+// CHECK-NEXT: %[[sv2:.+]] = "test.get_some_value"() : () -> i32
+// CHECK-NEXT: scf.yield %[[sv2]] : i32
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[if]], %[[toret]] : i32, i1
+
+
+// CHECK-LABEL: @replace_if_with_cond3
+func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
+ %res:2 = scf.if %arg0 -> (i32, i64) {
+ %v = "test.get_some_value"() : () -> i32
+ scf.yield %v, %arg2 : i32, i64
+ } else {
+ %v2 = "test.get_some_value"() : () -> i32
+ scf.yield %v2, %arg2 : i32, i64
+ }
+ return %res#0, %res#1 : i32, i64
+}
+// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (i32) {
+// CHECK-NEXT: %[[sv1:.+]] = "test.get_some_value"() : () -> i32
+// CHECK-NEXT: scf.yield %[[sv1]] : i32
+// CHECK-NEXT: } else {
+// CHECK-NEXT: %[[sv2:.+]] = "test.get_some_value"() : () -> i32
+// CHECK-NEXT: scf.yield %[[sv2]] : i32
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[if]], %arg1 : i32, i64
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 49a35d162b2b..f83d7b0cfca3 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -1198,11 +1198,12 @@ func @clone_loop_alloc(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<2
// -----
// CHECK-LABEL: func @clone_nested_region
-func @clone_nested_region(%arg0: index, %arg1: index) -> memref<?x?xf32> {
+func @clone_nested_region(%arg0: index, %arg1: index, %arg2: index) -> memref<?x?xf32> {
+ %cmp = cmpi eq, %arg0, %arg1 : index
%0 = cmpi eq, %arg0, %arg1 : index
%1 = memref.alloc(%arg0, %arg0) : memref<?x?xf32>
%2 = scf.if %0 -> (memref<?x?xf32>) {
- %3 = scf.if %0 -> (memref<?x?xf32>) {
+ %3 = scf.if %cmp -> (memref<?x?xf32>) {
%9 = memref.clone %1 : memref<?x?xf32> to memref<?x?xf32>
scf.yield %9 : memref<?x?xf32>
} else {
More information about the Mlir-commits
mailing list