[Mlir-commits] [mlir] [mlir][func] Fix multiple bugs in `DuplicateFunctionElimination` (PR #109571)
Longsheng Mou
llvmlistbot at llvm.org
Sun Sep 22 00:59:42 PDT 2024
https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/109571
This PR fixes multiple bugs in `DuplicateFunctionElimination`.
- Prevents elimination of function declarations.
- Updates constant ops to reference unique function representatives.
- Simplifies DenseMap by using `StringRef` as the key instead of `StringAttr`.
Fixes #93483.
>From 2b8ccab75096cb4a4f338a4ac530f2e1b5972f33 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Sun, 22 Sep 2024 15:44:32 +0800
Subject: [PATCH] [mlir][func] Fix multiple bugs in
`DuplicateFunctionElimination`
This PR fixes multiple bugs in `DuplicateFunctionElimination`.
- Prevents elimination of function declarations.
- Updates constant ops to reference unique function representatives.
- Simplifies DenseMap by using `StringRef` as the key instead of `StringAttr`.
---
.../DuplicateFunctionElimination.cpp | 15 ++++--
.../Func/duplicate-function-elimination.mlir | 48 +++++++++++++++++++
2 files changed, 60 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
index d41d6c3e8972f9..5e23207eabf9c4 100644
--- a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
@@ -54,6 +54,10 @@ struct DuplicateFuncOpEquivalenceInfo
if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
rhs == getTombstoneKey() || rhs == getEmptyKey())
return false;
+
+ if (lhs.isDeclaration() || rhs.isDeclaration())
+ return false;
+
// Check discardable attributes equivalence
if (lhs->getDiscardableAttrDictionary() !=
rhs->getDiscardableAttrDictionary())
@@ -87,11 +91,11 @@ struct DuplicateFunctionEliminationPass
// Find unique representant per equivalent func ops.
DenseSet<func::FuncOp, DuplicateFuncOpEquivalenceInfo> uniqueFuncOps;
- DenseMap<StringAttr, func::FuncOp> getRepresentant;
+ DenseMap<StringRef, func::FuncOp> getRepresentant;
DenseSet<func::FuncOp> toBeErased;
module.walk([&](func::FuncOp f) {
auto [repr, inserted] = uniqueFuncOps.insert(f);
- getRepresentant[f.getSymNameAttr()] = *repr;
+ getRepresentant[f.getSymName()] = *repr;
if (!inserted) {
toBeErased.insert(f);
}
@@ -99,9 +103,14 @@ struct DuplicateFunctionEliminationPass
// Update call ops to call unique func op representants.
module.walk([&](func::CallOp callOp) {
- func::FuncOp callee = getRepresentant[callOp.getCalleeAttr().getAttr()];
+ func::FuncOp callee = getRepresentant[callOp.getCallee()];
callOp.setCallee(callee.getSymName());
});
+ // Update constant ops to reference unique func op representants.
+ module.walk([&](func::ConstantOp constantOp) {
+ func::FuncOp value = getRepresentant[constantOp.getValue()];
+ constantOp.setValue(value.getSymName());
+ });
// Erase redundant func ops.
for (auto it : toBeErased) {
diff --git a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
index 28d059a149bde8..1c6876c1327bc6 100644
--- a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
+++ b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
@@ -366,3 +366,51 @@ func.func @user(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %odd: f32, %even: f32)
// CHECK: @user
// CHECK-2: call @deep_tree
// CHECK: call @reverse_deep_tree
+
+// -----
+
+func.func private @func_declaration(i32, i32) -> i32
+func.func private @func_declaration1(i32, i32) -> i32
+
+func.func @user(%arg0: i32, %arg1: i32) -> (i32, i32) {
+ %0 = call @func_declaration(%arg0, %arg1) : (i32, i32) -> i32
+ %1 = call @func_declaration1(%arg0, %arg1) : (i32, i32) -> i32
+ return %0, %1 : i32, i32
+}
+
+// CHECK: @func_declaration
+// CHECK: @func_declaration1
+// CHECK: @user
+// CHECK: call @func_declaration
+// CHECK: call @func_declaration1
+
+
+// -----
+
+func.func @identity(%arg0: tensor<f32>) -> tensor<f32> {
+ return %arg0 : tensor<f32>
+}
+
+func.func @also_identity(%arg0: tensor<f32>) -> tensor<f32> {
+ return %arg0 : tensor<f32>
+}
+
+func.func @yet_another_identity(%arg0: tensor<f32>) -> tensor<f32> {
+ return %arg0 : tensor<f32>
+}
+
+func.func @user(%arg0: tensor<f32>) -> tensor<f32> {
+ %f = constant @identity : (tensor<f32>) -> tensor<f32>
+ %0 = call_indirect %f(%arg0) : (tensor<f32>) -> tensor<f32>
+ %f_0 = constant @also_identity : (tensor<f32>) -> tensor<f32>
+ %1 = call_indirect %f_0(%0) : (tensor<f32>) -> tensor<f32>
+ %2 = call @yet_another_identity(%1) : (tensor<f32>) -> tensor<f32>
+ return %2 : tensor<f32>
+}
+
+// CHECK: @identity
+// CHECK-NOT: @also_identity
+// CHECK-NOT: @yet_another_identity
+// CHECK: @user
+// CHECK-2: constant @identity
+// CHECK: call @identity
More information about the Mlir-commits
mailing list