[Mlir-commits] [mlir] 606f7c8 - [mlir][bufferization][NFC] Move more unknown type conversion logic into BufferizationOptions
Matthias Springer
llvmlistbot at llvm.org
Thu Jul 7 04:40:53 PDT 2022
Author: Matthias Springer
Date: 2022-07-07T13:36:28+02:00
New Revision: 606f7c8f7a770718bd7061d5a506711a9c84f482
URL: https://github.com/llvm/llvm-project/commit/606f7c8f7a770718bd7061d5a506711a9c84f482
DIFF: https://github.com/llvm/llvm-project/commit/606f7c8f7a770718bd7061d5a506711a9c84f482.diff
LOG: [mlir][bufferization][NFC] Move more unknown type conversion logic into BufferizationOptions
The `unknownTypeConversion` bufferization option (enum) is now a type converter function option. Some logic of `getMemRefType` is now handled by that function.
This change makes type conversion more controllable. Previously, there were only two options when generating memref types for non-bufferizable ops: Static identity layout or fully dynamic layout. With this change, users of One-Shot Bufferize can provide a function with custom logic.
Differential Revision: https://reviews.llvm.org/D129273
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index ff8db00f7644e..2cc84c99d2040 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -179,6 +179,10 @@ struct BufferizationOptions {
/// Initializer function for dialect-specific analysis state.
using DialectStateInitFn =
std::function<std::unique_ptr<DialectAnalysisState>()>;
+ /// Tensor -> MemRef type converter.
+ /// Parameters: Value, memory space, bufferization options
+ using UnknownTypeConverterFn = std::function<BaseMemRefType(
+ Value, unsigned, const BufferizationOptions &)>;
enum class LayoutMapOption : int8_t {
InferLayoutMap = 0,
@@ -266,21 +270,11 @@ struct BufferizationOptions {
LayoutMapOption functionBoundaryTypeConversion =
LayoutMapOption::InferLayoutMap;
- /// This flag controls buffer types on unknown ops (to_memref wrappers) and in
- /// other cases where a precise memref type cannot be inferred (e.g., the
- /// bufferization of "tensor.cast").
- ///
- /// * InferLayoutMap: This option is invalid and cannot be used.
- /// * FullyDynamicLayoutMap: Assume that unknown ops have results with fully
- /// dynamic layout maps after bufferization. This option is most efficient
- /// because any layout map can be casted to a fully dynamic one.
- /// * IdentityLayoutMap: Assume that unknown ops have results with static
- /// identity layout (i.e., no layout map) after bufferization. This option
- /// introduces additional buffer allocs and copies if the unknown op is
- /// eventually bufferized to an op that returns a buffer with non-identity
- /// layout.
- LayoutMapOption unknownTypeConversion =
- LayoutMapOption::FullyDynamicLayoutMap;
+ /// Type converter from tensors to memrefs. This type converter is used if no
+ /// memref type could be inferred during bufferization. By default, a type
+ /// converter that returns a memref type with a fully dynamic layout map is
+ /// used.
+ UnknownTypeConverterFn unknownTypeConverterFn = nullptr;
/// Specifies whether dealloc ops should be generated along with alloc ops. If
/// not, new memory allocations will leak.
@@ -505,20 +499,19 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
return newOp;
}
-/// Return a MemRefType to which the `tensorType` can be bufferized.
+/// Return a MemRefType to which the type of the given value can be bufferized.
///
/// If possible, op bufferization implementations should not use this function
/// and instead infer precise memref types for tensor results by themselves.
///
-/// Unless a layout map was specified, `options.unknownTypeConverter` determines
-/// what kind of layout map will be used. For best composability (without
-/// copies), the fully dynamic layout map is used by default.
+/// Unless a layout map was specified, `options.unknownTypeConverterFn`
+/// determines what kind of layout map will be used. For best composability
+/// (without copies), the fully dynamic layout map is used by default.
///
/// Note: Canonicalization patterns could clean up layout maps and infer more
/// precise layout maps after bufferization. However, many possible
/// canonicalizations are currently not implemented.
-BaseMemRefType getMemRefType(TensorType tensorType,
- const BufferizationOptions &options,
+BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options,
MemRefLayoutAttrInterface layout = {},
unsigned memorySpace = 0);
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index 61caa18561d34..49a5c9115e718 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -351,8 +351,9 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*defaultImplementation=*/[{
assert(bbArg.getOwner()->getParentOp() == $_op &&
"bbArg must belong to this op");
- auto tensorType = bbArg.getType().cast<TensorType>();
- return bufferization::getMemRefType(tensorType, options);
+ assert(bbArg.getType().isa<TensorType>() &&
+ "expected tensor type");
+ return bufferization::getMemRefType(bbArg, options);
}]
>,
InterfaceMethod<
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 374fbd7da664d..97a84bf220536 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -222,8 +222,17 @@ bool OpFilter::isOpAllowed(Operation *op) const {
// BufferizationOptions
//===----------------------------------------------------------------------===//
+/// Default unknown type converter: Use a fully dynamic layout map.
+static BaseMemRefType
+defaultUnknownTypeConverter(Value value, unsigned memorySpace,
+ const BufferizationOptions &options) {
+ return getMemRefTypeWithFullyDynamicLayout(value.getType().cast<TensorType>(),
+ memorySpace);
+}
+
// Default constructor for BufferizationOptions.
-BufferizationOptions::BufferizationOptions() = default;
+BufferizationOptions::BufferizationOptions()
+ : unknownTypeConverterFn(defaultUnknownTypeConverter) {}
bool BufferizationOptions::isOpAllowed(Operation *op) const {
// Special case: If function boundary bufferization is deactivated, do not
@@ -528,8 +537,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
/// Return the buffer type for a given Value (tensor) after bufferization.
FailureOr<BaseMemRefType>
bufferization::getBufferType(Value value, const BufferizationOptions &options) {
- auto tensorType = value.getType().dyn_cast<TensorType>();
- assert(tensorType && "unexpected non-tensor type");
+ assert(value.getType().isa<TensorType>() && "unexpected non-tensor type");
Operation *op = getOwnerOfValue(value);
// ToTensorOp: Take buffer type directly from the op.
@@ -566,7 +574,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options) {
if (!memorySpace.hasValue())
return op->emitError("could not infer memory space");
- return getMemRefType(tensorType, options, /*layout=*/{}, *memorySpace);
+ return getMemRefType(value, options, /*layout=*/{}, *memorySpace);
}
void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
@@ -652,10 +660,11 @@ bool bufferization::isFunctionArgument(Value value) {
return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
}
-BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
+BaseMemRefType bufferization::getMemRefType(Value value,
const BufferizationOptions &options,
MemRefLayoutAttrInterface layout,
unsigned memorySpace) {
+ auto tensorType = value.getType().cast<TensorType>();
auto memorySpaceAttr = IntegerAttr::get(
IntegerType::get(tensorType.getContext(), 64), memorySpace);
@@ -674,17 +683,7 @@ BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
memorySpaceAttr);
}
- // Case 3: Configured with "fully dynamic layout maps".
- if (options.unknownTypeConversion ==
- BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap)
- return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
-
- // Case 4: Configured with "static identity layout maps".
- if (options.unknownTypeConversion ==
- BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
- return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
-
- llvm_unreachable("InferLayoutMap is an invalid option");
+ return options.unknownTypeConverterFn(value, memorySpace, options);
}
BaseMemRefType
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index c68d1d120be6a..f1dfbd113947a 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -192,8 +192,26 @@ struct OneShotBufferizePass
opt.printConflicts = printConflicts;
opt.testAnalysisOnly = testAnalysisOnly;
opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
- opt.unknownTypeConversion = parseLayoutMapOption(unknownTypeConversion);
+ // Configure type converter.
+ BufferizationOptions::LayoutMapOption unknownTypeConversionOption =
+ parseLayoutMapOption(unknownTypeConversion);
+ opt.unknownTypeConverterFn = [=](Value value, unsigned memorySpace,
+ const BufferizationOptions &options) {
+ auto tensorType = value.getType().cast<TensorType>();
+ if (unknownTypeConversionOption ==
+ BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
+ return bufferization::getMemRefTypeWithStaticIdentityLayout(
+ tensorType, memorySpace);
+ assert(
+ unknownTypeConversionOption ==
+ BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap &&
+ "invalid layout map option");
+ return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
+ memorySpace);
+ };
+
+ // Configure op filter.
OpFilter::Entry::FilterFn filterFn =
[&](Operation *op) {
// Filter may be specified via options.
@@ -372,10 +390,6 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
const BufferizationOptions &options,
bool copyBeforeWrite,
const OpFilter *opFilter) {
- assert(options.unknownTypeConversion !=
- BufferizationOptions::LayoutMapOption::InferLayoutMap &&
- "invalid layout map option");
-
if (copyBeforeWrite) {
AnalysisState state(options);
if (failed(insertTensorCopies(op, state)))
@@ -474,8 +488,11 @@ BufferizationOptions bufferization::getPartialBufferizationOptions() {
options.allowUnknownOps = true;
options.createDeallocs = false;
options.enforceAliasingInvariants = false;
- options.unknownTypeConversion =
- BufferizationOptions::LayoutMapOption::IdentityLayoutMap;
+ options.unknownTypeConverterFn = [](Value value, unsigned memorySpace,
+ const BufferizationOptions &options) {
+ return getMemRefTypeWithStaticIdentityLayout(
+ value.getType().cast<TensorType>(), memorySpace);
+ };
options.opFilter.allowDialect<BufferizationDialect>();
return options;
}
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 97da5969a3004..6cd9134b097ab 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -67,7 +67,7 @@ struct CastOpInterface
// Compute the new memref type.
Type resultMemRefType =
- getMemRefType(resultTensorType, options, layout,
+ getMemRefType(castOp.getResult(), options, layout,
sourceMemRefType.getMemorySpaceAsInt());
// Replace the op with a memref.cast.
@@ -780,9 +780,8 @@ struct ReshapeOpInterface
getBuffer(rewriter, reshapeOp.getShape(), options);
if (failed(srcBuffer) || failed(shapeBuffer))
return failure();
- auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>();
auto resultMemRefType = getMemRefType(
- resultTensorType, options, /*layout=*/{},
+ reshapeOp.getResult(), options, /*layout=*/{},
srcBuffer->getType().cast<BaseMemRefType>().getMemorySpaceAsInt());
replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);
More information about the Mlir-commits
mailing list