[Mlir-commits] [mlir] [mlir] Add optional layout attribute to VectorType (PR #71916)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 10 02:07:04 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-ods

Author: None (harsh-nod)

<details>
<summary>Changes</summary>

This patch adds an attribute interface for representing the layout on vector types. This layout could be used to represent the mapping from the vector indices to the indices of the vector fragments held by different threads of a GPU.

The interface has a verify function that can be used to validate that the layout accurately represents the vector shape.

---
Full diff: https://github.com/llvm/llvm-project/pull/71916.diff


8 Files Affected:

- (modified) mlir/include/mlir/IR/BuiltinAttributeInterfaces.td (+24) 
- (modified) mlir/include/mlir/IR/BuiltinTypes.h (+11-3) 
- (modified) mlir/include/mlir/IR/BuiltinTypes.td (+17-4) 
- (modified) mlir/lib/AsmParser/TypeParser.cpp (+26-2) 
- (modified) mlir/lib/IR/AsmPrinter.cpp (+5) 
- (modified) mlir/lib/IR/BuiltinTypes.cpp (+7-1) 
- (modified) mlir/unittests/Interfaces/CMakeLists.txt (+1) 
- (added) mlir/unittests/Interfaces/VectorLayoutInterfaceTest.cpp (+158) 


``````````diff
diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
index c741db9b47f34e5..9241cac8c3b98a0 100644
--- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
@@ -495,4 +495,28 @@ def MemRefLayoutAttrInterface : AttrInterface<"MemRefLayoutAttrInterface"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// VectorLayoutAttrInterface
+//===----------------------------------------------------------------------===//
+
+def VectorLayoutAttrInterface : AttrInterface<"VectorLayoutAttrInterface"> {
+  let cppNamespace = "::mlir";
+
+  let description = [{
+    This interface is used for attributes that can represent the Vector type's
+    layout semantics, such as being able to map the vector indices to those
+    of the vector fragments held by individiual threads.
+  }];
+
+  let methods = [
+    InterfaceMethod<
+      "Check if the current layout is applicable to the provided shape",
+      "::mlir::LogicalResult", "verifyLayout",
+      (ins "::llvm::ArrayRef<int64_t>":$shape,
+           "::mlir::Type":$elementType,
+           "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
+    >
+  ];
+}
+
 #endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 92ce053ad5c829b..a387390b38e7de4 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -307,12 +307,14 @@ class VectorType::Builder {
   /// Build from another VectorType.
   explicit Builder(VectorType other)
       : elementType(other.getElementType()), shape(other.getShape()),
-        scalableDims(other.getScalableDims()) {}
+        scalableDims(other.getScalableDims()), layout(other.getLayout()) {}
 
   /// Build from scratch.
   Builder(ArrayRef<int64_t> shape, Type elementType,
-          ArrayRef<bool> scalableDims = {})
-      : elementType(elementType), shape(shape), scalableDims(scalableDims) {}
+          ArrayRef<bool> scalableDims = {},
+          VectorLayoutAttrInterface layout = {})
+      : elementType(elementType), shape(shape), scalableDims(scalableDims),
+        layout(layout) {}
 
   Builder &setShape(ArrayRef<int64_t> newShape,
                     ArrayRef<bool> newIsScalableDim = {}) {
@@ -342,6 +344,11 @@ class VectorType::Builder {
     return *this;
   }
 
+  Builder &setLayout(VectorLayoutAttrInterface newLayout) {
+    layout = newLayout;
+    return *this;
+  }
+
   operator VectorType() {
     return VectorType::get(shape, elementType, scalableDims);
   }
@@ -350,6 +357,7 @@ class VectorType::Builder {
   Type elementType;
   CopyOnWriteArrayRef<int64_t> shape;
   CopyOnWriteArrayRef<bool> scalableDims;
+  VectorLayoutAttrInterface layout;
 };
 
 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 5ec986ac26de06b..3a2193d7a1768f7 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1029,11 +1029,13 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
     Syntax:
 
     ```
-    vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
+    vector-type ::= `vector` `<` vector-dim-list vector-element-type
+                    (`,` layout-specification)? `>`
     vector-element-type ::= float-type | integer-type | index-type
     vector-dim-list := (static-dim-list `x`)?
     static-dim-list ::= static-dim (`x` static-dim)*
     static-dim ::= (decimal-literal | `[` decimal-literal `]`)
+    layout-specification ::= attribute-value
     ```
 
     The vector type represents a SIMD style vector used by target-specific
@@ -1050,6 +1052,14 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
     declarations, `vector<0x42xi32>` is invalid because it is interpreted as a
     2D vector with shape `(0, 42)` and zero shapes are not allowed.
 
+    ##### Layout
+
+    A vector may optionally have a layout that indicates how indices of
+    the vector are transformed to indices of the vector fragments that
+    are held by individual threads in a SIMT execution model. Such layouts
+    are common in a wide variety of GPU matrix multiplication instructions.
+    The layout can be any attribute that implements `VectorLayoutAttrInterface`.
+
     Examples:
 
     ```mlir
@@ -1068,17 +1078,20 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
     // A 3D mixed fixed/scalable vector in which only the inner dimension is
     // scalable.
     vector<2x[4]x8xf32>
+
     ```
   }];
   let parameters = (ins
     ArrayRefParameter<"int64_t">:$shape,
     "Type":$elementType,
-    ArrayRefParameter<"bool">:$scalableDims
+    ArrayRefParameter<"bool">:$scalableDims,
+    "VectorLayoutAttrInterface":$layout
   );
   let builders = [
     TypeBuilderWithInferredContext<(ins
       "ArrayRef<int64_t>":$shape, "Type":$elementType,
-      CArg<"ArrayRef<bool>", "{}">:$scalableDims
+      CArg<"ArrayRef<bool>", "{}">:$scalableDims,
+      CArg<"VectorLayoutAttrInterface", "{}">:$layout
     ), [{
       // While `scalableDims` is optional, its default value should be
       // `false` for every dim in `shape`.
@@ -1087,7 +1100,7 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
         isScalableVec.resize(shape.size(), false);
         scalableDims = isScalableVec;
       }
-      return $_get(elementType.getContext(), shape, elementType, scalableDims);
+      return $_get(elementType.getContext(), shape, elementType, scalableDims, layout);
     }]>
   ];
   let extraClassDeclaration = [{
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index adaefb78172c2ea..be8c84ee74e8688 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -459,14 +459,38 @@ VectorType Parser::parseVectorType() {
   // Parse the element type.
   auto typeLoc = getToken().getLoc();
   auto elementType = parseType();
-  if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
+  if (!elementType)
     return nullptr;
 
   if (!VectorType::isValidElementType(elementType))
     return emitError(typeLoc, "vector elements must be int/index/float type"),
            nullptr;
 
-  return VectorType::get(dimensions, elementType, scalableDims);
+  VectorLayoutAttrInterface layout;
+  auto parseElt = [&]() -> ParseResult {
+    Attribute attr = parseAttribute();
+    if (!attr)
+      return failure();
+    if (isa<VectorLayoutAttrInterface>(attr)) {
+      layout = cast<VectorLayoutAttrInterface>(attr);
+    }
+    return success();
+  };
+
+  // Parse a list of mappings and address space if present.
+  if (!consumeIf(Token::greater)) {
+    // Parse comma separated list of affine maps, followed by memory space.
+    if (parseToken(Token::comma, "expected ',' or '>' in vector type") ||
+        parseCommaSeparatedListUntil(Token::greater, parseElt,
+                                     /*allowEmptyList=*/false)) {
+      return nullptr;
+    }
+  }
+
+  if (!layout)
+    return VectorType::get(dimensions, elementType, scalableDims);
+
+  return VectorType::get(dimensions, elementType, scalableDims, layout);
 }
 
 /// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index dae7fdd40b5456c..458140b0de81c50 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2572,6 +2572,11 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
           os << 'x';
         }
         printType(vectorTy.getElementType());
+        VectorLayoutAttrInterface layout = vectorTy.getLayout();
+        if (layout) {
+          os << ", ";
+          printAttribute(vectorTy.getLayout(), AttrTypeElision::May);
+        }
         os << '>';
       })
       .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index a9284d5714637bc..b0ebec169e3c841 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -227,7 +227,8 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
 
 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
                                  ArrayRef<int64_t> shape, Type elementType,
