[Mlir-commits] [mlir] 854d0ed - [MLIR] Conditional Branch Argument Propagation
William S. Moses
llvmlistbot at llvm.org
Mon Jun 7 10:33:17 PDT 2021
Author: William S. Moses
Date: 2021-06-07T13:33:10-04:00
New Revision: 854d0edce6c6e29ee2803c1f6590dee62994419a
URL: https://github.com/llvm/llvm-project/commit/854d0edce6c6e29ee2803c1f6590dee62994419a
DIFF: https://github.com/llvm/llvm-project/commit/854d0edce6c6e29ee2803c1f6590dee62994419a.diff
LOG: [MLIR] Conditional Branch Argument Propagation
In an operation in the true/false dest of a branch,
one can assume that the operation itself was true/false if
only that edge can reach the operation.
Differential Revision: https://reviews.llvm.org/D101709
Added:
Modified:
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index a3c2513e28bcc..7118afbb397aa 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1044,13 +1044,87 @@ struct SimplifyCondBranchFromCondBranchOnSameCondition
return success();
}
};
+
+/// cond_br %arg0, ^trueB, ^falseB
+///
+/// ^trueB:
+/// "test.consumer1"(%arg0) : (i1) -> ()
+/// ...
+///
+/// ^falseB:
+/// "test.consumer2"(%arg0) : (i1) -> ()
+/// ...
+///
+/// ->
+///
+/// cond_br %arg0, ^trueB, ^falseB
+/// ^trueB:
+/// "test.consumer1"(%true) : (i1) -> ()
+/// ...
+///
+/// ^falseB:
+/// "test.consumer2"(%false) : (i1) -> ()
+/// ...
+struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
+ using OpRewritePattern<CondBranchOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(CondBranchOp condbr,
+ PatternRewriter &rewriter) const override {
+ // Check that we have a single distinct predecessor.
+ bool replaced = false;
+ Type ty = rewriter.getI1Type();
+
+ // These variables serve to prevent creating duplicate constants
+ // and hold constant true or false values.
+ Value constantTrue = nullptr;
+ Value constantFalse = nullptr;
+
+ // TODO These checks can be expanded to encompas any use with only
+ // either the true of false edge as a predecessor. For now, we fall
+ // back to checking the single predecessor is given by the true/fasle
+ // destination, thereby ensuring that only that edge can reach the
+ // op.
+ if (condbr.getTrueDest()->getSinglePredecessor()) {
+ for (OpOperand &use :
+ llvm::make_early_inc_range(condbr.condition().getUses())) {
+ if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
+ replaced = true;
+
+ if (!constantTrue)
+ constantTrue = rewriter.create<mlir::ConstantOp>(
+ condbr.getLoc(), ty, rewriter.getBoolAttr(true));
+
+ rewriter.updateRootInPlace(use.getOwner(),
+ [&] { use.set(constantTrue); });
+ }
+ }
+ }
+ if (condbr.getFalseDest()->getSinglePredecessor()) {
+ for (OpOperand &use :
+ llvm::make_early_inc_range(condbr.condition().getUses())) {
+ if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
+ replaced = true;
+
+ if (!constantFalse)
+ constantFalse = rewriter.create<mlir::ConstantOp>(
+ condbr.getLoc(), ty, rewriter.getBoolAttr(false));
+
+ rewriter.updateRootInPlace(use.getOwner(),
+ [&] { use.set(constantFalse); });
+ }
+ }
+ }
+ return success(replaced);
+ }
+};
} // end anonymous namespace
void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
SimplifyCondBranchIdenticalSuccessors,
- SimplifyCondBranchFromCondBranchOnSameCondition>(context);
+ SimplifyCondBranchFromCondBranchOnSameCondition,
+ CondBranchTruthPropagation>(context);
}
Optional<MutableOperandRange>
diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index f67ba447cf3d1..d3e48d4c7edd2 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -429,8 +429,6 @@ func @truncConstant(%arg0: i8) -> i16 {
return %tr : i16
}
-// -----
-
// CHECK-LABEL: @tripleAddAdd
// CHECK: %[[cres:.+]] = constant 59 : index
// CHECK: %[[add:.+]] = addi %arg0, %[[cres]] : index
@@ -648,3 +646,25 @@ func @notCmpUGE(%arg0: i8, %arg1: i8) -> i1 {
%ncmp = xor %cmp, %true : i1
return %ncmp : i1
}
+
+// -----
+
+// CHECK-LABEL: @branchCondProp
+// CHECK: %[[trueval:.+]] = constant true
+// CHECK: %[[falseval:.+]] = constant false
+// CHECK: "test.consumer1"(%[[trueval]]) : (i1) -> ()
+// CHECK: "test.consumer2"(%[[falseval]]) : (i1) -> ()
+func @branchCondProp(%arg0: i1) {
+ cond_br %arg0, ^trueB, ^falseB
+
+^trueB:
+ "test.consumer1"(%arg0) : (i1) -> ()
+ br ^exit
+
+^falseB:
+ "test.consumer2"(%arg0) : (i1) -> ()
+ br ^exit
+
+^exit:
+ return
+}
More information about the Mlir-commits
mailing list