[Mlir-commits] [mlir] [mlir] Add optional layout attribute to VectorType (PR #71916)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 10 14:28:41 PST 2023
https://github.com/harsh-nod updated https://github.com/llvm/llvm-project/pull/71916
>From bb7c9463e3f5aba2e1f31334839d7e40627eec4b Mon Sep 17 00:00:00 2001
From: Harsh Menon <harsh at nod-labs.com>
Date: Thu, 9 Nov 2023 11:10:21 -0800
Subject: [PATCH 1/2] [mlir] Add optional layout attribute to VectorType
This patch adds an attribute interface for representing the
layout on vector types. This layout could be used to represent
the mapping from the vector indices to the indices of the
vector fragments held by different threads of a GPU.
The interface has a verify function that can be used to validate
that the layout accurately represents the vector shape.
---
.../mlir/IR/BuiltinAttributeInterfaces.td | 24 +++
mlir/include/mlir/IR/BuiltinTypes.h | 14 +-
mlir/include/mlir/IR/BuiltinTypes.td | 21 ++-
mlir/lib/AsmParser/TypeParser.cpp | 28 +++-
mlir/lib/IR/AsmPrinter.cpp | 5 +
mlir/lib/IR/BuiltinTypes.cpp | 8 +-
mlir/unittests/Interfaces/CMakeLists.txt | 1 +
.../Interfaces/VectorLayoutInterfaceTest.cpp | 158 ++++++++++++++++++
8 files changed, 249 insertions(+), 10 deletions(-)
create mode 100644 mlir/unittests/Interfaces/VectorLayoutInterfaceTest.cpp
diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
index c741db9b47f34e5..9241cac8c3b98a0 100644
--- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
@@ -495,4 +495,28 @@ def MemRefLayoutAttrInterface : AttrInterface<"MemRefLayoutAttrInterface"> {
];
}
+//===----------------------------------------------------------------------===//
+// VectorLayoutAttrInterface
+//===----------------------------------------------------------------------===//
+
+def VectorLayoutAttrInterface : AttrInterface<"VectorLayoutAttrInterface"> {
+ let cppNamespace = "::mlir";
+
+ let description = [{
+ This interface is used for attributes that can represent the Vector type's
+ layout semantics, such as being able to map the vector indices to those
+ of the vector fragments held by individiual threads.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ "Check if the current layout is applicable to the provided shape",
+ "::mlir::LogicalResult", "verifyLayout",
+ (ins "::llvm::ArrayRef<int64_t>":$shape,
+ "::mlir::Type":$elementType,
+ "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
+ >
+ ];
+}
+
#endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 92ce053ad5c829b..a387390b38e7de4 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -307,12 +307,14 @@ class VectorType::Builder {
/// Build from another VectorType.
explicit Builder(VectorType other)
: elementType(other.getElementType()), shape(other.getShape()),
- scalableDims(other.getScalableDims()) {}
+ scalableDims(other.getScalableDims()), layout(other.getLayout()) {}
/// Build from scratch.
Builder(ArrayRef<int64_t> shape, Type elementType,
- ArrayRef<bool> scalableDims = {})
- : elementType(elementType), shape(shape), scalableDims(scalableDims) {}
+ ArrayRef<bool> scalableDims = {},
+ VectorLayoutAttrInterface layout = {})
+ : elementType(elementType), shape(shape), scalableDims(scalableDims),
+ layout(layout) {}
Builder &setShape(ArrayRef<int64_t> newShape,
ArrayRef<bool> newIsScalableDim = {}) {
@@ -342,6 +344,11 @@ class VectorType::Builder {
return *this;
}
+ Builder &setLayout(VectorLayoutAttrInterface newLayout) {
+ layout = newLayout;
+ return *this;
+ }
+
operator VectorType() {
return VectorType::get(shape, elementType, scalableDims);
}
@@ -350,6 +357,7 @@ class VectorType::Builder {
Type elementType;
CopyOnWriteArrayRef<int64_t> shape;
CopyOnWriteArrayRef<bool> scalableDims;
+ VectorLayoutAttrInterface layout;
};
/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 5ec986ac26de06b..3a2193d7a1768f7 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1029,11 +1029,13 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
Syntax:
```
- vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
+ vector-type ::= `vector` `<` vector-dim-list vector-element-type
+ (`,` layout-specification)? `>`
vector-element-type ::= float-type | integer-type | index-type
vector-dim-list := (static-dim-list `x`)?
static-dim-list ::= static-dim (`x` static-dim)*
static-dim ::= (decimal-literal | `[` decimal-literal `]`)
+ layout-specification ::= attribute-value
```
The vector type represents a SIMD style vector used by target-specific
@@ -1050,6 +1052,14 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
declarations, `vector<0x42xi32>` is invalid because it is interpreted as a
2D vector with shape `(0, 42)` and zero shapes are not allowed.
+ ##### Layout
+
+ A vector may optionally have a layout that indicates how indices of
+ the vector are transformed to indices of the vector fragments that
+ are held by individual threads in a SIMT execution model. Such layouts
+ are common in a wide variety of GPU matrix multiplication instructions.
+ The layout can be any attribute that implements `VectorLayoutAttrInterface`.
+
Examples:
```mlir
@@ -1068,17 +1078,20 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
// A 3D mixed fixed/scalable vector in which only the inner dimension is
// scalable.
vector<2x[4]x8xf32>
+
```
}];
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
"Type":$elementType,
- ArrayRefParameter<"bool">:$scalableDims
+ ArrayRefParameter<"bool">:$scalableDims,
+ "VectorLayoutAttrInterface":$layout
);
let builders = [
TypeBuilderWithInferredContext<(ins
"ArrayRef<int64_t>":$shape, "Type":$elementType,
- CArg<"ArrayRef<bool>", "{}">:$scalableDims
+ CArg<"ArrayRef<bool>", "{}">:$scalableDims,
+ CArg<"VectorLayoutAttrInterface", "{}">:$layout
), [{
// While `scalableDims` is optional, its default value should be
// `false` for every dim in `shape`.
@@ -1087,7 +1100,7 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
isScalableVec.resize(shape.size(), false);
scalableDims = isScalableVec;
}
- return $_get(elementType.getContext(), shape, elementType, scalableDims);
+ return $_get(elementType.getContext(), shape, elementType, scalableDims, layout);
}]>
];
let extraClassDeclaration = [{
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index adaefb78172c2ea..be8c84ee74e8688 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -459,14 +459,38 @@ VectorType Parser::parseVectorType() {
// Parse the element type.
auto typeLoc = getToken().getLoc();
auto elementType = parseType();
- if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
+ if (!elementType)
return nullptr;
if (!VectorType::isValidElementType(elementType))
return emitError(typeLoc, "vector elements must be int/index/float type"),
nullptr;
- return VectorType::get(dimensions, elementType, scalableDims);
+ VectorLayoutAttrInterface layout;
+ auto parseElt = [&]() -> ParseResult {
+ Attribute attr = parseAttribute();
+ if (!attr)
+ return failure();
+ if (isa<VectorLayoutAttrInterface>(attr)) {
+ layout = cast<VectorLayoutAttrInterface>(attr);
+ }
+ return success();
+ };
+
+ // Parse a list of mappings and address space if present.
+ if (!consumeIf(Token::greater)) {
+ // Parse comma separated list of affine maps, followed by memory space.
+ if (parseToken(Token::comma, "expected ',' or '>' in vector type") ||
+ parseCommaSeparatedListUntil(Token::greater, parseElt,
+ /*allowEmptyList=*/false)) {
+ return nullptr;
+ }
+ }
+
+ if (!layout)
+ return VectorType::get(dimensions, elementType, scalableDims);
+
+ return VectorType::get(dimensions, elementType, scalableDims, layout);
}
/// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index dae7fdd40b5456c..458140b0de81c50 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2572,6 +2572,11 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
os << 'x';
}
printType(vectorTy.getElementType());
+ VectorLayoutAttrInterface layout = vectorTy.getLayout();
+ if (layout) {
+ os << ", ";
+ printAttribute(vectorTy.getLayout(), AttrTypeElision::May);
+ }
os << '>';
})
.Case<RankedTensorType>([&](RankedTensorType tensorTy) {
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index a9284d5714637bc..b0ebec169e3c841 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -227,7 +227,8 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
- ArrayRef<bool> scalableDims) {
+ ArrayRef<bool> scalableDims,
+ VectorLayoutAttrInterface layout) {
if (!isValidElementType(elementType))
return emitError()
<< "vector elements must be int/index/float type but got "
@@ -242,6 +243,11 @@ LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
return emitError() << "number of dims must match, got "
<< scalableDims.size() << " and " << shape.size();
+ if (layout) {
+ if (failed(layout.verifyLayout(shape, elementType, emitError)))
+ return emitError() << "layout does not match underlying vector shape";
+ }
+
return success();
}
diff --git a/mlir/unittests/Interfaces/CMakeLists.txt b/mlir/unittests/Interfaces/CMakeLists.txt
index d192b2922d6b9dc..45a518f8549a3c6 100644
--- a/mlir/unittests/Interfaces/CMakeLists.txt
+++ b/mlir/unittests/Interfaces/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_unittest(MLIRInterfacesTests
DataLayoutInterfacesTest.cpp
InferIntRangeInterfaceTest.cpp
InferTypeOpInterfaceTest.cpp
+ VectorLayoutInterfaceTest.cpp
)
target_link_libraries(MLIRInterfacesTests
diff --git a/mlir/unittests/Interfaces/VectorLayoutInterfaceTest.cpp b/mlir/unittests/Interfaces/VectorLayoutInterfaceTest.cpp
new file mode 100644
index 000000000000000..0ea9a710138e3a4
--- /dev/null
+++ b/mlir/unittests/Interfaces/VectorLayoutInterfaceTest.cpp
@@ -0,0 +1,158 @@
+//===-- VectorLayoutInterfaceTest.cpp - Unit Tests for Vector Layouts -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Parser/Parser.h"
+
+#include <gtest/gtest.h>
+
+using namespace mlir;
+using namespace mlir::detail;
+
+class NamedStridedLayoutAttrStorage : public AttributeStorage {
+public:
+ using KeyTy =
+ std::tuple<ArrayRef<std::string>, ArrayRef<int64_t>, ArrayRef<int64_t>>;
+
+ NamedStridedLayoutAttrStorage(ArrayRef<std::string> names,
+ ArrayRef<int64_t> strides,
+ ArrayRef<int64_t> vectorShape)
+ : names(names), strides(strides), vectorShape(vectorShape) {}
+
+ bool operator==(const KeyTy &key) const {
+ return (std::get<0>(key) == names) && (std::get<1>(key) == strides) &&
+ (std::get<2>(key) == vectorShape);
+ }
+
+ static NamedStridedLayoutAttrStorage *
+ construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+ ArrayRef<std::string> names = allocator.copyInto(std::get<0>(key));
+ ArrayRef<int64_t> strides = allocator.copyInto(std::get<1>(key));
+ ArrayRef<int64_t> vectorShape = allocator.copyInto(std::get<2>(key));
+ return new (allocator.allocate<NamedStridedLayoutAttrStorage>())
+ NamedStridedLayoutAttrStorage(names, strides, vectorShape);
+ }
+
+ ArrayRef<std::string> names;
+ ArrayRef<int64_t> strides;
+ ArrayRef<int64_t> vectorShape;
+};
+
+struct NamedStridedLayoutAttr
+ : public Attribute::AttrBase<NamedStridedLayoutAttr, Attribute,
+ NamedStridedLayoutAttrStorage,
+ VectorLayoutAttrInterface::Trait> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NamedStridedLayoutAttr)
+ using Base::Base;
+ static NamedStridedLayoutAttr get(MLIRContext *ctx,
+ ArrayRef<std::string> names,
+ ArrayRef<int64_t> strides,
+ ArrayRef<int64_t> vectorShape) {
+ return Base::get(ctx, names, strides, vectorShape);
+ }
+
+ LogicalResult verifyLayout(ArrayRef<int64_t> shape, Type elementType,
+ function_ref<InFlightDiagnostic()> emitError) {
+ if (shape == getVectorShape())
+ return success();
+ return failure();
+ }
+
+ ArrayRef<std::string> getNames() { return getImpl()->names; }
+ ArrayRef<int64_t> getStrides() { return getImpl()->strides; }
+ ArrayRef<int64_t> getVectorShape() { return getImpl()->vectorShape; }
+};
+
+struct VLTestDialect : Dialect {
+ explicit VLTestDialect(MLIRContext *ctx)
+ : Dialect(getDialectNamespace(), ctx, TypeID::get<VLTestDialect>()) {
+ ctx->loadDialect<VLTestDialect>();
+ addAttributes<NamedStridedLayoutAttr>();
+ }
+ static StringRef getDialectNamespace() { return "vltest"; }
+
+ Attribute parseAttribute(DialectAsmParser &parser, Type type) const override {
+ SmallVector<int64_t> strides, vectorShape;
+ SmallVector<std::string> names;
+ if (!succeeded(parser.parseKeyword("named_strided_layout")))
+ return {};
+ if (!succeeded(parser.parseLess()))
+ return {};
+ do {
+ if (!succeeded(parser.parseLSquare()))
+ return {};
+ std::string name;
+ int64_t stride;
+ int64_t shape = 1;
+ do {
+ if (succeeded(parser.parseString(&name)) &&
+ succeeded(parser.parseColon()) &&
+ succeeded(parser.parseInteger(stride))) {
+ names.push_back(name);
+ strides.push_back(stride);
+ shape *= stride;
+ }
+ } while (succeeded(parser.parseOptionalComma()));
+ if (!succeeded(parser.parseRSquare()))
+ return {};
+ vectorShape.push_back(shape);
+ } while (succeeded(parser.parseOptionalComma()));
+ if (!succeeded(parser.parseGreater()))
+ return {};
+ return NamedStridedLayoutAttr::get(parser.getContext(), names, strides,
+ vectorShape);
+ }
+};
+
+TEST(VectorLayoutAttrInterface, NamedStridedLayout) {
+ const char *ir = R"MLIR(
+ #layout = #vltest.named_strided_layout<["BatchX" : 2, "LaneX" : 4, "VectorX" : 2],
+ ["BatchY" : 1, "LaneY" : 8, "VectorY" : 2]>
+ %lhs = "arith.constant"() {value = dense<0.0> : vector<16x16xf16, #layout>}
+ : () -> (vector<16x16xf16, #layout>)
+ )MLIR";
+
+ DialectRegistry registry;
+ registry.insert<VLTestDialect, func::FuncDialect, arith::ArithDialect>();
+ MLIRContext ctx(registry);
+ OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
+
+ arith::ConstantOp op =
+ llvm::cast<arith::ConstantOp>(module->getBody()->getOperations().front());
+ Type type = op.getResult().getType();
+ if (auto vectorType = llvm::cast<VectorType>(type)) {
+ VectorLayoutAttrInterface layout = vectorType.getLayout();
+ auto namedStridedLayout = llvm::cast<NamedStridedLayoutAttr>(layout);
+ ArrayRef<std::string> names = namedStridedLayout.getNames();
+ ArrayRef<int64_t> strides = namedStridedLayout.getStrides();
+ ArrayRef<int64_t> vectorShape = namedStridedLayout.getVectorShape();
+ EXPECT_EQ(vectorShape.size(), 2u);
+ EXPECT_EQ(vectorShape[0], 16u);
+ EXPECT_EQ(vectorShape[1], 16u);
+ EXPECT_EQ(strides.size(), 6u);
+ EXPECT_EQ(strides[0], 2u);
+ EXPECT_EQ(strides[1], 4u);
+ EXPECT_EQ(strides[2], 2u);
+ EXPECT_EQ(strides[3], 1u);
+ EXPECT_EQ(strides[4], 8u);
+ EXPECT_EQ(strides[5], 2u);
+ EXPECT_EQ(names.size(), 6u);
+ EXPECT_EQ(names[0], "BatchX");
+ EXPECT_EQ(names[1], "LaneX");
+ EXPECT_EQ(names[2], "VectorX");
+ EXPECT_EQ(names[3], "BatchY");
+ EXPECT_EQ(names[4], "LaneY");
+ EXPECT_EQ(names[5], "VectorY");
+ }
+}
>From cdefb686a5351adaa263b68de5c4d5f70c8ea87f Mon Sep 17 00:00:00 2001
From: Harsh Menon <harsh at nod-labs.com>
Date: Fri, 10 Nov 2023 14:27:27 -0800
Subject: [PATCH 2/2] Address Thomas' comments
- Fix out of date comments
- More generic error message on layout verification failure
- Add roundtrip test
- More generic description of layout
---
mlir/include/mlir/IR/BuiltinTypes.td | 12 ++--
mlir/lib/AsmParser/TypeParser.cpp | 3 +-
mlir/lib/IR/BuiltinTypes.cpp | 2 +-
.../Interfaces/VectorLayoutInterfaceTest.cpp | 57 ++++++++++++++++++-
4 files changed, 65 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 3a2193d7a1768f7..f02335c6bf5bee3 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1054,11 +1054,13 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
##### Layout
- A vector may optionally have a layout that indicates how indices of
- the vector are transformed to indices of the vector fragments that
- are held by individual threads in a SIMT execution model. Such layouts
- are common in a wide variety of GPU matrix multiplication instructions.
- The layout can be any attribute that implements `VectorLayoutAttrInterface`.
+ A vector may optionally have a layout that can be used to capture
+ the mapping of the vector indices to a an arbitrary coordinate sytem.
+ An example of such a mapping is the mapping of vector indices to
+ indices of the vector fragments that are held by individual threads
+ in a SIMT execution model. Such layouts are common in a wide variety of
+ GPU matrix multiplication instructions. The layout can be any attribute
+ that implements `VectorLayoutAttrInterface`.
Examples:
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index be8c84ee74e8688..0888ad7224a1070 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -477,9 +477,8 @@ VectorType Parser::parseVectorType() {
return success();
};
- // Parse a list of mappings and address space if present.
+ // Parse the vector layout
if (!consumeIf(Token::greater)) {
- // Parse comma separated list of affine maps, followed by memory space.
if (parseToken(Token::comma, "expected ',' or '>' in vector type") ||
parseCommaSeparatedListUntil(Token::greater, parseElt,
/*allowEmptyList=*/false)) {
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index b0ebec169e3c841..90617fd479b4b70 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -245,7 +245,7 @@ LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
if (layout) {
if (failed(layout.verifyLayout(shape, elementType, emitError)))
- return emitError() << "layout does not match underlying vector shape";
+ return emitError() << "Layout verification failed!";
}
return success();
diff --git a/mlir/unittests/Interfaces/VectorLayoutInterfaceTest.cpp b/mlir/unittests/Interfaces/VectorLayoutInterfaceTest.cpp
index 0ea9a710138e3a4..96713122c34483a 100644
--- a/mlir/unittests/Interfaces/VectorLayoutInterfaceTest.cpp
+++ b/mlir/unittests/Interfaces/VectorLayoutInterfaceTest.cpp
@@ -82,6 +82,37 @@ struct VLTestDialect : Dialect {
}
static StringRef getDialectNamespace() { return "vltest"; }
+ void printAttribute(Attribute attr,
+ DialectAsmPrinter &printer) const override {
+ auto layoutAttr = llvm::cast<NamedStridedLayoutAttr>(attr);
+ SmallVector<int64_t> mutableVectorShape(layoutAttr.getVectorShape());
+ size_t i{0}, j{0};
+ auto addCommaIf = [&](bool condition) {
+ if (condition)
+ printer << ", ";
+ };
+ auto addLParenIf = [&](bool condition) {
+ if (condition)
+ printer << "[";
+ };
+ auto addRParenIf = [&](bool condition) {
+ if (condition)
+ printer << "]";
+ };
+ for (const auto &[name, stride] :
+ llvm::zip(layoutAttr.getNames(), layoutAttr.getStrides())) {
+ addLParenIf(j == 0);
+ printer << name << " : " << stride;
+ mutableVectorShape[i] /= stride;
+ addCommaIf(mutableVectorShape[i] > 1);
+ bool finishedParsingList = mutableVectorShape[i] == 1;
+ addRParenIf(finishedParsingList);
+ addCommaIf(finishedParsingList && (i < mutableVectorShape.size() - 1));
+ j = finishedParsingList ? 0 : j + 1;
+ i = finishedParsingList ? i + 1 : i;
+ }
+ }
+
Attribute parseAttribute(DialectAsmParser &parser, Type type) const override {
SmallVector<int64_t> strides, vectorShape;
SmallVector<std::string> names;
@@ -124,7 +155,7 @@ TEST(VectorLayoutAttrInterface, NamedStridedLayout) {
)MLIR";
DialectRegistry registry;
- registry.insert<VLTestDialect, func::FuncDialect, arith::ArithDialect>();
+ registry.insert<VLTestDialect, arith::ArithDialect>();
MLIRContext ctx(registry);
OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
@@ -156,3 +187,27 @@ TEST(VectorLayoutAttrInterface, NamedStridedLayout) {
EXPECT_EQ(names[5], "VectorY");
}
}
+
+TEST(VectorLayoutAttrInterface, RoundTripTest) {
+ const char *ir = R"MLIR(
+ #layout = #vltest.named_strided_layout<["BatchX" : 2, "LaneX" : 4, "VectorX" : 2],
+ ["BatchY" : 1, "LaneY" : 8, "VectorY" : 2]>
+ %lhs = "arith.constant"() {value = dense<0.0> : vector<16x16xf16, #layout>}
+ : () -> (vector<16x16xf16, #layout>)
+ )MLIR";
+
+ DialectRegistry registry;
+ registry.insert<VLTestDialect, arith::ArithDialect>();
+ MLIRContext ctx(registry);
+ OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
+ std::string moduleStr;
+ llvm::raw_string_ostream stream(moduleStr);
+ stream << *module;
+ stream.flush();
+ const std::string expectedResult =
+ "module {\n"
+ " %cst = arith.constant dense<0.000000e+00> :"
+ " vector<16x16xf16, #vltest<[BatchX : 2, LaneX : 4, VectorX : 2],"
+ " [BatchY : 1, LaneY : 8, VectorY : 2]>>\n}";
+ EXPECT_EQ(moduleStr, expectedResult);
+}
\ No newline at end of file
More information about the Mlir-commits
mailing list