[Mlir-commits] [mlir] [mlir] add hoist-pure-ops to mlir (PR #168715)

lonely eagle llvmlistbot at llvm.org
Thu Nov 20 06:19:08 PST 2025


https://github.com/linuxlonelyeagle updated https://github.com/llvm/llvm-project/pull/168715

>From ca50a53e98e7ed05029fdbec5226bb19c9322353 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Wed, 19 Nov 2025 01:48:15 +0000
Subject: [PATCH 1/4] add tablegen pass define.

---
 mlir/include/mlir/Transforms/Passes.td | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 28b4a01cf0ecd..c74ce1946cb03 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -636,4 +636,8 @@ def BubbleDownMemorySpaceCasts :
   }];
 }
 
+def HoistPureOps :
+    Pass<"hoist-pure-ops"> {
+}
+
 #endif // MLIR_TRANSFORMS_PASSES

>From ae17efd27ed5f88a8ffc84be1a452cbfcb10336d Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Wed, 19 Nov 2025 10:06:46 +0000
Subject: [PATCH 2/4] add baisc implement of hoist pure pass.

---
 mlir/include/mlir/Transforms/Passes.h |  1 +
 mlir/lib/Transforms/CMakeLists.txt    |  1 +
 mlir/lib/Transforms/HoistPureOps.cpp  | 89 +++++++++++++++++++++++++++
 3 files changed, 91 insertions(+)
 create mode 100644 mlir/lib/Transforms/HoistPureOps.cpp

diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 724da009e70f1..690e9e88a87b8 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -37,6 +37,7 @@ class GreedyRewriteConfig;
 #define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
 #define GEN_PASS_DECL_CONTROLFLOWSINK
 #define GEN_PASS_DECL_GENERATERUNTIMEVERIFICATION
