[Mlir-commits] [mlir] [mlir][func] Add eliminate-function-parameter pass (PR #160654)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 25 00:08:42 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-func

Author: lonely eagle (linuxlonelyeagle)

<details>
<summary>Changes</summary>

Added the eliminate-function-parameter pass. During the IR transformation process, function parameters may become unused, and this pass is used to remove unused parameters in functions.

---
Full diff: https://github.com/llvm/llvm-project/pull/160654.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/Func/Transforms/Passes.h (+1-1) 
- (modified) mlir/include/mlir/Dialect/Func/Transforms/Passes.td (+8) 
- (modified) mlir/lib/Dialect/Func/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/Func/Transforms/EliminateFunctionParameter.cpp (+88) 
- (added) mlir/test/Dialect/Func/eliminate-function-parameter.mlir (+32) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Func/Transforms/Passes.h b/mlir/include/mlir/Dialect/Func/Transforms/Passes.h
index 6fe9cc4bb2986..1369a3627d0f2 100644
--- a/mlir/include/mlir/Dialect/Func/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Func/Transforms/Passes.h
@@ -22,7 +22,7 @@ class RewritePatternSet;
 
 namespace func {
 
-#define GEN_PASS_DECL_DUPLICATEFUNCTIONELIMINATIONPASS
+#define GEN_PASS_DECL
 #include "mlir/Dialect/Func/Transforms/Passes.h.inc"
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Func/Transforms/Passes.td b/mlir/include/mlir/Dialect/Func/Transforms/Passes.td
index 4163997515bb0..0e57bc3e0da91 100644
--- a/mlir/include/mlir/Dialect/Func/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Func/Transforms/Passes.td
@@ -21,4 +21,12 @@ def DuplicateFunctionEliminationPass : Pass<"duplicate-function-elimination",
   }];
 }
 
+def EliminateFunctionParameterPass : Pass<"eliminate-function-parameter",
+  "ModuleOp"> {
+  let summary = "Eliminate function parameter";
+  let description = [{
+    Eliminate function parameter is used to remove unnecessary parameters passed
+    to a function, then update the function call.
+  }];
+}
 #endif // MLIR_DIALECT_FUNC_TRANSFORMS_PASSES_TD
diff --git a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
index 0bed59e109503..3553613543c86 100644
--- a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRFuncTransforms
+  EliminateFunctionParameter.cpp
   DuplicateFunctionElimination.cpp
   FuncConversions.cpp
 
