[Mlir-commits] [mlir] [mlir][reducer] Separate Reduction Steps in `findOptimal` and `applyPatterns` (PR #190560)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Apr 18 15:06:47 PDT 2026


https://github.com/aidint updated https://github.com/llvm/llvm-project/pull/190560

>From 614d083119a7cbc372d3e75938fdca1ffb5b865e Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Mon, 6 Apr 2026 00:25:24 +0200
Subject: [PATCH 1/5] refactor findOptimal function

---
 mlir/include/mlir/Reducer/ReductionNode.h |   5 +-
 mlir/lib/Reducer/ReductionTreePass.cpp    | 214 ++++++++++------------
 2 files changed, 102 insertions(+), 117 deletions(-)

diff --git a/mlir/include/mlir/Reducer/ReductionNode.h b/mlir/include/mlir/Reducer/ReductionNode.h
index 6ca4e13d159ac..5ae2ae52d10ef 100644
--- a/mlir/include/mlir/Reducer/ReductionNode.h
+++ b/mlir/include/mlir/Reducer/ReductionNode.h
@@ -78,6 +78,9 @@ class ReductionNode {
   /// Return the generated variants(the child nodes).
   ArrayRef<ReductionNode *> getVariants() const { return variants; }
 
+  /// Add a variant to Node's variants
+  void addVariant(ReductionNode * node) { variants.push_back(node); }
+
   /// Split the ranges and generate new variants.
   ArrayRef<ReductionNode *> generateNewVariants();
 
@@ -88,7 +91,7 @@ class ReductionNode {
   /// patterns. In addition, we only apply rewrite patterns in a certain region.
   /// In init(), we will duplicate the module from parent node and locate the
   /// corresponding region.
-  LogicalResult initialize(ModuleOp parentModule, Region &parentRegion);
+  LogicalResult initialize(ModuleOp parentModule, Region &targetRegion);
 
 private:
   /// A custom BFS iterator. The difference between
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 83497143d9669..66fea68b625b6 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -32,64 +32,11 @@ namespace mlir {
 
 using namespace mlir;
 
-/// We implicitly number each operation in the region and if an operation's
-/// number falls into rangeToKeep, we need to keep it and apply the given
-/// rewrite patterns on it.
-static void applyPatterns(Region &region,
-                          const FrozenRewritePatternSet &patterns,
-                          ArrayRef<ReductionNode::Range> rangeToKeep,
-                          bool eraseOpNotInRange) {
-  std::vector<Operation *> opsNotInRange;
-  std::vector<Operation *> opsInRange;
-  size_t keepIndex = 0;
-  for (const auto &op : enumerate(region.getOps())) {
-    int index = op.index();
-    if (keepIndex < rangeToKeep.size() &&
-        index == rangeToKeep[keepIndex].second)
-      ++keepIndex;
-    if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
-      opsNotInRange.push_back(&op.value());
-    else
-      opsInRange.push_back(&op.value());
-  }
-
-  // `applyOpPatternsGreedily` with folding may erase the ops so we can't do the
-  // pattern matching in above iteration. Besides, erase op not-in-range may end
-  // up in invalid module, so `applyOpPatternsGreedily` with folding should come
-  // before that transform.
-  for (Operation *op : opsInRange) {
-    // `applyOpPatternsGreedily` with folding returns whether the op is
-    // converted. Omit it because we don't have expectation this reduction will
-    // be success or not.
-    (void)applyOpPatternsGreedily(op, patterns,
-                                  GreedyRewriteConfig().setStrictness(
-                                      GreedyRewriteStrictness::ExistingOps));
-  }
-
-  if (eraseOpNotInRange)
-    for (Operation *op : opsNotInRange) {
-      op->dropAllUses();
-      op->erase();
-    }
-}
-
-/// We will apply the reducer patterns to the operations in the ranges specified
-/// by ReductionNode. Note that we are not able to remove an operation without
-/// replacing it with another valid operation. However, The validity of module
-/// reduction is based on the Tester provided by the user and that means certain
-/// invalid module is still interested by the use. Thus we provide an
-/// alternative way to remove operations, which is using `eraseOpNotInRange` to
-/// erase the operations not in the range specified by ReductionNode.
-template <typename IteratorType>
-static LogicalResult findOptimal(ModuleOp module, Region &region,
-                                 const FrozenRewritePatternSet &patterns,
-                                 const Tester &test, bool eraseOpNotInRange) {
-  std::pair<Tester::Interestingness, size_t> initStatus =
-      test.isInteresting(module);
-  // While exploring the reduction tree, we always branch from an interesting
-  // node. Thus the root node must be interesting.
-  if (initStatus.first != Tester::Interestingness::True)
-    return module.emitWarning() << "uninterested module will not be reduced";
+/// We will apply `applyFn` to the operations in the ranges specified by
+/// ReductionNode.
+template <typename IteratorType, typename ApplyFn>
+static LogicalResult findOptimalUsing(ModuleOp module, Region &region,
+                                      const Tester &test, ApplyFn applyFn) {
 
   llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
 
@@ -101,17 +48,20 @@ static LogicalResult findOptimal(ModuleOp module, Region &region,
   // Duplicate the module for root node and locate the region in the copy.
   if (failed(root->initialize(module, region)))
     llvm_unreachable("unexpected initialization failure");
-  root->update(initStatus);
+
+  // Create a duplicate of the root node as the first variant of the root.
+  // This will keep the root intact when applying `applyFn`.
+  ReductionNode *firstVariant = allocator.Allocate();
+  new (firstVariant) ReductionNode(root, ranges, allocator);
+  root->addVariant(firstVariant);
 
   ReductionNode *smallestNode = root;
-  IteratorType iter(root);
+  IteratorType iter(firstVariant);
 
   while (iter != IteratorType::end()) {
     ReductionNode &currentNode = *iter;
-    Region &curRegion = currentNode.getRegion();
 
-    applyPatterns(curRegion, patterns, currentNode.getRanges(),
-                  eraseOpNotInRange);
+    applyFn(currentNode.getRegion(), currentNode.getRanges());
     currentNode.update(test.isInteresting(currentNode.getModule()));
 
     if (currentNode.isInteresting() == Tester::Interestingness::True &&
@@ -125,86 +75,118 @@ static LogicalResult findOptimal(ModuleOp module, Region &region,
   // the path and apply the reducer to it.
   SmallVector<ReductionNode *> trace;
   ReductionNode *curNode = smallestNode;
-  trace.push_back(curNode);
   while (curNode != root) {
-    curNode = curNode->getParent();
     trace.push_back(curNode);
+    curNode = curNode->getParent();
   }
 
+  if (trace.empty())
+    // If trace is empty, then the smallestNode == root and therefore we were
+    // not successful in reducing the module
+    return failure();
+
   // Reduce the region through the optimal path.
   while (!trace.empty()) {
     ReductionNode *top = trace.pop_back_val();
-    applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
+    applyFn(region, top->getStartRanges());
   }
 
-  if (test.isInteresting(module).first != Tester::Interestingness::True)
+  std::pair<Tester::Interestingness, size_t> finalStatus =
+      test.isInteresting(module);
+
+  if (finalStatus.first != Tester::Interestingness::True)
     llvm::report_fatal_error("Reduced module is not interesting");
-  if (test.isInteresting(module).second != smallestNode->getSize())
+  if (finalStatus.second != smallestNode->getSize())
     llvm::report_fatal_error(
         "Reduced module doesn't have consistent size with smallestNode");
   return success();
 }
 
-/// This function attempts to erase all operations within the region currently
-/// being processed.
-static LogicalResult eraseAllOpsInRegion(ModuleOp module, Region &region,
-                                         const Tester &test) {
-  std::pair<Tester::Interestingness, size_t> initStatus =
-      test.isInteresting(module);
-
-  // While exploring the reduction tree, we always branch from an interesting
-  // node. Thus the root node must be interesting.
-  if (initStatus.first != Tester::Interestingness::True)
-    return module.emitWarning() << "uninterested module will not be reduced";
-  llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
-
-  // Setting the ranges to {{0, 0}} will result in the deletion of all ops
-  // within the region.
-  std::vector<ReductionNode::Range> ranges{{0, 0}};
-
-  // We allocate memory on the stack, and the 'allocator' is only used to
-  // construct the 'root node'. Since we won't be constructing any child nodes
-  // for emptyRegionNode, it is only used within the current scope.
-  ReductionNode emptyRegionNode(nullptr, ranges, allocator);
-  ReductionNode *root = &emptyRegionNode;
+/// We implicitly number each operation in the region and if an operation's
+/// number falls into rangeToKeep, we'll keep it.
+static void eliminateOperations(Region &region,
+                                ArrayRef<ReductionNode::Range> rangeToKeep) {
+  std::vector<Operation *> opsNotInRange;
+  size_t keepIndex = 0;
+  for (const auto &op : enumerate(region.getOps())) {
+    int index = op.index();
+    if (keepIndex < rangeToKeep.size() &&
+        index == rangeToKeep[keepIndex].second)
+      ++keepIndex;
+    if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
+      opsNotInRange.push_back(&op.value());
+  }
 
-  // Create a copy of the current IR.
-  if (failed(root->initialize(module, region)))
-    llvm_unreachable("unexpected initialization failure");
+  for (Operation *op : opsNotInRange) {
+    op->dropAllUses();
+    op->erase();
+  }
+}
 
-  // Erase all operations within the corresponding region of the clone.
-  applyPatterns(root->getRegion(), {}, root->getRanges(), true);
-  root->update(test.isInteresting(root->getModule()));
-  if (root->isInteresting() == Tester::Interestingness::True) {
-    // If we can successfully remove all ops in the region, we apply the same
-    // transformation to the original IR and return success.
-    applyPatterns(region, {}, root->getRanges(), true);
-    return success();
+/// We implicitly number each operation in the region and if an operation's
+/// number falls into rangeToApply, we'll apply the given rewrite patterns on
+/// it.
+static void applyPatterns(Region &region,
+                          const FrozenRewritePatternSet &patterns,
+                          ArrayRef<ReductionNode::Range> rangeToApply) {
+  size_t rangeIndex = 0;
+  for (const auto &op : enumerate(region.getOps())) {
+    int index = op.index();
+    if (rangeIndex < rangeToApply.size() &&
+        index == rangeToApply[rangeIndex].second)
+      ++rangeIndex;
+    if (rangeIndex < rangeToApply.size() &&
+        index > rangeToApply[rangeIndex].first)
+      // `applyOpPatternsGreedily` with folding returns whether the op is
+      // converted. Omit it because we don't have expectation this reduction
+      // will be success or not.
+      (void)applyOpPatternsGreedily(&op.value(), patterns,
+                                    GreedyRewriteConfig().setStrictness(
+                                        GreedyRewriteStrictness::ExistingOps));
   }
-  return failure();
 }
 
 template <typename IteratorType>
 static LogicalResult findOptimal(ModuleOp module, Region &region,
                                  const FrozenRewritePatternSet &patterns,
                                  const Tester &test) {
-  // We separate the reduction process into 3 steps, the first one is to erase
-  // redundant operations and the second one is to apply the reducer patterns.
 
-  // In the first phase, we attempt to erase all operations within the entire
-  // region.
-  if (succeeded(eraseAllOpsInRegion(module, region, test)))
+  // We first test the interstingness of the module passed to findOptimal.
+  std::pair<Tester::Interestingness, size_t> initStatus =
+      test.isInteresting(module);
+  if (initStatus.first != Tester::Interestingness::True)
+    // If the module is not interesting, we can return failure
+    return module.emitWarning() << "uninterested module will not be reduced";
+
+  // We separate the reduction process into 3 steps:
+  // In the first step, we attempt to erase all operations within the
+  // entire region.
+  if (succeeded(findOptimalUsing<IteratorType>(
+          module, region, test, [](auto &region, auto) {
+            for (auto &block : region.getBlocks())
+              block.clear();
+            ;
+          })))
+    // If clearing the entire region kept the module interesting
+    // we will return success.
     return success();
 
-  // In the second phase, we don't apply any patterns so that we only select the
-  // range of operations to keep to the module stay interesting.
-  if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
-                                       /*eraseOpNotInRange=*/true)))
-    return failure();
-  // In the third phase, we suppose that no operation is redundant, so we try
-  // to rewrite the operation into simpler form.
-  return findOptimal<IteratorType>(module, region, patterns, test,
-                                   /*eraseOpNotInRange=*/false);
+  // In the second step, we eliminate redundant operations from the region to
+  // select those that keep the module interesting
+  auto eliminationResult =
+      findOptimalUsing<IteratorType>(module, region, test, eliminateOperations);
+
+  // In the third step, we suppose that no operation is redundant, so we try
+  // to rewrite the operation into simpler form by applying patterns.
+  auto applyPatternsResult = findOptimalUsing<IteratorType>(
+      module, region, test, [&](auto &region, auto ranges) {
+        applyPatterns(region, patterns, ranges);
+      });
+
+  if (succeeded(eliminationResult) || succeeded(applyPatternsResult))
+    // if step 2 or 3 was successful, then we return success.
+    return success();
+  return failure();
 }
 
 namespace {

>From 5fc2c2d9d1680ac75ba88132c6310158a6250a9d Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Mon, 6 Apr 2026 01:20:22 +0200
Subject: [PATCH 2/5] add test

---
 mlir/lib/Reducer/ReductionTreePass.cpp    |  2 +-
 mlir/test/mlir-reduce/trivially-dead.mlir | 16 ++++++++++++++++
 mlir/test/mlir-reduce/trivially-dead.sh   |  4 ++++
 3 files changed, 21 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/mlir-reduce/trivially-dead.mlir
 create mode 100755 mlir/test/mlir-reduce/trivially-dead.sh

diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 66fea68b625b6..f07881f15ac77 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -136,7 +136,7 @@ static void applyPatterns(Region &region,
         index == rangeToApply[rangeIndex].second)
       ++rangeIndex;
     if (rangeIndex < rangeToApply.size() &&
-        index > rangeToApply[rangeIndex].first)
+        index >= rangeToApply[rangeIndex].first)
       // `applyOpPatternsGreedily` with folding returns whether the op is
       // converted. Omit it because we don't have expectation this reduction
       // will be success or not.
diff --git a/mlir/test/mlir-reduce/trivially-dead.mlir b/mlir/test/mlir-reduce/trivially-dead.mlir
new file mode 100644
index 0000000000000..ee69f754a1971
--- /dev/null
+++ b/mlir/test/mlir-reduce/trivially-dead.mlir
@@ -0,0 +1,16 @@
+// UNSUPPORTED: system-windows
+// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/trivially-dead.sh' | FileCheck %s
+// We are testing the ability of keeping trivially-dead yet interesting code
+
+func.func @trivially_dead() {
+  %0 = arith.constant 1 : i32
+  %1 = arith.constant 2 : i32
+  %2 = arith.constant 3 : i32
+  %3 = arith.constant 4 : i32
+  return
+}
+
+// CHECK-LABEL: func.func @trivially_dead() {
+// CHECK-NEXT: {{.*}} = arith.constant 2 : i32
+// CHECK-NEXT: return
+// CHECK-NEXT: }
diff --git a/mlir/test/mlir-reduce/trivially-dead.sh b/mlir/test/mlir-reduce/trivially-dead.sh
new file mode 100755
index 0000000000000..3a973c12f0486
--- /dev/null
+++ b/mlir/test/mlir-reduce/trivially-dead.sh
@@ -0,0 +1,4 @@
+#!/bin/sh
+
+# break only on `arith.constat 2 : i32`
+! cat $1 | grep "arith.constant 2 : i32" 2>&1 1>/dev/null

>From 0b56ecfbb1d1ce8d89108c93c8d011558e231321 Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Mon, 6 Apr 2026 01:38:05 +0200
Subject: [PATCH 3/5] resolve clang format issue

---
 mlir/include/mlir/Reducer/ReductionNode.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Reducer/ReductionNode.h b/mlir/include/mlir/Reducer/ReductionNode.h
index 5ae2ae52d10ef..8815be0442676 100644
--- a/mlir/include/mlir/Reducer/ReductionNode.h
+++ b/mlir/include/mlir/Reducer/ReductionNode.h
@@ -79,7 +79,7 @@ class ReductionNode {
   ArrayRef<ReductionNode *> getVariants() const { return variants; }
 
   /// Add a variant to Node's variants
-  void addVariant(ReductionNode * node) { variants.push_back(node); }
+  void addVariant(ReductionNode *node) { variants.push_back(node); }
 
   /// Split the ranges and generate new variants.
   ArrayRef<ReductionNode *> generateNewVariants();

>From d21982f699d11d606e35b7a040a1197b09e9d30f Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Thu, 9 Apr 2026 01:18:50 +0200
Subject: [PATCH 4/5] remove extra semicolon

---
 mlir/lib/Reducer/ReductionTreePass.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 6bacd4f336545..bd0189577bcbe 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -165,7 +165,6 @@ static LogicalResult findOptimal(ModuleOp module, Region &region,
           module, region, test, [](auto &region, auto) {
             for (auto &block : region.getBlocks())
               block.clear();
-            ;
           })))
     // If clearing the entire region kept the module interesting
     // we will return success.

>From f84512b3375ef1116d87ff8d1d54e4dfe33d1a08 Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Sun, 19 Apr 2026 00:06:29 +0200
Subject: [PATCH 5/5] fix applyPattern bug

---
 mlir/lib/Reducer/ReductionTreePass.cpp | 16 ++++++++++------
 1 file changed, 10 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index bd0189577bcbe..bfcab5264c24a 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -130,6 +130,7 @@ static void applyPatterns(Region &region,
                           const FrozenRewritePatternSet &patterns,
                           ArrayRef<ReductionNode::Range> rangeToApply) {
   size_t rangeIndex = 0;
+  std::vector<Operation *> opsInRange;
   for (const auto &op : enumerate(region.getOps())) {
     int index = op.index();
     if (rangeIndex < rangeToApply.size() &&
@@ -137,13 +138,16 @@ static void applyPatterns(Region &region,
       ++rangeIndex;
     if (rangeIndex < rangeToApply.size() &&
         index >= rangeToApply[rangeIndex].first)
-      // `applyOpPatternsGreedily` with folding returns whether the op is
-      // converted. Omit it because we don't have expectation this reduction
-      // will be success or not.
-      (void)applyOpPatternsGreedily(&op.value(), patterns,
-                                    GreedyRewriteConfig().setStrictness(
-                                        GreedyRewriteStrictness::ExistingOps));
+      opsInRange.push_back(&op.value());
   }
+
+  for (auto *op : opsInRange)
+    // `applyOpPatternsGreedily` with folding returns whether the op is
+    // converted. Omit it because we don't have expectation this reduction
+    // will be success or not.
+    (void)applyOpPatternsGreedily(op, patterns,
+                                  GreedyRewriteConfig().setStrictness(
+                                      GreedyRewriteStrictness::ExistingOps));
 }
 
 template <typename IteratorType>



More information about the Mlir-commits mailing list