[Mlir-commits] [mlir] [mlir][bufferization] Improve performance of DropEquivalentBufferResultsPass (PR #101281)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 30 20:10:17 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

<details>
<summary>Changes</summary>

By using DenseMap to minimize the traveral time of callOps, and the efficiency of running this pass has been greatly improved.

---
Full diff: https://github.com/llvm/llvm-project/pull/101281.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp (+10-6) 


``````````diff
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
index 016ec2be62dce..d86bdb20a66bb 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
@@ -71,6 +71,14 @@ LogicalResult
 mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
   IRRewriter rewriter(module.getContext());
 
+  DenseMap<func::FuncOp, DenseSet<func::CallOp>> callerMap;
+  // Collect the mapping of functions to their call sites.
+  module.walk([&](func::CallOp callOp) {
+    if (func::FuncOp calledFunc = getCalledFunction(callOp)) {
+      callerMap[calledFunc].insert(callOp);
+    }
+  });
+
   for (auto funcOp : module.getOps<func::FuncOp>()) {
     if (funcOp.isExternal())
       continue;
@@ -109,10 +117,7 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
     returnOp.getOperandsMutable().assign(newReturnValues);
 
     // Update function calls.
-    module.walk([&](func::CallOp callOp) {
-      if (getCalledFunction(callOp) != funcOp)
-        return WalkResult::skip();
-
+    for (func::CallOp callOp : callerMap[funcOp]) {
       rewriter.setInsertionPoint(callOp);
       auto newCallOp = rewriter.create<func::CallOp>(callOp.getLoc(), funcOp,
                                                      callOp.getOperands());
@@ -136,8 +141,7 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
         newResults.push_back(replacement);
       }
       rewriter.replaceOp(callOp, newResults);
-      return WalkResult::advance();
-    });
+    }
   }
 
   return success();

``````````

</details>


https://github.com/llvm/llvm-project/pull/101281


More information about the Mlir-commits mailing list