[Mlir-commits] [mlir] 16d890c - [mlir][ArmNeon] Adds Arm Neon SMMLA, UMMLA, and USMMLA Intrinsics (#80511)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 6 14:12:45 PST 2024
Author: Kojo Acquah
Date: 2024-02-06T14:12:40-08:00
New Revision: 16d890ced68aafae4cc8ba3efc9213bfab84ba54
URL: https://github.com/llvm/llvm-project/commit/16d890ced68aafae4cc8ba3efc9213bfab84ba54
DIFF: https://github.com/llvm/llvm-project/commit/16d890ced68aafae4cc8ba3efc9213bfab84ba54.diff
LOG: [mlir][ArmNeon] Adds Arm Neon SMMLA, UMMLA, and USMMLA Intrinsics (#80511)
This adds the SMMLA, UMMLA, and USMMLA intrinsics to Neon dialect bringing it in line with the SVE dialect.
These ops enable matrix multiply-accumulate instructions with two e 2x8 matrix inputs of respective signage into a 2x2 32-bit integer accumulator. This is equivalent to performing an 8-way dot product per destination element.
Op details:
https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=mmla
Added:
Modified:
mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
mlir/test/Dialect/ArmNeon/invalid.mlir
mlir/test/Dialect/ArmNeon/roundtrip.mlir
mlir/test/Target/LLVMIR/arm-neon.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
index e298963f3d19f..9cc792093bf83 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
+++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
@@ -30,6 +30,15 @@ def ArmNeon_Dialect : Dialect {
// to the LLVMDialect (ops or types).
}
+//===----------------------------------------------------------------------===//
+// ArmNeon type definition
+//===----------------------------------------------------------------------===//
+
+class NeonVectorOfLength<int length, Type elementType> : ShapedContainerType<
+ [elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorTypePred]>,
+ "a vector with length " # length,
+ "::mlir::VectorType">;
+
//===----------------------------------------------------------------------===//
// ArmNeon op definitions
//===----------------------------------------------------------------------===//
@@ -120,6 +129,99 @@ def SdotOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"sdot",[1], [
"$a `,` $b `,` $c attr-dict `:` type($b) `,` type($c) `to` type($res)";
}
+def SmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"smmla",[1], [
+ Pure,
+ AllTypesMatch<["src1", "src2"]>,
+ AllTypesMatch<["acc", "res"]>,
+ ]> {
+ let summary = "Matrix-matrix multiply and accumulate op";
+ let description = [{
+ SMMLA: Signed integer matrix multiply-accumulate.
+
+ Signed 8-bit integer matrix multiply-accumulate. This instruction multiplies
+ the 2x8 matrix of signed 8-bit integer values in the first source vector by
+ the 8x2 matrix of signed 8-bit integer values in the second source vector.
+ The resulting 2x2 32-bit integer matrix product is destructively added to
+ the 32-bit integer matrix accumulator in the destination vector. This is
+ equivalent to performing an 8-way dot product per destination element.
+
+ Source:
+ https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=smmla
+ }];
+ // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+ let arguments = (ins
+ NeonVectorOfLength<4, I32>:$acc,
+ NeonVectorOfLength<16, I8>:$src1,
+ NeonVectorOfLength<16, I8>:$src2
+ );
+ let results = (outs NeonVectorOfLength<4, I32>:$res);
+ let assemblyFormat =
+ "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
+}
+
+def UmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"ummla",[1], [
+ Pure,
+ AllTypesMatch<["src1", "src2"]>,
+ AllTypesMatch<["acc", "res"]>,
+ ]> {
+ let summary = "Unsinged matrix-matrix multiply and accumulate op";
+ let description = [{
+ UMMLA: Signed integer matrix multiply-accumulate.
+
+ Unsigned 8-bit integer matrix multiply-accumulate. This instruction
+ multiplies the 2x8 matrix of unsigned 8-bit integer values in the first
+ source vector by the 8x2 matrix of unsigned 8-bit integer values in the
+ second source vector. The resulting 2x2 32-bit integer matrix product is
+ destructively added to the 32-bit integer matrix accumulator in the
+ destination vector. This is equivalent to performing an 8-way dot product
+ per destination element.
+
+ Source:
+ https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=ummla
+ }];
+ // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+ let arguments = (ins
+ NeonVectorOfLength<4, I32>:$acc,
+ NeonVectorOfLength<16, I8>:$src1,
+ NeonVectorOfLength<16, I8>:$src2
+ );
+ let results = (outs NeonVectorOfLength<4, I32>:$res);
+ let assemblyFormat =
+ "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
+}
+
+def UsmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"usmmla",[1], [
+ Pure,
+ AllTypesMatch<["src1", "src2"]>,
+ AllTypesMatch<["acc", "res"]>,
+ ]> {
+ let summary = "Unsignged and signed matrix-matrix multiply and accumulate op";
+ let description = [{
+ USMMLA: Signed integer matrix multiply-accumulate.
+
+ Unsigned and signed 8-bit integer matrix multiply-accumulate. This
+ instruction multiplies the 2x8 matrix of unsigned 8-bit integer values in
+ the first source vector by the 8x2 matrix of signed 8-bit integer values in
+ the second source vector. The resulting 2x2 32-bit integer matrix product is
+ destructively added to the 32-bit integer matrix accumulator in the
+ destination vector. This is equivalent to performing an 8-way dot product
+ per destination element.
+
+
+ Source:
+ https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=usmmla
+ }];
+ // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+ let arguments = (ins
+ NeonVectorOfLength<4, I32>:$acc,
+ NeonVectorOfLength<16, I8>:$src1,
+ NeonVectorOfLength<16, I8>:$src2
+ );
+ let results = (outs NeonVectorOfLength<4, I32>:$res);
+ let assemblyFormat =
+ "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
+}
+
class ArmNeon_2dOp<string mnemonic, list<Trait> traits = []>
: Op</*dialect=*/ArmNeon_Dialect,
/*opName=*/"2d." # mnemonic,
diff --git a/mlir/test/Dialect/ArmNeon/invalid.mlir b/mlir/test/Dialect/ArmNeon/invalid.mlir
index 62caf04160020..be49e0e4597b0 100644
--- a/mlir/test/Dialect/ArmNeon/invalid.mlir
+++ b/mlir/test/Dialect/ArmNeon/invalid.mlir
@@ -31,3 +31,63 @@ func.func @b_has_2_rows_but_a_has_length_4(%a : vector<4xi32>, %b : vector<2x4xi
%0 = arm_neon.2d.sdot %a, %b, %b : vector<2x4xi8>, vector<2x4xi8> to vector<4xi32>
return %0 : vector<4xi32>
}
+
+// -----
+
+func.func @smmla_invalid_input_types(%a: vector<4xi32>,
+ %b: vector<16xi4>,
+ %c: vector<16xi4>) -> vector<4xi32> {
+ // expected-error at +1 {{op operand #1 must be a vector with length 16 of 8-bit signless integer values, but got 'vector<16xi4>'}}
+ %0 = arm_neon.intr.smmla %a, %b, %c : vector<16xi4> to vector<4xi32>
+ return %0 : vector<4xi32>
+}
+
+// -----
+
+func.func @smmla_invalid_dimensions(%a: vector<8xi32>,
+ %b: vector<32xi8>,
+ %c: vector<32xi8>) -> vector<8xi32> {
+ // expected-error at +1 {{op operand #0 must be a vector with length 4 of 32-bit signless integer values, but got 'vector<8xi32>'}}
+ %0 = arm_neon.intr.smmla %a, %b, %c : vector<32xi8> to vector<8xi32>
+ return %0 : vector<8xi32>
+}
+
+// -----
+
+func.func @ummla_invalid_input_types(%a: vector<4xi32>,
+ %b: vector<16xi4>,
+ %c: vector<16xi4>) -> vector<4xi32> {
+ // expected-error at +1 {{op operand #1 must be a vector with length 16 of 8-bit signless integer values, but got 'vector<16xi4>'}}
+ %0 = arm_neon.intr.ummla %a, %b, %c : vector<16xi4> to vector<4xi32>
+ return %0 : vector<4xi32>
+}
+
+// -----
+
+func.func @ummla_invalid_dimensions(%a: vector<8xi32>,
+ %b: vector<32xi8>,
+ %c: vector<32xi8>) -> vector<8xi32> {
+ // expected-error at +1 {{op operand #0 must be a vector with length 4 of 32-bit signless integer values, but got 'vector<8xi32>'}}
+ %0 = arm_neon.intr.ummla %a, %b, %c : vector<32xi8> to vector<8xi32>
+ return %0 : vector<8xi32>
+}
+
+// -----
+
+func.func @usmmla_invalid_input_types(%a: vector<4xi32>,
+ %b: vector<16xi4>,
+ %c: vector<16xi4>) -> vector<4xi32> {
+ // expected-error at +1 {{op operand #1 must be a vector with length 16 of 8-bit signless integer values, but got 'vector<16xi4>'}}
+ %0 = arm_neon.intr.usmmla %a, %b, %c : vector<16xi4> to vector<4xi32>
+ return %0 : vector<4xi32>
+}
+
+// -----
+
+func.func @usmmla_invalid_dimensions(%a: vector<8xi32>,
+ %b: vector<32xi8>,
+ %c: vector<32xi8>) -> vector<8xi32> {
+ // expected-error at +1 {{op operand #0 must be a vector with length 4 of 32-bit signless integer values, but got 'vector<8xi32>'}}
+ %0 = arm_neon.intr.usmmla %a, %b, %c : vector<32xi8> to vector<8xi32>
+ return %0 : vector<8xi32>
+}
diff --git a/mlir/test/Dialect/ArmNeon/roundtrip.mlir b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
index 704bfe8c084a5..b5df0ffa8105c 100644
--- a/mlir/test/Dialect/ArmNeon/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt -verify-diagnostics -split-input-file %s | mlir-opt | FileCheck %s
// CHECK-LABEL: arm_neon_smull
func.func @arm_neon_smull(%a: vector<8xi8>, %b: vector<8xi8>)
@@ -19,9 +19,44 @@ func.func @arm_neon_smull(%a: vector<8xi8>, %b: vector<8xi8>)
return %0, %1, %2 : vector<8xi16>, vector<4xi32>, vector<2xi64>
}
+// -----
+
// CHECK-LABEL: arm_neon_sdot
func.func @arm_neon_sdot(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<8xi8>) -> vector<2xi32> {
// CHECK: arm_neon.intr.sdot {{.*}}: vector<8xi8>, vector<8xi8> to vector<2xi32>
%0 = arm_neon.intr.sdot %a, %b, %c : vector<8xi8>, vector<8xi8> to vector<2xi32>
return %0 : vector<2xi32>
}
+
+// -----
+
+// CHECK-LABEL: arm_neon_smmla
+func.func @arm_neon_smmla(%a: vector<16xi8>,
+ %b: vector<16xi8>,
+ %c: vector<4xi32>) -> vector<4xi32> {
+ // CHECK: arm_neon.intr.smmla {{.*}}: vector<16xi8> to vector<4xi32>
+ %0 = arm_neon.intr.smmla %c, %a, %b : vector<16xi8> to vector<4xi32>
+ return %0 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: arm_neon_ummla
+func.func @arm_neon_ummla(%a: vector<16xi8>,
+ %b: vector<16xi8>,
+ %c: vector<4xi32>) -> vector<4xi32> {
+ // CHECK: arm_neon.intr.ummla {{.*}}: vector<16xi8> to vector<4xi32>
+ %0 = arm_neon.intr.ummla %c, %a, %b : vector<16xi8> to vector<4xi32>
+ return %0 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: arm_neon_usmmla
+func.func @arm_neon_usmmla(%a: vector<16xi8>,
+ %b: vector<16xi8>,
+ %c: vector<4xi32>) -> vector<4xi32> {
+ // CHECK: arm_neon.intr.usmmla {{.*}}: vector<16xi8> to vector<4xi32>
+ %0 = arm_neon.intr.usmmla %c, %a, %b : vector<16xi8> to vector<4xi32>
+ return %0 : vector<4xi32>
+}
diff --git a/mlir/test/Target/LLVMIR/arm-neon.mlir b/mlir/test/Target/LLVMIR/arm-neon.mlir
index f4716fe58f203..e096172667c9f 100644
--- a/mlir/test/Target/LLVMIR/arm-neon.mlir
+++ b/mlir/test/Target/LLVMIR/arm-neon.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
// CHECK-LABEL: arm_neon_smull
llvm.func @arm_neon_smull(%arg0: vector<8xi8>, %arg1: vector<8xi8>) -> !llvm.struct<(vector<8xi16>, vector<4xi32>, vector<2xi64>)> {
@@ -24,6 +24,8 @@ llvm.func @arm_neon_smull(%arg0: vector<8xi8>, %arg1: vector<8xi8>) -> !llvm.str
llvm.return %8 : !llvm.struct<(vector<8xi16>, vector<4xi32>, vector<2xi64>)>
}
+// -----
+
// CHECK-LABEL: arm_neon_sdot_8_i8i8
llvm.func @arm_neon_sdot_8_i8i8(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<8xi8>) -> vector<2xi32> {
// CHECK: %[[V0:.*]] = call <2 x i32> @llvm.aarch64.neon.sdot.v2i32.v8i8(<2 x i32> %{{.*}}, <8 x i8> %{{.*}}, <8 x i8> %{{.*}})
@@ -32,6 +34,8 @@ llvm.func @arm_neon_sdot_8_i8i8(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<
llvm.return %0 : vector<2xi32>
}
+// -----
+
// CHECK-LABEL: arm_neon_sdot_16_i8i8
llvm.func @arm_neon_sdot_16_i8i8(%a: vector<4xi32>, %b: vector<16xi8>, %c: vector<16xi8>) -> vector<4xi32> {
// CHECK: %[[V0:.*]] = call <4 x i32> @llvm.aarch64.neon.sdot.v4i32.v16i8(<4 x i32> %{{.*}}, <16 x i8> %{{.*}}, <16 x i8> %{{.*}})
@@ -39,3 +43,42 @@ llvm.func @arm_neon_sdot_16_i8i8(%a: vector<4xi32>, %b: vector<16xi8>, %c: vecto
%0 = arm_neon.intr.sdot %a, %b, %c : vector<16xi8>, vector<16xi8> to vector<4xi32>
llvm.return %0 : vector<4xi32>
}
+
+// -----
+
+// CHECK-LABEL: arm_neon_smmla
+llvm.func @arm_neon_smmla(%arg0: vector<16xi8>,
+ %arg1: vector<16xi8>,
+ %arg2: vector<4xi32>) -> vector<4xi32> {
+ // CHECK: <4 x i32> @llvm.aarch64.neon.smmla.v4i32.v16i8(<4 x i32
+ %0 = "arm_neon.intr.smmla"(%arg2, %arg0, %arg1) :
+ (vector<4xi32>, vector<16xi8>, vector<16xi8>)
+ -> vector<4xi32>
+ llvm.return %0 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: arm_neon_ummla
+llvm.func @arm_neon_ummla(%arg0: vector<16xi8>,
+ %arg1: vector<16xi8>,
+ %arg2: vector<4xi32>) -> vector<4xi32> {
+ // CHECK: <4 x i32> @llvm.aarch64.neon.ummla.v4i32.v16i8(<4 x i32
+ %0 = "arm_neon.intr.ummla"(%arg2, %arg0, %arg1) :
+ (vector<4xi32>, vector<16xi8>, vector<16xi8>)
+ -> vector<4xi32>
+ llvm.return %0 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: arm_neon_usmmla
+llvm.func @arm_neon_usmmla(%arg0: vector<16xi8>,
+ %arg1: vector<16xi8>,
+ %arg2: vector<4xi32>) -> vector<4xi32> {
+ // CHECK: <4 x i32> @llvm.aarch64.neon.usmmla.v4i32.v16i8(<4 x i32
+ %0 = "arm_neon.intr.usmmla"(%arg2, %arg0, %arg1) :
+ (vector<4xi32>, vector<16xi8>, vector<16xi8>)
+ -> vector<4xi32>
+ llvm.return %0 : vector<4xi32>
+}
More information about the Mlir-commits
mailing list