+#define GEN_PASS_DECL_HOISTPUREOPS
 #define GEN_PASS_DECL_LOOPINVARIANTCODEMOTION
 #define GEN_PASS_DECL_INLINER
 #define GEN_PASS_DECL_MEM2REG
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 54b67f5c7a91e..b32865ed01b82 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -21,6 +21,7 @@ add_mlir_library(MLIRTransforms
   SymbolPrivatize.cpp
   TopologicalSort.cpp
   ViewOpGraph.cpp
+  HoistPureOps.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
diff --git a/mlir/lib/Transforms/HoistPureOps.cpp b/mlir/lib/Transforms/HoistPureOps.cpp
new file mode 100644
index 0000000000000..c7c515faa001f
--- /dev/null
+++ b/mlir/lib/Transforms/HoistPureOps.cpp
@@ -0,0 +1,89 @@
+//===- HoistPureOps.cpp - Hoist Pure ops ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the function of hoist the pure op based on SSA
+// dominance.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_HOISTPUREOPS
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+/// Return the dominated Value.
+static Value getDomaincedValue(DominanceInfo &dominanceInfo, Value a, Value b) {
+  Block *aB = a.getParentBlock();
+  Block *bB = b.getParentBlock();
+  if (isa_and_present<BlockArgument>(a) && isa_and_present<BlockArgument>(b)) {
+    return dominanceInfo.dominates(aB, bB) ? b : a;
+  } else if (isa_and_present<BlockArgument>(a) ||
+             isa_and_present<BlockArgument>(b)) {
+    if (aB == bB)
+      return b;
+    return dominanceInfo.dominates(aB, bB) ? b : a;
+  } else {
+    Operation *aDefineOp = a.getDefiningOp();
+    Operation *bDefineOp = b.getDefiningOp();
+    return dominanceInfo.dominates(aDefineOp, bDefineOp) ? b : a;
+  }
+}
+
+/// Find the hoisting position for the pure op.
+static Value getDestPos(Operation *op) {
+  DominanceInfo dominanceInfo(op);
+  SmallVector<Value> operands(op->getOperands());
+  if (operands.empty())
+    return {};
+  Value ret = operands[0];
+  for (int i = 1, e = operands.size(); i < e; ++i) {
+    ret = getDomaincedValue(dominanceInfo, ret, operands[i]);
+  }
+  return ret;
+}
+
+/// Hoist single pure op.
+static void hoistPureOp(RewriterBase &rewriter, Operation *op) {
+  Value pos = getDestPos(op);
+  if (!pos)
+    return;
+
+  if (Operation *defineOp = pos.getDefiningOp()) {
+    rewriter.moveOpAfter(op, defineOp);
+    return;
+  }
+  auto argument = cast<BlockArgument>(pos);
+  rewriter.moveOpBefore(op, &argument.getOwner()->front());
+}
+
+struct HoistPureOps : public impl::HoistPureOpsBase<HoistPureOps> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void HoistPureOps::runOnOperation() {
+  Operation *module = getOperation();
+  IRRewriter rewriter(module->getContext());
+  module->walk<WalkOrder::PreOrder>([&](Operation *op) {
+    if (op->hasTrait<mlir::OpTrait::IsTerminator>())
+      return;
+    if (isPure(op)) {
+      hoistPureOp(rewriter, op);
+    }
+  });
+}

>From 397448aa3b0c5bf0a5d0e751d04da1fffc653a03 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Wed, 19 Nov 2025 13:31:09 +0000
Subject: [PATCH 3/4] handle bound condition.

---
 mlir/lib/Transforms/HoistPureOps.cpp     | 30 ++++++++---
 mlir/test/Transforms/hoist-pure-ops.mlir | 67 ++++++++++++++++++++++++
 2 files changed, 91 insertions(+), 6 deletions(-)
 create mode 100644 mlir/test/Transforms/hoist-pure-ops.mlir

diff --git a/mlir/lib/Transforms/HoistPureOps.cpp b/mlir/lib/Transforms/HoistPureOps.cpp
index c7c515faa001f..86b756b717bab 100644
--- a/mlir/lib/Transforms/HoistPureOps.cpp
+++ b/mlir/lib/Transforms/HoistPureOps.cpp
@@ -16,12 +16,15 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/Passes.h"
+#include "llvm/Support/DebugLog.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_HOISTPUREOPS
 #include "mlir/Transforms/Passes.h.inc"
 } // namespace mlir
 
