[Mlir-commits] [mlir] 067d277 - [MLIR] Setting MemorySpace During Bufferization (#78484)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 8 07:59:40 PST 2024
Author: ian Bearman
Date: 2024-02-08T16:59:37+01:00
New Revision: 067d2779fcfc62dd429177f350b8cefe49b65b51
URL: https://github.com/llvm/llvm-project/commit/067d2779fcfc62dd429177f350b8cefe49b65b51
DIFF: https://github.com/llvm/llvm-project/commit/067d2779fcfc62dd429177f350b8cefe49b65b51.diff
LOG: [MLIR] Setting MemorySpace During Bufferization (#78484)
Collection of changes with the goal of being able to convert `encoding`
to `memorySpace` during bufferization
- new API for encoder to allow implementation to select destination
memory space
- update existing bufferization implementations to support the new
interface
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 226a2fbd08563c..d8cfeee2466360 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -257,6 +257,9 @@ struct BufferizationOptions {
/// Parameters: Value, memory space, bufferization options
using UnknownTypeConverterFn = std::function<BaseMemRefType(
Value, Attribute memorySpace, const BufferizationOptions &)>;
+ // Produce a MemorySpace attribute from a tensor type
+ using DefaultMemorySpaceFn =
+ std::function<std::optional<Attribute>(TensorType t)>;
BufferizationOptions();
@@ -296,11 +299,6 @@ struct BufferizationOptions {
/// bufferized or not.
bool bufferizeFunctionBoundaries = false;
- /// The default memory space that should be used when it cannot be inferred
- /// from the context. If case of std::nullopt, bufferization fails when the
- /// memory space cannot be inferred at any point.
- std::optional<Attribute> defaultMemorySpace = Attribute();
-
/// Certain ops have aliasing OpOperand/OpResult invariants (e.g., scf.for).
/// If this flag is set to `false`, those invariants are no longer enforced
/// with buffer copies.
@@ -351,6 +349,13 @@ struct BufferizationOptions {
/// used.
UnknownTypeConverterFn unknownTypeConverterFn = nullptr;
+ // Use during type conversion to determine the memory space for memref based
+ // on the original tensor type if the memory space cannot be inferred.
+ // Returning std::nullopt will cause bufferization to fail (useful to indicate
+ // failure to determine memory space for a tensor type).
+ DefaultMemorySpaceFn defaultMemorySpaceFn =
+ [](TensorType t) -> std::optional<Attribute> { return Attribute(); };
+
/// Seed for the analysis fuzzer. If set to `0`, the fuzzer is deactivated.
/// Should be used only with `testAnalysisOnly = true`.
unsigned analysisFuzzerSeed = 0;
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index f69b2557eec922..d7492c9e25db31 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -26,17 +26,18 @@ struct ConstantOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto constantOp = cast<arith::ConstantOp>(op);
+ auto type = constantOp.getType().dyn_cast<RankedTensorType>();
+
+ // Only ranked tensors are supported.
+ if (!type)
+ return failure();
Attribute memorySpace;
- if (options.defaultMemorySpace.has_value())
- memorySpace = *options.defaultMemorySpace;
+ if (auto memSpace = options.defaultMemorySpaceFn(type))
+ memorySpace = *memSpace;
else
return constantOp->emitError("could not infer memory space");
- // Only ranked tensors are supported.
- if (!isa<RankedTensorType>(constantOp.getType()))
- return failure();
-
// Only constants inside a module are supported.
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
if (!moduleOp)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 6ca9702cbbc66b..8f0f6d1fcc8490 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -682,11 +682,12 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
return bufferizableOp.getBufferType(value, options, invocationStack);
// Op is not bufferizable.
- if (!options.defaultMemorySpace.has_value())
+ auto memSpace =
+ options.defaultMemorySpaceFn(value.getType().cast<TensorType>());
+ if (!memSpace.has_value())
return op->emitError("could not infer memory space");
- return getMemRefType(value, options, /*layout=*/{},
- *options.defaultMemorySpace);
+ return getMemRefType(value, options, /*layout=*/{}, *memSpace);
}
bool bufferization::hasTensorSemantics(Operation *op) {
@@ -936,11 +937,12 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
// If we do not know the memory space and there is no default memory space,
// report a failure.
- if (!options.defaultMemorySpace.has_value())
+ auto memSpace =
+ options.defaultMemorySpaceFn(value.getType().cast<TensorType>());
+ if (!memSpace.has_value())
return op->emitError("could not infer memory space");
- return getMemRefType(value, options, /*layout=*/{},
- *options.defaultMemorySpace);
+ return getMemRefType(value, options, /*layout=*/{}, *memSpace);
}
bool bufferization::detail::defaultIsRepetitiveRegion(
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index eb4a96f3549904..34a0c594a5a5a3 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -234,8 +234,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
if (failed(copyBufferType))
return failure();
memorySpace = copyBufferType->getMemorySpace();
- } else if (options.defaultMemorySpace.has_value()) {
- memorySpace = *options.defaultMemorySpace;
+ } else if (auto ms = options.defaultMemorySpaceFn(getType())) {
+ memorySpace = *ms;
} else {
return getOperation()->emitError("could not infer memory space");
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index dc94b72edcdf0c..208cbda3a9eb63 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -210,8 +210,12 @@ struct OneShotBufferizePass
opt.dumpAliasSets = dumpAliasSets;
opt.setFunctionBoundaryTypeConversion(
parseLayoutMapOption(functionBoundaryTypeConversion));
- if (mustInferMemorySpace)
- opt.defaultMemorySpace = std::nullopt;
+ if (mustInferMemorySpace) {
+ opt.defaultMemorySpaceFn =
+ [](TensorType t) -> std::optional<Attribute> {
+ return std::nullopt;
+ };
+ }
opt.printConflicts = printConflicts;
opt.testAnalysisOnly = testAnalysisOnly;
opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 07cd1f90b17df4..4cdbbf35dc876b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -66,7 +66,7 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
assert(tensorType && "expected TensorType");
BaseMemRefType memrefType = options.functionArgTypeConverterFn(
- tensorType, *options.defaultMemorySpace, funcOp, options);
+ tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
index, BufferizationDialect::kBufferLayoutAttrName);
@@ -443,7 +443,8 @@ struct FuncOpInterface
// Note: If `inferFunctionResultLayout = true`, cast are later folded
// away.
BaseMemRefType resultType = options.functionArgTypeConverterFn(
- tensorType, *options.defaultMemorySpace, funcOp, options);
+ tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
+ options);
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
loc, resultType, returnVal);
returnValues.push_back(toMemrefOp);
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 678b7c099fa369..957f6314f35876 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -473,14 +473,14 @@ struct FromElementsOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto fromElementsOp = cast<tensor::FromElementsOp>(op);
+ auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
// TODO: Implement memory space for this op.
- if (options.defaultMemorySpace != Attribute())
+ if (options.defaultMemorySpaceFn(tensorType) != Attribute())
return op->emitError("memory space not implemented yet");
// Allocate a buffer for the result.
Location loc = op->getLoc();
- auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
auto shape = tensorType.getShape();
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
@@ -588,8 +588,10 @@ struct GenerateOpInterface
const BufferizationOptions &options) const {
auto generateOp = cast<tensor::GenerateOp>(op);
+ auto type = generateOp.getResult().getType();
+
// TODO: Implement memory space for this op.
- if (options.defaultMemorySpace != Attribute())
+ if (options.defaultMemorySpaceFn(type) != Attribute())
return op->emitError("memory space not implemented yet");
// Allocate memory.
@@ -1007,10 +1009,6 @@ struct SplatOpInterface
OpBuilder::InsertionGuard g(rewriter);
auto splatOp = cast<tensor::SplatOp>(op);
- // TODO: Implement memory space for this op.
- if (options.defaultMemorySpace != Attribute())
- return op->emitError("memory space not implemented yet");
-
// Allocate memory.
Location loc = op->getLoc();
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
@@ -1021,6 +1019,11 @@ struct SplatOpInterface
// Create linalg::MapOp.
auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
+
+ // TODO: Implement memory space for this op.
+ if (options.defaultMemorySpaceFn(tensorType) != Attribute())
+ return op->emitError("memory space not implemented yet");
+
auto linalgOp =
rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(),
/*init=*/*tensorAlloc);
diff --git a/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp b/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp
index fedfbe350a51a9..2991a3c165ee2d 100644
--- a/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp
+++ b/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp
@@ -44,8 +44,10 @@ struct TestTensorCopyInsertionPass
bufferization::OneShotBufferizationOptions options;
options.allowReturnAllocsFromLoops = allowReturnAllocsFromLoops;
options.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
- if (mustInferMemorySpace)
- options.defaultMemorySpace = std::nullopt;
+ if (mustInferMemorySpace) {
+ options.defaultMemorySpaceFn =
+ [](TensorType t) -> std::optional<Attribute> { return std::nullopt; };
+ }
if (failed(bufferization::insertTensorCopies(getOperation(), options)))
signalPassFailure();
}
More information about the Mlir-commits
mailing list