[Mlir-commits] [mlir] [MLIR] Add replace-operands option to mlir-reduce (PR #153100)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Aug 16 15:55:36 PDT 2025
https://github.com/aidint updated https://github.com/llvm/llvm-project/pull/153100
>From 720b851f6596bede3cd34107ac103425150b27e6 Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Tue, 5 Aug 2025 20:38:35 +0200
Subject: [PATCH 1/4] add replaceOperands option to mlir-reduce
---
mlir/include/mlir/Reducer/Passes.td | 4 +++
mlir/lib/Reducer/ReductionTreePass.cpp | 38 ++++++++++++++++++++------
2 files changed, 34 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Reducer/Passes.td b/mlir/include/mlir/Reducer/Passes.td
index 624e2e1edc329..1d453e31c3b2c 100644
--- a/mlir/include/mlir/Reducer/Passes.td
+++ b/mlir/include/mlir/Reducer/Passes.td
@@ -31,6 +31,10 @@ def ReductionTreePass : Pass<"reduction-tree"> {
Option<"traversalModeId", "traversal-mode", "unsigned",
/* default */"0",
"The graph traversal mode, the default is single-path mode">,
+ Option<"replaceOperands", "replace-operands", "bool",
+ /* default */"false",
+ "Whether the pass should replace operands with previously defined values with the same type">,
+
] # CommonReductionPassOptions.options;
}
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 5b49204013cc0..29b2cfde8b1b8 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -15,11 +15,13 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/Dominance.h"
#include "mlir/Reducer/Passes.h"
#include "mlir/Reducer/ReductionNode.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Reducer/Tester.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/ArrayRef.h"
@@ -38,7 +40,7 @@ using namespace mlir;
static void applyPatterns(Region ®ion,
const FrozenRewritePatternSet &patterns,
ArrayRef<ReductionNode::Range> rangeToKeep,
- bool eraseOpNotInRange) {
+ bool eraseOpNotInRange, bool replaceOperands) {
std::vector<Operation *> opsNotInRange;
std::vector<Operation *> opsInRange;
size_t keepIndex = 0;
@@ -53,17 +55,33 @@ static void applyPatterns(Region ®ion,
opsInRange.push_back(&op.value());
}
+ DominanceInfo domInfo(region.getParentOp());
+ mlir::DenseMap<mlir::Type, mlir::SmallVector<mlir::Value, 5>> valueMap;
+
// `applyOpPatternsGreedily` with folding may erase the ops so we can't do the
// pattern matching in above iteration. Besides, erase op not-in-range may end
// up in invalid module, so `applyOpPatternsGreedily` with folding should come
// before that transform.
for (Operation *op : opsInRange) {
+ if (replaceOperands)
+ for (auto operandTie : llvm::enumerate(op->getOperands())) {
+ size_t index = operandTie.index();
+ auto operand = operandTie.value();
+ for (auto candidate : valueMap[operand.getType()])
+ if (domInfo.properlyDominates(candidate, op))
+ op->setOperand(index, candidate);
+ }
+
// `applyOpPatternsGreedily` with folding returns whether the op is
// converted. Omit it because we don't have expectation this reduction will
// be success or not.
(void)applyOpPatternsGreedily(op, patterns,
GreedyRewriteConfig().setStrictness(
GreedyRewriteStrictness::ExistingOps));
+
+ if (op && replaceOperands)
+ for (auto result : op->getResults())
+ valueMap[result.getType()].push_back(result);
}
if (eraseOpNotInRange)
@@ -83,7 +101,8 @@ static void applyPatterns(Region ®ion,
template <typename IteratorType>
static LogicalResult findOptimal(ModuleOp module, Region ®ion,
const FrozenRewritePatternSet &patterns,
- const Tester &test, bool eraseOpNotInRange) {
+ const Tester &test, bool eraseOpNotInRange,
+ bool replaceOperands) {
std::pair<Tester::Interestingness, size_t> initStatus =
test.isInteresting(module);
// While exploring the reduction tree, we always branch from an interesting
@@ -111,7 +130,7 @@ static LogicalResult findOptimal(ModuleOp module, Region ®ion,
Region &curRegion = currentNode.getRegion();
applyPatterns(curRegion, patterns, currentNode.getRanges(),
- eraseOpNotInRange);
+ eraseOpNotInRange, replaceOperands);
currentNode.update(test.isInteresting(currentNode.getModule()));
if (currentNode.isInteresting() == Tester::Interestingness::True &&
@@ -134,7 +153,8 @@ static LogicalResult findOptimal(ModuleOp module, Region ®ion,
// Reduce the region through the optimal path.
while (!trace.empty()) {
ReductionNode *top = trace.pop_back_val();
- applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
+ applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange,
+ replaceOperands);
}
if (test.isInteresting(module).first != Tester::Interestingness::True)
@@ -148,19 +168,21 @@ static LogicalResult findOptimal(ModuleOp module, Region ®ion,
template <typename IteratorType>
static LogicalResult findOptimal(ModuleOp module, Region ®ion,
const FrozenRewritePatternSet &patterns,
- const Tester &test) {
+ const Tester &test, bool replaceOperands) {
// We separate the reduction process into 2 steps, the first one is to erase
// redundant operations and the second one is to apply the reducer patterns.
// In the first phase, we don't apply any patterns so that we only select the
// range of operations to keep to the module stay interesting.
if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
- /*eraseOpNotInRange=*/true)))
+ /*eraseOpNotInRange=*/true,
+ replaceOperands)))
return failure();
// In the second phase, we suppose that no operation is redundant, so we try
// to rewrite the operation into simpler form.
return findOptimal<IteratorType>(module, region, patterns, test,
- /*eraseOpNotInRange=*/false);
+ /*eraseOpNotInRange=*/false,
+ /*replaceOperands=*/false);
}
namespace {
@@ -248,7 +270,7 @@ LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) {
switch (traversalModeId) {
case TraversalMode::SinglePath:
return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
- module, region, reducerPatterns, test);
+ module, region, reducerPatterns, test, replaceOperands);
default:
return module.emitError() << "unsupported traversal mode detected";
}
>From 8e0b6ccaa02a5d46f9088289cf924a11b52e853b Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Tue, 12 Aug 2025 00:43:18 +0200
Subject: [PATCH 2/4] add test
---
mlir/test/mlir-reduce/replace-operands.mlir | 13 +++++++++++++
1 file changed, 13 insertions(+)
create mode 100644 mlir/test/mlir-reduce/replace-operands.mlir
diff --git a/mlir/test/mlir-reduce/replace-operands.mlir b/mlir/test/mlir-reduce/replace-operands.mlir
new file mode 100644
index 0000000000000..261db546b9348
--- /dev/null
+++ b/mlir/test/mlir-reduce/replace-operands.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/failure-test.sh replace-operands=true' | FileCheck %s
+
+// CHECK-LABEL: func.func @main
+func.func @main() {
+ // CHECK-NEXT: %[[RESULT:.*]] = arith.constant 2 : i32
+ // CHECK-NEXT: {{.*}} = "test.op_crash"(%[[RESULT]], %[[RESULT]]) : (i32, i32) -> i32
+ // CHECK-NEXT return
+
+ %c1 = arith.constant 3 : i32
+ %c2 = arith.constant 2 : i32
+ %2 = "test.op_crash" (%c1, %c2) : (i32, i32) -> i32
+ return
+}
>From c20f63797f91c1195f7cef46f9370bccc5f7ea85 Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Tue, 12 Aug 2025 22:42:44 +0200
Subject: [PATCH 3/4] remove test from windows
---
mlir/test/mlir-reduce/replace-operands.mlir | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/test/mlir-reduce/replace-operands.mlir b/mlir/test/mlir-reduce/replace-operands.mlir
index 261db546b9348..7f722d0f21e94 100644
--- a/mlir/test/mlir-reduce/replace-operands.mlir
+++ b/mlir/test/mlir-reduce/replace-operands.mlir
@@ -1,3 +1,4 @@
+// UNSUPPORTED: system-windows
// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/failure-test.sh replace-operands=true' | FileCheck %s
// CHECK-LABEL: func.func @main
>From aa7f3faf9aaaacb6a82535c8327bfc4f8b05a000 Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Sun, 17 Aug 2025 00:55:21 +0200
Subject: [PATCH 4/4] break if replacement found
---
mlir/lib/Reducer/ReductionTreePass.cpp | 4 +++-
mlir/test/mlir-reduce/replace-operands.mlir | 2 +-
2 files changed, 4 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 29b2cfde8b1b8..c106bf61f6cb1 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -68,8 +68,10 @@ static void applyPatterns(Region ®ion,
size_t index = operandTie.index();
auto operand = operandTie.value();
for (auto candidate : valueMap[operand.getType()])
- if (domInfo.properlyDominates(candidate, op))
+ if (domInfo.properlyDominates(candidate, op)) {
op->setOperand(index, candidate);
+ break;
+ }
}
// `applyOpPatternsGreedily` with folding returns whether the op is
diff --git a/mlir/test/mlir-reduce/replace-operands.mlir b/mlir/test/mlir-reduce/replace-operands.mlir
index 7f722d0f21e94..b79a3aa663db3 100644
--- a/mlir/test/mlir-reduce/replace-operands.mlir
+++ b/mlir/test/mlir-reduce/replace-operands.mlir
@@ -3,7 +3,7 @@
// CHECK-LABEL: func.func @main
func.func @main() {
- // CHECK-NEXT: %[[RESULT:.*]] = arith.constant 2 : i32
+ // CHECK-NEXT: %[[RESULT:.*]] = arith.constant 3 : i32
// CHECK-NEXT: {{.*}} = "test.op_crash"(%[[RESULT]], %[[RESULT]]) : (i32, i32) -> i32
// CHECK-NEXT return
More information about the Mlir-commits
mailing list