[Mlir-commits] [mlir] e41ebbe - [mlir][RFC] Refactor layout representation in MemRefType

Vladislav Vinogradov llvmlistbot at llvm.org
Tue Oct 19 02:43:01 PDT 2021


Author: Vladislav Vinogradov
Date: 2021-10-19T12:31:15+03:00
New Revision: e41ebbecf97ac70326dfa2c54e3ef5fa13df54eb

URL: https://github.com/llvm/llvm-project/commit/e41ebbecf97ac70326dfa2c54e3ef5fa13df54eb
DIFF: https://github.com/llvm/llvm-project/commit/e41ebbecf97ac70326dfa2c54e3ef5fa13df54eb.diff

LOG: [mlir][RFC] Refactor layout representation in MemRefType

The change is based on the proposal from the following discussion:
https://llvm.discourse.group/t/rfc-memreftype-affine-maps-list-vs-single-item/3968

* Introduce `MemRefLayoutAttr` interface to get `AffineMap` from an `Attribute`
  (`AffineMapAttr` implements this interface).
* Store layout as a single generic `MemRefLayoutAttr`.

This change removes the affine map composition feature and related API.
Actually, while the `MemRefType` itself supported it, almost none of the upstream
can work with more than 1 affine map in `MemRefType`.

The introduced `MemRefLayoutAttr` allows to re-implement this feature
in a more stable way - via separate attribute class.

Also the interface allows to use different layout representations rather than affine maps.
For example, the described "stride + offset" form, which is currently supported in ASM parser only,
can now be expressed as separate attribute.

Reviewed By: ftynse, bondhugula

Differential Revision: https://reviews.llvm.org/D111553

Added: 
    

Modified: 
    mlir/include/mlir-c/BuiltinTypes.h
    mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
    mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
    mlir/include/mlir/IR/BuiltinAttributes.td
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/include/mlir/IR/BuiltinTypes.td
    mlir/lib/Analysis/LoopAnalysis.cpp
    mlir/lib/Analysis/Utils.cpp
    mlir/lib/Bindings/Python/IRTypes.cpp
    mlir/lib/CAPI/IR/BuiltinTypes.cpp
    mlir/lib/Conversion/LLVMCommon/Pattern.cpp
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/BuiltinAttributeInterfaces.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/lib/Parser/TypeParser.cpp
    mlir/lib/Transforms/NormalizeMemRefs.cpp
    mlir/lib/Transforms/PipelineDataTransfer.cpp
    mlir/lib/Transforms/Utils/LoopUtils.cpp
    mlir/lib/Transforms/Utils/Utils.cpp
    mlir/test/CAPI/ir.c
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/IR/invalid-ops.mlir
    mlir/test/IR/invalid.mlir
    mlir/test/IR/parser.mlir
    mlir/test/python/ir/builtin_types.py
    mlir/unittests/IR/ShapedTypeTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index a677d4d365b11..2983627a57896 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -229,16 +229,17 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedMemRef(MlirType type);
 /// Creates a MemRef type with the given rank and shape, a potentially empty
 /// list of affine layout maps, the given memory space and element type, in the
 /// same context as element type. The type is owned by the context.
-MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGet(
-    MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps,
-    MlirAffineMap const *affineMaps, MlirAttribute memorySpace);
+MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGet(MlirType elementType,
+                                              intptr_t rank,
+                                              const int64_t *shape,
+                                              MlirAttribute layout,
+                                              MlirAttribute memorySpace);
 
 /// Same as "mlirMemRefTypeGet" but returns a nullptr-wrapping MlirType o
 /// illegal arguments, emitting appropriate diagnostics.
 MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGetChecked(
     MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape,
-    intptr_t numMaps, MlirAffineMap const *affineMaps,
-    MlirAttribute memorySpace);
+    MlirAttribute layout, MlirAttribute memorySpace);
 
 /// Creates a MemRef type with the given rank, shape, memory space and element
 /// type in the same context as the element type. The type has no affine maps,
@@ -264,12 +265,11 @@ mlirUnrankedMemRefTypeGet(MlirType elementType, MlirAttribute memorySpace);
 MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGetChecked(
     MlirLocation loc, MlirType elementType, MlirAttribute memorySpace);
 
-/// Returns the number of affine layout maps in the given MemRef type.
-MLIR_CAPI_EXPORTED intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type);
+/// Returns the layout of the given MemRef type.
+MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetLayout(MlirType type);
 
-/// Returns the pos-th affine map of the given MemRef type.
-MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type,
-                                                            intptr_t pos);
+/// Returns the affine map of the given MemRef type.
+MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type);
 
 /// Returns the memory space of the given MemRef type.
 MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type);

diff  --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
index 392bffc09da6e..c48a359383ff4 100644
--- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
 #define MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
 
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Support/LogicalResult.h"
@@ -227,6 +228,21 @@ class ElementsAttrIterator
   ptr
diff _t index;
 };
 } // namespace detail
+
+//===----------------------------------------------------------------------===//
+// MemRefLayoutAttrInterface
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+
+// Verify the affine map 'm' can be used as a layout specification
+// for memref with 'shape'.
+LogicalResult
+verifyAffineMapAsLayout(AffineMap m, ArrayRef<int64_t> shape,
+                        function_ref<InFlightDiagnostic()> emitError);
+
+} // namespace detail
+
 } // namespace mlir
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
index fa8cb28223a0d..30b3ea7ca09a8 100644
--- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
@@ -432,4 +432,52 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
   }] # ElementsAttrInterfaceAccessors;
 }
 
+//===----------------------------------------------------------------------===//
+// MemRefLayoutAttrInterface
+//===----------------------------------------------------------------------===//
+
+def MemRefLayoutAttrInterface : AttrInterface<"MemRefLayoutAttrInterface"> {
+  let cppNamespace = "::mlir";
+
+  let description = [{
+    This interface is used for attributes that can represent the MemRef type's
+    layout semantics, such as dimension order in the memory, strides and offsets.
+    Such a layout attribute should be representable as a
+    [semi-affine map](Affine.md/#semi-affine-maps).
+
+    Note: the MemRef type's layout is assumed to represent simple strided buffer
+    layout. For more complicated case, like sparse storage buffers,
+    it is preferable to use separate type with more specic layout, rather then
+    introducing extra complexity to the builin MemRef type.
+  }];
+
+  let methods = [
+    InterfaceMethod<
+      "Get the MemRef layout as an AffineMap, the method must not return NULL",
+      "::mlir::AffineMap", "getAffineMap", (ins)
+    >,
+
+    InterfaceMethod<
+      "Return true if this attribute represents the identity layout",
+      "bool", "isIdentity", (ins),
+      [{}],
+      [{
+        return $_attr.getAffineMap().isIdentity();
+      }]
+    >,
+
+    InterfaceMethod<
+      "Check if the current layout is applicable to the provided shape",
+      "::mlir::LogicalResult", "verifyLayout",
+      (ins "::llvm::ArrayRef<int64_t>":$shape,
+           "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError),
+      [{}],
+      [{
+        return ::mlir::detail::verifyAffineMapAsLayout($_attr.getAffineMap(),
+                                                       shape, emitError);
+      }]
+    >
+  ];
+}
+
 #endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 0d3ead2383722..fcd6082a6ef85 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -34,7 +34,9 @@ class Builtin_Attr<string name, list<Trait> traits = [],
 // AffineMapAttr
 //===----------------------------------------------------------------------===//
 
-def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap"> {
+def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap", [
+    MemRefLayoutAttrInterface
+  ]> {
   let summary = "An Attribute containing an AffineMap object";
   let description = [{
     Syntax:
@@ -56,7 +58,10 @@ def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap"> {
       return $_get(value.getContext(), value);
     }]>
   ];
-  let extraClassDeclaration = "using ValueType = AffineMap;";
+  let extraClassDeclaration = [{
+    using ValueType = AffineMap;
+    AffineMap getAffineMap() const { return getValue(); }
+  }];
   let skipDefaultBuilders = 1;
   let typeBuilder = "IndexType::get($_value.getContext())";
 }

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 47368f378dc7b..f2f3ccf537626 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_IR_BUILTINTYPES_H
 #define MLIR_IR_BUILTINTYPES_H
 
