[Mlir-commits] [mlir] [MLIR] Add replace-operands option to mlir-reduce (PR #153100)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Aug 16 15:55:36 PDT 2025


https://github.com/aidint updated https://github.com/llvm/llvm-project/pull/153100

>From 720b851f6596bede3cd34107ac103425150b27e6 Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Tue, 5 Aug 2025 20:38:35 +0200
Subject: [PATCH 1/4] add replaceOperands option to mlir-reduce

---
 mlir/include/mlir/Reducer/Passes.td    |  4 +++
 mlir/lib/Reducer/ReductionTreePass.cpp | 38 ++++++++++++++++++++------
 2 files changed, 34 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Reducer/Passes.td b/mlir/include/mlir/Reducer/Passes.td
index 624e2e1edc329..1d453e31c3b2c 100644
--- a/mlir/include/mlir/Reducer/Passes.td
+++ b/mlir/include/mlir/Reducer/Passes.td
@@ -31,6 +31,10 @@ def ReductionTreePass : Pass<"reduction-tree"> {
     Option<"traversalModeId", "traversal-mode", "unsigned",
            /* default */"0",
            "The graph traversal mode, the default is single-path mode">,
+    Option<"replaceOperands", "replace-operands", "bool",
+           /* default */"false",
+           "Whether the pass should replace operands with previously defined values with the same type">,
+
   ] # CommonReductionPassOptions.options;
 }
 
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 5b49204013cc0..29b2cfde8b1b8 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -15,11 +15,13 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/Reducer/Passes.h"
 #include "mlir/Reducer/ReductionNode.h"
 #include "mlir/Reducer/ReductionPatternInterface.h"
 #include "mlir/Reducer/Tester.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #include "llvm/ADT/ArrayRef.h"
@@ -38,7 +40,7 @@ using namespace mlir;
 static void applyPatterns(Region &region,
                           const FrozenRewritePatternSet &patterns,
                           ArrayRef<ReductionNode::Range> rangeToKeep,
-                          bool eraseOpNotInRange) {
+                          bool eraseOpNotInRange, bool replaceOperands) {
   std::vector<Operation *> opsNotInRange;
   std::vector<Operation *> opsInRange;
   size_t keepIndex = 0;
@@ -53,17 +55,33 @@ static void applyPatterns(Region &region,
       opsInRange.push_back(&op.value());
   }
 
+  DominanceInfo domInfo(region.getParentOp());
+  mlir::DenseMap<mlir::Type, mlir::SmallVector<mlir::Value, 5>> valueMap;
+
   // `applyOpPatternsGreedily` with folding may erase the ops so we can't do the
   // pattern matching in above iteration. Besides, erase op not-in-range may end
   // up in invalid module, so `applyOpPatternsGreedily` with folding should come
   // before that transform.
   for (Operation *op : opsInRange) {
+    if (replaceOperands)
+      for (auto operandTie : llvm::enumerate(op->getOperands())) {
+        size_t index = operandTie.index();
+        auto operand = operandTie.value();
+        for (auto candidate : valueMap[operand.getType()])
+          if (domInfo.properlyDominates(candidate, op))
+            op->setOperand(index, candidate);
+      }
+
     // `applyOpPatternsGreedily` with folding returns whether the op is
     // converted. Omit it because we don't have expectation this reduction will
     // be success or not.
     (void)applyOpPatternsGreedily(op, patterns,
                                   GreedyRewriteConfig().setStrictness(
                                       GreedyRewriteStrictness::ExistingOps));
+
+    if (op && replaceOperands)
+      for (auto result : op->getResults())
+        valueMap[result.getType()].push_back(result);
   }
 
   if (eraseOpNotInRange)
@@ -83,7 +101,8 @@ static void applyPatterns(Region &region,
 template <typename IteratorType>
 static LogicalResult findOptimal(ModuleOp module, Region &region,
                                  const FrozenRewritePatternSet &patterns,
-                                 const Tester &test, bool eraseOpNotInRange) {
+                                 const Tester &test, bool eraseOpNotInRange,
+                                 bool replaceOperands) {
   std::pair<Tester::Interestingness, size_t> initStatus =
       test.isInteresting(module);
   // While exploring the reduction tree, we always branch from an interesting
@@ -111,7 +130,7 @@ static LogicalResult findOptimal(ModuleOp module, Region &region,
     Region &curRegion = currentNode.getRegion();
 
     applyPatterns(curRegion, patterns, currentNode.getRanges(),
-                  eraseOpNotInRange);
+                  eraseOpNotInRange, replaceOperands);
     currentNode.update(test.isInteresting(currentNode.getModule()));
 
     if (currentNode.isInteresting() == Tester::Interestingness::True &&
@@ -134,7 +153,8 @@ static LogicalResult findOptimal(ModuleOp module, Region &region,
   // Reduce the region through the optimal path.
   while (!trace.empty()) {
     ReductionNode *top = trace.pop_back_val();
-    applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
+    applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange,
+                  replaceOperands);
   }
 
   if (test.isInteresting(module).first != Tester::Interestingness::True)
@@ -148,19 +168,21 @@ static LogicalResult findOptimal(ModuleOp module, Region &region,
 template <typename IteratorType>
 static LogicalResult findOptimal(ModuleOp module, Region &region,
                                  const FrozenRewritePatternSet &patterns,
-                                 const Tester &test) {
+                                 const Tester &test, bool replaceOperands) {
   // We separate the reduction process into 2 steps, the first one is to erase
   // redundant operations and the second one is to apply the reducer patterns.
 
   // In the first phase, we don't apply any patterns so that we only select the
   // range of operations to keep to the module stay interesting.
   if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
-                                       /*eraseOpNotInRange=*/true)))
+                                       /*eraseOpNotInRange=*/true,
+                                       replaceOperands)))
     return failure();
   // In the second phase, we suppose that no operation is redundant, so we try
   // to rewrite the operation into simpler form.
   return findOptimal<IteratorType>(module, region, patterns, test,
-                                   /*eraseOpNotInRange=*/false);
+                                   /*eraseOpNotInRange=*/false,
+                                   /*replaceOperands=*/false);
 }
 
 namespace {
@@ -248,7 +270,7 @@ LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
   switch (traversalModeId) {
   case TraversalMode::SinglePath:
     return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
-        module, region, reducerPatterns, test);
+        module, region, reducerPatterns, test, replaceOperands);
   default:
     return module.emitError() << "unsupported traversal mode detected";
   }

