[Mlir-commits] [mlir] [mlir][ArmSME] Add tile slice to vector intrinsics (PR #66910)

Benjamin Maxwell llvmlistbot at llvm.org
Wed Sep 20 07:26:19 PDT 2023


https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/66910

Add support for following vector to tile (MOVA) intrinsics to ArmSME dialect:

  llvm.aarch64.sme.read.vert
  llvm.aarch64.sme.read.horiz

This also slightly updates ArmSME_IntrOp to support return values.

>From 896ab7de80ddaf7ec36bb39a0f1d3f45950a42dd Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 20 Sep 2023 12:50:24 +0000
Subject: [PATCH] [mlir][ArmSME] Add tile slice to vector intrinsics

Add support for following vector to tile (MOVA) intrinsics to ArmSME
dialect:

  llvm.aarch64.sme.read.vert
  llvm.aarch64.sme.read.horiz

This also slightly updates ArmSME_IntrOp to support return values.
---
 mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td |  23 ++-
 mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp         |   1 +
 mlir/test/Target/LLVMIR/arm-sme-invalid.mlir  |  24 ++++
 mlir/test/Target/LLVMIR/arm-sme.mlir          | 134 ++++++++++++++++++
 4 files changed, 178 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 7f02e723f3d91c2..00e1fefc0521a78 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -469,15 +469,16 @@ def MOPVector : ScalableVectorOfLengthAndType<[16, 8, 4, 2],
 def LDSTPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2, 1], [I1]>;
 
 class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
-                    list<Trait> traits = []>
+                    list<Trait> traits = [], int numResults = 0,
+                    list<int> overloadedResults = []>
     : LLVM_IntrOpBase<
           /*Dialect dialect=*/ArmSME_Dialect,
           /*string opName=*/"intr." # mnemonic,
           /*string enumName=*/"aarch64_sme_" # !subst(".", "_", mnemonic),
-          /*list<int> overloadedResults=*/[],
+          /*list<int> overloadedResults=*/overloadedResults,
           /*list<int> overloadedOperands=*/overloadedOperands,
           /*list<Trait> traits=*/traits,
-          /*int numResults=*/0>;
+          /*int numResults=*/numResults>;
 
 // Zero
 def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
@@ -548,7 +549,7 @@ def LLVM_aarch64_sme_str
       Arguments<(ins Arg<I32, "Index">,
                  Arg<LLVM_AnyPointer, "Store address", [MemWrite]>)>;
 
-// Vector to tile
+// Vector to tile slice
 class LLVM_aarch64_sme_write<string direction>
     : ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
                     [AllShapesMatch<["pg", "vector"]>]>,
@@ -557,9 +558,23 @@ class LLVM_aarch64_sme_write<string direction>
                      Arg<SVEPredicate, "Vector predicate">:$pg,
                      Arg<SVEVector, "Vector operand">:$vector)>;
 
+// Tile slice to vector
+class LLVM_aarch64_sme_read<string direction>
+    : ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
+                    [AllShapesMatch<["vector", "pg", "res"]>,
+                     AllElementTypesMatch<["vector", "res"]>],
+                    /*numResults*/1, /*overloadedResults*/[0]>,
+      Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
+                     Arg<SVEPredicate, "Vector predicate">:$pg,
+                     Arg<I32, "Virtual tile ID">,
+                     Arg<I32, "Tile slice">)>;
+
 def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
 def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
 
+def LLVM_aarch64_sme_read_horiz : LLVM_aarch64_sme_read<"horiz">;
+def LLVM_aarch64_sme_read_vert : LLVM_aarch64_sme_read<"vert">;
+
 def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
 def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
 
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 750627421215dfb..7cbc382b0050a6e 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/TypeUtilities.h"
 
 using namespace mlir;
 using namespace mlir::arm_sme;
diff --git a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
index e119e1f1a404416..ae99ac5e02d62f0 100644
--- a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
@@ -10,3 +10,27 @@ llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32,
       (i32, i32, vector<[4]xi1>, vector<[16]xi8>) -> ()
   llvm.return
 }
