[Mlir-commits] [mlir] [mlir][reducer] Introduce the materialization mechanism in the reduction-tree and fix the logic for deleting operations (PR #185445)

lonely eagle llvmlistbot at llvm.org
Sun Apr 5 22:56:40 PDT 2026


https://github.com/linuxlonelyeagle updated https://github.com/llvm/llvm-project/pull/185445

>From 8cd2423bdad47f0a94cb8a3ff6d0f0ce07982f2c Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Mon, 6 Apr 2026 05:56:10 +0000
Subject: [PATCH] rebase main.

---
 mlir/lib/Reducer/ReductionTreePass.cpp    | 87 ++++++++++++++++++++---
 mlir/test/mlir-reduce/reduction-tree.mlir | 31 ++++++++
 2 files changed, 108 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 12358f7d71688..25d41c8b815af 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -59,10 +59,83 @@ static void applyPatterns(Region &region,
       opsInRange.push_back(&op.value());
   }
 
-  // `applyOpPatternsGreedily` with folding may erase the ops so we can't do the
-  // pattern matching in above iteration. Besides, erase op not-in-range may end
-  // up in invalid module, so `applyOpPatternsGreedily` with folding should come
-  // before that transform.
+  if (eraseOpNotInRange) {
+
+    // clang-format off
+    LLVM_DEBUG(
+      LDBG() << "before erase ops not in ranges, keep the ranges:";
+      for (ReductionNode::Range range : rangeToKeep) {
+        LDBG() << "[" << range.first << " " << range.second << ")";
+      } 
+      LDBG() << "region:\n" << region;
+    );
+    // clang-format on
+
+    // The map uses the results of the operations as keys, while the values
+    // represent the remaining user count for each result. We iterate through
+    // `opsNotInRange` to update this map; if a key's value remains greater than
+    // zero, it indicates that materialization is required for that specific
+    // value.
+    DenseMap<Value, int64_t> valueToMaterializationMap;
+
+    for (Operation *op : opsNotInRange) {
+      if (op->hasTrait<mlir::OpTrait::IsTerminator>())
+        continue;
+
+      for (Value result : op->getResults())
+        valueToMaterializationMap[result] = result.getNumUses();
+
+      // Use a set to store all operands to prevent the map value from being
+      // decremented multiple times if an operation uses the same operand more
+      // than once.
+      SmallPtrSet<Value, 4> operandSet(op->getOperands().begin(),
+                                       op->getOperands().end());
+      for (Value operand : operandSet)
+        // If an `operand` is a key in the map, it indicates that the operand
+        // was defined within `opsNotInRange`.
+        if (valueToMaterializationMap.contains(operand))
+          --valueToMaterializationMap[operand];
+    }
+
+    SmallVector<Type, 4> materializationTypes;
+    SmallVector<Value, 4> valueNeedMaterialization;
+    for (auto mapValue : valueToMaterializationMap) {
+      // If a key in the map has a value greater than zero, it indicates that
+      // there are still operations in the remaining IR using this key.
+      // Therefore, we should materialize it.
+      if (mapValue.second > 0) {
+        materializationTypes.push_back(mapValue.first.getType());
+        valueNeedMaterialization.push_back(mapValue.first);
+      }
+    }
+
+    if (!materializationTypes.empty()) {
+      OpBuilder b(region.getContext());
+      b.setInsertionPointToStart(&region.front());
+      auto castOp = UnrealizedConversionCastOp::create(
+          b, b.getUnknownLoc(), materializationTypes, {});
+      for (auto [src, res] :
+           llvm::zip_equal(valueNeedMaterialization, castOp.getResults())) {
+        src.replaceAllUsesWith(res);
+      }
+    }
+
+    for (Operation *op : opsNotInRange) {
+      if (op->hasTrait<mlir::OpTrait::IsTerminator>())
+        continue;
+      op->dropAllUses();
+      op->erase();
+    }
+    LDBG() << "after erase ops not in ranges:\n" << region;
+  }
+
+  // After removing `opsNotInRange`, we apply `applyOpPatternsGreedily` both to
+  // run specific patterns and to eliminate operations that have no users. The
+  // reason we do not directly delete all userless operations is that some may
+  // be `interesting` ops. Therefore, we utilize `applyOpPatternsGreedily` here
+  // instead. It is essential to further eliminate redundant operations here;
+  // otherwise, the reduction will fail if the size of the deleted ops is
+  // smaller than the newly introduced `unrealized_conversion_cast`.
   for (Operation *op : opsInRange) {
     // `applyOpPatternsGreedily` with folding returns whether the op is
     // converted. Omit it because we don't have expectation this reduction will
@@ -71,12 +144,6 @@ static void applyPatterns(Region &region,
                                   GreedyRewriteConfig().setStrictness(
                                       GreedyRewriteStrictness::ExistingOps));
   }
-
-  if (eraseOpNotInRange)
-    for (Operation *op : opsNotInRange) {
-      op->dropAllUses();
-      op->erase();
-    }
 }
 
 /// We will apply the reducer patterns to the operations in the ranges specified
diff --git a/mlir/test/mlir-reduce/reduction-tree.mlir b/mlir/test/mlir-reduce/reduction-tree.mlir
index b053a111e9a16..3775e9107cca7 100644
--- a/mlir/test/mlir-reduce/reduction-tree.mlir
+++ b/mlir/test/mlir-reduce/reduction-tree.mlir
@@ -123,3 +123,34 @@ func.func @switch_reduction(%arg0: i32, %arg1: memref<2xf32>, %arg2: memref<2xf3
   return
 }
 // CHECK-NEXT:  "test.op_crash"(%[[ARG1]], %[[ARG2]])
+
+// -----
+
+// CHECK-LABEL: func @materialization
+//  CHECK-SAME:   %[[ARG0:.*]]: i32
+func.func @materialization(%arg0: i32) -> (i32) {
+  %0 = "test.op_crash_long" (%arg0, %arg0, %arg0) : (i32, i32, i32) -> i32
+  %1 = arith.addi %0, %0 : i32
+  %2 = arith.addi %1, %1 : i32
+  return %2 : i32
+}
+// CHECK-NEXT: %[[CAST:.*]] = builtin.unrealized_conversion_cast to i32
+// CHECK-NEXT: %{{.*}} = "test.op_crash_short"() : () -> i32
+// CHECK-NEXT: return %[[CAST]] : i32
+
+// -----
+
+// In this case, when the add operation was replaced by an unrealized_conversion_cast,
+// the file size actually increased, leading to a failure in materialization.
+
+// CHECK-LABEL: func @no_materialization
+//  CHECK-SAME:   %[[ARG0:.*]]: i32
+func.func @no_materialization(%arg0: i32) -> (i32) {
+  %0 = "test.op_crash_long" (%arg0, %arg0, %arg0) : (i32, i32, i32) -> i32
+  %1 = arith.addi %0, %0 : i32
+  return %1 : i32
+}
+// CHECK-NEXT: %[[CRASH:.*]] = "test.op_crash_short"() : () -> i32
+// CHECK-NEXT: %[[ADDI:.*]] = arith.addi %[[CRASH]], %[[CRASH]] : i32
+// CHECK-NEXT:  return %[[ADDI]] : i32
+



More information about the Mlir-commits mailing list