>From 8e0b6ccaa02a5d46f9088289cf924a11b52e853b Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Tue, 12 Aug 2025 00:43:18 +0200
Subject: [PATCH 2/4] add test

---
 mlir/test/mlir-reduce/replace-operands.mlir | 13 +++++++++++++
 1 file changed, 13 insertions(+)
 create mode 100644 mlir/test/mlir-reduce/replace-operands.mlir

diff --git a/mlir/test/mlir-reduce/replace-operands.mlir b/mlir/test/mlir-reduce/replace-operands.mlir
new file mode 100644
index 0000000000000..261db546b9348
--- /dev/null
+++ b/mlir/test/mlir-reduce/replace-operands.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/failure-test.sh replace-operands=true' | FileCheck %s
+
+// CHECK-LABEL: func.func @main
+func.func @main() {
+  // CHECK-NEXT: %[[RESULT:.*]] = arith.constant 2 : i32
+  // CHECK-NEXT: {{.*}} = "test.op_crash"(%[[RESULT]], %[[RESULT]]) : (i32, i32) -> i32
+  // CHECK-NEXT return
+
+  %c1 = arith.constant 3 : i32
+  %c2 = arith.constant 2 : i32
+  %2 = "test.op_crash" (%c1, %c2) : (i32, i32) -> i32
+  return
+}

>From c20f63797f91c1195f7cef46f9370bccc5f7ea85 Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Tue, 12 Aug 2025 22:42:44 +0200
Subject: [PATCH 3/4] remove test from windows

---
 mlir/test/mlir-reduce/replace-operands.mlir | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/test/mlir-reduce/replace-operands.mlir b/mlir/test/mlir-reduce/replace-operands.mlir
index 261db546b9348..7f722d0f21e94 100644
--- a/mlir/test/mlir-reduce/replace-operands.mlir
+++ b/mlir/test/mlir-reduce/replace-operands.mlir
@@ -1,3 +1,4 @@
+// UNSUPPORTED: system-windows
 // RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/failure-test.sh replace-operands=true' | FileCheck %s
 
 // CHECK-LABEL: func.func @main

>From aa7f3faf9aaaacb6a82535c8327bfc4f8b05a000 Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Sun, 17 Aug 2025 00:55:21 +0200
Subject: [PATCH 4/4] break if replacement found

---
 mlir/lib/Reducer/ReductionTreePass.cpp      | 4 +++-
 mlir/test/mlir-reduce/replace-operands.mlir | 2 +-
 2 files changed, 4 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 29b2cfde8b1b8..c106bf61f6cb1 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -68,8 +68,10 @@ static void applyPatterns(Region &region,
         size_t index = operandTie.index();
         auto operand = operandTie.value();
         for (auto candidate : valueMap[operand.getType()])
-          if (domInfo.properlyDominates(candidate, op))
+          if (domInfo.properlyDominates(candidate, op)) {
             op->setOperand(index, candidate);
+            break;
+          }
       }
 
     // `applyOpPatternsGreedily` with folding returns whether the op is
diff --git a/mlir/test/mlir-reduce/replace-operands.mlir b/mlir/test/mlir-reduce/replace-operands.mlir
index 7f722d0f21e94..b79a3aa663db3 100644
--- a/mlir/test/mlir-reduce/replace-operands.mlir
+++ b/mlir/test/mlir-reduce/replace-operands.mlir
@@ -3,7 +3,7 @@
 
 // CHECK-LABEL: func.func @main
 func.func @main() {
-  // CHECK-NEXT: %[[RESULT:.*]] = arith.constant 2 : i32
+  // CHECK-NEXT: %[[RESULT:.*]] = arith.constant 3 : i32
   // CHECK-NEXT: {{.*}} = "test.op_crash"(%[[RESULT]], %[[RESULT]]) : (i32, i32) -> i32
   // CHECK-NEXT return
 



More information about the Mlir-commits mailing list