[Mlir-commits] [mlir] [mlir] add hoist-pure-ops to mlir (PR #168715)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 19 06:45:16 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: lonely eagle (linuxlonelyeagle)

<details>
<summary>Changes</summary>

`hoist-pure-ops` pass can hoist the `Pure` ops, to gain more opportunities for optimisation.

---
Full diff: https://github.com/llvm/llvm-project/pull/168715.diff


5 Files Affected:

- (modified) mlir/include/mlir/Transforms/Passes.h (+1) 
- (modified) mlir/include/mlir/Transforms/Passes.td (+4) 
- (modified) mlir/lib/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Transforms/HoistPureOps.cpp (+107) 
- (added) mlir/test/Transforms/hoist-pure-ops.mlir (+67) 


``````````diff
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 724da009e70f1..690e9e88a87b8 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -37,6 +37,7 @@ class GreedyRewriteConfig;
 #define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
 #define GEN_PASS_DECL_CONTROLFLOWSINK
 #define GEN_PASS_DECL_GENERATERUNTIMEVERIFICATION
+#define GEN_PASS_DECL_HOISTPUREOPS
 #define GEN_PASS_DECL_LOOPINVARIANTCODEMOTION
 #define GEN_PASS_DECL_INLINER
 #define GEN_PASS_DECL_MEM2REG
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 28b4a01cf0ecd..c74ce1946cb03 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -636,4 +636,8 @@ def BubbleDownMemorySpaceCasts :
   }];
 }
 
+def HoistPureOps :
+    Pass<"hoist-pure-ops"> {
+}
+
 #endif // MLIR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 54b67f5c7a91e..b32865ed01b82 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -21,6 +21,7 @@ add_mlir_library(MLIRTransforms
   SymbolPrivatize.cpp
   TopologicalSort.cpp
   ViewOpGraph.cpp
