[Mlir-commits] [mlir] [mlir][reducer] Separate Reduction Steps in `findOptimal` and `applyPatterns` (PR #190560)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Apr 5 16:35:40 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: AidinT (aidint)
<details>
<summary>Changes</summary>
Previously, `findOptimal` and `applyPatterns` were each handling multiple types of reduction simultaneously, coupling concerns that are better kept separate.
This PR refactors these functions to cleanly separate their responsibilities:
- **`findOptimal`** now focuses solely on tree traversal and is parameterized over an `ApplyFn`, making it more composable and easier to reason about.
- **`applyPatterns`** is split into two distinct functions:
- one responsible for **elimination**
- one responsible for **applying patterns**
- both are then used as `ApplyFn` in consecutive, well-defined steps
- **`eraseAllOpsInRegion`** is also simplified into a small, focused `ApplyFn`.
### Motivation
The current MLIR Reduce implementation has a correctness issue when reducing modules that contain **trivially dead code**: if the trivially dead code is itself the cause of the bug being investigated, MLIR Reduce fails to reduce it. This happens because the elimination step and the pattern application step are intertwined. Dead code is removed in the elimination step, even though no patterns are passed to `findOptimal`.
### Impact
These refactoring changes, together with minor fixes to the traversal logic, enable MLIR Reduce to correctly handle and reduce **modules with trivially dead code**.
---
Full diff: https://github.com/llvm/llvm-project/pull/190560.diff
4 Files Affected:
- (modified) mlir/include/mlir/Reducer/ReductionNode.h (+4-1)
- (modified) mlir/lib/Reducer/ReductionTreePass.cpp (+98-116)
- (added) mlir/test/mlir-reduce/trivially-dead.mlir (+16)
- (added) mlir/test/mlir-reduce/trivially-dead.sh (+4)
``````````diff
diff --git a/mlir/include/mlir/Reducer/ReductionNode.h b/mlir/include/mlir/Reducer/ReductionNode.h
index 6ca4e13d159ac..5ae2ae52d10ef 100644
--- a/mlir/include/mlir/Reducer/ReductionNode.h
+++ b/mlir/include/mlir/Reducer/ReductionNode.h
@@ -78,6 +78,9 @@ class ReductionNode {
/// Return the generated variants(the child nodes).
ArrayRef<ReductionNode *> getVariants() const { return variants; }
+ /// Add a variant to Node's variants
+ void addVariant(ReductionNode * node) { variants.push_back(node); }
+
/// Split the ranges and generate new variants.
ArrayRef<ReductionNode *> generateNewVariants();
@@ -88,7 +91,7 @@ class ReductionNode {
/// patterns. In addition, we only apply rewrite patterns in a certain region.
/// In init(), we will duplicate the module from parent node and locate the
/// corresponding region.
- LogicalResult initialize(ModuleOp parentModule, Region &parentRegion);
+ LogicalResult initialize(ModuleOp parentModule, Region &targetRegion);
private:
/// A custom BFS iterator. The difference between
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 83497143d9669..f07881f15ac77 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -32,64 +32,11 @@ namespace mlir {
using namespace mlir;
-/// We implicitly number each operation in the region and if an operation's
-/// number falls into rangeToKeep, we need to keep it and apply the given
-/// rewrite patterns on it.
-static void applyPatterns(Region ®ion,
- const FrozenRewritePatternSet &patterns,
- ArrayRef<ReductionNode::Range> rangeToKeep,
- bool eraseOpNotInRange) {
- std::vector<Operation *> opsNotInRange;
- std::vector<Operation *> opsInRange;
- size_t keepIndex = 0;
- for (const auto &op : enumerate(region.getOps())) {
- int index = op.index();
- if (keepIndex < rangeToKeep.size() &&
- index == rangeToKeep[keepIndex].second)
- ++keepIndex;
- if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
- opsNotInRange.push_back(&op.value());
- else
- opsInRange.push_back(&op.value());
- }
-
- // `applyOpPatternsGreedily` with folding may erase the ops so we can't do the
- // pattern matching in above iteration. Besides, erase op not-in-range may end
- // up in invalid module, so `applyOpPatternsGreedily` with folding should come
- // before that transform.
- for (Operation *op : opsInRange) {
- // `applyOpPatternsGreedily` with folding returns whether the op is
- // converted. Omit it because we don't have expectation this reduction will
- // be success or not.
- (void)applyOpPatternsGreedily(op, patterns,
- GreedyRewriteConfig().setStrictness(
- GreedyRewriteStrictness::ExistingOps));
- }
-
- if (eraseOpNotInRange)
- for (Operation *op : opsNotInRange) {
- op->dropAllUses();
- op->erase();
- }
-}
-
-/// We will apply the reducer patterns to the operations in the ranges specified
-/// by ReductionNode. Note that we are not able to remove an operation without
-/// replacing it with another valid operation. However, The validity of module
-/// reduction is based on the Tester provided by the user and that means certain
-/// invalid module is still interested by the use. Thus we provide an
-/// alternative way to remove operations, which is using `eraseOpNotInRange` to
-/// erase the operations not in the range specified by ReductionNode.
-template <typename IteratorType>
-static LogicalResult findOptimal(ModuleOp module, Region ®ion,
- const FrozenRewritePatternSet &patterns,
- const Tester &test, bool eraseOpNotInRange) {
- std::pair<Tester::Interestingness, size_t> initStatus =
- test.isInteresting(module);
- // While exploring the reduction tree, we always branch from an interesting
- // node. Thus the root node must be interesting.
- if (initStatus.first != Tester::Interestingness::True)
- return module.emitWarning() << "uninterested module will not be reduced";
+/// We will apply `applyFn` to the operations in the ranges specified by
+/// ReductionNode.
+template <typename IteratorType, typename ApplyFn>
+static LogicalResult findOptimalUsing(ModuleOp module, Region ®ion,
+ const Tester &test, ApplyFn applyFn) {
llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
@@ -101,17 +48,20 @@ static LogicalResult findOptimal(ModuleOp module, Region ®ion,
// Duplicate the module for root node and locate the region in the copy.
if (failed(root->initialize(module, region)))
llvm_unreachable("unexpected initialization failure");
- root->update(initStatus);
+
+ // Create a duplicate of the root node as the first variant of the root.
+ // This will keep the root intact when applying `applyFn`.
+ ReductionNode *firstVariant = allocator.Allocate();
+ new (firstVariant) ReductionNode(root, ranges, allocator);
+ root->addVariant(firstVariant);
ReductionNode *smallestNode = root;
- IteratorType iter(root);
+ IteratorType iter(firstVariant);
while (iter != IteratorType::end()) {
ReductionNode ¤tNode = *iter;
- Region &curRegion = currentNode.getRegion();
- applyPatterns(curRegion, patterns, currentNode.getRanges(),
- eraseOpNotInRange);
+ applyFn(currentNode.getRegion(), currentNode.getRanges());
currentNode.update(test.isInteresting(currentNode.getModule()));
if (currentNode.isInteresting() == Tester::Interestingness::True &&
@@ -125,86 +75,118 @@ static LogicalResult findOptimal(ModuleOp module, Region ®ion,
// the path and apply the reducer to it.
SmallVector<ReductionNode *> trace;
ReductionNode *curNode = smallestNode;
- trace.push_back(curNode);
while (curNode != root) {
- curNode = curNode->getParent();
trace.push_back(curNode);
+ curNode = curNode->getParent();
}
+ if (trace.empty())
+ // If trace is empty, then the smallestNode == root and therefore we were
+ // not successful in reducing the module
+ return failure();
+
// Reduce the region through the optimal path.
while (!trace.empty()) {
ReductionNode *top = trace.pop_back_val();
- applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
+ applyFn(region, top->getStartRanges());
}
- if (test.isInteresting(module).first != Tester::Interestingness::True)
+ std::pair<Tester::Interestingness, size_t> finalStatus =
+ test.isInteresting(module);
+
+ if (finalStatus.first != Tester::Interestingness::True)
llvm::report_fatal_error("Reduced module is not interesting");
- if (test.isInteresting(module).second != smallestNode->getSize())
+ if (finalStatus.second != smallestNode->getSize())
llvm::report_fatal_error(
"Reduced module doesn't have consistent size with smallestNode");
return success();
}
-/// This function attempts to erase all operations within the region currently
-/// being processed.
-static LogicalResult eraseAllOpsInRegion(ModuleOp module, Region ®ion,
- const Tester &test) {
- std::pair<Tester::Interestingness, size_t> initStatus =
- test.isInteresting(module);
-
- // While exploring the reduction tree, we always branch from an interesting
- // node. Thus the root node must be interesting.
- if (initStatus.first != Tester::Interestingness::True)
- return module.emitWarning() << "uninterested module will not be reduced";
- llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
-
- // Setting the ranges to {{0, 0}} will result in the deletion of all ops
- // within the region.
- std::vector<ReductionNode::Range> ranges{{0, 0}};
-
- // We allocate memory on the stack, and the 'allocator' is only used to
- // construct the 'root node'. Since we won't be constructing any child nodes
- // for emptyRegionNode, it is only used within the current scope.
- ReductionNode emptyRegionNode(nullptr, ranges, allocator);
- ReductionNode *root = &emptyRegionNode;
+/// We implicitly number each operation in the region and if an operation's
+/// number falls into rangeToKeep, we'll keep it.
+static void eliminateOperations(Region ®ion,
+ ArrayRef<ReductionNode::Range> rangeToKeep) {
+ std::vector<Operation *> opsNotInRange;
+ size_t keepIndex = 0;
+ for (const auto &op : enumerate(region.getOps())) {
+ int index = op.index();
+ if (keepIndex < rangeToKeep.size() &&
+ index == rangeToKeep[keepIndex].second)
+ ++keepIndex;
+ if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
+ opsNotInRange.push_back(&op.value());
+ }
- // Create a copy of the current IR.
- if (failed(root->initialize(module, region)))
- llvm_unreachable("unexpected initialization failure");
+ for (Operation *op : opsNotInRange) {
+ op->dropAllUses();
+ op->erase();
+ }
+}
- // Erase all operations within the corresponding region of the clone.
- applyPatterns(root->getRegion(), {}, root->getRanges(), true);
- root->update(test.isInteresting(root->getModule()));
- if (root->isInteresting() == Tester::Interestingness::True) {
- // If we can successfully remove all ops in the region, we apply the same
- // transformation to the original IR and return success.
- applyPatterns(region, {}, root->getRanges(), true);
- return success();
+/// We implicitly number each operation in the region and if an operation's
+/// number falls into rangeToApply, we'll apply the given rewrite patterns on
+/// it.
+static void applyPatterns(Region ®ion,
+ const FrozenRewritePatternSet &patterns,
+ ArrayRef<ReductionNode::Range> rangeToApply) {
+ size_t rangeIndex = 0;
+ for (const auto &op : enumerate(region.getOps())) {
+ int index = op.index();
+ if (rangeIndex < rangeToApply.size() &&
+ index == rangeToApply[rangeIndex].second)
+ ++rangeIndex;
+ if (rangeIndex < rangeToApply.size() &&
+ index >= rangeToApply[rangeIndex].first)
+ // `applyOpPatternsGreedily` with folding returns whether the op is
+ // converted. Omit it because we don't have expectation this reduction
+ // will be success or not.
+ (void)applyOpPatternsGreedily(&op.value(), patterns,
+ GreedyRewriteConfig().setStrictness(
+ GreedyRewriteStrictness::ExistingOps));
}
- return failure();
}
template <typename IteratorType>
static LogicalResult findOptimal(ModuleOp module, Region ®ion,
const FrozenRewritePatternSet &patterns,
const Tester &test) {
- // We separate the reduction process into 3 steps, the first one is to erase
- // redundant operations and the second one is to apply the reducer patterns.
- // In the first phase, we attempt to erase all operations within the entire
- // region.
- if (succeeded(eraseAllOpsInRegion(module, region, test)))
+ // We first test the interstingness of the module passed to findOptimal.
+ std::pair<Tester::Interestingness, size_t> initStatus =
+ test.isInteresting(module);
+ if (initStatus.first != Tester::Interestingness::True)
+ // If the module is not interesting, we can return failure
+ return module.emitWarning() << "uninterested module will not be reduced";
+
+ // We separate the reduction process into 3 steps:
+ // In the first step, we attempt to erase all operations within the
+ // entire region.
+ if (succeeded(findOptimalUsing<IteratorType>(
+ module, region, test, [](auto ®ion, auto) {
+ for (auto &block : region.getBlocks())
+ block.clear();
+ ;
+ })))
+ // If clearing the entire region kept the module interesting
+ // we will return success.
return success();
- // In the second phase, we don't apply any patterns so that we only select the
- // range of operations to keep to the module stay interesting.
- if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
- /*eraseOpNotInRange=*/true)))
- return failure();
- // In the third phase, we suppose that no operation is redundant, so we try
- // to rewrite the operation into simpler form.
- return findOptimal<IteratorType>(module, region, patterns, test,
- /*eraseOpNotInRange=*/false);
+ // In the second step, we eliminate redundant operations from the region to
+ // select those that keep the module interesting
+ auto eliminationResult =
+ findOptimalUsing<IteratorType>(module, region, test, eliminateOperations);
+
+ // In the third step, we suppose that no operation is redundant, so we try
+ // to rewrite the operation into simpler form by applying patterns.
+ auto applyPatternsResult = findOptimalUsing<IteratorType>(
+ module, region, test, [&](auto ®ion, auto ranges) {
+ applyPatterns(region, patterns, ranges);
+ });
+
+ if (succeeded(eliminationResult) || succeeded(applyPatternsResult))
+ // if step 2 or 3 was successful, then we return success.
+ return success();
+ return failure();
}
namespace {
diff --git a/mlir/test/mlir-reduce/trivially-dead.mlir b/mlir/test/mlir-reduce/trivially-dead.mlir
new file mode 100644
index 0000000000000..ee69f754a1971
--- /dev/null
+++ b/mlir/test/mlir-reduce/trivially-dead.mlir
@@ -0,0 +1,16 @@
+// UNSUPPORTED: system-windows
+// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/trivially-dead.sh' | FileCheck %s
+// We are testing the ability of keeping trivially-dead yet interesting code
+
+func.func @trivially_dead() {
+ %0 = arith.constant 1 : i32
+ %1 = arith.constant 2 : i32
+ %2 = arith.constant 3 : i32
+ %3 = arith.constant 4 : i32
+ return
+}
+
+// CHECK-LABEL: func.func @trivially_dead() {
+// CHECK-NEXT: {{.*}} = arith.constant 2 : i32
+// CHECK-NEXT: return
+// CHECK-NEXT: }
diff --git a/mlir/test/mlir-reduce/trivially-dead.sh b/mlir/test/mlir-reduce/trivially-dead.sh
new file mode 100755
index 0000000000000..3a973c12f0486
--- /dev/null
+++ b/mlir/test/mlir-reduce/trivially-dead.sh
@@ -0,0 +1,4 @@
+#!/bin/sh
+
+# break only on `arith.constat 2 : i32`
+! cat $1 | grep "arith.constant 2 : i32" 2>&1 1>/dev/null
``````````
</details>
https://github.com/llvm/llvm-project/pull/190560
More information about the Mlir-commits
mailing list