[Mlir-commits] [mlir] 854d0ed - [MLIR] Conditional Branch Argument Propagation

William S. Moses llvmlistbot at llvm.org
Mon Jun 7 10:33:17 PDT 2021


Author: William S. Moses
Date: 2021-06-07T13:33:10-04:00
New Revision: 854d0edce6c6e29ee2803c1f6590dee62994419a

URL: https://github.com/llvm/llvm-project/commit/854d0edce6c6e29ee2803c1f6590dee62994419a
DIFF: https://github.com/llvm/llvm-project/commit/854d0edce6c6e29ee2803c1f6590dee62994419a.diff

LOG: [MLIR] Conditional Branch Argument Propagation

In an operation in the true/false dest of a branch,
one can assume that the operation itself was true/false if
only that edge can reach the operation.

Differential Revision: https://reviews.llvm.org/D101709

Added: 
    

Modified: 
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Standard/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index a3c2513e28bcc..7118afbb397aa 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1044,13 +1044,87 @@ struct SimplifyCondBranchFromCondBranchOnSameCondition
     return success();
   }
 };
+
+///   cond_br %arg0, ^trueB, ^falseB
+///
+/// ^trueB:
+///   "test.consumer1"(%arg0) : (i1) -> ()
+///    ...
+///
+/// ^falseB:
+///   "test.consumer2"(%arg0) : (i1) -> ()
+///   ...
+///
+/// ->
+///
+///   cond_br %arg0, ^trueB, ^falseB
+/// ^trueB:
+///   "test.consumer1"(%true) : (i1) -> ()
+///   ...
+///
+/// ^falseB:
+///   "test.consumer2"(%false) : (i1) -> ()
+///   ...
+struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
+  using OpRewritePattern<CondBranchOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(CondBranchOp condbr,
+                                PatternRewriter &rewriter) const override {
+    // Check that we have a single distinct predecessor.
+    bool replaced = false;
+    Type ty = rewriter.getI1Type();
+
+    // These variables serve to prevent creating duplicate constants
+    // and hold constant true or false values.
+    Value constantTrue = nullptr;
+    Value constantFalse = nullptr;
+
+    // TODO These checks can be expanded to encompas any use with only
+    // either the true of false edge as a predecessor. For now, we fall
+    // back to checking the single predecessor is given by the true/fasle
+    // destination, thereby ensuring that only that edge can reach the
+    // op.
+    if (condbr.getTrueDest()->getSinglePredecessor()) {
+      for (OpOperand &use :
+           llvm::make_early_inc_range(condbr.condition().getUses())) {
+        if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
+          replaced = true;
+
+          if (!constantTrue)
+            constantTrue = rewriter.create<mlir::ConstantOp>(
+                condbr.getLoc(), ty, rewriter.getBoolAttr(true));
+
+          rewriter.updateRootInPlace(use.getOwner(),
+                                     [&] { use.set(constantTrue); });
+        }
+      }
+    }
+    if (condbr.getFalseDest()->getSinglePredecessor()) {
+      for (OpOperand &use :
+           llvm::make_early_inc_range(condbr.condition().getUses())) {
+        if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
+          replaced = true;
+
+          if (!constantFalse)
+            constantFalse = rewriter.create<mlir::ConstantOp>(
+                condbr.getLoc(), ty, rewriter.getBoolAttr(false));
+
+          rewriter.updateRootInPlace(use.getOwner(),
+                                     [&] { use.set(constantFalse); });
+        }
+      }
+    }
+    return success(replaced);
+  }
+};
 } // end anonymous namespace
 
 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                MLIRContext *context) {
   results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
               SimplifyCondBranchIdenticalSuccessors,
-              SimplifyCondBranchFromCondBranchOnSameCondition>(context);
+              SimplifyCondBranchFromCondBranchOnSameCondition,
+              CondBranchTruthPropagation>(context);
 }
 
 Optional<MutableOperandRange>

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index f67ba447cf3d1..d3e48d4c7edd2 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -429,8 +429,6 @@ func @truncConstant(%arg0: i8) -> i16 {
   return %tr : i16
 }
 
-// -----
-
 // CHECK-LABEL: @tripleAddAdd
 //       CHECK:   %[[cres:.+]] = constant 59 : index 
 //       CHECK:   %[[add:.+]] = addi %arg0, %[[cres]] : index 
@@ -648,3 +646,25 @@ func @notCmpUGE(%arg0: i8, %arg1: i8) -> i1 {
   %ncmp = xor %cmp, %true : i1
   return %ncmp : i1
 }
+
+// -----
+
+// CHECK-LABEL: @branchCondProp
+//       CHECK:       %[[trueval:.+]] = constant true
+//       CHECK:       %[[falseval:.+]] = constant false
+//       CHECK:       "test.consumer1"(%[[trueval]]) : (i1) -> ()
+//       CHECK:       "test.consumer2"(%[[falseval]]) : (i1) -> ()
+func @branchCondProp(%arg0: i1) {
+  cond_br %arg0, ^trueB, ^falseB
+
+^trueB:
+  "test.consumer1"(%arg0) : (i1) -> ()
+  br ^exit
+
+^falseB:
+  "test.consumer2"(%arg0) : (i1) -> ()
+  br ^exit
+
+^exit:
+  return
+}


        


More information about the Mlir-commits mailing list