[Mlir-commits] [mlir] [mlir][ArmSVE] Add intrinsics for the SME2 multi-vector zips (PR #80985)

Benjamin Maxwell llvmlistbot at llvm.org
Wed Feb 7 05:57:13 PST 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/80985

>From 0925ac611238f92b20160535bcfca66c4a4ffde1 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 7 Feb 2024 12:33:50 +0000
Subject: [PATCH] [mlir][ArmSVE] Add intrinsics for the SME2 multi-vector zips

These are added to the ArmSVE dialect for consistency with LLVM, which
registers SME2 intrinsics that don't require ZA under SVE.
---
 mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 25 ++++++++++-
 mlir/test/Target/LLVMIR/arm-sve.mlir          | 42 +++++++++++++++++++
 2 files changed, 65 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index e3f3d9e62e8fb..f237f232487e5 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -59,14 +59,15 @@ class ArmSVE_Op<string mnemonic, list<Trait> traits = []> :
 class ArmSVE_IntrOp<string mnemonic,
                     list<Trait> traits = [],
                     list<int> overloadedOperands = [],
-                    list<int> overloadedResults = []> :
+                    list<int> overloadedResults = [],
+                    int numResults = 1> :
   LLVM_IntrOpBase</*Dialect dialect=*/ArmSVE_Dialect,
                   /*string opName=*/"intr." # mnemonic,
                   /*string enumName=*/"aarch64_sve_" # !subst(".", "_", mnemonic),
                   /*list<int> overloadedResults=*/overloadedResults,
                   /*list<int> overloadedOperands=*/overloadedOperands,
                   /*list<Trait> traits=*/traits,
-                  /*int numResults=*/1>;
+                  /*int numResults=*/numResults>;
 
 class ArmSVE_IntrBinaryOverloadedOp<string mnemonic,
                                     list<Trait> traits = []>:
@@ -410,4 +411,24 @@ def ConvertToSvboolIntrOp :
     /*overloadedResults=*/[]>,
     Arguments<(ins SVEPredicate:$mask)>;
 
+// Note: This multi-vector intrinsic requires SME2.
+def ZipX2IntrOp : ArmSVE_IntrOp<"zip.x2",
+    /*traits=*/[],
+    /*overloadedOperands=*/[0],
+    /*overloadedResults=*/[],
+    /*numResults=*/2>,
+    Arguments<(ins Arg<AnyScalableVector, "v1">:$v1,
+                   Arg<AnyScalableVector, "v2">:$v2)>;
+
+// Note: This multi-vector intrinsic requires SME2.
+def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
+    /*traits=*/[],
+    /*overloadedOperands=*/[0],
+    /*overloadedResults=*/[],
+    /*numResults=*/4>,
+    Arguments<(ins Arg<AnyScalableVector, "v1">:$v1,
+                   Arg<AnyScalableVector, "v2">:$v2,
+                   Arg<AnyScalableVector, "v3">:$v3,
+                   Arg<AnyScalableVector, "v3">:$v4)>;
+
 #endif // ARMSVE_OPS
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index b63d3f0651569..c7cd1b74ccdb5 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -314,3 +314,45 @@ llvm.func @arm_sve_convert_to_svbool(
     : (vector<[1]xi1>) -> vector<[16]xi1>
   llvm.return
 }
+
+// CHECK-LABEL: arm_sve_zip_x2(
+// CHECK-SAME:                 <vscale x 16 x i8> %[[V1:[0-9]+]],
+// CHECK-SAME:                 <vscale x 8 x i16> %[[V2:[0-9]+]],
+// CHECK-SAME:                 <vscale x 4 x i32> %[[V3:[0-9]+]],
+// CHECK-SAME:                 <vscale x 2 x i64> %[[V4:[0-9]+]])
+llvm.func @arm_sve_zip_x2(%nxv16i8: vector<[16]xi8>, %nxv8i16: vector<[8]xi16>, %nxv4i32: vector<[4]xi32>, %nxv2i64: vector<[2]xi64>) {
+  // CHECK: call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.aarch64.sve.zip.x2.nxv16i8(<vscale x 16 x i8> %[[V1]], <vscale x 16 x i8> %[[V1]])
+  %0 = "arm_sve.intr.zip.x2"(%nxv16i8, %nxv16i8) : (vector<[16]xi8>, vector<[16]xi8>)
+    -> !llvm.struct<(vector<[16]xi8>, vector<[16]xi8>)>
+  // CHECK: call { <vscale x 8 x i16>, <vscale x 8 x i16> } @llvm.aarch64.sve.zip.x2.nxv8i16(<vscale x 8 x i16> %[[V2]], <vscale x 8 x i16> %[[V2]])
+  %1 = "arm_sve.intr.zip.x2"(%nxv8i16, %nxv8i16) : (vector<[8]xi16>, vector<[8]xi16>)
+    -> !llvm.struct<(vector<[8]xi16>, vector<[8]xi16>)>
+  // CHECK: call { <vscale x 4 x i32>, <vscale x 4 x i32> } @llvm.aarch64.sve.zip.x2.nxv4i32(<vscale x 4 x i32> %[[V3]], <vscale x 4 x i32> %[[V3]])
+  %2 = "arm_sve.intr.zip.x2"(%nxv4i32, %nxv4i32) : (vector<[4]xi32>, vector<[4]xi32>)
+    -> !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)>
+  // CHECK: call { <vscale x 2 x i64>, <vscale x 2 x i64> } @llvm.aarch64.sve.zip.x2.nxv2i64(<vscale x 2 x i64> %[[V4]], <vscale x 2 x i64> %[[V4]])
+  %3 = "arm_sve.intr.zip.x2"(%nxv2i64, %nxv2i64) : (vector<[2]xi64>, vector<[2]xi64>)
+     -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
+  llvm.return
+}
+
+// CHECK-LABEL: arm_sve_zip_x4(
+// CHECK-SAME:                 <vscale x 16 x i8> %[[V1:[0-9]+]],
+// CHECK-SAME:                 <vscale x 8 x i16> %[[V2:[0-9]+]],
+// CHECK-SAME:                 <vscale x 4 x i32> %[[V3:[0-9]+]],
+// CHECK-SAME:                 <vscale x 2 x i64> %[[V4:[0-9]+]])
+llvm.func @arm_sve_zip_x4(%nxv16i8: vector<[16]xi8>, %nxv8i16: vector<[8]xi16>, %nxv4i32: vector<[4]xi32>, %nxv2i64: vector<[2]xi64>) {
+  // CHECK: call { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.aarch64.sve.zip.x4.nxv16i8(<vscale x 16 x i8> %[[V1]], <vscale x 16 x i8> %[[V1]], <vscale x 16 x i8> %[[V1]], <vscale x 16 x i8> %[[V1]])
+  %0 = "arm_sve.intr.zip.x4"(%nxv16i8, %nxv16i8, %nxv16i8, %nxv16i8) : (vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>)
+    -> !llvm.struct<(vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>)>
+  // CHECK: call { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } @llvm.aarch64.sve.zip.x4.nxv8i16(<vscale x 8 x i16> %[[V2]], <vscale x 8 x i16> %[[V2]], <vscale x 8 x i16> %[[V2]], <vscale x 8 x i16> %[[V2]])
+  %1 = "arm_sve.intr.zip.x4"(%nxv8i16, %nxv8i16, %nxv8i16, %nxv8i16) : (vector<[8]xi16>, vector<[8]xi16>, vector<[8]xi16>, vector<[8]xi16>)
+    -> !llvm.struct<(vector<[8]xi16>, vector<[8]xi16>, vector<[8]xi16>, vector<[8]xi16>)>
+  // CHECK: call { <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32> } @llvm.aarch64.sve.zip.x4.nxv4i32(<vscale x 4 x i32> %[[V3]], <vscale x 4 x i32> %[[V3]], <vscale x 4 x i32> %[[V3]], <vscale x 4 x i32> %[[V3]])
+  %2 = "arm_sve.intr.zip.x4"(%nxv4i32, %nxv4i32, %nxv4i32, %nxv4i32) : (vector<[4]xi32>, vector<[4]xi32>, vector<[4]xi32>, vector<[4]xi32>)
+    -> !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>, vector<[4]xi32>, vector<[4]xi32>)>
+  // CHECK: call { <vscale x 2 x i64>, <vscale x 2 x i64>, <vscale x 2 x i64>, <vscale x 2 x i64> } @llvm.aarch64.sve.zip.x4.nxv2i64(<vscale x 2 x i64> %[[V4]], <vscale x 2 x i64> %[[V4]], <vscale x 2 x i64> %[[V4]], <vscale x 2 x i64> %[[V4]])
+  %3 = "arm_sve.intr.zip.x4"(%nxv2i64, %nxv2i64, %nxv2i64, %nxv2i64) : (vector<[2]xi64>, vector<[2]xi64>, vector<[2]xi64>, vector<[2]xi64>)
+     -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>, vector<[2]xi64>, vector<[2]xi64>)>
+  llvm.return
+}



More information about the Mlir-commits mailing list