[Mlir-commits] [mlir] [mlir][ArmSME] Add tile slice to vector intrinsics (PR #66910)

Benjamin Maxwell llvmlistbot at llvm.org
Thu Sep 21 09:56:56 PDT 2023


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/66910

>From 896ab7de80ddaf7ec36bb39a0f1d3f45950a42dd Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 20 Sep 2023 12:50:24 +0000
Subject: [PATCH 1/4] [mlir][ArmSME] Add tile slice to vector intrinsics

Add support for following vector to tile (MOVA) intrinsics to ArmSME
dialect:

  llvm.aarch64.sme.read.vert
  llvm.aarch64.sme.read.horiz

This also slightly updates ArmSME_IntrOp to support return values.
---
 mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td |  23 ++-
 mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp         |   1 +
 mlir/test/Target/LLVMIR/arm-sme-invalid.mlir  |  24 ++++
 mlir/test/Target/LLVMIR/arm-sme.mlir          | 134 ++++++++++++++++++
 4 files changed, 178 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 7f02e723f3d91c2..00e1fefc0521a78 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -469,15 +469,16 @@ def MOPVector : ScalableVectorOfLengthAndType<[16, 8, 4, 2],
 def LDSTPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2, 1], [I1]>;
 
 class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
-                    list<Trait> traits = []>
+                    list<Trait> traits = [], int numResults = 0,
+                    list<int> overloadedResults = []>
     : LLVM_IntrOpBase<
           /*Dialect dialect=*/ArmSME_Dialect,
           /*string opName=*/"intr." # mnemonic,
           /*string enumName=*/"aarch64_sme_" # !subst(".", "_", mnemonic),
-          /*list<int> overloadedResults=*/[],
+          /*list<int> overloadedResults=*/overloadedResults,
           /*list<int> overloadedOperands=*/overloadedOperands,
           /*list<Trait> traits=*/traits,
-          /*int numResults=*/0>;
+          /*int numResults=*/numResults>;
 
 // Zero
 def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
@@ -548,7 +549,7 @@ def LLVM_aarch64_sme_str
       Arguments<(ins Arg<I32, "Index">,
                  Arg<LLVM_AnyPointer, "Store address", [MemWrite]>)>;
 
-// Vector to tile
+// Vector to tile slice
 class LLVM_aarch64_sme_write<string direction>
     : ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
                     [AllShapesMatch<["pg", "vector"]>]>,
@@ -557,9 +558,23 @@ class LLVM_aarch64_sme_write<string direction>
                      Arg<SVEPredicate, "Vector predicate">:$pg,
                      Arg<SVEVector, "Vector operand">:$vector)>;
 
+// Tile slice to vector
+class LLVM_aarch64_sme_read<string direction>
+    : ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
+                    [AllShapesMatch<["vector", "pg", "res"]>,
+                     AllElementTypesMatch<["vector", "res"]>],
+                    /*numResults*/1, /*overloadedResults*/[0]>,
+      Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
+                     Arg<SVEPredicate, "Vector predicate">:$pg,
+                     Arg<I32, "Virtual tile ID">,
+                     Arg<I32, "Tile slice">)>;
+
 def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
 def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
 
+def LLVM_aarch64_sme_read_horiz : LLVM_aarch64_sme_read<"horiz">;
+def LLVM_aarch64_sme_read_vert : LLVM_aarch64_sme_read<"vert">;
+
 def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
 def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
 
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 750627421215dfb..7cbc382b0050a6e 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/TypeUtilities.h"
 
 using namespace mlir;
 using namespace mlir::arm_sme;
diff --git a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
index e119e1f1a404416..ae99ac5e02d62f0 100644
--- a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
@@ -10,3 +10,27 @@ llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32,
       (i32, i32, vector<[4]xi1>, vector<[16]xi8>) -> ()
   llvm.return
 }
