[Mlir-commits] [mlir] 8ce23b8 - [mlir][ArmSME] Add vector to tile intrinsics
Cullen Rhodes
llvmlistbot at llvm.org
Mon Aug 21 03:36:22 PDT 2023
Author: Cullen Rhodes
Date: 2023-08-21T10:35:58Z
New Revision: 8ce23b8e5c91c530d25c13f97b6f4cbacfe34b3c
URL: https://github.com/llvm/llvm-project/commit/8ce23b8e5c91c530d25c13f97b6f4cbacfe34b3c
DIFF: https://github.com/llvm/llvm-project/commit/8ce23b8e5c91c530d25c13f97b6f4cbacfe34b3c.diff
LOG: [mlir][ArmSME] Add vector to tile intrinsics
Add support for following vector to tile (MOVA) intrinsics to ArmSME
dialect:
llvm.aarch64.sme.write.vert
llvm.aarch64.sme.write.horiz
Includes the definition of new type predicate
'ScalableVectorOfRankAndLengthAndType' in OpBase.td.
Reviewed By: awarzynski, dcaballe
Differential Revision: https://reviews.llvm.org/D157004
Added:
mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
Modified:
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
mlir/include/mlir/IR/CommonTypeConstraints.td
mlir/test/Target/LLVMIR/arm-sme.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 95c5f899bdb52d..b083baf03fa96a 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -14,6 +14,7 @@
#ifndef ARMSME_OPS
#define ARMSME_OPS
+include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -61,6 +62,12 @@ def nxnxv2f64 : SMETileType<F64, [2, 2 ], "vector<[2]x[2]xf64>">;
def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64]>;
+def SVEVector : ScalableVectorOfRankAndLengthAndType<
+ [1], [16, 8, 4, 2, 1], [I8, I16, I32, I64, I128, F16, BF16, F32, F64]>;
+
+def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
+ [1], [16, 8, 4, 2, 1], [I1]>;
+
// A type constraint that verifies the bitwidth of the scalar integer returned
// from 'arm_sme.get_tile_id' matches the element bitwidth of a "virtual tile".
def TileElementWidthMatchesTileID : TypesMatchWith<
@@ -496,6 +503,18 @@ def LLVM_aarch64_sme_str
Arguments<(ins Arg<I32, "Index">,
Arg<LLVM_AnyPointer, "Store address", [MemWrite]>)>;
+// Vector to tile
+class LLVM_aarch64_sme_write<string direction>
+ : ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
+ [AllShapesMatch<["pg", "vector"]>]>,
+ Arguments<(ins Arg<I32, "Virtual tile ID">,
+ Arg<I32, "Tile slice">,
+ Arg<SVEPredicate, "Vector predicate">:$pg,
+ Arg<SVEVector, "Vector operand">:$vector)>;
+
+def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
+def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
+
def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index ebb1de47566d62..4fc14e30b8a10d 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -533,6 +533,19 @@ class ScalableVectorOfLengthAndType<list<int> allowedLengths,
ScalableVectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;
+// Any scalable vector where the rank is from the given `allowedRanks` list and
+// the number of elements is from the given `allowedLengths` list and the type
+// is from the given `allowedTypes` list
+class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
+ list<int> allowedLengths,
+ list<Type> allowedTypes> : AllOfType<
+ [ScalableVectorOfRank<allowedRanks>, ScalableVectorOf<allowedTypes>,
+ ScalableVectorOfLength<allowedLengths>],
+ ScalableVectorOfRank<allowedRanks>.summary #
+ ScalableVectorOf<allowedTypes>.summary #
+ ScalableVectorOfLength<allowedLengths>.summary,
+ "::mlir::VectorType">;
+
def AnyVector : VectorOf<[AnyType]>;
// Temporary vector type clone that allows gradual transition to 0-D vectors.
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
diff --git a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
new file mode 100644
index 00000000000000..e119e1f1a40441
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
+
+// Verify shape of predicate and vector must match
+llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32,
+ %nxv4i1 : vector<[4]xi1>,
+ %nxv16i8 : vector<[16]xi8>) {
+ %tile = llvm.mlir.constant(0 : index) : i32
+ // expected-error @+1 {{failed to verify that all of {pg, vector} have same shape}}
+ "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv16i8) :
+ (i32, i32, vector<[4]xi1>, vector<[16]xi8>) -> ()
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
index 7beec1f61aa923..9bb6b0c6574fcd 100644
--- a/mlir/test/Target/LLVMIR/arm-sme.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -236,3 +236,101 @@ llvm.func @arm_sme_toggle_za() {
"arm_sme.intr.za.disable"() : () -> ()
llvm.return
}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_vector_to_tile_horiz
+llvm.func @arm_sme_vector_to_tile_horiz(%tileslice : i32,
+ %nxv16i1 : vector<[16]xi1>,
+ %nxv8i1 : vector<[8]xi1>,
+ %nxv4i1 : vector<[4]xi1>,
+ %nxv2i1 : vector<[2]xi1>,
+ %nxv1i1 : vector<[1]xi1>,
+ %nxv16i8 : vector<[16]xi8>,
+ %nxv8i16 : vector<[8]xi16>,
+ %nxv4i32 : vector<[4]xi32>,
+ %nxv2i64 : vector<[2]xi64>,
+ %nxv1i128 : vector<[1]xi128>,
+ %nxv8f16 : vector<[8]xf16>,
+ %nxv8bf16 : vector<[8]xbf16>,
+ %nxv4f32 : vector<[4]xf32>,
+ %nxv2f64 : vector<[2]xf64>) {
+ %tile = llvm.mlir.constant(0 : index) : i32
+ // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv16i8
+ "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv16i1, %nxv16i8) :
+ (i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+ // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8i16
+ "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8i16) :
+ (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+ // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv4i32
+ "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv4i32) :
+ (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+ // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv2i64
+ "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv2i1, %nxv2i64) :
+ (i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+ // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv1i128
+ "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv1i1, %nxv1i128) :
+ (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+ // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8f16
+ "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8f16) :
+ (i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+ // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8bf16
+ "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8bf16) :
+ (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+ // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv4f32
+ "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv4f32) :
+ (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+ // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv2f64
+ "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv2i1, %nxv2f64) :
+ (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_vector_to_tile_vert
+llvm.func @arm_sme_vector_to_tile_vert(%tileslice : i32,
+ %nxv16i1 : vector<[16]xi1>,
+ %nxv8i1 : vector<[8]xi1>,
+ %nxv4i1 : vector<[4]xi1>,
+ %nxv2i1 : vector<[2]xi1>,
+ %nxv1i1 : vector<[1]xi1>,
+ %nxv16i8 : vector<[16]xi8>,
+ %nxv8i16 : vector<[8]xi16>,
+ %nxv4i32 : vector<[4]xi32>,
+ %nxv2i64 : vector<[2]xi64>,
+ %nxv1i128 : vector<[1]xi128>,
+ %nxv8f16 : vector<[8]xf16>,
+ %nxv8bf16 : vector<[8]xbf16>,
+ %nxv4f32 : vector<[4]xf32>,
+ %nxv2f64 : vector<[2]xf64>) {
+ %tile = llvm.mlir.constant(0 : index) : i32
+ // CHECK: call void @llvm.aarch64.sme.write.vert.nxv16i8
+ "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv16i1, %nxv16i8) :
+ (i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+ // CHECK: call void @llvm.aarch64.sme.write.vert.nxv8i16
+ "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8i16) :
+ (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+ // CHECK: call void @llvm.aarch64.sme.write.vert.nxv4i32
+ "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv4i1, %nxv4i32) :
+ (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+ // CHECK: call void @llvm.aarch64.sme.write.vert.nxv2i64
+ "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv2i1, %nxv2i64) :
+ (i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+ // CHECK: call void @llvm.aarch64.sme.write.vert.nxv1i128
+ "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv1i1, %nxv1i128) :
+ (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+ // CHECK: call void @llvm.aarch64.sme.write.vert.nxv8f16
+ "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8f16) :
+ (i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+ // CHECK: call void @llvm.aarch64.sme.write.vert.nxv8bf16
+ "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8bf16) :
+ (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+ // CHECK: call void @llvm.aarch64.sme.write.vert.nxv4f32
+ "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv4i1, %nxv4f32) :
+ (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+ // CHECK: call void @llvm.aarch64.sme.write.vert.nxv2f64
+ "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv2i1, %nxv2f64) :
+ (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+ llvm.return
+}
More information about the Mlir-commits
mailing list