[Mlir-commits] [mlir] [mlir] Add optional layout attribute to VectorType (PR #71916)
Thomas Raoux
llvmlistbot at llvm.org
Fri Nov 10 09:40:49 PST 2023
@@ -0,0 +1,158 @@
+//===-- VectorLayoutInterfaceTest.cpp - Unit Tests for Vector Layouts -----===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Parser/Parser.h"
+#include <gtest/gtest.h>
+using namespace mlir;
+using namespace mlir::detail;
+class NamedStridedLayoutAttrStorage : public AttributeStorage {
+ using KeyTy =
+ std::tuple<ArrayRef<std::string>, ArrayRef<int64_t>, ArrayRef<int64_t>>;
+ NamedStridedLayoutAttrStorage(ArrayRef<std::string> names,
+ ArrayRef<int64_t> strides,
+ ArrayRef<int64_t> vectorShape)
+ : names(names), strides(strides), vectorShape(vectorShape) {}
+ bool operator==(const KeyTy &key) const {
+ return (std::get<0>(key) == names) && (std::get<1>(key) == strides) &&
+ (std::get<2>(key) == vectorShape);
+ }
+ static NamedStridedLayoutAttrStorage *
+ construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+ ArrayRef<std::string> names = allocator.copyInto(std::get<0>(key));
+ ArrayRef<int64_t> strides = allocator.copyInto(std::get<1>(key));
+ ArrayRef<int64_t> vectorShape = allocator.copyInto(std::get<2>(key));
+ return new (allocator.allocate<NamedStridedLayoutAttrStorage>())
+ NamedStridedLayoutAttrStorage(names, strides, vectorShape);
+ }
+ ArrayRef<std::string> names;
+ ArrayRef<int64_t> strides;
+ ArrayRef<int64_t> vectorShape;
+struct NamedStridedLayoutAttr
+ : public Attribute::AttrBase<NamedStridedLayoutAttr, Attribute,
+ NamedStridedLayoutAttrStorage,
+ VectorLayoutAttrInterface::Trait> {
+ using Base::Base;
+ static NamedStridedLayoutAttr get(MLIRContext *ctx,
+ ArrayRef<std::string> names,
+ ArrayRef<int64_t> strides,
+ ArrayRef<int64_t> vectorShape) {
+ return Base::get(ctx, names, strides, vectorShape);
+ }
+ LogicalResult verifyLayout(ArrayRef<int64_t> shape, Type elementType,
+ function_ref<InFlightDiagnostic()> emitError) {
+ if (shape == getVectorShape())
+ return success();
+ return failure();
+ }
+ ArrayRef<std::string> getNames() { return getImpl()->names; }
+ ArrayRef<int64_t> getStrides() { return getImpl()->strides; }
+ ArrayRef<int64_t> getVectorShape() { return getImpl()->vectorShape; }
+struct VLTestDialect : Dialect {
+ explicit VLTestDialect(MLIRContext *ctx)
+ : Dialect(getDialectNamespace(), ctx, TypeID::get<VLTestDialect>()) {
+ ctx->loadDialect<VLTestDialect>();
+ addAttributes<NamedStridedLayoutAttr>();
+ }
+ static StringRef getDialectNamespace() { return "vltest"; }
+ Attribute parseAttribute(DialectAsmParser &parser, Type type) const override {
+ SmallVector<int64_t> strides, vectorShape;
+ SmallVector<std::string> names;
+ if (!succeeded(parser.parseKeyword("named_strided_layout")))
+ return {};
+ if (!succeeded(parser.parseLess()))
+ return {};
+ do {
+ if (!succeeded(parser.parseLSquare()))
+ return {};
+ std::string name;
+ int64_t stride;
+ int64_t shape = 1;
+ do {
+ if (succeeded(parser.parseString(&name)) &&
+ succeeded(parser.parseColon()) &&
+ succeeded(parser.parseInteger(stride))) {
+ names.push_back(name);
+ strides.push_back(stride);
+ shape *= stride;
+ }
+ } while (succeeded(parser.parseOptionalComma()));
+ if (!succeeded(parser.parseRSquare()))
+ return {};
+ vectorShape.push_back(shape);
+ } while (succeeded(parser.parseOptionalComma()));
+ if (!succeeded(parser.parseGreater()))
+ return {};
+ return NamedStridedLayoutAttr::get(parser.getContext(), names, strides,
+ vectorShape);
+ }
+TEST(VectorLayoutAttrInterface, NamedStridedLayout) {
+ const char *ir = R"MLIR(
ThomasRaoux wrote:
can you also add a lit test to test round trip printing/parsing?
More information about the Mlir-commits
mailing list