+
+// -----
+
+llvm.func @arm_sme_tile_slice_to_vector_invalid_shapes(
+  %tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv16i8 : vector<[16]xi8>
+) -> vector<[3]xf32> {
+  %tile = llvm.mlir.constant(0 : index) : i32
+  // expected-error @+1 {{failed to verify that all of {vector, pg, res} have same shape}}
+  %res = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv4i1, %tile, %tileslice) :
+      (vector<[16]xi8>, vector<[4]xi1>, i32, i32) -> vector<[3]xf32>
+  llvm.return %res : vector<[3]xf32>
+}
+
+// -----
+
+llvm.func @arm_sme_tile_slice_to_vector_invalid_element_types(
+  %tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv4f32 : vector<[4]xf32>
+) -> vector<[3]xi32> {
+  %tile = llvm.mlir.constant(0 : index) : i32
+  // expected-error @+1 {{failed to verify that all of {vector, res} have same element type}}
+  %res = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice) :
+      (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  llvm.return %res : vector<[4]xi32>
+}
diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
index 9bb6b0c6574fcdb..c318e6d2d37f7fd 100644
--- a/mlir/test/Target/LLVMIR/arm-sme.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -334,3 +334,137 @@ llvm.func @arm_sme_vector_to_tile_vert(%tileslice : i32,
       (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
   llvm.return
 }
+
+// -----
+
+llvm.func @prevent_dce.nxv16i8(vector<[16]xi8>)
+llvm.func @prevent_dce.nxv8i16(vector<[8]xi16>)
+llvm.func @prevent_dce.nxv4i32(vector<[4]xi32>)
+llvm.func @prevent_dce.nxv2i64(vector<[2]xi64>)
+llvm.func @prevent_dce.nxv1i128(vector<[1]xi128>)
+llvm.func @prevent_dce.nxv8f16(vector<[8]xf16>)
+llvm.func @prevent_dce.nxv8bf16(vector<[8]xbf16>)
+llvm.func @prevent_dce.nxv4f32(vector<[4]xf32>)
+llvm.func @prevent_dce.nxv2f64(vector<[2]xf64>)
+
+llvm.func @arm_sme_tile_slice_to_vector_horiz(%tileslice : i32,
+                                              %nxv16i1 : vector<[16]xi1>,
+                                              %nxv8i1 : vector<[8]xi1>,
+                                              %nxv4i1 : vector<[4]xi1>,
+                                              %nxv2i1 : vector<[2]xi1>,
+                                              %nxv1i1 : vector<[1]xi1>,
+                                              %nxv16i8 : vector<[16]xi8>,
+                                              %nxv8i16 : vector<[8]xi16>,
+                                              %nxv4i32 : vector<[4]xi32>,
+                                              %nxv2i64 : vector<[2]xi64>,
+                                              %nxv1i128 : vector<[1]xi128>,
+                                              %nxv8f16 : vector<[8]xf16>,
+                                              %nxv8bf16 : vector<[8]xbf16>,
+                                              %nxv4f32 : vector<[4]xf32>,
+                                              %nxv2f64 : vector<[2]xf64>) {
+  %tile = llvm.mlir.constant(0 : index) : i32
+  // CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.horiz.nxv16i8
+  %res0 = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv16i1, %tile, %tileslice)
+    : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+  llvm.call @prevent_dce.nxv16i8(%res0) : (vector<[16]xi8>) -> ()
+  // CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.horiz.nxv8i16
+  %res1 = "arm_sme.intr.read.horiz"(%nxv8i16, %nxv8i1, %tile, %tileslice)
+    : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+  llvm.call @prevent_dce.nxv8i16(%res1) : (vector<[8]xi16>) -> ()
+  // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.horiz.nxv4i32
+  %res2 = "arm_sme.intr.read.horiz"(%nxv4i32, %nxv4i1, %tile, %tileslice)
+    : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  llvm.call @prevent_dce.nxv4i32(%res2) : (vector<[4]xi32>) -> ()
+  // CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.horiz.nxv2i64
+  %res3 = "arm_sme.intr.read.horiz"(%nxv2i64, %nxv2i1, %tile, %tileslice)
+    : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+  llvm.call @prevent_dce.nxv2i64(%res3) : (vector<[2]xi64>) -> ()
+  // CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.horiz.nxv1i128
+  %res4 = "arm_sme.intr.read.horiz"(%nxv1i128, %nxv1i1, %tile, %tileslice)
+    : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+  llvm.call @prevent_dce.nxv1i128(%res4) : (vector<[1]xi128>) -> ()
+  // CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.horiz.nxv8f16
+  %res5 = "arm_sme.intr.read.horiz"(%nxv8f16, %nxv8i1, %tile, %tileslice)
+    : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+  llvm.call @prevent_dce.nxv8f16(%res5) : (vector<[8]xf16>) -> ()
+  // CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.horiz.nxv8bf16
+  %res6 = "arm_sme.intr.read.horiz"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
+    : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+  llvm.call @prevent_dce.nxv8bf16(%res6) : (vector<[8]xbf16>) -> ()
+  // CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.horiz.nxv4f32
+  %res7 = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice)
+    : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+  llvm.call @prevent_dce.nxv4f32(%res7) : (vector<[4]xf32>) -> ()
+  // CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.horiz.nxv2f64
+  %res8 = "arm_sme.intr.read.horiz"(%nxv2f64, %nxv2i1, %tile, %tileslice)
+    : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+  llvm.call @prevent_dce.nxv2f64(%res8) : (vector<[2]xf64>) -> ()
+  llvm.return
+}
+
+// -----
+
+llvm.func @prevent_dce.nxv16i8(vector<[16]xi8>)
+llvm.func @prevent_dce.nxv8i16(vector<[8]xi16>)
+llvm.func @prevent_dce.nxv4i32(vector<[4]xi32>)
+llvm.func @prevent_dce.nxv2i64(vector<[2]xi64>)
+llvm.func @prevent_dce.nxv1i128(vector<[1]xi128>)
+llvm.func @prevent_dce.nxv8f16(vector<[8]xf16>)
+llvm.func @prevent_dce.nxv8bf16(vector<[8]xbf16>)
+llvm.func @prevent_dce.nxv4f32(vector<[4]xf32>)
+llvm.func @prevent_dce.nxv2f64(vector<[2]xf64>)
+
+llvm.func @arm_sme_tile_slice_to_vector_vert(%tileslice : i32,
+                                              %nxv16i1 : vector<[16]xi1>,
+                                              %nxv8i1 : vector<[8]xi1>,
+                                              %nxv4i1 : vector<[4]xi1>,
+                                              %nxv2i1 : vector<[2]xi1>,
+                                              %nxv1i1 : vector<[1]xi1>,
+                                              %nxv16i8 : vector<[16]xi8>,
+                                              %nxv8i16 : vector<[8]xi16>,
+                                              %nxv4i32 : vector<[4]xi32>,
+                                              %nxv2i64 : vector<[2]xi64>,
+                                              %nxv1i128 : vector<[1]xi128>,
+                                              %nxv8f16 : vector<[8]xf16>,
+                                              %nxv8bf16 : vector<[8]xbf16>,
+                                              %nxv4f32 : vector<[4]xf32>,
+                                              %nxv2f64 : vector<[2]xf64>) {
+  %tile = llvm.mlir.constant(0 : index) : i32
+  // CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.vert.nxv16i8
+  %res0 = "arm_sme.intr.read.vert"(%nxv16i8, %nxv16i1, %tile, %tileslice)
+    : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+  llvm.call @prevent_dce.nxv16i8(%res0) : (vector<[16]xi8>) -> ()
+  // CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.vert.nxv8i16
+  %res1 = "arm_sme.intr.read.vert"(%nxv8i16, %nxv8i1, %tile, %tileslice)
+    : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+  llvm.call @prevent_dce.nxv8i16(%res1) : (vector<[8]xi16>) -> ()
+  // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.vert.nxv4i32
+  %res2 = "arm_sme.intr.read.vert"(%nxv4i32, %nxv4i1, %tile, %tileslice)
+    : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  llvm.call @prevent_dce.nxv4i32(%res2) : (vector<[4]xi32>) -> ()
+  // CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.vert.nxv2i64
+  %res3 = "arm_sme.intr.read.vert"(%nxv2i64, %nxv2i1, %tile, %tileslice)
+    : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+  llvm.call @prevent_dce.nxv2i64(%res3) : (vector<[2]xi64>) -> ()
+  // CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.vert.nxv1i128
+  %res4 = "arm_sme.intr.read.vert"(%nxv1i128, %nxv1i1, %tile, %tileslice)
+    : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+  llvm.call @prevent_dce.nxv1i128(%res4) : (vector<[1]xi128>) -> ()
+  // CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.vert.nxv8f16
+  %res5 = "arm_sme.intr.read.vert"(%nxv8f16, %nxv8i1, %tile, %tileslice)
+    : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+  llvm.call @prevent_dce.nxv8f16(%res5) : (vector<[8]xf16>) -> ()
+  // CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.vert.nxv8bf16
+  %res6 = "arm_sme.intr.read.vert"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
+    : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+  llvm.call @prevent_dce.nxv8bf16(%res6) : (vector<[8]xbf16>) -> ()
+  // CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.vert.nxv4f32
+  %res7 = "arm_sme.intr.read.vert"(%nxv4f32, %nxv4i1, %tile, %tileslice)
+    : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+  llvm.call @prevent_dce.nxv4f32(%res7) : (vector<[4]xf32>) -> ()
+  // CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.vert.nxv2f64
+  %res8 = "arm_sme.intr.read.vert"(%nxv2f64, %nxv2i1, %tile, %tileslice)
+    : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+  llvm.call @prevent_dce.nxv2f64(%res8) : (vector<[2]xf64>) -> ()
+  llvm.return
+}

