[Mlir-commits] [mlir] [mlir][func] Fix multiple bugs in `DuplicateFunctionElimination` (PR #109571)

Longsheng Mou llvmlistbot at llvm.org
Sun Sep 22 01:54:26 PDT 2024


https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/109571

>From 69f301a1761b5c9fe21529ff8cd8106c2be7968f Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Sun, 22 Sep 2024 15:44:32 +0800
Subject: [PATCH] [mlir][func] Fix multiple bugs in
 `DuplicateFunctionElimination`

This PR fixes multiple bugs in `DuplicateFunctionElimination`.
- Prevents elimination of function declarations.
- Updates all symbol uses to reference unique function representatives.
---
 .../DuplicateFunctionElimination.cpp          | 18 ++++---
 .../Func/duplicate-function-elimination.mlir  | 48 +++++++++++++++++++
 2 files changed, 59 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
index d41d6c3e8972f9..5300f2ae159d40 100644
--- a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
@@ -54,6 +54,10 @@ struct DuplicateFuncOpEquivalenceInfo
     if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
         rhs == getTombstoneKey() || rhs == getEmptyKey())
       return false;
+
+    if (lhs.isDeclaration() || rhs.isDeclaration())
+      return false;
+
     // Check discardable attributes equivalence
     if (lhs->getDiscardableAttrDictionary() !=
         rhs->getDiscardableAttrDictionary())
@@ -97,14 +101,14 @@ struct DuplicateFunctionEliminationPass
       }
     });
 
-    // Update call ops to call unique func op representants.
-    module.walk([&](func::CallOp callOp) {
-      func::FuncOp callee = getRepresentant[callOp.getCalleeAttr().getAttr()];
-      callOp.setCallee(callee.getSymName());
-    });
-
-    // Erase redundant func ops.
+    // Update all symbol uses to reference unique func op
+    // representants and erase redundant func ops.
     for (auto it : toBeErased) {
+      auto oldSymbol = it.getSymNameAttr();
+      auto newSymbol = getRepresentant[oldSymbol].getSymNameAttr();
+      if (failed(
+              SymbolTable::replaceAllSymbolUses(oldSymbol, newSymbol, module)))
+        return;
       it.erase();
     }
   }
diff --git a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
index 28d059a149bde8..1c6876c1327bc6 100644
--- a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
+++ b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
@@ -366,3 +366,51 @@ func.func @user(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %odd: f32, %even: f32)
 // CHECK:     @user
 // CHECK-2:     call @deep_tree
 // CHECK:       call @reverse_deep_tree
+
+// -----
+
+func.func private @func_declaration(i32, i32) -> i32
+func.func private @func_declaration1(i32, i32) -> i32
+
+func.func @user(%arg0: i32, %arg1: i32) -> (i32, i32) {
+  %0 = call @func_declaration(%arg0, %arg1) : (i32, i32) -> i32
+  %1 = call @func_declaration1(%arg0, %arg1) : (i32, i32) -> i32
+  return %0, %1 : i32, i32
+}
+
+// CHECK:  @func_declaration
+// CHECK:  @func_declaration1
+// CHECK:  @user
+// CHECK:    call @func_declaration
+// CHECK:    call @func_declaration1
+
+
+// -----
+
+func.func @identity(%arg0: tensor<f32>) -> tensor<f32> {
+  return %arg0 : tensor<f32>
+}
+
+func.func @also_identity(%arg0: tensor<f32>) -> tensor<f32> {
+  return %arg0 : tensor<f32>
+}
+
+func.func @yet_another_identity(%arg0: tensor<f32>) -> tensor<f32> {
+  return %arg0 : tensor<f32>
+}
+
+func.func @user(%arg0: tensor<f32>) -> tensor<f32> {
+  %f = constant @identity : (tensor<f32>) -> tensor<f32>
+  %0 = call_indirect %f(%arg0) : (tensor<f32>) -> tensor<f32>
+  %f_0 = constant @also_identity : (tensor<f32>) -> tensor<f32>
+  %1 = call_indirect %f_0(%0) : (tensor<f32>) -> tensor<f32>
+  %2 = call @yet_another_identity(%1) : (tensor<f32>) -> tensor<f32>
+  return %2 : tensor<f32>
+}
+
+// CHECK:     @identity
+// CHECK-NOT: @also_identity
+// CHECK-NOT: @yet_another_identity
+// CHECK:     @user
+// CHECK-2:     constant @identity
+// CHECK:       call @identity



More information about the Mlir-commits mailing list