[Mlir-commits] [mlir] [mlir] add hoist-pure-ops to mlir (PR #168715)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 19 06:45:16 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: lonely eagle (linuxlonelyeagle)
<details>
<summary>Changes</summary>
`hoist-pure-ops` pass can hoist the `Pure` ops, to gain more opportunities for optimisation.
---
Full diff: https://github.com/llvm/llvm-project/pull/168715.diff
5 Files Affected:
- (modified) mlir/include/mlir/Transforms/Passes.h (+1)
- (modified) mlir/include/mlir/Transforms/Passes.td (+4)
- (modified) mlir/lib/Transforms/CMakeLists.txt (+1)
- (added) mlir/lib/Transforms/HoistPureOps.cpp (+107)
- (added) mlir/test/Transforms/hoist-pure-ops.mlir (+67)
``````````diff
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 724da009e70f1..690e9e88a87b8 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -37,6 +37,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
#define GEN_PASS_DECL_CONTROLFLOWSINK
#define GEN_PASS_DECL_GENERATERUNTIMEVERIFICATION
+#define GEN_PASS_DECL_HOISTPUREOPS
#define GEN_PASS_DECL_LOOPINVARIANTCODEMOTION
#define GEN_PASS_DECL_INLINER
#define GEN_PASS_DECL_MEM2REG
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 28b4a01cf0ecd..c74ce1946cb03 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -636,4 +636,8 @@ def BubbleDownMemorySpaceCasts :
}];
}
+def HoistPureOps :
+ Pass<"hoist-pure-ops"> {
+}
+
#endif // MLIR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 54b67f5c7a91e..b32865ed01b82 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -21,6 +21,7 @@ add_mlir_library(MLIRTransforms
SymbolPrivatize.cpp
TopologicalSort.cpp
ViewOpGraph.cpp
+ HoistPureOps.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
diff --git a/mlir/lib/Transforms/HoistPureOps.cpp b/mlir/lib/Transforms/HoistPureOps.cpp
new file mode 100644
index 0000000000000..86b756b717bab
--- /dev/null
+++ b/mlir/lib/Transforms/HoistPureOps.cpp
@@ -0,0 +1,107 @@
+//===- HoistPureOps.cpp - Hoist Pure ops ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the function of hoist the pure op based on SSA
+// dominance.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/Support/DebugLog.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_HOISTPUREOPS
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "hoist-pure-ops"
+
+using namespace mlir;
+
+namespace {
+
+/// Return the dominated Value.
+static Value getDomaincedValue(DominanceInfo &dominanceInfo, Value a, Value b) {
+ Block *aB = a.getParentBlock();
+ Block *bB = b.getParentBlock();
+ if (isa<BlockArgument>(a) && isa<BlockArgument>(b)) {
+ return dominanceInfo.dominates(aB, bB) ? b : a;
+ } else if (isa<BlockArgument>(a) || isa<BlockArgument>(b)) {
+ if (aB != bB)
+ return dominanceInfo.dominates(aB, bB) ? b : a;
+ if (auto aArg = dyn_cast<BlockArgument>(a)) {
+ Operation *aFrontOp = &aArg.getOwner()->front();
+ if (aFrontOp == b.getDefiningOp())
+ return b;
+ return dominanceInfo.dominates(aFrontOp, b.getDefiningOp()) ? b : a;
+ }
+ auto bArg = cast<BlockArgument>(b);
+ Operation *bFrontOp = &bArg.getOwner()->front();
+ if (bFrontOp == a.getDefiningOp())
+ return a;
+ return dominanceInfo.dominates(a.getDefiningOp(), bFrontOp) ? b : a;
+ } else {
+ Operation *aDefineOp = a.getDefiningOp();
+ Operation *bDefineOp = b.getDefiningOp();
+ return dominanceInfo.dominates(aDefineOp, bDefineOp) ? b : a;
+ }
+}
+
+/// Find the hoisting position for the pure op.
+static Value getDestPos(Operation *op) {
+ DominanceInfo dominanceInfo(op);
+ SmallVector<Value> operands(op->getOperands());
+ if (operands.empty())
+ return {};
+ Value ret = operands[0];
+ for (int i = 1, e = operands.size(); i < e; ++i) {
+ ret = getDomaincedValue(dominanceInfo, ret, operands[i]);
+ }
+ return ret;
+}
+
+/// Hoist single pure op.
+static void hoistPureOp(RewriterBase &rewriter, Operation *op) {
+ Value pos = getDestPos(op);
+ if (!pos)
+ return;
+
+ if (Operation *defineOp = pos.getDefiningOp()) {
+ LDBG() << "move " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " after " << OpWithFlags(op, OpPrintingFlags().skipRegions());
+ rewriter.moveOpAfter(op, defineOp);
+ return;
+ }
+ auto argument = cast<BlockArgument>(pos);
+ LDBG() << "move " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " before "
+ << OpWithFlags(&argument.getOwner()->front(),
+ OpPrintingFlags().skipRegions());
+ rewriter.moveOpBefore(op, &argument.getOwner()->front());
+}
+
+struct HoistPureOps : public impl::HoistPureOpsBase<HoistPureOps> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void HoistPureOps::runOnOperation() {
+ Operation *module = getOperation();
+ IRRewriter rewriter(module->getContext());
+ module->walk<WalkOrder::PreOrder>([&](Operation *op) {
+ if (op->hasTrait<mlir::OpTrait::IsTerminator>())
+ return;
+ if (isPure(op)) {
+ hoistPureOp(rewriter, op);
+ }
+ });
+}
diff --git a/mlir/test/Transforms/hoist-pure-ops.mlir b/mlir/test/Transforms/hoist-pure-ops.mlir
new file mode 100644
index 0000000000000..d719e84862134
--- /dev/null
+++ b/mlir/test/Transforms/hoist-pure-ops.mlir
@@ -0,0 +1,67 @@
+// RUN: mlir-opt %s -hoist-pure-ops -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: func @hoist_cast_pos
+// CHECK-SAME: %[[ARG0:.*]]: memref<10xf32>,
+// CHECK-SAME: %[[ARG1:.*]]: i1
+func.func @hoist_cast_pos(%arg: memref<10xf32>, %arg1: i1) -> (memref<?xf32>) {
+ // CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]]
+ // CHECK: %[[CAST_1:.*]] = memref.cast %[[ARG0]]
+ // CHECK-NEXT: cf.cond_br %[[ARG1]]
+ cf.cond_br %arg1, ^bb1, ^bb2
+^bb1:
+ %cast = memref.cast %arg : memref<10xf32> to memref<?xf32>
+ // CHECK: return %[[CAST_1]]
+ return %cast : memref<?xf32>
+^bb2:
+ %cast1 = memref.cast %arg : memref<10xf32> to memref<?xf32>
+ // CHECK: return %[[CAST_0]]
+ return %cast1 : memref<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @hoist_cast_pos_alloc
+// CHECK-SAME: %[[ARG0:.*]]: i1
+func.func @hoist_cast_pos_alloc(%arg: i1) -> (memref<?xf32>) {
+ // CHECK: %[[ALLOC_0:.*]] = memref.alloc()
+ // CHECK: %[[CAST_0:.*]] = memref.cast %[[ALLOC_0]]
+ // CHECK: %[[CAST_1:.*]] = memref.cast %[[ALLOC_0]]
+ // CHECK-NEXT: cf.cond_br %[[ARG0]]
+ %alloc = memref.alloc() : memref<10xf32>
+ cf.cond_br %arg, ^bb1, ^bb2
+^bb1:
+ %cast = memref.cast %alloc : memref<10xf32> to memref<?xf32>
+ // CHECK: return %[[CAST_1]]
+ return %cast : memref<?xf32>
+^bb2:
+ %cast1 = memref.cast %alloc : memref<10xf32> to memref<?xf32>
+ // CHECK: return %[[CAST_0]]
+ return %cast1 : memref<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @mult_scf_sum(
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index
+func.func @mult_scf_sum(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %c0 = arith.constant 0 : index
+ %res0 = scf.for %iv0 = %arg0 to %arg1 step %arg2 iter_args(%sum0 = %c0) -> index {
+ %res1 = scf.for %iv1 = %arg0 to %arg1 step %arg2 iter_args(%sum1 = %sum0) -> index {
+ %res2 = scf.for %iv2 = %arg0 to %arg1 step %arg2 iter_args(%sum2 = %sum1) -> index {
+ %add0 = arith.addi %iv0, %iv1 : index
+ %add1 = arith.addi %add0, %iv2 : index
+ %add2 = arith.addi %add1, %sum2 : index
+ scf.yield %add1 : index
+ }
+ scf.yield %res2 : index
+ }
+ scf.yield %res1 : index
+ }
+ // CHECK: %[[FOR_0:.*]] = scf.for %[[IV_0:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
+ // CHECK-NEXT: %[[FOR_1:.*]] = scf.for %[[IV_1:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
+ // CHECK-NEXT: %[[ADDI_0:.*]] = arith.addi %[[IV_0]], %[[IV_1]] : index
+ // CHECK-NEXT: %[[FOR_2:.*]] = scf.for %[[IV_3:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]] iter_args(%[[ITER:.*]] = %{{.*}})
+ // CHECK-NEXT: %[[ADDI_1:.*]] = arith.addi %[[ADDI_0]], %[[IV_3]] : index
+ // CHECK-NEXT: %[[ADDI_2:.*]] = arith.addi %[[ADDI_1]], %[[ITER]] : index
+ return %res0 : index
+}
\ No newline at end of file
``````````
</details>
https://github.com/llvm/llvm-project/pull/168715
More information about the Mlir-commits
mailing list