[Mlir-commits] [mlir] [mlir][ArmSME] Add tile slice to vector intrinsics (PR #66910)
Benjamin Maxwell
llvmlistbot at llvm.org
Thu Sep 21 09:56:56 PDT 2023
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/66910
>From 896ab7de80ddaf7ec36bb39a0f1d3f45950a42dd Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 20 Sep 2023 12:50:24 +0000
Subject: [PATCH 1/4] [mlir][ArmSME] Add tile slice to vector intrinsics
Add support for following vector to tile (MOVA) intrinsics to ArmSME
dialect:
llvm.aarch64.sme.read.vert
llvm.aarch64.sme.read.horiz
This also slightly updates ArmSME_IntrOp to support return values.
---
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td | 23 ++-
mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp | 1 +
mlir/test/Target/LLVMIR/arm-sme-invalid.mlir | 24 ++++
mlir/test/Target/LLVMIR/arm-sme.mlir | 134 ++++++++++++++++++
4 files changed, 178 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 7f02e723f3d91c2..00e1fefc0521a78 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -469,15 +469,16 @@ def MOPVector : ScalableVectorOfLengthAndType<[16, 8, 4, 2],
def LDSTPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2, 1], [I1]>;
class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
- list<Trait> traits = []>
+ list<Trait> traits = [], int numResults = 0,
+ list<int> overloadedResults = []>
: LLVM_IntrOpBase<
/*Dialect dialect=*/ArmSME_Dialect,
/*string opName=*/"intr." # mnemonic,
/*string enumName=*/"aarch64_sme_" # !subst(".", "_", mnemonic),
- /*list<int> overloadedResults=*/[],
+ /*list<int> overloadedResults=*/overloadedResults,
/*list<int> overloadedOperands=*/overloadedOperands,
/*list<Trait> traits=*/traits,
- /*int numResults=*/0>;
+ /*int numResults=*/numResults>;
// Zero
def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
@@ -548,7 +549,7 @@ def LLVM_aarch64_sme_str
Arguments<(ins Arg<I32, "Index">,
Arg<LLVM_AnyPointer, "Store address", [MemWrite]>)>;
-// Vector to tile
+// Vector to tile slice
class LLVM_aarch64_sme_write<string direction>
: ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
[AllShapesMatch<["pg", "vector"]>]>,
@@ -557,9 +558,23 @@ class LLVM_aarch64_sme_write<string direction>
Arg<SVEPredicate, "Vector predicate">:$pg,
Arg<SVEVector, "Vector operand">:$vector)>;
+// Tile slice to vector
+class LLVM_aarch64_sme_read<string direction>
+ : ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
+ [AllShapesMatch<["vector", "pg", "res"]>,
+ AllElementTypesMatch<["vector", "res"]>],
+ /*numResults*/1, /*overloadedResults*/[0]>,
+ Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
+ Arg<SVEPredicate, "Vector predicate">:$pg,
+ Arg<I32, "Virtual tile ID">,
+ Arg<I32, "Tile slice">)>;
+
def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
+def LLVM_aarch64_sme_read_horiz : LLVM_aarch64_sme_read<"horiz">;
+def LLVM_aarch64_sme_read_vert : LLVM_aarch64_sme_read<"vert">;
+
def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 750627421215dfb..7cbc382b0050a6e 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/TypeUtilities.h"
using namespace mlir;
using namespace mlir::arm_sme;
diff --git a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
index e119e1f1a404416..ae99ac5e02d62f0 100644
--- a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
@@ -10,3 +10,27 @@ llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32,
(i32, i32, vector<[4]xi1>, vector<[16]xi8>) -> ()
llvm.return
}
+
+// -----
+
+llvm.func @arm_sme_tile_slice_to_vector_invalid_shapes(
+ %tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv16i8 : vector<[16]xi8>
+) -> vector<[3]xf32> {
+ %tile = llvm.mlir.constant(0 : index) : i32
+ // expected-error @+1 {{failed to verify that all of {vector, pg, res} have same shape}}
+ %res = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv4i1, %tile, %tileslice) :
+ (vector<[16]xi8>, vector<[4]xi1>, i32, i32) -> vector<[3]xf32>
+ llvm.return %res : vector<[3]xf32>
+}
+
+// -----
+
+llvm.func @arm_sme_tile_slice_to_vector_invalid_element_types(
+ %tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv4f32 : vector<[4]xf32>
+) -> vector<[3]xi32> {
+ %tile = llvm.mlir.constant(0 : index) : i32
+ // expected-error @+1 {{failed to verify that all of {vector, res} have same element type}}
+ %res = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice) :
+ (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+ llvm.return %res : vector<[4]xi32>
+}
diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
index 9bb6b0c6574fcdb..c318e6d2d37f7fd 100644
--- a/mlir/test/Target/LLVMIR/arm-sme.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -334,3 +334,137 @@ llvm.func @arm_sme_vector_to_tile_vert(%tileslice : i32,
(i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
llvm.return
}
+
+// -----
+
+llvm.func @prevent_dce.nxv16i8(vector<[16]xi8>)
+llvm.func @prevent_dce.nxv8i16(vector<[8]xi16>)
+llvm.func @prevent_dce.nxv4i32(vector<[4]xi32>)
+llvm.func @prevent_dce.nxv2i64(vector<[2]xi64>)
+llvm.func @prevent_dce.nxv1i128(vector<[1]xi128>)
+llvm.func @prevent_dce.nxv8f16(vector<[8]xf16>)
+llvm.func @prevent_dce.nxv8bf16(vector<[8]xbf16>)
+llvm.func @prevent_dce.nxv4f32(vector<[4]xf32>)
+llvm.func @prevent_dce.nxv2f64(vector<[2]xf64>)
+
+llvm.func @arm_sme_tile_slice_to_vector_horiz(%tileslice : i32,
+ %nxv16i1 : vector<[16]xi1>,
+ %nxv8i1 : vector<[8]xi1>,
+ %nxv4i1 : vector<[4]xi1>,
+ %nxv2i1 : vector<[2]xi1>,
+ %nxv1i1 : vector<[1]xi1>,
+ %nxv16i8 : vector<[16]xi8>,
+ %nxv8i16 : vector<[8]xi16>,
+ %nxv4i32 : vector<[4]xi32>,
+ %nxv2i64 : vector<[2]xi64>,
+ %nxv1i128 : vector<[1]xi128>,
+ %nxv8f16 : vector<[8]xf16>,
+ %nxv8bf16 : vector<[8]xbf16>,
+ %nxv4f32 : vector<[4]xf32>,
+ %nxv2f64 : vector<[2]xf64>) {
+ %tile = llvm.mlir.constant(0 : index) : i32
+ // CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.horiz.nxv16i8
+ %res0 = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv16i1, %tile, %tileslice)
+ : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+ llvm.call @prevent_dce.nxv16i8(%res0) : (vector<[16]xi8>) -> ()
+ // CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.horiz.nxv8i16
+ %res1 = "arm_sme.intr.read.horiz"(%nxv8i16, %nxv8i1, %tile, %tileslice)
+ : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+ llvm.call @prevent_dce.nxv8i16(%res1) : (vector<[8]xi16>) -> ()
+ // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.horiz.nxv4i32
+ %res2 = "arm_sme.intr.read.horiz"(%nxv4i32, %nxv4i1, %tile, %tileslice)
+ : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+ llvm.call @prevent_dce.nxv4i32(%res2) : (vector<[4]xi32>) -> ()
+ // CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.horiz.nxv2i64
+ %res3 = "arm_sme.intr.read.horiz"(%nxv2i64, %nxv2i1, %tile, %tileslice)
+ : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+ llvm.call @prevent_dce.nxv2i64(%res3) : (vector<[2]xi64>) -> ()
+ // CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.horiz.nxv1i128
+ %res4 = "arm_sme.intr.read.horiz"(%nxv1i128, %nxv1i1, %tile, %tileslice)
+ : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+ llvm.call @prevent_dce.nxv1i128(%res4) : (vector<[1]xi128>) -> ()
+ // CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.horiz.nxv8f16
+ %res5 = "arm_sme.intr.read.horiz"(%nxv8f16, %nxv8i1, %tile, %tileslice)
+ : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+ llvm.call @prevent_dce.nxv8f16(%res5) : (vector<[8]xf16>) -> ()
+ // CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.horiz.nxv8bf16
+ %res6 = "arm_sme.intr.read.horiz"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
+ : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+ llvm.call @prevent_dce.nxv8bf16(%res6) : (vector<[8]xbf16>) -> ()
+ // CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.horiz.nxv4f32
+ %res7 = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice)
+ : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+ llvm.call @prevent_dce.nxv4f32(%res7) : (vector<[4]xf32>) -> ()
+ // CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.horiz.nxv2f64
+ %res8 = "arm_sme.intr.read.horiz"(%nxv2f64, %nxv2i1, %tile, %tileslice)
+ : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+ llvm.call @prevent_dce.nxv2f64(%res8) : (vector<[2]xf64>) -> ()
+ llvm.return
+}
+
+// -----
+
+llvm.func @prevent_dce.nxv16i8(vector<[16]xi8>)
+llvm.func @prevent_dce.nxv8i16(vector<[8]xi16>)
+llvm.func @prevent_dce.nxv4i32(vector<[4]xi32>)
+llvm.func @prevent_dce.nxv2i64(vector<[2]xi64>)
+llvm.func @prevent_dce.nxv1i128(vector<[1]xi128>)
+llvm.func @prevent_dce.nxv8f16(vector<[8]xf16>)
+llvm.func @prevent_dce.nxv8bf16(vector<[8]xbf16>)
+llvm.func @prevent_dce.nxv4f32(vector<[4]xf32>)
+llvm.func @prevent_dce.nxv2f64(vector<[2]xf64>)
+
+llvm.func @arm_sme_tile_slice_to_vector_vert(%tileslice : i32,
+ %nxv16i1 : vector<[16]xi1>,
+ %nxv8i1 : vector<[8]xi1>,
+ %nxv4i1 : vector<[4]xi1>,
+ %nxv2i1 : vector<[2]xi1>,
+ %nxv1i1 : vector<[1]xi1>,
+ %nxv16i8 : vector<[16]xi8>,
+ %nxv8i16 : vector<[8]xi16>,
+ %nxv4i32 : vector<[4]xi32>,
+ %nxv2i64 : vector<[2]xi64>,
+ %nxv1i128 : vector<[1]xi128>,
+ %nxv8f16 : vector<[8]xf16>,
+ %nxv8bf16 : vector<[8]xbf16>,
+ %nxv4f32 : vector<[4]xf32>,
+ %nxv2f64 : vector<[2]xf64>) {
+ %tile = llvm.mlir.constant(0 : index) : i32
+ // CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.vert.nxv16i8
+ %res0 = "arm_sme.intr.read.vert"(%nxv16i8, %nxv16i1, %tile, %tileslice)
+ : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+ llvm.call @prevent_dce.nxv16i8(%res0) : (vector<[16]xi8>) -> ()
+ // CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.vert.nxv8i16
+ %res1 = "arm_sme.intr.read.vert"(%nxv8i16, %nxv8i1, %tile, %tileslice)
+ : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+ llvm.call @prevent_dce.nxv8i16(%res1) : (vector<[8]xi16>) -> ()
+ // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.vert.nxv4i32
+ %res2 = "arm_sme.intr.read.vert"(%nxv4i32, %nxv4i1, %tile, %tileslice)
+ : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+ llvm.call @prevent_dce.nxv4i32(%res2) : (vector<[4]xi32>) -> ()
+ // CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.vert.nxv2i64
+ %res3 = "arm_sme.intr.read.vert"(%nxv2i64, %nxv2i1, %tile, %tileslice)
+ : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+ llvm.call @prevent_dce.nxv2i64(%res3) : (vector<[2]xi64>) -> ()
+ // CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.vert.nxv1i128
+ %res4 = "arm_sme.intr.read.vert"(%nxv1i128, %nxv1i1, %tile, %tileslice)
+ : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+ llvm.call @prevent_dce.nxv1i128(%res4) : (vector<[1]xi128>) -> ()
+ // CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.vert.nxv8f16
+ %res5 = "arm_sme.intr.read.vert"(%nxv8f16, %nxv8i1, %tile, %tileslice)
+ : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+ llvm.call @prevent_dce.nxv8f16(%res5) : (vector<[8]xf16>) -> ()
+ // CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.vert.nxv8bf16
+ %res6 = "arm_sme.intr.read.vert"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
+ : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+ llvm.call @prevent_dce.nxv8bf16(%res6) : (vector<[8]xbf16>) -> ()
+ // CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.vert.nxv4f32
+ %res7 = "arm_sme.intr.read.vert"(%nxv4f32, %nxv4i1, %tile, %tileslice)
+ : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+ llvm.call @prevent_dce.nxv4f32(%res7) : (vector<[4]xf32>) -> ()
+ // CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.vert.nxv2f64
+ %res8 = "arm_sme.intr.read.vert"(%nxv2f64, %nxv2i1, %tile, %tileslice)
+ : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+ llvm.call @prevent_dce.nxv2f64(%res8) : (vector<[2]xf64>) -> ()
+ llvm.return
+}
>From a6d389c8733b633958dfa4f8de37915eb8cd84d0 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 20 Sep 2023 14:54:17 +0000
Subject: [PATCH 2/4] Review fixups
Fixes: https://github.com/llvm/llvm-project/pull/66910#discussion_r1331730482
Fixes: https://github.com/llvm/llvm-project/pull/66910#discussion_r1331734063
---
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td | 2 +-
mlir/test/Target/LLVMIR/arm-sme.mlir | 37 -------------------
2 files changed, 1 insertion(+), 38 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 00e1fefc0521a78..1ca284a3e70dcec 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -563,7 +563,7 @@ class LLVM_aarch64_sme_read<string direction>
: ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
[AllShapesMatch<["vector", "pg", "res"]>,
AllElementTypesMatch<["vector", "res"]>],
- /*numResults*/1, /*overloadedResults*/[0]>,
+ /*numResults=*/1, /*overloadedResults=*/[0]>,
Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
Arg<SVEPredicate, "Vector predicate">:$pg,
Arg<I32, "Virtual tile ID">,
diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
index c318e6d2d37f7fd..5d697993c323245 100644
--- a/mlir/test/Target/LLVMIR/arm-sme.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -337,15 +337,6 @@ llvm.func @arm_sme_vector_to_tile_vert(%tileslice : i32,
// -----
-llvm.func @prevent_dce.nxv16i8(vector<[16]xi8>)
-llvm.func @prevent_dce.nxv8i16(vector<[8]xi16>)
-llvm.func @prevent_dce.nxv4i32(vector<[4]xi32>)
-llvm.func @prevent_dce.nxv2i64(vector<[2]xi64>)
-llvm.func @prevent_dce.nxv1i128(vector<[1]xi128>)
-llvm.func @prevent_dce.nxv8f16(vector<[8]xf16>)
-llvm.func @prevent_dce.nxv8bf16(vector<[8]xbf16>)
-llvm.func @prevent_dce.nxv4f32(vector<[4]xf32>)
-llvm.func @prevent_dce.nxv2f64(vector<[2]xf64>)
llvm.func @arm_sme_tile_slice_to_vector_horiz(%tileslice : i32,
%nxv16i1 : vector<[16]xi1>,
@@ -366,54 +357,35 @@ llvm.func @arm_sme_tile_slice_to_vector_horiz(%tileslice : i32,
// CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.horiz.nxv16i8
%res0 = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv16i1, %tile, %tileslice)
: (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
- llvm.call @prevent_dce.nxv16i8(%res0) : (vector<[16]xi8>) -> ()
// CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.horiz.nxv8i16
%res1 = "arm_sme.intr.read.horiz"(%nxv8i16, %nxv8i1, %tile, %tileslice)
: (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
- llvm.call @prevent_dce.nxv8i16(%res1) : (vector<[8]xi16>) -> ()
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.horiz.nxv4i32
%res2 = "arm_sme.intr.read.horiz"(%nxv4i32, %nxv4i1, %tile, %tileslice)
: (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
- llvm.call @prevent_dce.nxv4i32(%res2) : (vector<[4]xi32>) -> ()
// CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.horiz.nxv2i64
%res3 = "arm_sme.intr.read.horiz"(%nxv2i64, %nxv2i1, %tile, %tileslice)
: (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
- llvm.call @prevent_dce.nxv2i64(%res3) : (vector<[2]xi64>) -> ()
// CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.horiz.nxv1i128
%res4 = "arm_sme.intr.read.horiz"(%nxv1i128, %nxv1i1, %tile, %tileslice)
: (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
- llvm.call @prevent_dce.nxv1i128(%res4) : (vector<[1]xi128>) -> ()
// CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.horiz.nxv8f16
%res5 = "arm_sme.intr.read.horiz"(%nxv8f16, %nxv8i1, %tile, %tileslice)
: (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
- llvm.call @prevent_dce.nxv8f16(%res5) : (vector<[8]xf16>) -> ()
// CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.horiz.nxv8bf16
%res6 = "arm_sme.intr.read.horiz"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
: (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
- llvm.call @prevent_dce.nxv8bf16(%res6) : (vector<[8]xbf16>) -> ()
// CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.horiz.nxv4f32
%res7 = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice)
: (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
- llvm.call @prevent_dce.nxv4f32(%res7) : (vector<[4]xf32>) -> ()
// CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.horiz.nxv2f64
%res8 = "arm_sme.intr.read.horiz"(%nxv2f64, %nxv2i1, %tile, %tileslice)
: (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
- llvm.call @prevent_dce.nxv2f64(%res8) : (vector<[2]xf64>) -> ()
llvm.return
}
// -----
-llvm.func @prevent_dce.nxv16i8(vector<[16]xi8>)
-llvm.func @prevent_dce.nxv8i16(vector<[8]xi16>)
-llvm.func @prevent_dce.nxv4i32(vector<[4]xi32>)
-llvm.func @prevent_dce.nxv2i64(vector<[2]xi64>)
-llvm.func @prevent_dce.nxv1i128(vector<[1]xi128>)
-llvm.func @prevent_dce.nxv8f16(vector<[8]xf16>)
-llvm.func @prevent_dce.nxv8bf16(vector<[8]xbf16>)
-llvm.func @prevent_dce.nxv4f32(vector<[4]xf32>)
-llvm.func @prevent_dce.nxv2f64(vector<[2]xf64>)
-
llvm.func @arm_sme_tile_slice_to_vector_vert(%tileslice : i32,
%nxv16i1 : vector<[16]xi1>,
%nxv8i1 : vector<[8]xi1>,
@@ -433,38 +405,29 @@ llvm.func @arm_sme_tile_slice_to_vector_vert(%tileslice : i32,
// CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.vert.nxv16i8
%res0 = "arm_sme.intr.read.vert"(%nxv16i8, %nxv16i1, %tile, %tileslice)
: (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
- llvm.call @prevent_dce.nxv16i8(%res0) : (vector<[16]xi8>) -> ()
// CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.vert.nxv8i16
%res1 = "arm_sme.intr.read.vert"(%nxv8i16, %nxv8i1, %tile, %tileslice)
: (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
- llvm.call @prevent_dce.nxv8i16(%res1) : (vector<[8]xi16>) -> ()
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.vert.nxv4i32
%res2 = "arm_sme.intr.read.vert"(%nxv4i32, %nxv4i1, %tile, %tileslice)
: (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
- llvm.call @prevent_dce.nxv4i32(%res2) : (vector<[4]xi32>) -> ()
// CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.vert.nxv2i64
%res3 = "arm_sme.intr.read.vert"(%nxv2i64, %nxv2i1, %tile, %tileslice)
: (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
- llvm.call @prevent_dce.nxv2i64(%res3) : (vector<[2]xi64>) -> ()
// CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.vert.nxv1i128
%res4 = "arm_sme.intr.read.vert"(%nxv1i128, %nxv1i1, %tile, %tileslice)
: (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
- llvm.call @prevent_dce.nxv1i128(%res4) : (vector<[1]xi128>) -> ()
// CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.vert.nxv8f16
%res5 = "arm_sme.intr.read.vert"(%nxv8f16, %nxv8i1, %tile, %tileslice)
: (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
- llvm.call @prevent_dce.nxv8f16(%res5) : (vector<[8]xf16>) -> ()
// CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.vert.nxv8bf16
%res6 = "arm_sme.intr.read.vert"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
: (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
- llvm.call @prevent_dce.nxv8bf16(%res6) : (vector<[8]xbf16>) -> ()
// CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.vert.nxv4f32
%res7 = "arm_sme.intr.read.vert"(%nxv4f32, %nxv4i1, %tile, %tileslice)
: (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
- llvm.call @prevent_dce.nxv4f32(%res7) : (vector<[4]xf32>) -> ()
// CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.vert.nxv2f64
%res8 = "arm_sme.intr.read.vert"(%nxv2f64, %nxv2i1, %tile, %tileslice)
: (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
- llvm.call @prevent_dce.nxv2f64(%res8) : (vector<[2]xf64>) -> ()
llvm.return
}
>From ec90210e146e12648354363ea402130aaed03697 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 21 Sep 2023 11:57:37 +0000
Subject: [PATCH 3/4] [mlir][docgen] Display full attribute descriptions in
expandable regions
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
This updates the table of op attributes so that clicking the summary
expands to show the complete description.
Attribute | MLIR Type | Description
<name> <type> ▶ <summary> <-- Click to expand
Enum attributes have now also been updated to generate a description
that lists all the cases (with both their MLIR and C++ names). This
makes viewing enums on the MLIR docs much nicer.
---
mlir/include/mlir/IR/EnumAttr.td | 8 +++++++
mlir/tools/mlir-tblgen/OpDocGen.cpp | 36 +++++++++++++++++++++++++----
2 files changed, 40 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td
index 485c5266f3cfdfa..cb918b5eceb1a19 100644
--- a/mlir/include/mlir/IR/EnumAttr.td
+++ b/mlir/include/mlir/IR/EnumAttr.td
@@ -113,6 +113,13 @@ class I64BitEnumAttrCaseGroup<string sym, list<BitEnumAttrCaseBase> cases,
class EnumAttrInfo<
string name, list<EnumAttrCaseInfo> cases, Attr baseClass> :
Attr<baseClass.predicate, baseClass.summary> {
+
+ // Generate a description of this enums members for the MLIR docs.
+ let description =
+ "Enum cases:\n" # !interleave(
+ !foreach(case, cases,
+ "* " # case.str # " (`" # case.symbol # "`)"), "\n");
+
// The C++ enum class name
string className = name;
@@ -381,6 +388,7 @@ class EnumAttr<Dialect dialect, EnumAttrInfo enumInfo, string name = "",
list <Trait> traits = []>
: AttrDef<dialect, enumInfo.className, traits> {
let summary = enumInfo.summary;
+ let description = enumInfo.description;
// The backing enumeration.
EnumAttrInfo enum = enumInfo;
diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp
index c546763880a853f..d2538fc49b20d0d 100644
--- a/mlir/tools/mlir-tblgen/OpDocGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp
@@ -173,6 +173,13 @@ static void emitOpTraitsDoc(const Operator &op, raw_ostream &os) {
}
}
+static StringRef resolveAttrDescription(const Attribute &attr) {
+ StringRef description = attr.getDescription();
+ if (description.empty())
+ return attr.getBaseAttr().getDescription();
+ return description;
+}
+
static void emitOpDoc(const Operator &op, raw_ostream &os) {
std::string classNameStr = op.getQualCppClassName();
StringRef className = classNameStr;
@@ -195,13 +202,34 @@ static void emitOpDoc(const Operator &op, raw_ostream &os) {
// TODO: Attributes are only documented by TableGen name, with no further
// info. This should be improved.
os << "\n#### Attributes:\n\n";
- os << "| Attribute | MLIR Type | Description |\n"
- << "| :-------: | :-------: | ----------- |\n";
+ // Note: This table is HTML rather than markdown so the attribute's
+ // description can appear in an expandable region. The description may be
+ // multiple lines, which is not supported in a markdown table cell.
+ os << "<table>\n";
+ // Header.
+ os << "<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>\n";
for (const auto &it : op.getAttributes()) {
StringRef storageType = it.attr.getStorageType();
- os << "| `" << it.name << "` | " << storageType << " | "
- << it.attr.getSummary() << "\n";
+ // Name and storage type.
+ os << "<tr>";
+ os << "<td><code>" << it.name << "</code></td><td>" << storageType
+ << "</td><td>";
+ StringRef description = resolveAttrDescription(it.attr);
+ if (!description.empty()) {
+ // Expandable description.
+ // This appears as just the summary, but when clicked shows the full
+ // description.
+ os << "<details>"
+ << "<summary>" << it.attr.getSummary() << "</summary>"
+ << "{{% markdown %}}" << description << "{{% /markdown %}}"
+ << "</details>";
+ } else {
+ // Fallback: Single-line summary.
+ os << it.attr.getSummary();
+ }
+ os << "</td></tr>\n";
}
+ os << "<table>\n";
}
// Emit each of the operands.
>From 48160912f657fc94a55127f7728c133320bfa25e Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 21 Sep 2023 16:56:16 +0000
Subject: [PATCH 4/4] Align `:` in params
---
mlir/test/Target/LLVMIR/arm-sme.mlir | 52 ++++++++++++++--------------
1 file changed, 26 insertions(+), 26 deletions(-)
diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
index 5d697993c323245..628d7ba4b649e51 100644
--- a/mlir/test/Target/LLVMIR/arm-sme.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -339,20 +339,20 @@ llvm.func @arm_sme_vector_to_tile_vert(%tileslice : i32,
llvm.func @arm_sme_tile_slice_to_vector_horiz(%tileslice : i32,
- %nxv16i1 : vector<[16]xi1>,
- %nxv8i1 : vector<[8]xi1>,
- %nxv4i1 : vector<[4]xi1>,
- %nxv2i1 : vector<[2]xi1>,
- %nxv1i1 : vector<[1]xi1>,
- %nxv16i8 : vector<[16]xi8>,
- %nxv8i16 : vector<[8]xi16>,
- %nxv4i32 : vector<[4]xi32>,
- %nxv2i64 : vector<[2]xi64>,
- %nxv1i128 : vector<[1]xi128>,
- %nxv8f16 : vector<[8]xf16>,
- %nxv8bf16 : vector<[8]xbf16>,
- %nxv4f32 : vector<[4]xf32>,
- %nxv2f64 : vector<[2]xf64>) {
+ %nxv16i1 : vector<[16]xi1>,
+ %nxv8i1 : vector<[8]xi1>,
+ %nxv4i1 : vector<[4]xi1>,
+ %nxv2i1 : vector<[2]xi1>,
+ %nxv1i1 : vector<[1]xi1>,
+ %nxv16i8 : vector<[16]xi8>,
+ %nxv8i16 : vector<[8]xi16>,
+ %nxv4i32 : vector<[4]xi32>,
+ %nxv2i64 : vector<[2]xi64>,
+ %nxv1i128 : vector<[1]xi128>,
+ %nxv8f16 : vector<[8]xf16>,
+ %nxv8bf16 : vector<[8]xbf16>,
+ %nxv4f32 : vector<[4]xf32>,
+ %nxv2f64 : vector<[2]xf64>) {
%tile = llvm.mlir.constant(0 : index) : i32
// CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.horiz.nxv16i8
%res0 = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv16i1, %tile, %tileslice)
@@ -387,20 +387,20 @@ llvm.func @arm_sme_tile_slice_to_vector_horiz(%tileslice : i32,
// -----
llvm.func @arm_sme_tile_slice_to_vector_vert(%tileslice : i32,
- %nxv16i1 : vector<[16]xi1>,
- %nxv8i1 : vector<[8]xi1>,
- %nxv4i1 : vector<[4]xi1>,
- %nxv2i1 : vector<[2]xi1>,
- %nxv1i1 : vector<[1]xi1>,
- %nxv16i8 : vector<[16]xi8>,
- %nxv8i16 : vector<[8]xi16>,
- %nxv4i32 : vector<[4]xi32>,
- %nxv2i64 : vector<[2]xi64>,
+ %nxv16i1 : vector<[16]xi1>,
+ %nxv8i1 : vector<[8]xi1>,
+ %nxv4i1 : vector<[4]xi1>,
+ %nxv2i1 : vector<[2]xi1>,
+ %nxv1i1 : vector<[1]xi1>,
+ %nxv16i8 : vector<[16]xi8>,
+ %nxv8i16 : vector<[8]xi16>,
+ %nxv4i32 : vector<[4]xi32>,
+ %nxv2i64 : vector<[2]xi64>,
%nxv1i128 : vector<[1]xi128>,
- %nxv8f16 : vector<[8]xf16>,
+ %nxv8f16 : vector<[8]xf16>,
%nxv8bf16 : vector<[8]xbf16>,
- %nxv4f32 : vector<[4]xf32>,
- %nxv2f64 : vector<[2]xf64>) {
+ %nxv4f32 : vector<[4]xf32>,
+ %nxv2f64 : vector<[2]xf64>) {
%tile = llvm.mlir.constant(0 : index) : i32
// CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.vert.nxv16i8
%res0 = "arm_sme.intr.read.vert"(%nxv16i8, %nxv16i1, %tile, %tileslice)
More information about the Mlir-commits
mailing list