[Mlir-commits] [mlir] [mlir] Add encoding attribute to VectorType. (PR #99029)

Alexander Belyaev llvmlistbot at llvm.org
Tue Jul 16 05:52:49 PDT 2024


https://github.com/pifon2a created https://github.com/llvm/llvm-project/pull/99029

RankedTensorType already has encoding attribute. Adding it to VectorType as well.

>From 8672bec00111b34626d057185d96595576947c54 Mon Sep 17 00:00:00 2001
From: Alexander Belyaev <pifon at google.com>
Date: Tue, 16 Jul 2024 14:47:44 +0200
Subject: [PATCH] [mlir] Add encoding attribute to VectorType.

RankedTensorType already has encoding attribute. Adding it to VectorType as
well.
---
 mlir/include/mlir/IR/BuiltinTypes.h  | 15 +++++++++++----
 mlir/include/mlir/IR/BuiltinTypes.td | 17 ++++++++++++++---
 mlir/lib/AsmParser/TypeParser.cpp    | 20 ++++++++++++++++++--
 mlir/lib/IR/AsmPrinter.cpp           |  5 +++++
 mlir/lib/IR/BuiltinTypes.cpp         |  8 ++++++--
 mlir/test/IR/parser.mlir             |  3 +++
 6 files changed, 57 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 5579b138668d2..564a27e01240c 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -307,12 +307,13 @@ class VectorType::Builder {
   /// Build from another VectorType.
   explicit Builder(VectorType other)
       : elementType(other.getElementType()), shape(other.getShape()),
-        scalableDims(other.getScalableDims()) {}
+        scalableDims(other.getScalableDims()), encoding(other.getEncoding()) {}
 
   /// Build from scratch.
   Builder(ArrayRef<int64_t> shape, Type elementType,
-          ArrayRef<bool> scalableDims = {})
-      : elementType(elementType), shape(shape), scalableDims(scalableDims) {}
+          ArrayRef<bool> scalableDims = {}, Attribute encoding = nullptr)
+      : elementType(elementType), shape(shape), scalableDims(scalableDims),
+        encoding(encoding) {}
 
   Builder &setShape(ArrayRef<int64_t> newShape,
                     ArrayRef<bool> newIsScalableDim = {}) {
@@ -342,14 +343,20 @@ class VectorType::Builder {
     return *this;
   }
 
+  Builder &setEncoding(Attribute newEncoding) {
+    encoding = newEncoding;
+    return *this;
+  }
+
   operator VectorType() {
-    return VectorType::get(shape, elementType, scalableDims);
+    return VectorType::get(shape, elementType, scalableDims, encoding);
   }
 
 private:
   Type elementType;
   CopyOnWriteArrayRef<int64_t> shape;
   CopyOnWriteArrayRef<bool> scalableDims;
+  Attribute encoding;
 };
 
 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 4cade83dd3c32..7edc8d228b340 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1060,6 +1060,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
     vector-dim-list := (static-dim-list `x`)?
     static-dim-list ::= static-dim (`x` static-dim)*
     static-dim ::= (decimal-literal | `[` decimal-literal `]`)
+    encoding ::= attribute-value
     ```
 
     The vector type represents a SIMD style vector used by target-specific
@@ -1072,6 +1073,10 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
     Vector shapes must be positive decimal integers. 0D vectors are allowed by
     omitting the dimension: `vector<f32>`.
 
+    The `encoding` attribute provides additional information on the vector.
+    An empty attribute denotes a straightforward vector without any specific
+    structure.
+
     Note: hexadecimal integer literals are not allowed in vector type
     declarations, `vector<0x42xi32>` is invalid because it is interpreted as a
     2D vector with shape `(0, 42)` and zero shapes are not allowed.
@@ -1094,17 +1099,22 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
     // A 3D mixed fixed/scalable vector in which only the inner dimension is
     // scalable.
     vector<2x[4]x8xf32>
+
+    // Vector with an encoding attribute (where #ENCODING is a named alias).
+    vector<4x2xf64, #ENCODING>
     ```
   }];
   let parameters = (ins
     ArrayRefParameter<"int64_t">:$shape,
     "Type":$elementType,
-    ArrayRefParameter<"bool">:$scalableDims
+    ArrayRefParameter<"bool">:$scalableDims,
+    "Attribute":$encoding
   );
   let builders = [
     TypeBuilderWithInferredContext<(ins
       "ArrayRef<int64_t>":$shape, "Type":$elementType,
-      CArg<"ArrayRef<bool>", "{}">:$scalableDims
+      CArg<"ArrayRef<bool>", "{}">:$scalableDims,
+      CArg<"Attribute", "{}">:$encoding
     ), [{
       // While `scalableDims` is optional, its default value should be
       // `false` for every dim in `shape`.
@@ -1113,7 +1123,8 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
         isScalableVec.resize(shape.size(), false);
         scalableDims = isScalableVec;
       }
-      return $_get(elementType.getContext(), shape, elementType, scalableDims);
+      auto ctx = elementType.getContext();
+      return $_get(ctx, shape, elementType, scalableDims, encoding);
     }]>
   ];
   let extraClassDeclaration = [{
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 0b46c96bbc04d..8d38d5d1c4dea 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -448,6 +448,7 @@ Type Parser::parseTupleType() {
 /// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
 /// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
 /// static-dim-list ::= decimal-literal (`x` decimal-literal)*
+/// encoding ::= attribute-value
 ///
 VectorType Parser::parseVectorType() {
   consumeToken(Token::kw_vector);
@@ -467,14 +468,29 @@ VectorType Parser::parseVectorType() {
   // Parse the element type.
   auto typeLoc = getToken().getLoc();
   auto elementType = parseType();
+
+  // Parse an optional encoding attribute.
+  Attribute encoding;
+  if (consumeIf(Token::comma)) {
+    auto parseResult = parseOptionalAttribute(encoding);
+    if (parseResult.has_value()) {
+      if (failed(parseResult.value()))
+        return nullptr;
+      if (auto v = dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) {
+        if (failed(v.verifyEncoding(dimensions, elementType,
+                                    [&] { return emitError(); })))
+          return nullptr;
+      }
+    }
+  }
+
   if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
     return nullptr;
-
   if (!VectorType::isValidElementType(elementType))
     return emitError(typeLoc, "vector elements must be int/index/float type"),
            nullptr;
 
-  return VectorType::get(dimensions, elementType, scalableDims);
+  return VectorType::get(dimensions, elementType, scalableDims, encoding);
 }
 
 /// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 13eb18036eeec..f336ca061a55b 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2622,6 +2622,11 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
           os << 'x';
         }
         printType(vectorTy.getElementType());
+        // Only print the encoding attribute value if set.
+        if (vectorTy.getEncoding()) {
+          os << ", ";
+          printAttribute(vectorTy.getEncoding());
+        }
         os << '>';
       })
       .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 179797cb943a1..b15a35d1e4126 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -227,7 +227,8 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
 
 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
                                  ArrayRef<int64_t> shape, Type elementType,
-                                 ArrayRef<bool> scalableDims) {
+                                 ArrayRef<bool> scalableDims,
+                                 Attribute encoding) {
   if (!isValidElementType(elementType))
     return emitError()
            << "vector elements must be int/index/float type but got "
@@ -242,6 +243,9 @@ LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
     return emitError() << "number of dims must match, got "
                        << scalableDims.size() << " and " << shape.size();
 
+  if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
+    if (failed(v.verifyEncoding(shape, elementType, emitError)))
+      return failure();
   return success();
 }
 
@@ -260,7 +264,7 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) {
 VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
                                  Type elementType) const {
   return VectorType::get(shape.value_or(getShape()), elementType,
-                         getScalableDims());
+                         getScalableDims(), getEncoding());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index cace1fefa43d6..57ccbd6c02da5 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -73,6 +73,9 @@ func.func private @float_types(f80, f128)
 // CHECK: func private @vectors(vector<f32>, vector<1xf32>, vector<2x4xf32>)
 func.func private @vectors(vector<f32>, vector<1 x f32>, vector<2x4xf32>)
 
+// CHECK: func private @vector_encoding(vector<16x32xf64, "indexed">)
+func.func private @vector_encoding(vector<16x32xf64, "indexed">)
+
 // CHECK: func private @tensors(tensor<*xf32>, tensor<*xvector<2x4xf32>>, tensor<1x?x4x?x?xi32>, tensor<i8>)
 func.func private @tensors(tensor<* x f32>, tensor<* x vector<2x4xf32>>,
                  tensor<1x?x4x?x?xi32>, tensor<i8>)



More information about the Mlir-commits mailing list