[Mlir-commits] [mlir] f3bf5c0 - [mlir] Model MemRef memory space as Attribute

Vladislav Vinogradov llvmlistbot at llvm.org
Wed Mar 10 02:01:38 PST 2021


Author: Vladislav Vinogradov
Date: 2021-03-10T12:57:27+03:00
New Revision: f3bf5c053b06c756f75ff07517d1e54c48cc99c0

URL: https://github.com/llvm/llvm-project/commit/f3bf5c053b06c756f75ff07517d1e54c48cc99c0
DIFF: https://github.com/llvm/llvm-project/commit/f3bf5c053b06c756f75ff07517d1e54c48cc99c0.diff

LOG: [mlir] Model MemRef memory space as Attribute

Based on the following discussion:
https://llvm.discourse.group/t/rfc-memref-memory-shape-as-attribute/2229

The goal of the change is to make memory space property to have more
expressive representation, rather then "magic" integer values.

It will allow to have more clean ASM form:

```
gpu.func @test(%arg0: memref<100xf32, "workgroup">)

// instead of

gpu.func @test(%arg0: memref<100xf32, 3>)
```

Explanation for `Attribute` choice instead of plain `string`:

* `Attribute` classes allow to use more type safe API based on RTTI.
* `Attribute` classes provides faster comparison operator based on
  pointer comparison in contrast to generic string comparison.
* `Attribute` allows to store more complex things, like structs or dictionaries.
  It will allows to have more complex memory space hierarchy.

This commit preserve old integer-based API and implements it on top
of the new one.

Depends on D97476

Reviewed By: rriddle, mehdi_amini

Differential Revision: https://reviews.llvm.org/D96145

Added: 
    

Modified: 
    mlir/docs/LangRef.md
    mlir/include/mlir-c/BuiltinTypes.h
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/include/mlir/IR/BuiltinTypes.td
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/CAPI/IR/BuiltinTypes.cpp
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/lib/IR/TypeDetail.h
    mlir/lib/Parser/TypeParser.cpp
    mlir/test/Bindings/Python/ir_types.py
    mlir/test/CAPI/ir.c
    mlir/test/IR/invalid.mlir
    mlir/test/IR/parser.mlir
    mlir/unittests/IR/ShapedTypeTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md
index 1ccc5026a0da4..7b58b63258a5e 100644
--- a/mlir/docs/LangRef.md
+++ b/mlir/docs/LangRef.md
@@ -1003,14 +1003,9 @@ sugar is supported to make certain layout specifications more intuitive to read.
 For the moment, a `memref` supports parsing a strided form which is converted to
 a semi-affine map automatically.
 
-The memory space of a memref is specified by a target-specific integer index. If
-no memory space is specified, then the default memory space (0) is used. The
-default space is target specific but always at index 0.
-
-TODO: MLIR will eventually have target-dialects which allow symbolic use of
-memory hierarchy names (e.g. L3, L2, L1, ...) but we have not spec'd the details
-of that mechanism yet. Until then, this document pretends that it is valid to
-refer to these memories by `bare-id`.
+The memory space of a memref is specified by a target-specific attribute.
+It might be an integer value, string, dictionary or custom dialect attribute.
+The empty memory space (attribute is None) is target specific.
 
 The notionally dynamic value of a memref value includes the address of the
 buffer allocated, as well as the symbols referred to by the shape, layout map,

diff  --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index a706c58efc7dd..b2ec37c9deb64 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -224,38 +224,38 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedMemRef(MlirType type);
 /// same context as element type. The type is owned by the context.
 MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGet(
     MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps,
-    MlirAffineMap const *affineMaps, unsigned memorySpace);
+    MlirAffineMap const *affineMaps, MlirAttribute memorySpace);
 
 /// Same as "mlirMemRefTypeGet" but returns a nullptr-wrapping MlirType o
 /// illegal arguments, emitting appropriate diagnostics.
 MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGetChecked(
     MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape,
-    intptr_t numMaps, MlirAffineMap const *affineMaps, unsigned memorySpace);
+    intptr_t numMaps, MlirAffineMap const *affineMaps,
+    MlirAttribute memorySpace);
 
 /// Creates a MemRef type with the given rank, shape, memory space and element
 /// type in the same context as the element type. The type has no affine maps,
 /// i.e. represents a default row-major contiguous memref. The type is owned by
 /// the context.
-MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeContiguousGet(MlirType elementType,
-                                                        intptr_t rank,
-                                                        const int64_t *shape,
-                                                        unsigned memorySpace);
+MLIR_CAPI_EXPORTED MlirType
+mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
+                            const int64_t *shape, MlirAttribute memorySpace);
 
 /// Same as "mlirMemRefTypeContiguousGet" but returns a nullptr wrapping
 /// MlirType on illegal arguments, emitting appropriate diagnostics.
 MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeContiguousGetChecked(
     MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape,
-    unsigned memorySpace);
+    MlirAttribute memorySpace);
 
 /// Creates an Unranked MemRef type with the given element type and in the given
 /// memory space. The type is owned by the context of element type.
-MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGet(MlirType elementType,
-                                                      unsigned memorySpace);
+MLIR_CAPI_EXPORTED MlirType
+mlirUnrankedMemRefTypeGet(MlirType elementType, MlirAttribute memorySpace);
 
 /// Same as "mlirUnrankedMemRefTypeGet" but returns a nullptr wrapping
 /// MlirType on illegal arguments, emitting appropriate diagnostics.
 MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGetChecked(
-    MlirLocation loc, MlirType elementType, unsigned memorySpace);
+    MlirLocation loc, MlirType elementType, MlirAttribute memorySpace);
 
 /// Returns the number of affine layout maps in the given MemRef type.
 MLIR_CAPI_EXPORTED intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type);
@@ -265,10 +265,11 @@ MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type,
                                                             intptr_t pos);
 
 /// Returns the memory space of the given MemRef type.
-MLIR_CAPI_EXPORTED unsigned mlirMemRefTypeGetMemorySpace(MlirType type);
+MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type);
 
 /// Returns the memory spcae of the given Unranked MemRef type.
-MLIR_CAPI_EXPORTED unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type);
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirUnrankedMemrefGetMemorySpace(MlirType type);
 
 //===----------------------------------------------------------------------===//
 // Tuple type.

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 039ef47bc4cbb..0e945b4035e98 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_IR_BUILTINTYPES_H
 #define MLIR_IR_BUILTINTYPES_H
 
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/Types.h"
 
 namespace llvm {
@@ -175,6 +176,10 @@ class BaseMemRefType : public ShapedType {
   static bool classof(Type type);
 
   /// Returns the memory space in which data referred to by this memref resides.
+  Attribute getMemorySpace() const;
+
+  /// [deprecated] Returns the memory space in old raw integer representation.
+  /// New `Attribute getMemorySpace()` method should be used instead.
   unsigned getMemorySpaceAsInt() const;
 };
 
@@ -199,12 +204,12 @@ class MemRefType::Builder {
   // Build from another MemRefType.
   explicit Builder(MemRefType other)
       : shape(other.getShape()), elementType(other.getElementType()),
-        affineMaps(other.getAffineMaps()),
-        memorySpace(other.getMemorySpaceAsInt()) {}
+        affineMaps(other.getAffineMaps()), memorySpace(other.getMemorySpace()) {
+  }
 
   // Build from scratch.
   Builder(ArrayRef<int64_t> shape, Type elementType)
-      : shape(shape), elementType(elementType), affineMaps(), memorySpace(0) {}
+      : shape(shape), elementType(elementType), affineMaps() {}
 
   Builder &setShape(ArrayRef<int64_t> newShape) {
     shape = newShape;
@@ -221,11 +226,14 @@ class MemRefType::Builder {
     return *this;
   }
 
-  Builder &setMemorySpace(unsigned newMemorySpace) {
+  Builder &setMemorySpace(Attribute newMemorySpace) {
     memorySpace = newMemorySpace;
     return *this;
   }
 
+  // [deprecated] `setMemorySpace(Attribute)` should be used instead.
+  Builder &setMemorySpace(unsigned newMemorySpace);
+
   operator MemRefType() {
     return MemRefType::get(shape, elementType, affineMaps, memorySpace);
   }
@@ -234,7 +242,7 @@ class MemRefType::Builder {
   ArrayRef<int64_t> shape;
   Type elementType;
   ArrayRef<AffineMap> affineMaps;
-  unsigned memorySpace;
+  Attribute memorySpace;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 9953eafae2914..02f699ab3628f 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -268,7 +268,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "BaseMemRefType"> {
     strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
     semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
     layout-specification ::= semi-affine-map-composition | strided-layout
-    memory-space ::= integer-literal /* | TODO: address-space-id */
+    memory-space ::= attribute-value
     ```
 
     A `memref` type is a reference to a region of memory (similar to a buffer
@@ -335,14 +335,9 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "BaseMemRefType"> {
     intuitive to read. For the moment, a `memref` supports parsing a strided
     form which is converted to a semi-affine map automatically.
 
-    The memory space of a memref is specified by a target-specific integer
-    index. If no memory space is specified, then the default memory space (0)
-    is used. The default space is target specific but always at index 0.
-
-    TODO: MLIR will eventually have target-dialects which allow symbolic use of
-    memory hierarchy names (e.g. L3, L2, L1, ...) but we have not spec'd the
-    details of that mechanism yet. Until then, this document pretends that it
-    is valid to refer to these memories by `bare-id`.
+    The memory space of a memref is specified by a target-specific attribute.
+    It might be an integer value, string, dictionary or custom dialect attribute.
+    The empty memory space (attribute is None) is target specific.
 
     The notionally dynamic value of a memref value includes the address of the
     buffer allocated, as well as the symbols referred to by the shape, layout
@@ -527,22 +522,34 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "BaseMemRefType"> {
     ArrayRefParameter<"int64_t">:$shape,
     "Type":$elementType,
     ArrayRefParameter<"AffineMap">:$affineMaps,
-    "unsigned":$memorySpaceAsInt
+    "Attribute":$memorySpace
   );
-
   let builders = [
     TypeBuilderWithInferredContext<(ins
       "ArrayRef<int64_t>":$shape, "Type":$elementType,
       CArg<"ArrayRef<AffineMap>", "{}">:$affineMaps,
-      CArg<"unsigned", "0">:$memorySpace
+      CArg<"Attribute", "{}">:$memorySpace
     ), [{
       // Drop identity maps from the composition. This may lead to the
       // composition becoming empty, which is interpreted as an implicit
       // identity.
       auto nonIdentityMaps = llvm::make_filter_range(affineMaps,
         [](AffineMap map) { return !map.isIdentity(); });
+      // Drop default memory space value and replace it with empty attribute.
+      Attribute nonDefaultMemorySpace = skipDefaultMemorySpace(memorySpace);
       return $_get(elementType.getContext(), shape, elementType,
-                   llvm::to_vector<4>(nonIdentityMaps), memorySpace);
+                   llvm::to_vector<4>(nonIdentityMaps), nonDefaultMemorySpace);
+    }]>,
+    /// [deprecated] `Attribute`-based form should be used instead.
+    TypeBuilderWithInferredContext<(ins
+      "ArrayRef<int64_t>":$shape, "Type":$elementType,
+      "ArrayRef<AffineMap>":$affineMaps,
+      "unsigned":$memorySpace
+    ), [{
+      // Convert deprecated integer-like memory space to Attribute.
+      Attribute memorySpaceAttr =
+          wrapIntegerMemorySpace(memorySpace, elementType.getContext());
+      return MemRefType::get(shape, elementType, affineMaps, memorySpaceAttr);
     }]>
   ];
   let extraClassDeclaration = [{
@@ -550,6 +557,10 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "BaseMemRefType"> {
     /// Arguments that are passed into the builder must out-live the builder.
     class Builder;
 
+    /// [deprecated] Returns the memory space in old raw integer representation.
+    /// New `Attribute getMemorySpace()` method should be used instead.
+    unsigned getMemorySpaceAsInt() const;
+
     // TODO: merge these two special values in a single one used everywhere.
     // Unfortunately, uses of `-1` have crept deep into the codebase now and are
     // hard to track.
@@ -767,7 +778,7 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "BaseMemRefType"> {
 
     ```
     unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>`
-    memory-space ::= integer-literal /* | TODO: address-space-id */
+    memory-space ::= attribute-value
     ```
 
     A `memref` type with an unknown rank (e.g. `memref<*xf32>`). The purpose of
@@ -787,16 +798,30 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "BaseMemRefType"> {
     memref<*f32, 10>
     ```
   }];
-  let parameters = (ins "Type":$elementType, "unsigned":$memorySpaceAsInt);
+  let parameters = (ins "Type":$elementType, "Attribute":$memorySpace);
 
   let builders = [
+    TypeBuilderWithInferredContext<(ins "Type":$elementType,
+                                        "Attribute":$memorySpace), [{
+      // Drop default memory space value and replace it with empty attribute.
+      Attribute nonDefaultMemorySpace = skipDefaultMemorySpace(memorySpace);
+      return $_get(elementType.getContext(), elementType, nonDefaultMemorySpace);
+    }]>,
+    /// [deprecated] `Attribute`-based form should be used instead.
     TypeBuilderWithInferredContext<(ins "Type":$elementType,
                                         "unsigned":$memorySpace), [{
-      return $_get(elementType.getContext(), elementType, memorySpace);
+      // Convert deprecated integer-like memory space to Attribute.
+      Attribute memorySpaceAttr =
+          wrapIntegerMemorySpace(memorySpace, elementType.getContext());
+      return UnrankedMemRefType::get(elementType, memorySpaceAttr);
     }]>
   ];
   let extraClassDeclaration = [{
     ArrayRef<int64_t> getShape() const { return llvm::None; }
+
+    /// [deprecated] Returns the memory space in old raw integer representation.
+    /// New `Attribute getMemorySpace()` method should be used instead.
+    unsigned getMemorySpaceAsInt() const;
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 9152fd06d36ac..a544e52c26131 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -2861,16 +2861,20 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
     c.def_static(
          "get",
          [](std::vector<int64_t> shape, PyType &elementType,
-            std::vector<PyAffineMap> layout, unsigned memorySpace,
+            std::vector<PyAffineMap> layout, PyAttribute *memorySpace,
             DefaultingPyLocation loc) {
            SmallVector<MlirAffineMap> maps;
            maps.reserve(layout.size());
            for (PyAffineMap &map : layout)
              maps.push_back(map);
 
+           MlirAttribute memSpaceAttr = {};
+           if (memorySpace)
+             memSpaceAttr = *memorySpace;
+
            MlirType t = mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
                                                  shape.data(), maps.size(),
-                                                 maps.data(), memorySpace);
+                                                 maps.data(), memSpaceAttr);
            // TODO: Rework error reporting once diagnostic engine is exposed
            // in C API.
            if (mlirTypeIsNull(t)) {
@@ -2885,14 +2889,15 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
            return PyMemRefType(elementType.getContext(), t);
          },
          py::arg("shape"), py::arg("element_type"),
-         py::arg("layout") = py::list(), py::arg("memory_space") = 0,
+         py::arg("layout") = py::list(), py::arg("memory_space") = py::none(),
          py::arg("loc") = py::none(), "Create a memref type")
         .def_property_readonly("layout", &PyMemRefType::getLayout,
                                "The list of layout maps of the MemRef type.")
         .def_property_readonly(
             "memory_space",
-            [](PyMemRefType &self) -> unsigned {
-              return mlirMemRefTypeGetMemorySpace(self);
+            [](PyMemRefType &self) -> PyAttribute {
+              MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
+              return PyAttribute(self.getContext(), a);
             },
             "Returns the memory space of the given MemRef type.");
   }
@@ -2944,10 +2949,14 @@ class PyUnrankedMemRefType
   static void bindDerived(ClassTy &c) {
     c.def_static(
          "get",
-         [](PyType &elementType, unsigned memorySpace,
+         [](PyType &elementType, PyAttribute *memorySpace,
             DefaultingPyLocation loc) {
+           MlirAttribute memSpaceAttr = {};
+           if (memorySpace)
+             memSpaceAttr = *memorySpace;
+
            MlirType t =
-               mlirUnrankedMemRefTypeGetChecked(loc, elementType, memorySpace);
+               mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
            // TODO: Rework error reporting once diagnostic engine is exposed
            // in C API.
            if (mlirTypeIsNull(t)) {
@@ -2965,8 +2974,9 @@ class PyUnrankedMemRefType
          py::arg("loc") = py::none(), "Create a unranked memref type")
         .def_property_readonly(
             "memory_space",
-            [](PyUnrankedMemRefType &self) -> unsigned {
-              return mlirUnrankedMemrefGetMemorySpace(self);
+            [](PyUnrankedMemRefType &self) -> PyAttribute {
+              MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
+              return PyAttribute(self.getContext(), a);
             },
             "Returns the memory space of the given Unranked MemRef type.");
   }

diff  --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index e4442ac4c567c..c84ced1779f94 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -223,41 +223,41 @@ bool mlirTypeIsAMemRef(MlirType type) { return unwrap(type).isa<MemRefType>(); }
 MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank,
                            const int64_t *shape, intptr_t numMaps,
                            MlirAffineMap const *affineMaps,
-                           unsigned memorySpace) {
+                           MlirAttribute memorySpace) {
   SmallVector<AffineMap, 1> maps;
   (void)unwrapList(numMaps, affineMaps, maps);
   return wrap(
       MemRefType::get(llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
-                      unwrap(elementType), maps, memorySpace));
+                      unwrap(elementType), maps, unwrap(memorySpace)));
 }
 
 MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType,
                                   intptr_t rank, const int64_t *shape,
                                   intptr_t numMaps,
                                   MlirAffineMap const *affineMaps,
-                                  unsigned memorySpace) {
+                                  MlirAttribute memorySpace) {
   SmallVector<AffineMap, 1> maps;
   (void)unwrapList(numMaps, affineMaps, maps);
   return wrap(MemRefType::getChecked(
       unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
-      unwrap(elementType), maps, memorySpace));
+      unwrap(elementType), maps, unwrap(memorySpace)));
 }
 
 MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
                                      const int64_t *shape,
-                                     unsigned memorySpace) {
+                                     MlirAttribute memorySpace) {
   return wrap(
       MemRefType::get(llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
-                      unwrap(elementType), llvm::None, memorySpace));
+                      unwrap(elementType), llvm::None, unwrap(memorySpace)));
 }
 
 MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc,
                                             MlirType elementType, intptr_t rank,
                                             const int64_t *shape,
-                                            unsigned memorySpace) {
+                                            MlirAttribute memorySpace) {
   return wrap(MemRefType::getChecked(
       unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
-      unwrap(elementType), llvm::None, memorySpace));
+      unwrap(elementType), llvm::None, unwrap(memorySpace)));
 }
 
 intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) {
@@ -269,27 +269,29 @@ MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type, intptr_t pos) {
   return wrap(unwrap(type).cast<MemRefType>().getAffineMaps()[pos]);
 }
 
-unsigned mlirMemRefTypeGetMemorySpace(MlirType type) {
-  return unwrap(type).cast<MemRefType>().getMemorySpaceAsInt();
+MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
+  return wrap(unwrap(type).cast<MemRefType>().getMemorySpace());
 }
 
 bool mlirTypeIsAUnrankedMemRef(MlirType type) {
   return unwrap(type).isa<UnrankedMemRefType>();
 }
 
-MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace) {
-  return wrap(UnrankedMemRefType::get(unwrap(elementType), memorySpace));
+MlirType mlirUnrankedMemRefTypeGet(MlirType elementType,
+                                   MlirAttribute memorySpace) {
+  return wrap(
+      UnrankedMemRefType::get(unwrap(elementType), unwrap(memorySpace)));
 }
 
 MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc,
                                           MlirType elementType,
-                                          unsigned memorySpace) {
+                                          MlirAttribute memorySpace) {
   return wrap(UnrankedMemRefType::getChecked(unwrap(loc), unwrap(elementType),
-                                             memorySpace));
+                                             unwrap(memorySpace)));
 }
 
-unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type) {
-  return unwrap(type).cast<UnrankedMemRefType>().getMemorySpaceAsInt();
+MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type) {
+  return wrap(unwrap(type).cast<UnrankedMemRefType>().getMemorySpace());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 722fbe4f2e5fc..b6d327b1c78b5 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1887,16 +1887,20 @@ void ModulePrinter::printType(Type type) {
           printAttribute(AffineMapAttr::get(map));
         }
         // Only print the memory space if it is the non-default one.
-        if (memrefTy.getMemorySpaceAsInt())
-          os << ", " << memrefTy.getMemorySpaceAsInt();
+        if (memrefTy.getMemorySpace()) {
+          os << ", ";
+          printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
+        }
         os << '>';
       })
       .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
         os << "memref<*x";
         printType(memrefTy.getElementType());
         // Only print the memory space if it is the non-default one.
-        if (memrefTy.getMemorySpaceAsInt())
-          os << ", " << memrefTy.getMemorySpaceAsInt();
+        if (memrefTy.getMemorySpace()) {
+          os << ", ";
+          printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
+        }
         os << '>';
       })
       .Case<ComplexType>([&](ComplexType complexTy) {

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 48afa791f43b1..652883f745e32 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -10,6 +10,8 @@
 #include "TypeDetail.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dialect.h"
 #include "llvm/ADT/APFloat.h"
@@ -207,7 +209,7 @@ ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) {
 
   if (auto other = dyn_cast<UnrankedMemRefType>()) {
     MemRefType::Builder b(shape, elementType);
-    b.setMemorySpace(other.getMemorySpaceAsInt());
+    b.setMemorySpace(other.getMemorySpace());
     return b;
   }
 
@@ -230,7 +232,7 @@ ShapedType ShapedType::clone(ArrayRef<int64_t> shape) {
   if (auto other = dyn_cast<UnrankedMemRefType>()) {
     MemRefType::Builder b(shape, other.getElementType());
     b.setShape(shape);
-    b.setMemorySpace(other.getMemorySpaceAsInt());
+    b.setMemorySpace(other.getMemorySpace());
     return b;
   }
 
@@ -251,7 +253,7 @@ ShapedType ShapedType::clone(Type elementType) {
   }
 
   if (auto other = dyn_cast<UnrankedMemRefType>()) {
-    return UnrankedMemRefType::get(elementType, other.getMemorySpaceAsInt());
+    return UnrankedMemRefType::get(elementType, other.getMemorySpace());
   }
 
   if (isa<TensorType>()) {
@@ -436,6 +438,12 @@ UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
 // BaseMemRefType
 //===----------------------------------------------------------------------===//
 
+Attribute BaseMemRefType::getMemorySpace() const {
+  if (auto rankedMemRefTy = dyn_cast<MemRefType>())
+    return rankedMemRefTy.getMemorySpace();
+  return cast<UnrankedMemRefType>().getMemorySpace();
+}
+
 unsigned BaseMemRefType::getMemorySpaceAsInt() const {
   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
     return rankedMemRefTy.getMemorySpaceAsInt();
@@ -446,10 +454,63 @@ unsigned BaseMemRefType::getMemorySpaceAsInt() const {
 // MemRefType
 //===----------------------------------------------------------------------===//
 
+bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
+  // Empty attribute is allowed as default memory space.
+  if (!memorySpace)
+    return true;
+
+  // Supported built-in attributes.
+  if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>())
+    return true;
+
+  // Allow custom dialect attributes.
+  if (!::mlir::isa<BuiltinDialect>(memorySpace.getDialect()))
+    return true;
+
+  return false;
+}
+
+Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace,
+                                               MLIRContext *ctx) {
+  if (memorySpace == 0)
+    return nullptr;
+
+  return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
+}
+
+Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) {
+  IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>();
+  if (intMemorySpace && intMemorySpace.getValue() == 0)
+    return nullptr;
+
+  return memorySpace;
+}
+
+unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) {
+  if (!memorySpace)
+    return 0;
+
+  assert(memorySpace.isa<IntegerAttr>() &&
+         "Using `getMemorySpaceInteger` with non-Integer attribute");
+
+  return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt());
+}
+
+MemRefType::Builder &
+MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) {
+  memorySpace =
+      wrapIntegerMemorySpace(newMemorySpace, elementType.getContext());
+  return *this;
+}
+
+unsigned MemRefType::getMemorySpaceAsInt() const {
+  return detail::getMemorySpaceAsInt(getMemorySpace());
+}
+
 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
                                  ArrayRef<int64_t> shape, Type elementType,
                                  ArrayRef<AffineMap> affineMapComposition,
-                                 unsigned memorySpace) {
+                                 Attribute memorySpace) {
   if (!BaseMemRefType::isValidElementType(elementType))
     return emitError() << "invalid memref element type";
 
@@ -474,6 +535,11 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
                        << " and affine map" << it.index() + 1 << ": " << dim
                        << " != " << map.getNumDims();
   }
+
+  if (!isSupportedMemorySpace(memorySpace)) {
+    return emitError() << "unsupported memory space Attribute";
+  }
+
   return success();
 }
 
@@ -481,11 +547,19 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
 // UnrankedMemRefType
 //===----------------------------------------------------------------------===//
 
+unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
+  return detail::getMemorySpaceAsInt(getMemorySpace());
+}
+
 LogicalResult
 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
-                           Type elementType, unsigned memorySpace) {
+                           Type elementType, Attribute memorySpace) {
   if (!BaseMemRefType::isValidElementType(elementType))
     return emitError() << "invalid memref element type";
+
+  if (!isSupportedMemorySpace(memorySpace))
+    return emitError() << "unsupported memory space Attribute";
+
   return success();
 }
 

diff  --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h
index 694b161caba5a..5240f766e61c1 100644
--- a/mlir/lib/IR/TypeDetail.h
+++ b/mlir/lib/IR/TypeDetail.h
@@ -130,6 +130,19 @@ struct TupleTypeStorage final
   unsigned numElements;
 };
 
+/// Checks if the memorySpace has supported Attribute type.
+bool isSupportedMemorySpace(Attribute memorySpace);
+
+/// Wraps deprecated integer memory space to the new Attribute form.
+Attribute wrapIntegerMemorySpace(unsigned memorySpace, MLIRContext *ctx);
+
+/// Replaces default memorySpace (integer == `0`) with empty Attribute.
+Attribute skipDefaultMemorySpace(Attribute memorySpace);
+
+/// [deprecated] Returns the memory space in old raw integer representation.
+/// New `Attribute getMemorySpace()` method should be used instead.
+unsigned getMemorySpaceAsInt(Attribute memorySpace);
+
 } // namespace detail
 } // namespace mlir
 

diff  --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp
index 3bb8f1ac75994..378b82f3bb1f6 100644
--- a/mlir/lib/Parser/TypeParser.cpp
+++ b/mlir/lib/Parser/TypeParser.cpp
@@ -224,27 +224,14 @@ Type Parser::parseMemRefType() {
 
   // Parse semi-affine-map-composition.
   SmallVector<AffineMap, 2> affineMapComposition;
-  Optional<unsigned> memorySpace;
+  Attribute memorySpace;
   unsigned numDims = dimensions.size();
 
   auto parseElt = [&]() -> ParseResult {
-    // Check for the memory space.
-    if (getToken().is(Token::integer)) {
-      if (memorySpace)
-        return emitError("multiple memory spaces specified in memref type");
-      memorySpace = getToken().getUnsignedIntegerValue();
-      if (!memorySpace.hasValue())
-        return emitError("invalid memory space in memref type");
-      consumeToken(Token::integer);
-      return success();
-    }
-    if (isUnranked)
-      return emitError("cannot have affine map for unranked memref type");
-    if (memorySpace)
-      return emitError("expected memory space to be last in memref type");
-
     AffineMap map;
     llvm::SMLoc mapLoc = getToken().getLoc();
+
+    // Check for AffineMap as offset/strides.
     if (getToken().is(Token::kw_offset)) {
       int64_t offset;
       SmallVector<int64_t, 4> strides;
@@ -253,16 +240,26 @@ Type Parser::parseMemRefType() {
       // Construct strided affine map.
       map = makeStridedLinearLayoutMap(strides, offset, state.context);
     } else {
-      // Parse an affine map attribute.
-      auto affineMap = parseAttribute();
-      if (!affineMap)
+      // Either it is AffineMapAttr or memory space attribute.
+      Attribute attr = parseAttribute();
+      if (!attr)
         return failure();
-      auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>();
-      if (!affineMapAttr)
-        return emitError("expected affine map in memref type");
-      map = affineMapAttr.getValue();
+
+      if (AffineMapAttr affineMapAttr = attr.dyn_cast<AffineMapAttr>()) {
+        map = affineMapAttr.getValue();
+      } else if (memorySpace) {
+        return emitError("multiple memory spaces specified in memref type");
+      } else {
+        memorySpace = attr;
+        return success();
+      }
     }
 
+    if (isUnranked)
+      return emitError("cannot have affine map for unranked memref type");
+    if (memorySpace)
+      return emitError("expected memory space to be last in memref type");
+
     if (map.getNumDims() != numDims) {
       size_t i = affineMapComposition.size();
       return emitError(mapLoc, "memref affine map dimension mismatch between ")
@@ -285,11 +282,15 @@ Type Parser::parseMemRefType() {
     }
   }
 
-  if (isUnranked)
-    return UnrankedMemRefType::get(elementType, memorySpace.getValueOr(0));
+  if (isUnranked) {
+    return UnrankedMemRefType::getChecked(
+        [&]() -> InFlightDiagnostic { return emitError(); }, elementType,
+        memorySpace);
+  }
 
-  return MemRefType::get(dimensions, elementType, affineMapComposition,
-                         memorySpace.getValueOr(0));
+  return MemRefType::getChecked(
+      [&]() -> InFlightDiagnostic { return emitError(); }, dimensions,
+      elementType, affineMapComposition, memorySpace);
 }
 
 /// Parse any type except the function type.

diff  --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py
index 7402c644a1c18..59b4b50b533d8 100644
--- a/mlir/test/Bindings/Python/ir_types.py
+++ b/mlir/test/Bindings/Python/ir_types.py
@@ -326,7 +326,7 @@ def testMemRefType():
     f32 = F32Type.get()
     shape = [2, 3]
     loc = Location.unknown()
-    memref = MemRefType.get(shape, f32, memory_space=2)
+    memref = MemRefType.get(shape, f32, memory_space=Attribute.parse("2"))
     # CHECK: memref type: memref<2x3xf32, 2>
     print("memref type:", memref)
     # CHECK: number of affine layout maps: 0
@@ -341,7 +341,7 @@ def testMemRefType():
     assert len(memref_layout.layout) == 1
     # CHECK: memref layout: (d0, d1) -> (d1, d0)
     print("memref layout:", memref_layout.layout[0])
-    # CHECK: memory space: 0
+    # CHECK: memory space: <<NULL ATTRIBUTE>>
     print("memory space:", memref_layout.memory_space)
 
     none = NoneType.get()
@@ -361,7 +361,7 @@ def testUnrankedMemRefType():
   with Context(), Location.unknown():
     f32 = F32Type.get()
     loc = Location.unknown()
-    unranked_memref = UnrankedMemRefType.get(f32, 2)
+    unranked_memref = UnrankedMemRefType.get(f32, Attribute.parse("2"))
     # CHECK: unranked memref type: memref<*xf32, 2>
     print("unranked memref type:", unranked_memref)
     try:
@@ -388,7 +388,7 @@ def testUnrankedMemRefType():
 
     none = NoneType.get()
     try:
-      memref_invalid = UnrankedMemRefType.get(none, 2)
+      memref_invalid = UnrankedMemRefType.get(none, Attribute.parse("2"))
     except ValueError as e:
       # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
       # CHECK: or complex type.

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 8b785ee897bb5..38d200ec75971 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -707,21 +707,24 @@ static int printBuiltinTypes(MlirContext ctx) {
   // CHECK: tensor<*xf32>
 
   // MemRef type.
+  MlirAttribute memSpace2 = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 64), 2);
   MlirType memRef = mlirMemRefTypeContiguousGet(
-      f32, sizeof(shape) / sizeof(int64_t), shape, 2);
+      f32, sizeof(shape) / sizeof(int64_t), shape, memSpace2);
   if (!mlirTypeIsAMemRef(memRef) ||
       mlirMemRefTypeGetNumAffineMaps(memRef) != 0 ||
-      mlirMemRefTypeGetMemorySpace(memRef) != 2)
+      !mlirAttributeEqual(mlirMemRefTypeGetMemorySpace(memRef), memSpace2))
     return 18;
   mlirTypeDump(memRef);
   fprintf(stderr, "\n");
   // CHECK: memref<2x3xf32, 2>
 
   // Unranked MemRef type.
-  MlirType unrankedMemRef = mlirUnrankedMemRefTypeGet(f32, 4);
+  MlirAttribute memSpace4 = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 64), 4);
+  MlirType unrankedMemRef = mlirUnrankedMemRefTypeGet(f32, memSpace4);
   if (!mlirTypeIsAUnrankedMemRef(unrankedMemRef) ||
       mlirTypeIsAMemRef(unrankedMemRef) ||
-      mlirUnrankedMemrefGetMemorySpace(unrankedMemRef) != 4)
+      !mlirAttributeEqual(mlirUnrankedMemrefGetMemorySpace(unrankedMemRef),
+                          memSpace4))
     return 19;
   mlirTypeDump(unrankedMemRef);
   fprintf(stderr, "\n");

diff  --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index bb9c6a5523855..5751cf4046069 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -36,8 +36,8 @@ func @memrefs(memref<2x4xi8, >) // expected-error {{expected list element}}
 func @memrefs(memref<2x4xi8, #map7>) // expected-error {{undefined symbol alias id 'map7'}}
 
 // -----
-// Test non affine map in memref type.
-func @memrefs(memref<2x4xi8, i8>) // expected-error {{expected affine map in memref type}}
+// Test unsupported memory space.
+func @memrefs(memref<2x4xi8, i8>) // expected-error {{unsupported memory space Attribute}}
 
 // -----
 // Test non-existent map in map composition of memref type.

diff  --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index b63775edcfb09..eba9a71c1a958 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -137,11 +137,35 @@ func private @memrefs_drop_triv_id_multiple(memref<2xi8, affine_map<(d0) -> (d0)
 func private @memrefs_compose_with_id(memref<2x2xi8, affine_map<(d0, d1) -> (d0, d1)>,
                                              affine_map<(d0, d1) -> (d1, d0)>>)
 
+// Test memref with custom memory space
+
+// CHECK: func private @memrefs_nomap_nospace(memref<5x6x7xf32>)
+func private @memrefs_nomap_nospace(memref<5x6x7xf32>)
+
+// CHECK: func private @memrefs_map_nospace(memref<5x6x7xf32, #map{{[0-9]+}}>)
+func private @memrefs_map_nospace(memref<5x6x7xf32, #map3>)
+
+// CHECK: func private @memrefs_nomap_intspace(memref<5x6x7xf32, 3>)
+func private @memrefs_nomap_intspace(memref<5x6x7xf32, 3>)
+
+// CHECK: func private @memrefs_map_intspace(memref<5x6x7xf32, #map{{[0-9]+}}, 5>)
+func private @memrefs_map_intspace(memref<5x6x7xf32, #map3, 5>)
+
+// CHECK: func private @memrefs_nomap_strspace(memref<5x6x7xf32, "local">)
+func private @memrefs_nomap_strspace(memref<5x6x7xf32, "local">)
+
+// CHECK: func private @memrefs_map_strspace(memref<5x6x7xf32, #map{{[0-9]+}}, "private">)
+func private @memrefs_map_strspace(memref<5x6x7xf32, #map3, "private">)
+
+// CHECK: func private @memrefs_nomap_dictspace(memref<5x6x7xf32, {memSpace = "special", subIndex = 1 : i64}>)
+func private @memrefs_nomap_dictspace(memref<5x6x7xf32, {memSpace = "special", subIndex = 1}>)
+
+// CHECK: func private @memrefs_map_dictspace(memref<5x6x7xf32, #map{{[0-9]+}}, {memSpace = "special", subIndex = 3 : i64}>)
+func private @memrefs_map_dictspace(memref<5x6x7xf32, #map3, {memSpace = "special", subIndex = 3}>)
 
 // CHECK: func private @complex_types(complex<i1>) -> complex<f32>
 func private @complex_types(complex<i1>) -> complex<f32>
 
-
 // CHECK: func private @memref_with_index_elems(memref<1x?xindex>)
 func private @memref_with_index_elems(memref<1x?xindex>)
 

diff  --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index e3e5ffe95fe12..dc7591738b873 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/DialectInterface.h"
@@ -23,7 +24,7 @@ TEST(ShapedTypeTest, CloneMemref) {
 
   Type i32 = IntegerType::get(&context, 32);
   Type f32 = FloatType::getF32(&context);
-  int memSpace = 7;
+  Attribute memSpace = IntegerAttr::get(IntegerType::get(&context, 64), 7);
   Type memrefOriginalType = i32;
   llvm::SmallVector<int64_t> memrefOriginalShape({10, 20});
   AffineMap map = makeStridedLinearLayoutMap({2, 3}, 5, &context);


        


More information about the Mlir-commits mailing list