-                                 ArrayRef<bool> scalableDims) {
+                                 ArrayRef<bool> scalableDims,
+                                 VectorLayoutAttrInterface layout) {
   if (!isValidElementType(elementType))
     return emitError()
            << "vector elements must be int/index/float type but got "
@@ -242,6 +243,11 @@ LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
     return emitError() << "number of dims must match, got "
                        << scalableDims.size() << " and " << shape.size();
 
+  if (layout) {
+    if (failed(layout.verifyLayout(shape, elementType, emitError)))
+      return emitError() << "layout does not match underlying vector shape";
+  }
+
   return success();
 }
 
diff --git a/mlir/unittests/Interfaces/CMakeLists.txt b/mlir/unittests/Interfaces/CMakeLists.txt
index d192b2922d6b9dc..45a518f8549a3c6 100644
--- a/mlir/unittests/Interfaces/CMakeLists.txt
+++ b/mlir/unittests/Interfaces/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_unittest(MLIRInterfacesTests
   DataLayoutInterfacesTest.cpp
   InferIntRangeInterfaceTest.cpp
   InferTypeOpInterfaceTest.cpp
+  VectorLayoutInterfaceTest.cpp
 )
 
 target_link_libraries(MLIRInterfacesTests
diff --git a/mlir/unittests/Interfaces/VectorLayoutInterfaceTest.cpp b/mlir/unittests/Interfaces/VectorLayoutInterfaceTest.cpp
new file mode 100644
index 000000000000000..0ea9a710138e3a4
--- /dev/null
+++ b/mlir/unittests/Interfaces/VectorLayoutInterfaceTest.cpp
@@ -0,0 +1,158 @@
+//===-- VectorLayoutInterfaceTest.cpp - Unit Tests for Vector Layouts -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Parser/Parser.h"
+
+#include <gtest/gtest.h>
+
+using namespace mlir;
+using namespace mlir::detail;
+
+class NamedStridedLayoutAttrStorage : public AttributeStorage {
+public:
+  using KeyTy =
+      std::tuple<ArrayRef<std::string>, ArrayRef<int64_t>, ArrayRef<int64_t>>;
+
+  NamedStridedLayoutAttrStorage(ArrayRef<std::string> names,
+                                ArrayRef<int64_t> strides,
+                                ArrayRef<int64_t> vectorShape)
+      : names(names), strides(strides), vectorShape(vectorShape) {}
+
+  bool operator==(const KeyTy &key) const {
+    return (std::get<0>(key) == names) && (std::get<1>(key) == strides) &&
+           (std::get<2>(key) == vectorShape);
+  }
+
+  static NamedStridedLayoutAttrStorage *
+  construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+    ArrayRef<std::string> names = allocator.copyInto(std::get<0>(key));
+    ArrayRef<int64_t> strides = allocator.copyInto(std::get<1>(key));
+    ArrayRef<int64_t> vectorShape = allocator.copyInto(std::get<2>(key));
+    return new (allocator.allocate<NamedStridedLayoutAttrStorage>())
+        NamedStridedLayoutAttrStorage(names, strides, vectorShape);
+  }
+
+  ArrayRef<std::string> names;
+  ArrayRef<int64_t> strides;
+  ArrayRef<int64_t> vectorShape;
+};
+
+struct NamedStridedLayoutAttr
+    : public Attribute::AttrBase<NamedStridedLayoutAttr, Attribute,
+                                 NamedStridedLayoutAttrStorage,
+                                 VectorLayoutAttrInterface::Trait> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NamedStridedLayoutAttr)
+  using Base::Base;
+  static NamedStridedLayoutAttr get(MLIRContext *ctx,
+                                    ArrayRef<std::string> names,
+                                    ArrayRef<int64_t> strides,
+                                    ArrayRef<int64_t> vectorShape) {
+    return Base::get(ctx, names, strides, vectorShape);
+  }
+
+  LogicalResult verifyLayout(ArrayRef<int64_t> shape, Type elementType,
+                             function_ref<InFlightDiagnostic()> emitError) {
+    if (shape == getVectorShape())
+      return success();
+    return failure();
+  }
+
+  ArrayRef<std::string> getNames() { return getImpl()->names; }
+  ArrayRef<int64_t> getStrides() { return getImpl()->strides; }
+  ArrayRef<int64_t> getVectorShape() { return getImpl()->vectorShape; }
+};
+
+struct VLTestDialect : Dialect {
+  explicit VLTestDialect(MLIRContext *ctx)
+      : Dialect(getDialectNamespace(), ctx, TypeID::get<VLTestDialect>()) {
+    ctx->loadDialect<VLTestDialect>();
+    addAttributes<NamedStridedLayoutAttr>();
+  }
+  static StringRef getDialectNamespace() { return "vltest"; }
+
+  Attribute parseAttribute(DialectAsmParser &parser, Type type) const override {
+    SmallVector<int64_t> strides, vectorShape;
+    SmallVector<std::string> names;
+    if (!succeeded(parser.parseKeyword("named_strided_layout")))
+      return {};
+    if (!succeeded(parser.parseLess()))
+      return {};
+    do {
+      if (!succeeded(parser.parseLSquare()))
+        return {};
+      std::string name;
+      int64_t stride;
+      int64_t shape = 1;
+      do {
+        if (succeeded(parser.parseString(&name)) &&
+            succeeded(parser.parseColon()) &&
+            succeeded(parser.parseInteger(stride))) {
+          names.push_back(name);
+          strides.push_back(stride);
+          shape *= stride;
+        }
+      } while (succeeded(parser.parseOptionalComma()));
+      if (!succeeded(parser.parseRSquare()))
+        return {};
+      vectorShape.push_back(shape);
+    } while (succeeded(parser.parseOptionalComma()));
+    if (!succeeded(parser.parseGreater()))
+      return {};
+    return NamedStridedLayoutAttr::get(parser.getContext(), names, strides,
+                                       vectorShape);
+  }
+};
+
+TEST(VectorLayoutAttrInterface, NamedStridedLayout) {
+  const char *ir = R"MLIR(
+    #layout = #vltest.named_strided_layout<["BatchX" : 2, "LaneX" : 4, "VectorX" : 2],
+                                           ["BatchY" : 1, "LaneY" : 8, "VectorY" : 2]>
+    %lhs = "arith.constant"() {value = dense<0.0> : vector<16x16xf16, #layout>}
+        : () -> (vector<16x16xf16, #layout>)
+  )MLIR";
+
+  DialectRegistry registry;
+  registry.insert<VLTestDialect, func::FuncDialect, arith::ArithDialect>();
+  MLIRContext ctx(registry);
+  OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
+
+  arith::ConstantOp op =
+      llvm::cast<arith::ConstantOp>(module->getBody()->getOperations().front());
+  Type type = op.getResult().getType();
+  if (auto vectorType = llvm::cast<VectorType>(type)) {
+    VectorLayoutAttrInterface layout = vectorType.getLayout();
+    auto namedStridedLayout = llvm::cast<NamedStridedLayoutAttr>(layout);
+    ArrayRef<std::string> names = namedStridedLayout.getNames();
+    ArrayRef<int64_t> strides = namedStridedLayout.getStrides();
+    ArrayRef<int64_t> vectorShape = namedStridedLayout.getVectorShape();
+    EXPECT_EQ(vectorShape.size(), 2u);
+    EXPECT_EQ(vectorShape[0], 16u);
+    EXPECT_EQ(vectorShape[1], 16u);
+    EXPECT_EQ(strides.size(), 6u);
+    EXPECT_EQ(strides[0], 2u);
+    EXPECT_EQ(strides[1], 4u);
+    EXPECT_EQ(strides[2], 2u);
+    EXPECT_EQ(strides[3], 1u);
+    EXPECT_EQ(strides[4], 8u);
+    EXPECT_EQ(strides[5], 2u);
+    EXPECT_EQ(names.size(), 6u);
+    EXPECT_EQ(names[0], "BatchX");
+    EXPECT_EQ(names[1], "LaneX");
+    EXPECT_EQ(names[2], "VectorX");
+    EXPECT_EQ(names[3], "BatchY");
+    EXPECT_EQ(names[4], "LaneY");
+    EXPECT_EQ(names[5], "VectorY");
+  }
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/71916


More information about the Mlir-commits mailing list