[Mlir-commits] [mlir] [mlir][ArmSME] Add tile slice to vector intrinsics (PR #66910)
Benjamin Maxwell
llvmlistbot at llvm.org
Wed Sep 20 07:26:19 PDT 2023
https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/66910
Add support for following vector to tile (MOVA) intrinsics to ArmSME dialect:
llvm.aarch64.sme.read.vert
llvm.aarch64.sme.read.horiz
This also slightly updates ArmSME_IntrOp to support return values.
>From 896ab7de80ddaf7ec36bb39a0f1d3f45950a42dd Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 20 Sep 2023 12:50:24 +0000
Subject: [PATCH] [mlir][ArmSME] Add tile slice to vector intrinsics
Add support for following vector to tile (MOVA) intrinsics to ArmSME
dialect:
llvm.aarch64.sme.read.vert
llvm.aarch64.sme.read.horiz
This also slightly updates ArmSME_IntrOp to support return values.
---
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td | 23 ++-
mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp | 1 +
mlir/test/Target/LLVMIR/arm-sme-invalid.mlir | 24 ++++
mlir/test/Target/LLVMIR/arm-sme.mlir | 134 ++++++++++++++++++
4 files changed, 178 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 7f02e723f3d91c2..00e1fefc0521a78 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -469,15 +469,16 @@ def MOPVector : ScalableVectorOfLengthAndType<[16, 8, 4, 2],
def LDSTPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2, 1], [I1]>;
class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
- list<Trait> traits = []>
+ list<Trait> traits = [], int numResults = 0,
+ list<int> overloadedResults = []>
: LLVM_IntrOpBase<
/*Dialect dialect=*/ArmSME_Dialect,
/*string opName=*/"intr." # mnemonic,
/*string enumName=*/"aarch64_sme_" # !subst(".", "_", mnemonic),
- /*list<int> overloadedResults=*/[],
+ /*list<int> overloadedResults=*/overloadedResults,
/*list<int> overloadedOperands=*/overloadedOperands,
/*list<Trait> traits=*/traits,
- /*int numResults=*/0>;
+ /*int numResults=*/numResults>;
// Zero
def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
@@ -548,7 +549,7 @@ def LLVM_aarch64_sme_str
Arguments<(ins Arg<I32, "Index">,
Arg<LLVM_AnyPointer, "Store address", [MemWrite]>)>;
-// Vector to tile
+// Vector to tile slice
class LLVM_aarch64_sme_write<string direction>
: ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
[AllShapesMatch<["pg", "vector"]>]>,
@@ -557,9 +558,23 @@ class LLVM_aarch64_sme_write<string direction>
Arg<SVEPredicate, "Vector predicate">:$pg,
Arg<SVEVector, "Vector operand">:$vector)>;
+// Tile slice to vector
+class LLVM_aarch64_sme_read<string direction>
+ : ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
+ [AllShapesMatch<["vector", "pg", "res"]>,
+ AllElementTypesMatch<["vector", "res"]>],
+ /*numResults*/1, /*overloadedResults*/[0]>,
+ Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
+ Arg<SVEPredicate, "Vector predicate">:$pg,
+ Arg<I32, "Virtual tile ID">,
+ Arg<I32, "Tile slice">)>;
+
def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
+def LLVM_aarch64_sme_read_horiz : LLVM_aarch64_sme_read<"horiz">;
+def LLVM_aarch64_sme_read_vert : LLVM_aarch64_sme_read<"vert">;
+
def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 750627421215dfb..7cbc382b0050a6e 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/TypeUtilities.h"
using namespace mlir;
using namespace mlir::arm_sme;
diff --git a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
index e119e1f1a404416..ae99ac5e02d62f0 100644
--- a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
@@ -10,3 +10,27 @@ llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32,
(i32, i32, vector<[4]xi1>, vector<[16]xi8>) -> ()
llvm.return
}
+
+// -----
+
+llvm.func @arm_sme_tile_slice_to_vector_invalid_shapes(
+ %tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv16i8 : vector<[16]xi8>
+) -> vector<[3]xf32> {
+ %tile = llvm.mlir.constant(0 : index) : i32
+ // expected-error @+1 {{failed to verify that all of {vector, pg, res} have same shape}}
+ %res = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv4i1, %tile, %tileslice) :
+ (vector<[16]xi8>, vector<[4]xi1>, i32, i32) -> vector<[3]xf32>
+ llvm.return %res : vector<[3]xf32>
+}
+
+// -----
+
+llvm.func @arm_sme_tile_slice_to_vector_invalid_element_types(
+ %tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv4f32 : vector<[4]xf32>
+) -> vector<[3]xi32> {
+ %tile = llvm.mlir.constant(0 : index) : i32
+ // expected-error @+1 {{failed to verify that all of {vector, res} have same element type}}
+ %res = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice) :
+ (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+ llvm.return %res : vector<[4]xi32>
+}
diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
index 9bb6b0c6574fcdb..c318e6d2d37f7fd 100644
--- a/mlir/test/Target/LLVMIR/arm-sme.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -334,3 +334,137 @@ llvm.func @arm_sme_vector_to_tile_vert(%tileslice : i32,
(i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
llvm.return
}
+
+// -----
+
+llvm.func @prevent_dce.nxv16i8(vector<[16]xi8>)
+llvm.func @prevent_dce.nxv8i16(vector<[8]xi16>)
+llvm.func @prevent_dce.nxv4i32(vector<[4]xi32>)
+llvm.func @prevent_dce.nxv2i64(vector<[2]xi64>)
+llvm.func @prevent_dce.nxv1i128(vector<[1]xi128>)
+llvm.func @prevent_dce.nxv8f16(vector<[8]xf16>)
+llvm.func @prevent_dce.nxv8bf16(vector<[8]xbf16>)
+llvm.func @prevent_dce.nxv4f32(vector<[4]xf32>)
+llvm.func @prevent_dce.nxv2f64(vector<[2]xf64>)
+
+llvm.func @arm_sme_tile_slice_to_vector_horiz(%tileslice : i32,
+ %nxv16i1 : vector<[16]xi1>,
+ %nxv8i1 : vector<[8]xi1>,
+ %nxv4i1 : vector<[4]xi1>,
+ %nxv2i1 : vector<[2]xi1>,
+ %nxv1i1 : vector<[1]xi1>,
+ %nxv16i8 : vector<[16]xi8>,
+ %nxv8i16 : vector<[8]xi16>,
+ %nxv4i32 : vector<[4]xi32>,
+ %nxv2i64 : vector<[2]xi64>,
+ %nxv1i128 : vector<[1]xi128>,
+ %nxv8f16 : vector<[8]xf16>,
+ %nxv8bf16 : vector<[8]xbf16>,
+ %nxv4f32 : vector<[4]xf32>,
+ %nxv2f64 : vector<[2]xf64>) {
+ %tile = llvm.mlir.constant(0 : index) : i32
+ // CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.horiz.nxv16i8
+ %res0 = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv16i1, %tile, %tileslice)
+ : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+ llvm.call @prevent_dce.nxv16i8(%res0) : (vector<[16]xi8>) -> ()
+ // CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.horiz.nxv8i16
+ %res1 = "arm_sme.intr.read.horiz"(%nxv8i16, %nxv8i1, %tile, %tileslice)
+ : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+ llvm.call @prevent_dce.nxv8i16(%res1) : (vector<[8]xi16>) -> ()
+ // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.horiz.nxv4i32
+ %res2 = "arm_sme.intr.read.horiz"(%nxv4i32, %nxv4i1, %tile, %tileslice)
+ : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+ llvm.call @prevent_dce.nxv4i32(%res2) : (vector<[4]xi32>) -> ()
+ // CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.horiz.nxv2i64
+ %res3 = "arm_sme.intr.read.horiz"(%nxv2i64, %nxv2i1, %tile, %tileslice)
+ : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+ llvm.call @prevent_dce.nxv2i64(%res3) : (vector<[2]xi64>) -> ()
+ // CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.horiz.nxv1i128
+ %res4 = "arm_sme.intr.read.horiz"(%nxv1i128, %nxv1i1, %tile, %tileslice)
+ : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+ llvm.call @prevent_dce.nxv1i128(%res4) : (vector<[1]xi128>) -> ()
+ // CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.horiz.nxv8f16
+ %res5 = "arm_sme.intr.read.horiz"(%nxv8f16, %nxv8i1, %tile, %tileslice)
+ : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+ llvm.call @prevent_dce.nxv8f16(%res5) : (vector<[8]xf16>) -> ()
+ // CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.horiz.nxv8bf16
+ %res6 = "arm_sme.intr.read.horiz"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
+ : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+ llvm.call @prevent_dce.nxv8bf16(%res6) : (vector<[8]xbf16>) -> ()
+ // CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.horiz.nxv4f32
+ %res7 = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice)
+ : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+ llvm.call @prevent_dce.nxv4f32(%res7) : (vector<[4]xf32>) -> ()
+ // CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.horiz.nxv2f64
+ %res8 = "arm_sme.intr.read.horiz"(%nxv2f64, %nxv2i1, %tile, %tileslice)
+ : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+ llvm.call @prevent_dce.nxv2f64(%res8) : (vector<[2]xf64>) -> ()
+ llvm.return
+}
+
+// -----
+
+llvm.func @prevent_dce.nxv16i8(vector<[16]xi8>)
+llvm.func @prevent_dce.nxv8i16(vector<[8]xi16>)
+llvm.func @prevent_dce.nxv4i32(vector<[4]xi32>)
+llvm.func @prevent_dce.nxv2i64(vector<[2]xi64>)
+llvm.func @prevent_dce.nxv1i128(vector<[1]xi128>)
+llvm.func @prevent_dce.nxv8f16(vector<[8]xf16>)
+llvm.func @prevent_dce.nxv8bf16(vector<[8]xbf16>)
+llvm.func @prevent_dce.nxv4f32(vector<[4]xf32>)
+llvm.func @prevent_dce.nxv2f64(vector<[2]xf64>)
+
+llvm.func @arm_sme_tile_slice_to_vector_vert(%tileslice : i32,
+ %nxv16i1 : vector<[16]xi1>,
+ %nxv8i1 : vector<[8]xi1>,
+ %nxv4i1 : vector<[4]xi1>,
+ %nxv2i1 : vector<[2]xi1>,
+ %nxv1i1 : vector<[1]xi1>,
+ %nxv16i8 : vector<[16]xi8>,
+ %nxv8i16 : vector<[8]xi16>,
+ %nxv4i32 : vector<[4]xi32>,
+ %nxv2i64 : vector<[2]xi64>,
+ %nxv1i128 : vector<[1]xi128>,
+ %nxv8f16 : vector<[8]xf16>,
+ %nxv8bf16 : vector<[8]xbf16>,
+ %nxv4f32 : vector<[4]xf32>,
+ %nxv2f64 : vector<[2]xf64>) {
+ %tile = llvm.mlir.constant(0 : index) : i32
+ // CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.vert.nxv16i8
+ %res0 = "arm_sme.intr.read.vert"(%nxv16i8, %nxv16i1, %tile, %tileslice)
+ : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+ llvm.call @prevent_dce.nxv16i8(%res0) : (vector<[16]xi8>) -> ()
+ // CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.vert.nxv8i16
+ %res1 = "arm_sme.intr.read.vert"(%nxv8i16, %nxv8i1, %tile, %tileslice)
+ : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+ llvm.call @prevent_dce.nxv8i16(%res1) : (vector<[8]xi16>) -> ()
+ // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.vert.nxv4i32
+ %res2 = "arm_sme.intr.read.vert"(%nxv4i32, %nxv4i1, %tile, %tileslice)
+ : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+ llvm.call @prevent_dce.nxv4i32(%res2) : (vector<[4]xi32>) -> ()
+ // CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.vert.nxv2i64
+ %res3 = "arm_sme.intr.read.vert"(%nxv2i64, %nxv2i1, %tile, %tileslice)
+ : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+ llvm.call @prevent_dce.nxv2i64(%res3) : (vector<[2]xi64>) -> ()
+ // CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.vert.nxv1i128
+ %res4 = "arm_sme.intr.read.vert"(%nxv1i128, %nxv1i1, %tile, %tileslice)
+ : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+ llvm.call @prevent_dce.nxv1i128(%res4) : (vector<[1]xi128>) -> ()
+ // CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.vert.nxv8f16
+ %res5 = "arm_sme.intr.read.vert"(%nxv8f16, %nxv8i1, %tile, %tileslice)
+ : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+ llvm.call @prevent_dce.nxv8f16(%res5) : (vector<[8]xf16>) -> ()
+ // CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.vert.nxv8bf16
+ %res6 = "arm_sme.intr.read.vert"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
+ : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+ llvm.call @prevent_dce.nxv8bf16(%res6) : (vector<[8]xbf16>) -> ()
+ // CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.vert.nxv4f32
+ %res7 = "arm_sme.intr.read.vert"(%nxv4f32, %nxv4i1, %tile, %tileslice)
+ : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+ llvm.call @prevent_dce.nxv4f32(%res7) : (vector<[4]xf32>) -> ()
+ // CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.vert.nxv2f64
+ %res8 = "arm_sme.intr.read.vert"(%nxv2f64, %nxv2i1, %tile, %tileslice)
+ : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+ llvm.call @prevent_dce.nxv2f64(%res8) : (vector<[2]xf64>) -> ()
+ llvm.return
+}
More information about the Mlir-commits
mailing list