[Mlir-commits] [mlir] 37eca08 - [mlir][NFC] Rename `MemRefType::getMemorySpace` to `getMemorySpaceAsInt`

Vladislav Vinogradov llvmlistbot at llvm.org
Tue Mar 2 00:19:43 PST 2021


Author: Vladislav Vinogradov
Date: 2021-03-02T11:08:54+03:00
New Revision: 37eca08e5bcfbe926176412a4a0acf5b963da7e6

URL: https://github.com/llvm/llvm-project/commit/37eca08e5bcfbe926176412a4a0acf5b963da7e6
DIFF: https://github.com/llvm/llvm-project/commit/37eca08e5bcfbe926176412a4a0acf5b963da7e6.diff

LOG: [mlir][NFC] Rename `MemRefType::getMemorySpace` to `getMemorySpaceAsInt`

Just a pure method renaming.

It is a preparation step for replacing "memory space as raw integer"
with more generic "memory space as attribute", which will be done in
separate commit.

The `MemRefType::getMemorySpace` method will return `Attribute` and
become the main API, while `getMemorySpaceAsInt` will be declared as
deprecated and will be replaced in all in-tree dialects (also in separate
commits).

Reviewed By: mehdi_amini, rriddle

Differential Revision: https://reviews.llvm.org/D97476

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/lib/CAPI/IR/BuiltinTypes.cpp
    mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
    mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/lib/Transforms/LoopFusion.cpp
    mlir/lib/Transforms/Utils/LoopUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index 29fc305b682a..10216da70ab2 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -120,7 +120,7 @@ class AffineDmaStartOp
 
   /// Returns the memory space of the src memref.
   unsigned getSrcMemorySpace() {
-    return getSrcMemRef().getType().cast<MemRefType>().getMemorySpace();
+    return getSrcMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
   }
 
   /// Returns the operand index of the dst memref.
@@ -141,7 +141,7 @@ class AffineDmaStartOp
 
   /// Returns the memory space of the src memref.
   unsigned getDstMemorySpace() {
-    return getDstMemRef().getType().cast<MemRefType>().getMemorySpace();
+    return getDstMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
   }
 
   /// Returns the affine map used to access the dst memref.

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index 9a253f8e814a..241f4ed9fa84 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -177,10 +177,10 @@ class DmaStartOp
     return getDstMemRef().getType().cast<MemRefType>().getRank();
   }
   unsigned getSrcMemorySpace() {
-    return getSrcMemRef().getType().cast<MemRefType>().getMemorySpace();
+    return getSrcMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
   }
   unsigned getDstMemorySpace() {
-    return getDstMemRef().getType().cast<MemRefType>().getMemorySpace();
+    return getDstMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
   }
 
   // Returns the destination memref indices for this DMA operation.

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index e3b8d597a2a7..61836b11fee8 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -293,7 +293,7 @@ class BaseMemRefType : public ShapedType {
   static bool classof(Type type);
 
   /// Returns the memory space in which data referred to by this memref resides.
-  unsigned getMemorySpace() const;
+  unsigned getMemorySpaceAsInt() const;
 };
 
 //===----------------------------------------------------------------------===//
@@ -314,7 +314,7 @@ class MemRefType : public Type::TypeBase<MemRefType, BaseMemRefType,
     explicit Builder(MemRefType other)
         : shape(other.getShape()), elementType(other.getElementType()),
           affineMaps(other.getAffineMaps()),
-          memorySpace(other.getMemorySpace()) {}
+          memorySpace(other.getMemorySpaceAsInt()) {}
 
     // Build from scratch.
     Builder(ArrayRef<int64_t> shape, Type elementType)

