[Mlir-commits] [mlir] [mlir] Add optional layout attribute to VectorType (PR #71916)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 10 12:10:55 PST 2023


================
@@ -0,0 +1,158 @@
+//===-- VectorLayoutInterfaceTest.cpp - Unit Tests for Vector Layouts -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Parser/Parser.h"
+
+#include <gtest/gtest.h>
+
+using namespace mlir;
+using namespace mlir::detail;
+
+class NamedStridedLayoutAttrStorage : public AttributeStorage {
+public:
+  using KeyTy =
+      std::tuple<ArrayRef<std::string>, ArrayRef<int64_t>, ArrayRef<int64_t>>;
+
+  NamedStridedLayoutAttrStorage(ArrayRef<std::string> names,
+                                ArrayRef<int64_t> strides,
+                                ArrayRef<int64_t> vectorShape)
+      : names(names), strides(strides), vectorShape(vectorShape) {}
+
+  bool operator==(const KeyTy &key) const {
+    return (std::get<0>(key) == names) && (std::get<1>(key) == strides) &&
+           (std::get<2>(key) == vectorShape);
+  }
+
+  static NamedStridedLayoutAttrStorage *
+  construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+    ArrayRef<std::string> names = allocator.copyInto(std::get<0>(key));
+    ArrayRef<int64_t> strides = allocator.copyInto(std::get<1>(key));
+    ArrayRef<int64_t> vectorShape = allocator.copyInto(std::get<2>(key));
+    return new (allocator.allocate<NamedStridedLayoutAttrStorage>())
+        NamedStridedLayoutAttrStorage(names, strides, vectorShape);
+  }
+
+  ArrayRef<std::string> names;
+  ArrayRef<int64_t> strides;
+  ArrayRef<int64_t> vectorShape;
+};
+
+struct NamedStridedLayoutAttr
+    : public Attribute::AttrBase<NamedStridedLayoutAttr, Attribute,
+                                 NamedStridedLayoutAttrStorage,
+                                 VectorLayoutAttrInterface::Trait> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NamedStridedLayoutAttr)
+  using Base::Base;
+  static NamedStridedLayoutAttr get(MLIRContext *ctx,
+                                    ArrayRef<std::string> names,
+                                    ArrayRef<int64_t> strides,
+                                    ArrayRef<int64_t> vectorShape) {
+    return Base::get(ctx, names, strides, vectorShape);
+  }
+
+  LogicalResult verifyLayout(ArrayRef<int64_t> shape, Type elementType,
+                             function_ref<InFlightDiagnostic()> emitError) {
+    if (shape == getVectorShape())
+      return success();
+    return failure();
+  }
+
+  ArrayRef<std::string> getNames() { return getImpl()->names; }
+  ArrayRef<int64_t> getStrides() { return getImpl()->strides; }
+  ArrayRef<int64_t> getVectorShape() { return getImpl()->vectorShape; }
+};
+
+struct VLTestDialect : Dialect {
+  explicit VLTestDialect(MLIRContext *ctx)
+      : Dialect(getDialectNamespace(), ctx, TypeID::get<VLTestDialect>()) {
+    ctx->loadDialect<VLTestDialect>();
+    addAttributes<NamedStridedLayoutAttr>();
+  }
+  static StringRef getDialectNamespace() { return "vltest"; }
+
+  Attribute parseAttribute(DialectAsmParser &parser, Type type) const override {
+    SmallVector<int64_t> strides, vectorShape;
+    SmallVector<std::string> names;
+    if (!succeeded(parser.parseKeyword("named_strided_layout")))
+      return {};
+    if (!succeeded(parser.parseLess()))
+      return {};
+    do {
+      if (!succeeded(parser.parseLSquare()))
+        return {};
+      std::string name;
+      int64_t stride;
+      int64_t shape = 1;
+      do {
+        if (succeeded(parser.parseString(&name)) &&
+            succeeded(parser.parseColon()) &&
+            succeeded(parser.parseInteger(stride))) {
+          names.push_back(name);
+          strides.push_back(stride);
+          shape *= stride;
+        }
+      } while (succeeded(parser.parseOptionalComma()));
+      if (!succeeded(parser.parseRSquare()))
+        return {};
+      vectorShape.push_back(shape);
+    } while (succeeded(parser.parseOptionalComma()));
+    if (!succeeded(parser.parseGreater()))
+      return {};
+    return NamedStridedLayoutAttr::get(parser.getContext(), names, strides,
+                                       vectorShape);
+  }
+};
+
+TEST(VectorLayoutAttrInterface, NamedStridedLayout) {
+  const char *ir = R"MLIR(
----------------
harsh-nod wrote:

sure will add a roundtrip test. Might have to keep it as a unit test since the attribute and dialect are only defined in this cpp file.

https://github.com/llvm/llvm-project/pull/71916


More information about the Mlir-commits mailing list