+#include "BuiltinAttributeInterfaces.h"
 #include "SubElementInterfaces.h"
 
 namespace llvm {
@@ -209,12 +210,11 @@ class MemRefType::Builder {
   // Build from another MemRefType.
   explicit Builder(MemRefType other)
       : shape(other.getShape()), elementType(other.getElementType()),
-        affineMaps(other.getAffineMaps()), memorySpace(other.getMemorySpace()) {
-  }
+        layout(other.getLayout()), memorySpace(other.getMemorySpace()) {}
 
   // Build from scratch.
   Builder(ArrayRef<int64_t> shape, Type elementType)
-      : shape(shape), elementType(elementType), affineMaps() {}
+      : shape(shape), elementType(elementType) {}
 
   Builder &setShape(ArrayRef<int64_t> newShape) {
     shape = newShape;
@@ -226,8 +226,8 @@ class MemRefType::Builder {
     return *this;
   }
 
-  Builder &setAffineMaps(ArrayRef<AffineMap> newAffineMaps) {
-    affineMaps = newAffineMaps;
+  Builder &setLayout(MemRefLayoutAttrInterface newLayout) {
+    layout = newLayout;
     return *this;
   }
 
@@ -240,13 +240,13 @@ class MemRefType::Builder {
   Builder &setMemorySpace(unsigned newMemorySpace);
 
   operator MemRefType() {
-    return MemRefType::get(shape, elementType, affineMaps, memorySpace);
+    return MemRefType::get(shape, elementType, layout, memorySpace);
   }
 
 private:
   ArrayRef<int64_t> shape;
   Type elementType;
-  ArrayRef<AffineMap> affineMaps;
+  MemRefLayoutAttrInterface layout;
   Attribute memorySpace;
 };
 

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 7edb9ee5fac20..b27086144761e 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -278,8 +278,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", [
 
     stride-list ::= `[` (dimension (`,` dimension)*)? `]`
     strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
-    semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
-    layout-specification ::= semi-affine-map-composition | strided-layout
+    layout-specification ::= semi-affine-map | strided-layout | attribute-value
     memory-space ::= attribute-value
     ```
 
@@ -486,27 +485,6 @@ def Builtin_MemRef : Builtin_Type<"MemRef", [
     #layout_tiled = (i, j) -> (i floordiv 64, j floordiv 64, i mod 64, j mod 64)
     ```
 
-    ##### Affine Map Composition
-
-    A memref specifies a semi-affine map composition as part of its type. A
-    semi-affine map composition is a composition of semi-affine maps beginning
-    with zero or more index maps, and ending with a layout map. The composition
-    must be conformant: the number of dimensions of the range of one map, must
-    match the number of dimensions of the domain of the next map in the
-    composition.
-
-    The semi-affine map composition specified in the memref type, maps from
-    accesses used to index the memref in load/store operations to other index
-    spaces (i.e. logical to physical index mapping). Each of the
-    [semi-affine maps](Affine.md/#semi-affine-maps) and thus its composition is required
-    to be one-to-one.
-
-    The semi-affine map composition can be used in dependence analysis, memory
-    access pattern analysis, and for performance optimizations like
-    vectorization, copy elision and in-place updates. If an affine map
-    composition is not specified for the memref, the identity affine map is
-    assumed.
-
     ##### Strided MemRef
 
     A memref may specify a strided layout as part of its type. A stride
@@ -544,36 +522,23 @@ def Builtin_MemRef : Builtin_Type<"MemRef", [
   let parameters = (ins
     ArrayRefParameter<"int64_t">:$shape,
     "Type":$elementType,
-    ArrayRefParameter<"AffineMap">:$affineMaps,
+    "MemRefLayoutAttrInterface":$layout,
     "Attribute":$memorySpace
   );
   let builders = [
     TypeBuilderWithInferredContext<(ins
       "ArrayRef<int64_t>":$shape, "Type":$elementType,
-      CArg<"ArrayRef<AffineMap>", "{}">:$affineMaps,
-      CArg<"Attribute", "{}">:$memorySpace
-    ), [{
-      // Drop identity maps from the composition. This may lead to the
-      // composition becoming empty, which is interpreted as an implicit
-      // identity.
-      auto nonIdentityMaps = llvm::make_filter_range(affineMaps,
-        [](AffineMap map) { return !map.isIdentity(); });
-      // Drop default memory space value and replace it with empty attribute.
-      Attribute nonDefaultMemorySpace = skipDefaultMemorySpace(memorySpace);
-      return $_get(elementType.getContext(), shape, elementType,
-                   llvm::to_vector<4>(nonIdentityMaps), nonDefaultMemorySpace);
-    }]>,
+      CArg<"MemRefLayoutAttrInterface", "{}">:$layout,
+      CArg<"Attribute", "{}">:$memorySpace)>,
+    TypeBuilderWithInferredContext<(ins
+      "ArrayRef<int64_t>":$shape, "Type":$elementType,
+      CArg<"AffineMap">:$map,
+      CArg<"Attribute", "{}">:$memorySpace)>,
     /// [deprecated] `Attribute`-based form should be used instead.
     TypeBuilderWithInferredContext<(ins
       "ArrayRef<int64_t>":$shape, "Type":$elementType,
-      "ArrayRef<AffineMap>":$affineMaps,
-      "unsigned":$memorySpace
-    ), [{
-      // Convert deprecated integer-like memory space to Attribute.
-      Attribute memorySpaceAttr =
-          wrapIntegerMemorySpace(memorySpace, elementType.getContext());
-      return MemRefType::get(shape, elementType, affineMaps, memorySpaceAttr);
-    }]>
+      "AffineMap":$map,
+      "unsigned":$memorySpaceInd)>
   ];
   let extraClassDeclaration = [{
     /// This is a builder type that keeps local references to arguments.

diff  --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp
index d40958e877771..478b212c3b520 100644
--- a/mlir/lib/Analysis/LoopAnalysis.cpp
+++ b/mlir/lib/Analysis/LoopAnalysis.cpp
@@ -219,15 +219,8 @@ static bool isContiguousAccess(Value iv, LoadOrStoreOp memoryOp,
   assert(memRefDim && "memRefDim == nullptr");
   auto memRefType = memoryOp.getMemRefType();
 
-  auto layoutMap = memRefType.getAffineMaps();
-  // TODO: remove dependence on Builder once we support non-identity layout map.
-  Builder b(memoryOp.getContext());
-  if (layoutMap.size() >= 2 ||
-      (layoutMap.size() == 1 &&
-       !(layoutMap[0] ==
-         b.getMultiDimIdentityMap(layoutMap[0].getNumDims())))) {
+  if (!memRefType.getLayout().isIdentity())
     return memoryOp.emitError("NYI: non-trivial layoutMap"), false;
-  }
 
   int uniqueVaryingIndexAlongIv = -1;
   auto accessMap = memoryOp.getAffineMap();

diff  --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index b9e7b8479c65e..9565441863018 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -616,9 +616,7 @@ static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
 Optional<int64_t> MemRefRegion::getRegionSize() {
   auto memRefType = memref.getType().cast<MemRefType>();
 
-  auto layoutMaps = memRefType.getAffineMaps();
-  if (layoutMaps.size() > 1 ||
-      (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) {
+  if (!memRefType.getLayout().isIdentity()) {
     LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
     return false;
   }

diff  --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index fd9f3efe7405f..1cfd799bf6934 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -401,8 +401,6 @@ class PyUnrankedTensorType
   }
 };
 
-class PyMemRefLayoutMapList;
-
 /// Ranked MemRef Type subclass - MemRefType.
 class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
 public:
@@ -410,26 +408,18 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
   static constexpr const char *pyClassName = "MemRefType";
   using PyConcreteType::PyConcreteType;
 
-  PyMemRefLayoutMapList getLayout();
-
   static void bindDerived(ClassTy &c) {
     c.def_static(
          "get",
          [](std::vector<int64_t> shape, PyType &elementType,
-            std::vector<PyAffineMap> layout, PyAttribute *memorySpace,
+            PyAttribute *layout, PyAttribute *memorySpace,
             DefaultingPyLocation loc) {
-           SmallVector<MlirAffineMap> maps;
-           maps.reserve(layout.size());
-           for (PyAffineMap &map : layout)
-             maps.push_back(map);
-
-           MlirAttribute memSpaceAttr = {};
-           if (memorySpace)
-             memSpaceAttr = *memorySpace;
-
-           MlirType t = mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
-                                                 shape.data(), maps.size(),
-                                                 maps.data(), memSpaceAttr);
+           MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
+           MlirAttribute memSpaceAttr =
+               memorySpace ? *memorySpace : mlirAttributeGetNull();
+           MlirType t =
+               mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
+                                        shape.data(), layoutAttr, memSpaceAttr);
            // TODO: Rework error reporting once diagnostic engine is exposed
            // in C API.
            if (mlirTypeIsNull(t)) {
@@ -444,10 +434,22 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
            return PyMemRefType(elementType.getContext(), t);
          },
          py::arg("shape"), py::arg("element_type"),
-         py::arg("layout") = py::list(), py::arg("memory_space") = py::none(),
+         py::arg("layout") = py::none(), py::arg("memory_space") = py::none(),
          py::arg("loc") = py::none(), "Create a memref type")
-        .def_property_readonly("layout", &PyMemRefType::getLayout,
-                               "The list of layout maps of the MemRef type.")
+        .def_property_readonly(
+            "layout",
+            [](PyMemRefType &self) -> PyAttribute {
+              MlirAttribute layout = mlirMemRefTypeGetLayout(self);
+              return PyAttribute(self.getContext(), layout);
+            },
+            "The layout of the MemRef type.")
+        .def_property_readonly(
+            "affine_map",
+            [](PyMemRefType &self) -> PyAffineMap {
+              MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
+              return PyAffineMap(self.getContext(), map);
+            },
+            "The layout of the MemRef type as an affine map.")
         .def_property_readonly(
             "memory_space",
             [](PyMemRefType &self) -> PyAttribute {
@@ -458,41 +460,6 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
   }
 };
 
-/// A list of affine layout maps in a memref type. Internally, these are stored
-/// as consecutive elements, random access is cheap. Both the type and the maps
-/// are owned by the context, no need to worry about lifetime extension.
-class PyMemRefLayoutMapList
-    : public Sliceable<PyMemRefLayoutMapList, PyAffineMap> {
-public:
-  static constexpr const char *pyClassName = "MemRefLayoutMapList";
-
-  PyMemRefLayoutMapList(PyMemRefType type, intptr_t startIndex = 0,
-                        intptr_t length = -1, intptr_t step = 1)
-      : Sliceable(startIndex,
-                  length == -1 ? mlirMemRefTypeGetNumAffineMaps(type) : length,
-                  step),
-        memref(type) {}
-
-  intptr_t getNumElements() { return mlirMemRefTypeGetNumAffineMaps(memref); }
-
-  PyAffineMap getElement(intptr_t index) {
-    return PyAffineMap(memref.getContext(),
-                       mlirMemRefTypeGetAffineMap(memref, index));
-  }
-
-  PyMemRefLayoutMapList slice(intptr_t startIndex, intptr_t length,
-                              intptr_t step) {
-    return PyMemRefLayoutMapList(memref, startIndex, length, step);
-  }
-
-private:
-  PyMemRefType memref;
-};
-
-PyMemRefLayoutMapList PyMemRefType::getLayout() {
-  return PyMemRefLayoutMapList(*this);
-}
-
 /// Unranked MemRef Type subclass - UnrankedMemRefType.
 class PyUnrankedMemRefType
     : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
@@ -640,7 +607,6 @@ void mlir::python::populateIRTypes(py::module &m) {
   PyRankedTensorType::bind(m);
   PyUnrankedTensorType::bind(m);
   PyMemRefType::bind(m);
-  PyMemRefLayoutMapList::bind(m);
   PyUnrankedMemRefType::bind(m);
   PyTupleType::bind(m);
   PyFunctionType::bind(m);

diff  --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index d978f17b98d5b..318b8eb10c16f 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -226,34 +226,35 @@ MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc,
 bool mlirTypeIsAMemRef(MlirType type) { return unwrap(type).isa<MemRefType>(); }
 
 MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank,
-                           const int64_t *shape, intptr_t numMaps,
-                           MlirAffineMap const *affineMaps,
+                           const int64_t *shape, MlirAttribute layout,
                            MlirAttribute memorySpace) {
-  SmallVector<AffineMap, 1> maps;
-  (void)unwrapList(numMaps, affineMaps, maps);
-  return wrap(
-      MemRefType::get(llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
-                      unwrap(elementType), maps, unwrap(memorySpace)));
+  return wrap(MemRefType::get(
+      llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
+      mlirAttributeIsNull(layout)
+          ? MemRefLayoutAttrInterface()
+          : unwrap(layout).cast<MemRefLayoutAttrInterface>(),
+      unwrap(memorySpace)));
 }
 
 MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType,
                                   intptr_t rank, const int64_t *shape,
-                                  intptr_t numMaps,
-                                  MlirAffineMap const *affineMaps,
+                                  MlirAttribute layout,
                                   MlirAttribute memorySpace) {
-  SmallVector<AffineMap, 1> maps;
-  (void)unwrapList(numMaps, affineMaps, maps);
   return wrap(MemRefType::getChecked(
       unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
-      unwrap(elementType), maps, unwrap(memorySpace)));
+      unwrap(elementType),
+      mlirAttributeIsNull(layout)
+          ? MemRefLayoutAttrInterface()
+          : unwrap(layout).cast<MemRefLayoutAttrInterface>(),
+      unwrap(memorySpace)));
 }
 
 MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
                                      const int64_t *shape,
                                      MlirAttribute memorySpace) {
-  return wrap(
-      MemRefType::get(llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
-                      unwrap(elementType), llvm::None, unwrap(memorySpace)));
+  return wrap(MemRefType::get(
+      llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
+      MemRefLayoutAttrInterface(), unwrap(memorySpace)));
 }
 
 MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc,
@@ -262,16 +263,15 @@ MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc,
                                             MlirAttribute memorySpace) {
   return wrap(MemRefType::getChecked(
       unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
-      unwrap(elementType), llvm::None, unwrap(memorySpace)));
+      unwrap(elementType), MemRefLayoutAttrInterface(), unwrap(memorySpace)));
 }
 
-intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) {
-  return static_cast<intptr_t>(
-      unwrap(type).cast<MemRefType>().getAffineMaps().size());
+MlirAttribute mlirMemRefTypeGetLayout(MlirType type) {
+  return wrap(unwrap(type).cast<MemRefType>().getLayout());
 }
 
-MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type, intptr_t pos) {
-  return wrap(unwrap(type).cast<MemRefType>().getAffineMaps()[pos]);
+MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type) {
+  return wrap(unwrap(type).cast<MemRefType>().getLayout().getAffineMap());
 }
 
 MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {

diff  --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 7745b9585a5cc..47dabc90bce5b 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -106,9 +106,7 @@ bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps(
     MemRefType type) const {
   if (!typeConverter->convertType(type.getElementType()))
     return false;
-  return type.getAffineMaps().empty() ||
-         llvm::all_of(type.getAffineMaps(),
-                      [](AffineMap map) { return map.isIdentity(); });
+  return type.getLayout().isIdentity();
 }
 
 Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 74462caf39b45..6188e8b571123 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1142,7 +1142,8 @@ class ReassociatingReshapeOpConversion
                   ConversionPatternRewriter &rewriter) const override {
     MemRefType dstType = reshapeOp.getResultType();
     MemRefType srcType = reshapeOp.getSrcType();
-    if (!srcType.getAffineMaps().empty() || !dstType.getAffineMaps().empty()) {
+    if (!srcType.getLayout().isIdentity() ||
+        !dstType.getLayout().isIdentity()) {
       return rewriter.notifyMatchFailure(reshapeOp,
                                          "only empty layout map is supported");
     }

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 765b58d8c3d22..a6f25332d1331 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -950,8 +950,7 @@ computeContiguousStrides(MemRefType memRefType) {
   if (!strides.empty() && strides.back() != 1)
     return None;
   // If no layout or identity layout, this is contiguous by definition.
-  if (memRefType.getAffineMaps().empty() ||
-      memRefType.getAffineMaps().front().isIdentity())
+  if (memRefType.getLayout().isIdentity())
     return strides;
 
   // Otherwise, we must determine contiguity form shapes. This can only ever

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 9a15cb26e2b24..ba1710b57a919 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1047,8 +1047,7 @@ static LogicalResult verify(SubgroupMmaLoadMatrixOp op) {
   auto srcMemrefType = srcType.cast<MemRefType>();
   auto srcMemSpace = srcMemrefType.getMemorySpaceAsInt();
 
-  if (!srcMemrefType.getAffineMaps().empty() &&
-      !srcMemrefType.getAffineMaps().front().isIdentity())
+  if (!srcMemrefType.getLayout().isIdentity())
     return op.emitError("expected identity layout map for source memref");
 
   if (srcMemSpace != kGenericMemorySpace && srcMemSpace != kSharedMemorySpace &&
@@ -1074,9 +1073,7 @@ static LogicalResult verify(SubgroupMmaStoreMatrixOp op) {
   auto srcMatrixType = srcType.cast<gpu::MMAMatrixType>();
   auto dstMemrefType = dstType.cast<MemRefType>();
   auto dstMemSpace = dstMemrefType.getMemorySpaceAsInt();
-
-  if (!dstMemrefType.getAffineMaps().empty() &&
-      !dstMemrefType.getAffineMaps().front().isIdentity())
+  if (!dstMemrefType.getLayout().isIdentity())
     return op.emitError("expected identity layout map for destination memref");
 
   if (dstMemSpace != kGenericMemorySpace && dstMemSpace != kSharedMemorySpace &&

diff  --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index b7a05c3420532..1aebd90a2e660 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -159,9 +159,8 @@ struct GpuAllReduceRewriter {
   Value createWorkgroupBuffer() {
     int workgroupMemoryAddressSpace =
         gpu::GPUDialect::getWorkgroupAddressSpace();
-    auto bufferType =
-        MemRefType::get({kSubgroupSize}, valueType, ArrayRef<AffineMap>{},
-                        workgroupMemoryAddressSpace);
+    auto bufferType = MemRefType::get({kSubgroupSize}, valueType, AffineMap{},
+                                      workgroupMemoryAddressSpace);
     return funcOp.addWorkgroupAttribution(bufferType);
   }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 7e9a0e0ed38a6..3d7919b6e7125 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -1219,27 +1219,24 @@ getEquivalentEnclosingFuncBBArg(Value v,
 /// with the same shape as `shapedType` and specified `layout` and
 /// `addressSpace`.
 static MemRefType getContiguousMemRefType(ShapedType shapedType,
-                                          ArrayRef<AffineMap> layout = {},
-                                          unsigned addressSpace = 0) {
-  if (RankedTensorType tensorType = shapedType.dyn_cast<RankedTensorType>())
-    return MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
-                           layout, addressSpace);
-  MemRefType memrefType = shapedType.cast<MemRefType>();
-  return MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
-                         layout, addressSpace);
+                                          MemRefLayoutAttrInterface layout = {},
+                                          Attribute memorySpace = {}) {
+  return MemRefType::get(shapedType.getShape(), shapedType.getElementType(),
+                         layout, memorySpace);
 }
 
 /// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
 /// with the same shape as `shapedType` and specified `layout` and
 /// `addressSpace` or an UnrankedMemRefType otherwise.
-static Type getContiguousOrUnrankedMemRefType(Type type,
-                                              ArrayRef<AffineMap> layout = {},
-                                              unsigned addressSpace = 0) {
+static Type
+getContiguousOrUnrankedMemRefType(Type type,
+                                  MemRefLayoutAttrInterface layout = {},
+                                  Attribute memorySpace = {}) {
   if (type.isa<RankedTensorType, MemRefType>())
     return getContiguousMemRefType(type.cast<ShapedType>(), layout,
-                                   addressSpace);
-  assert(layout.empty() && "expected empty layout with UnrankedMemRefType");
-  return UnrankedMemRefType::get(getElementTypeOrSelf(type), addressSpace);
+                                   memorySpace);
+  assert(!layout && "expected empty layout with UnrankedMemRefType");
+  return UnrankedMemRefType::get(getElementTypeOrSelf(type), memorySpace);
 }
 
 /// Return a MemRefType to which the `tensorType` can be bufferized in a
@@ -1644,16 +1641,16 @@ static LogicalResult bufferize(OpBuilder &b, tensor::CastOp castOp,
   auto rankedMemRefType = sourceType.dyn_cast<MemRefType>();
   auto unrankedMemRefType = sourceType.dyn_cast<UnrankedMemRefType>();
   assert(rankedMemRefType || unrankedMemRefType);
-  unsigned memorySpace = rankedMemRefType
-                             ? rankedMemRefType.getMemorySpaceAsInt()
-                             : unrankedMemRefType.getMemorySpaceAsInt();
+  Attribute memorySpace = rankedMemRefType
+                              ? rankedMemRefType.getMemorySpace()
+                              : unrankedMemRefType.getMemorySpace();
   TensorType tensorType = castOp.getResult().getType().cast<TensorType>();
-  ArrayRef<AffineMap> affineMaps =
+  MemRefLayoutAttrInterface layout =
       rankedMemRefType && tensorType.isa<RankedTensorType>()
-          ? rankedMemRefType.getAffineMaps()
-          : ArrayRef<AffineMap>{};
+          ? rankedMemRefType.getLayout()
+          : MemRefLayoutAttrInterface();
   Type memRefType = getContiguousOrUnrankedMemRefType(
-      castOp.getResult().getType(), affineMaps, memorySpace);
+      castOp.getResult().getType(), layout, memorySpace);
   Value res =
       b.create<memref::CastOp>(castOp.getLoc(), memRefType, resultBuffer);
   aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index c91cb0d4f7867..cb92185394938 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -258,7 +258,7 @@ replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand,
   // leave them unchanged.
   Type actualType = opOperand->get().getType();
   if (auto memref = actualType.dyn_cast<MemRefType>()) {
-    if (!memref.getAffineMaps().empty())
+    if (!memref.getLayout().isIdentity())
       return llvm::None;
   }
 

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 928b727db9a0f..e792d6581566d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -88,8 +88,8 @@ static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
                           "dynamic dimension count");
 
   unsigned numSymbols = 0;
-  if (!memRefType.getAffineMaps().empty())
-    numSymbols = memRefType.getAffineMaps().front().getNumSymbols();
+  if (!memRefType.getLayout().isIdentity())
+    numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
   if (op.symbolOperands().size() != numSymbols)
     return op.emitOpError("symbol operand count does not equal memref symbol "
                           "count: expected ")
@@ -496,7 +496,7 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   if (aT && bT) {
     if (aT.getElementType() != bT.getElementType())
       return false;
-    if (aT.getAffineMaps() != bT.getAffineMaps()) {
+    if (aT.getLayout() != bT.getLayout()) {
       int64_t aOffset, bOffset;
       SmallVector<int64_t, 4> aStrides, bStrides;
       if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
@@ -1408,7 +1408,7 @@ static LogicalResult verify(ReinterpretCastOp op) {
 
   // Match offset and strides in static_offset and static_strides attributes if
   // result memref type has an affine map specified.
-  if (!resultType.getAffineMaps().empty()) {
+  if (!resultType.getLayout().isIdentity()) {
     int64_t resultOffset;
     SmallVector<int64_t, 4> resultStrides;
     if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
@@ -1526,8 +1526,8 @@ computeReshapeCollapsedType(MemRefType type,
   }
 
   // Early-exit: if `type` is contiguous, the result must be contiguous.
-  if (canonicalizeStridedLayout(type).getAffineMaps().empty())
-    return MemRefType::Builder(type).setShape(newSizes).setAffineMaps({});
+  if (canonicalizeStridedLayout(type).getLayout().isIdentity())
+    return MemRefType::Builder(type).setShape(newSizes).setLayout({});
 
   // Convert back to int64_t because we don't have enough information to create
   // new strided layouts from AffineExpr only. This corresponds to a case where
@@ -1546,7 +1546,8 @@ computeReshapeCollapsedType(MemRefType type,
   auto layout =
       makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext());
   return canonicalizeStridedLayout(
-      MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout}));
+      MemRefType::Builder(type).setShape(newSizes).setLayout(
+          AffineMapAttr::get(layout)));
 }
 
 void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src,
@@ -1662,14 +1663,14 @@ static LogicalResult verify(ReshapeOp op) {
                           "types should be the same");
 
   if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
-    if (!operandMemRefType.getAffineMaps().empty())
+    if (!operandMemRefType.getLayout().isIdentity())
       return op.emitOpError(
           "source memref type should have identity affine map");
 
   int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
   auto resultMemRefType = resultType.dyn_cast<MemRefType>();
   if (resultMemRefType) {
-    if (!resultMemRefType.getAffineMaps().empty())
+    if (!resultMemRefType.getLayout().isIdentity())
       return op.emitOpError(
           "result memref type should have identity affine map");
     if (shapeSize == ShapedType::kDynamicSize)
@@ -1824,10 +1825,9 @@ Type SubViewOp::inferRankReducedResultType(
       if (!dimsToProject.contains(pos))
         projectedShape.push_back(shape[pos]);
 
-    AffineMap map;
-    auto maps = inferredType.getAffineMaps();
-    if (!maps.empty() && maps.front())
-      map = getProjectedMap(maps.front(), dimsToProject);
+    AffineMap map = inferredType.getLayout().getAffineMap();
+    if (!map.isIdentity())
+      map = getProjectedMap(map, dimsToProject);
     inferredType =
         MemRefType::get(projectedShape, inferredType.getElementType(), map,
                         inferredType.getMemorySpace());
@@ -2279,7 +2279,9 @@ static MemRefType inferTransposeResultType(MemRefType memRefType,
   auto map =
       makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
   map = permutationMap ? map.compose(permutationMap) : map;
-  return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map);
+  return MemRefType::Builder(memRefType)
+      .setShape(sizes)
+      .setLayout(AffineMapAttr::get(map));
 }
 
 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
@@ -2387,15 +2389,11 @@ static LogicalResult verify(ViewOp op) {
   auto viewType = op.getType();
 
   // The base memref should have identity layout map (or none).
-  if (baseType.getAffineMaps().size() > 1 ||
-      (baseType.getAffineMaps().size() == 1 &&
-       !baseType.getAffineMaps()[0].isIdentity()))
+  if (!baseType.getLayout().isIdentity())
     return op.emitError("unsupported map for base memref type ") << baseType;
 
   // The result memref should have identity layout map (or none).
-  if (viewType.getAffineMaps().size() > 1 ||
-      (viewType.getAffineMaps().size() == 1 &&
-       !viewType.getAffineMaps()[0].isIdentity()))
+  if (!viewType.getLayout().isIdentity())
     return op.emitError("unsupported map for result memref type ") << viewType;
 
   // The base memref and the view memref should be in the same memory space.

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index fdec267934bb7..769a416278150 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -3767,16 +3767,17 @@ void TypeCastOp::build(OpBuilder &builder, OperationState &result,
   VectorType vectorType =
       VectorType::get(extractShape(memRefType),
                       getElementTypeOrSelf(getElementTypeOrSelf(memRefType)));
-  result.addTypes(
-      MemRefType::get({}, vectorType, {}, memRefType.getMemorySpace()));
+  result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
+                                  memRefType.getMemorySpace()));
 }
 
 static LogicalResult verify(TypeCastOp op) {
   MemRefType canonicalType = canonicalizeStridedLayout(op.getMemRefType());
-  if (!canonicalType.getAffineMaps().empty())
-    return op.emitOpError("expects operand to be a memref with no layout");
-  if (!op.getResultMemRefType().getAffineMaps().empty())
-    return op.emitOpError("expects result to be a memref with no layout");
+  if (!canonicalType.getLayout().isIdentity())
+    return op.emitOpError(
+        "expects operand to be a memref with identity layout");
+  if (!op.getResultMemRefType().getLayout().isIdentity())
+    return op.emitOpError("expects result to be a memref with identity layout");
   if (op.getResultMemRefType().getMemorySpace() !=
       op.getMemRefType().getMemorySpace())
     return op.emitOpError("expects result in same memory space");

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index e2e72d1e00755..cf1eb8b56d807 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1971,9 +1971,9 @@ void AsmPrinter::Impl::printType(Type type) {
           os << 'x';
         }
         printType(memrefTy.getElementType());
-        for (auto map : memrefTy.getAffineMaps()) {
+        if (!memrefTy.getLayout().isIdentity()) {
           os << ", ";
-          printAttribute(AffineMapAttr::get(map));
+          printAttribute(memrefTy.getLayout(), AttrTypeElision::May);
         }
         // Only print the memory space if it is the non-default one.
         if (memrefTy.getMemorySpace()) {

diff  --git a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
index 6bfa1ee6633ab..96992c219aa0c 100644
--- a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
 #include "llvm/ADT/Sequence.h"
 
 using namespace mlir;
@@ -72,3 +73,17 @@ uint64_t ElementsAttr::getFlattenedIndex(Attribute elementsAttr,
   }
   return valueIndex;
 }
+
+//===----------------------------------------------------------------------===//
+// MemRefLayoutAttrInterface
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::detail::verifyAffineMapAsLayout(
+    AffineMap m, ArrayRef<int64_t> shape,
+    function_ref<InFlightDiagnostic()> emitError) {
+  if (m.getNumDims() != shape.size())
+    return emitError() << "memref layout mismatch between rank and affine map: "
+                       << shape.size() << " != " << m.getNumDims();
+
+  return success();
+}

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index a87dd55d8b261..2ca794eddf2b0 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -646,9 +646,118 @@ unsigned MemRefType::getMemorySpaceAsInt() const {
   return detail::getMemorySpaceAsInt(getMemorySpace());
 }
 
+MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
+                           MemRefLayoutAttrInterface layout,
+                           Attribute memorySpace) {
+  // Use default layout for empty attribute.
+  if (!layout)
+    layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
+        shape.size(), elementType.getContext()));
+
+  // Drop default memory space value and replace it with empty attribute.
+  memorySpace = skipDefaultMemorySpace(memorySpace);
+
+  return Base::get(elementType.getContext(), shape, elementType, layout,
+                   memorySpace);
+}
+
+MemRefType MemRefType::getChecked(
+    function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
+    Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
+
+  // Use default layout for empty attribute.
+  if (!layout)
+    layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
+        shape.size(), elementType.getContext()));
+
+  // Drop default memory space value and replace it with empty attribute.
+  memorySpace = skipDefaultMemorySpace(memorySpace);
+
+  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
+                          elementType, layout, memorySpace);
+}
+
+MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
+                           AffineMap map, Attribute memorySpace) {
+
+  // Use default layout for empty map.
+  if (!map)
+    map = AffineMap::getMultiDimIdentityMap(shape.size(),
+                                            elementType.getContext());
+
+  // Wrap AffineMap into Attribute.
+  Attribute layout = AffineMapAttr::get(map);
+
+  // Drop default memory space value and replace it with empty attribute.
+  memorySpace = skipDefaultMemorySpace(memorySpace);
+
+  return Base::get(elementType.getContext(), shape, elementType, layout,
+                   memorySpace);
+}
+
+MemRefType
+MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
+                       ArrayRef<int64_t> shape, Type elementType, AffineMap map,
+                       Attribute memorySpace) {
+
+  // Use default layout for empty map.
+  if (!map)
+    map = AffineMap::getMultiDimIdentityMap(shape.size(),
+                                            elementType.getContext());
+
+  // Wrap AffineMap into Attribute.
+  Attribute layout = AffineMapAttr::get(map);
+
+  // Drop default memory space value and replace it with empty attribute.
+  memorySpace = skipDefaultMemorySpace(memorySpace);
+
+  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
+                          elementType, layout, memorySpace);
+}
+
+MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
+                           AffineMap map, unsigned memorySpaceInd) {
+
+  // Use default layout for empty map.
+  if (!map)
+    map = AffineMap::getMultiDimIdentityMap(shape.size(),
+                                            elementType.getContext());
+
+  // Wrap AffineMap into Attribute.
+  Attribute layout = AffineMapAttr::get(map);
+
+  // Convert deprecated integer-like memory space to Attribute.
+  Attribute memorySpace =
+      wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
+
+  return Base::get(elementType.getContext(), shape, elementType, layout,
+                   memorySpace);
+}
+
+MemRefType
+MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
+                       ArrayRef<int64_t> shape, Type elementType, AffineMap map,
+                       unsigned memorySpaceInd) {
+
+  // Use default layout for empty map.
+  if (!map)
+    map = AffineMap::getMultiDimIdentityMap(shape.size(),
+                                            elementType.getContext());
+
+  // Wrap AffineMap into Attribute.
+  Attribute layout = AffineMapAttr::get(map);
+
+  // Convert deprecated integer-like memory space to Attribute.
+  Attribute memorySpace =
+      wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
+
+  return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
+                          elementType, layout, memorySpace);
+}
+
 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
                                  ArrayRef<int64_t> shape, Type elementType,
