[Mlir-commits] [mlir] [mlir][ArmSME] Provide descriptions and summaries for types (PR #70920)

Benjamin Maxwell llvmlistbot at llvm.org
Wed Nov 1 05:03:20 PDT 2023


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/70920

>From 5ebdff179fe22868d85782f2210028a1e4b02884 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 1 Nov 2023 10:37:32 +0000
Subject: [PATCH] [mlir][ArmSME] Provide descriptions and summaries for types

The auto-generated summaries are hard to read (and pretty unhelpful),
and SME tile was:

```
vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values
```

...and an SVE vector:

```
of ranks 1scalable vector of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer or 16-bit float or bfloat16 type or 32-bit float or 64-bit float values of length 16/8/4/2/1
```

Note: The descriptions added here won't yet be shown on the MLIR docs
(only the short summaries), but this should be easy to enable like
it was for attribute descriptions in #67009.

A table of contents (TOC) is also added to the ArmSME docs page to make
it easier to navigate.
---
 mlir/docs/Dialects/ArmSME.md                  |  2 +
 mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td | 38 ++++++++++++++-
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       | 48 +++++++++++++++++--
 mlir/test/Dialect/ArmSME/invalid.mlir         | 16 +++----
 4 files changed, 90 insertions(+), 14 deletions(-)

diff --git a/mlir/docs/Dialects/ArmSME.md b/mlir/docs/Dialects/ArmSME.md
index ab7c9ffe7aa92f1..505b52938eacc05 100644
--- a/mlir/docs/Dialects/ArmSME.md
+++ b/mlir/docs/Dialects/ArmSME.md
@@ -1,5 +1,7 @@
 # 'ArmSME' Dialect
 
+[TOC]
+
 Basic dialect to target Arm SME architectures This dialect contains the
 definitions necessary to target Arm SME scalable matrix operations.
 
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 18b9bd7a107febf..ffafb2569310ed1 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -42,11 +42,45 @@ def ArmSME_Dialect : Dialect {
 // ArmSME type definitions
 //===----------------------------------------------------------------------===//
 
+// FIXME: This allows types that are not SVE vectors, e.g. vector<[16]xi128>.
 def SVEVector : ScalableVectorOfRankAndLengthAndType<
-  [1], [16, 8, 4, 2, 1], [I8, I16, I32, I64, I128, F16, BF16, F32, F64]>;
+  [1], [16, 8, 4, 2, 1], [I8, I16, I32, I64, I128, F16, BF16, F32, F64]>
+{
+  let summary = "a vector type that matches the size of a SVE vector";
+  let description = [{
+    Possible vector types:
+
+    Integer elements:
+
+    * `vector<[16]xi8>`
+    * `vector<[8]xi16>`
+    * `vector<[4]xi32>`
+    * `vector<[2]xi64>`
+    * `vector<[1]xi128>`
+
+    Floating point elements:
+
+    * `vector<[8]xf16>`
+    * `vector<[8]xbf16>`
+    * `vector<[4]xf32>`
+    * `vector<[2]xf64>`
+  }];
+}
 
 def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
-  [1], [16, 8, 4, 2, 1], [I1]>;
+  [1], [16, 8, 4, 2, 1], [I1]>
+{
+  let summary = "a vector type that matches the size of a SVE predicate";
+  let description = [{
+    Possible vector types:
+
+    * `vector<[16]xi1>`
+    * `vector<[8]xi1>`
+    * `vector<[4]xi1>`
+    * `vector<[2]xi1>`
+    * `vector<[1]xi1>`
+  }];
+}
 
 
 #endif // ARMSME
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 37a2257a0015ce7..e57a8acd82de758 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -43,7 +43,47 @@ def nxnxv4f32  : SMETileType<F32,  [4,  4 ], "vector<[4]x[4]xf32>">;
 def nxnxv2f64  : SMETileType<F64,  [2,  2 ], "vector<[2]x[2]xf64>">;
 
 def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
-                         nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64]>;
+                         nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64],
+                        "a vector type that fits into a SME tile">
+{
+  let description = [{
+    Possible vector types:
+
+    Integer elements:
+
+    * `vector<[16]x[16]xi8>`
+    * `vector<[8]x[8]xi16>`
+    * `vector<[4]x[4]xi32>`
+    * `vector<[2]x[2]xi64>`
+    * `vector<[1]x[1]xi128>`
+
+    Floating point elements:
+
+    * `vector<[8]x[8]xf16>`
+    * `vector<[8]x[8]xbf16>`
+    * `vector<[4]x[4]xf32>`
+    * `vector<[2]x[2]xf64>`
+  }];
+}
+
+def TileID : AnyTypeOf<[I8, I16, I32, I64, I128],
+                       "an identifier of a virtual tile (of a size) within the ZA storage">
+{
+  let description = [{
+    The tile ID is an 8, 16, 32, 64, or 128-bit signless integer. The value of
+    the integer indicates the tile to use, and the bit size indicates the size
+    of tile. The number of tiles available and the element types of those depend
+    on the size. This is summarised below:
+
+    | Tile ID Type | Possible Tile IDs   | Tile Vector Types                                                       |
+    |--------------|---------------------|-------------------------------------------------------------------------|
+    | `i8`         | 0                   | `vector<[16]x[16]xi8>`                                                  |
+    | `i16`        | 0 and 1             | `vector<[8]x[8]xi16>`, `vector<[8]x[8]xf16>`, or `vector<[8]x[8]xbf16>` |
+    | `i32`        | 0 to 3 (inclusive)  | `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>`                          |
+    | `i64`        | 0 to 7 (inclusive)  | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>`                          |
+    | `i128`       | 0 to 15 (inclusive) | `vector<[1]x[1]xi128>`                                                  |
+  }];
+}
 
 // A type constraint that verifies the bitwidth of the scalar integer returned
 // from 'arm_sme.get_tile_id' matches the element bitwidth of a "virtual tile".
@@ -145,7 +185,7 @@ def CastTileToVector : ArmSME_Op<"cast_tile_to_vector", [Pure, TileElementWidthM
     Canonicalization will look through `arm_sme.cast_tile_to_vector` and fold
     the cast away if it comes from a `arm_sme.cast_vector_to_tile`.
   }];
-  let arguments = (ins AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id);
+  let arguments = (ins TileID:$tile_id);
   let results = (outs SMETile:$vector);
   let assemblyFormat =
     "$tile_id attr-dict `:` type($tile_id) `to` type($vector)";
@@ -181,7 +221,7 @@ def CastVectorToTile : ArmSME_Op<"cast_vector_to_tile", [Pure, TileElementWidthM
     the cast away if it comes from a `arm_sme.cast_tile_to_vector`.
   }];
   let arguments = (ins SMETile:$vector);
-  let results = (outs AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id);
+  let results = (outs TileID:$tile_id);
   let assemblyFormat =
     "$vector attr-dict `:` type($vector) `to` type($tile_id)";
   let hasCanonicalizeMethod = 1;
@@ -217,7 +257,7 @@ def GetTileID : ArmSME_Op<"get_tile_id"> {
     ```
   }];
 
-  let results = (outs AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id);
+  let results = (outs TileID:$tile_id);
   let assemblyFormat = "attr-dict `:` type($tile_id)";
 }
 
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 1d6386bbf3828fa..666847dc60f51a5 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -15,7 +15,7 @@ func.func @arm_sme_cast_tile_to_vector__bad_tile_id_bitwidth(%tile_id : i8) -> v
 // -----
 
 func.func @arm_sme_cast_tile_to_vector__bad_vector_type_rank_1(%tile_id : i8) -> vector<[16]xi8> {
-  // expected-error at +1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]xi8>'}}
+  // expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]xi8>'}}
   %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]xi8>
   return %0 : vector<[16]xi8>
 }
