[Mlir-commits] [mlir] [mlir] Add encoding attribute to VectorType. (PR #99029)
Alexander Belyaev
llvmlistbot at llvm.org
Tue Jul 16 05:52:49 PDT 2024
https://github.com/pifon2a created https://github.com/llvm/llvm-project/pull/99029
RankedTensorType already has encoding attribute. Adding it to VectorType as well.
>From 8672bec00111b34626d057185d96595576947c54 Mon Sep 17 00:00:00 2001
From: Alexander Belyaev <pifon at google.com>
Date: Tue, 16 Jul 2024 14:47:44 +0200
Subject: [PATCH] [mlir] Add encoding attribute to VectorType.
RankedTensorType already has encoding attribute. Adding it to VectorType as
well.
---
mlir/include/mlir/IR/BuiltinTypes.h | 15 +++++++++++----
mlir/include/mlir/IR/BuiltinTypes.td | 17 ++++++++++++++---
mlir/lib/AsmParser/TypeParser.cpp | 20 ++++++++++++++++++--
mlir/lib/IR/AsmPrinter.cpp | 5 +++++
mlir/lib/IR/BuiltinTypes.cpp | 8 ++++++--
mlir/test/IR/parser.mlir | 3 +++
6 files changed, 57 insertions(+), 11 deletions(-)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 5579b138668d2..564a27e01240c 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -307,12 +307,13 @@ class VectorType::Builder {
/// Build from another VectorType.
explicit Builder(VectorType other)
: elementType(other.getElementType()), shape(other.getShape()),
- scalableDims(other.getScalableDims()) {}
+ scalableDims(other.getScalableDims()), encoding(other.getEncoding()) {}
/// Build from scratch.
Builder(ArrayRef<int64_t> shape, Type elementType,
- ArrayRef<bool> scalableDims = {})
- : elementType(elementType), shape(shape), scalableDims(scalableDims) {}
+ ArrayRef<bool> scalableDims = {}, Attribute encoding = nullptr)
+ : elementType(elementType), shape(shape), scalableDims(scalableDims),
+ encoding(encoding) {}
Builder &setShape(ArrayRef<int64_t> newShape,
ArrayRef<bool> newIsScalableDim = {}) {
@@ -342,14 +343,20 @@ class VectorType::Builder {
return *this;
}
+ Builder &setEncoding(Attribute newEncoding) {
+ encoding = newEncoding;
+ return *this;
+ }
+
operator VectorType() {
- return VectorType::get(shape, elementType, scalableDims);
+ return VectorType::get(shape, elementType, scalableDims, encoding);
}
private:
Type elementType;
CopyOnWriteArrayRef<int64_t> shape;
CopyOnWriteArrayRef<bool> scalableDims;
+ Attribute encoding;
};
/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 4cade83dd3c32..7edc8d228b340 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1060,6 +1060,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
vector-dim-list := (static-dim-list `x`)?
static-dim-list ::= static-dim (`x` static-dim)*
static-dim ::= (decimal-literal | `[` decimal-literal `]`)
+ encoding ::= attribute-value
```
The vector type represents a SIMD style vector used by target-specific
@@ -1072,6 +1073,10 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
Vector shapes must be positive decimal integers. 0D vectors are allowed by
omitting the dimension: `vector<f32>`.
+ The `encoding` attribute provides additional information on the vector.
+ An empty attribute denotes a straightforward vector without any specific
+ structure.
+
Note: hexadecimal integer literals are not allowed in vector type
declarations, `vector<0x42xi32>` is invalid because it is interpreted as a
2D vector with shape `(0, 42)` and zero shapes are not allowed.
@@ -1094,17 +1099,22 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
// A 3D mixed fixed/scalable vector in which only the inner dimension is
// scalable.
vector<2x[4]x8xf32>
+
+ // Vector with an encoding attribute (where #ENCODING is a named alias).
+ vector<4x2xf64, #ENCODING>
```
}];
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
"Type":$elementType,
- ArrayRefParameter<"bool">:$scalableDims
+ ArrayRefParameter<"bool">:$scalableDims,
+ "Attribute":$encoding
);
let builders = [
TypeBuilderWithInferredContext<(ins
"ArrayRef<int64_t>":$shape, "Type":$elementType,
- CArg<"ArrayRef<bool>", "{}">:$scalableDims
+ CArg<"ArrayRef<bool>", "{}">:$scalableDims,
+ CArg<"Attribute", "{}">:$encoding
), [{
// While `scalableDims` is optional, its default value should be
// `false` for every dim in `shape`.
@@ -1113,7 +1123,8 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
isScalableVec.resize(shape.size(), false);
scalableDims = isScalableVec;
}
- return $_get(elementType.getContext(), shape, elementType, scalableDims);
+ auto ctx = elementType.getContext();
+ return $_get(ctx, shape, elementType, scalableDims, encoding);
}]>
];
let extraClassDeclaration = [{
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 0b46c96bbc04d..8d38d5d1c4dea 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -448,6 +448,7 @@ Type Parser::parseTupleType() {
/// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
/// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
/// static-dim-list ::= decimal-literal (`x` decimal-literal)*
+/// encoding ::= attribute-value
///
VectorType Parser::parseVectorType() {
consumeToken(Token::kw_vector);
@@ -467,14 +468,29 @@ VectorType Parser::parseVectorType() {
// Parse the element type.
auto typeLoc = getToken().getLoc();
auto elementType = parseType();
+
+ // Parse an optional encoding attribute.
+ Attribute encoding;
+ if (consumeIf(Token::comma)) {
+ auto parseResult = parseOptionalAttribute(encoding);
+ if (parseResult.has_value()) {
+ if (failed(parseResult.value()))
+ return nullptr;
+ if (auto v = dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) {
+ if (failed(v.verifyEncoding(dimensions, elementType,
+ [&] { return emitError(); })))
+ return nullptr;
+ }
+ }
+ }
+
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
return nullptr;
-
if (!VectorType::isValidElementType(elementType))
return emitError(typeLoc, "vector elements must be int/index/float type"),
nullptr;
- return VectorType::get(dimensions, elementType, scalableDims);
+ return VectorType::get(dimensions, elementType, scalableDims, encoding);
}
/// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 13eb18036eeec..f336ca061a55b 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2622,6 +2622,11 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
os << 'x';
}
printType(vectorTy.getElementType());
+ // Only print the encoding attribute value if set.
+ if (vectorTy.getEncoding()) {
+ os << ", ";
+ printAttribute(vectorTy.getEncoding());
+ }
os << '>';
})
.Case<RankedTensorType>([&](RankedTensorType tensorTy) {
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 179797cb943a1..b15a35d1e4126 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -227,7 +227,8 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
- ArrayRef<bool> scalableDims) {
+ ArrayRef<bool> scalableDims,
+ Attribute encoding) {
if (!isValidElementType(elementType))
return emitError()
<< "vector elements must be int/index/float type but got "
@@ -242,6 +243,9 @@ LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
return emitError() << "number of dims must match, got "
<< scalableDims.size() << " and " << shape.size();
+ if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
+ if (failed(v.verifyEncoding(shape, elementType, emitError)))
+ return failure();
return success();
}
@@ -260,7 +264,7 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) {
VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const {
return VectorType::get(shape.value_or(getShape()), elementType,
- getScalableDims());
+ getScalableDims(), getEncoding());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index cace1fefa43d6..57ccbd6c02da5 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -73,6 +73,9 @@ func.func private @float_types(f80, f128)
// CHECK: func private @vectors(vector<f32>, vector<1xf32>, vector<2x4xf32>)
func.func private @vectors(vector<f32>, vector<1 x f32>, vector<2x4xf32>)
+// CHECK: func private @vector_encoding(vector<16x32xf64, "indexed">)
+func.func private @vector_encoding(vector<16x32xf64, "indexed">)
+
// CHECK: func private @tensors(tensor<*xf32>, tensor<*xvector<2x4xf32>>, tensor<1x?x4x?x?xi32>, tensor<i8>)
func.func private @tensors(tensor<* x f32>, tensor<* x vector<2x4xf32>>,
tensor<1x?x4x?x?xi32>, tensor<i8>)
More information about the Mlir-commits
mailing list