+#define DEBUG_TYPE "hoist-pure-ops"
+
 using namespace mlir;
 
 namespace {
@@ -30,13 +33,22 @@ namespace {
 static Value getDomaincedValue(DominanceInfo &dominanceInfo, Value a, Value b) {
   Block *aB = a.getParentBlock();
   Block *bB = b.getParentBlock();
-  if (isa_and_present<BlockArgument>(a) && isa_and_present<BlockArgument>(b)) {
-    return dominanceInfo.dominates(aB, bB) ? b : a;
-  } else if (isa_and_present<BlockArgument>(a) ||
-             isa_and_present<BlockArgument>(b)) {
-    if (aB == bB)
-      return b;
+  if (isa<BlockArgument>(a) && isa<BlockArgument>(b)) {
     return dominanceInfo.dominates(aB, bB) ? b : a;
+  } else if (isa<BlockArgument>(a) || isa<BlockArgument>(b)) {
+    if (aB != bB)
+      return dominanceInfo.dominates(aB, bB) ? b : a;
+    if (auto aArg = dyn_cast<BlockArgument>(a)) {
+      Operation *aFrontOp = &aArg.getOwner()->front();
+      if (aFrontOp == b.getDefiningOp())
+        return b;
+      return dominanceInfo.dominates(aFrontOp, b.getDefiningOp()) ? b : a;
+    }
+    auto bArg = cast<BlockArgument>(b);
+    Operation *bFrontOp = &bArg.getOwner()->front();
+    if (bFrontOp == a.getDefiningOp())
+      return a;
+    return dominanceInfo.dominates(a.getDefiningOp(), bFrontOp) ? b : a;
   } else {
     Operation *aDefineOp = a.getDefiningOp();
     Operation *bDefineOp = b.getDefiningOp();
@@ -64,10 +76,16 @@ static void hoistPureOp(RewriterBase &rewriter, Operation *op) {
     return;
 
   if (Operation *defineOp = pos.getDefiningOp()) {
+    LDBG() << "move " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+           << " after " << OpWithFlags(op, OpPrintingFlags().skipRegions());
     rewriter.moveOpAfter(op, defineOp);
     return;
   }
   auto argument = cast<BlockArgument>(pos);
+  LDBG() << "move " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+         << " before "
+         << OpWithFlags(&argument.getOwner()->front(),
+                        OpPrintingFlags().skipRegions());
   rewriter.moveOpBefore(op, &argument.getOwner()->front());
 }
 
diff --git a/mlir/test/Transforms/hoist-pure-ops.mlir b/mlir/test/Transforms/hoist-pure-ops.mlir
new file mode 100644
index 0000000000000..d719e84862134
--- /dev/null
+++ b/mlir/test/Transforms/hoist-pure-ops.mlir
@@ -0,0 +1,67 @@
+// RUN: mlir-opt %s -hoist-pure-ops -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: func @hoist_cast_pos
+//  CHECK-SAME:   %[[ARG0:.*]]: memref<10xf32>,
+//  CHECK-SAME:   %[[ARG1:.*]]: i1
+func.func @hoist_cast_pos(%arg: memref<10xf32>, %arg1: i1) -> (memref<?xf32>) {
+  //      CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]]
+  //      CHECK: %[[CAST_1:.*]] = memref.cast %[[ARG0]]
+  // CHECK-NEXT: cf.cond_br %[[ARG1]]
+  cf.cond_br %arg1, ^bb1, ^bb2
+^bb1:
+  %cast = memref.cast %arg : memref<10xf32> to memref<?xf32>
+  // CHECK: return %[[CAST_1]]
+  return %cast : memref<?xf32>
+^bb2:
+  %cast1 = memref.cast %arg : memref<10xf32> to memref<?xf32>
+  // CHECK: return %[[CAST_0]]
+  return %cast1 : memref<?xf32> 
+}
+
+// -----
+
+// CHECK-LABEL: func.func @hoist_cast_pos_alloc
+//  CHECK-SAME:   %[[ARG0:.*]]: i1
+func.func @hoist_cast_pos_alloc(%arg: i1) -> (memref<?xf32>) {
+  //      CHECK: %[[ALLOC_0:.*]] = memref.alloc()
+  //      CHECK: %[[CAST_0:.*]] = memref.cast %[[ALLOC_0]]
+  //      CHECK: %[[CAST_1:.*]] = memref.cast %[[ALLOC_0]]
+  // CHECK-NEXT: cf.cond_br %[[ARG0]]
+  %alloc = memref.alloc() : memref<10xf32>
+  cf.cond_br %arg, ^bb1, ^bb2
+^bb1:
+  %cast = memref.cast %alloc : memref<10xf32> to memref<?xf32>
+  // CHECK: return %[[CAST_1]]
+  return %cast : memref<?xf32>
+^bb2:
+  %cast1 = memref.cast %alloc : memref<10xf32> to memref<?xf32>
+  // CHECK: return %[[CAST_0]]
+  return %cast1 : memref<?xf32> 
+}
+
+// -----
+
+// CHECK-LABEL: func @mult_scf_sum(
+//  CHECK-SAME:   %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index
+func.func @mult_scf_sum(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %c0 = arith.constant 0 : index
+  %res0 = scf.for %iv0 = %arg0 to %arg1 step %arg2 iter_args(%sum0 = %c0) -> index {
+    %res1 = scf.for %iv1 = %arg0 to %arg1 step %arg2 iter_args(%sum1 = %sum0) -> index {
+      %res2 = scf.for %iv2 = %arg0 to %arg1 step %arg2 iter_args(%sum2 = %sum1) -> index {
+        %add0 = arith.addi %iv0, %iv1 : index
+        %add1 = arith.addi %add0, %iv2 : index
+        %add2 = arith.addi %add1, %sum2 : index
+        scf.yield %add1 : index
+      }
+      scf.yield %res2 : index
+    }
+    scf.yield %res1 : index
+  }
+  //      CHECK: %[[FOR_0:.*]] = scf.for %[[IV_0:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
+  // CHECK-NEXT:   %[[FOR_1:.*]] = scf.for %[[IV_1:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
+  // CHECK-NEXT:     %[[ADDI_0:.*]] = arith.addi %[[IV_0]], %[[IV_1]] : index
+  // CHECK-NEXT:       %[[FOR_2:.*]] = scf.for %[[IV_3:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]] iter_args(%[[ITER:.*]] = %{{.*}})
+  // CHECK-NEXT:         %[[ADDI_1:.*]] = arith.addi %[[ADDI_0]], %[[IV_3]] : index
+  // CHECK-NEXT:         %[[ADDI_2:.*]] = arith.addi %[[ADDI_1]], %[[ITER]] : index
+  return %res0 : index
+}
\ No newline at end of file

>From af0cbf4a1928f35fd2ae9d328a6cb4ef554c1fe5 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Thu, 20 Nov 2025 13:53:17 +0000
Subject: [PATCH 4/4] support region op.

---
 mlir/lib/Transforms/HoistPureOps.cpp | 31 +++++++++++++++++++++++++++-
 1 file changed, 30 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Transforms/HoistPureOps.cpp b/mlir/lib/Transforms/HoistPureOps.cpp
index 86b756b717bab..b35743f8ffd40 100644
--- a/mlir/lib/Transforms/HoistPureOps.cpp
+++ b/mlir/lib/Transforms/HoistPureOps.cpp
@@ -56,10 +56,34 @@ static Value getDomaincedValue(DominanceInfo &dominanceInfo, Value a, Value b) {
   }
 }
 
+static bool isOpContainBlock(Operation *op, Block *block) {
+  Operation *parentOp = block->getParentOp();
+  while (parentOp && parentOp != op) {
+    parentOp = parentOp->getParentOp();
+  }
+  return parentOp == op ? true : false;
+}
+
 /// Find the hoisting position for the pure op.
 static Value getDestPos(Operation *op) {
   DominanceInfo dominanceInfo(op);
   SmallVector<Value> operands(op->getOperands());
+  if (op->getNumRegions()) {
+    op->walk([&](Operation *operation) {
+      for (auto operand : operation->getOperands()) {
+        Operation *defineOp = operand.getDefiningOp();
+        if (!defineOp) {
+          BlockArgument argument = cast<BlockArgument>(operand);
+          if (!isOpContainBlock(op, argument.getOwner()))
+            operands.push_back(operand);
+          continue;
+        }
+        if (!isOpContainBlock(op, defineOp->getBlock())) {
+          operands.push_back(operand);
+        }
+      }
+    });
+  }
   if (operands.empty())
     return {};
   Value ret = operands[0];
@@ -71,13 +95,18 @@ static Value getDestPos(Operation *op) {
 
 /// Hoist single pure op.
 static void hoistPureOp(RewriterBase &rewriter, Operation *op) {
+  LDBG() << "hoistPureOp: " << OpWithFlags(op, OpPrintingFlags().skipRegions());
   Value pos = getDestPos(op);
   if (!pos)
     return;
 
   if (Operation *defineOp = pos.getDefiningOp()) {
+    if (op == defineOp)
+      return;
+
     LDBG() << "move " << OpWithFlags(op, OpPrintingFlags().skipRegions())
-           << " after " << OpWithFlags(op, OpPrintingFlags().skipRegions());
+           << " after "
+           << OpWithFlags(defineOp, OpPrintingFlags().skipRegions());
     rewriter.moveOpAfter(op, defineOp);
     return;
   }



More information about the Mlir-commits mailing list