[Mlir-commits] [mlir] [mlir][reducer] Separate Reduction Steps in `findOptimal` and `applyPatterns` (PR #190560)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Apr 18 15:06:47 PDT 2026
https://github.com/aidint updated https://github.com/llvm/llvm-project/pull/190560
>From 614d083119a7cbc372d3e75938fdca1ffb5b865e Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Mon, 6 Apr 2026 00:25:24 +0200
Subject: [PATCH 1/5] refactor findOptimal function
---
mlir/include/mlir/Reducer/ReductionNode.h | 5 +-
mlir/lib/Reducer/ReductionTreePass.cpp | 214 ++++++++++------------
2 files changed, 102 insertions(+), 117 deletions(-)
diff --git a/mlir/include/mlir/Reducer/ReductionNode.h b/mlir/include/mlir/Reducer/ReductionNode.h
index 6ca4e13d159ac..5ae2ae52d10ef 100644
--- a/mlir/include/mlir/Reducer/ReductionNode.h
+++ b/mlir/include/mlir/Reducer/ReductionNode.h
@@ -78,6 +78,9 @@ class ReductionNode {
/// Return the generated variants(the child nodes).
ArrayRef<ReductionNode *> getVariants() const { return variants; }
+ /// Add a variant to Node's variants
+ void addVariant(ReductionNode * node) { variants.push_back(node); }
+
/// Split the ranges and generate new variants.
ArrayRef<ReductionNode *> generateNewVariants();
@@ -88,7 +91,7 @@ class ReductionNode {
/// patterns. In addition, we only apply rewrite patterns in a certain region.
/// In init(), we will duplicate the module from parent node and locate the
/// corresponding region.
- LogicalResult initialize(ModuleOp parentModule, Region &parentRegion);
+ LogicalResult initialize(ModuleOp parentModule, Region &targetRegion);
private:
/// A custom BFS iterator. The difference between
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 83497143d9669..66fea68b625b6 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -32,64 +32,11 @@ namespace mlir {
using namespace mlir;
-/// We implicitly number each operation in the region and if an operation's
-/// number falls into rangeToKeep, we need to keep it and apply the given
-/// rewrite patterns on it.
-static void applyPatterns(Region ®ion,
- const FrozenRewritePatternSet &patterns,
- ArrayRef<ReductionNode::Range> rangeToKeep,
- bool eraseOpNotInRange) {
- std::vector<Operation *> opsNotInRange;
- std::vector<Operation *> opsInRange;
- size_t keepIndex = 0;
- for (const auto &op : enumerate(region.getOps())) {
- int index = op.index();
- if (keepIndex < rangeToKeep.size() &&
- index == rangeToKeep[keepIndex].second)
- ++keepIndex;
- if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
- opsNotInRange.push_back(&op.value());
- else
- opsInRange.push_back(&op.value());
- }
-
- // `applyOpPatternsGreedily` with folding may erase the ops so we can't do the
- // pattern matching in above iteration. Besides, erase op not-in-range may end
- // up in invalid module, so `applyOpPatternsGreedily` with folding should come
- // before that transform.
- for (Operation *op : opsInRange) {
- // `applyOpPatternsGreedily` with folding returns whether the op is
- // converted. Omit it because we don't have expectation this reduction will
- // be success or not.
- (void)applyOpPatternsGreedily(op, patterns,
- GreedyRewriteConfig().setStrictness(
- GreedyRewriteStrictness::ExistingOps));
- }
-
- if (eraseOpNotInRange)
- for (Operation *op : opsNotInRange) {
- op->dropAllUses();
- op->erase();
- }
-}
-
-/// We will apply the reducer patterns to the operations in the ranges specified
-/// by ReductionNode. Note that we are not able to remove an operation without
-/// replacing it with another valid operation. However, The validity of module
-/// reduction is based on the Tester provided by the user and that means certain
-/// invalid module is still interested by the use. Thus we provide an
-/// alternative way to remove operations, which is using `eraseOpNotInRange` to
-/// erase the operations not in the range specified by ReductionNode.
-template <typename IteratorType>
-static LogicalResult findOptimal(ModuleOp module, Region ®ion,
- const FrozenRewritePatternSet &patterns,
- const Tester &test, bool eraseOpNotInRange) {
- std::pair<Tester::Interestingness, size_t> initStatus =
- test.isInteresting(module);
- // While exploring the reduction tree, we always branch from an interesting
- // node. Thus the root node must be interesting.
- if (initStatus.first != Tester::Interestingness::True)
- return module.emitWarning() << "uninterested module will not be reduced";
+/// We will apply `applyFn` to the operations in the ranges specified by
+/// ReductionNode.
+template <typename IteratorType, typename ApplyFn>
+static LogicalResult findOptimalUsing(ModuleOp module, Region ®ion,
+ const Tester &test, ApplyFn applyFn) {
llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
@@ -101,17 +48,20 @@ static LogicalResult findOptimal(ModuleOp module, Region ®ion,
// Duplicate the module for root node and locate the region in the copy.
if (failed(root->initialize(module, region)))
llvm_unreachable("unexpected initialization failure");
- root->update(initStatus);
+
+ // Create a duplicate of the root node as the first variant of the root.
+ // This will keep the root intact when applying `applyFn`.
+ ReductionNode *firstVariant = allocator.Allocate();
+ new (firstVariant) ReductionNode(root, ranges, allocator);
+ root->addVariant(firstVariant);
ReductionNode *smallestNode = root;
- IteratorType iter(root);
+ IteratorType iter(firstVariant);
while (iter != IteratorType::end()) {
ReductionNode ¤tNode = *iter;
- Region &curRegion = currentNode.getRegion();
- applyPatterns(curRegion, patterns, currentNode.getRanges(),
- eraseOpNotInRange);
+ applyFn(currentNode.getRegion(), currentNode.getRanges());
currentNode.update(test.isInteresting(currentNode.getModule()));
if (currentNode.isInteresting() == Tester::Interestingness::True &&
@@ -125,86 +75,118 @@ static LogicalResult findOptimal(ModuleOp module, Region ®ion,
// the path and apply the reducer to it.
SmallVector<ReductionNode *> trace;
ReductionNode *curNode = smallestNode;
- trace.push_back(curNode);
while (curNode != root) {
- curNode = curNode->getParent();
trace.push_back(curNode);
+ curNode = curNode->getParent();
}
+ if (trace.empty())
+ // If trace is empty, then the smallestNode == root and therefore we were
+ // not successful in reducing the module
+ return failure();
+
// Reduce the region through the optimal path.
while (!trace.empty()) {
ReductionNode *top = trace.pop_back_val();
- applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
+ applyFn(region, top->getStartRanges());
}
- if (test.isInteresting(module).first != Tester::Interestingness::True)
+ std::pair<Tester::Interestingness, size_t> finalStatus =
+ test.isInteresting(module);
+
+ if (finalStatus.first != Tester::Interestingness::True)
llvm::report_fatal_error("Reduced module is not interesting");
- if (test.isInteresting(module).second != smallestNode->getSize())
+ if (finalStatus.second != smallestNode->getSize())
llvm::report_fatal_error(
"Reduced module doesn't have consistent size with smallestNode");
return success();
}
-/// This function attempts to erase all operations within the region currently
-/// being processed.
-static LogicalResult eraseAllOpsInRegion(ModuleOp module, Region ®ion,
- const Tester &test) {
- std::pair<Tester::Interestingness, size_t> initStatus =
- test.isInteresting(module);
-
- // While exploring the reduction tree, we always branch from an interesting
- // node. Thus the root node must be interesting.
- if (initStatus.first != Tester::Interestingness::True)
- return module.emitWarning() << "uninterested module will not be reduced";
- llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
-
- // Setting the ranges to {{0, 0}} will result in the deletion of all ops
- // within the region.
- std::vector<ReductionNode::Range> ranges{{0, 0}};
-
- // We allocate memory on the stack, and the 'allocator' is only used to
- // construct the 'root node'. Since we won't be constructing any child nodes
- // for emptyRegionNode, it is only used within the current scope.
- ReductionNode emptyRegionNode(nullptr, ranges, allocator);
- ReductionNode *root = &emptyRegionNode;
+/// We implicitly number each operation in the region and if an operation's
+/// number falls into rangeToKeep, we'll keep it.
+static void eliminateOperations(Region ®ion,
+ ArrayRef<ReductionNode::Range> rangeToKeep) {
+ std::vector<Operation *> opsNotInRange;
+ size_t keepIndex = 0;
+ for (const auto &op : enumerate(region.getOps())) {
+ int index = op.index();
+ if (keepIndex < rangeToKeep.size() &&
+ index == rangeToKeep[keepIndex].second)
+ ++keepIndex;
+ if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
+ opsNotInRange.push_back(&op.value());
+ }
- // Create a copy of the current IR.
- if (failed(root->initialize(module, region)))
- llvm_unreachable("unexpected initialization failure");
+ for (Operation *op : opsNotInRange) {
+ op->dropAllUses();
+ op->erase();
+ }
+}
- // Erase all operations within the corresponding region of the clone.
- applyPatterns(root->getRegion(), {}, root->getRanges(), true);
- root->update(test.isInteresting(root->getModule()));
- if (root->isInteresting() == Tester::Interestingness::True) {
- // If we can successfully remove all ops in the region, we apply the same
- // transformation to the original IR and return success.
- applyPatterns(region, {}, root->getRanges(), true);
- return success();
+/// We implicitly number each operation in the region and if an operation's
+/// number falls into rangeToApply, we'll apply the given rewrite patterns on
+/// it.
+static void applyPatterns(Region ®ion,
+ const FrozenRewritePatternSet &patterns,
+ ArrayRef<ReductionNode::Range> rangeToApply) {
+ size_t rangeIndex = 0;
+ for (const auto &op : enumerate(region.getOps())) {
+ int index = op.index();
+ if (rangeIndex < rangeToApply.size() &&
+ index == rangeToApply[rangeIndex].second)
+ ++rangeIndex;
+ if (rangeIndex < rangeToApply.size() &&
+ index > rangeToApply[rangeIndex].first)
+ // `applyOpPatternsGreedily` with folding returns whether the op is
+ // converted. Omit it because we don't have expectation this reduction
+ // will be success or not.
+ (void)applyOpPatternsGreedily(&op.value(), patterns,
+ GreedyRewriteConfig().setStrictness(
+ GreedyRewriteStrictness::ExistingOps));
}
- return failure();
}
template <typename IteratorType>
static LogicalResult findOptimal(ModuleOp module, Region ®ion,
const FrozenRewritePatternSet &patterns,
const Tester &test) {
- // We separate the reduction process into 3 steps, the first one is to erase
- // redundant operations and the second one is to apply the reducer patterns.
- // In the first phase, we attempt to erase all operations within the entire
- // region.
- if (succeeded(eraseAllOpsInRegion(module, region, test)))
+ // We first test the interstingness of the module passed to findOptimal.
+ std::pair<Tester::Interestingness, size_t> initStatus =
+ test.isInteresting(module);
+ if (initStatus.first != Tester::Interestingness::True)
+ // If the module is not interesting, we can return failure
+ return module.emitWarning() << "uninterested module will not be reduced";
+
+ // We separate the reduction process into 3 steps:
+ // In the first step, we attempt to erase all operations within the
+ // entire region.
+ if (succeeded(findOptimalUsing<IteratorType>(
+ module, region, test, [](auto ®ion, auto) {
+ for (auto &block : region.getBlocks())
+ block.clear();
+ ;
+ })))
+ // If clearing the entire region kept the module interesting
+ // we will return success.
return success();
- // In the second phase, we don't apply any patterns so that we only select the
- // range of operations to keep to the module stay interesting.
- if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
- /*eraseOpNotInRange=*/true)))
- return failure();
- // In the third phase, we suppose that no operation is redundant, so we try
- // to rewrite the operation into simpler form.
- return findOptimal<IteratorType>(module, region, patterns, test,
- /*eraseOpNotInRange=*/false);
+ // In the second step, we eliminate redundant operations from the region to
+ // select those that keep the module interesting
+ auto eliminationResult =
+ findOptimalUsing<IteratorType>(module, region, test, eliminateOperations);
+
+ // In the third step, we suppose that no operation is redundant, so we try
+ // to rewrite the operation into simpler form by applying patterns.
+ auto applyPatternsResult = findOptimalUsing<IteratorType>(
+ module, region, test, [&](auto ®ion, auto ranges) {
+ applyPatterns(region, patterns, ranges);
+ });
+
+ if (succeeded(eliminationResult) || succeeded(applyPatternsResult))
+ // if step 2 or 3 was successful, then we return success.
+ return success();
+ return failure();
}
namespace {
>From 5fc2c2d9d1680ac75ba88132c6310158a6250a9d Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Mon, 6 Apr 2026 01:20:22 +0200
Subject: [PATCH 2/5] add test
---
mlir/lib/Reducer/ReductionTreePass.cpp | 2 +-
mlir/test/mlir-reduce/trivially-dead.mlir | 16 ++++++++++++++++
mlir/test/mlir-reduce/trivially-dead.sh | 4 ++++
3 files changed, 21 insertions(+), 1 deletion(-)
create mode 100644 mlir/test/mlir-reduce/trivially-dead.mlir
create mode 100755 mlir/test/mlir-reduce/trivially-dead.sh
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 66fea68b625b6..f07881f15ac77 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -136,7 +136,7 @@ static void applyPatterns(Region ®ion,
index == rangeToApply[rangeIndex].second)
++rangeIndex;
if (rangeIndex < rangeToApply.size() &&
- index > rangeToApply[rangeIndex].first)
+ index >= rangeToApply[rangeIndex].first)
// `applyOpPatternsGreedily` with folding returns whether the op is
// converted. Omit it because we don't have expectation this reduction
// will be success or not.
diff --git a/mlir/test/mlir-reduce/trivially-dead.mlir b/mlir/test/mlir-reduce/trivially-dead.mlir
new file mode 100644
index 0000000000000..ee69f754a1971
--- /dev/null
+++ b/mlir/test/mlir-reduce/trivially-dead.mlir
@@ -0,0 +1,16 @@
+// UNSUPPORTED: system-windows
+// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/trivially-dead.sh' | FileCheck %s
+// We are testing the ability of keeping trivially-dead yet interesting code
+
+func.func @trivially_dead() {
+ %0 = arith.constant 1 : i32
+ %1 = arith.constant 2 : i32
+ %2 = arith.constant 3 : i32
+ %3 = arith.constant 4 : i32
+ return
+}
+
+// CHECK-LABEL: func.func @trivially_dead() {
+// CHECK-NEXT: {{.*}} = arith.constant 2 : i32
+// CHECK-NEXT: return
+// CHECK-NEXT: }
diff --git a/mlir/test/mlir-reduce/trivially-dead.sh b/mlir/test/mlir-reduce/trivially-dead.sh
new file mode 100755
index 0000000000000..3a973c12f0486
--- /dev/null
+++ b/mlir/test/mlir-reduce/trivially-dead.sh
@@ -0,0 +1,4 @@
+#!/bin/sh
+
+# break only on `arith.constat 2 : i32`
+! cat $1 | grep "arith.constant 2 : i32" 2>&1 1>/dev/null
>From 0b56ecfbb1d1ce8d89108c93c8d011558e231321 Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Mon, 6 Apr 2026 01:38:05 +0200
Subject: [PATCH 3/5] resolve clang format issue
---
mlir/include/mlir/Reducer/ReductionNode.h | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Reducer/ReductionNode.h b/mlir/include/mlir/Reducer/ReductionNode.h
index 5ae2ae52d10ef..8815be0442676 100644
--- a/mlir/include/mlir/Reducer/ReductionNode.h
+++ b/mlir/include/mlir/Reducer/ReductionNode.h
@@ -79,7 +79,7 @@ class ReductionNode {
ArrayRef<ReductionNode *> getVariants() const { return variants; }
/// Add a variant to Node's variants
- void addVariant(ReductionNode * node) { variants.push_back(node); }
+ void addVariant(ReductionNode *node) { variants.push_back(node); }
/// Split the ranges and generate new variants.
ArrayRef<ReductionNode *> generateNewVariants();
>From d21982f699d11d606e35b7a040a1197b09e9d30f Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Thu, 9 Apr 2026 01:18:50 +0200
Subject: [PATCH 4/5] remove extra semicolon
---
mlir/lib/Reducer/ReductionTreePass.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 6bacd4f336545..bd0189577bcbe 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -165,7 +165,6 @@ static LogicalResult findOptimal(ModuleOp module, Region ®ion,
module, region, test, [](auto ®ion, auto) {
for (auto &block : region.getBlocks())
block.clear();
- ;
})))
// If clearing the entire region kept the module interesting
// we will return success.
>From f84512b3375ef1116d87ff8d1d54e4dfe33d1a08 Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Sun, 19 Apr 2026 00:06:29 +0200
Subject: [PATCH 5/5] fix applyPattern bug
---
mlir/lib/Reducer/ReductionTreePass.cpp | 16 ++++++++++------
1 file changed, 10 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index bd0189577bcbe..bfcab5264c24a 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -130,6 +130,7 @@ static void applyPatterns(Region ®ion,
const FrozenRewritePatternSet &patterns,
ArrayRef<ReductionNode::Range> rangeToApply) {
size_t rangeIndex = 0;
+ std::vector<Operation *> opsInRange;
for (const auto &op : enumerate(region.getOps())) {
int index = op.index();
if (rangeIndex < rangeToApply.size() &&
@@ -137,13 +138,16 @@ static void applyPatterns(Region ®ion,
++rangeIndex;
if (rangeIndex < rangeToApply.size() &&
index >= rangeToApply[rangeIndex].first)
- // `applyOpPatternsGreedily` with folding returns whether the op is
- // converted. Omit it because we don't have expectation this reduction
- // will be success or not.
- (void)applyOpPatternsGreedily(&op.value(), patterns,
- GreedyRewriteConfig().setStrictness(
- GreedyRewriteStrictness::ExistingOps));
+ opsInRange.push_back(&op.value());
}
+
+ for (auto *op : opsInRange)
+ // `applyOpPatternsGreedily` with folding returns whether the op is
+ // converted. Omit it because we don't have expectation this reduction
+ // will be success or not.
+ (void)applyOpPatternsGreedily(op, patterns,
+ GreedyRewriteConfig().setStrictness(
+ GreedyRewriteStrictness::ExistingOps));
}
template <typename IteratorType>
More information about the Mlir-commits
mailing list