[Mlir-commits] [mlir] 8ce23b8 - [mlir][ArmSME] Add vector to tile intrinsics

Cullen Rhodes llvmlistbot at llvm.org
Mon Aug 21 03:36:22 PDT 2023


Author: Cullen Rhodes
Date: 2023-08-21T10:35:58Z
New Revision: 8ce23b8e5c91c530d25c13f97b6f4cbacfe34b3c

URL: https://github.com/llvm/llvm-project/commit/8ce23b8e5c91c530d25c13f97b6f4cbacfe34b3c
DIFF: https://github.com/llvm/llvm-project/commit/8ce23b8e5c91c530d25c13f97b6f4cbacfe34b3c.diff

LOG: [mlir][ArmSME] Add vector to tile intrinsics

Add support for following vector to tile (MOVA) intrinsics to ArmSME
dialect:

  llvm.aarch64.sme.write.vert
  llvm.aarch64.sme.write.horiz

Includes the definition of new type predicate
'ScalableVectorOfRankAndLengthAndType' in OpBase.td.

Reviewed By: awarzynski, dcaballe

Differential Revision: https://reviews.llvm.org/D157004

Added: 
    mlir/test/Target/LLVMIR/arm-sme-invalid.mlir

Modified: 
    mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
    mlir/include/mlir/IR/CommonTypeConstraints.td
    mlir/test/Target/LLVMIR/arm-sme.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 95c5f899bdb52d..b083baf03fa96a 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -14,6 +14,7 @@
 #ifndef ARMSME_OPS
 #define ARMSME_OPS
 
+include "mlir/IR/OpBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -61,6 +62,12 @@ def nxnxv2f64  : SMETileType<F64,  [2,  2 ], "vector<[2]x[2]xf64>">;
 def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
                          nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64]>;
 
+def SVEVector : ScalableVectorOfRankAndLengthAndType<
+  [1], [16, 8, 4, 2, 1], [I8, I16, I32, I64, I128, F16, BF16, F32, F64]>;
+
+def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
+  [1], [16, 8, 4, 2, 1], [I1]>;
+
 // A type constraint that verifies the bitwidth of the scalar integer returned
 // from 'arm_sme.get_tile_id' matches the element bitwidth of a "virtual tile".
 def TileElementWidthMatchesTileID : TypesMatchWith<
@@ -496,6 +503,18 @@ def LLVM_aarch64_sme_str
       Arguments<(ins Arg<I32, "Index">,
                  Arg<LLVM_AnyPointer, "Store address", [MemWrite]>)>;
 
+// Vector to tile
+class LLVM_aarch64_sme_write<string direction>
+    : ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
+                    [AllShapesMatch<["pg", "vector"]>]>,
+      Arguments<(ins Arg<I32, "Virtual tile ID">,
+                     Arg<I32, "Tile slice">,
+                     Arg<SVEPredicate, "Vector predicate">:$pg,
+                     Arg<SVEVector, "Vector operand">:$vector)>;
+
+def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
+def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
+
 def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
 def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
 

diff  --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index ebb1de47566d62..4fc14e30b8a10d 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -533,6 +533,19 @@ class ScalableVectorOfLengthAndType<list<int> allowedLengths,
   ScalableVectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
 
+// Any scalable vector where the rank is from the given `allowedRanks` list and
+// the number of elements is from the given `allowedLengths` list and the type
+// is from the given `allowedTypes` list
+class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
+                                           list<int> allowedLengths,
+                                           list<Type> allowedTypes> : AllOfType<
+  [ScalableVectorOfRank<allowedRanks>, ScalableVectorOf<allowedTypes>,
+   ScalableVectorOfLength<allowedLengths>],
+  ScalableVectorOfRank<allowedRanks>.summary #
+  ScalableVectorOf<allowedTypes>.summary #
+  ScalableVectorOfLength<allowedLengths>.summary,
+  "::mlir::VectorType">;
+
 def AnyVector : VectorOf<[AnyType]>;
 // Temporary vector type clone that allows gradual transition to 0-D vectors.
 def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;

