[Mlir-commits] [mlir] 37eca08 - [mlir][NFC] Rename `MemRefType::getMemorySpace` to `getMemorySpaceAsInt`
Vladislav Vinogradov
llvmlistbot at llvm.org
Tue Mar 2 00:19:43 PST 2021
Author: Vladislav Vinogradov
Date: 2021-03-02T11:08:54+03:00
New Revision: 37eca08e5bcfbe926176412a4a0acf5b963da7e6
URL: https://github.com/llvm/llvm-project/commit/37eca08e5bcfbe926176412a4a0acf5b963da7e6
DIFF: https://github.com/llvm/llvm-project/commit/37eca08e5bcfbe926176412a4a0acf5b963da7e6.diff
LOG: [mlir][NFC] Rename `MemRefType::getMemorySpace` to `getMemorySpaceAsInt`
Just a pure method renaming.
It is a preparation step for replacing "memory space as raw integer"
with more generic "memory space as attribute", which will be done in
separate commit.
The `MemRefType::getMemorySpace` method will return `Attribute` and
become the main API, while `getMemorySpaceAsInt` will be declared as
deprecated and will be replaced in all in-tree dialects (also in separate
commits).
Reviewed By: mehdi_amini, rriddle
Differential Revision: https://reviews.llvm.org/D97476
Added:
Modified:
mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
mlir/include/mlir/IR/BuiltinTypes.h
mlir/lib/CAPI/IR/BuiltinTypes.cpp
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/BuiltinTypes.cpp
mlir/lib/Transforms/LoopFusion.cpp
mlir/lib/Transforms/Utils/LoopUtils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index 29fc305b682a..10216da70ab2 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -120,7 +120,7 @@ class AffineDmaStartOp
/// Returns the memory space of the src memref.
unsigned getSrcMemorySpace() {
- return getSrcMemRef().getType().cast<MemRefType>().getMemorySpace();
+ return getSrcMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
}
/// Returns the operand index of the dst memref.
@@ -141,7 +141,7 @@ class AffineDmaStartOp
/// Returns the memory space of the src memref.
unsigned getDstMemorySpace() {
- return getDstMemRef().getType().cast<MemRefType>().getMemorySpace();
+ return getDstMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
}
/// Returns the affine map used to access the dst memref.
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index 9a253f8e814a..241f4ed9fa84 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -177,10 +177,10 @@ class DmaStartOp
return getDstMemRef().getType().cast<MemRefType>().getRank();
}
unsigned getSrcMemorySpace() {
- return getSrcMemRef().getType().cast<MemRefType>().getMemorySpace();
+ return getSrcMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
}
unsigned getDstMemorySpace() {
- return getDstMemRef().getType().cast<MemRefType>().getMemorySpace();
+ return getDstMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
}
// Returns the destination memref indices for this DMA operation.
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index e3b8d597a2a7..61836b11fee8 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -293,7 +293,7 @@ class BaseMemRefType : public ShapedType {
static bool classof(Type type);
/// Returns the memory space in which data referred to by this memref resides.
- unsigned getMemorySpace() const;
+ unsigned getMemorySpaceAsInt() const;
};
//===----------------------------------------------------------------------===//
@@ -314,7 +314,7 @@ class MemRefType : public Type::TypeBase<MemRefType, BaseMemRefType,
explicit Builder(MemRefType other)
: shape(other.getShape()), elementType(other.getElementType()),
affineMaps(other.getAffineMaps()),
- memorySpace(other.getMemorySpace()) {}
+ memorySpace(other.getMemorySpaceAsInt()) {}
// Build from scratch.
Builder(ArrayRef<int64_t> shape, Type elementType)
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 10cdbc4b1658..e4442ac4c567 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -270,7 +270,7 @@ MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type, intptr_t pos) {
}
unsigned mlirMemRefTypeGetMemorySpace(MlirType type) {
- return unwrap(type).cast<MemRefType>().getMemorySpace();
+ return unwrap(type).cast<MemRefType>().getMemorySpaceAsInt();
}
bool mlirTypeIsAUnrankedMemRef(MlirType type) {
@@ -289,7 +289,7 @@ MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc,
}
unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type) {
- return unwrap(type).cast<UnrankedMemRefType>().getMemorySpace();
+ return unwrap(type).cast<UnrankedMemRefType>().getMemorySpaceAsInt();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 3deec2242c5e..bc4f65182bdb 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -118,7 +118,8 @@ struct LowerGpuOpsToNVVMOpsPass
/// converter drops the private memory space to support the use case above.
LLVMTypeConverter converter(m.getContext(), options);
converter.addConversion([&](MemRefType type) -> Optional<Type> {
- if (type.getMemorySpace() != gpu::GPUDialect::getPrivateAddressSpace())
+ if (type.getMemorySpaceAsInt() !=
+ gpu::GPUDialect::getPrivateAddressSpace())
return llvm::None;
return converter.convertType(MemRefType::Builder(type).setMemorySpace(0));
});
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index ce2590f2cdfa..3e11a5ef1a14 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -316,7 +316,8 @@ LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
Type elementType = unwrap(convertType(type.getElementType()));
if (!elementType)
return {};
- auto ptrTy = LLVM::LLVMPointerType::get(elementType, type.getMemorySpace());
+ auto ptrTy =
+ LLVM::LLVMPointerType::get(elementType, type.getMemorySpaceAsInt());
auto indexTy = getIndexType();
SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
@@ -388,7 +389,7 @@ Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) {
Type elementType = unwrap(convertType(type.getElementType()));
if (!elementType)
return {};
- return LLVM::LLVMPointerType::get(elementType, type.getMemorySpace());
+ return LLVM::LLVMPointerType::get(elementType, type.getMemorySpaceAsInt());
}
/// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type
@@ -1081,7 +1082,8 @@ bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps(
Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
auto elementType = type.getElementType();
auto structElementType = unwrap(typeConverter->convertType(elementType));
- return LLVM::LLVMPointerType::get(structElementType, type.getMemorySpace());
+ return LLVM::LLVMPointerType::get(structElementType,
+ type.getMemorySpaceAsInt());
}
void ConvertToLLVMPattern::getMemRefDescriptorSizes(
@@ -1899,7 +1901,7 @@ struct AllocOpLowering : public AllocLikeOpLowering {
Value alignedPtr = allocatedPtr;
if (alignment) {
- auto intPtrType = getIntPtrType(memRefType.getMemorySpace());
+ auto intPtrType = getIntPtrType(memRefType.getMemorySpaceAsInt());
// Compute the aligned type pointer.
Value allocatedInt =
rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, allocatedPtr);
@@ -2247,7 +2249,7 @@ struct GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<GlobalMemrefOp> {
rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
global, arrayTy, global.constant(), linkage, global.sym_name(),
- initialValue, type.getMemorySpace());
+ initialValue, type.getMemorySpaceAsInt());
return success();
}
};
@@ -2266,7 +2268,7 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering {
Operation *op) const override {
auto getGlobalOp = cast<GetGlobalMemrefOp>(op);
MemRefType type = getGlobalOp.result().getType().cast<MemRefType>();
- unsigned memSpace = type.getMemorySpace();
+ unsigned memSpace = type.getMemorySpaceAsInt();
Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
auto addressOf = rewriter.create<LLVM::AddressOfOp>(
@@ -2462,7 +2464,7 @@ static void extractPointersAndOffset(Location loc,
}
unsigned memorySpace =
- operandType.cast<UnrankedMemRefType>().getMemorySpace();
+ operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt();
Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
Type llvmElementType = unwrap(typeConverter.convertType(elementType));
Type elementPtrPtrType = LLVM::LLVMPointerType::get(
@@ -2591,7 +2593,7 @@ struct MemRefReshapeOpLowering
// Extract address space and element type.
auto targetType =
reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
- unsigned addressSpace = targetType.getMemorySpace();
+ unsigned addressSpace = targetType.getMemorySpaceAsInt();
Type elementType = targetType.getElementType();
// Create the unranked memref descriptor that holds the ranked one. The
@@ -2751,7 +2753,7 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
auto scalarMemRefType =
MemRefType::get({}, unrankedMemRefType.getElementType());
- unsigned addressSpace = unrankedMemRefType.getMemorySpace();
+ unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
// Extract pointer to the underlying ranked descriptor and bitcast it to a
// memref<element_type> descriptor pointer to minimize the number of GEP
@@ -3265,7 +3267,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
loc,
LLVM::LLVMPointerType::get(targetElementTy,
- viewMemRefType.getMemorySpace()),
+ viewMemRefType.getMemorySpaceAsInt()),
extracted);
targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
@@ -3274,7 +3276,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
bitcastPtr = rewriter.create<LLVM::BitcastOp>(
loc,
LLVM::LLVMPointerType::get(targetElementTy,
- viewMemRefType.getMemorySpace()),
+ viewMemRefType.getMemorySpaceAsInt()),
extracted);
targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
@@ -3491,7 +3493,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
loc,
LLVM::LLVMPointerType::get(targetElementTy,
- srcMemRefType.getMemorySpace()),
+ srcMemRefType.getMemorySpaceAsInt()),
allocatedPtr);
targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
@@ -3502,7 +3504,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
bitcastPtr = rewriter.create<LLVM::BitcastOp>(
loc,
LLVM::LLVMPointerType::get(targetElementTy,
- srcMemRefType.getMemorySpace()),
+ srcMemRefType.getMemorySpaceAsInt()),
alignedPtr);
targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index c2db461cae79..e07934b16fe3 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -194,7 +194,7 @@ static bool isAllocationSupported(MemRefType t) {
// shape and int or float or vector of int or float element type.
if (!(t.hasStaticShape() &&
SPIRVTypeConverter::getMemorySpaceForStorageClass(
- spirv::StorageClass::Workgroup) == t.getMemorySpace()))
+ spirv::StorageClass::Workgroup) == t.getMemorySpaceAsInt()))
return false;
Type elementType = t.getElementType();
if (auto vecType = elementType.dyn_cast<VectorType>())
@@ -207,7 +207,8 @@ static bool isAllocationSupported(MemRefType t) {
/// type. Returns None on failure.
static Optional<spirv::Scope> getAtomicOpScope(MemRefType t) {
Optional<spirv::StorageClass> storageClass =
- SPIRVTypeConverter::getStorageClassForMemorySpace(t.getMemorySpace());
+ SPIRVTypeConverter::getStorageClassForMemorySpace(
+ t.getMemorySpaceAsInt());
if (!storageClass)
return {};
switch (*storageClass) {
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d567e065479d..ac0e3fc003d1 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -188,7 +188,7 @@ static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
SmallVector<int64_t, 4> strides;
auto successStrides = getStridesAndOffset(memRefType, strides, offset);
if (failed(successStrides) || strides.back() != 1 ||
- memRefType.getMemorySpace() != 0)
+ memRefType.getMemorySpaceAsInt() != 0)
return failure();
auto pType = MemRefDescriptor(memref).getElementPtrType();
auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
@@ -200,7 +200,7 @@ static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
// will be in the same address space as the incoming memref type.
static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
Value ptr, MemRefType memRefType, Type vt) {
- auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpace());
+ auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt());
return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
}
diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
index 005e7b30ea7c..42c072626cf5 100644
--- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
+++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
@@ -94,8 +94,8 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
// MUBUF instruction operate only on addresspace 0(unified) or 1(global)
// In case of 3(LDS): fall back to vector->llvm pass
// In case of 5(VGPR): wrong
- if ((memRefType.getMemorySpace() != 0) &&
- (memRefType.getMemorySpace() != 1))
+ if ((memRefType.getMemorySpaceAsInt() != 0) &&
+ (memRefType.getMemorySpaceAsInt() != 1))
return failure();
// Note that the dataPtr starts at the offset address specified by
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 503a2f001345..724d0afdf67b 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -115,9 +115,11 @@ class NDTransferOpHelper {
VectorType::get(vectorType.getShape().take_back(minorRank),
vectorType.getElementType());
/// Memref of minor vector type is used for individual transfers.
- memRefMinorVectorType = MemRefType::get(
- majorVectorType.getShape(), minorVectorType, {},
- xferOp.getShapedType().template cast<MemRefType>().getMemorySpace());
+ memRefMinorVectorType =
+ MemRefType::get(majorVectorType.getShape(), minorVectorType, {},
+ xferOp.getShapedType()
+ .template cast<MemRefType>()
+ .getMemorySpaceAsInt());
}
LogicalResult doReplace();
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index e72119545001..27fde9b87405 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -727,7 +727,7 @@ static LogicalResult verifyAttributions(Operation *op,
if (!type)
return op->emitOpError() << "expected memref type in attribution";
- if (type.getMemorySpace() != memorySpace) {
+ if (type.getMemorySpaceAsInt() != memorySpace) {
return op->emitOpError()
<< "expected memory space " << memorySpace << " in attribution";
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 1361fa06546c..d63e7753f93e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1345,7 +1345,7 @@ static LogicalResult verifyCast(DialectCastOp op, Type llvmType, Type type,
if (!memrefType.hasStaticShape())
return op->emitOpError(
"unexpected bare pointer for dynamically shaped memref");
- if (memrefType.getMemorySpace() != ptrType.getAddressSpace())
+ if (memrefType.getMemorySpaceAsInt() != ptrType.getAddressSpace())
return op->emitError("invalid conversion between memref and pointer in "
"
diff erent memory spaces");
@@ -1369,7 +1369,7 @@ static LogicalResult verifyCast(DialectCastOp op, Type llvmType, Type type,
// The first two elements are pointers to the element type.
auto allocatedPtr = structType.getBody()[0].dyn_cast<LLVMPointerType>();
if (!allocatedPtr ||
- allocatedPtr.getAddressSpace() != memrefType.getMemorySpace())
+ allocatedPtr.getAddressSpace() != memrefType.getMemorySpaceAsInt())
return op->emitOpError("expected first element of a memref descriptor to "
"be a pointer in the address space of the memref");
if (failed(verifyCast(op, allocatedPtr.getElementType(),
@@ -1378,7 +1378,7 @@ static LogicalResult verifyCast(DialectCastOp op, Type llvmType, Type type,
auto alignedPtr = structType.getBody()[1].dyn_cast<LLVMPointerType>();
if (!alignedPtr ||
- alignedPtr.getAddressSpace() != memrefType.getMemorySpace())
+ alignedPtr.getAddressSpace() != memrefType.getMemorySpaceAsInt())
return op->emitOpError(
"expected second element of a memref descriptor to "
"be a pointer in the address space of the memref");
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index a91244eef23c..47269f4d5ec2 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -344,7 +344,8 @@ static Optional<Type> convertTensorType(const spirv::TargetEnv &targetEnv,
static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
MemRefType type) {
Optional<spirv::StorageClass> storageClass =
- SPIRVTypeConverter::getStorageClassForMemorySpace(type.getMemorySpace());
+ SPIRVTypeConverter::getStorageClassForMemorySpace(
+ type.getMemorySpaceAsInt());
if (!storageClass) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot convert memory space\n");
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 539252af5cf9..536d71d89d4f 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2096,7 +2096,7 @@ bool MemRefCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
return false;
}
- if (aT.getMemorySpace() != bT.getMemorySpace())
+ if (aT.getMemorySpaceAsInt() != bT.getMemorySpaceAsInt())
return false;
// They must have the same rank, and any specified dimensions must match.
@@ -2123,8 +2123,10 @@ bool MemRefCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (aEltType != bEltType)
return false;
- auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
- auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
+ auto aMemSpace =
+ (aT) ? aT.getMemorySpaceAsInt() : uaT.getMemorySpaceAsInt();
+ auto bMemSpace =
+ (bT) ? bT.getMemorySpaceAsInt() : ubT.getMemorySpaceAsInt();
if (aMemSpace != bMemSpace)
return false;
@@ -2201,7 +2203,7 @@ static LogicalResult verify(MemRefReinterpretCastOp op) {
// The source and result memrefs should be in the same memory space.
auto srcType = op.source().getType().cast<BaseMemRefType>();
auto resultType = op.getType().cast<MemRefType>();
- if (srcType.getMemorySpace() != resultType.getMemorySpace())
+ if (srcType.getMemorySpaceAsInt() != resultType.getMemorySpaceAsInt())
return op.emitError("
diff erent memory spaces specified for source type ")
<< srcType << " and result memref type " << resultType;
if (srcType.getElementType() != resultType.getElementType())
@@ -2875,7 +2877,7 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
staticSizes, sourceMemRefType.getElementType(),
makeStridedLinearLayoutMap(targetStrides, targetOffset,
sourceMemRefType.getContext()),
- sourceMemRefType.getMemorySpace());
+ sourceMemRefType.getMemorySpaceAsInt());
}
Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
@@ -2932,7 +2934,7 @@ Type SubViewOp::inferRankReducedResultType(
map = getProjectedMap(maps.front(), dimsToProject);
inferredType =
MemRefType::get(projectedShape, inferredType.getElementType(), map,
- inferredType.getMemorySpace());
+ inferredType.getMemorySpaceAsInt());
}
return inferredType;
}
@@ -3154,7 +3156,7 @@ isRankReducedType(Type originalType, Type candidateReducedType,
// Strided layout logic is relevant for MemRefType only.
MemRefType original = originalType.cast<MemRefType>();
MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
- if (original.getMemorySpace() != candidateReduced.getMemorySpace())
+ if (original.getMemorySpaceAsInt() != candidateReduced.getMemorySpaceAsInt())
return SubViewVerificationResult::MemSpaceMismatch;
llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue();
@@ -3228,7 +3230,7 @@ static LogicalResult verify(SubViewOp op) {
MemRefType subViewType = op.getType();
// The base memref and the view memref should be in the same memory space.
- if (baseType.getMemorySpace() != subViewType.getMemorySpace())
+ if (baseType.getMemorySpaceAsInt() != subViewType.getMemorySpaceAsInt())
return op.emitError("
diff erent memory spaces specified for base memref "
"type ")
<< baseType << " and subview memref type " << subViewType;
@@ -4179,7 +4181,7 @@ static LogicalResult verify(ViewOp op) {
return op.emitError("unsupported map for result memref type ") << viewType;
// The base memref and the view memref should be in the same memory space.
- if (baseType.getMemorySpace() != viewType.getMemorySpace())
+ if (baseType.getMemorySpaceAsInt() != viewType.getMemorySpaceAsInt())
return op.emitError("
diff erent memory spaces specified for base memref "
"type ")
<< baseType << " and view memref type " << viewType;
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 4702626c3e8c..b08a696281fa 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -3174,7 +3174,7 @@ void TypeCastOp::build(OpBuilder &builder, OperationState &result,
VectorType::get(extractShape(memRefType),
getElementTypeOrSelf(getElementTypeOrSelf(memRefType)));
result.addTypes(
- MemRefType::get({}, vectorType, {}, memRefType.getMemorySpace()));
+ MemRefType::get({}, vectorType, {}, memRefType.getMemorySpaceAsInt()));
}
static LogicalResult verify(TypeCastOp op) {
@@ -3183,8 +3183,8 @@ static LogicalResult verify(TypeCastOp op) {
return op.emitOpError("expects operand to be a memref with no layout");
if (!op.getResultMemRefType().getAffineMaps().empty())
return op.emitOpError("expects result to be a memref with no layout");
- if (op.getResultMemRefType().getMemorySpace() !=
- op.getMemRefType().getMemorySpace())
+ if (op.getResultMemRefType().getMemorySpaceAsInt() !=
+ op.getMemRefType().getMemorySpaceAsInt())
return op.emitOpError("expects result in same memory space");
auto sourceType = op.getMemRefType();
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 437d353655ff..c0b20c064a78 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1882,16 +1882,16 @@ void ModulePrinter::printType(Type type) {
printAttribute(AffineMapAttr::get(map));
}
// Only print the memory space if it is the non-default one.
- if (memrefTy.getMemorySpace())
- os << ", " << memrefTy.getMemorySpace();
+ if (memrefTy.getMemorySpaceAsInt())
+ os << ", " << memrefTy.getMemorySpaceAsInt();
os << '>';
})
.Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
os << "memref<*x";
printType(memrefTy.getElementType());
// Only print the memory space if it is the non-default one.
- if (memrefTy.getMemorySpace())
- os << ", " << memrefTy.getMemorySpace();
+ if (memrefTy.getMemorySpaceAsInt())
+ os << ", " << memrefTy.getMemorySpaceAsInt();
os << '>';
})
.Case<ComplexType>([&](ComplexType complexTy) {
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 9b15854919e0..c84569a53531 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -206,7 +206,7 @@ ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) {
if (auto other = dyn_cast<UnrankedMemRefType>()) {
MemRefType::Builder b(shape, elementType);
- b.setMemorySpace(other.getMemorySpace());
+ b.setMemorySpace(other.getMemorySpaceAsInt());
return b;
}
@@ -229,7 +229,7 @@ ShapedType ShapedType::clone(ArrayRef<int64_t> shape) {
if (auto other = dyn_cast<UnrankedMemRefType>()) {
MemRefType::Builder b(shape, other.getElementType());
b.setShape(shape);
- b.setMemorySpace(other.getMemorySpace());
+ b.setMemorySpace(other.getMemorySpaceAsInt());
return b;
}
@@ -250,7 +250,7 @@ ShapedType ShapedType::clone(Type elementType) {
}
if (auto other = dyn_cast<UnrankedMemRefType>()) {
- return UnrankedMemRefType::get(elementType, other.getMemorySpace());
+ return UnrankedMemRefType::get(elementType, other.getMemorySpaceAsInt());
}
if (isa<TensorType>()) {
@@ -472,7 +472,7 @@ UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
// BaseMemRefType
//===----------------------------------------------------------------------===//
-unsigned BaseMemRefType::getMemorySpace() const {
+unsigned BaseMemRefType::getMemorySpaceAsInt() const {
return static_cast<ImplType *>(impl)->memorySpace;
}
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 26a20ee365b5..d6d18b3c6f7a 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -947,7 +947,7 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
if (bufSize <= localBufSizeThreshold && fastMemorySpace.hasValue()) {
newMemSpace = fastMemorySpace.getValue();
} else {
- newMemSpace = oldMemRefType.getMemorySpace();
+ newMemSpace = oldMemRefType.getMemorySpaceAsInt();
}
auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(),
{}, newMemSpace);
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 77d24fb0c161..71a0fc8e5d89 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -2725,12 +2725,12 @@ uint64_t mlir::affineDataCopyGenerate(Block::iterator begin,
// Gather regions to allocate to buffers in faster memory space.
if (auto loadOp = dyn_cast<AffineLoadOp>(opInst)) {
if ((filterMemRef.hasValue() && filterMemRef != loadOp.getMemRef()) ||
- (loadOp.getMemRefType().getMemorySpace() !=
+ (loadOp.getMemRefType().getMemorySpaceAsInt() !=
copyOptions.slowMemorySpace))
return;
} else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst)) {
if ((filterMemRef.hasValue() && filterMemRef != storeOp.getMemRef()) ||
- storeOp.getMemRefType().getMemorySpace() !=
+ storeOp.getMemRefType().getMemorySpaceAsInt() !=
copyOptions.slowMemorySpace)
return;
} else {
More information about the Mlir-commits
mailing list