+
+// -----
+
+llvm.func @arm_sme_tile_slice_to_vector_invalid_shapes(
+  %tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv16i8 : vector<[16]xi8>
+) -> vector<[3]xf32> {
+  %tile = llvm.mlir.constant(0 : index) : i32
+  // expected-error @+1 {{failed to verify that all of {vector, pg, res} have same shape}}
+  %res = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv4i1, %tile, %tileslice) :
+      (vector<[16]xi8>, vector<[4]xi1>, i32, i32) -> vector<[3]xf32>
+  llvm.return %res : vector<[3]xf32>
+}
+
+// -----
+
+llvm.func @arm_sme_tile_slice_to_vector_invalid_element_types(
+  %tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv4f32 : vector<[4]xf32>
+) -> vector<[3]xi32> {
+  %tile = llvm.mlir.constant(0 : index) : i32
+  // expected-error @+1 {{failed to verify that all of {vector, res} have same element type}}
+  %res = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice) :
+      (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  llvm.return %res : vector<[4]xi32>
+}
diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
index 9bb6b0c6574fcdb..c318e6d2d37f7fd 100644
--- a/mlir/test/Target/LLVMIR/arm-sme.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -334,3 +334,137 @@ llvm.func @arm_sme_vector_to_tile_vert(%tileslice : i32,
       (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
   llvm.return
 }
+
+// -----
+
+llvm.func @prevent_dce.nxv16i8(vector<[16]xi8>)
+llvm.func @prevent_dce.nxv8i16(vector<[8]xi16>)
+llvm.func @prevent_dce.nxv4i32(vector<[4]xi32>)
+llvm.func @prevent_dce.nxv2i64(vector<[2]xi64>)
+llvm.func @prevent_dce.nxv1i128(vector<[1]xi128>)
+llvm.func @prevent_dce.nxv8f16(vector<[8]xf16>)
+llvm.func @prevent_dce.nxv8bf16(vector<[8]xbf16>)
+llvm.func @prevent_dce.nxv4f32(vector<[4]xf32>)
+llvm.func @prevent_dce.nxv2f64(vector<[2]xf64>)
+
+llvm.func @arm_sme_tile_slice_to_vector_horiz(%tileslice : i32,
+                                              %nxv16i1 : vector<[16]xi1>,
+                                              %nxv8i1 : vector<[8]xi1>,
+                                              %nxv4i1 : vector<[4]xi1>,
+                                              %nxv2i1 : vector<[2]xi1>,
+                                              %nxv1i1 : vector<[1]xi1>,
+                                              %nxv16i8 : vector<[16]xi8>,
+                                              %nxv8i16 : vector<[8]xi16>,
+                                              %nxv4i32 : vector<[4]xi32>,
+                                              %nxv2i64 : vector<[2]xi64>,
+                                              %nxv1i128 : vector<[1]xi128>,
+                                              %nxv8f16 : vector<[8]xf16>,
+                                              %nxv8bf16 : vector<[8]xbf16>,
+                                              %nxv4f32 : vector<[4]xf32>,
+                                              %nxv2f64 : vector<[2]xf64>) {
+  %tile = llvm.mlir.constant(0 : index) : i32
+  // CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.horiz.nxv16i8
+  %res0 = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv16i1, %tile, %tileslice)
+    : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+  llvm.call @prevent_dce.nxv16i8(%res0) : (vector<[16]xi8>) -> ()
+  // CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.horiz.nxv8i16
+  %res1 = "arm_sme.intr.read.horiz"(%nxv8i16, %nxv8i1, %tile, %tileslice)
+    : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+  llvm.call @prevent_dce.nxv8i16(%res1) : (vector<[8]xi16>) -> ()
+  // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.horiz.nxv4i32
+  %res2 = "arm_sme.intr.read.horiz"(%nxv4i32, %nxv4i1, %tile, %tileslice)
+    : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  llvm.call @prevent_dce.nxv4i32(%res2) : (vector<[4]xi32>) -> ()
+  // CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.horiz.nxv2i64
+  %res3 = "arm_sme.intr.read.horiz"(%nxv2i64, %nxv2i1, %tile, %tileslice)
+    : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+  llvm.call @prevent_dce.nxv2i64(%res3) : (vector<[2]xi64>) -> ()
+  // CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.horiz.nxv1i128
+  %res4 = "arm_sme.intr.read.horiz"(%nxv1i128, %nxv1i1, %tile, %tileslice)
+    : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+  llvm.call @prevent_dce.nxv1i128(%res4) : (vector<[1]xi128>) -> ()
+  // CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.horiz.nxv8f16
+  %res5 = "arm_sme.intr.read.horiz"(%nxv8f16, %nxv8i1, %tile, %tileslice)
+    : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+  llvm.call @prevent_dce.nxv8f16(%res5) : (vector<[8]xf16>) -> ()
+  // CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.horiz.nxv8bf16
+  %res6 = "arm_sme.intr.read.horiz"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
+    : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+  llvm.call @prevent_dce.nxv8bf16(%res6) : (vector<[8]xbf16>) -> ()
+  // CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.horiz.nxv4f32
+  %res7 = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice)
+    : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+  llvm.call @prevent_dce.nxv4f32(%res7) : (vector<[4]xf32>) -> ()
+  // CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.horiz.nxv2f64
+  %res8 = "arm_sme.intr.read.horiz"(%nxv2f64, %nxv2i1, %tile, %tileslice)
+    : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+  llvm.call @prevent_dce.nxv2f64(%res8) : (vector<[2]xf64>) -> ()
+  llvm.return
+}
+
+// -----
+
+llvm.func @prevent_dce.nxv16i8(vector<[16]xi8>)
+llvm.func @prevent_dce.nxv8i16(vector<[8]xi16>)
+llvm.func @prevent_dce.nxv4i32(vector<[4]xi32>)
+llvm.func @prevent_dce.nxv2i64(vector<[2]xi64>)
+llvm.func @prevent_dce.nxv1i128(vector<[1]xi128>)
+llvm.func @prevent_dce.nxv8f16(vector<[8]xf16>)
+llvm.func @prevent_dce.nxv8bf16(vector<[8]xbf16>)
+llvm.func @prevent_dce.nxv4f32(vector<[4]xf32>)
+llvm.func @prevent_dce.nxv2f64(vector<[2]xf64>)
+
+llvm.func @arm_sme_tile_slice_to_vector_vert(%tileslice : i32,
+                                              %nxv16i1 : vector<[16]xi1>,
+                                              %nxv8i1 : vector<[8]xi1>,
+                                              %nxv4i1 : vector<[4]xi1>,
+                                              %nxv2i1 : vector<[2]xi1>,
+                                              %nxv1i1 : vector<[1]xi1>,
+                                              %nxv16i8 : vector<[16]xi8>,
+                                              %nxv8i16 : vector<[8]xi16>,
+                                              %nxv4i32 : vector<[4]xi32>,
+                                              %nxv2i64 : vector<[2]xi64>,
+                                              %nxv1i128 : vector<[1]xi128>,
+                                              %nxv8f16 : vector<[8]xf16>,
+                                              %nxv8bf16 : vector<[8]xbf16>,
+                                              %nxv4f32 : vector<[4]xf32>,
+                                              %nxv2f64 : vector<[2]xf64>) {
+  %tile = llvm.mlir.constant(0 : index) : i32
+  // CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.vert.nxv16i8
+  %res0 = "arm_sme.intr.read.vert"(%nxv16i8, %nxv16i1, %tile, %tileslice)
+    : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+  llvm.call @prevent_dce.nxv16i8(%res0) : (vector<[16]xi8>) -> ()
+  // CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.vert.nxv8i16
+  %res1 = "arm_sme.intr.read.vert"(%nxv8i16, %nxv8i1, %tile, %tileslice)
+    : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+  llvm.call @prevent_dce.nxv8i16(%res1) : (vector<[8]xi16>) -> ()
+  // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.vert.nxv4i32
+  %res2 = "arm_sme.intr.read.vert"(%nxv4i32, %nxv4i1, %tile, %tileslice)
+    : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  llvm.call @prevent_dce.nxv4i32(%res2) : (vector<[4]xi32>) -> ()
+  // CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.vert.nxv2i64
+  %res3 = "arm_sme.intr.read.vert"(%nxv2i64, %nxv2i1, %tile, %tileslice)
+    : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+  llvm.call @prevent_dce.nxv2i64(%res3) : (vector<[2]xi64>) -> ()
+  // CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.vert.nxv1i128
+  %res4 = "arm_sme.intr.read.vert"(%nxv1i128, %nxv1i1, %tile, %tileslice)
+    : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+  llvm.call @prevent_dce.nxv1i128(%res4) : (vector<[1]xi128>) -> ()
+  // CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.vert.nxv8f16
+  %res5 = "arm_sme.intr.read.vert"(%nxv8f16, %nxv8i1, %tile, %tileslice)
+    : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+  llvm.call @prevent_dce.nxv8f16(%res5) : (vector<[8]xf16>) -> ()
+  // CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.vert.nxv8bf16
+  %res6 = "arm_sme.intr.read.vert"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
+    : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+  llvm.call @prevent_dce.nxv8bf16(%res6) : (vector<[8]xbf16>) -> ()
+  // CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.vert.nxv4f32
+  %res7 = "arm_sme.intr.read.vert"(%nxv4f32, %nxv4i1, %tile, %tileslice)
+    : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+  llvm.call @prevent_dce.nxv4f32(%res7) : (vector<[4]xf32>) -> ()
+  // CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.vert.nxv2f64
+  %res8 = "arm_sme.intr.read.vert"(%nxv2f64, %nxv2i1, %tile, %tileslice)
+    : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+  llvm.call @prevent_dce.nxv2f64(%res8) : (vector<[2]xf64>) -> ()
+  llvm.return
+}



More information about the Mlir-commits mailing list