[Mlir-commits] [mlir] 75ef84b - [mlir][bufferization] Make function boundary type convertion logic dynamic.
Oleg Shyshkov
llvmlistbot at llvm.org
Wed Apr 12 02:03:33 PDT 2023
Author: Oleg Shyshkov
Date: 2023-04-12T11:02:43+02:00
New Revision: 75ef84bf52605cbb24f4b7b8ca9ea8a46077d885
URL: https://github.com/llvm/llvm-project/commit/75ef84bf52605cbb24f4b7b8ca9ea8a46077d885
DIFF: https://github.com/llvm/llvm-project/commit/75ef84bf52605cbb24f4b7b8ca9ea8a46077d885.diff
LOG: [mlir][bufferization] Make function boundary type convertion logic dynamic.
Having to choose from only static or dynamic layout for all function is limiting.
Differential Revision: https://reviews.llvm.org/D148074
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index b6644963a751d..2dbd113547e91 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -19,6 +19,9 @@
namespace mlir {
class OpBuilder;
+namespace func {
+class FuncOp;
+}
namespace bufferization {
@@ -250,6 +253,11 @@ struct BufferizationOptions {
/// Initializer function for analysis state.
using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
/// Tensor -> MemRef type converter.
+ /// Parameters: Value, memory space, func op, bufferization options
+ using FunctionArgTypeConverterFn =
+ std::function<BaseMemRefType(TensorType, Attribute memorySpace,
+ func::FuncOp, const BufferizationOptions &)>;
+ /// Tensor -> MemRef type converter.
/// Parameters: Value, memory space, bufferization options
using UnknownTypeConverterFn = std::function<BaseMemRefType(
Value, Attribute memorySpace, const BufferizationOptions &)>;
@@ -313,7 +321,8 @@ struct BufferizationOptions {
/// OpOperands out-of-place.
bool enforceAliasingInvariants = true;
- /// This flag controls buffer types on function signatures.
+ /// This function controls buffer types on function signatures. Sets
+ /// `functionArgTypeConverterFn` and `inferFunctionResultLayout` accordingly.
///
/// * InferLayoutMap: All function parameter types have a fully dynamic layout
/// map, but function result types are inferred from the body of the
@@ -326,13 +335,25 @@ struct BufferizationOptions {
/// additional buffer allocs and copies because layout maps cannot be casted
/// away.
///
- /// If `bufferizeFunctionBoundaries` is not set, this flag has no effect.
- ///
/// Note: Inferred layout maps may not be desireable when interacting with
/// external functions, because the generated function signatures will be less
/// predictable.
- LayoutMapOption functionBoundaryTypeConversion =
- LayoutMapOption::InferLayoutMap;
+ void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption);
+
+ /// Type converter from tensors to memrefs. This type converter is used to
+ /// determine bufferized function argument types. By default, a type
+ /// converter that returns a memref type with a fully dynamic layout map is
+ /// used.
+ ///
+ /// If `bufferizeFunctionBoundaries` is not set, this function isn't used.
+ FunctionArgTypeConverterFn functionArgTypeConverterFn = nullptr;
+
+ /// If true, function result types are inferred from the body of the function.
+ /// Otherwise, function result type is determined by
+ /// `functionArgTypeConverterFn`.
+ ///
+ /// If `bufferizeFunctionBoundaries` is not set, this flag has no effect.
+ bool inferFunctionResultLayout = true;
/// Type converter from tensors to memrefs. This type converter is used if no
/// memref type could be inferred during bufferization. By default, a type
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 3b965cf732086..70d857b20ad60 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -322,17 +322,29 @@ bool OpFilter::isOpAllowed(Operation *op) const {
// BufferizationOptions
//===----------------------------------------------------------------------===//
+namespace {
+
+/// Default function arg type converter: Use a fully dynamic layout map.
+BaseMemRefType
+defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
+ func::FuncOp funcOp,
+ const BufferizationOptions &options) {
+ return getMemRefTypeWithFullyDynamicLayout(type, memorySpace);
+}
/// Default unknown type converter: Use a fully dynamic layout map.
-static BaseMemRefType
+BaseMemRefType
defaultUnknownTypeConverter(Value value, Attribute memorySpace,
const BufferizationOptions &options) {
return getMemRefTypeWithFullyDynamicLayout(value.getType().cast<TensorType>(),
memorySpace);
}
+}; // namespace
+
// Default constructor for BufferizationOptions.
BufferizationOptions::BufferizationOptions()
- : unknownTypeConverterFn(defaultUnknownTypeConverter) {}
+ : functionArgTypeConverterFn(defaultFunctionArgTypeConverter),
+ unknownTypeConverterFn(defaultUnknownTypeConverter) {}
bool BufferizationOptions::isOpAllowed(Operation *op) const {
// Special case: If function boundary bufferization is deactivated, do not
@@ -362,6 +374,21 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {
return nullptr;
}
+void BufferizationOptions::setFunctionBoundaryTypeConversion(
+ LayoutMapOption layoutMapOption) {
+ functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
+ func::FuncOp funcOp,
+ const BufferizationOptions &options) {
+ if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
+ return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
+ memorySpace);
+ return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
+ memorySpace);
+ };
+ inferFunctionResultLayout =
+ layoutMapOption == LayoutMapOption::InferLayoutMap;
+}
+
//===----------------------------------------------------------------------===//
// Helper functions for BufferizableOpInterface
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index 58766e83f255b..ed95a62b9b6f8 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -38,8 +38,8 @@ transform::OneShotBufferizeOp::apply(TransformResults &transformResults,
options.testAnalysisOnly = getTestAnalysisOnly();
options.printConflicts = getPrintConflicts();
if (getFunctionBoundaryTypeConversion().has_value())
- options.functionBoundaryTypeConversion =
- *getFunctionBoundaryTypeConversion();
+ options.setFunctionBoundaryTypeConversion(
+ *getFunctionBoundaryTypeConversion());
ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
for (Operation *target : payloadOps) {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index e5e125f731032..4eabfccf2514b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -208,8 +208,8 @@ struct OneShotBufferizePass
opt.analysisHeuristic = parseHeuristicOption(analysisHeuristic);
opt.copyBeforeWrite = copyBeforeWrite;
opt.createDeallocs = createDeallocs;
- opt.functionBoundaryTypeConversion =
- parseLayoutMapOption(functionBoundaryTypeConversion);
+ opt.setFunctionBoundaryTypeConversion(
+ parseLayoutMapOption(functionBoundaryTypeConversion));
if (mustInferMemorySpace)
opt.defaultMemorySpace = std::nullopt;
opt.printConflicts = printConflicts;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 1c5767973960c..bf14e466190b4 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -55,8 +55,7 @@ static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
/// Return the index-th bufferized function argument type. This assumes that the
/// specified argument is a tensor. If the tensor is ranked, a layout map may be
-/// specified by the user. If no layout map is specified, the default layout map
-/// (as per `options.functionBoundaryTypeConversion`) is used.
+/// specified by the user (as per `options.functionArgTypeConverterFn`).
static BaseMemRefType
getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
const BufferizationOptions &options) {
@@ -64,17 +63,8 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>();
assert(tensorType && "expected TensorType");
- BaseMemRefType memrefType;
- if (options.functionBoundaryTypeConversion ==
- LayoutMapOption::IdentityLayoutMap) {
- memrefType = getMemRefTypeWithStaticIdentityLayout(
- tensorType, *options.defaultMemorySpace);
- } else {
- // Note: Layout maps on function parameters cannot be inferred. The best we
- // can do at the moment is "fully dynamic".
- memrefType = getMemRefTypeWithFullyDynamicLayout(
- tensorType, *options.defaultMemorySpace);
- }
+ BaseMemRefType memrefType = options.functionArgTypeConverterFn(
+ tensorType, *options.defaultMemorySpace, funcOp, options);
auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
index, BufferizationDialect::kBufferLayoutAttrName);
@@ -423,16 +413,10 @@ struct FuncOpInterface
continue;
}
- BaseMemRefType resultType;
- if (options.functionBoundaryTypeConversion ==
- LayoutMapOption::IdentityLayoutMap) {
- resultType = getMemRefTypeWithStaticIdentityLayout(
- tensorType, *options.defaultMemorySpace);
- } else {
- // Note: If `InferLayoutMap`, cast are later folded away.
- resultType = getMemRefTypeWithFullyDynamicLayout(
- tensorType, *options.defaultMemorySpace);
- }
+ // Note: If `inferFunctionResultLayout = true`, cast are later folded
+ // away.
+ BaseMemRefType resultType = options.functionArgTypeConverterFn(
+ tensorType, *options.defaultMemorySpace, funcOp, options);
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
loc, resultType, returnVal);
returnValues.push_back(toMemrefOp);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index c96a507ab654f..27b560afdbb34 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -433,8 +433,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
/*opFilter=*/nullptr, statistics)))
return failure();
// Change buffer return types to more precise layout maps.
- if (options.functionBoundaryTypeConversion ==
- LayoutMapOption::InferLayoutMap)
+ if (options.inferFunctionResultLayout)
foldMemRefCasts(funcOp);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
index a2fa480c8b39a..99a619cda7b63 100644
--- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
@@ -37,7 +37,7 @@ getBufferizationOptions(bool analysisOnly) {
// TODO(springerm): To spot memory leaks more easily, returning dense allocs
// should be disallowed.
options.allowReturnAllocs = true;
- options.functionBoundaryTypeConversion = LayoutMapOption::IdentityLayoutMap;
+ options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap);
options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
const BufferizationOptions &options) {
return getMemRefTypeWithStaticIdentityLayout(
More information about the Mlir-commits
mailing list