-                                 ArrayRef<AffineMap> affineMapComposition,
+                                 MemRefLayoutAttrInterface layout,
                                  Attribute memorySpace) {
   if (!BaseMemRefType::isValidElementType(elementType))
     return emitError() << "invalid memref element type";
@@ -658,26 +767,12 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
     if (s < -1)
       return emitError() << "invalid memref size";
 
-  // Check that the structure of the composition is valid, i.e. that each
-  // subsequent affine map has as many inputs as the previous map has results.
-  // Take the dimensionality of the MemRef for the first map.
-  size_t dim = shape.size();
-  for (auto it : llvm::enumerate(affineMapComposition)) {
-    AffineMap map = it.value();
-    if (map.getNumDims() == dim) {
-      dim = map.getNumResults();
-      continue;
-    }
-    return emitError() << "memref affine map dimension mismatch between "
-                       << (it.index() == 0 ? Twine("memref rank")
-                                           : "affine map " + Twine(it.index()))
-                       << " and affine map" << it.index() + 1 << ": " << dim
-                       << " != " << map.getNumDims();
-  }
+  assert(layout && "missing layout specification");
+  if (failed(layout.verifyLayout(shape, emitError)))
+    return failure();
 
-  if (!isSupportedMemorySpace(memorySpace)) {
+  if (!isSupportedMemorySpace(memorySpace))
     return emitError() << "unsupported memory space Attribute";
-  }
 
   return success();
 }