+  HoistPureOps.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
diff --git a/mlir/lib/Transforms/HoistPureOps.cpp b/mlir/lib/Transforms/HoistPureOps.cpp
new file mode 100644
index 0000000000000..86b756b717bab
--- /dev/null
+++ b/mlir/lib/Transforms/HoistPureOps.cpp
@@ -0,0 +1,107 @@
+//===- HoistPureOps.cpp - Hoist Pure ops ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the function of hoist the pure op based on SSA
+// dominance.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/Support/DebugLog.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_HOISTPUREOPS
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "hoist-pure-ops"
+
+using namespace mlir;
+
+namespace {
+
+/// Return the dominated Value.
+static Value getDomaincedValue(DominanceInfo &dominanceInfo, Value a, Value b) {
+  Block *aB = a.getParentBlock();
+  Block *bB = b.getParentBlock();
+  if (isa<BlockArgument>(a) && isa<BlockArgument>(b)) {
+    return dominanceInfo.dominates(aB, bB) ? b : a;
+  } else if (isa<BlockArgument>(a) || isa<BlockArgument>(b)) {
+    if (aB != bB)
+      return dominanceInfo.dominates(aB, bB) ? b : a;
+    if (auto aArg = dyn_cast<BlockArgument>(a)) {
+      Operation *aFrontOp = &aArg.getOwner()->front();
+      if (aFrontOp == b.getDefiningOp())
+        return b;
+      return dominanceInfo.dominates(aFrontOp, b.getDefiningOp()) ? b : a;
+    }
+    auto bArg = cast<BlockArgument>(b);
+    Operation *bFrontOp = &bArg.getOwner()->front();
+    if (bFrontOp == a.getDefiningOp())
+      return a;
+    return dominanceInfo.dominates(a.getDefiningOp(), bFrontOp) ? b : a;
+  } else {
+    Operation *aDefineOp = a.getDefiningOp();
+    Operation *bDefineOp = b.getDefiningOp();
+    return dominanceInfo.dominates(aDefineOp, bDefineOp) ? b : a;
+  }
+}
+
+/// Find the hoisting position for the pure op.
+static Value getDestPos(Operation *op) {
+  DominanceInfo dominanceInfo(op);
+  SmallVector<Value> operands(op->getOperands());
+  if (operands.empty())
+    return {};
+  Value ret = operands[0];
+  for (int i = 1, e = operands.size(); i < e; ++i) {
+    ret = getDomaincedValue(dominanceInfo, ret, operands[i]);
+  }
+  return ret;
+}
+
+/// Hoist single pure op.
+static void hoistPureOp(RewriterBase &rewriter, Operation *op) {
+  Value pos = getDestPos(op);
+  if (!pos)
+    return;
+
+  if (Operation *defineOp = pos.getDefiningOp()) {
+    LDBG() << "move " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+           << " after " << OpWithFlags(op, OpPrintingFlags().skipRegions());
+    rewriter.moveOpAfter(op, defineOp);
+    return;
+  }
+  auto argument = cast<BlockArgument>(pos);
+  LDBG() << "move " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+         << " before "
+         << OpWithFlags(&argument.getOwner()->front(),
+                        OpPrintingFlags().skipRegions());
+  rewriter.moveOpBefore(op, &argument.getOwner()->front());
+}
+
+struct HoistPureOps : public impl::HoistPureOpsBase<HoistPureOps> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void HoistPureOps::runOnOperation() {
+  Operation *module = getOperation();
+  IRRewriter rewriter(module->getContext());
+  module->walk<WalkOrder::PreOrder>([&](Operation *op) {
+    if (op->hasTrait<mlir::OpTrait::IsTerminator>())
+      return;
+    if (isPure(op)) {
+      hoistPureOp(rewriter, op);
+    }
+  });
+}
diff --git a/mlir/test/Transforms/hoist-pure-ops.mlir b/mlir/test/Transforms/hoist-pure-ops.mlir
new file mode 100644
index 0000000000000..d719e84862134
--- /dev/null
+++ b/mlir/test/Transforms/hoist-pure-ops.mlir
@@ -0,0 +1,67 @@
+// RUN: mlir-opt %s -hoist-pure-ops -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: func @hoist_cast_pos
+//  CHECK-SAME:   %[[ARG0:.*]]: memref<10xf32>,
+//  CHECK-SAME:   %[[ARG1:.*]]: i1
+func.func @hoist_cast_pos(%arg: memref<10xf32>, %arg1: i1) -> (memref<?xf32>) {
+  //      CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]]
+  //      CHECK: %[[CAST_1:.*]] = memref.cast %[[ARG0]]
+  // CHECK-NEXT: cf.cond_br %[[ARG1]]
+  cf.cond_br %arg1, ^bb1, ^bb2
+^bb1:
+  %cast = memref.cast %arg : memref<10xf32> to memref<?xf32>
+  // CHECK: return %[[CAST_1]]
+  return %cast : memref<?xf32>
+^bb2:
+  %cast1 = memref.cast %arg : memref<10xf32> to memref<?xf32>
+  // CHECK: return %[[CAST_0]]
+  return %cast1 : memref<?xf32> 
+}
+
+// -----
+
+// CHECK-LABEL: func.func @hoist_cast_pos_alloc
+//  CHECK-SAME:   %[[ARG0:.*]]: i1
+func.func @hoist_cast_pos_alloc(%arg: i1) -> (memref<?xf32>) {
+  //      CHECK: %[[ALLOC_0:.*]] = memref.alloc()
+  //      CHECK: %[[CAST_0:.*]] = memref.cast %[[ALLOC_0]]
+  //      CHECK: %[[CAST_1:.*]] = memref.cast %[[ALLOC_0]]
+  // CHECK-NEXT: cf.cond_br %[[ARG0]]
+  %alloc = memref.alloc() : memref<10xf32>
+  cf.cond_br %arg, ^bb1, ^bb2
+^bb1:
+  %cast = memref.cast %alloc : memref<10xf32> to memref<?xf32>
+  // CHECK: return %[[CAST_1]]
+  return %cast : memref<?xf32>
+^bb2:
+  %cast1 = memref.cast %alloc : memref<10xf32> to memref<?xf32>
+  // CHECK: return %[[CAST_0]]
+  return %cast1 : memref<?xf32> 
+}
+
+// -----
+
+// CHECK-LABEL: func @mult_scf_sum(
+//  CHECK-SAME:   %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index
+func.func @mult_scf_sum(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %c0 = arith.constant 0 : index
+  %res0 = scf.for %iv0 = %arg0 to %arg1 step %arg2 iter_args(%sum0 = %c0) -> index {
+    %res1 = scf.for %iv1 = %arg0 to %arg1 step %arg2 iter_args(%sum1 = %sum0) -> index {
+      %res2 = scf.for %iv2 = %arg0 to %arg1 step %arg2 iter_args(%sum2 = %sum1) -> index {
+        %add0 = arith.addi %iv0, %iv1 : index
+        %add1 = arith.addi %add0, %iv2 : index
+        %add2 = arith.addi %add1, %sum2 : index
+        scf.yield %add1 : index
+      }
+      scf.yield %res2 : index
+    }
+    scf.yield %res1 : index
+  }
+  //      CHECK: %[[FOR_0:.*]] = scf.for %[[IV_0:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
+  // CHECK-NEXT:   %[[FOR_1:.*]] = scf.for %[[IV_1:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
+  // CHECK-NEXT:     %[[ADDI_0:.*]] = arith.addi %[[IV_0]], %[[IV_1]] : index
+  // CHECK-NEXT:       %[[FOR_2:.*]] = scf.for %[[IV_3:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]] iter_args(%[[ITER:.*]] = %{{.*}})
+  // CHECK-NEXT:         %[[ADDI_1:.*]] = arith.addi %[[ADDI_0]], %[[IV_3]] : index
+  // CHECK-NEXT:         %[[ADDI_2:.*]] = arith.addi %[[ADDI_1]], %[[ITER]] : index
+  return %res0 : index
+}
\ No newline at end of file

``````````

</details>


https://github.com/llvm/llvm-project/pull/168715


More information about the Mlir-commits mailing list