[Mlir-commits] [mlir] [mlir][ArmNeon] Adds Arm Neon SMMLA, UMMLA, and USMMLA Intrinsics (PR #80511)

Kojo Acquah llvmlistbot at llvm.org
Mon Feb 5 14:26:34 PST 2024


https://github.com/KoolJBlack updated https://github.com/llvm/llvm-project/pull/80511

>From dbe547d068dc97134fbe3a6566e45cbe1424f30e Mon Sep 17 00:00:00 2001
From: Kojo Acquah <kooljblack at google.com>
Date: Thu, 1 Feb 2024 13:34:43 -0500
Subject: [PATCH 1/3] implemented roundtrip and target for neon smmla, ummla
 and usmmla

---
 mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td | 93 ++++++++++++++++++++
 mlir/test/Dialect/ArmNeon/invalid.mlir       | 66 ++++++++++++++
 mlir/test/Dialect/ArmNeon/roundtrip.mlir     | 35 +++++++-
 mlir/test/Target/LLVMIR/arm-neon.mlir        | 44 ++++++++-
 4 files changed, 236 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
index e298963f3d19f..c515a858ee8a1 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
+++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
@@ -120,6 +120,99 @@ def SdotOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"sdot",[1], [
     "$a `,` $b `,` $c attr-dict `:` type($b) `,` type($c) `to` type($res)";
   }
 
+def SmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"smmla",[1], [
+                Pure,
+                AllTypesMatch<["b", "c"]>,
+                AllTypesMatch<["a", "res"]>,
+              ]> {
+  let summary = "Matrix-matrix multiply and accumulate op";
+  let description = [{
+    SMMLA: Signed integer matrix multiply-accumulate.
+
+    Signed 8-bit integer matrix multiply-accumulate. This instruction multiplies
+    the 2x8 matrix of signed 8-bit integer values in the first source vector by
+    the 8x2 matrix of signed 8-bit integer values in the second source vector.
+    The resulting 2x2 32-bit integer matrix product is destructively added to
+    the 32-bit integer matrix accumulator in the destination vector. This is
+    equivalent to performing an 8-way dot product per destination element.
+
+    Source:
+    https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=smmla
+  }];
+  // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+  let arguments = (ins
+          VectorOfLengthAndType<[4], [I32]>:$a,
+          VectorOfLengthAndType<[16], [I8]>:$b,
+          VectorOfLengthAndType<[16], [I8]>:$c
+  );
+  let results = (outs VectorOfLengthAndType<[4], [I32]>:$res);
+  let assemblyFormat =
+    "$a `,` $b `,` $c attr-dict `:` type($b) `to` type($res)";
+}
+
+def UmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"ummla",[1], [
+                Pure,
+                AllTypesMatch<["b", "c"]>,
+                AllTypesMatch<["a", "res"]>,
+              ]> {
+  let summary = "Unsinged matrix-matrix multiply and accumulate op";
+  let description = [{
+    UMMLA: Signed integer matrix multiply-accumulate.
+
+    Unsigned 8-bit integer matrix multiply-accumulate. This instruction
+    multiplies the 2x8 matrix of unsigned 8-bit integer values in the first
+    source vector by the 8x2 matrix of unsigned 8-bit integer values in the
+    second source vector. The resulting 2x2 32-bit integer matrix product is
+    destructively added to the 32-bit integer matrix accumulator in the
+    destination vector. This is equivalent to performing an 8-way dot product
+    per destination element.
+
+    Source:
+    https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=ummla
+  }];
+  // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+  let arguments = (ins
+          VectorOfLengthAndType<[4], [I32]>:$a,
+          VectorOfLengthAndType<[16], [I8]>:$b,
+          VectorOfLengthAndType<[16], [I8]>:$c
+  );
+  let results = (outs VectorOfLengthAndType<[4], [I32]>:$res);
+  let assemblyFormat =
+    "$a `,` $b `,` $c attr-dict `:` type($b) `to` type($res)";
+}
+
+def UsmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"usmmla",[1], [
+                Pure,
+                AllTypesMatch<["b", "c"]>,
+                AllTypesMatch<["a", "res"]>,
+              ]> {
+  let summary = "Unsignged and signed matrix-matrix multiply and accumulate op";
+  let description = [{
+    USMMLA: Signed integer matrix multiply-accumulate.
+
+    Unsigned and signed 8-bit integer matrix multiply-accumulate. This
+    instruction multiplies the 2x8 matrix of unsigned 8-bit integer values in
+    the first source vector by the 8x2 matrix of signed 8-bit integer values in
+    the second source vector. The resulting 2x2 32-bit integer matrix product is
+    destructively added to the 32-bit integer matrix accumulator in the
+    destination vector. This is equivalent to performing an 8-way dot product
+     per destination element.
+
+
+    Source:
+    https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=usmmla
+  }];
+  // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+  let arguments = (ins
+          VectorOfLengthAndType<[4], [I32]>:$a,
+          VectorOfLengthAndType<[16], [I8]>:$b,
+          VectorOfLengthAndType<[16], [I8]>:$c
+  );
+  let results = (outs VectorOfLengthAndType<[4], [I32]>:$res);
+  let assemblyFormat =
+    "$a `,` $b `,` $c attr-dict `:` type($b) `to` type($res)";
+}
+
 class ArmNeon_2dOp<string mnemonic, list<Trait> traits = []>
     : Op</*dialect=*/ArmNeon_Dialect,
          /*opName=*/"2d." # mnemonic,
