[Mlir-commits] [mlir] d62b4b0 - [mlir][linalg][bufferize] Compose dialect-specific bufferization state
Matthias Springer
llvmlistbot at llvm.org
Thu Nov 25 18:38:37 PST 2021
Author: Matthias Springer
Date: 2021-11-26T11:35:45+09:00
New Revision: d62b4b08af03a9fc25274ed0e380d9d052fe251b
URL: https://github.com/llvm/llvm-project/commit/d62b4b08af03a9fc25274ed0e380d9d052fe251b
DIFF: https://github.com/llvm/llvm-project/commit/d62b4b08af03a9fc25274ed0e380d9d052fe251b.diff
LOG: [mlir][linalg][bufferize] Compose dialect-specific bufferization state
Use composition instead of inheritance for storing dialect-specific bufferization state. This is in preparation of adding "tensor dialect"-specific bufferization state.
Differential Revision: https://reviews.llvm.org/D114508
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index f52a9aa7b4f1..e03aaea85731 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -230,6 +230,13 @@ struct AllocationCallbacks {
MemCpyFn memCpyFn;
};
+/// Dialect-specific bufferization state. Analysis/bufferization information
+/// that is specific to ops from a certain dialect can be stored in derived
+/// variants of this struct.
+struct DialectBufferizationState {
+ virtual ~DialectBufferizationState() = default;
+};
+
/// BufferizationState keeps track of bufferization state and provides access to
/// the results of the analysis.
struct BufferizationState {
@@ -271,6 +278,14 @@ struct BufferizationState {
/// Erase all ops that were marked obsolete.
void eraseObsoleteOps();
+ /// Return dialect-specific bufferization state.
+ template <typename StateT> StateT &getDialectState(StringRef name) {
+ // Create state if it does not exist yet.
+ if (!dialectState.count(name))
+ dialectState[name] = std::make_unique<StateT>();
+ return static_cast<StateT &>(*dialectState[name]);
+ }
+
/// `aliasInfo` keeps track of aliasing and equivalent values.
BufferizationAliasInfo aliasInfo;
@@ -284,6 +299,9 @@ struct BufferizationState {
/// Obsolete ops that should be deleted after bufferization.
SmallVector<Operation *> obsoleteOps;
+
+ /// Dialect-specific bufferization state.
+ DenseMap<StringRef, std::unique_ptr<DialectBufferizationState>> dialectState;
};
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index e354195dfca0..c98cc1de60af 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -27,11 +27,9 @@ using namespace tensor;
using namespace comprehensive_bufferize;
namespace {
-/// A specialization of BufferizationState that keeps track of additional
-/// state required for bufferization of function boundaries.
-struct ModuleBufferizationState : public BufferizationState {
- using BufferizationState::BufferizationState;
-
+/// Extra bufferization state that is required for bufferization of function
+/// boundaries.
+struct ModuleBufferizationState : public DialectBufferizationState {
/// A map for looking up bufferized function types.
DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
@@ -40,6 +38,12 @@ struct ModuleBufferizationState : public BufferizationState {
};
} // namespace
+static ModuleBufferizationState &
+getModuleBufferizationState(BufferizationState &state) {
+ return state.getDialectState<ModuleBufferizationState>(
+ StandardOpsDialect::getDialectNamespace());
+}
+
static bool isaTensor(Type t) { return t.isa<TensorType>(); }
/// If `value` is a memref::CastOp, return its source. Otherwise, return
@@ -127,7 +131,9 @@ static FunctionType getOrCreateBufferizedFunctionType(
/// Store function BlockArguments that are equivalent to a returned value in
/// the given ModuleBufferizationState.
static void populateEquivalentFuncOpBBArgs(FuncOp funcOp,
- ModuleBufferizationState &state) {
+ BufferizationState &state) {
+ ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
+
// Support only single return-terminated block in the function.
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
assert(returnOp && "expected func with single return op");
@@ -137,7 +143,7 @@ static void populateEquivalentFuncOpBBArgs(FuncOp funcOp,
for (BlockArgument bbArg : funcOp.getArguments())
if (bbArg.getType().isa<RankedTensorType>())
if (state.aliasInfo.areEquivalentBufferizedValues(returnVal, bbArg))
- state.equivalentReturnValToBBArg[returnVal] = bbArg;
+ moduleState.equivalentReturnValToBBArg[returnVal] = bbArg;
}
/// Rewrite the `funcOp` arguments analysis return values and terminator into
@@ -155,8 +161,9 @@ static void populateEquivalentFuncOpBBArgs(FuncOp funcOp,
/// originate from an op with an Alloc effect, they could be hoisted in the
/// future.
static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
- ModuleBufferizationState &state) {
+ BufferizationState &state) {
LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n");
+ ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
// If nothing to do then we are done.
@@ -188,7 +195,7 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
<< "returns a tensor";
FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
funcOp, funcOp.getType().getInputs(), TypeRange{},
- state.bufferizedFunctionTypes);
+ moduleState.bufferizedFunctionTypes);
funcOp.setType(bufferizedFuncType);
LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary no fun body: " << funcOp);
return success();
@@ -210,7 +217,7 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
}
// If return operand is equivalent to some bbArg, no need to return it.
- if (state.equivalentReturnValToBBArg.count(returnVal))
+ if (moduleState.equivalentReturnValToBBArg.count(returnVal))
continue;
// Cast values at the call site if necessary.
@@ -221,7 +228,7 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
ValueRange retValues{returnValues};
FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
funcOp, funcOp.getType().getInputs(), retValues.getTypes(),
- state.bufferizedFunctionTypes);
+ moduleState.bufferizedFunctionTypes);
OpBuilder b(returnOp);
b.create<ReturnOp>(returnOp.getLoc(), returnValues);
returnOp->erase();
@@ -474,7 +481,7 @@ struct CallOpInterface
FuncOp funcOp = getCalledFunction(callOp);
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
"expected Callop to a FuncOp");
- auto &moduleState = static_cast<ModuleBufferizationState &>(state);
+ ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
@@ -649,7 +656,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
return failure();
- ModuleBufferizationState state(moduleOp, *options.allocationFns);
+ BufferizationState state(moduleOp, *options.allocationFns);
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
// Interestingly, all function args that are not visible outside of a module
More information about the Mlir-commits
mailing list