diff  --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 10cdbc4b1658..e4442ac4c567 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -270,7 +270,7 @@ MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type, intptr_t pos) {
 }
 
 unsigned mlirMemRefTypeGetMemorySpace(MlirType type) {
-  return unwrap(type).cast<MemRefType>().getMemorySpace();
+  return unwrap(type).cast<MemRefType>().getMemorySpaceAsInt();
 }
 
 bool mlirTypeIsAUnrankedMemRef(MlirType type) {
@@ -289,7 +289,7 @@ MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc,
 }
 
 unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type) {
-  return unwrap(type).cast<UnrankedMemRefType>().getMemorySpace();
+  return unwrap(type).cast<UnrankedMemRefType>().getMemorySpaceAsInt();
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 3deec2242c5e..bc4f65182bdb 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -118,7 +118,8 @@ struct LowerGpuOpsToNVVMOpsPass
     /// converter drops the private memory space to support the use case above.
     LLVMTypeConverter converter(m.getContext(), options);
     converter.addConversion([&](MemRefType type) -> Optional<Type> {
-      if (type.getMemorySpace() != gpu::GPUDialect::getPrivateAddressSpace())
+      if (type.getMemorySpaceAsInt() !=
+          gpu::GPUDialect::getPrivateAddressSpace())
         return llvm::None;
       return converter.convertType(MemRefType::Builder(type).setMemorySpace(0));
     });

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index ce2590f2cdfa..3e11a5ef1a14 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -316,7 +316,8 @@ LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
   Type elementType = unwrap(convertType(type.getElementType()));
   if (!elementType)
     return {};
-  auto ptrTy = LLVM::LLVMPointerType::get(elementType, type.getMemorySpace());
+  auto ptrTy =
+      LLVM::LLVMPointerType::get(elementType, type.getMemorySpaceAsInt());
   auto indexTy = getIndexType();
 
   SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
@@ -388,7 +389,7 @@ Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) {
   Type elementType = unwrap(convertType(type.getElementType()));
   if (!elementType)
     return {};
-  return LLVM::LLVMPointerType::get(elementType, type.getMemorySpace());
+  return LLVM::LLVMPointerType::get(elementType, type.getMemorySpaceAsInt());
 }
 
 /// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type
@@ -1081,7 +1082,8 @@ bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps(
 Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
   auto elementType = type.getElementType();
   auto structElementType = unwrap(typeConverter->convertType(elementType));
-  return LLVM::LLVMPointerType::get(structElementType, type.getMemorySpace());
+  return LLVM::LLVMPointerType::get(structElementType,
+                                    type.getMemorySpaceAsInt());
 }
 
 void ConvertToLLVMPattern::getMemRefDescriptorSizes(
@@ -1899,7 +1901,7 @@ struct AllocOpLowering : public AllocLikeOpLowering {
 
     Value alignedPtr = allocatedPtr;
     if (alignment) {
-      auto intPtrType = getIntPtrType(memRefType.getMemorySpace());
+      auto intPtrType = getIntPtrType(memRefType.getMemorySpaceAsInt());
       // Compute the aligned type pointer.
       Value allocatedInt =
           rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, allocatedPtr);
@@ -2247,7 +2249,7 @@ struct GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<GlobalMemrefOp> {
 
     rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
         global, arrayTy, global.constant(), linkage, global.sym_name(),
-        initialValue, type.getMemorySpace());
+        initialValue, type.getMemorySpaceAsInt());
     return success();
   }
 };
@@ -2266,7 +2268,7 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering {
                                           Operation *op) const override {
     auto getGlobalOp = cast<GetGlobalMemrefOp>(op);
     MemRefType type = getGlobalOp.result().getType().cast<MemRefType>();
-    unsigned memSpace = type.getMemorySpace();
+    unsigned memSpace = type.getMemorySpaceAsInt();
 
     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
     auto addressOf = rewriter.create<LLVM::AddressOfOp>(
@@ -2462,7 +2464,7 @@ static void extractPointersAndOffset(Location loc,
   }
 
   unsigned memorySpace =
-      operandType.cast<UnrankedMemRefType>().getMemorySpace();
+      operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt();
   Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
   Type llvmElementType = unwrap(typeConverter.convertType(elementType));
   Type elementPtrPtrType = LLVM::LLVMPointerType::get(
@@ -2591,7 +2593,7 @@ struct MemRefReshapeOpLowering
     // Extract address space and element type.
     auto targetType =
         reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
-    unsigned addressSpace = targetType.getMemorySpace();
+    unsigned addressSpace = targetType.getMemorySpaceAsInt();
     Type elementType = targetType.getElementType();
 
     // Create the unranked memref descriptor that holds the ranked one. The
@@ -2751,7 +2753,7 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
     auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
     auto scalarMemRefType =
         MemRefType::get({}, unrankedMemRefType.getElementType());
-    unsigned addressSpace = unrankedMemRefType.getMemorySpace();
+    unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
 
     // Extract pointer to the underlying ranked descriptor and bitcast it to a
     // memref<element_type> descriptor pointer to minimize the number of GEP
@@ -3265,7 +3267,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
         loc,
         LLVM::LLVMPointerType::get(targetElementTy,
-                                   viewMemRefType.getMemorySpace()),
+                                   viewMemRefType.getMemorySpaceAsInt()),
         extracted);
     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
 
@@ -3274,7 +3276,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
         loc,
         LLVM::LLVMPointerType::get(targetElementTy,
-                                   viewMemRefType.getMemorySpace()),
+                                   viewMemRefType.getMemorySpaceAsInt()),
         extracted);
     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
 
