[Mlir-commits] [mlir] cd84cf9 - [mlir][linalg][bufferize][NFC] Do not cache bufferized function types
Matthias Springer
llvmlistbot at llvm.org
Thu Jan 6 07:14:14 PST 2022
Author: Matthias Springer
Date: 2022-01-07T00:04:57+09:00
New Revision: cd84cf90e90856f0672e80701a14c672014b3471
URL: https://github.com/llvm/llvm-project/commit/cd84cf90e90856f0672e80701a14c672014b3471
DIFF: https://github.com/llvm/llvm-project/commit/cd84cf90e90856f0672e80701a14c672014b3471.diff
LOG: [mlir][linalg][bufferize][NFC] Do not cache bufferized function types
This does not work if BufferizationState is passed around as a const reference in most places.
Differential Revision: https://reviews.llvm.org/D116741
Added:
Modified:
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 00719b3e34661..171b47b6447c4 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -24,9 +24,6 @@ namespace {
/// Extra bufferization state that is required for bufferization of function
/// boundaries.
struct ModuleBufferizationState : public DialectBufferizationState {
- /// A map for looking up bufferized function types.
- DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
-
/// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg
/// indices.
DenseMap<FuncOp, DenseMap<int64_t, int64_t>> equivalentFuncArgs;
@@ -161,23 +158,6 @@ static FunctionType getBufferizedFunctionType(MLIRContext *ctx,
return FunctionType::get(ctx, argTypes, retTypes);
}
-/// If an entry for `funcOp` is available in `bufferizedFunctionTypes`, return
-/// it. Otherwise, construct a new entry based on `argumentTypes` and
-/// `resultTypes`.
-// TODO: improve the layering.
-static FunctionType getOrCreateBufferizedFunctionType(
- FuncOp funcOp, TypeRange argumentTypes, TypeRange resultTypes,
- DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
- auto it = bufferizedFunctionTypes.find(funcOp);
- if (it != bufferizedFunctionTypes.end())
- return it->second;
-
- auto it2 = bufferizedFunctionTypes.try_emplace(
- funcOp, getBufferizedFunctionType(funcOp.getContext(), argumentTypes,
- resultTypes));
- return it2.first->second;
-}
-
/// Gather equivalence info of CallOps.
/// Note: This only adds new equivalence info if `funcOp` was already analyzed.
// TODO: This does not handle cyclic function call graphs etc.
@@ -250,9 +230,8 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
return funcOp->emitError() << "cannot bufferize bodiless function that "
<< "returns a tensor";
- FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
- funcOp, funcOp.getType().getInputs(), TypeRange{},
- moduleState.bufferizedFunctionTypes);
+ FunctionType bufferizedFuncType = getBufferizedFunctionType(
+ funcOp.getContext(), funcOp.getType().getInputs(), TypeRange{});
funcOp.setType(bufferizedFuncType);
return success();
}
@@ -284,9 +263,8 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
// 2. Rewrite the terminator without the inPlace bufferizable values.
ValueRange retValues{returnValues};
- FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
- funcOp, funcOp.getType().getInputs(), retValues.getTypes(),
- moduleState.bufferizedFunctionTypes);
+ FunctionType bufferizedFuncType = getBufferizedFunctionType(
+ funcOp.getContext(), funcOp.getType().getInputs(), retValues.getTypes());
OpBuilder b(returnOp);
b.create<ReturnOp>(returnOp.getLoc(), returnValues);
returnOp->erase();
@@ -590,9 +568,8 @@ struct CallOpInterface
SmallVector<Type> argumentTypes{callOp->getOperandTypes()};
// Get the bufferized FunctionType for funcOp or construct it if not yet
// available.
- FunctionType bufferizedFuncType =
- getOrCreateBufferizedFunctionType(funcOp, argumentTypes, resultTypes,
- moduleState.bufferizedFunctionTypes);
+ FunctionType bufferizedFuncType = getBufferizedFunctionType(
+ funcOp.getContext(), argumentTypes, resultTypes);
// 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
for (OpOperand &opOperand : callOp->getOpOperands()) {
More information about the Mlir-commits
mailing list