>From a6d389c8733b633958dfa4f8de37915eb8cd84d0 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 20 Sep 2023 14:54:17 +0000
Subject: [PATCH 2/4] Review fixups

Fixes: https://github.com/llvm/llvm-project/pull/66910#discussion_r1331730482
Fixes: https://github.com/llvm/llvm-project/pull/66910#discussion_r1331734063
---
 mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td |  2 +-
 mlir/test/Target/LLVMIR/arm-sme.mlir          | 37 -------------------
 2 files changed, 1 insertion(+), 38 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 00e1fefc0521a78..1ca284a3e70dcec 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -563,7 +563,7 @@ class LLVM_aarch64_sme_read<string direction>
     : ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
                     [AllShapesMatch<["vector", "pg", "res"]>,
                      AllElementTypesMatch<["vector", "res"]>],
-                    /*numResults*/1, /*overloadedResults*/[0]>,
+                    /*numResults=*/1, /*overloadedResults=*/[0]>,
       Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
                      Arg<SVEPredicate, "Vector predicate">:$pg,
                      Arg<I32, "Virtual tile ID">,
diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
index c318e6d2d37f7fd..5d697993c323245 100644
--- a/mlir/test/Target/LLVMIR/arm-sme.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -337,15 +337,6 @@ llvm.func @arm_sme_vector_to_tile_vert(%tileslice : i32,
 
 // -----
 
-llvm.func @prevent_dce.nxv16i8(vector<[16]xi8>)
-llvm.func @prevent_dce.nxv8i16(vector<[8]xi16>)
-llvm.func @prevent_dce.nxv4i32(vector<[4]xi32>)
-llvm.func @prevent_dce.nxv2i64(vector<[2]xi64>)
-llvm.func @prevent_dce.nxv1i128(vector<[1]xi128>)
-llvm.func @prevent_dce.nxv8f16(vector<[8]xf16>)
-llvm.func @prevent_dce.nxv8bf16(vector<[8]xbf16>)
-llvm.func @prevent_dce.nxv4f32(vector<[4]xf32>)
-llvm.func @prevent_dce.nxv2f64(vector<[2]xf64>)
 
 llvm.func @arm_sme_tile_slice_to_vector_horiz(%tileslice : i32,
                                               %nxv16i1 : vector<[16]xi1>,
@@ -366,54 +357,35 @@ llvm.func @arm_sme_tile_slice_to_vector_horiz(%tileslice : i32,
   // CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.horiz.nxv16i8
   %res0 = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv16i1, %tile, %tileslice)
     : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
-  llvm.call @prevent_dce.nxv16i8(%res0) : (vector<[16]xi8>) -> ()
   // CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.horiz.nxv8i16
   %res1 = "arm_sme.intr.read.horiz"(%nxv8i16, %nxv8i1, %tile, %tileslice)
     : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
-  llvm.call @prevent_dce.nxv8i16(%res1) : (vector<[8]xi16>) -> ()
   // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.horiz.nxv4i32
   %res2 = "arm_sme.intr.read.horiz"(%nxv4i32, %nxv4i1, %tile, %tileslice)
     : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
-  llvm.call @prevent_dce.nxv4i32(%res2) : (vector<[4]xi32>) -> ()
   // CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.horiz.nxv2i64
   %res3 = "arm_sme.intr.read.horiz"(%nxv2i64, %nxv2i1, %tile, %tileslice)
     : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
-  llvm.call @prevent_dce.nxv2i64(%res3) : (vector<[2]xi64>) -> ()
   // CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.horiz.nxv1i128
   %res4 = "arm_sme.intr.read.horiz"(%nxv1i128, %nxv1i1, %tile, %tileslice)
     : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
-  llvm.call @prevent_dce.nxv1i128(%res4) : (vector<[1]xi128>) -> ()
   // CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.horiz.nxv8f16
   %res5 = "arm_sme.intr.read.horiz"(%nxv8f16, %nxv8i1, %tile, %tileslice)
     : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
-  llvm.call @prevent_dce.nxv8f16(%res5) : (vector<[8]xf16>) -> ()
   // CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.horiz.nxv8bf16
   %res6 = "arm_sme.intr.read.horiz"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
     : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
-  llvm.call @prevent_dce.nxv8bf16(%res6) : (vector<[8]xbf16>) -> ()
   // CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.horiz.nxv4f32
   %res7 = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice)
     : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
-  llvm.call @prevent_dce.nxv4f32(%res7) : (vector<[4]xf32>) -> ()
   // CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.horiz.nxv2f64
   %res8 = "arm_sme.intr.read.horiz"(%nxv2f64, %nxv2i1, %tile, %tileslice)
     : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
-  llvm.call @prevent_dce.nxv2f64(%res8) : (vector<[2]xf64>) -> ()
   llvm.return
 }
 
 // -----
 
