[Mlir-commits] [mlir] bf58256 - [mlir][bufferize] Fix bug in module equivalence analysis

Matthias Springer llvmlistbot at llvm.org
Thu Jun 9 09:35:39 PDT 2022


Author: Matthias Springer
Date: 2022-06-09T18:32:17+02:00
New Revision: bf58256967e5dab0d991b0a2e50671b943c6dc2e

URL: https://github.com/llvm/llvm-project/commit/bf58256967e5dab0d991b0a2e50671b943c6dc2e
DIFF: https://github.com/llvm/llvm-project/commit/bf58256967e5dab0d991b0a2e50671b943c6dc2e.diff

LOG: [mlir][bufferize] Fix bug in module equivalence analysis

CallOp result are not equivalent to an OpOperand if the OpOperand bufferizes out-of-place.

Differential Revision: https://reviews.llvm.org/D126813

Added: 
    

Modified: 
    mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
    mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index df89f682fae38..243db9651c4f0 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -258,7 +258,8 @@ static func::FuncOp getCalledFunction(CallOpInterface callOp) {
 // TODO: This does not handle cyclic function call graphs etc.
 static void equivalenceAnalysis(func::FuncOp funcOp,
                                 BufferizationAliasInfo &aliasInfo,
-                                FuncAnalysisState &funcState) {
+                                OneShotAnalysisState &state) {
+  FuncAnalysisState &funcState = getFuncAnalysisState(state);
   funcOp->walk([&](func::CallOp callOp) {
     func::FuncOp calledFunction = getCalledFunction(callOp);
     assert(calledFunction && "could not retrieved called func::FuncOp");
@@ -270,6 +271,8 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
     for (auto it : funcState.equivalentFuncArgs[calledFunction]) {
       int64_t returnIdx = it.first;
       int64_t bbargIdx = it.second;
+      if (!state.isInPlace(callOp->getOpOperand(bbargIdx)))
+        continue;
       Value returnVal = callOp.getResult(returnIdx);
       Value argVal = callOp->getOperand(bbargIdx);
       aliasInfo.unionEquivalenceClasses(returnVal, argVal);
@@ -409,7 +412,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
     funcState.startFunctionAnalysis(funcOp);
 
     // Gather equivalence info for CallOps.
-    equivalenceAnalysis(funcOp, aliasInfo, funcState);
+    equivalenceAnalysis(funcOp, aliasInfo, state);
 
     // Analyze funcOp.
     if (failed(analyzeOp(funcOp, state)))

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index d617a29c03642..beb3b38da7b0e 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -196,8 +196,9 @@ func.func @call_func_with_non_tensor_return(
   // CHECK: %[[call:.*]] = call @inner_func(%[[casted]])
   %0, %1 = call @inner_func(%t0) : (tensor<?xf32>) -> (tensor<?xf32>, f32)
 
-  // Note: The tensor return value has folded away.
-  // CHECK: return %[[call]] : f32
+  // Note: The tensor return value cannot fold away because the CallOp
+  // bufferized out-of-place.
+  // CHECK: return %[[call]], %[[alloc]] : f32, memref<?xf32>
   return %1, %0 : f32, tensor<?xf32>
 }
 


        


More information about the Mlir-commits mailing list