[Mlir-commits] [mlir] e41ebbe - [mlir][RFC] Refactor layout representation in MemRefType
Vladislav Vinogradov
llvmlistbot at llvm.org
Tue Oct 19 02:43:01 PDT 2021
Author: Vladislav Vinogradov
Date: 2021-10-19T12:31:15+03:00
New Revision: e41ebbecf97ac70326dfa2c54e3ef5fa13df54eb
URL: https://github.com/llvm/llvm-project/commit/e41ebbecf97ac70326dfa2c54e3ef5fa13df54eb
DIFF: https://github.com/llvm/llvm-project/commit/e41ebbecf97ac70326dfa2c54e3ef5fa13df54eb.diff
LOG: [mlir][RFC] Refactor layout representation in MemRefType
The change is based on the proposal from the following discussion:
https://llvm.discourse.group/t/rfc-memreftype-affine-maps-list-vs-single-item/3968
* Introduce `MemRefLayoutAttr` interface to get `AffineMap` from an `Attribute`
(`AffineMapAttr` implements this interface).
* Store layout as a single generic `MemRefLayoutAttr`.
This change removes the affine map composition feature and related API.
Actually, while the `MemRefType` itself supported it, almost none of the upstream
can work with more than 1 affine map in `MemRefType`.
The introduced `MemRefLayoutAttr` allows to re-implement this feature
in a more stable way - via separate attribute class.
Also the interface allows to use different layout representations rather than affine maps.
For example, the described "stride + offset" form, which is currently supported in ASM parser only,
can now be expressed as separate attribute.
Reviewed By: ftynse, bondhugula
Differential Revision: https://reviews.llvm.org/D111553
Added:
Modified:
mlir/include/mlir-c/BuiltinTypes.h
mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
mlir/include/mlir/IR/BuiltinAttributes.td
mlir/include/mlir/IR/BuiltinTypes.h
mlir/include/mlir/IR/BuiltinTypes.td
mlir/lib/Analysis/LoopAnalysis.cpp
mlir/lib/Analysis/Utils.cpp
mlir/lib/Bindings/Python/IRTypes.cpp
mlir/lib/CAPI/IR/BuiltinTypes.cpp
mlir/lib/Conversion/LLVMCommon/Pattern.cpp
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/BuiltinAttributeInterfaces.cpp
mlir/lib/IR/BuiltinTypes.cpp
mlir/lib/Parser/TypeParser.cpp
mlir/lib/Transforms/NormalizeMemRefs.cpp
mlir/lib/Transforms/PipelineDataTransfer.cpp
mlir/lib/Transforms/Utils/LoopUtils.cpp
mlir/lib/Transforms/Utils/Utils.cpp
mlir/test/CAPI/ir.c
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/IR/invalid-ops.mlir
mlir/test/IR/invalid.mlir
mlir/test/IR/parser.mlir
mlir/test/python/ir/builtin_types.py
mlir/unittests/IR/ShapedTypeTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index a677d4d365b11..2983627a57896 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -229,16 +229,17 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedMemRef(MlirType type);
/// Creates a MemRef type with the given rank and shape, a potentially empty
/// list of affine layout maps, the given memory space and element type, in the
/// same context as element type. The type is owned by the context.
-MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGet(
- MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps,
- MlirAffineMap const *affineMaps, MlirAttribute memorySpace);
+MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGet(MlirType elementType,
+ intptr_t rank,
+ const int64_t *shape,
+ MlirAttribute layout,
+ MlirAttribute memorySpace);
/// Same as "mlirMemRefTypeGet" but returns a nullptr-wrapping MlirType o
/// illegal arguments, emitting appropriate diagnostics.
MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGetChecked(
MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape,
- intptr_t numMaps, MlirAffineMap const *affineMaps,
- MlirAttribute memorySpace);
+ MlirAttribute layout, MlirAttribute memorySpace);
/// Creates a MemRef type with the given rank, shape, memory space and element
/// type in the same context as the element type. The type has no affine maps,
@@ -264,12 +265,11 @@ mlirUnrankedMemRefTypeGet(MlirType elementType, MlirAttribute memorySpace);
MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGetChecked(
MlirLocation loc, MlirType elementType, MlirAttribute memorySpace);
-/// Returns the number of affine layout maps in the given MemRef type.
-MLIR_CAPI_EXPORTED intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type);
+/// Returns the layout of the given MemRef type.
+MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetLayout(MlirType type);
-/// Returns the pos-th affine map of the given MemRef type.
-MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type,
- intptr_t pos);
+/// Returns the affine map of the given MemRef type.
+MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type);
/// Returns the memory space of the given MemRef type.
MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type);
diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
index 392bffc09da6e..c48a359383ff4 100644
--- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
@@ -9,6 +9,7 @@
#ifndef MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
#define MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
+#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LogicalResult.h"
@@ -227,6 +228,21 @@ class ElementsAttrIterator
ptr
diff _t index;
};
} // namespace detail
+
+//===----------------------------------------------------------------------===//
+// MemRefLayoutAttrInterface
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+
+// Verify the affine map 'm' can be used as a layout specification
+// for memref with 'shape'.
+LogicalResult
+verifyAffineMapAsLayout(AffineMap m, ArrayRef<int64_t> shape,
+ function_ref<InFlightDiagnostic()> emitError);
+
+} // namespace detail
+
} // namespace mlir
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
index fa8cb28223a0d..30b3ea7ca09a8 100644
--- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
@@ -432,4 +432,52 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
}] # ElementsAttrInterfaceAccessors;
}
+//===----------------------------------------------------------------------===//
+// MemRefLayoutAttrInterface
+//===----------------------------------------------------------------------===//
+
+def MemRefLayoutAttrInterface : AttrInterface<"MemRefLayoutAttrInterface"> {
+ let cppNamespace = "::mlir";
+
+ let description = [{
+ This interface is used for attributes that can represent the MemRef type's
+ layout semantics, such as dimension order in the memory, strides and offsets.
+ Such a layout attribute should be representable as a
+ [semi-affine map](Affine.md/#semi-affine-maps).
+
+ Note: the MemRef type's layout is assumed to represent simple strided buffer
+ layout. For more complicated case, like sparse storage buffers,
+ it is preferable to use separate type with more specic layout, rather then
+ introducing extra complexity to the builin MemRef type.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ "Get the MemRef layout as an AffineMap, the method must not return NULL",
+ "::mlir::AffineMap", "getAffineMap", (ins)
+ >,
+
+ InterfaceMethod<
+ "Return true if this attribute represents the identity layout",
+ "bool", "isIdentity", (ins),
+ [{}],
+ [{
+ return $_attr.getAffineMap().isIdentity();
+ }]
+ >,
+
+ InterfaceMethod<
+ "Check if the current layout is applicable to the provided shape",
+ "::mlir::LogicalResult", "verifyLayout",
+ (ins "::llvm::ArrayRef<int64_t>":$shape,
+ "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError),
+ [{}],
+ [{
+ return ::mlir::detail::verifyAffineMapAsLayout($_attr.getAffineMap(),
+ shape, emitError);
+ }]
+ >
+ ];
+}
+
#endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 0d3ead2383722..fcd6082a6ef85 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -34,7 +34,9 @@ class Builtin_Attr<string name, list<Trait> traits = [],
// AffineMapAttr
//===----------------------------------------------------------------------===//
-def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap"> {
+def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap", [
+ MemRefLayoutAttrInterface
+ ]> {
let summary = "An Attribute containing an AffineMap object";
let description = [{
Syntax:
@@ -56,7 +58,10 @@ def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap"> {
return $_get(value.getContext(), value);
}]>
];
- let extraClassDeclaration = "using ValueType = AffineMap;";
+ let extraClassDeclaration = [{
+ using ValueType = AffineMap;
+ AffineMap getAffineMap() const { return getValue(); }
+ }];
let skipDefaultBuilders = 1;
let typeBuilder = "IndexType::get($_value.getContext())";
}
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 47368f378dc7b..f2f3ccf537626 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -9,6 +9,7 @@
#ifndef MLIR_IR_BUILTINTYPES_H
#define MLIR_IR_BUILTINTYPES_H
+#include "BuiltinAttributeInterfaces.h"
#include "SubElementInterfaces.h"
namespace llvm {
@@ -209,12 +210,11 @@ class MemRefType::Builder {
// Build from another MemRefType.
explicit Builder(MemRefType other)
: shape(other.getShape()), elementType(other.getElementType()),
- affineMaps(other.getAffineMaps()), memorySpace(other.getMemorySpace()) {
- }
+ layout(other.getLayout()), memorySpace(other.getMemorySpace()) {}
// Build from scratch.
Builder(ArrayRef<int64_t> shape, Type elementType)
- : shape(shape), elementType(elementType), affineMaps() {}
+ : shape(shape), elementType(elementType) {}
Builder &setShape(ArrayRef<int64_t> newShape) {
shape = newShape;
@@ -226,8 +226,8 @@ class MemRefType::Builder {
return *this;
}
- Builder &setAffineMaps(ArrayRef<AffineMap> newAffineMaps) {
- affineMaps = newAffineMaps;
+ Builder &setLayout(MemRefLayoutAttrInterface newLayout) {
+ layout = newLayout;
return *this;
}
@@ -240,13 +240,13 @@ class MemRefType::Builder {
Builder &setMemorySpace(unsigned newMemorySpace);
operator MemRefType() {
- return MemRefType::get(shape, elementType, affineMaps, memorySpace);
+ return MemRefType::get(shape, elementType, layout, memorySpace);
}
private:
ArrayRef<int64_t> shape;
Type elementType;
- ArrayRef<AffineMap> affineMaps;
+ MemRefLayoutAttrInterface layout;
Attribute memorySpace;
};
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 7edb9ee5fac20..b27086144761e 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -278,8 +278,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", [
stride-list ::= `[` (dimension (`,` dimension)*)? `]`
strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
- semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
- layout-specification ::= semi-affine-map-composition | strided-layout
+ layout-specification ::= semi-affine-map | strided-layout | attribute-value
memory-space ::= attribute-value
```
@@ -486,27 +485,6 @@ def Builtin_MemRef : Builtin_Type<"MemRef", [
#layout_tiled = (i, j) -> (i floordiv 64, j floordiv 64, i mod 64, j mod 64)
```
- ##### Affine Map Composition
-
- A memref specifies a semi-affine map composition as part of its type. A
- semi-affine map composition is a composition of semi-affine maps beginning
- with zero or more index maps, and ending with a layout map. The composition
- must be conformant: the number of dimensions of the range of one map, must
- match the number of dimensions of the domain of the next map in the
- composition.
-
- The semi-affine map composition specified in the memref type, maps from
- accesses used to index the memref in load/store operations to other index
- spaces (i.e. logical to physical index mapping). Each of the
- [semi-affine maps](Affine.md/#semi-affine-maps) and thus its composition is required
- to be one-to-one.
-
- The semi-affine map composition can be used in dependence analysis, memory
- access pattern analysis, and for performance optimizations like
- vectorization, copy elision and in-place updates. If an affine map
- composition is not specified for the memref, the identity affine map is
- assumed.
-
##### Strided MemRef
A memref may specify a strided layout as part of its type. A stride
@@ -544,36 +522,23 @@ def Builtin_MemRef : Builtin_Type<"MemRef", [
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
"Type":$elementType,
- ArrayRefParameter<"AffineMap">:$affineMaps,
+ "MemRefLayoutAttrInterface":$layout,
"Attribute":$memorySpace
);
let builders = [
TypeBuilderWithInferredContext<(ins
"ArrayRef<int64_t>":$shape, "Type":$elementType,
- CArg<"ArrayRef<AffineMap>", "{}">:$affineMaps,
- CArg<"Attribute", "{}">:$memorySpace
- ), [{
- // Drop identity maps from the composition. This may lead to the
- // composition becoming empty, which is interpreted as an implicit
- // identity.
- auto nonIdentityMaps = llvm::make_filter_range(affineMaps,
- [](AffineMap map) { return !map.isIdentity(); });
- // Drop default memory space value and replace it with empty attribute.
- Attribute nonDefaultMemorySpace = skipDefaultMemorySpace(memorySpace);
- return $_get(elementType.getContext(), shape, elementType,
- llvm::to_vector<4>(nonIdentityMaps), nonDefaultMemorySpace);
- }]>,
+ CArg<"MemRefLayoutAttrInterface", "{}">:$layout,
+ CArg<"Attribute", "{}">:$memorySpace)>,
+ TypeBuilderWithInferredContext<(ins
+ "ArrayRef<int64_t>":$shape, "Type":$elementType,
+ CArg<"AffineMap">:$map,
+ CArg<"Attribute", "{}">:$memorySpace)>,
/// [deprecated] `Attribute`-based form should be used instead.
TypeBuilderWithInferredContext<(ins
"ArrayRef<int64_t>":$shape, "Type":$elementType,
- "ArrayRef<AffineMap>":$affineMaps,
- "unsigned":$memorySpace
- ), [{
- // Convert deprecated integer-like memory space to Attribute.
- Attribute memorySpaceAttr =
- wrapIntegerMemorySpace(memorySpace, elementType.getContext());
- return MemRefType::get(shape, elementType, affineMaps, memorySpaceAttr);
- }]>
+ "AffineMap":$map,
+ "unsigned":$memorySpaceInd)>
];
let extraClassDeclaration = [{
/// This is a builder type that keeps local references to arguments.
diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp
index d40958e877771..478b212c3b520 100644
--- a/mlir/lib/Analysis/LoopAnalysis.cpp
+++ b/mlir/lib/Analysis/LoopAnalysis.cpp
@@ -219,15 +219,8 @@ static bool isContiguousAccess(Value iv, LoadOrStoreOp memoryOp,
assert(memRefDim && "memRefDim == nullptr");
auto memRefType = memoryOp.getMemRefType();
- auto layoutMap = memRefType.getAffineMaps();
- // TODO: remove dependence on Builder once we support non-identity layout map.
- Builder b(memoryOp.getContext());
- if (layoutMap.size() >= 2 ||
- (layoutMap.size() == 1 &&
- !(layoutMap[0] ==
- b.getMultiDimIdentityMap(layoutMap[0].getNumDims())))) {
+ if (!memRefType.getLayout().isIdentity())
return memoryOp.emitError("NYI: non-trivial layoutMap"), false;
- }
int uniqueVaryingIndexAlongIv = -1;
auto accessMap = memoryOp.getAffineMap();
diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index b9e7b8479c65e..9565441863018 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -616,9 +616,7 @@ static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
Optional<int64_t> MemRefRegion::getRegionSize() {
auto memRefType = memref.getType().cast<MemRefType>();
- auto layoutMaps = memRefType.getAffineMaps();
- if (layoutMaps.size() > 1 ||
- (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) {
+ if (!memRefType.getLayout().isIdentity()) {
LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
return false;
}
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index fd9f3efe7405f..1cfd799bf6934 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -401,8 +401,6 @@ class PyUnrankedTensorType
}
};
-class PyMemRefLayoutMapList;
-
/// Ranked MemRef Type subclass - MemRefType.
class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
public:
@@ -410,26 +408,18 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
static constexpr const char *pyClassName = "MemRefType";
using PyConcreteType::PyConcreteType;
- PyMemRefLayoutMapList getLayout();
-
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](std::vector<int64_t> shape, PyType &elementType,
- std::vector<PyAffineMap> layout, PyAttribute *memorySpace,
+ PyAttribute *layout, PyAttribute *memorySpace,
DefaultingPyLocation loc) {
- SmallVector<MlirAffineMap> maps;
- maps.reserve(layout.size());
- for (PyAffineMap &map : layout)
- maps.push_back(map);
-
- MlirAttribute memSpaceAttr = {};
- if (memorySpace)
- memSpaceAttr = *memorySpace;
-
- MlirType t = mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
- shape.data(), maps.size(),
- maps.data(), memSpaceAttr);
+ MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
+ MlirAttribute memSpaceAttr =
+ memorySpace ? *memorySpace : mlirAttributeGetNull();
+ MlirType t =
+ mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
+ shape.data(), layoutAttr, memSpaceAttr);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
@@ -444,10 +434,22 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
return PyMemRefType(elementType.getContext(), t);
},
py::arg("shape"), py::arg("element_type"),
- py::arg("layout") = py::list(), py::arg("memory_space") = py::none(),
+ py::arg("layout") = py::none(), py::arg("memory_space") = py::none(),
py::arg("loc") = py::none(), "Create a memref type")
- .def_property_readonly("layout", &PyMemRefType::getLayout,
- "The list of layout maps of the MemRef type.")
+ .def_property_readonly(
+ "layout",
+ [](PyMemRefType &self) -> PyAttribute {
+ MlirAttribute layout = mlirMemRefTypeGetLayout(self);
+ return PyAttribute(self.getContext(), layout);
+ },
+ "The layout of the MemRef type.")
+ .def_property_readonly(
+ "affine_map",
+ [](PyMemRefType &self) -> PyAffineMap {
+ MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
+ return PyAffineMap(self.getContext(), map);
+ },
+ "The layout of the MemRef type as an affine map.")
.def_property_readonly(
"memory_space",
[](PyMemRefType &self) -> PyAttribute {
@@ -458,41 +460,6 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
}
};
-/// A list of affine layout maps in a memref type. Internally, these are stored
-/// as consecutive elements, random access is cheap. Both the type and the maps
-/// are owned by the context, no need to worry about lifetime extension.
-class PyMemRefLayoutMapList
- : public Sliceable<PyMemRefLayoutMapList, PyAffineMap> {
-public:
- static constexpr const char *pyClassName = "MemRefLayoutMapList";
-
- PyMemRefLayoutMapList(PyMemRefType type, intptr_t startIndex = 0,
- intptr_t length = -1, intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirMemRefTypeGetNumAffineMaps(type) : length,
- step),
- memref(type) {}
-
- intptr_t getNumElements() { return mlirMemRefTypeGetNumAffineMaps(memref); }
-
- PyAffineMap getElement(intptr_t index) {
- return PyAffineMap(memref.getContext(),
- mlirMemRefTypeGetAffineMap(memref, index));
- }
-
- PyMemRefLayoutMapList slice(intptr_t startIndex, intptr_t length,
- intptr_t step) {
- return PyMemRefLayoutMapList(memref, startIndex, length, step);
- }
-
-private:
- PyMemRefType memref;
-};
-
-PyMemRefLayoutMapList PyMemRefType::getLayout() {
- return PyMemRefLayoutMapList(*this);
-}
-
/// Unranked MemRef Type subclass - UnrankedMemRefType.
class PyUnrankedMemRefType
: public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
@@ -640,7 +607,6 @@ void mlir::python::populateIRTypes(py::module &m) {
PyRankedTensorType::bind(m);
PyUnrankedTensorType::bind(m);
PyMemRefType::bind(m);
- PyMemRefLayoutMapList::bind(m);
PyUnrankedMemRefType::bind(m);
PyTupleType::bind(m);
PyFunctionType::bind(m);
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index d978f17b98d5b..318b8eb10c16f 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -226,34 +226,35 @@ MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc,
bool mlirTypeIsAMemRef(MlirType type) { return unwrap(type).isa<MemRefType>(); }
MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank,
- const int64_t *shape, intptr_t numMaps,
- MlirAffineMap const *affineMaps,
+ const int64_t *shape, MlirAttribute layout,
MlirAttribute memorySpace) {
- SmallVector<AffineMap, 1> maps;
- (void)unwrapList(numMaps, affineMaps, maps);
- return wrap(
- MemRefType::get(llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType), maps, unwrap(memorySpace)));
+ return wrap(MemRefType::get(
+ llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
+ mlirAttributeIsNull(layout)
+ ? MemRefLayoutAttrInterface()
+ : unwrap(layout).cast<MemRefLayoutAttrInterface>(),
+ unwrap(memorySpace)));
}
MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType,
intptr_t rank, const int64_t *shape,
- intptr_t numMaps,
- MlirAffineMap const *affineMaps,
+ MlirAttribute layout,
MlirAttribute memorySpace) {
- SmallVector<AffineMap, 1> maps;
- (void)unwrapList(numMaps, affineMaps, maps);
return wrap(MemRefType::getChecked(
unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType), maps, unwrap(memorySpace)));
+ unwrap(elementType),
+ mlirAttributeIsNull(layout)
+ ? MemRefLayoutAttrInterface()
+ : unwrap(layout).cast<MemRefLayoutAttrInterface>(),
+ unwrap(memorySpace)));
}
MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
const int64_t *shape,
MlirAttribute memorySpace) {
- return wrap(
- MemRefType::get(llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType), llvm::None, unwrap(memorySpace)));
+ return wrap(MemRefType::get(
+ llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
+ MemRefLayoutAttrInterface(), unwrap(memorySpace)));
}
MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc,
@@ -262,16 +263,15 @@ MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc,
MlirAttribute memorySpace) {
return wrap(MemRefType::getChecked(
unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType), llvm::None, unwrap(memorySpace)));
+ unwrap(elementType), MemRefLayoutAttrInterface(), unwrap(memorySpace)));
}
-intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) {
- return static_cast<intptr_t>(
- unwrap(type).cast<MemRefType>().getAffineMaps().size());
+MlirAttribute mlirMemRefTypeGetLayout(MlirType type) {
+ return wrap(unwrap(type).cast<MemRefType>().getLayout());
}
-MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type, intptr_t pos) {
- return wrap(unwrap(type).cast<MemRefType>().getAffineMaps()[pos]);
+MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type) {
+ return wrap(unwrap(type).cast<MemRefType>().getLayout().getAffineMap());
}
MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 7745b9585a5cc..47dabc90bce5b 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -106,9 +106,7 @@ bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps(
MemRefType type) const {
if (!typeConverter->convertType(type.getElementType()))
return false;
- return type.getAffineMaps().empty() ||
- llvm::all_of(type.getAffineMaps(),
- [](AffineMap map) { return map.isIdentity(); });
+ return type.getLayout().isIdentity();
}
Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 74462caf39b45..6188e8b571123 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1142,7 +1142,8 @@ class ReassociatingReshapeOpConversion
ConversionPatternRewriter &rewriter) const override {
MemRefType dstType = reshapeOp.getResultType();
MemRefType srcType = reshapeOp.getSrcType();
- if (!srcType.getAffineMaps().empty() || !dstType.getAffineMaps().empty()) {
+ if (!srcType.getLayout().isIdentity() ||
+ !dstType.getLayout().isIdentity()) {
return rewriter.notifyMatchFailure(reshapeOp,
"only empty layout map is supported");
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 765b58d8c3d22..a6f25332d1331 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -950,8 +950,7 @@ computeContiguousStrides(MemRefType memRefType) {
if (!strides.empty() && strides.back() != 1)
return None;
// If no layout or identity layout, this is contiguous by definition.
- if (memRefType.getAffineMaps().empty() ||
- memRefType.getAffineMaps().front().isIdentity())
+ if (memRefType.getLayout().isIdentity())
return strides;
// Otherwise, we must determine contiguity form shapes. This can only ever
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 9a15cb26e2b24..ba1710b57a919 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1047,8 +1047,7 @@ static LogicalResult verify(SubgroupMmaLoadMatrixOp op) {
auto srcMemrefType = srcType.cast<MemRefType>();
auto srcMemSpace = srcMemrefType.getMemorySpaceAsInt();
- if (!srcMemrefType.getAffineMaps().empty() &&
- !srcMemrefType.getAffineMaps().front().isIdentity())
+ if (!srcMemrefType.getLayout().isIdentity())
return op.emitError("expected identity layout map for source memref");
if (srcMemSpace != kGenericMemorySpace && srcMemSpace != kSharedMemorySpace &&
@@ -1074,9 +1073,7 @@ static LogicalResult verify(SubgroupMmaStoreMatrixOp op) {
auto srcMatrixType = srcType.cast<gpu::MMAMatrixType>();
auto dstMemrefType = dstType.cast<MemRefType>();
auto dstMemSpace = dstMemrefType.getMemorySpaceAsInt();
-
- if (!dstMemrefType.getAffineMaps().empty() &&
- !dstMemrefType.getAffineMaps().front().isIdentity())
+ if (!dstMemrefType.getLayout().isIdentity())
return op.emitError("expected identity layout map for destination memref");
if (dstMemSpace != kGenericMemorySpace && dstMemSpace != kSharedMemorySpace &&
diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index b7a05c3420532..1aebd90a2e660 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -159,9 +159,8 @@ struct GpuAllReduceRewriter {
Value createWorkgroupBuffer() {
int workgroupMemoryAddressSpace =
gpu::GPUDialect::getWorkgroupAddressSpace();
- auto bufferType =
- MemRefType::get({kSubgroupSize}, valueType, ArrayRef<AffineMap>{},
- workgroupMemoryAddressSpace);
+ auto bufferType = MemRefType::get({kSubgroupSize}, valueType, AffineMap{},
+ workgroupMemoryAddressSpace);
return funcOp.addWorkgroupAttribution(bufferType);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 7e9a0e0ed38a6..3d7919b6e7125 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -1219,27 +1219,24 @@ getEquivalentEnclosingFuncBBArg(Value v,
/// with the same shape as `shapedType` and specified `layout` and
/// `addressSpace`.
static MemRefType getContiguousMemRefType(ShapedType shapedType,
- ArrayRef<AffineMap> layout = {},
- unsigned addressSpace = 0) {
- if (RankedTensorType tensorType = shapedType.dyn_cast<RankedTensorType>())
- return MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
- layout, addressSpace);
- MemRefType memrefType = shapedType.cast<MemRefType>();
- return MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
- layout, addressSpace);
+ MemRefLayoutAttrInterface layout = {},
+ Attribute memorySpace = {}) {
+ return MemRefType::get(shapedType.getShape(), shapedType.getElementType(),
+ layout, memorySpace);
}
/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
/// with the same shape as `shapedType` and specified `layout` and
/// `addressSpace` or an UnrankedMemRefType otherwise.
-static Type getContiguousOrUnrankedMemRefType(Type type,
- ArrayRef<AffineMap> layout = {},
- unsigned addressSpace = 0) {
+static Type
+getContiguousOrUnrankedMemRefType(Type type,
+ MemRefLayoutAttrInterface layout = {},
+ Attribute memorySpace = {}) {
if (type.isa<RankedTensorType, MemRefType>())
return getContiguousMemRefType(type.cast<ShapedType>(), layout,
- addressSpace);
- assert(layout.empty() && "expected empty layout with UnrankedMemRefType");
- return UnrankedMemRefType::get(getElementTypeOrSelf(type), addressSpace);
+ memorySpace);
+ assert(!layout && "expected empty layout with UnrankedMemRefType");
+ return UnrankedMemRefType::get(getElementTypeOrSelf(type), memorySpace);
}
/// Return a MemRefType to which the `tensorType` can be bufferized in a
@@ -1644,16 +1641,16 @@ static LogicalResult bufferize(OpBuilder &b, tensor::CastOp castOp,
auto rankedMemRefType = sourceType.dyn_cast<MemRefType>();
auto unrankedMemRefType = sourceType.dyn_cast<UnrankedMemRefType>();
assert(rankedMemRefType || unrankedMemRefType);
- unsigned memorySpace = rankedMemRefType
- ? rankedMemRefType.getMemorySpaceAsInt()
- : unrankedMemRefType.getMemorySpaceAsInt();
+ Attribute memorySpace = rankedMemRefType
+ ? rankedMemRefType.getMemorySpace()
+ : unrankedMemRefType.getMemorySpace();
TensorType tensorType = castOp.getResult().getType().cast<TensorType>();
- ArrayRef<AffineMap> affineMaps =
+ MemRefLayoutAttrInterface layout =
rankedMemRefType && tensorType.isa<RankedTensorType>()
- ? rankedMemRefType.getAffineMaps()
- : ArrayRef<AffineMap>{};
+ ? rankedMemRefType.getLayout()
+ : MemRefLayoutAttrInterface();
Type memRefType = getContiguousOrUnrankedMemRefType(
- castOp.getResult().getType(), affineMaps, memorySpace);
+ castOp.getResult().getType(), layout, memorySpace);
Value res =
b.create<memref::CastOp>(castOp.getLoc(), memRefType, resultBuffer);
aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index c91cb0d4f7867..cb92185394938 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -258,7 +258,7 @@ replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand,
// leave them unchanged.
Type actualType = opOperand->get().getType();
if (auto memref = actualType.dyn_cast<MemRefType>()) {
- if (!memref.getAffineMaps().empty())
+ if (!memref.getLayout().isIdentity())
return llvm::None;
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 928b727db9a0f..e792d6581566d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -88,8 +88,8 @@ static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
"dynamic dimension count");
unsigned numSymbols = 0;
- if (!memRefType.getAffineMaps().empty())
- numSymbols = memRefType.getAffineMaps().front().getNumSymbols();
+ if (!memRefType.getLayout().isIdentity())
+ numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
if (op.symbolOperands().size() != numSymbols)
return op.emitOpError("symbol operand count does not equal memref symbol "
"count: expected ")
@@ -496,7 +496,7 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (aT && bT) {
if (aT.getElementType() != bT.getElementType())
return false;
- if (aT.getAffineMaps() != bT.getAffineMaps()) {
+ if (aT.getLayout() != bT.getLayout()) {
int64_t aOffset, bOffset;
SmallVector<int64_t, 4> aStrides, bStrides;
if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
@@ -1408,7 +1408,7 @@ static LogicalResult verify(ReinterpretCastOp op) {
// Match offset and strides in static_offset and static_strides attributes if
// result memref type has an affine map specified.
- if (!resultType.getAffineMaps().empty()) {
+ if (!resultType.getLayout().isIdentity()) {
int64_t resultOffset;
SmallVector<int64_t, 4> resultStrides;
if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
@@ -1526,8 +1526,8 @@ computeReshapeCollapsedType(MemRefType type,
}
// Early-exit: if `type` is contiguous, the result must be contiguous.
- if (canonicalizeStridedLayout(type).getAffineMaps().empty())
- return MemRefType::Builder(type).setShape(newSizes).setAffineMaps({});
+ if (canonicalizeStridedLayout(type).getLayout().isIdentity())
+ return MemRefType::Builder(type).setShape(newSizes).setLayout({});
// Convert back to int64_t because we don't have enough information to create
// new strided layouts from AffineExpr only. This corresponds to a case where
@@ -1546,7 +1546,8 @@ computeReshapeCollapsedType(MemRefType type,
auto layout =
makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext());
return canonicalizeStridedLayout(
- MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout}));
+ MemRefType::Builder(type).setShape(newSizes).setLayout(
+ AffineMapAttr::get(layout)));
}
void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src,
@@ -1662,14 +1663,14 @@ static LogicalResult verify(ReshapeOp op) {
"types should be the same");
if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
- if (!operandMemRefType.getAffineMaps().empty())
+ if (!operandMemRefType.getLayout().isIdentity())
return op.emitOpError(
"source memref type should have identity affine map");
int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
auto resultMemRefType = resultType.dyn_cast<MemRefType>();
if (resultMemRefType) {
- if (!resultMemRefType.getAffineMaps().empty())
+ if (!resultMemRefType.getLayout().isIdentity())
return op.emitOpError(
"result memref type should have identity affine map");
if (shapeSize == ShapedType::kDynamicSize)
@@ -1824,10 +1825,9 @@ Type SubViewOp::inferRankReducedResultType(
if (!dimsToProject.contains(pos))
projectedShape.push_back(shape[pos]);
- AffineMap map;
- auto maps = inferredType.getAffineMaps();
- if (!maps.empty() && maps.front())
- map = getProjectedMap(maps.front(), dimsToProject);
+ AffineMap map = inferredType.getLayout().getAffineMap();
+ if (!map.isIdentity())
+ map = getProjectedMap(map, dimsToProject);
inferredType =
MemRefType::get(projectedShape, inferredType.getElementType(), map,
inferredType.getMemorySpace());
@@ -2279,7 +2279,9 @@ static MemRefType inferTransposeResultType(MemRefType memRefType,
auto map =
makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
map = permutationMap ? map.compose(permutationMap) : map;
- return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map);
+ return MemRefType::Builder(memRefType)
+ .setShape(sizes)
+ .setLayout(AffineMapAttr::get(map));
}
void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
@@ -2387,15 +2389,11 @@ static LogicalResult verify(ViewOp op) {
auto viewType = op.getType();
// The base memref should have identity layout map (or none).
- if (baseType.getAffineMaps().size() > 1 ||
- (baseType.getAffineMaps().size() == 1 &&
- !baseType.getAffineMaps()[0].isIdentity()))
+ if (!baseType.getLayout().isIdentity())
return op.emitError("unsupported map for base memref type ") << baseType;
// The result memref should have identity layout map (or none).
- if (viewType.getAffineMaps().size() > 1 ||
- (viewType.getAffineMaps().size() == 1 &&
- !viewType.getAffineMaps()[0].isIdentity()))
+ if (!viewType.getLayout().isIdentity())
return op.emitError("unsupported map for result memref type ") << viewType;
// The base memref and the view memref should be in the same memory space.
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index fdec267934bb7..769a416278150 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -3767,16 +3767,17 @@ void TypeCastOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType =
VectorType::get(extractShape(memRefType),
getElementTypeOrSelf(getElementTypeOrSelf(memRefType)));
- result.addTypes(
- MemRefType::get({}, vectorType, {}, memRefType.getMemorySpace()));
+ result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
+ memRefType.getMemorySpace()));
}
static LogicalResult verify(TypeCastOp op) {
MemRefType canonicalType = canonicalizeStridedLayout(op.getMemRefType());
- if (!canonicalType.getAffineMaps().empty())
- return op.emitOpError("expects operand to be a memref with no layout");
- if (!op.getResultMemRefType().getAffineMaps().empty())
- return op.emitOpError("expects result to be a memref with no layout");
+ if (!canonicalType.getLayout().isIdentity())
+ return op.emitOpError(
+ "expects operand to be a memref with identity layout");
+ if (!op.getResultMemRefType().getLayout().isIdentity())
+ return op.emitOpError("expects result to be a memref with identity layout");
if (op.getResultMemRefType().getMemorySpace() !=
op.getMemRefType().getMemorySpace())
return op.emitOpError("expects result in same memory space");
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index e2e72d1e00755..cf1eb8b56d807 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1971,9 +1971,9 @@ void AsmPrinter::Impl::printType(Type type) {
os << 'x';
}
printType(memrefTy.getElementType());
- for (auto map : memrefTy.getAffineMaps()) {
+ if (!memrefTy.getLayout().isIdentity()) {
os << ", ";
- printAttribute(AffineMapAttr::get(map));
+ printAttribute(memrefTy.getLayout(), AttrTypeElision::May);
}
// Only print the memory space if it is the non-default one.
if (memrefTy.getMemorySpace()) {
diff --git a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
index 6bfa1ee6633ab..96992c219aa0c 100644
--- a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
@@ -8,6 +8,7 @@
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
#include "llvm/ADT/Sequence.h"
using namespace mlir;
@@ -72,3 +73,17 @@ uint64_t ElementsAttr::getFlattenedIndex(Attribute elementsAttr,
}
return valueIndex;
}
+
+//===----------------------------------------------------------------------===//
+// MemRefLayoutAttrInterface
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::detail::verifyAffineMapAsLayout(
+ AffineMap m, ArrayRef<int64_t> shape,
+ function_ref<InFlightDiagnostic()> emitError) {
+ if (m.getNumDims() != shape.size())
+ return emitError() << "memref layout mismatch between rank and affine map: "
+ << shape.size() << " != " << m.getNumDims();
+
+ return success();
+}
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index a87dd55d8b261..2ca794eddf2b0 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -646,9 +646,118 @@ unsigned MemRefType::getMemorySpaceAsInt() const {
return detail::getMemorySpaceAsInt(getMemorySpace());
}
+MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
+ MemRefLayoutAttrInterface layout,
+ Attribute memorySpace) {
+ // Use default layout for empty attribute.
+ if (!layout)
+ layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
+ shape.size(), elementType.getContext()));
+
+ // Drop default memory space value and replace it with empty attribute.
+ memorySpace = skipDefaultMemorySpace(memorySpace);
+
+ return Base::get(elementType.getContext(), shape, elementType, layout,
+ memorySpace);
+}
+
+MemRefType MemRefType::getChecked(
+ function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
+ Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
+
+ // Use default layout for empty attribute.
+ if (!layout)
+ layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
+ shape.size(), elementType.getContext()));
+
+ // Drop default memory space value and replace it with empty attribute.
+ memorySpace = skipDefaultMemorySpace(memorySpace);
+
+ return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
+ elementType, layout, memorySpace);
+}
+
+MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
+ AffineMap map, Attribute memorySpace) {
+
+ // Use default layout for empty map.
+ if (!map)
+ map = AffineMap::getMultiDimIdentityMap(shape.size(),
+ elementType.getContext());
+
+ // Wrap AffineMap into Attribute.
+ Attribute layout = AffineMapAttr::get(map);
+
+ // Drop default memory space value and replace it with empty attribute.
+ memorySpace = skipDefaultMemorySpace(memorySpace);
+
+ return Base::get(elementType.getContext(), shape, elementType, layout,
+ memorySpace);
+}
+
+MemRefType
+MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
+ ArrayRef<int64_t> shape, Type elementType, AffineMap map,
+ Attribute memorySpace) {
+
+ // Use default layout for empty map.
+ if (!map)
+ map = AffineMap::getMultiDimIdentityMap(shape.size(),
+ elementType.getContext());
+
+ // Wrap AffineMap into Attribute.
+ Attribute layout = AffineMapAttr::get(map);
+
+ // Drop default memory space value and replace it with empty attribute.
+ memorySpace = skipDefaultMemorySpace(memorySpace);
+
+ return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
+ elementType, layout, memorySpace);
+}
+
+MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
+ AffineMap map, unsigned memorySpaceInd) {
+
+ // Use default layout for empty map.
+ if (!map)
+ map = AffineMap::getMultiDimIdentityMap(shape.size(),
+ elementType.getContext());
+
+ // Wrap AffineMap into Attribute.
+ Attribute layout = AffineMapAttr::get(map);
+
+ // Convert deprecated integer-like memory space to Attribute.
+ Attribute memorySpace =
+ wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
+
+ return Base::get(elementType.getContext(), shape, elementType, layout,
+ memorySpace);
+}
+
+MemRefType
+MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
+ ArrayRef<int64_t> shape, Type elementType, AffineMap map,
+ unsigned memorySpaceInd) {
+
+ // Use default layout for empty map.
+ if (!map)
+ map = AffineMap::getMultiDimIdentityMap(shape.size(),
+ elementType.getContext());
+
+ // Wrap AffineMap into Attribute.
+ Attribute layout = AffineMapAttr::get(map);
+
+ // Convert deprecated integer-like memory space to Attribute.
+ Attribute memorySpace =
+ wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
+
+ return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
+ elementType, layout, memorySpace);
+}
+
LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
- ArrayRef<AffineMap> affineMapComposition,
+ MemRefLayoutAttrInterface layout,
Attribute memorySpace) {
if (!BaseMemRefType::isValidElementType(elementType))
return emitError() << "invalid memref element type";
@@ -658,26 +767,12 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
if (s < -1)
return emitError() << "invalid memref size";
- // Check that the structure of the composition is valid, i.e. that each
- // subsequent affine map has as many inputs as the previous map has results.
- // Take the dimensionality of the MemRef for the first map.
- size_t dim = shape.size();
- for (auto it : llvm::enumerate(affineMapComposition)) {
- AffineMap map = it.value();
- if (map.getNumDims() == dim) {
- dim = map.getNumResults();
- continue;
- }
- return emitError() << "memref affine map dimension mismatch between "
- << (it.index() == 0 ? Twine("memref rank")
- : "affine map " + Twine(it.index()))
- << " and affine map" << it.index() + 1 << ": " << dim
- << " != " << map.getNumDims();
- }
+ assert(layout && "missing layout specification");
+ if (failed(layout.verifyLayout(shape, emitError)))
+ return failure();
- if (!isSupportedMemorySpace(memorySpace)) {
+ if (!isSupportedMemorySpace(memorySpace))
return emitError() << "unsupported memory space Attribute";
- }
return success();
}
@@ -686,9 +781,9 @@ void MemRefType::walkImmediateSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const {
walkTypesFn(getElementType());
+ if (!getLayout().isIdentity())
+ walkAttrsFn(getLayout());
walkAttrsFn(getMemorySpace());
- for (AffineMap map : getAffineMaps())
- walkAttrsFn(AffineMapAttr::get(map));
}
//===----------------------------------------------------------------------===//
@@ -775,23 +870,18 @@ static LogicalResult extractStrides(AffineExpr e,
LogicalResult mlir::getStridesAndOffset(MemRefType t,
SmallVectorImpl<AffineExpr> &strides,
AffineExpr &offset) {
- auto affineMaps = t.getAffineMaps();
-
- if (affineMaps.size() > 1)
- return failure();
+ AffineMap m = t.getLayout().getAffineMap();
- if (!affineMaps.empty() && affineMaps.back().getNumResults() != 1)
+ if (m.getNumResults() != 1 && !m.isIdentity())
return failure();
- AffineMap m = affineMaps.empty() ? AffineMap() : affineMaps.back();
-
auto zero = getAffineConstantExpr(0, t.getContext());
auto one = getAffineConstantExpr(1, t.getContext());
offset = zero;
strides.assign(t.getRank(), zero);
// Canonical case for empty map.
- if (!m || m.isIdentity()) {
+ if (m.isIdentity()) {
// 0-D corner case, offset is already 0.
if (t.getRank() == 0)
return success();
@@ -938,21 +1028,21 @@ AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
/// `t` with simplified layout.
/// If `t` has multiple layout maps or a multi-result layout, just return `t`.
MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
- auto affineMaps = t.getAffineMaps();
+ AffineMap m = t.getLayout().getAffineMap();
+
// Already in canonical form.
- if (affineMaps.empty())
+ if (m.isIdentity())
return t;
// Can't reduce to canonical identity form, return in canonical form.
- if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1)
+ if (m.getNumResults() > 1)
return t;
// Corner-case for 0-D affine maps.
- auto m = affineMaps[0];
if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
if (cst.getValue() == 0)
- return MemRefType::Builder(t).setAffineMaps({});
+ return MemRefType::Builder(t).setLayout({});
return t;
}
@@ -970,9 +1060,9 @@ MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
auto simplifiedLayoutExpr =
simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
if (expr != simplifiedLayoutExpr)
- return MemRefType::Builder(t).setAffineMaps({AffineMap::get(
- m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)});
- return MemRefType::Builder(t).setAffineMaps({});
+ return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get(
+ m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
+ return MemRefType::Builder(t).setLayout({});
}
AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
@@ -1016,8 +1106,9 @@ AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
/// strides. This is used to erase the static layout.
MemRefType mlir::eraseStridedLayout(MemRefType t) {
auto val = ShapedType::kDynamicStrideOrOffset;
- return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap(
- SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()));
+ return MemRefType::Builder(t).setLayout(
+ AffineMapAttr::get(makeStridedLinearLayoutMap(
+ SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext())));
}
AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp
index 46fbf52d25bb6..256d6c0a96a24 100644
--- a/mlir/lib/Parser/TypeParser.cpp
+++ b/mlir/lib/Parser/TypeParser.cpp
@@ -185,9 +185,8 @@ ParseResult Parser::parseStridedLayout(int64_t &offset,
///
/// stride-list ::= `[` (dimension (`,` dimension)*)? `]`
/// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
-/// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
-/// layout-specification ::= semi-affine-map-composition | strided-layout
-/// memory-space ::= integer-literal /* | TODO: address-space-id */
+/// layout-specification ::= semi-affine-map | strided-layout | attribute
+/// memory-space ::= integer-literal | attribute
///
Type Parser::parseMemRefType() {
llvm::SMLoc loc = getToken().getLoc();
@@ -221,15 +220,10 @@ Type Parser::parseMemRefType() {
if (!BaseMemRefType::isValidElementType(elementType))
return emitError(typeLoc, "invalid memref element type"), nullptr;
- // Parse semi-affine-map-composition.
- SmallVector<AffineMap, 2> affineMapComposition;
+ MemRefLayoutAttrInterface layout;
Attribute memorySpace;
- unsigned numDims = dimensions.size();
auto parseElt = [&]() -> ParseResult {
- AffineMap map;
- llvm::SMLoc mapLoc = getToken().getLoc();
-
// Check for AffineMap as offset/strides.
if (getToken().is(Token::kw_offset)) {
int64_t offset;
@@ -237,15 +231,17 @@ Type Parser::parseMemRefType() {
if (failed(parseStridedLayout(offset, strides)))
return failure();
// Construct strided affine map.
- map = makeStridedLinearLayoutMap(strides, offset, state.context);
+ AffineMap map =
+ makeStridedLinearLayoutMap(strides, offset, state.context);
+ layout = AffineMapAttr::get(map);
} else {
- // Either it is AffineMapAttr or memory space attribute.
+ // Either it is MemRefLayoutAttrInterface or memory space attribute.
Attribute attr = parseAttribute();
if (!attr)
return failure();
- if (AffineMapAttr affineMapAttr = attr.dyn_cast<AffineMapAttr>()) {
- map = affineMapAttr.getValue();
+ if (attr.isa<MemRefLayoutAttrInterface>()) {
+ layout = attr.cast<MemRefLayoutAttrInterface>();
} else if (memorySpace) {
return emitError("multiple memory spaces specified in memref type");
} else {
@@ -259,15 +255,6 @@ Type Parser::parseMemRefType() {
if (memorySpace)
return emitError("expected memory space to be last in memref type");
- if (map.getNumDims() != numDims) {
- size_t i = affineMapComposition.size();
- return emitError(mapLoc, "memref affine map dimension mismatch between ")
- << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i))
- << " and affine map" << i + 1 << ": " << numDims
- << " != " << map.getNumDims();
- }
- numDims = map.getNumResults();
- affineMapComposition.push_back(map);
return success();
};
@@ -284,8 +271,8 @@ Type Parser::parseMemRefType() {
if (isUnranked)
return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace);
- return getChecked<MemRefType>(loc, dimensions, elementType,
- affineMapComposition, memorySpace);
+ return getChecked<MemRefType>(loc, dimensions, elementType, layout,
+ memorySpace);
}
/// Parse any type except the function type.
diff --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp
index c2b3a956d997e..148592a2c689e 100644
--- a/mlir/lib/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp
@@ -225,7 +225,7 @@ void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp,
// memref type is normalized.
// TODO: When selective normalization is implemented, handle multiple
// results case where some are normalized, some aren't.
- if (memrefType.getAffineMaps().empty())
+ if (memrefType.getLayout().isIdentity())
resultTypes[operandEn.index()] = memrefType;
}
});
@@ -269,7 +269,7 @@ void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp,
if (oldResult.getType() == newResult.getType())
continue;
AffineMap layoutMap =
- oldResult.getType().dyn_cast<MemRefType>().getAffineMaps().front();
+ oldResult.getType().cast<MemRefType>().getLayout().getAffineMap();
if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult,
/*extraIndices=*/{},
/*indexRemap=*/layoutMap,
@@ -363,7 +363,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(FuncOp funcOp,
BlockArgument newMemRef =
funcOp.front().insertArgument(argIndex, newMemRefType);
BlockArgument oldMemRef = funcOp.getArgument(argIndex + 1);
- AffineMap layoutMap = memrefType.getAffineMaps().front();
+ AffineMap layoutMap = memrefType.getLayout().getAffineMap();
// Replace all uses of the old memref.
if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef,
/*extraIndices=*/{},
@@ -412,7 +412,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(FuncOp funcOp,
if (oldMemRefType == newMemRefType)
continue;
// TODO: Assume single layout map. Multiple maps not supported.
- AffineMap layoutMap = oldMemRefType.getAffineMaps().front();
+ AffineMap layoutMap = oldMemRefType.getLayout().getAffineMap();
if (failed(replaceAllMemRefUsesWith(oldMemRef,
/*newMemRef=*/newMemRef,
/*extraIndices=*/{},
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp
index eaeb3475e6f17..500203b19a185 100644
--- a/mlir/lib/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -74,9 +74,7 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank());
newShape[0] = 2;
std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1);
- return MemRefType::Builder(oldMemRefType)
- .setShape(newShape)
- .setAffineMaps({});
+ return MemRefType::Builder(oldMemRefType).setShape(newShape).setLayout({});
};
auto oldMemRefType = oldMemRef.getType().cast<MemRefType>();
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 30968f9b1de4f..c43ee3f70ea63 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -2648,9 +2648,7 @@ static LogicalResult generateCopy(
auto memref = region.memref;
auto memRefType = memref.getType().cast<MemRefType>();
- auto layoutMaps = memRefType.getAffineMaps();
- if (layoutMaps.size() > 1 ||
- (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) {
+ if (!memRefType.getLayout().isIdentity()) {
LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
return failure();
}
diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp
index f37740d5317b8..8b40085e380a4 100644
--- a/mlir/lib/Transforms/Utils/Utils.cpp
+++ b/mlir/lib/Transforms/Utils/Utils.cpp
@@ -647,7 +647,7 @@ LogicalResult mlir::normalizeMemRef(memref::AllocOp *allocOp) {
Value oldMemRef = allocOp->getResult();
SmallVector<Value, 4> symbolOperands(allocOp->symbolOperands());
- AffineMap layoutMap = memrefType.getAffineMaps().front();
+ AffineMap layoutMap = memrefType.getLayout().getAffineMap();
memref::AllocOp newAlloc;
// Check if `layoutMap` is a tiled layout. Only single layout map is
// supported for normalizing dynamic memrefs.
@@ -695,13 +695,12 @@ MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b,
if (rank == 0)
return memrefType;
- ArrayRef<AffineMap> layoutMaps = memrefType.getAffineMaps();
- if (layoutMaps.empty() ||
- layoutMaps.front() == b.getMultiDimIdentityMap(rank)) {
+ if (memrefType.getLayout().isIdentity()) {
// Either no maps is associated with this memref or this memref has
// a trivial (identity) map.
return memrefType;
}
+ AffineMap layoutMap = memrefType.getLayout().getAffineMap();
// We don't do any checks for one-to-one'ness; we assume that it is
// one-to-one.
@@ -710,7 +709,7 @@ MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b,
// for now.
// TODO: Normalize the other types of dynamic memrefs.
SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
- (void)getTileSizePos(layoutMaps.front(), tileSizePos);
+ (void)getTileSizePos(layoutMap, tileSizePos);
if (memrefType.getNumDynamicDims() > 0 && tileSizePos.empty())
return memrefType;
@@ -731,7 +730,6 @@ MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b,
}
// We compose this map with the original index (logical) space to derive
// the upper bounds for the new index space.
- AffineMap layoutMap = layoutMaps.front();
unsigned newRank = layoutMap.getNumResults();
if (failed(fac.composeMatchingMap(layoutMap)))
return memrefType;
@@ -763,7 +761,7 @@ MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b,
MemRefType newMemRefType =
MemRefType::Builder(memrefType)
.setShape(newShape)
- .setAffineMaps(b.getMultiDimIdentityMap(newRank));
+ .setLayout(AffineMapAttr::get(b.getMultiDimIdentityMap(newRank)));
return newMemRefType;
}
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 4746f5104e7e8..fd245cd6afb17 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -726,7 +726,6 @@ static int printBuiltinTypes(MlirContext ctx) {
MlirType memRef = mlirMemRefTypeContiguousGet(
f32, sizeof(shape) / sizeof(int64_t), shape, memSpace2);
if (!mlirTypeIsAMemRef(memRef) ||
- mlirMemRefTypeGetNumAffineMaps(memRef) != 0 ||
!mlirAttributeEqual(mlirMemRefTypeGetMemorySpace(memRef), memSpace2))
return 18;
mlirTypeDump(memRef);
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 3c97530ec651b..6aa9679117cb0 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1095,7 +1095,7 @@ func @flat_transpose_type_mismatch(%arg0: vector<16xf32>) {
// -----
func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>>) {
- // expected-error at +1 {{expects operand to be a memref with no layout}}
+ // expected-error at +1 {{expects operand to be a memref with identity layout}}
%0 = vector.type_cast %arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>> to memref<vector<4x3xf32>>
}
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index e026b54d94c5e..cafe3440ace52 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -104,7 +104,8 @@ func @test_store_zero_results2(%x: i32, %p: memref<i32>) {
func @test_alloc_memref_map_rank_mismatch() {
^bb0:
- %0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0) -> (d0)>, 1> // expected-error {{memref affine map dimension mismatch}}
+ // expected-error at +1 {{memref layout mismatch between rank and affine map: 2 != 1}}
+ %0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0) -> (d0)>, 1>
return
}
diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index 8b9bce076749c..b9187dc96673c 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -61,13 +61,7 @@ func @memrefs(memref<2x4xi8, #map0, 1, #map1>) // expected-error {{expected memo
// The error must be emitted even for the trivial identity layout maps that are
// dropped in type creation.
#map0 = affine_map<(d0, d1) -> (d0, d1)>
-func @memrefs(memref<42xi8, #map0>) // expected-error {{memref affine map dimension mismatch}}
-
-// -----
-
-#map0 = affine_map<(d0, d1) -> (d0, d1)>
-#map1 = affine_map<(d0) -> (d0)>
-func @memrefs(memref<42x42xi8, #map0, #map1>) // expected-error {{memref affine map dimension mismatch}}
+func @memrefs(memref<42xi8, #map0>) // expected-error {{memref layout mismatch between rank and affine map: 1 != 2}}
// -----
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index 1b31c89e91caa..d1f4ab64f7f37 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -12,9 +12,6 @@
// CHECK-DAG: #map{{[0-9]+}} = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
-// CHECK-DAG: #map{{[0-9]+}} = affine_map<(d0, d1, d2) -> (d2, d1, d0)>
-#map4 = affine_map<(d0, d1, d2) -> (d2, d1, d0)>
-
// CHECK-DAG: #map{{[0-9]+}} = affine_map<()[s0] -> (0, s0 - 1)>
#inline_map_minmax_loop1 = affine_map<()[s0] -> (0, s0 - 1)>
@@ -80,28 +77,15 @@ func private @tensors(tensor<* x f32>, tensor<* x vector<2x4xf32>>,
// CHECK: func private @tensor_encoding(tensor<16x32xf64, "sparse">)
func private @tensor_encoding(tensor<16x32xf64, "sparse">)
-// CHECK: func private @memrefs(memref<1x?x4x?x?xi32, #map{{[0-9]+}}>, memref<8xi8>)
-func private @memrefs(memref<1x?x4x?x?xi32, #map0>, memref<8xi8, #map1, #map1>)
-
-// Test memref affine map compositions.
+// CHECK: func private @functions((memref<1x?x4x?x?xi32, #map0>, memref<8xi8>) -> (), () -> ())
+func private @functions((memref<1x?x4x?x?xi32, #map0, 0>, memref<8xi8, #map1, 0>) -> (), ()->())
// CHECK: func private @memrefs2(memref<2x4x8xi8, 1>)
func private @memrefs2(memref<2x4x8xi8, #map2, 1>)
-// CHECK: func private @memrefs23(memref<2x4x8xi8, #map{{[0-9]+}}>)
-func private @memrefs23(memref<2x4x8xi8, #map2, #map3, 0>)
-
-// CHECK: func private @memrefs234(memref<2x4x8xi8, #map{{[0-9]+}}, #map{{[0-9]+}}, 3>)
-func private @memrefs234(memref<2x4x8xi8, #map2, #map3, #map4, 3>)
-
-// Test memref inline affine map compositions, minding that identity maps are removed.
-
// CHECK: func private @memrefs3(memref<2x4x8xi8>)
func private @memrefs3(memref<2x4x8xi8, affine_map<(d0, d1, d2) -> (d0, d1, d2)>>)
-// CHECK: func private @memrefs33(memref<2x4x8xi8, #map{{[0-9]+}}, 1>)
-func private @memrefs33(memref<2x4x8xi8, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>, 1>)
-
// CHECK: func private @memrefs_drop_triv_id_inline(memref<2xi8>)
func private @memrefs_drop_triv_id_inline(memref<2xi8, affine_map<(d0) -> (d0)>>)
@@ -111,35 +95,6 @@ func private @memrefs_drop_triv_id_inline0(memref<2xi8, affine_map<(d0) -> (d0)>
// CHECK: func private @memrefs_drop_triv_id_inline1(memref<2xi8, 1>)
func private @memrefs_drop_triv_id_inline1(memref<2xi8, affine_map<(d0) -> (d0)>, 1>)
-// Identity maps should be dropped from the composition, but not the pair of
-// "interchange" maps that, if composed, would be also an identity.
-// CHECK: func private @memrefs_drop_triv_id_composition(memref<2x2xi8, #map{{[0-9]+}}, #map{{[0-9]+}}>)
-func private @memrefs_drop_triv_id_composition(memref<2x2xi8,
- affine_map<(d0, d1) -> (d1, d0)>,
- affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d1, d0)>,
- affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d0, d1)>>)
-
-// CHECK: func private @memrefs_drop_triv_id_trailing(memref<2x2xi8, #map{{[0-9]+}}>)
-func private @memrefs_drop_triv_id_trailing(memref<2x2xi8, affine_map<(d0, d1) -> (d1, d0)>,
- affine_map<(d0, d1) -> (d0, d1)>>)
-
-// CHECK: func private @memrefs_drop_triv_id_middle(memref<2x2xi8, #map{{[0-9]+}}, #map{{[0-9]+}}>)
-func private @memrefs_drop_triv_id_middle(memref<2x2xi8,
- affine_map<(d0, d1) -> (d0, d1 + 1)>,
- affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d0 + 1, d1)>>)
-
-// CHECK: func private @memrefs_drop_triv_id_multiple(memref<2xi8>)
-func private @memrefs_drop_triv_id_multiple(memref<2xi8, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>>)
-
-// These maps appeared before, so they must be uniqued and hoisted to the beginning.
-// Identity map should be removed.
-// CHECK: func private @memrefs_compose_with_id(memref<2x2xi8, #map{{[0-9]+}}>)
-func private @memrefs_compose_with_id(memref<2x2xi8, affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d1, d0)>>)
-
// Test memref with custom memory space
// CHECK: func private @memrefs_nomap_nospace(memref<5x6x7xf32>)
@@ -202,9 +157,6 @@ func private @unranked_memref_with_index_elems(memref<*xindex>)
// CHECK: func private @unranked_memref_with_vector_elems(memref<*xvector<10xf32>>)
func private @unranked_memref_with_vector_elems(memref<*xvector<10xf32>>)
-// CHECK: func private @functions((memref<1x?x4x?x?xi32, #map0>, memref<8xi8>) -> (), () -> ())
-func private @functions((memref<1x?x4x?x?xi32, #map0, 0>, memref<8xi8, #map1, 0>) -> (), ()->())
-
// CHECK-LABEL: func @simpleCFG(%{{.*}}: i32, %{{.*}}: f32) -> i1 {
func @simpleCFG(%arg0: i32, %f: f32) -> i1 {
// CHECK: %{{.*}} = "foo"() : () -> i64
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index ab6502e1d61fc..911391f2d528b 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -372,18 +372,21 @@ def testMemRefType():
memref = MemRefType.get(shape, f32, memory_space=Attribute.parse("2"))
# CHECK: memref type: memref<2x3xf32, 2>
print("memref type:", memref)
- # CHECK: number of affine layout maps: 0
- print("number of affine layout maps:", len(memref.layout))
+ # CHECK: memref layout: affine_map<(d0, d1) -> (d0, d1)>
+ print("memref layout:", memref.layout)
+ # CHECK: memref affine map: (d0, d1) -> (d0, d1)
+ print("memref affine map:", memref.affine_map)
# CHECK: memory space: 2
print("memory space:", memref.memory_space)
- layout = AffineMap.get_permutation([1, 0])
- memref_layout = MemRefType.get(shape, f32, [layout])
+ layout = AffineMapAttr.get(AffineMap.get_permutation([1, 0]))
+ memref_layout = MemRefType.get(shape, f32, layout=layout)
# CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>>
print("memref type:", memref_layout)
- assert len(memref_layout.layout) == 1
- # CHECK: memref layout: (d0, d1) -> (d1, d0)
- print("memref layout:", memref_layout.layout[0])
+ # CHECK: memref layout: affine_map<(d0, d1) -> (d1, d0)>
+ print("memref layout:", memref_layout.layout)
+ # CHECK: memref affine map: (d0, d1) -> (d1, d0)
+ print("memref affine map:", memref_layout.affine_map)
# CHECK: memory space: <<NULL ATTRIBUTE>>
print("memory space:", memref_layout.memory_space)
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index dc7591738b873..9c2a93e9429c6 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -32,26 +32,26 @@ TEST(ShapedTypeTest, CloneMemref) {
ShapedType memrefType =
MemRefType::Builder(memrefOriginalShape, memrefOriginalType)
.setMemorySpace(memSpace)
- .setAffineMaps(map);
+ .setLayout(AffineMapAttr::get(map));
// Update shape.
llvm::SmallVector<int64_t> memrefNewShape({30, 40});
ASSERT_NE(memrefOriginalShape, memrefNewShape);
ASSERT_EQ(memrefType.clone(memrefNewShape),
(MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType)
.setMemorySpace(memSpace)
- .setAffineMaps(map));
+ .setLayout(AffineMapAttr::get(map)));
// Update type.
Type memrefNewType = f32;
ASSERT_NE(memrefOriginalType, memrefNewType);
ASSERT_EQ(memrefType.clone(memrefNewType),
(MemRefType)MemRefType::Builder(memrefOriginalShape, memrefNewType)
.setMemorySpace(memSpace)
- .setAffineMaps(map));
+ .setLayout(AffineMapAttr::get(map)));
// Update both.
ASSERT_EQ(memrefType.clone(memrefNewShape, memrefNewType),
(MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType)
.setMemorySpace(memSpace)
- .setAffineMaps(map));
+ .setLayout(AffineMapAttr::get(map)));
// Test unranked memref cloning.
ShapedType unrankedTensorType =
More information about the Mlir-commits
mailing list