[Mlir-commits] [mlir] [mlir][IR] Tweak `RewriterBase::replaceUsesWithIf` call (PR #172883)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 19 03:15:10 PST 2025


https://github.com/BabakkGraphcore updated https://github.com/llvm/llvm-project/pull/172883

>From 16613bd820f54cb1573acecfef79a20c149a5547 Mon Sep 17 00:00:00 2001
From: Babak Khataee <babakk at graphcore.ai>
Date: Thu, 18 Dec 2025 16:49:54 +0000
Subject: [PATCH 1/2] Tweak repalceUsesWithIf call so that it only sets the
 allUsesReplacedCall flag if called with that flag set.

---
 mlir/lib/IR/PatternMatch.cpp | 17 +++++++++--------
 1 file changed, 9 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 8bc0fcd4517d8..07fd52e0812c6 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -276,15 +276,16 @@ void RewriterBase::replaceUsesWithIf(ValueRange from, ValueRange to,
                                      function_ref<bool(OpOperand &)> functor,
                                      bool *allUsesReplaced) {
   assert(from.size() == to.size() && "incorrect number of replacements");
-  bool allReplaced = true;
-  for (auto it : llvm::zip_equal(from, to)) {
-    bool r;
-    replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor,
-                      /*allUsesReplaced=*/&r);
-    allReplaced &= r;
+  for (auto [fromVal, toVal] : llvm::zip_equal(from, to)) {
+    if (allUsesReplaced) {
+      bool r;
+      replaceUsesWithIf(fromVal, toVal, functor,
+                        /*allUsesReplaced=*/&r);
+      *allUsesReplaced &= r;
+    } else {
+      replaceUsesWithIf(fromVal, toVal, functor);
+    }
   }
-  if (allUsesReplaced)
-    *allUsesReplaced = allReplaced;
 }
 
 void RewriterBase::inlineBlockBefore(Block *source, Block *dest,

>From 41be3aa5cd35b58e47d4363e12f7ceecd307f0d3 Mon Sep 17 00:00:00 2001
From: Babak Khataee <babakk at graphcore.ai>
Date: Fri, 19 Dec 2025 11:14:38 +0000
Subject: [PATCH 2/2] Review comments.

---
 mlir/lib/IR/PatternMatch.cpp | 17 ++++++++---------
 1 file changed, 8 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 07fd52e0812c6..226e4e518d3e0 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -276,16 +276,15 @@ void RewriterBase::replaceUsesWithIf(ValueRange from, ValueRange to,
                                      function_ref<bool(OpOperand &)> functor,
                                      bool *allUsesReplaced) {
   assert(from.size() == to.size() && "incorrect number of replacements");
-  for (auto [fromVal, toVal] : llvm::zip_equal(from, to)) {
-    if (allUsesReplaced) {
-      bool r;
-      replaceUsesWithIf(fromVal, toVal, functor,
-                        /*allUsesReplaced=*/&r);
-      *allUsesReplaced &= r;
-    } else {
-      replaceUsesWithIf(fromVal, toVal, functor);
-    }
+  bool allReplaced = true;
+  for (auto it : llvm::zip_equal(from, to)) {
+    bool r = true;
+    replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor,
+                      /*allUsesReplaced=*/allUsesReplaced ? &r : nullptr);
+    allReplaced &= r;
   }
+  if (allUsesReplaced)
+    *allUsesReplaced = allReplaced;
 }
 
 void RewriterBase::inlineBlockBefore(Block *source, Block *dest,



More information about the Mlir-commits mailing list