diff --git a/mlir/test/Dialect/ArmNeon/invalid.mlir b/mlir/test/Dialect/ArmNeon/invalid.mlir
index 62caf04160020..d8879c56857dc 100644
--- a/mlir/test/Dialect/ArmNeon/invalid.mlir
+++ b/mlir/test/Dialect/ArmNeon/invalid.mlir
@@ -31,3 +31,69 @@ func.func @b_has_2_rows_but_a_has_length_4(%a : vector<4xi32>, %b : vector<2x4xi
     %0 = arm_neon.2d.sdot %a, %b, %b : vector<2x4xi8>, vector<2x4xi8> to vector<4xi32>
     return %0 : vector<4xi32>
 }
+
+// -----
+
+func.func @smmla_invalid_input_types(%a: vector<16xi4>,
+                    %b: vector<16xi4>,
+                    %c: vector<4xi32>) -> vector<4xi32> {
+  // expected-error at +1 {{op operand #1 must be vector of 8-bit signless integer values of length 16, but got 'vector<16xi4>'}}
+  %0 = arm_neon.intr.smmla %c, %a, %b :
+             vector<16xi4> to vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// -----
+
+func.func @smmla_invalid_dimensions(%a: vector<32xi8>,
+                    %b: vector<32xi8>,
+                    %c: vector<8xi32>) -> vector<8xi32> {
+  // expected-error at +1 {{op operand #0 must be vector of 32-bit signless integer values of length 4, but got 'vector<8xi32>'}}
+  %0 = arm_neon.intr.smmla %c, %a, %b :
+             vector<32xi8> to vector<8xi32>
+  return %0 : vector<8xi32>
+}
+
+// -----
+
+func.func @ummla_invalid_input_types(%a: vector<16xi4>,
+                    %b: vector<16xi4>,
+                    %c: vector<4xi32>) -> vector<4xi32> {
+  // expected-error at +1 {{op operand #1 must be vector of 8-bit signless integer values of length 16, but got 'vector<16xi4>'}}
+  %0 = arm_neon.intr.ummla %c, %a, %b :
+             vector<16xi4> to vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// -----
+
+func.func @ummla_invalid_dimensions(%a: vector<32xi8>,
+                    %b: vector<32xi8>,
+                    %c: vector<8xi32>) -> vector<8xi32> {
+  // expected-error at +1 {{op operand #0 must be vector of 32-bit signless integer values of length 4, but got 'vector<8xi32>'}}
+  %0 = arm_neon.intr.ummla %c, %a, %b :
+             vector<32xi8> to vector<8xi32>
+  return %0 : vector<8xi32>
+}
+
+// -----
+
+func.func @usmmla_invalid_input_types(%a: vector<16xi4>,
+                    %b: vector<16xi4>,
+                    %c: vector<4xi32>) -> vector<4xi32> {
+  // expected-error at +1 {{op operand #1 must be vector of 8-bit signless integer values of length 16, but got 'vector<16xi4>'}}
+  %0 = arm_neon.intr.usmmla %c, %a, %b :
+             vector<16xi4> to vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// -----
+
+func.func @usmmla_invalid_dimensions(%a: vector<32xi8>,
+                    %b: vector<32xi8>,
+                    %c: vector<8xi32>) -> vector<8xi32> {
+  // expected-error at +1 {{op operand #0 must be vector of 32-bit signless integer values of length 4, but got 'vector<8xi32>'}}
+  %0 = arm_neon.intr.usmmla %c, %a, %b :
+             vector<32xi8> to vector<8xi32>
+  return %0 : vector<8xi32>
+}
diff --git a/mlir/test/Dialect/ArmNeon/roundtrip.mlir b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
index 704bfe8c084a5..c3a4692c4a39a 100644
--- a/mlir/test/Dialect/ArmNeon/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt -verify-diagnostics -split-input-file %s | mlir-opt | FileCheck %s
 
 // CHECK-LABEL: arm_neon_smull
 func.func @arm_neon_smull(%a: vector<8xi8>, %b: vector<8xi8>)
@@ -25,3 +25,36 @@ func.func @arm_neon_sdot(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<8xi8>)
   %0 = arm_neon.intr.sdot %a, %b, %c : vector<8xi8>, vector<8xi8> to vector<2xi32>
   return %0 : vector<2xi32>
 }
+
+// -----
+
+func.func @arm_neon_smmla(%a: vector<16xi8>,
+                    %b: vector<16xi8>,
+                    %c: vector<4xi32>) -> vector<4xi32> {
+  // CHECK: arm_neon.intr.smmla {{.*}}: vector<16xi8> to vector<4xi3
+  %0 = arm_neon.intr.smmla %c, %a, %b :
+             vector<16xi8> to vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// -----
+
+func.func @arm_neon_ummla(%a: vector<16xi8>,
+                    %b: vector<16xi8>,
+                    %c: vector<4xi32>) -> vector<4xi32> {
+  // CHECK: arm_neon.intr.ummla {{.*}}: vector<16xi8> to vector<4xi3
+  %0 = arm_neon.intr.ummla %c, %a, %b :
+             vector<16xi8> to vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// -----
+
+func.func @arm_neon_usmmla(%a: vector<16xi8>,
+                    %b: vector<16xi8>,
+                    %c: vector<4xi32>) -> vector<4xi32> {
+  // CHECK: arm_neon.intr.usmmla {{.*}}: vector<16xi8> to vector<4xi3
+  %0 = arm_neon.intr.usmmla %c, %a, %b :
+             vector<16xi8> to vector<4xi32>
+  return %0 : vector<4xi32>
+}
diff --git a/mlir/test/Target/LLVMIR/arm-neon.mlir b/mlir/test/Target/LLVMIR/arm-neon.mlir
index f4716fe58f203..2577f2d49e326 100644
--- a/mlir/test/Target/LLVMIR/arm-neon.mlir
+++ b/mlir/test/Target/LLVMIR/arm-neon.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file  %s | FileCheck %s
 
 // CHECK-LABEL: arm_neon_smull
 llvm.func @arm_neon_smull(%arg0: vector<8xi8>, %arg1: vector<8xi8>) -> !llvm.struct<(vector<8xi16>, vector<4xi32>, vector<2xi64>)> {
@@ -39,3 +39,45 @@ llvm.func @arm_neon_sdot_16_i8i8(%a: vector<4xi32>, %b: vector<16xi8>, %c: vecto
   %0 = arm_neon.intr.sdot %a, %b, %c : vector<16xi8>, vector<16xi8> to vector<4xi32>
   llvm.return %0 : vector<4xi32>
 }
+
+// -----
+
+// CHECK-LABEL: define <4 x i32> @arm_neon_smmla
+llvm.func @arm_neon_smmla(%arg0: vector<16xi8>,
+                         %arg1: vector<16xi8>,
+                         %arg2: vector<4xi32>)
+                         -> vector<4xi32> {
+  // CHECK: <4 x i32> @llvm.aarch64.neon.smmla.v4i32.v16i8(<4 x i32
+  %0 = "arm_neon.intr.smmla"(%arg2, %arg0, %arg1) :
+    (vector<4xi32>, vector<16xi8>, vector<16xi8>)
+        -> vector<4xi32>
+  llvm.return %0 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: define <4 x i32> @arm_neon_ummla
+llvm.func @arm_neon_ummla(%arg0: vector<16xi8>,
+                         %arg1: vector<16xi8>,
+                         %arg2: vector<4xi32>)
+                         -> vector<4xi32> {
+  // CHECK: <4 x i32> @llvm.aarch64.neon.ummla.v4i32.v16i8(<4 x i32
+  %0 = "arm_neon.intr.ummla"(%arg2, %arg0, %arg1) :
+    (vector<4xi32>, vector<16xi8>, vector<16xi8>)
+        -> vector<4xi32>
+  llvm.return %0 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: define <4 x i32> @arm_neon_usmmla
+llvm.func @arm_neon_usmmla(%arg0: vector<16xi8>,
+                         %arg1: vector<16xi8>,
+                         %arg2: vector<4xi32>)
+                         -> vector<4xi32> {
+  // CHECK: <4 x i32> @llvm.aarch64.neon.usmmla.v4i32.v16i8(<4 x i32
+  %0 = "arm_neon.intr.usmmla"(%arg2, %arg0, %arg1) :
+    (vector<4xi32>, vector<16xi8>, vector<16xi8>)
+        -> vector<4xi32>
+  llvm.return %0 : vector<4xi32>
+}

>From 034d6ed5cb697e988cd59f7d53884acc70e3b293 Mon Sep 17 00:00:00 2001
From: Kojo Acquah <kooljblack at google.com>
Date: Fri, 2 Feb 2024 20:08:34 -0500
Subject: [PATCH 2/3] diego nits

---
 mlir/test/Dialect/ArmNeon/invalid.mlir   | 42 ++++++++++--------------
 mlir/test/Dialect/ArmNeon/roundtrip.mlir | 24 +++++++-------
 mlir/test/Target/LLVMIR/arm-neon.mlir    | 12 +++----
 3 files changed, 36 insertions(+), 42 deletions(-)

diff --git a/mlir/test/Dialect/ArmNeon/invalid.mlir b/mlir/test/Dialect/ArmNeon/invalid.mlir
index d8879c56857dc..3ad763e7b4982 100644
--- a/mlir/test/Dialect/ArmNeon/invalid.mlir
+++ b/mlir/test/Dialect/ArmNeon/invalid.mlir
@@ -35,65 +35,59 @@ func.func @b_has_2_rows_but_a_has_length_4(%a : vector<4xi32>, %b : vector<2x4xi
 // -----
 
 func.func @smmla_invalid_input_types(%a: vector<16xi4>,
-                    %b: vector<16xi4>,
-                    %c: vector<4xi32>) -> vector<4xi32> {
+                                     %b: vector<16xi4>,
+                                     %c: vector<4xi32>) -> vector<4xi32> {
   // expected-error at +1 {{op operand #1 must be vector of 8-bit signless integer values of length 16, but got 'vector<16xi4>'}}
-  %0 = arm_neon.intr.smmla %c, %a, %b :
-             vector<16xi4> to vector<4xi32>
+  %0 = arm_neon.intr.smmla %c, %a, %b : vector<16xi4> to vector<4xi32>
   return %0 : vector<4xi32>
 }
 
 // -----
 
 func.func @smmla_invalid_dimensions(%a: vector<32xi8>,
-                    %b: vector<32xi8>,
-                    %c: vector<8xi32>) -> vector<8xi32> {
+                                    %b: vector<32xi8>,
+                                    %c: vector<8xi32>) -> vector<8xi32> {
   // expected-error at +1 {{op operand #0 must be vector of 32-bit signless integer values of length 4, but got 'vector<8xi32>'}}
-  %0 = arm_neon.intr.smmla %c, %a, %b :
-             vector<32xi8> to vector<8xi32>
+  %0 = arm_neon.intr.smmla %c, %a, %b : vector<32xi8> to vector<8xi32>
   return %0 : vector<8xi32>
 }
 
 // -----
 
 func.func @ummla_invalid_input_types(%a: vector<16xi4>,
-                    %b: vector<16xi4>,
-                    %c: vector<4xi32>) -> vector<4xi32> {
+                                     %b: vector<16xi4>,
+                                     %c: vector<4xi32>) -> vector<4xi32> {
   // expected-error at +1 {{op operand #1 must be vector of 8-bit signless integer values of length 16, but got 'vector<16xi4>'}}
-  %0 = arm_neon.intr.ummla %c, %a, %b :
-             vector<16xi4> to vector<4xi32>
+  %0 = arm_neon.intr.ummla %c, %a, %b : vector<16xi4> to vector<4xi32>
   return %0 : vector<4xi32>
 }
 
 // -----
 
 func.func @ummla_invalid_dimensions(%a: vector<32xi8>,
-                    %b: vector<32xi8>,
-                    %c: vector<8xi32>) -> vector<8xi32> {
+                                    %b: vector<32xi8>,
+                                    %c: vector<8xi32>) -> vector<8xi32> {
   // expected-error at +1 {{op operand #0 must be vector of 32-bit signless integer values of length 4, but got 'vector<8xi32>'}}
-  %0 = arm_neon.intr.ummla %c, %a, %b :
-             vector<32xi8> to vector<8xi32>
+  %0 = arm_neon.intr.ummla %c, %a, %b : vector<32xi8> to vector<8xi32>
   return %0 : vector<8xi32>
 }
 
 // -----
 
 func.func @usmmla_invalid_input_types(%a: vector<16xi4>,
-                    %b: vector<16xi4>,
-                    %c: vector<4xi32>) -> vector<4xi32> {
+                                      %b: vector<16xi4>,
+                                      %c: vector<4xi32>) -> vector<4xi32> {
   // expected-error at +1 {{op operand #1 must be vector of 8-bit signless integer values of length 16, but got 'vector<16xi4>'}}
-  %0 = arm_neon.intr.usmmla %c, %a, %b :
-             vector<16xi4> to vector<4xi32>
+  %0 = arm_neon.intr.usmmla %c, %a, %b : vector<16xi4> to vector<4xi32>
   return %0 : vector<4xi32>
 }
 
 // -----
 
 func.func @usmmla_invalid_dimensions(%a: vector<32xi8>,
-                    %b: vector<32xi8>,
-                    %c: vector<8xi32>) -> vector<8xi32> {
+                                     %b: vector<32xi8>,
+                                     %c: vector<8xi32>) -> vector<8xi32> {
   // expected-error at +1 {{op operand #0 must be vector of 32-bit signless integer values of length 4, but got 'vector<8xi32>'}}
-  %0 = arm_neon.intr.usmmla %c, %a, %b :
-             vector<32xi8> to vector<8xi32>
+  %0 = arm_neon.intr.usmmla %c, %a, %b : vector<32xi8> to vector<8xi32>
   return %0 : vector<8xi32>
 }
diff --git a/mlir/test/Dialect/ArmNeon/roundtrip.mlir b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
index c3a4692c4a39a..30afe325a482c 100644
--- a/mlir/test/Dialect/ArmNeon/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
@@ -28,33 +28,33 @@ func.func @arm_neon_sdot(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<8xi8>)
 
 // -----
 
+// CHECK-LABEL: arm_neon_smmla
 func.func @arm_neon_smmla(%a: vector<16xi8>,
-                    %b: vector<16xi8>,
-                    %c: vector<4xi32>) -> vector<4xi32> {
+                          %b: vector<16xi8>,
+                          %c: vector<4xi32>) -> vector<4xi32> {
   // CHECK: arm_neon.intr.smmla {{.*}}: vector<16xi8> to vector<4xi3
-  %0 = arm_neon.intr.smmla %c, %a, %b :
-             vector<16xi8> to vector<4xi32>
+  %0 = arm_neon.intr.smmla %c, %a, %b : vector<16xi8> to vector<4xi32>
   return %0 : vector<4xi32>
 }
 
 // -----
 
+// CHECK-LABEL: arm_neon_ummla
 func.func @arm_neon_ummla(%a: vector<16xi8>,
-                    %b: vector<16xi8>,
-                    %c: vector<4xi32>) -> vector<4xi32> {
+                          %b: vector<16xi8>,
+                          %c: vector<4xi32>) -> vector<4xi32> {
   // CHECK: arm_neon.intr.ummla {{.*}}: vector<16xi8> to vector<4xi3
-  %0 = arm_neon.intr.ummla %c, %a, %b :
-             vector<16xi8> to vector<4xi32>
+  %0 = arm_neon.intr.ummla %c, %a, %b : vector<16xi8> to vector<4xi32>
   return %0 : vector<4xi32>
 }
 
 // -----
 
+// CHECK-LABEL: arm_neon_usmmla
 func.func @arm_neon_usmmla(%a: vector<16xi8>,
-                    %b: vector<16xi8>,
-                    %c: vector<4xi32>) -> vector<4xi32> {
+                            %b: vector<16xi8>,
+                            %c: vector<4xi32>) -> vector<4xi32> {
   // CHECK: arm_neon.intr.usmmla {{.*}}: vector<16xi8> to vector<4xi3
-  %0 = arm_neon.intr.usmmla %c, %a, %b :
-             vector<16xi8> to vector<4xi32>
+  %0 = arm_neon.intr.usmmla %c, %a, %b : vector<16xi8> to vector<4xi32>
   return %0 : vector<4xi32>
 }
diff --git a/mlir/test/Target/LLVMIR/arm-neon.mlir b/mlir/test/Target/LLVMIR/arm-neon.mlir
index 2577f2d49e326..e5b37ea3c8a5d 100644
--- a/mlir/test/Target/LLVMIR/arm-neon.mlir
+++ b/mlir/test/Target/LLVMIR/arm-neon.mlir
@@ -44,8 +44,8 @@ llvm.func @arm_neon_sdot_16_i8i8(%a: vector<4xi32>, %b: vector<16xi8>, %c: vecto
 
 // CHECK-LABEL: define <4 x i32> @arm_neon_smmla
 llvm.func @arm_neon_smmla(%arg0: vector<16xi8>,
-                         %arg1: vector<16xi8>,
-                         %arg2: vector<4xi32>)
+                          %arg1: vector<16xi8>,
+                          %arg2: vector<4xi32>)
                          -> vector<4xi32> {
   // CHECK: <4 x i32> @llvm.aarch64.neon.smmla.v4i32.v16i8(<4 x i32
   %0 = "arm_neon.intr.smmla"(%arg2, %arg0, %arg1) :
@@ -58,8 +58,8 @@ llvm.func @arm_neon_smmla(%arg0: vector<16xi8>,
 
 // CHECK-LABEL: define <4 x i32> @arm_neon_ummla
 llvm.func @arm_neon_ummla(%arg0: vector<16xi8>,
-                         %arg1: vector<16xi8>,
-                         %arg2: vector<4xi32>)
+                          %arg1: vector<16xi8>,
+                          %arg2: vector<4xi32>)
                          -> vector<4xi32> {
   // CHECK: <4 x i32> @llvm.aarch64.neon.ummla.v4i32.v16i8(<4 x i32
   %0 = "arm_neon.intr.ummla"(%arg2, %arg0, %arg1) :
@@ -72,8 +72,8 @@ llvm.func @arm_neon_ummla(%arg0: vector<16xi8>,
 
 // CHECK-LABEL: define <4 x i32> @arm_neon_usmmla
 llvm.func @arm_neon_usmmla(%arg0: vector<16xi8>,
-                         %arg1: vector<16xi8>,
-                         %arg2: vector<4xi32>)
+                          %arg1: vector<16xi8>,
+                          %arg2: vector<4xi32>)
                          -> vector<4xi32> {
   // CHECK: <4 x i32> @llvm.aarch64.neon.usmmla.v4i32.v16i8(<4 x i32
   %0 = "arm_neon.intr.usmmla"(%arg2, %arg0, %arg1) :

>From 97527f80dbfac724952bcaca9145c6013361581a Mon Sep 17 00:00:00 2001
From: Kojo Acquah <kooljblack at google.com>
Date: Mon, 5 Feb 2024 17:08:22 -0500
Subject: [PATCH 3/3]  banach-space comments

---
 mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td | 36 ++++++++++----------
 mlir/test/Dialect/ArmNeon/invalid.mlir       | 36 ++++++++++----------
 mlir/test/Dialect/ArmNeon/roundtrip.mlir     |  8 +++--
 mlir/test/Target/LLVMIR/arm-neon.mlir        | 20 ++++++-----
 4 files changed, 52 insertions(+), 48 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
index c515a858ee8a1..25ae5a33692f9 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
+++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
@@ -122,8 +122,8 @@ def SdotOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"sdot",[1], [
 
 def SmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"smmla",[1], [
                 Pure,
-                AllTypesMatch<["b", "c"]>,
-                AllTypesMatch<["a", "res"]>,
+                AllTypesMatch<["src1", "src2"]>,
+                AllTypesMatch<["acc", "res"]>,
               ]> {
   let summary = "Matrix-matrix multiply and accumulate op";
   let description = [{
@@ -141,19 +141,19 @@ def SmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"smmla",[1], [
   }];
   // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
   let arguments = (ins
-          VectorOfLengthAndType<[4], [I32]>:$a,
-          VectorOfLengthAndType<[16], [I8]>:$b,
-          VectorOfLengthAndType<[16], [I8]>:$c
+          VectorOfLengthAndType<[4], [I32]>:$acc,
+          VectorOfLengthAndType<[16], [I8]>:$src1,
+          VectorOfLengthAndType<[16], [I8]>:$src2
   );
   let results = (outs VectorOfLengthAndType<[4], [I32]>:$res);
   let assemblyFormat =
-    "$a `,` $b `,` $c attr-dict `:` type($b) `to` type($res)";
+    "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
 }
 
 def UmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"ummla",[1], [
                 Pure,
-                AllTypesMatch<["b", "c"]>,
-                AllTypesMatch<["a", "res"]>,
+                AllTypesMatch<["src1", "src2"]>,
+                AllTypesMatch<["acc", "res"]>,
               ]> {
   let summary = "Unsinged matrix-matrix multiply and accumulate op";
   let description = [{
@@ -172,19 +172,19 @@ def UmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"ummla",[1], [
   }];
   // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
   let arguments = (ins
-          VectorOfLengthAndType<[4], [I32]>:$a,
-          VectorOfLengthAndType<[16], [I8]>:$b,
-          VectorOfLengthAndType<[16], [I8]>:$c
+          VectorOfLengthAndType<[4], [I32]>:$acc,
+          VectorOfLengthAndType<[16], [I8]>:$src1,
+          VectorOfLengthAndType<[16], [I8]>:$src2
   );
   let results = (outs VectorOfLengthAndType<[4], [I32]>:$res);
   let assemblyFormat =
-    "$a `,` $b `,` $c attr-dict `:` type($b) `to` type($res)";
+    "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
 }
 
 def UsmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"usmmla",[1], [
                 Pure,
-                AllTypesMatch<["b", "c"]>,
-                AllTypesMatch<["a", "res"]>,
+                AllTypesMatch<["src1", "src2"]>,
+                AllTypesMatch<["acc", "res"]>,
               ]> {
   let summary = "Unsignged and signed matrix-matrix multiply and accumulate op";
   let description = [{
@@ -204,13 +204,13 @@ def UsmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"usmmla",[1], [
   }];
   // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
   let arguments = (ins
-          VectorOfLengthAndType<[4], [I32]>:$a,
-          VectorOfLengthAndType<[16], [I8]>:$b,
-          VectorOfLengthAndType<[16], [I8]>:$c
+          VectorOfLengthAndType<[4], [I32]>:$acc,
+          VectorOfLengthAndType<[16], [I8]>:$src1,
+          VectorOfLengthAndType<[16], [I8]>:$src2
   );
   let results = (outs VectorOfLengthAndType<[4], [I32]>:$res);
   let assemblyFormat =
-    "$a `,` $b `,` $c attr-dict `:` type($b) `to` type($res)";
+    "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
 }
 
 class ArmNeon_2dOp<string mnemonic, list<Trait> traits = []>
diff --git a/mlir/test/Dialect/ArmNeon/invalid.mlir b/mlir/test/Dialect/ArmNeon/invalid.mlir
index 3ad763e7b4982..e679cf1280d18 100644
--- a/mlir/test/Dialect/ArmNeon/invalid.mlir
+++ b/mlir/test/Dialect/ArmNeon/invalid.mlir
@@ -34,60 +34,60 @@ func.func @b_has_2_rows_but_a_has_length_4(%a : vector<4xi32>, %b : vector<2x4xi
 
 // -----
 
-func.func @smmla_invalid_input_types(%a: vector<16xi4>,
+func.func @smmla_invalid_input_types(%a: vector<4xi32>,
                                      %b: vector<16xi4>,
-                                     %c: vector<4xi32>) -> vector<4xi32> {
+                                     %c: vector<16xi4>) -> vector<4xi32> {
   // expected-error at +1 {{op operand #1 must be vector of 8-bit signless integer values of length 16, but got 'vector<16xi4>'}}
-  %0 = arm_neon.intr.smmla %c, %a, %b : vector<16xi4> to vector<4xi32>
+  %0 = arm_neon.intr.smmla %a, %b, %c : vector<16xi4> to vector<4xi32>
   return %0 : vector<4xi32>
 }
 
 // -----
 
-func.func @smmla_invalid_dimensions(%a: vector<32xi8>,
+func.func @smmla_invalid_dimensions(%a: vector<8xi32>,
                                     %b: vector<32xi8>,
-                                    %c: vector<8xi32>) -> vector<8xi32> {
+                                    %c: vector<32xi8>) -> vector<8xi32> {
   // expected-error at +1 {{op operand #0 must be vector of 32-bit signless integer values of length 4, but got 'vector<8xi32>'}}
-  %0 = arm_neon.intr.smmla %c, %a, %b : vector<32xi8> to vector<8xi32>
+  %0 = arm_neon.intr.smmla %a, %b, %c : vector<32xi8> to vector<8xi32>
   return %0 : vector<8xi32>
 }
 
 // -----
 
-func.func @ummla_invalid_input_types(%a: vector<16xi4>,
+func.func @ummla_invalid_input_types(%a: vector<4xi32>,
                                      %b: vector<16xi4>,
-                                     %c: vector<4xi32>) -> vector<4xi32> {
+                                     %c: vector<16xi4>) -> vector<4xi32> {
   // expected-error at +1 {{op operand #1 must be vector of 8-bit signless integer values of length 16, but got 'vector<16xi4>'}}
-  %0 = arm_neon.intr.ummla %c, %a, %b : vector<16xi4> to vector<4xi32>
+  %0 = arm_neon.intr.ummla %a, %b, %c : vector<16xi4> to vector<4xi32>
   return %0 : vector<4xi32>
 }
 
 // -----
 
-func.func @ummla_invalid_dimensions(%a: vector<32xi8>,
+func.func @ummla_invalid_dimensions(%a: vector<8xi32>,
                                     %b: vector<32xi8>,
-                                    %c: vector<8xi32>) -> vector<8xi32> {
+                                    %c: vector<32xi8>) -> vector<8xi32> {
   // expected-error at +1 {{op operand #0 must be vector of 32-bit signless integer values of length 4, but got 'vector<8xi32>'}}
-  %0 = arm_neon.intr.ummla %c, %a, %b : vector<32xi8> to vector<8xi32>
+  %0 = arm_neon.intr.ummla %a, %b, %c : vector<32xi8> to vector<8xi32>
   return %0 : vector<8xi32>
 }
 
 // -----
 
-func.func @usmmla_invalid_input_types(%a: vector<16xi4>,
+func.func @usmmla_invalid_input_types(%a: vector<4xi32>,
                                       %b: vector<16xi4>,
-                                      %c: vector<4xi32>) -> vector<4xi32> {
+                                      %c: vector<16xi4>) -> vector<4xi32> {
   // expected-error at +1 {{op operand #1 must be vector of 8-bit signless integer values of length 16, but got 'vector<16xi4>'}}
-  %0 = arm_neon.intr.usmmla %c, %a, %b : vector<16xi4> to vector<4xi32>
+  %0 = arm_neon.intr.usmmla %a, %b, %c : vector<16xi4> to vector<4xi32>
   return %0 : vector<4xi32>
 }
 
 // -----
 
-func.func @usmmla_invalid_dimensions(%a: vector<32xi8>,
+func.func @usmmla_invalid_dimensions(%a: vector<8xi32>,
                                      %b: vector<32xi8>,
-                                     %c: vector<8xi32>) -> vector<8xi32> {
+                                     %c: vector<32xi8>) -> vector<8xi32> {
   // expected-error at +1 {{op operand #0 must be vector of 32-bit signless integer values of length 4, but got 'vector<8xi32>'}}
-  %0 = arm_neon.intr.usmmla %c, %a, %b : vector<32xi8> to vector<8xi32>
+  %0 = arm_neon.intr.usmmla %a, %b, %c : vector<32xi8> to vector<8xi32>
   return %0 : vector<8xi32>
 }
diff --git a/mlir/test/Dialect/ArmNeon/roundtrip.mlir b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
index 30afe325a482c..b5df0ffa8105c 100644
--- a/mlir/test/Dialect/ArmNeon/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
@@ -19,6 +19,8 @@ func.func @arm_neon_smull(%a: vector<8xi8>, %b: vector<8xi8>)
   return %0, %1, %2 : vector<8xi16>, vector<4xi32>, vector<2xi64>
 }
 
+// -----
+
 // CHECK-LABEL: arm_neon_sdot
 func.func @arm_neon_sdot(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<8xi8>) -> vector<2xi32> {
   // CHECK: arm_neon.intr.sdot {{.*}}: vector<8xi8>, vector<8xi8> to vector<2xi32>
@@ -32,7 +34,7 @@ func.func @arm_neon_sdot(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<8xi8>)
 func.func @arm_neon_smmla(%a: vector<16xi8>,
                           %b: vector<16xi8>,
                           %c: vector<4xi32>) -> vector<4xi32> {
-  // CHECK: arm_neon.intr.smmla {{.*}}: vector<16xi8> to vector<4xi3
+  // CHECK: arm_neon.intr.smmla {{.*}}: vector<16xi8> to vector<4xi32>
   %0 = arm_neon.intr.smmla %c, %a, %b : vector<16xi8> to vector<4xi32>
   return %0 : vector<4xi32>
 }
@@ -43,7 +45,7 @@ func.func @arm_neon_smmla(%a: vector<16xi8>,
 func.func @arm_neon_ummla(%a: vector<16xi8>,
                           %b: vector<16xi8>,
                           %c: vector<4xi32>) -> vector<4xi32> {
-  // CHECK: arm_neon.intr.ummla {{.*}}: vector<16xi8> to vector<4xi3
+  // CHECK: arm_neon.intr.ummla {{.*}}: vector<16xi8> to vector<4xi32>
   %0 = arm_neon.intr.ummla %c, %a, %b : vector<16xi8> to vector<4xi32>
   return %0 : vector<4xi32>
 }
@@ -54,7 +56,7 @@ func.func @arm_neon_ummla(%a: vector<16xi8>,
 func.func @arm_neon_usmmla(%a: vector<16xi8>,
                             %b: vector<16xi8>,
                             %c: vector<4xi32>) -> vector<4xi32> {
-  // CHECK: arm_neon.intr.usmmla {{.*}}: vector<16xi8> to vector<4xi3
+  // CHECK: arm_neon.intr.usmmla {{.*}}: vector<16xi8> to vector<4xi32>
   %0 = arm_neon.intr.usmmla %c, %a, %b : vector<16xi8> to vector<4xi32>
   return %0 : vector<4xi32>
 }
diff --git a/mlir/test/Target/LLVMIR/arm-neon.mlir b/mlir/test/Target/LLVMIR/arm-neon.mlir
index e5b37ea3c8a5d..7cd8bf8006008 100644
--- a/mlir/test/Target/LLVMIR/arm-neon.mlir
+++ b/mlir/test/Target/LLVMIR/arm-neon.mlir
@@ -24,6 +24,8 @@ llvm.func @arm_neon_smull(%arg0: vector<8xi8>, %arg1: vector<8xi8>) -> !llvm.str
   llvm.return %8 : !llvm.struct<(vector<8xi16>, vector<4xi32>, vector<2xi64>)>
 }
 
+// -----
+
 // CHECK-LABEL: arm_neon_sdot_8_i8i8
 llvm.func @arm_neon_sdot_8_i8i8(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<8xi8>) -> vector<2xi32> {
   // CHECK: %[[V0:.*]] = call <2 x i32> @llvm.aarch64.neon.sdot.v2i32.v8i8(<2 x i32> %{{.*}}, <8 x i8> %{{.*}}, <8 x i8> %{{.*}})
@@ -32,6 +34,9 @@ llvm.func @arm_neon_sdot_8_i8i8(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<
   llvm.return %0 : vector<2xi32>
 }
 
+// -----
+
+
 // CHECK-LABEL: arm_neon_sdot_16_i8i8
 llvm.func @arm_neon_sdot_16_i8i8(%a: vector<4xi32>, %b: vector<16xi8>, %c: vector<16xi8>) -> vector<4xi32> {
   // CHECK: %[[V0:.*]] = call <4 x i32> @llvm.aarch64.neon.sdot.v4i32.v16i8(<4 x i32> %{{.*}}, <16 x i8> %{{.*}}, <16 x i8> %{{.*}})
@@ -42,11 +47,10 @@ llvm.func @arm_neon_sdot_16_i8i8(%a: vector<4xi32>, %b: vector<16xi8>, %c: vecto
 
 // -----
 
-// CHECK-LABEL: define <4 x i32> @arm_neon_smmla
+// CHECK-LABEL: arm_neon_smmla
 llvm.func @arm_neon_smmla(%arg0: vector<16xi8>,
                           %arg1: vector<16xi8>,
-                          %arg2: vector<4xi32>)
-                         -> vector<4xi32> {
+                          %arg2: vector<4xi32>) -> vector<4xi32> {
   // CHECK: <4 x i32> @llvm.aarch64.neon.smmla.v4i32.v16i8(<4 x i32
   %0 = "arm_neon.intr.smmla"(%arg2, %arg0, %arg1) :
     (vector<4xi32>, vector<16xi8>, vector<16xi8>)
@@ -56,11 +60,10 @@ llvm.func @arm_neon_smmla(%arg0: vector<16xi8>,
 
 // -----
 
-// CHECK-LABEL: define <4 x i32> @arm_neon_ummla
+// CHECK-LABEL: arm_neon_ummla
 llvm.func @arm_neon_ummla(%arg0: vector<16xi8>,
                           %arg1: vector<16xi8>,
-                          %arg2: vector<4xi32>)
-                         -> vector<4xi32> {
+                          %arg2: vector<4xi32>) -> vector<4xi32> {
   // CHECK: <4 x i32> @llvm.aarch64.neon.ummla.v4i32.v16i8(<4 x i32
   %0 = "arm_neon.intr.ummla"(%arg2, %arg0, %arg1) :
     (vector<4xi32>, vector<16xi8>, vector<16xi8>)
@@ -70,11 +73,10 @@ llvm.func @arm_neon_ummla(%arg0: vector<16xi8>,
 
 // -----
 
-// CHECK-LABEL: define <4 x i32> @arm_neon_usmmla
+// CHECK-LABEL: arm_neon_usmmla
 llvm.func @arm_neon_usmmla(%arg0: vector<16xi8>,
                           %arg1: vector<16xi8>,
-                          %arg2: vector<4xi32>)
-                         -> vector<4xi32> {
+                          %arg2: vector<4xi32>) -> vector<4xi32> {
   // CHECK: <4 x i32> @llvm.aarch64.neon.usmmla.v4i32.v16i8(<4 x i32
   %0 = "arm_neon.intr.usmmla"(%arg2, %arg0, %arg1) :
     (vector<4xi32>, vector<16xi8>, vector<16xi8>)



More information about the Mlir-commits mailing list