[Mlir-commits] [mlir] bf58256 - [mlir][bufferize] Fix bug in module equivalence analysis
Matthias Springer
llvmlistbot at llvm.org
Thu Jun 9 09:35:39 PDT 2022
Author: Matthias Springer
Date: 2022-06-09T18:32:17+02:00
New Revision: bf58256967e5dab0d991b0a2e50671b943c6dc2e
URL: https://github.com/llvm/llvm-project/commit/bf58256967e5dab0d991b0a2e50671b943c6dc2e
DIFF: https://github.com/llvm/llvm-project/commit/bf58256967e5dab0d991b0a2e50671b943c6dc2e.diff
LOG: [mlir][bufferize] Fix bug in module equivalence analysis
CallOp result are not equivalent to an OpOperand if the OpOperand bufferizes out-of-place.
Differential Revision: https://reviews.llvm.org/D126813
Added:
Modified:
mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index df89f682fae38..243db9651c4f0 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -258,7 +258,8 @@ static func::FuncOp getCalledFunction(CallOpInterface callOp) {
// TODO: This does not handle cyclic function call graphs etc.
static void equivalenceAnalysis(func::FuncOp funcOp,
BufferizationAliasInfo &aliasInfo,
- FuncAnalysisState &funcState) {
+ OneShotAnalysisState &state) {
+ FuncAnalysisState &funcState = getFuncAnalysisState(state);
funcOp->walk([&](func::CallOp callOp) {
func::FuncOp calledFunction = getCalledFunction(callOp);
assert(calledFunction && "could not retrieved called func::FuncOp");
@@ -270,6 +271,8 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
for (auto it : funcState.equivalentFuncArgs[calledFunction]) {
int64_t returnIdx = it.first;
int64_t bbargIdx = it.second;
+ if (!state.isInPlace(callOp->getOpOperand(bbargIdx)))
+ continue;
Value returnVal = callOp.getResult(returnIdx);
Value argVal = callOp->getOperand(bbargIdx);
aliasInfo.unionEquivalenceClasses(returnVal, argVal);
@@ -409,7 +412,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
funcState.startFunctionAnalysis(funcOp);
// Gather equivalence info for CallOps.
- equivalenceAnalysis(funcOp, aliasInfo, funcState);
+ equivalenceAnalysis(funcOp, aliasInfo, state);
// Analyze funcOp.
if (failed(analyzeOp(funcOp, state)))
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index d617a29c03642..beb3b38da7b0e 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -196,8 +196,9 @@ func.func @call_func_with_non_tensor_return(
// CHECK: %[[call:.*]] = call @inner_func(%[[casted]])
%0, %1 = call @inner_func(%t0) : (tensor<?xf32>) -> (tensor<?xf32>, f32)
- // Note: The tensor return value has folded away.
- // CHECK: return %[[call]] : f32
+ // Note: The tensor return value cannot fold away because the CallOp
+ // bufferized out-of-place.
+ // CHECK: return %[[call]], %[[alloc]] : f32, memref<?xf32>
return %1, %0 : f32, tensor<?xf32>
}
More information about the Mlir-commits
mailing list