[Mlir-commits] [mlir] [mlir][ArmSME] Provide descriptions and summaries for types (PR #70920)
Benjamin Maxwell
llvmlistbot at llvm.org
Wed Nov 1 05:03:20 PDT 2023
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/70920
>From 5ebdff179fe22868d85782f2210028a1e4b02884 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 1 Nov 2023 10:37:32 +0000
Subject: [PATCH] [mlir][ArmSME] Provide descriptions and summaries for types
The auto-generated summaries are hard to read (and pretty unhelpful),
and SME tile was:
```
vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values
```
...and an SVE vector:
```
of ranks 1scalable vector of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer or 16-bit float or bfloat16 type or 32-bit float or 64-bit float values of length 16/8/4/2/1
```
Note: The descriptions added here won't yet be shown on the MLIR docs
(only the short summaries), but this should be easy to enable like
it was for attribute descriptions in #67009.
A table of contents (TOC) is also added to the ArmSME docs page to make
it easier to navigate.
---
mlir/docs/Dialects/ArmSME.md | 2 +
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td | 38 ++++++++++++++-
.../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 48 +++++++++++++++++--
mlir/test/Dialect/ArmSME/invalid.mlir | 16 +++----
4 files changed, 90 insertions(+), 14 deletions(-)
diff --git a/mlir/docs/Dialects/ArmSME.md b/mlir/docs/Dialects/ArmSME.md
index ab7c9ffe7aa92f1..505b52938eacc05 100644
--- a/mlir/docs/Dialects/ArmSME.md
+++ b/mlir/docs/Dialects/ArmSME.md
@@ -1,5 +1,7 @@
# 'ArmSME' Dialect
+[TOC]
+
Basic dialect to target Arm SME architectures This dialect contains the
definitions necessary to target Arm SME scalable matrix operations.
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 18b9bd7a107febf..ffafb2569310ed1 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -42,11 +42,45 @@ def ArmSME_Dialect : Dialect {
// ArmSME type definitions
//===----------------------------------------------------------------------===//
+// FIXME: This allows types that are not SVE vectors, e.g. vector<[16]xi128>.
def SVEVector : ScalableVectorOfRankAndLengthAndType<
- [1], [16, 8, 4, 2, 1], [I8, I16, I32, I64, I128, F16, BF16, F32, F64]>;
+ [1], [16, 8, 4, 2, 1], [I8, I16, I32, I64, I128, F16, BF16, F32, F64]>
+{
+ let summary = "a vector type that matches the size of a SVE vector";
+ let description = [{
+ Possible vector types:
+
+ Integer elements:
+
+ * `vector<[16]xi8>`
+ * `vector<[8]xi16>`
+ * `vector<[4]xi32>`
+ * `vector<[2]xi64>`
+ * `vector<[1]xi128>`
+
+ Floating point elements:
+
+ * `vector<[8]xf16>`
+ * `vector<[8]xbf16>`
+ * `vector<[4]xf32>`
+ * `vector<[2]xf64>`
+ }];
+}
def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
- [1], [16, 8, 4, 2, 1], [I1]>;
+ [1], [16, 8, 4, 2, 1], [I1]>
+{
+ let summary = "a vector type that matches the size of a SVE predicate";
+ let description = [{
+ Possible vector types:
+
+ * `vector<[16]xi1>`
+ * `vector<[8]xi1>`
+ * `vector<[4]xi1>`
+ * `vector<[2]xi1>`
+ * `vector<[1]xi1>`
+ }];
+}
#endif // ARMSME
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 37a2257a0015ce7..e57a8acd82de758 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -43,7 +43,47 @@ def nxnxv4f32 : SMETileType<F32, [4, 4 ], "vector<[4]x[4]xf32>">;
def nxnxv2f64 : SMETileType<F64, [2, 2 ], "vector<[2]x[2]xf64>">;
def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
- nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64]>;
+ nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64],
+ "a vector type that fits into a SME tile">
+{
+ let description = [{
+ Possible vector types:
+
+ Integer elements:
+
+ * `vector<[16]x[16]xi8>`
+ * `vector<[8]x[8]xi16>`
+ * `vector<[4]x[4]xi32>`
+ * `vector<[2]x[2]xi64>`
+ * `vector<[1]x[1]xi128>`
+
+ Floating point elements:
+
+ * `vector<[8]x[8]xf16>`
+ * `vector<[8]x[8]xbf16>`
+ * `vector<[4]x[4]xf32>`
+ * `vector<[2]x[2]xf64>`
+ }];
+}
+
+def TileID : AnyTypeOf<[I8, I16, I32, I64, I128],
+ "an identifier of a virtual tile (of a size) within the ZA storage">
+{
+ let description = [{
+ The tile ID is an 8, 16, 32, 64, or 128-bit signless integer. The value of
+ the integer indicates the tile to use, and the bit size indicates the size
+ of tile. The number of tiles available and the element types of those depend
+ on the size. This is summarised below:
+
+ | Tile ID Type | Possible Tile IDs | Tile Vector Types |
+ |--------------|---------------------|-------------------------------------------------------------------------|
+ | `i8` | 0 | `vector<[16]x[16]xi8>` |
+ | `i16` | 0 and 1 | `vector<[8]x[8]xi16>`, `vector<[8]x[8]xf16>`, or `vector<[8]x[8]xbf16>` |
+ | `i32` | 0 to 3 (inclusive) | `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>` |
+ | `i64` | 0 to 7 (inclusive) | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>` |
+ | `i128` | 0 to 15 (inclusive) | `vector<[1]x[1]xi128>` |
+ }];
+}
// A type constraint that verifies the bitwidth of the scalar integer returned
// from 'arm_sme.get_tile_id' matches the element bitwidth of a "virtual tile".
@@ -145,7 +185,7 @@ def CastTileToVector : ArmSME_Op<"cast_tile_to_vector", [Pure, TileElementWidthM
Canonicalization will look through `arm_sme.cast_tile_to_vector` and fold
the cast away if it comes from a `arm_sme.cast_vector_to_tile`.
}];
- let arguments = (ins AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id);
+ let arguments = (ins TileID:$tile_id);
let results = (outs SMETile:$vector);
let assemblyFormat =
"$tile_id attr-dict `:` type($tile_id) `to` type($vector)";
@@ -181,7 +221,7 @@ def CastVectorToTile : ArmSME_Op<"cast_vector_to_tile", [Pure, TileElementWidthM
the cast away if it comes from a `arm_sme.cast_tile_to_vector`.
}];
let arguments = (ins SMETile:$vector);
- let results = (outs AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id);
+ let results = (outs TileID:$tile_id);
let assemblyFormat =
"$vector attr-dict `:` type($vector) `to` type($tile_id)";
let hasCanonicalizeMethod = 1;
@@ -217,7 +257,7 @@ def GetTileID : ArmSME_Op<"get_tile_id"> {
```
}];
- let results = (outs AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id);
+ let results = (outs TileID:$tile_id);
let assemblyFormat = "attr-dict `:` type($tile_id)";
}
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 1d6386bbf3828fa..666847dc60f51a5 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -15,7 +15,7 @@ func.func @arm_sme_cast_tile_to_vector__bad_tile_id_bitwidth(%tile_id : i8) -> v
// -----
func.func @arm_sme_cast_tile_to_vector__bad_vector_type_rank_1(%tile_id : i8) -> vector<[16]xi8> {
- // expected-error at +1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]xi8>'}}
+ // expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]xi8>'}}
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]xi8>
return %0 : vector<[16]xi8>
}
@@ -23,7 +23,7 @@ func.func @arm_sme_cast_tile_to_vector__bad_vector_type_rank_1(%tile_id : i8) ->
// -----
func.func @arm_sme_cast_tile_to_vector__bad_vector_type_i4(%tile_id : i8) -> vector<[16]x[16]xi4> {
- // expected-error at +1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]x[16]xi4>'}}
+ // expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]x[16]xi4>'}}
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi4>
return %0 : vector<[16]x[16]xi4>
}
@@ -31,7 +31,7 @@ func.func @arm_sme_cast_tile_to_vector__bad_vector_type_i4(%tile_id : i8) -> vec
// -----
func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_0(%tile_id : i8) -> vector<16x[16]xi8> {
- // expected-error at +1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<16x[16]xi8>'}}
+ // expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<16x[16]xi8>'}}
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<16x[16]xi8>
return %0 : vector<16x[16]xi8>
}
@@ -39,7 +39,7 @@ func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_0(%tile
// -----
func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_1(%tile_id : i8) -> vector<[16]x16xi8> {
- // expected-error at +1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]x16xi8>'}}
+ // expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]x16xi8>'}}
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x16xi8>
return %0 : vector<[16]x16xi8>
}
@@ -47,7 +47,7 @@ func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_1(%tile
// -----
func.func @arm_sme_cast_tile_to_vector_bad_shape(%tile_id : i8) -> vector<[4]x[16]xi8> {
- // expected-error at +1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[4]x[16]xi8>'}}
+ // expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[4]x[16]xi8>'}}
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[4]x[16]xi8>
return %0 : vector<[4]x[16]xi8>
}
@@ -67,7 +67,7 @@ func.func @arm_sme_cast_vector_to_tile__bad_tile_id_bitwidth(%vector : vector<[1
// -----
func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -> i8 {
- // expected-error at +1 {{op operand #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]xi8>'}}
+ // expected-error at +1 {{op operand #0 must be a vector type that fits into a SME tile, but got 'vector<[16]xi8>'}}
%0 = arm_sme.cast_vector_to_tile %vector : vector<[16]xi8> to i8
return %0 : i8
}
@@ -79,7 +79,7 @@ func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -
// -----
func.func @arm_sme_get_tile_id__bad_type() -> i1 {
- // expected-error at +1 {{op result #0 must be 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer}}
+ // expected-error at +1 {{op result #0 must be an identifier of a virtual tile (of a size) within the ZA storage}}
%0 = arm_sme.get_tile_id : i1
return %0 : i1
}
@@ -172,7 +172,7 @@ func.func @arm_sme_load_tile_slice__bad_mask_type(%src : memref<?x?xi8>, %mask :
func.func @arm_sme_outerproduct__bad_result_type(%vecA: vector<[2]xi16>, %vecB: vector<[2]xi16>) -> vector<[2]x[2]xi16>
{
- // expected-error at +1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[2]x[2]xi16>'}}
+ // expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[2]x[2]xi16>'}}
%0 = arm_sme.outerproduct %vecA, %vecB : vector<[2]xi16>, vector<[2]xi16>
return %0 : vector<[2]x[2]xi16>
}
More information about the Mlir-commits
mailing list