@@ -686,9 +781,9 @@ void MemRefType::walkImmediateSubElements(
     function_ref<void(Attribute)> walkAttrsFn,
     function_ref<void(Type)> walkTypesFn) const {
   walkTypesFn(getElementType());
+  if (!getLayout().isIdentity())
+    walkAttrsFn(getLayout());
   walkAttrsFn(getMemorySpace());
-  for (AffineMap map : getAffineMaps())
-    walkAttrsFn(AffineMapAttr::get(map));
 }
 
 //===----------------------------------------------------------------------===//
@@ -775,23 +870,18 @@ static LogicalResult extractStrides(AffineExpr e,
 LogicalResult mlir::getStridesAndOffset(MemRefType t,
                                         SmallVectorImpl<AffineExpr> &strides,
                                         AffineExpr &offset) {
-  auto affineMaps = t.getAffineMaps();
-
-  if (affineMaps.size() > 1)
-    return failure();
+  AffineMap m = t.getLayout().getAffineMap();
 
-  if (!affineMaps.empty() && affineMaps.back().getNumResults() != 1)
+  if (m.getNumResults() != 1 && !m.isIdentity())
     return failure();
 
-  AffineMap m = affineMaps.empty() ? AffineMap() : affineMaps.back();
-
   auto zero = getAffineConstantExpr(0, t.getContext());
   auto one = getAffineConstantExpr(1, t.getContext());
   offset = zero;
   strides.assign(t.getRank(), zero);
 
   // Canonical case for empty map.
-  if (!m || m.isIdentity()) {
+  if (m.isIdentity()) {
     // 0-D corner case, offset is already 0.
     if (t.getRank() == 0)
       return success();
@@ -938,21 +1028,21 @@ AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
 /// `t` with simplified layout.
 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
-  auto affineMaps = t.getAffineMaps();
+  AffineMap m = t.getLayout().getAffineMap();
+
   // Already in canonical form.
-  if (affineMaps.empty())
+  if (m.isIdentity())
     return t;
 
   // Can't reduce to canonical identity form, return in canonical form.
-  if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1)
+  if (m.getNumResults() > 1)
     return t;
 
   // Corner-case for 0-D affine maps.
-  auto m = affineMaps[0];
   if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
     if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
       if (cst.getValue() == 0)
-        return MemRefType::Builder(t).setAffineMaps({});
+        return MemRefType::Builder(t).setLayout({});
     return t;
   }
 
@@ -970,9 +1060,9 @@ MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
   auto simplifiedLayoutExpr =
       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
   if (expr != simplifiedLayoutExpr)
-    return MemRefType::Builder(t).setAffineMaps({AffineMap::get(
-        m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)});
-  return MemRefType::Builder(t).setAffineMaps({});
+    return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get(
+        m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
+  return MemRefType::Builder(t).setLayout({});
 }
 
 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
@@ -1016,8 +1106,9 @@ AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
 /// strides. This is used to erase the static layout.
 MemRefType mlir::eraseStridedLayout(MemRefType t) {
   auto val = ShapedType::kDynamicStrideOrOffset;
-  return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap(
-      SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()));
+  return MemRefType::Builder(t).setLayout(
+      AffineMapAttr::get(makeStridedLinearLayoutMap(
+          SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext())));
 }
 
 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,

diff  --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp
index 46fbf52d25bb6..256d6c0a96a24 100644
--- a/mlir/lib/Parser/TypeParser.cpp
+++ b/mlir/lib/Parser/TypeParser.cpp
@@ -185,9 +185,8 @@ ParseResult Parser::parseStridedLayout(int64_t &offset,
 ///
 ///   stride-list ::= `[` (dimension (`,` dimension)*)? `]`
 ///   strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
-///   semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
-///   layout-specification ::= semi-affine-map-composition | strided-layout
-///   memory-space ::= integer-literal /* | TODO: address-space-id */
+///   layout-specification ::= semi-affine-map | strided-layout | attribute
+///   memory-space ::= integer-literal | attribute
 ///
 Type Parser::parseMemRefType() {
   llvm::SMLoc loc = getToken().getLoc();
@@ -221,15 +220,10 @@ Type Parser::parseMemRefType() {
   if (!BaseMemRefType::isValidElementType(elementType))
     return emitError(typeLoc, "invalid memref element type"), nullptr;
 
-  // Parse semi-affine-map-composition.
-  SmallVector<AffineMap, 2> affineMapComposition;
+  MemRefLayoutAttrInterface layout;
   Attribute memorySpace;
-  unsigned numDims = dimensions.size();
 
   auto parseElt = [&]() -> ParseResult {
-    AffineMap map;
-    llvm::SMLoc mapLoc = getToken().getLoc();
-
     // Check for AffineMap as offset/strides.
     if (getToken().is(Token::kw_offset)) {
       int64_t offset;
@@ -237,15 +231,17 @@ Type Parser::parseMemRefType() {
       if (failed(parseStridedLayout(offset, strides)))
         return failure();
       // Construct strided affine map.
-      map = makeStridedLinearLayoutMap(strides, offset, state.context);
+      AffineMap map =
+          makeStridedLinearLayoutMap(strides, offset, state.context);
+      layout = AffineMapAttr::get(map);
     } else {
-      // Either it is AffineMapAttr or memory space attribute.
+      // Either it is MemRefLayoutAttrInterface or memory space attribute.
       Attribute attr = parseAttribute();
       if (!attr)
         return failure();
 
-      if (AffineMapAttr affineMapAttr = attr.dyn_cast<AffineMapAttr>()) {
-        map = affineMapAttr.getValue();
+      if (attr.isa<MemRefLayoutAttrInterface>()) {
+        layout = attr.cast<MemRefLayoutAttrInterface>();
       } else if (memorySpace) {
         return emitError("multiple memory spaces specified in memref type");
       } else {
@@ -259,15 +255,6 @@ Type Parser::parseMemRefType() {
     if (memorySpace)
       return emitError("expected memory space to be last in memref type");
 
-    if (map.getNumDims() != numDims) {
-      size_t i = affineMapComposition.size();
-      return emitError(mapLoc, "memref affine map dimension mismatch between ")
-             << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i))
-             << " and affine map" << i + 1 << ": " << numDims
-             << " != " << map.getNumDims();
-    }
-    numDims = map.getNumResults();
-    affineMapComposition.push_back(map);
     return success();
   };
 
@@ -284,8 +271,8 @@ Type Parser::parseMemRefType() {
   if (isUnranked)
     return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace);
 
-  return getChecked<MemRefType>(loc, dimensions, elementType,
-                                affineMapComposition, memorySpace);
+  return getChecked<MemRefType>(loc, dimensions, elementType, layout,
+                                memorySpace);
 }
 
 /// Parse any type except the function type.

diff  --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp
index c2b3a956d997e..148592a2c689e 100644
--- a/mlir/lib/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp
@@ -225,7 +225,7 @@ void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp,
         // memref type is normalized.
         // TODO: When selective normalization is implemented, handle multiple
         // results case where some are normalized, some aren't.
-        if (memrefType.getAffineMaps().empty())
+        if (memrefType.getLayout().isIdentity())
           resultTypes[operandEn.index()] = memrefType;
       }
     });
@@ -269,7 +269,7 @@ void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp,
       if (oldResult.getType() == newResult.getType())
         continue;
       AffineMap layoutMap =
-          oldResult.getType().dyn_cast<MemRefType>().getAffineMaps().front();
+          oldResult.getType().cast<MemRefType>().getLayout().getAffineMap();
       if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult,
                                           /*extraIndices=*/{},
                                           /*indexRemap=*/layoutMap,
@@ -363,7 +363,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(FuncOp funcOp,
     BlockArgument newMemRef =
         funcOp.front().insertArgument(argIndex, newMemRefType);
     BlockArgument oldMemRef = funcOp.getArgument(argIndex + 1);
-    AffineMap layoutMap = memrefType.getAffineMaps().front();
+    AffineMap layoutMap = memrefType.getLayout().getAffineMap();
     // Replace all uses of the old memref.
     if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef,
                                         /*extraIndices=*/{},
@@ -412,7 +412,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(FuncOp funcOp,
           if (oldMemRefType == newMemRefType)
             continue;
           // TODO: Assume single layout map. Multiple maps not supported.
-          AffineMap layoutMap = oldMemRefType.getAffineMaps().front();
+          AffineMap layoutMap = oldMemRefType.getLayout().getAffineMap();
           if (failed(replaceAllMemRefUsesWith(oldMemRef,
                                               /*newMemRef=*/newMemRef,
                                               /*extraIndices=*/{},

diff  --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp
index eaeb3475e6f17..500203b19a185 100644
--- a/mlir/lib/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -74,9 +74,7 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
     SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank());
     newShape[0] = 2;
     std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1);
-    return MemRefType::Builder(oldMemRefType)
-        .setShape(newShape)
-        .setAffineMaps({});
+    return MemRefType::Builder(oldMemRefType).setShape(newShape).setLayout({});
   };
 
   auto oldMemRefType = oldMemRef.getType().cast<MemRefType>();

diff  --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 30968f9b1de4f..c43ee3f70ea63 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -2648,9 +2648,7 @@ static LogicalResult generateCopy(
   auto memref = region.memref;
   auto memRefType = memref.getType().cast<MemRefType>();
 
-  auto layoutMaps = memRefType.getAffineMaps();
-  if (layoutMaps.size() > 1 ||
-      (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) {
+  if (!memRefType.getLayout().isIdentity()) {
     LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
     return failure();
   }

diff  --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp
index f37740d5317b8..8b40085e380a4 100644
--- a/mlir/lib/Transforms/Utils/Utils.cpp
+++ b/mlir/lib/Transforms/Utils/Utils.cpp
@@ -647,7 +647,7 @@ LogicalResult mlir::normalizeMemRef(memref::AllocOp *allocOp) {
   Value oldMemRef = allocOp->getResult();
 
   SmallVector<Value, 4> symbolOperands(allocOp->symbolOperands());
-  AffineMap layoutMap = memrefType.getAffineMaps().front();
+  AffineMap layoutMap = memrefType.getLayout().getAffineMap();
   memref::AllocOp newAlloc;
   // Check if `layoutMap` is a tiled layout. Only single layout map is
   // supported for normalizing dynamic memrefs.
@@ -695,13 +695,12 @@ MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b,
   if (rank == 0)
     return memrefType;
 
-  ArrayRef<AffineMap> layoutMaps = memrefType.getAffineMaps();
-  if (layoutMaps.empty() ||
-      layoutMaps.front() == b.getMultiDimIdentityMap(rank)) {
+  if (memrefType.getLayout().isIdentity()) {
     // Either no maps is associated with this memref or this memref has
     // a trivial (identity) map.
     return memrefType;
   }
+  AffineMap layoutMap = memrefType.getLayout().getAffineMap();
 
   // We don't do any checks for one-to-one'ness; we assume that it is
   // one-to-one.
@@ -710,7 +709,7 @@ MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b,
   // for now.
   // TODO: Normalize the other types of dynamic memrefs.
   SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
-  (void)getTileSizePos(layoutMaps.front(), tileSizePos);
+  (void)getTileSizePos(layoutMap, tileSizePos);
   if (memrefType.getNumDynamicDims() > 0 && tileSizePos.empty())
     return memrefType;
 
@@ -731,7 +730,6 @@ MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b,
   }
   // We compose this map with the original index (logical) space to derive
   // the upper bounds for the new index space.
-  AffineMap layoutMap = layoutMaps.front();
   unsigned newRank = layoutMap.getNumResults();
   if (failed(fac.composeMatchingMap(layoutMap)))
     return memrefType;
@@ -763,7 +761,7 @@ MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b,
   MemRefType newMemRefType =
       MemRefType::Builder(memrefType)
           .setShape(newShape)
-          .setAffineMaps(b.getMultiDimIdentityMap(newRank));
+          .setLayout(AffineMapAttr::get(b.getMultiDimIdentityMap(newRank)));
 
   return newMemRefType;
 }

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 4746f5104e7e8..fd245cd6afb17 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -726,7 +726,6 @@ static int printBuiltinTypes(MlirContext ctx) {
   MlirType memRef = mlirMemRefTypeContiguousGet(
       f32, sizeof(shape) / sizeof(int64_t), shape, memSpace2);
   if (!mlirTypeIsAMemRef(memRef) ||
-      mlirMemRefTypeGetNumAffineMaps(memRef) != 0 ||
       !mlirAttributeEqual(mlirMemRefTypeGetMemorySpace(memRef), memSpace2))
     return 18;
   mlirTypeDump(memRef);

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 3c97530ec651b..6aa9679117cb0 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1095,7 +1095,7 @@ func @flat_transpose_type_mismatch(%arg0: vector<16xf32>) {
 // -----
 
 func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>>) {
-  // expected-error at +1 {{expects operand to be a memref with no layout}}
+  // expected-error at +1 {{expects operand to be a memref with identity layout}}
   %0 = vector.type_cast %arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>> to memref<vector<4x3xf32>>
 }
 

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index e026b54d94c5e..cafe3440ace52 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -104,7 +104,8 @@ func @test_store_zero_results2(%x: i32, %p: memref<i32>) {
 
 func @test_alloc_memref_map_rank_mismatch() {
 ^bb0:
-  %0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0) -> (d0)>, 1> // expected-error {{memref affine map dimension mismatch}}
+  // expected-error at +1 {{memref layout mismatch between rank and affine map: 2 != 1}}
+  %0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0) -> (d0)>, 1>
   return
 }
 

diff  --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index 8b9bce076749c..b9187dc96673c 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -61,13 +61,7 @@ func @memrefs(memref<2x4xi8, #map0, 1, #map1>) // expected-error {{expected memo
 // The error must be emitted even for the trivial identity layout maps that are
 // dropped in type creation.
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
-func @memrefs(memref<42xi8, #map0>) // expected-error {{memref affine map dimension mismatch}}
-
-// -----
-
-#map0 = affine_map<(d0, d1) -> (d0, d1)>
-#map1 = affine_map<(d0) -> (d0)>
-func @memrefs(memref<42x42xi8, #map0, #map1>) // expected-error {{memref affine map dimension mismatch}}
+func @memrefs(memref<42xi8, #map0>) // expected-error {{memref layout mismatch between rank and affine map: 1 != 2}}
 
 // -----
 

diff  --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index 1b31c89e91caa..d1f4ab64f7f37 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -12,9 +12,6 @@
 // CHECK-DAG: #map{{[0-9]+}} = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
 #map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
 
-// CHECK-DAG: #map{{[0-9]+}} = affine_map<(d0, d1, d2) -> (d2, d1, d0)>
-#map4 = affine_map<(d0, d1, d2) -> (d2, d1, d0)>
-
 // CHECK-DAG: #map{{[0-9]+}} = affine_map<()[s0] -> (0, s0 - 1)>
 #inline_map_minmax_loop1 = affine_map<()[s0] -> (0, s0 - 1)>
 
@@ -80,28 +77,15 @@ func private @tensors(tensor<* x f32>, tensor<* x vector<2x4xf32>>,
 // CHECK: func private @tensor_encoding(tensor<16x32xf64, "sparse">)
 func private @tensor_encoding(tensor<16x32xf64, "sparse">)
 
-// CHECK: func private @memrefs(memref<1x?x4x?x?xi32, #map{{[0-9]+}}>, memref<8xi8>)
-func private @memrefs(memref<1x?x4x?x?xi32, #map0>, memref<8xi8, #map1, #map1>)
-
-// Test memref affine map compositions.
+// CHECK: func private @functions((memref<1x?x4x?x?xi32, #map0>, memref<8xi8>) -> (), () -> ())
+func private @functions((memref<1x?x4x?x?xi32, #map0, 0>, memref<8xi8, #map1, 0>) -> (), ()->())
 
 // CHECK: func private @memrefs2(memref<2x4x8xi8, 1>)
 func private @memrefs2(memref<2x4x8xi8, #map2, 1>)
 
-// CHECK: func private @memrefs23(memref<2x4x8xi8, #map{{[0-9]+}}>)
-func private @memrefs23(memref<2x4x8xi8, #map2, #map3, 0>)
-
-// CHECK: func private @memrefs234(memref<2x4x8xi8, #map{{[0-9]+}}, #map{{[0-9]+}}, 3>)
-func private @memrefs234(memref<2x4x8xi8, #map2, #map3, #map4, 3>)
-
-// Test memref inline affine map compositions, minding that identity maps are removed.
-
 // CHECK: func private @memrefs3(memref<2x4x8xi8>)
 func private @memrefs3(memref<2x4x8xi8, affine_map<(d0, d1, d2) -> (d0, d1, d2)>>)
 
-// CHECK: func private @memrefs33(memref<2x4x8xi8, #map{{[0-9]+}}, 1>)
-func private @memrefs33(memref<2x4x8xi8, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>, 1>)
-
 // CHECK: func private @memrefs_drop_triv_id_inline(memref<2xi8>)
 func private @memrefs_drop_triv_id_inline(memref<2xi8, affine_map<(d0) -> (d0)>>)
 
@@ -111,35 +95,6 @@ func private @memrefs_drop_triv_id_inline0(memref<2xi8, affine_map<(d0) -> (d0)>
 // CHECK: func private @memrefs_drop_triv_id_inline1(memref<2xi8, 1>)
 func private @memrefs_drop_triv_id_inline1(memref<2xi8, affine_map<(d0) -> (d0)>, 1>)
 
-// Identity maps should be dropped from the composition, but not the pair of
-// "interchange" maps that, if composed, would be also an identity.
-// CHECK: func private @memrefs_drop_triv_id_composition(memref<2x2xi8, #map{{[0-9]+}}, #map{{[0-9]+}}>)
-func private @memrefs_drop_triv_id_composition(memref<2x2xi8,
-                                                affine_map<(d0, d1) -> (d1, d0)>,
-                                                affine_map<(d0, d1) -> (d0, d1)>,
-                                                affine_map<(d0, d1) -> (d1, d0)>,
-                                                affine_map<(d0, d1) -> (d0, d1)>,
-                                                affine_map<(d0, d1) -> (d0, d1)>>)
-
-// CHECK: func private @memrefs_drop_triv_id_trailing(memref<2x2xi8, #map{{[0-9]+}}>)
-func private @memrefs_drop_triv_id_trailing(memref<2x2xi8, affine_map<(d0, d1) -> (d1, d0)>,
-                                                   affine_map<(d0, d1) -> (d0, d1)>>)
-
-// CHECK: func private @memrefs_drop_triv_id_middle(memref<2x2xi8, #map{{[0-9]+}}, #map{{[0-9]+}}>)
-func private @memrefs_drop_triv_id_middle(memref<2x2xi8,
-                                         affine_map<(d0, d1) -> (d0, d1 + 1)>,
-                                         affine_map<(d0, d1) -> (d0, d1)>,
-                                         affine_map<(d0, d1) -> (d0 + 1, d1)>>)
-
-// CHECK: func private @memrefs_drop_triv_id_multiple(memref<2xi8>)
-func private @memrefs_drop_triv_id_multiple(memref<2xi8, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>>)
-
-// These maps appeared before, so they must be uniqued and hoisted to the beginning.
-// Identity map should be removed.
-// CHECK: func private @memrefs_compose_with_id(memref<2x2xi8, #map{{[0-9]+}}>)
-func private @memrefs_compose_with_id(memref<2x2xi8, affine_map<(d0, d1) -> (d0, d1)>,
-                                             affine_map<(d0, d1) -> (d1, d0)>>)
-
 // Test memref with custom memory space
 
 // CHECK: func private @memrefs_nomap_nospace(memref<5x6x7xf32>)
@@ -202,9 +157,6 @@ func private @unranked_memref_with_index_elems(memref<*xindex>)
 // CHECK: func private @unranked_memref_with_vector_elems(memref<*xvector<10xf32>>)
 func private @unranked_memref_with_vector_elems(memref<*xvector<10xf32>>)
 
-// CHECK: func private @functions((memref<1x?x4x?x?xi32, #map0>, memref<8xi8>) -> (), () -> ())
-func private @functions((memref<1x?x4x?x?xi32, #map0, 0>, memref<8xi8, #map1, 0>) -> (), ()->())
-
 // CHECK-LABEL: func @simpleCFG(%{{.*}}: i32, %{{.*}}: f32) -> i1 {
 func @simpleCFG(%arg0: i32, %f: f32) -> i1 {
   // CHECK: %{{.*}} = "foo"() : () -> i64

diff  --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index ab6502e1d61fc..911391f2d528b 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -372,18 +372,21 @@ def testMemRefType():
     memref = MemRefType.get(shape, f32, memory_space=Attribute.parse("2"))
     # CHECK: memref type: memref<2x3xf32, 2>
     print("memref type:", memref)
-    # CHECK: number of affine layout maps: 0
-    print("number of affine layout maps:", len(memref.layout))
+    # CHECK: memref layout: affine_map<(d0, d1) -> (d0, d1)>
+    print("memref layout:", memref.layout)
+    # CHECK: memref affine map: (d0, d1) -> (d0, d1)
+    print("memref affine map:", memref.affine_map)
     # CHECK: memory space: 2
     print("memory space:", memref.memory_space)
 
-    layout = AffineMap.get_permutation([1, 0])
-    memref_layout = MemRefType.get(shape, f32, [layout])
+    layout = AffineMapAttr.get(AffineMap.get_permutation([1, 0]))
+    memref_layout = MemRefType.get(shape, f32, layout=layout)
     # CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>>
     print("memref type:", memref_layout)
-    assert len(memref_layout.layout) == 1
-    # CHECK: memref layout: (d0, d1) -> (d1, d0)
-    print("memref layout:", memref_layout.layout[0])
+    # CHECK: memref layout: affine_map<(d0, d1) -> (d1, d0)>
+    print("memref layout:", memref_layout.layout)
+    # CHECK: memref affine map: (d0, d1) -> (d1, d0)
+    print("memref affine map:", memref_layout.affine_map)
     # CHECK: memory space: <<NULL ATTRIBUTE>>
     print("memory space:", memref_layout.memory_space)
 

diff  --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index dc7591738b873..9c2a93e9429c6 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -32,26 +32,26 @@ TEST(ShapedTypeTest, CloneMemref) {
   ShapedType memrefType =
       MemRefType::Builder(memrefOriginalShape, memrefOriginalType)
           .setMemorySpace(memSpace)
-          .setAffineMaps(map);
+          .setLayout(AffineMapAttr::get(map));
   // Update shape.
   llvm::SmallVector<int64_t> memrefNewShape({30, 40});
   ASSERT_NE(memrefOriginalShape, memrefNewShape);
   ASSERT_EQ(memrefType.clone(memrefNewShape),
             (MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType)
                 .setMemorySpace(memSpace)
-                .setAffineMaps(map));
+                .setLayout(AffineMapAttr::get(map)));
   // Update type.
   Type memrefNewType = f32;
   ASSERT_NE(memrefOriginalType, memrefNewType);
   ASSERT_EQ(memrefType.clone(memrefNewType),
             (MemRefType)MemRefType::Builder(memrefOriginalShape, memrefNewType)
                 .setMemorySpace(memSpace)
-                .setAffineMaps(map));
+                .setLayout(AffineMapAttr::get(map)));
   // Update both.
   ASSERT_EQ(memrefType.clone(memrefNewShape, memrefNewType),
             (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType)
                 .setMemorySpace(memSpace)
-                .setAffineMaps(map));
+                .setLayout(AffineMapAttr::get(map)));
 
   // Test unranked memref cloning.
   ShapedType unrankedTensorType =


        


More information about the Mlir-commits mailing list