-llvm.func @prevent_dce.nxv16i8(vector<[16]xi8>)
-llvm.func @prevent_dce.nxv8i16(vector<[8]xi16>)
-llvm.func @prevent_dce.nxv4i32(vector<[4]xi32>)
-llvm.func @prevent_dce.nxv2i64(vector<[2]xi64>)
-llvm.func @prevent_dce.nxv1i128(vector<[1]xi128>)
-llvm.func @prevent_dce.nxv8f16(vector<[8]xf16>)
-llvm.func @prevent_dce.nxv8bf16(vector<[8]xbf16>)
-llvm.func @prevent_dce.nxv4f32(vector<[4]xf32>)
-llvm.func @prevent_dce.nxv2f64(vector<[2]xf64>)
-
 llvm.func @arm_sme_tile_slice_to_vector_vert(%tileslice : i32,
                                               %nxv16i1 : vector<[16]xi1>,
                                               %nxv8i1 : vector<[8]xi1>,
@@ -433,38 +405,29 @@ llvm.func @arm_sme_tile_slice_to_vector_vert(%tileslice : i32,
   // CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.vert.nxv16i8
   %res0 = "arm_sme.intr.read.vert"(%nxv16i8, %nxv16i1, %tile, %tileslice)
     : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
-  llvm.call @prevent_dce.nxv16i8(%res0) : (vector<[16]xi8>) -> ()
   // CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.vert.nxv8i16
   %res1 = "arm_sme.intr.read.vert"(%nxv8i16, %nxv8i1, %tile, %tileslice)
     : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
-  llvm.call @prevent_dce.nxv8i16(%res1) : (vector<[8]xi16>) -> ()
   // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.vert.nxv4i32
   %res2 = "arm_sme.intr.read.vert"(%nxv4i32, %nxv4i1, %tile, %tileslice)
     : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
-  llvm.call @prevent_dce.nxv4i32(%res2) : (vector<[4]xi32>) -> ()
   // CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.vert.nxv2i64
   %res3 = "arm_sme.intr.read.vert"(%nxv2i64, %nxv2i1, %tile, %tileslice)
     : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
-  llvm.call @prevent_dce.nxv2i64(%res3) : (vector<[2]xi64>) -> ()
   // CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.vert.nxv1i128
   %res4 = "arm_sme.intr.read.vert"(%nxv1i128, %nxv1i1, %tile, %tileslice)
     : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
-  llvm.call @prevent_dce.nxv1i128(%res4) : (vector<[1]xi128>) -> ()
   // CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.vert.nxv8f16
   %res5 = "arm_sme.intr.read.vert"(%nxv8f16, %nxv8i1, %tile, %tileslice)
     : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
-  llvm.call @prevent_dce.nxv8f16(%res5) : (vector<[8]xf16>) -> ()
   // CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.vert.nxv8bf16
   %res6 = "arm_sme.intr.read.vert"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
     : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
-  llvm.call @prevent_dce.nxv8bf16(%res6) : (vector<[8]xbf16>) -> ()
   // CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.vert.nxv4f32
   %res7 = "arm_sme.intr.read.vert"(%nxv4f32, %nxv4i1, %tile, %tileslice)
     : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
-  llvm.call @prevent_dce.nxv4f32(%res7) : (vector<[4]xf32>) -> ()
   // CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.vert.nxv2f64
   %res8 = "arm_sme.intr.read.vert"(%nxv2f64, %nxv2i1, %tile, %tileslice)
     : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
-  llvm.call @prevent_dce.nxv2f64(%res8) : (vector<[2]xf64>) -> ()
   llvm.return
 }

>From ec90210e146e12648354363ea402130aaed03697 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 21 Sep 2023 11:57:37 +0000
Subject: [PATCH 3/4] [mlir][docgen] Display full attribute descriptions in
 expandable regions
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This updates the table of op attributes so that clicking the summary
expands to show the complete description.

   Attribute | MLIR Type | Description
   <name>      <type>      ▶ <summary>  <-- Click to expand

Enum attributes have now also been updated to generate a description
that lists all the cases (with both their MLIR and C++ names). This
makes viewing enums on the MLIR docs much nicer.
---
 mlir/include/mlir/IR/EnumAttr.td    |  8 +++++++
 mlir/tools/mlir-tblgen/OpDocGen.cpp | 36 +++++++++++++++++++++++++----
 2 files changed, 40 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td
index 485c5266f3cfdfa..cb918b5eceb1a19 100644
--- a/mlir/include/mlir/IR/EnumAttr.td
+++ b/mlir/include/mlir/IR/EnumAttr.td
@@ -113,6 +113,13 @@ class I64BitEnumAttrCaseGroup<string sym, list<BitEnumAttrCaseBase> cases,
 class EnumAttrInfo<
     string name, list<EnumAttrCaseInfo> cases, Attr baseClass> :
       Attr<baseClass.predicate, baseClass.summary> {
+
+  // Generate a description of this enums members for the MLIR docs.
+  let description =
+        "Enum cases:\n" # !interleave(
+          !foreach(case, cases,
+              "* " # case.str  # " (`" # case.symbol # "`)"), "\n");
+
   // The C++ enum class name
   string className = name;
 
@@ -381,6 +388,7 @@ class EnumAttr<Dialect dialect, EnumAttrInfo enumInfo, string name = "",
                list <Trait> traits = []>
     : AttrDef<dialect, enumInfo.className, traits> {
   let summary = enumInfo.summary;
+  let description = enumInfo.description;
 
   // The backing enumeration.
   EnumAttrInfo enum = enumInfo;
diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp
index c546763880a853f..d2538fc49b20d0d 100644
--- a/mlir/tools/mlir-tblgen/OpDocGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp
@@ -173,6 +173,13 @@ static void emitOpTraitsDoc(const Operator &op, raw_ostream &os) {
   }
 }
 
+static StringRef resolveAttrDescription(const Attribute &attr) {
+  StringRef description = attr.getDescription();
+  if (description.empty())
+    return attr.getBaseAttr().getDescription();
+  return description;
+}
+
 static void emitOpDoc(const Operator &op, raw_ostream &os) {
   std::string classNameStr = op.getQualCppClassName();
   StringRef className = classNameStr;
@@ -195,13 +202,34 @@ static void emitOpDoc(const Operator &op, raw_ostream &os) {
     // TODO: Attributes are only documented by TableGen name, with no further
     // info. This should be improved.
     os << "\n#### Attributes:\n\n";
-    os << "| Attribute | MLIR Type | Description |\n"
-       << "| :-------: | :-------: | ----------- |\n";
+    // Note: This table is HTML rather than markdown so the attribute's
+    // description can appear in an expandable region. The description may be
+    // multiple lines, which is not supported in a markdown table cell.
+    os << "<table>\n";
+    // Header.
+    os << "<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>\n";
     for (const auto &it : op.getAttributes()) {
       StringRef storageType = it.attr.getStorageType();
-      os << "| `" << it.name << "` | " << storageType << " | "
-         << it.attr.getSummary() << "\n";
+      // Name and storage type.
+      os << "<tr>";
+      os << "<td><code>" << it.name << "</code></td><td>" << storageType
+         << "</td><td>";
+      StringRef description = resolveAttrDescription(it.attr);
+      if (!description.empty()) {
+        // Expandable description.
+        // This appears as just the summary, but when clicked shows the full
+        // description.
+        os << "<details>"
+           << "<summary>" << it.attr.getSummary() << "</summary>"
+           << "{{% markdown %}}" << description << "{{% /markdown %}}"
+           << "</details>";
+      } else {
+        // Fallback: Single-line summary.
+        os << it.attr.getSummary();
+      }
+      os << "</td></tr>\n";
     }
+    os << "<table>\n";
   }
 
   // Emit each of the operands.

>From 48160912f657fc94a55127f7728c133320bfa25e Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 21 Sep 2023 16:56:16 +0000
Subject: [PATCH 4/4] Align `:` in params

---
 mlir/test/Target/LLVMIR/arm-sme.mlir | 52 ++++++++++++++--------------
 1 file changed, 26 insertions(+), 26 deletions(-)

diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
index 5d697993c323245..628d7ba4b649e51 100644
--- a/mlir/test/Target/LLVMIR/arm-sme.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -339,20 +339,20 @@ llvm.func @arm_sme_vector_to_tile_vert(%tileslice : i32,
 
 
 llvm.func @arm_sme_tile_slice_to_vector_horiz(%tileslice : i32,
-                                              %nxv16i1 : vector<[16]xi1>,
-                                              %nxv8i1 : vector<[8]xi1>,
-                                              %nxv4i1 : vector<[4]xi1>,
-                                              %nxv2i1 : vector<[2]xi1>,
-                                              %nxv1i1 : vector<[1]xi1>,
-                                              %nxv16i8 : vector<[16]xi8>,
-                                              %nxv8i16 : vector<[8]xi16>,
-                                              %nxv4i32 : vector<[4]xi32>,
-                                              %nxv2i64 : vector<[2]xi64>,
-                                              %nxv1i128 : vector<[1]xi128>,
-                                              %nxv8f16 : vector<[8]xf16>,
-                                              %nxv8bf16 : vector<[8]xbf16>,
-                                              %nxv4f32 : vector<[4]xf32>,
-                                              %nxv2f64 : vector<[2]xf64>) {
+                                              %nxv16i1   : vector<[16]xi1>,
+                                              %nxv8i1    : vector<[8]xi1>,
+                                              %nxv4i1    : vector<[4]xi1>,
+                                              %nxv2i1    : vector<[2]xi1>,
+                                              %nxv1i1    : vector<[1]xi1>,
+                                              %nxv16i8   : vector<[16]xi8>,
+                                              %nxv8i16   : vector<[8]xi16>,
+                                              %nxv4i32   : vector<[4]xi32>,
+                                              %nxv2i64   : vector<[2]xi64>,
+                                              %nxv1i128  : vector<[1]xi128>,
+                                              %nxv8f16   : vector<[8]xf16>,
+                                              %nxv8bf16  : vector<[8]xbf16>,
+                                              %nxv4f32   : vector<[4]xf32>,
+                                              %nxv2f64   : vector<[2]xf64>) {
   %tile = llvm.mlir.constant(0 : index) : i32
   // CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.horiz.nxv16i8
   %res0 = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv16i1, %tile, %tileslice)
@@ -387,20 +387,20 @@ llvm.func @arm_sme_tile_slice_to_vector_horiz(%tileslice : i32,
 // -----
 
 llvm.func @arm_sme_tile_slice_to_vector_vert(%tileslice : i32,
-                                              %nxv16i1 : vector<[16]xi1>,
-                                              %nxv8i1 : vector<[8]xi1>,
-                                              %nxv4i1 : vector<[4]xi1>,
-                                              %nxv2i1 : vector<[2]xi1>,
-                                              %nxv1i1 : vector<[1]xi1>,
-                                              %nxv16i8 : vector<[16]xi8>,
-                                              %nxv8i16 : vector<[8]xi16>,
-                                              %nxv4i32 : vector<[4]xi32>,
-                                              %nxv2i64 : vector<[2]xi64>,
+                                              %nxv16i1  : vector<[16]xi1>,
+                                              %nxv8i1   : vector<[8]xi1>,
+                                              %nxv4i1   : vector<[4]xi1>,
+                                              %nxv2i1   : vector<[2]xi1>,
+                                              %nxv1i1   : vector<[1]xi1>,
+                                              %nxv16i8  : vector<[16]xi8>,
+                                              %nxv8i16  : vector<[8]xi16>,
+                                              %nxv4i32  : vector<[4]xi32>,
+                                              %nxv2i64  : vector<[2]xi64>,
                                               %nxv1i128 : vector<[1]xi128>,
-                                              %nxv8f16 : vector<[8]xf16>,
+                                              %nxv8f16  : vector<[8]xf16>,
                                               %nxv8bf16 : vector<[8]xbf16>,
-                                              %nxv4f32 : vector<[4]xf32>,
-                                              %nxv2f64 : vector<[2]xf64>) {
+                                              %nxv4f32  : vector<[4]xf32>,
+                                              %nxv2f64  : vector<[2]xf64>) {
   %tile = llvm.mlir.constant(0 : index) : i32
   // CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.vert.nxv16i8
   %res0 = "arm_sme.intr.read.vert"(%nxv16i8, %nxv16i1, %tile, %tileslice)



More information about the Mlir-commits mailing list