diff --git a/mlir/lib/Dialect/Func/Transforms/EliminateFunctionParameter.cpp b/mlir/lib/Dialect/Func/Transforms/EliminateFunctionParameter.cpp
new file mode 100644
index 0000000000000..c5412e117878b
--- /dev/null
+++ b/mlir/lib/Dialect/Func/Transforms/EliminateFunctionParameter.cpp
@@ -0,0 +1,88 @@
+//===- EliminateFunctionParameter.cpp.cpp - Eliminate function Parameter --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/Passes.h"
+
+namespace mlir {
+namespace func {
+#define GEN_PASS_DEF_ELIMINATEFUNCTIONPARAMETERPASS
+#include "mlir/Dialect/Func/Transforms/Passes.h.inc"
+} // namespace func
+
+/// This function eliminates unnecessary parameters within the function.
+static LogicalResult updateFunc(func::FuncOp funcOp, BitVector &arguemntNoUse) {
+  Block &entryBlock = funcOp.front();
+  bool change = false;
+  FunctionType origType = funcOp.getFunctionType();
+  llvm::ArrayRef<Type> origInputTypes = origType.getInputs();
+  SmallVector<Type, 4> newInputTypes;
+  for (auto iter : llvm::enumerate(funcOp.getArguments())) {
+    size_t position = iter.index();
+    if (!iter.value().use_empty()) {
+      newInputTypes.push_back(origInputTypes[position]);
+      continue;
+    }
+    arguemntNoUse.set(position);
+    entryBlock.eraseArgument(position);
+    change = true;
+  }
+
+  if (change) {
+    auto newFunctionType = FunctionType::get(funcOp.getContext(), newInputTypes,
+                                             origType.getResults());
+    funcOp.setFunctionType(newFunctionType);
+  }
+  return success(change);
+}
+
+/// After eliminating redundant parameters from the function, update the
+/// function calls.
+static LogicalResult updateCall(func::CallOp callOp,
+                                BitVector &argumentsNoUse) {
+  ValueRange origOperands = callOp.getOperands();
+  SmallVector<Value, 4> newOperands;
+  for (auto iter : llvm::enumerate(origOperands)) {
+    if (!argumentsNoUse[iter.index()])
+      newOperands.push_back(iter.value());
+  }
+  callOp->setOperands(newOperands);
+  return success();
+}
+
+namespace {
+struct EliminateFunctionParameterPass
+    : public func::impl::EliminateFunctionParameterPassBase<
+          EliminateFunctionParameterPass> {
+  using EliminateFunctionParameterPassBase<
+      EliminateFunctionParameterPass>::EliminateFunctionParameterPassBase;
+  void runOnOperation() override {
+    ModuleOp moduleOp = getOperation();
+    for (auto funcOp : moduleOp.getOps<func::FuncOp>()) {
+      size_t argumentSize = funcOp.getArguments().size();
+      if (!argumentSize)
+        continue;
+      BitVector argumentNoUse(argumentSize);
+      if (failed(updateFunc(funcOp, argumentNoUse)))
+        continue;
+
+      auto symbolOp = mlir::cast<SymbolOpInterface>(funcOp.getOperation());
+      auto users = symbolOp.getSymbolUses(moduleOp);
+
+      if (!users.has_value())
+        continue;
+      for (SymbolTable::SymbolUse user : *users) {
+        Operation *call = user.getUser();
+        (void)updateCall(mlir::cast<func::CallOp>(call), argumentNoUse);
+      }
+    }
+  }
+};
+
+} // namespace
+} // namespace mlir
diff --git a/mlir/test/Dialect/Func/eliminate-function-parameter.mlir b/mlir/test/Dialect/Func/eliminate-function-parameter.mlir
new file mode 100644
index 0000000000000..0bd35ec6bd1c7
--- /dev/null
+++ b/mlir/test/Dialect/Func/eliminate-function-parameter.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-opt %s --split-input-file --eliminate-function-parameter | \
+// RUN: FileCheck %s
+
+func.func @single_parameter(%arg: index) {
+  return
+}
+
+func.func @mutl_parameter(%arg0 : index, %arg1 : index) -> index {
+  return %arg0 : index
+}
+
+func.func @eliminate_parameter(%arg0: index, %arg1: index) -> index {
+  func.call @single_parameter(%arg0) : (index) -> ()
+  %ret = func.call @mutl_parameter(%arg0, %arg0) : (index, index) -> (index)
+  return %ret : index
+}
+
+// CHECK-LABEL: func @single_parameter() {
+//       CHECK:   return
+//       CHECK: }
+
+// CHECK-LABEL: func @mutl_parameter(
+//  CHECK-SAME:   %[[ARG0:.*]]: index) -> index {
+//       CHECK:   return %[[ARG0]] : index
+//       CHECK: }
+
+// CHECK-LABEL: func @eliminate_parameter(
+//  CHECK-SAME:   %[[ARG0:.*]]: index) -> index {
+//       CHECK:   call @single_parameter() : () -> ()
+//       CHECK:   %[[RET:.*]] = call @mutl_parameter(%[[ARG0]]) : (index) -> index
+//       CHECK:   return %[[RET]] : index
+//       CHECK: }

``````````

</details>


https://github.com/llvm/llvm-project/pull/160654


More information about the Mlir-commits mailing list