@@ -23,7 +23,7 @@ func.func @arm_sme_cast_tile_to_vector__bad_vector_type_rank_1(%tile_id : i8) ->
 // -----
 
 func.func @arm_sme_cast_tile_to_vector__bad_vector_type_i4(%tile_id : i8) -> vector<[16]x[16]xi4> {
-  // expected-error at +1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]x[16]xi4>'}}
+  // expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]x[16]xi4>'}}
   %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi4>
   return %0 : vector<[16]x[16]xi4>
 }
@@ -31,7 +31,7 @@ func.func @arm_sme_cast_tile_to_vector__bad_vector_type_i4(%tile_id : i8) -> vec
 // -----
 
 func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_0(%tile_id : i8) -> vector<16x[16]xi8> {
-  // expected-error at +1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<16x[16]xi8>'}}
+  // expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<16x[16]xi8>'}}
   %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<16x[16]xi8>
   return %0 : vector<16x[16]xi8>
 }
@@ -39,7 +39,7 @@ func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_0(%tile
 // -----
 
 func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_1(%tile_id : i8) -> vector<[16]x16xi8> {
-  // expected-error at +1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]x16xi8>'}}
+  // expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]x16xi8>'}}
   %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x16xi8>
   return %0 : vector<[16]x16xi8>
 }
@@ -47,7 +47,7 @@ func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_1(%tile
 // -----
 
 func.func @arm_sme_cast_tile_to_vector_bad_shape(%tile_id : i8) -> vector<[4]x[16]xi8> {
-  // expected-error at +1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[4]x[16]xi8>'}}
+  // expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[4]x[16]xi8>'}}
   %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[4]x[16]xi8>
   return %0 : vector<[4]x[16]xi8>
 }
@@ -67,7 +67,7 @@ func.func @arm_sme_cast_vector_to_tile__bad_tile_id_bitwidth(%vector : vector<[1
 // -----
 
 func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -> i8 {
-  // expected-error at +1 {{op operand #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]xi8>'}}
+  // expected-error at +1 {{op operand #0 must be a vector type that fits into a SME tile, but got 'vector<[16]xi8>'}}
   %0 = arm_sme.cast_vector_to_tile %vector : vector<[16]xi8> to i8
   return %0 : i8
 }
@@ -79,7 +79,7 @@ func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -
 // -----
 
 func.func @arm_sme_get_tile_id__bad_type() -> i1 {
-  // expected-error at +1 {{op result #0 must be 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer}}
+  // expected-error at +1 {{op result #0 must be an identifier of a virtual tile (of a size) within the ZA storage}}
   %0 = arm_sme.get_tile_id : i1
   return %0 : i1
 }
@@ -172,7 +172,7 @@ func.func @arm_sme_load_tile_slice__bad_mask_type(%src : memref<?x?xi8>, %mask :
 
 func.func @arm_sme_outerproduct__bad_result_type(%vecA: vector<[2]xi16>, %vecB: vector<[2]xi16>) -> vector<[2]x[2]xi16>
 {
-  // expected-error at +1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[2]x[2]xi16>'}}
+  // expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[2]x[2]xi16>'}}
   %0 = arm_sme.outerproduct %vecA, %vecB : vector<[2]xi16>, vector<[2]xi16>
   return %0 : vector<[2]x[2]xi16>
 }



More information about the Mlir-commits mailing list