diff  --git a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
new file mode 100644
index 00000000000000..e119e1f1a40441
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
+
+// Verify shape of predicate and vector must match
+llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32,
+                                                %nxv4i1 : vector<[4]xi1>,
+                                                %nxv16i8 : vector<[16]xi8>) {
+  %tile = llvm.mlir.constant(0 : index) : i32
+  // expected-error @+1 {{failed to verify that all of {pg, vector} have same shape}}
+  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv16i8) :
+      (i32, i32, vector<[4]xi1>, vector<[16]xi8>) -> ()
+  llvm.return
+}

diff  --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
index 7beec1f61aa923..9bb6b0c6574fcd 100644
--- a/mlir/test/Target/LLVMIR/arm-sme.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -236,3 +236,101 @@ llvm.func @arm_sme_toggle_za() {
   "arm_sme.intr.za.disable"() : () -> ()
   llvm.return
 }
+
+// -----
+
+// CHECK-LABEL: @arm_sme_vector_to_tile_horiz
+llvm.func @arm_sme_vector_to_tile_horiz(%tileslice : i32,
+                                        %nxv16i1 : vector<[16]xi1>,
+                                        %nxv8i1 : vector<[8]xi1>,
+                                        %nxv4i1 : vector<[4]xi1>,
+                                        %nxv2i1 : vector<[2]xi1>,
+                                        %nxv1i1 : vector<[1]xi1>,
+                                        %nxv16i8 : vector<[16]xi8>,
+                                        %nxv8i16 : vector<[8]xi16>,
+                                        %nxv4i32 : vector<[4]xi32>,
+                                        %nxv2i64 : vector<[2]xi64>,
+                                        %nxv1i128 : vector<[1]xi128>,
+                                        %nxv8f16 : vector<[8]xf16>,
+                                        %nxv8bf16 : vector<[8]xbf16>,
+                                        %nxv4f32 : vector<[4]xf32>,
+                                        %nxv2f64 : vector<[2]xf64>) {
+  %tile = llvm.mlir.constant(0 : index) : i32
+  // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv16i8
+  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv16i1, %nxv16i8) :
+      (i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8i16
+  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8i16) :
+      (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv4i32
+  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv4i32) :
+      (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv2i64
+  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv2i1, %nxv2i64) :
+      (i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv1i128
+  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv1i1, %nxv1i128) :
+      (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8f16
+  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8f16) :
+      (i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8bf16
+  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8bf16) :
+      (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv4f32
+  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv4f32) :
+      (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv2f64
+  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv2i1, %nxv2f64) :
+      (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_vector_to_tile_vert
+llvm.func @arm_sme_vector_to_tile_vert(%tileslice : i32,
+                                       %nxv16i1 : vector<[16]xi1>,
+                                       %nxv8i1 : vector<[8]xi1>,
+                                       %nxv4i1 : vector<[4]xi1>,
+                                       %nxv2i1 : vector<[2]xi1>,
+                                       %nxv1i1 : vector<[1]xi1>,
+                                       %nxv16i8 : vector<[16]xi8>,
+                                       %nxv8i16 : vector<[8]xi16>,
+                                       %nxv4i32 : vector<[4]xi32>,
+                                       %nxv2i64 : vector<[2]xi64>,
+                                       %nxv1i128 : vector<[1]xi128>,
+                                       %nxv8f16 : vector<[8]xf16>,
+                                       %nxv8bf16 : vector<[8]xbf16>,
+                                       %nxv4f32 : vector<[4]xf32>,
+                                       %nxv2f64 : vector<[2]xf64>) {
+  %tile = llvm.mlir.constant(0 : index) : i32
+  // CHECK: call void @llvm.aarch64.sme.write.vert.nxv16i8
+  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv16i1, %nxv16i8) :
+      (i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.write.vert.nxv8i16
+  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8i16) :
+      (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.write.vert.nxv4i32
+  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv4i1, %nxv4i32) :
+      (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.write.vert.nxv2i64
+  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv2i1, %nxv2i64) :
+      (i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.write.vert.nxv1i128
+  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv1i1, %nxv1i128) :
+      (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.write.vert.nxv8f16
+  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8f16) :
+      (i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.write.vert.nxv8bf16
+  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8bf16) :
+      (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.write.vert.nxv4f32
+  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv4i1, %nxv4f32) :
+      (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.write.vert.nxv2f64
+  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv2i1, %nxv2f64) :
+      (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+  llvm.return
+}


        


More information about the Mlir-commits mailing list