[Mlir-commits] [mlir] b06614e - [mlir][bufferization][NFC] Change signature of getMemRefType
Matthias Springer
llvmlistbot at llvm.org
Mon Jun 27 01:41:50 PDT 2022
Author: Matthias Springer
Date: 2022-06-27T10:41:40+02:00
New Revision: b06614e2e8d74339f65a46931dbf1521552df35b
URL: https://github.com/llvm/llvm-project/commit/b06614e2e8d74339f65a46931dbf1521552df35b
DIFF: https://github.com/llvm/llvm-project/commit/b06614e2e8d74339f65a46931dbf1521552df35b.diff
LOG: [mlir][bufferization][NFC] Change signature of getMemRefType
These functions now accep unsigned attributes for address spaces instead of Attributes.
Differential Revision: https://reviews.llvm.org/D128275
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index fa44fde98b6e..f28db1e26c09 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -513,18 +513,17 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
BaseMemRefType getMemRefType(TensorType tensorType,
const BufferizationOptions &options,
MemRefLayoutAttrInterface layout = {},
- Attribute memorySpace = {});
+ unsigned memorySpace = 0);
/// Return a MemRef type with fully dynamic layout. If the given tensor type
/// is unranked, return an unranked MemRef type.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
- Attribute memorySpace = {});
+ unsigned memorySpace = 0);
/// Return a MemRef type with a static identity layout (i.e., no layout map). If
/// the given tensor type is unranked, return an unranked MemRef type.
-BaseMemRefType
-getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
- Attribute memorySpace = {});
+BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
+ unsigned memorySpace = 0);
} // namespace bufferization
} // namespace mlir
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 85a3f562ce99..3fe32bc293ea 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -596,12 +596,15 @@ bool bufferization::isFunctionArgument(Value value) {
BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
const BufferizationOptions &options,
MemRefLayoutAttrInterface layout,
- Attribute memorySpace) {
+ unsigned memorySpace) {
+ auto memorySpaceAttr = IntegerAttr::get(
+ IntegerType::get(tensorType.getContext(), 64), memorySpace);
+
// Case 1: Unranked memref type.
if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
assert(!layout && "UnrankedTensorType cannot have a layout map");
return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
- memorySpace);
+ memorySpaceAttr);
}
// Case 2: Ranked memref type with specified layout.
@@ -609,7 +612,7 @@ BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
if (layout) {
return MemRefType::get(rankedTensorType.getShape(),
rankedTensorType.getElementType(), layout,
- memorySpace);
+ memorySpaceAttr);
}
// Case 3: Configured with "fully dynamic layout maps".
@@ -627,7 +630,7 @@ BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
BaseMemRefType
bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
- Attribute memorySpace) {
+ unsigned memorySpace) {
// Case 1: Unranked memref type.
if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
@@ -635,6 +638,8 @@ bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
}
// Case 2: Ranked memref type.
+ auto memorySpaceAttr = IntegerAttr::get(
+ IntegerType::get(tensorType.getContext(), 64), memorySpace);
auto rankedTensorType = tensorType.cast<RankedTensorType>();
int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
@@ -643,14 +648,14 @@ bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
dynamicStrides, dynamicOffset, rankedTensorType.getContext());
return MemRefType::get(rankedTensorType.getShape(),
rankedTensorType.getElementType(), stridedLayout,
- memorySpace);
+ memorySpaceAttr);
}
/// Return a MemRef type with a static identity layout (i.e., no layout map). If
/// the given tensor type is unranked, return an unranked MemRef type.
BaseMemRefType
bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
- Attribute memorySpace) {
+ unsigned memorySpace) {
// Case 1: Unranked memref type.
if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
@@ -659,8 +664,10 @@ bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
// Case 2: Ranked memref type.
auto rankedTensorType = tensorType.cast<RankedTensorType>();
+ auto memorySpaceAttr = IntegerAttr::get(
+ IntegerType::get(tensorType.getContext(), 64), memorySpace);
MemRefLayoutAttrInterface layout = {};
return MemRefType::get(rankedTensorType.getShape(),
rankedTensorType.getElementType(), layout,
- memorySpace);
+ memorySpaceAttr);
}
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 6bdae8ce7a05..7f8b6a35491d 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -54,7 +54,6 @@ struct CastOpInterface
// The result buffer still has the old (pre-cast) type.
Value resultBuffer = getBuffer(rewriter, castOp.getSource(), options);
auto sourceMemRefType = resultBuffer.getType().cast<BaseMemRefType>();
- Attribute memorySpace = sourceMemRefType.getMemorySpace();
TensorType resultTensorType =
castOp.getResult().getType().cast<TensorType>();
MemRefLayoutAttrInterface layout;
@@ -65,7 +64,8 @@ struct CastOpInterface
// Compute the new memref type.
Type resultMemRefType =
- getMemRefType(resultTensorType, options, layout, memorySpace);
+ getMemRefType(resultTensorType, options, layout,
+ sourceMemRefType.getMemorySpaceAsInt());
// Replace the op with a memref.cast.
assert(memref::CastOp::areCastCompatible(resultBuffer.getType(),
More information about the Mlir-commits
mailing list