[Mlir-commits] [mlir] [mlir] add hoist-pure-ops to mlir (PR #168715)
lonely eagle
llvmlistbot at llvm.org
Thu Nov 20 06:19:08 PST 2025
https://github.com/linuxlonelyeagle updated https://github.com/llvm/llvm-project/pull/168715
>From ca50a53e98e7ed05029fdbec5226bb19c9322353 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Wed, 19 Nov 2025 01:48:15 +0000
Subject: [PATCH 1/4] add tablegen pass define.
---
mlir/include/mlir/Transforms/Passes.td | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 28b4a01cf0ecd..c74ce1946cb03 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -636,4 +636,8 @@ def BubbleDownMemorySpaceCasts :
}];
}
+def HoistPureOps :
+ Pass<"hoist-pure-ops"> {
+}
+
#endif // MLIR_TRANSFORMS_PASSES
>From ae17efd27ed5f88a8ffc84be1a452cbfcb10336d Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Wed, 19 Nov 2025 10:06:46 +0000
Subject: [PATCH 2/4] add baisc implement of hoist pure pass.
---
mlir/include/mlir/Transforms/Passes.h | 1 +
mlir/lib/Transforms/CMakeLists.txt | 1 +
mlir/lib/Transforms/HoistPureOps.cpp | 89 +++++++++++++++++++++++++++
3 files changed, 91 insertions(+)
create mode 100644 mlir/lib/Transforms/HoistPureOps.cpp
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 724da009e70f1..690e9e88a87b8 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -37,6 +37,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
#define GEN_PASS_DECL_CONTROLFLOWSINK
#define GEN_PASS_DECL_GENERATERUNTIMEVERIFICATION
+#define GEN_PASS_DECL_HOISTPUREOPS
#define GEN_PASS_DECL_LOOPINVARIANTCODEMOTION
#define GEN_PASS_DECL_INLINER
#define GEN_PASS_DECL_MEM2REG
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 54b67f5c7a91e..b32865ed01b82 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -21,6 +21,7 @@ add_mlir_library(MLIRTransforms
SymbolPrivatize.cpp
TopologicalSort.cpp
ViewOpGraph.cpp
+ HoistPureOps.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
diff --git a/mlir/lib/Transforms/HoistPureOps.cpp b/mlir/lib/Transforms/HoistPureOps.cpp
new file mode 100644
index 0000000000000..c7c515faa001f
--- /dev/null
+++ b/mlir/lib/Transforms/HoistPureOps.cpp
@@ -0,0 +1,89 @@
+//===- HoistPureOps.cpp - Hoist Pure ops ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the function of hoist the pure op based on SSA
+// dominance.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_HOISTPUREOPS
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+/// Return the dominated Value.
+static Value getDomaincedValue(DominanceInfo &dominanceInfo, Value a, Value b) {
+ Block *aB = a.getParentBlock();
+ Block *bB = b.getParentBlock();
+ if (isa_and_present<BlockArgument>(a) && isa_and_present<BlockArgument>(b)) {
+ return dominanceInfo.dominates(aB, bB) ? b : a;
+ } else if (isa_and_present<BlockArgument>(a) ||
+ isa_and_present<BlockArgument>(b)) {
+ if (aB == bB)
+ return b;
+ return dominanceInfo.dominates(aB, bB) ? b : a;
+ } else {
+ Operation *aDefineOp = a.getDefiningOp();
+ Operation *bDefineOp = b.getDefiningOp();
+ return dominanceInfo.dominates(aDefineOp, bDefineOp) ? b : a;
+ }
+}
+
+/// Find the hoisting position for the pure op.
+static Value getDestPos(Operation *op) {
+ DominanceInfo dominanceInfo(op);
+ SmallVector<Value> operands(op->getOperands());
+ if (operands.empty())
+ return {};
+ Value ret = operands[0];
+ for (int i = 1, e = operands.size(); i < e; ++i) {
+ ret = getDomaincedValue(dominanceInfo, ret, operands[i]);
+ }
+ return ret;
+}
+
+/// Hoist single pure op.
+static void hoistPureOp(RewriterBase &rewriter, Operation *op) {
+ Value pos = getDestPos(op);
+ if (!pos)
+ return;
+
+ if (Operation *defineOp = pos.getDefiningOp()) {
+ rewriter.moveOpAfter(op, defineOp);
+ return;
+ }
+ auto argument = cast<BlockArgument>(pos);
+ rewriter.moveOpBefore(op, &argument.getOwner()->front());
+}
+
+struct HoistPureOps : public impl::HoistPureOpsBase<HoistPureOps> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void HoistPureOps::runOnOperation() {
+ Operation *module = getOperation();
+ IRRewriter rewriter(module->getContext());
+ module->walk<WalkOrder::PreOrder>([&](Operation *op) {
+ if (op->hasTrait<mlir::OpTrait::IsTerminator>())
+ return;
+ if (isPure(op)) {
+ hoistPureOp(rewriter, op);
+ }
+ });
+}
>From 397448aa3b0c5bf0a5d0e751d04da1fffc653a03 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Wed, 19 Nov 2025 13:31:09 +0000
Subject: [PATCH 3/4] handle bound condition.
---
mlir/lib/Transforms/HoistPureOps.cpp | 30 ++++++++---
mlir/test/Transforms/hoist-pure-ops.mlir | 67 ++++++++++++++++++++++++
2 files changed, 91 insertions(+), 6 deletions(-)
create mode 100644 mlir/test/Transforms/hoist-pure-ops.mlir
diff --git a/mlir/lib/Transforms/HoistPureOps.cpp b/mlir/lib/Transforms/HoistPureOps.cpp
index c7c515faa001f..86b756b717bab 100644
--- a/mlir/lib/Transforms/HoistPureOps.cpp
+++ b/mlir/lib/Transforms/HoistPureOps.cpp
@@ -16,12 +16,15 @@
#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
+#include "llvm/Support/DebugLog.h"
namespace mlir {
#define GEN_PASS_DEF_HOISTPUREOPS
#include "mlir/Transforms/Passes.h.inc"
} // namespace mlir
+#define DEBUG_TYPE "hoist-pure-ops"
+
using namespace mlir;
namespace {
@@ -30,13 +33,22 @@ namespace {
static Value getDomaincedValue(DominanceInfo &dominanceInfo, Value a, Value b) {
Block *aB = a.getParentBlock();
Block *bB = b.getParentBlock();
- if (isa_and_present<BlockArgument>(a) && isa_and_present<BlockArgument>(b)) {
- return dominanceInfo.dominates(aB, bB) ? b : a;
- } else if (isa_and_present<BlockArgument>(a) ||
- isa_and_present<BlockArgument>(b)) {
- if (aB == bB)
- return b;
+ if (isa<BlockArgument>(a) && isa<BlockArgument>(b)) {
return dominanceInfo.dominates(aB, bB) ? b : a;
+ } else if (isa<BlockArgument>(a) || isa<BlockArgument>(b)) {
+ if (aB != bB)
+ return dominanceInfo.dominates(aB, bB) ? b : a;
+ if (auto aArg = dyn_cast<BlockArgument>(a)) {
+ Operation *aFrontOp = &aArg.getOwner()->front();
+ if (aFrontOp == b.getDefiningOp())
+ return b;
+ return dominanceInfo.dominates(aFrontOp, b.getDefiningOp()) ? b : a;
+ }
+ auto bArg = cast<BlockArgument>(b);
+ Operation *bFrontOp = &bArg.getOwner()->front();
+ if (bFrontOp == a.getDefiningOp())
+ return a;
+ return dominanceInfo.dominates(a.getDefiningOp(), bFrontOp) ? b : a;
} else {
Operation *aDefineOp = a.getDefiningOp();
Operation *bDefineOp = b.getDefiningOp();
@@ -64,10 +76,16 @@ static void hoistPureOp(RewriterBase &rewriter, Operation *op) {
return;
if (Operation *defineOp = pos.getDefiningOp()) {
+ LDBG() << "move " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " after " << OpWithFlags(op, OpPrintingFlags().skipRegions());
rewriter.moveOpAfter(op, defineOp);
return;
}
auto argument = cast<BlockArgument>(pos);
+ LDBG() << "move " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " before "
+ << OpWithFlags(&argument.getOwner()->front(),
+ OpPrintingFlags().skipRegions());
rewriter.moveOpBefore(op, &argument.getOwner()->front());
}
diff --git a/mlir/test/Transforms/hoist-pure-ops.mlir b/mlir/test/Transforms/hoist-pure-ops.mlir
new file mode 100644
index 0000000000000..d719e84862134
--- /dev/null
+++ b/mlir/test/Transforms/hoist-pure-ops.mlir
@@ -0,0 +1,67 @@
+// RUN: mlir-opt %s -hoist-pure-ops -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: func @hoist_cast_pos
+// CHECK-SAME: %[[ARG0:.*]]: memref<10xf32>,
+// CHECK-SAME: %[[ARG1:.*]]: i1
+func.func @hoist_cast_pos(%arg: memref<10xf32>, %arg1: i1) -> (memref<?xf32>) {
+ // CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]]
+ // CHECK: %[[CAST_1:.*]] = memref.cast %[[ARG0]]
+ // CHECK-NEXT: cf.cond_br %[[ARG1]]
+ cf.cond_br %arg1, ^bb1, ^bb2
+^bb1:
+ %cast = memref.cast %arg : memref<10xf32> to memref<?xf32>
+ // CHECK: return %[[CAST_1]]
+ return %cast : memref<?xf32>
+^bb2:
+ %cast1 = memref.cast %arg : memref<10xf32> to memref<?xf32>
+ // CHECK: return %[[CAST_0]]
+ return %cast1 : memref<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @hoist_cast_pos_alloc
+// CHECK-SAME: %[[ARG0:.*]]: i1
+func.func @hoist_cast_pos_alloc(%arg: i1) -> (memref<?xf32>) {
+ // CHECK: %[[ALLOC_0:.*]] = memref.alloc()
+ // CHECK: %[[CAST_0:.*]] = memref.cast %[[ALLOC_0]]
+ // CHECK: %[[CAST_1:.*]] = memref.cast %[[ALLOC_0]]
+ // CHECK-NEXT: cf.cond_br %[[ARG0]]
+ %alloc = memref.alloc() : memref<10xf32>
+ cf.cond_br %arg, ^bb1, ^bb2
+^bb1:
+ %cast = memref.cast %alloc : memref<10xf32> to memref<?xf32>
+ // CHECK: return %[[CAST_1]]
+ return %cast : memref<?xf32>
+^bb2:
+ %cast1 = memref.cast %alloc : memref<10xf32> to memref<?xf32>
+ // CHECK: return %[[CAST_0]]
+ return %cast1 : memref<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @mult_scf_sum(
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index
+func.func @mult_scf_sum(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %c0 = arith.constant 0 : index
+ %res0 = scf.for %iv0 = %arg0 to %arg1 step %arg2 iter_args(%sum0 = %c0) -> index {
+ %res1 = scf.for %iv1 = %arg0 to %arg1 step %arg2 iter_args(%sum1 = %sum0) -> index {
+ %res2 = scf.for %iv2 = %arg0 to %arg1 step %arg2 iter_args(%sum2 = %sum1) -> index {
+ %add0 = arith.addi %iv0, %iv1 : index
+ %add1 = arith.addi %add0, %iv2 : index
+ %add2 = arith.addi %add1, %sum2 : index
+ scf.yield %add1 : index
+ }
+ scf.yield %res2 : index
+ }
+ scf.yield %res1 : index
+ }
+ // CHECK: %[[FOR_0:.*]] = scf.for %[[IV_0:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
+ // CHECK-NEXT: %[[FOR_1:.*]] = scf.for %[[IV_1:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
+ // CHECK-NEXT: %[[ADDI_0:.*]] = arith.addi %[[IV_0]], %[[IV_1]] : index
+ // CHECK-NEXT: %[[FOR_2:.*]] = scf.for %[[IV_3:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]] iter_args(%[[ITER:.*]] = %{{.*}})
+ // CHECK-NEXT: %[[ADDI_1:.*]] = arith.addi %[[ADDI_0]], %[[IV_3]] : index
+ // CHECK-NEXT: %[[ADDI_2:.*]] = arith.addi %[[ADDI_1]], %[[ITER]] : index
+ return %res0 : index
+}
\ No newline at end of file
>From af0cbf4a1928f35fd2ae9d328a6cb4ef554c1fe5 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Thu, 20 Nov 2025 13:53:17 +0000
Subject: [PATCH 4/4] support region op.
---
mlir/lib/Transforms/HoistPureOps.cpp | 31 +++++++++++++++++++++++++++-
1 file changed, 30 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Transforms/HoistPureOps.cpp b/mlir/lib/Transforms/HoistPureOps.cpp
index 86b756b717bab..b35743f8ffd40 100644
--- a/mlir/lib/Transforms/HoistPureOps.cpp
+++ b/mlir/lib/Transforms/HoistPureOps.cpp
@@ -56,10 +56,34 @@ static Value getDomaincedValue(DominanceInfo &dominanceInfo, Value a, Value b) {
}
}
+static bool isOpContainBlock(Operation *op, Block *block) {
+ Operation *parentOp = block->getParentOp();
+ while (parentOp && parentOp != op) {
+ parentOp = parentOp->getParentOp();
+ }
+ return parentOp == op ? true : false;
+}
+
/// Find the hoisting position for the pure op.
static Value getDestPos(Operation *op) {
DominanceInfo dominanceInfo(op);
SmallVector<Value> operands(op->getOperands());
+ if (op->getNumRegions()) {
+ op->walk([&](Operation *operation) {
+ for (auto operand : operation->getOperands()) {
+ Operation *defineOp = operand.getDefiningOp();
+ if (!defineOp) {
+ BlockArgument argument = cast<BlockArgument>(operand);
+ if (!isOpContainBlock(op, argument.getOwner()))
+ operands.push_back(operand);
+ continue;
+ }
+ if (!isOpContainBlock(op, defineOp->getBlock())) {
+ operands.push_back(operand);
+ }
+ }
+ });
+ }
if (operands.empty())
return {};
Value ret = operands[0];
@@ -71,13 +95,18 @@ static Value getDestPos(Operation *op) {
/// Hoist single pure op.
static void hoistPureOp(RewriterBase &rewriter, Operation *op) {
+ LDBG() << "hoistPureOp: " << OpWithFlags(op, OpPrintingFlags().skipRegions());
Value pos = getDestPos(op);
if (!pos)
return;
if (Operation *defineOp = pos.getDefiningOp()) {
+ if (op == defineOp)
+ return;
+
LDBG() << "move " << OpWithFlags(op, OpPrintingFlags().skipRegions())
- << " after " << OpWithFlags(op, OpPrintingFlags().skipRegions());
+ << " after "
+ << OpWithFlags(defineOp, OpPrintingFlags().skipRegions());
rewriter.moveOpAfter(op, defineOp);
return;
}
More information about the Mlir-commits
mailing list