@@ -3491,7 +3493,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
         loc,
         LLVM::LLVMPointerType::get(targetElementTy,
-                                   srcMemRefType.getMemorySpace()),
+                                   srcMemRefType.getMemorySpaceAsInt()),
         allocatedPtr);
     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
 
@@ -3502,7 +3504,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
         loc,
         LLVM::LLVMPointerType::get(targetElementTy,
-                                   srcMemRefType.getMemorySpace()),
+                                   srcMemRefType.getMemorySpaceAsInt()),
         alignedPtr);
     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
 

diff  --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index c2db461cae79..e07934b16fe3 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -194,7 +194,7 @@ static bool isAllocationSupported(MemRefType t) {
   // shape and int or float or vector of int or float element type.
   if (!(t.hasStaticShape() &&
         SPIRVTypeConverter::getMemorySpaceForStorageClass(
-            spirv::StorageClass::Workgroup) == t.getMemorySpace()))
+            spirv::StorageClass::Workgroup) == t.getMemorySpaceAsInt()))
     return false;
   Type elementType = t.getElementType();
   if (auto vecType = elementType.dyn_cast<VectorType>())
@@ -207,7 +207,8 @@ static bool isAllocationSupported(MemRefType t) {
 /// type. Returns None on failure.
 static Optional<spirv::Scope> getAtomicOpScope(MemRefType t) {
   Optional<spirv::StorageClass> storageClass =
-      SPIRVTypeConverter::getStorageClassForMemorySpace(t.getMemorySpace());
+      SPIRVTypeConverter::getStorageClassForMemorySpace(
+          t.getMemorySpaceAsInt());
   if (!storageClass)
     return {};
   switch (*storageClass) {

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d567e065479d..ac0e3fc003d1 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -188,7 +188,7 @@ static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
   SmallVector<int64_t, 4> strides;
   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
   if (failed(successStrides) || strides.back() != 1 ||
-      memRefType.getMemorySpace() != 0)
+      memRefType.getMemorySpaceAsInt() != 0)
     return failure();
   auto pType = MemRefDescriptor(memref).getElementPtrType();
   auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
@@ -200,7 +200,7 @@ static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
 // will be in the same address space as the incoming memref type.
 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
                          Value ptr, MemRefType memRefType, Type vt) {
-  auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpace());
+  auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt());
   return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
 }
 

diff  --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
index 005e7b30ea7c..42c072626cf5 100644
--- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
+++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
@@ -94,8 +94,8 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
     // MUBUF instruction operate only on addresspace 0(unified) or 1(global)
     // In case of 3(LDS): fall back to vector->llvm pass
     // In case of 5(VGPR): wrong
-    if ((memRefType.getMemorySpace() != 0) &&
-        (memRefType.getMemorySpace() != 1))
+    if ((memRefType.getMemorySpaceAsInt() != 0) &&
+        (memRefType.getMemorySpaceAsInt() != 1))
       return failure();
 
     // Note that the dataPtr starts at the offset address specified by

