[Mlir-commits] [mlir] [MLIR][ArmNeon] Add an ArmNeon operation which maps to `bfmmla` (PR #145038)

Momchil Velikov llvmlistbot at llvm.org
Wed Jun 25 04:11:38 PDT 2025


https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/145038

>From 154b104143a097df6118c3c2fe29124a4ed8da9b Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Fri, 20 Jun 2025 13:36:45 +0000
Subject: [PATCH 1/2] [MLIR][ArmNeon] Add an ArmNeon operation which maps to
 `bfmmla`

---
 mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td | 28 ++++++++++++++++++++
 mlir/test/Dialect/ArmNeon/roundtrip.mlir     | 12 +++++++++
 mlir/test/Target/LLVMIR/arm-neon.mlir        | 13 +++++++++
 3 files changed, 53 insertions(+)

diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
index 475b11f12c5f0..ce86ff2cfd922 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
+++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
@@ -222,6 +222,34 @@ def UsmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"usmmla",[1], [
     "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
 }
 
+def BfmmlaOp : ArmNeon_IntrOp<"bfmmla", [], [], 1, [
+                 Pure,
+                 AllTypesMatch<["src1", "src2"]>,
+                 AllTypesMatch<["acc", "res"]>,
+               ]> {
+  let summary = "BFloat16 matrix multiply-accumulate to single-precision";
+  let description = [{
+    BFMMLA: BFloat16 matrix multiply-accumulate to single-precision.
+
+    The operation multiplies the 2x4 BFloat16 matrix in the first source vector
+    with the 4x2 BFloat16 matrix in the second source vector, then accumulates
+    this intermediate result with the 2x2 Float32 matrix in the accumulator
+    vector, yielding the final 2x2 Float32 result.
+
+    Source:
+    https://developer.arm.com/architectures/instruction-sets/intrinsics/vbfmmlaq_f32
+  }];
+  // Supports (vector<8xbf16>, vector<8xbf16>) -> (vector<2xf32>)
+  let arguments = (ins
+    NeonVectorOfLength<4, F32>:$acc,
+    NeonVectorOfLength<8, BF16>:$src1,
+    NeonVectorOfLength<8, BF16>:$src2
+  );
+  let results = (outs NeonVectorOfLength<4, F32>:$res);
+  let assemblyFormat =
+    "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
+}
+
 class ArmNeon_2dOp<string mnemonic, list<Trait> traits = []>
     : Op</*dialect=*/ArmNeon_Dialect,
          /*opName=*/"2d." # mnemonic,
diff --git a/mlir/test/Dialect/ArmNeon/roundtrip.mlir b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
index b5df0ffa8105c..60133ce0fa6f3 100644
--- a/mlir/test/Dialect/ArmNeon/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
@@ -60,3 +60,15 @@ func.func @arm_neon_usmmla(%a: vector<16xi8>,
   %0 = arm_neon.intr.usmmla %c, %a, %b : vector<16xi8> to vector<4xi32>
   return %0 : vector<4xi32>
 }
+
+
+// -----
+
+// CHECK-LABEL: arm_neon_bfmmla
+func.func @arm_neon_bfmmla(%a: vector<8xbf16>,
+                           %b: vector<8xbf16>,
+                           %c: vector<4xf32>) -> vector<4xf32> {
+  // CHECK: arm_neon.intr.bfmmla {{.*}}: vector<8xbf16> to vector<4xf32>
+  %0 = arm_neon.intr.bfmmla %c, %a, %b : vector<8xbf16> to vector<4xf32>
+  return %0 : vector<4xf32>
+}
diff --git a/mlir/test/Target/LLVMIR/arm-neon.mlir b/mlir/test/Target/LLVMIR/arm-neon.mlir
index e096172667c9f..e1328ad448f0a 100644
--- a/mlir/test/Target/LLVMIR/arm-neon.mlir
+++ b/mlir/test/Target/LLVMIR/arm-neon.mlir
@@ -82,3 +82,16 @@ llvm.func @arm_neon_usmmla(%arg0: vector<16xi8>,
         -> vector<4xi32>
   llvm.return %0 : vector<4xi32>
 }
+
+// -----
+
+// CHECK-LABEL: arm_neon_bfmmla
+llvm.func @arm_neon_bfmmla(%arg0: vector<8xbf16>,
+                           %arg1: vector<8xbf16>,
+                           %arg2: vector<4xf32>) -> vector<4xf32> {
+  // CHECK: <4 x float> @llvm.aarch64.neon.bfmmla(<4 x float
+  %0 = "arm_neon.intr.bfmmla"(%arg2, %arg0, %arg1) :
+    (vector<4xf32>, vector<8xbf16>, vector<8xbf16>)
+        -> vector<4xf32>
+  llvm.return %0 : vector<4xf32>
+}

>From b86bc0d036068b2a9abc99ab00275d83d661aaec Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Wed, 25 Jun 2025 11:05:44 +0000
Subject: [PATCH 2/2] [fixup] Add tests to `Dialect/ArmNeon/invalid.mlir`

---
 mlir/test/Dialect/ArmNeon/invalid.mlir | 40 ++++++++++++++++++++++++++
 1 file changed, 40 insertions(+)

diff --git a/mlir/test/Dialect/ArmNeon/invalid.mlir b/mlir/test/Dialect/ArmNeon/invalid.mlir
index be49e0e4597b0..989293a3508b4 100644
--- a/mlir/test/Dialect/ArmNeon/invalid.mlir
+++ b/mlir/test/Dialect/ArmNeon/invalid.mlir
@@ -91,3 +91,43 @@ func.func @usmmla_invalid_dimensions(%a: vector<8xi32>,
   %0 = arm_neon.intr.usmmla %a, %b, %c : vector<32xi8> to vector<8xi32>
   return %0 : vector<8xi32>
 }
+
+// -----
+
+func.func @bfmmla_invalid_element_type_lhs_rhs(%acc: vector<4xf32>,
+                                               %lhs: vector<8xf16>,
+                                               %rhs: vector<8xf16>) -> vector<4xf32> {
+  // expected-error at +1 {{operand #1 must be a vector with length 8 of bfloat16 type values, but got 'vector<8xf16>'}}
+  %0 = arm_neon.intr.bfmmla %acc, %lhs, %rhs : vector<8xf16> to vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// -----
+
+func.func @bfmmla_invalid_dimension_lhs_rhs(%acc: vector<4xf32>,
+                                            %lhs: vector<4xbf16>,
+                                            %rhs: vector<4xbf16>) -> vector<4xf32> {
+  // expected-error at +1 {{operand #1 must be a vector with length 8 of bfloat16 type values, but got 'vector<4xbf16>'}}
+  %0 = arm_neon.intr.bfmmla %acc, %lhs, %rhs : vector<4xbf16> to vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// -----
+
+func.func @bfmmla_invalid_element_type_acc(%acc: vector<4xi32>,
+                                           %lhs: vector<8xbf16>,
+                                           %rhs: vector<8xbf16>) -> vector<4xi32> {
+  // expected-error at +1 {{op operand #0 must be a vector with length 4 of 32-bit float values, but got 'vector<4xi32>'}}
+  %0 = arm_neon.intr.bfmmla %acc, %lhs, %rhs : vector<8xbf16> to vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// -----
+
+func.func @bfmmla_invalid_dimension_acc(%acc: vector<8xf32>,
+                                        %lhs: vector<8xbf16>,
+                                        %rhs: vector<8xbf16>) -> vector<8xf32> {
+  // expected-error at +1 {{op operand #0 must be a vector with length 4 of 32-bit float values, but got 'vector<8xf32>'}}
+  %0 = arm_neon.intr.bfmmla %acc, %lhs, %rhs : vector<8xbf16> to vector<8xf32>
+  return %0 : vector<8xf32>
+}



More information about the Mlir-commits mailing list