[Mlir-commits] [mlir] [MLIR][SCF] Speed up ConditionPropagation (PR #166080)

William Moses llvmlistbot at llvm.org
Sun Nov 2 10:25:01 PST 2025


https://github.com/wsmoses created https://github.com/llvm/llvm-project/pull/166080

Fixes https://github.com/llvm/llvm-project/issues/166039

>From 25f4ef9ca607a747be453a5c1819c93a2cf101a8 Mon Sep 17 00:00:00 2001
From: Billy Moses <wmoses at google.com>
Date: Sun, 2 Nov 2025 12:22:23 -0600
Subject: [PATCH] [MLIR][SCF] Speed up ConditionPropagation

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp | 48 ++++++++++++++++++++++++++++++---
 1 file changed, 45 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 2946b53c8cb36..2d5a0052d4c53 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -2565,6 +2565,41 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
 struct ConditionPropagation : public OpRewritePattern<IfOp> {
   using OpRewritePattern<IfOp>::OpRewritePattern;
 
+  enum class Parent {
+    Then,
+    Else,
+    None
+  };
+
+  static Parent getParentType(Region *toCheck, IfOp op, DenseMap<Region*, Parent> &cache) {
+    SmallVector<Region*> seen;
+    while (toCheck) {
+      auto found = cache.find(toCheck);
+      if (found != cache.end()) {
+        return found->second;
+      }
+      seen.push_back(toCheck);
+      if (&op.getThenRegion() == toCheck) {
+        for (auto v : seen) {
+          cache[v] = Parent::Then;
+        }
+        return Parent::Then;
+      }
+      if (&op.getElseRegion() == toCheck) {
+        for (auto v : seen) {
+          cache[v] = Parent::Else;
+        }
+        return Parent::Else;
+      }
+      toCheck = toCheck->getParentRegion();
+    }
+
+    for (auto v : seen) {
+      cache[v] = Parent::None;
+    }
+    return Parent::None;
+  }
+
   LogicalResult matchAndRewrite(IfOp op,
                                 PatternRewriter &rewriter) const override {
     // Early exit if the condition is constant since replacing a constant
@@ -2580,9 +2615,11 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
     Value constantTrue = nullptr;
     Value constantFalse = nullptr;
 
+    DenseMap<Region*, Parent> regionCache;
     for (OpOperand &use :
          llvm::make_early_inc_range(op.getCondition().getUses())) {
-      if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
+      switch(getParentType(use.getOwner()->getParentRegion(), op, cache)) {
+      case Parent::Then:{
         changed = true;
 
         if (!constantTrue)
@@ -2591,8 +2628,9 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
 
         rewriter.modifyOpInPlace(use.getOwner(),
                                  [&]() { use.set(constantTrue); });
-      } else if (op.getElseRegion().isAncestor(
-                     use.getOwner()->getParentRegion())) {
+        break;
+      }
+      case Parent::Else:{
         changed = true;
 
         if (!constantFalse)
@@ -2601,6 +2639,10 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
 
         rewriter.modifyOpInPlace(use.getOwner(),
                                  [&]() { use.set(constantFalse); });
+        break;
+      }
+      case Parent::None:
+        break;
       }
     }
 



More information about the Mlir-commits mailing list