diff  --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 503a2f001345..724d0afdf67b 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -115,9 +115,11 @@ class NDTransferOpHelper {
         VectorType::get(vectorType.getShape().take_back(minorRank),
                         vectorType.getElementType());
     /// Memref of minor vector type is used for individual transfers.
-    memRefMinorVectorType = MemRefType::get(
-        majorVectorType.getShape(), minorVectorType, {},
-        xferOp.getShapedType().template cast<MemRefType>().getMemorySpace());
+    memRefMinorVectorType =
+        MemRefType::get(majorVectorType.getShape(), minorVectorType, {},
+                        xferOp.getShapedType()
+                            .template cast<MemRefType>()
+                            .getMemorySpaceAsInt());
   }
 
   LogicalResult doReplace();

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index e72119545001..27fde9b87405 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -727,7 +727,7 @@ static LogicalResult verifyAttributions(Operation *op,
     if (!type)
       return op->emitOpError() << "expected memref type in attribution";
 
-    if (type.getMemorySpace() != memorySpace) {
+    if (type.getMemorySpaceAsInt() != memorySpace) {
       return op->emitOpError()
              << "expected memory space " << memorySpace << " in attribution";
     }

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 1361fa06546c..d63e7753f93e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1345,7 +1345,7 @@ static LogicalResult verifyCast(DialectCastOp op, Type llvmType, Type type,
       if (!memrefType.hasStaticShape())
         return op->emitOpError(
             "unexpected bare pointer for dynamically shaped memref");
-      if (memrefType.getMemorySpace() != ptrType.getAddressSpace())
+      if (memrefType.getMemorySpaceAsInt() != ptrType.getAddressSpace())
         return op->emitError("invalid conversion between memref and pointer in "
                              "
diff erent memory spaces");
 
@@ -1369,7 +1369,7 @@ static LogicalResult verifyCast(DialectCastOp op, Type llvmType, Type type,
     // The first two elements are pointers to the element type.
     auto allocatedPtr = structType.getBody()[0].dyn_cast<LLVMPointerType>();
     if (!allocatedPtr ||
-        allocatedPtr.getAddressSpace() != memrefType.getMemorySpace())
+        allocatedPtr.getAddressSpace() != memrefType.getMemorySpaceAsInt())
       return op->emitOpError("expected first element of a memref descriptor to "
                              "be a pointer in the address space of the memref");
     if (failed(verifyCast(op, allocatedPtr.getElementType(),
@@ -1378,7 +1378,7 @@ static LogicalResult verifyCast(DialectCastOp op, Type llvmType, Type type,
 
     auto alignedPtr = structType.getBody()[1].dyn_cast<LLVMPointerType>();
     if (!alignedPtr ||
-        alignedPtr.getAddressSpace() != memrefType.getMemorySpace())
+        alignedPtr.getAddressSpace() != memrefType.getMemorySpaceAsInt())
       return op->emitOpError(
           "expected second element of a memref descriptor to "
           "be a pointer in the address space of the memref");

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index a91244eef23c..47269f4d5ec2 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -344,7 +344,8 @@ static Optional<Type> convertTensorType(const spirv::TargetEnv &targetEnv,
 static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
                                         MemRefType type) {
   Optional<spirv::StorageClass> storageClass =
-      SPIRVTypeConverter::getStorageClassForMemorySpace(type.getMemorySpace());
+      SPIRVTypeConverter::getStorageClassForMemorySpace(
+          type.getMemorySpaceAsInt());
   if (!storageClass) {
     LLVM_DEBUG(llvm::dbgs()
                << type << " illegal: cannot convert memory space\n");

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 539252af5cf9..536d71d89d4f 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2096,7 +2096,7 @@ bool MemRefCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
         if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
           return false;
     }
-    if (aT.getMemorySpace() != bT.getMemorySpace())
+    if (aT.getMemorySpaceAsInt() != bT.getMemorySpaceAsInt())
       return false;
 
     // They must have the same rank, and any specified dimensions must match.
@@ -2123,8 +2123,10 @@ bool MemRefCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
     if (aEltType != bEltType)
       return false;
 
-    auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
-    auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
+    auto aMemSpace =
+        (aT) ? aT.getMemorySpaceAsInt() : uaT.getMemorySpaceAsInt();
+    auto bMemSpace =
+        (bT) ? bT.getMemorySpaceAsInt() : ubT.getMemorySpaceAsInt();
     if (aMemSpace != bMemSpace)
       return false;
 
@@ -2201,7 +2203,7 @@ static LogicalResult verify(MemRefReinterpretCastOp op) {
   // The source and result memrefs should be in the same memory space.
   auto srcType = op.source().getType().cast<BaseMemRefType>();
   auto resultType = op.getType().cast<MemRefType>();
-  if (srcType.getMemorySpace() != resultType.getMemorySpace())
+  if (srcType.getMemorySpaceAsInt() != resultType.getMemorySpaceAsInt())
     return op.emitError("
diff erent memory spaces specified for source type ")
            << srcType << " and result memref type " << resultType;
   if (srcType.getElementType() != resultType.getElementType())
@@ -2875,7 +2877,7 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
       staticSizes, sourceMemRefType.getElementType(),
       makeStridedLinearLayoutMap(targetStrides, targetOffset,
                                  sourceMemRefType.getContext()),
-      sourceMemRefType.getMemorySpace());
+      sourceMemRefType.getMemorySpaceAsInt());
 }
 
 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
@@ -2932,7 +2934,7 @@ Type SubViewOp::inferRankReducedResultType(
       map = getProjectedMap(maps.front(), dimsToProject);
     inferredType =
         MemRefType::get(projectedShape, inferredType.getElementType(), map,
-                        inferredType.getMemorySpace());
+                        inferredType.getMemorySpaceAsInt());
   }
   return inferredType;
 }
@@ -3154,7 +3156,7 @@ isRankReducedType(Type originalType, Type candidateReducedType,
   // Strided layout logic is relevant for MemRefType only.
   MemRefType original = originalType.cast<MemRefType>();
   MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
-  if (original.getMemorySpace() != candidateReduced.getMemorySpace())
+  if (original.getMemorySpaceAsInt() != candidateReduced.getMemorySpaceAsInt())
     return SubViewVerificationResult::MemSpaceMismatch;
 
   llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue();
@@ -3228,7 +3230,7 @@ static LogicalResult verify(SubViewOp op) {
   MemRefType subViewType = op.getType();
 
   // The base memref and the view memref should be in the same memory space.
-  if (baseType.getMemorySpace() != subViewType.getMemorySpace())
+  if (baseType.getMemorySpaceAsInt() != subViewType.getMemorySpaceAsInt())
     return op.emitError("
diff erent memory spaces specified for base memref "
                         "type ")
            << baseType << " and subview memref type " << subViewType;
@@ -4179,7 +4181,7 @@ static LogicalResult verify(ViewOp op) {
     return op.emitError("unsupported map for result memref type ") << viewType;
 
   // The base memref and the view memref should be in the same memory space.
-  if (baseType.getMemorySpace() != viewType.getMemorySpace())
+  if (baseType.getMemorySpaceAsInt() != viewType.getMemorySpaceAsInt())
     return op.emitError("
diff erent memory spaces specified for base memref "
                         "type ")
            << baseType << " and view memref type " << viewType;

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 4702626c3e8c..b08a696281fa 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -3174,7 +3174,7 @@ void TypeCastOp::build(OpBuilder &builder, OperationState &result,
       VectorType::get(extractShape(memRefType),
                       getElementTypeOrSelf(getElementTypeOrSelf(memRefType)));
   result.addTypes(
-      MemRefType::get({}, vectorType, {}, memRefType.getMemorySpace()));
+      MemRefType::get({}, vectorType, {}, memRefType.getMemorySpaceAsInt()));
 }
 
 static LogicalResult verify(TypeCastOp op) {
@@ -3183,8 +3183,8 @@ static LogicalResult verify(TypeCastOp op) {
     return op.emitOpError("expects operand to be a memref with no layout");
   if (!op.getResultMemRefType().getAffineMaps().empty())
     return op.emitOpError("expects result to be a memref with no layout");
-  if (op.getResultMemRefType().getMemorySpace() !=
-      op.getMemRefType().getMemorySpace())
+  if (op.getResultMemRefType().getMemorySpaceAsInt() !=
+      op.getMemRefType().getMemorySpaceAsInt())
     return op.emitOpError("expects result in same memory space");
 
   auto sourceType = op.getMemRefType();

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 437d353655ff..c0b20c064a78 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1882,16 +1882,16 @@ void ModulePrinter::printType(Type type) {
           printAttribute(AffineMapAttr::get(map));
         }
         // Only print the memory space if it is the non-default one.
-        if (memrefTy.getMemorySpace())
-          os << ", " << memrefTy.getMemorySpace();
+        if (memrefTy.getMemorySpaceAsInt())
+          os << ", " << memrefTy.getMemorySpaceAsInt();
         os << '>';
       })
       .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
         os << "memref<*x";
         printType(memrefTy.getElementType());
         // Only print the memory space if it is the non-default one.
-        if (memrefTy.getMemorySpace())
-          os << ", " << memrefTy.getMemorySpace();
+        if (memrefTy.getMemorySpaceAsInt())
+          os << ", " << memrefTy.getMemorySpaceAsInt();
         os << '>';
       })
       .Case<ComplexType>([&](ComplexType complexTy) {

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 9b15854919e0..c84569a53531 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -206,7 +206,7 @@ ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) {
 
   if (auto other = dyn_cast<UnrankedMemRefType>()) {
     MemRefType::Builder b(shape, elementType);
-    b.setMemorySpace(other.getMemorySpace());
+    b.setMemorySpace(other.getMemorySpaceAsInt());
     return b;
   }
 
@@ -229,7 +229,7 @@ ShapedType ShapedType::clone(ArrayRef<int64_t> shape) {
   if (auto other = dyn_cast<UnrankedMemRefType>()) {
     MemRefType::Builder b(shape, other.getElementType());
     b.setShape(shape);
-    b.setMemorySpace(other.getMemorySpace());
+    b.setMemorySpace(other.getMemorySpaceAsInt());
     return b;
   }
 
@@ -250,7 +250,7 @@ ShapedType ShapedType::clone(Type elementType) {
   }
 
   if (auto other = dyn_cast<UnrankedMemRefType>()) {
-    return UnrankedMemRefType::get(elementType, other.getMemorySpace());
+    return UnrankedMemRefType::get(elementType, other.getMemorySpaceAsInt());
   }
 
   if (isa<TensorType>()) {
@@ -472,7 +472,7 @@ UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
 // BaseMemRefType
 //===----------------------------------------------------------------------===//
 
-unsigned BaseMemRefType::getMemorySpace() const {
+unsigned BaseMemRefType::getMemorySpaceAsInt() const {
   return static_cast<ImplType *>(impl)->memorySpace;
 }
 

diff  --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 26a20ee365b5..d6d18b3c6f7a 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -947,7 +947,7 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
   if (bufSize <= localBufSizeThreshold && fastMemorySpace.hasValue()) {
     newMemSpace = fastMemorySpace.getValue();
   } else {
-    newMemSpace = oldMemRefType.getMemorySpace();
+    newMemSpace = oldMemRefType.getMemorySpaceAsInt();
   }
   auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(),
                                        {}, newMemSpace);

diff  --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 77d24fb0c161..71a0fc8e5d89 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -2725,12 +2725,12 @@ uint64_t mlir::affineDataCopyGenerate(Block::iterator begin,
     // Gather regions to allocate to buffers in faster memory space.
     if (auto loadOp = dyn_cast<AffineLoadOp>(opInst)) {
       if ((filterMemRef.hasValue() && filterMemRef != loadOp.getMemRef()) ||
-          (loadOp.getMemRefType().getMemorySpace() !=
+          (loadOp.getMemRefType().getMemorySpaceAsInt() !=
            copyOptions.slowMemorySpace))
         return;
     } else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst)) {
       if ((filterMemRef.hasValue() && filterMemRef != storeOp.getMemRef()) ||
-          storeOp.getMemRefType().getMemorySpace() !=
+          storeOp.getMemRefType().getMemorySpaceAsInt() !=
               copyOptions.slowMemorySpace)
         return;
     } else {


        


More information about the Mlir-commits mailing list