[Mlir-commits] [mlir] [MLIR][ArmNeon] Add an ArmNeon operation which maps to `bfmmla` (PR #145038)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 20 06:43:46 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-llvm
Author: Momchil Velikov (momchil-velikov)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/145038.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td (+28)
- (modified) mlir/test/Dialect/ArmNeon/roundtrip.mlir (+12)
- (modified) mlir/test/Target/LLVMIR/arm-neon.mlir (+13)
``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
index 475b11f12c5f0..ce86ff2cfd922 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
+++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
@@ -222,6 +222,34 @@ def UsmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"usmmla",[1], [
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
}
+def BfmmlaOp : ArmNeon_IntrOp<"bfmmla", [], [], 1, [
+ Pure,
+ AllTypesMatch<["src1", "src2"]>,
+ AllTypesMatch<["acc", "res"]>,
+ ]> {
+ let summary = "BFloat16 matrix multiply-accumulate to single-precision";
+ let description = [{
+ BFMMLA: BFloat16 matrix multiply-accumulate to single-precision.
+
+ The operation multiplies the 2x4 BFloat16 matrix in the first source vector
+ with the 4x2 BFloat16 matrix in the second source vector, then accumulates
+ this intermediate result with the 2x2 Float32 matrix in the accumulator
+ vector, yielding the final 2x2 Float32 result.
+
+ Source:
+ https://developer.arm.com/architectures/instruction-sets/intrinsics/vbfmmlaq_f32
+ }];
+ // Supports (vector<8xbf16>, vector<8xbf16>) -> (vector<2xf32>)
+ let arguments = (ins
+ NeonVectorOfLength<4, F32>:$acc,
+ NeonVectorOfLength<8, BF16>:$src1,
+ NeonVectorOfLength<8, BF16>:$src2
+ );
+ let results = (outs NeonVectorOfLength<4, F32>:$res);
+ let assemblyFormat =
+ "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
+}
+
class ArmNeon_2dOp<string mnemonic, list<Trait> traits = []>
: Op</*dialect=*/ArmNeon_Dialect,
/*opName=*/"2d." # mnemonic,
diff --git a/mlir/test/Dialect/ArmNeon/roundtrip.mlir b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
index b5df0ffa8105c..60133ce0fa6f3 100644
--- a/mlir/test/Dialect/ArmNeon/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
@@ -60,3 +60,15 @@ func.func @arm_neon_usmmla(%a: vector<16xi8>,
%0 = arm_neon.intr.usmmla %c, %a, %b : vector<16xi8> to vector<4xi32>
return %0 : vector<4xi32>
}
+
+
+// -----
+
+// CHECK-LABEL: arm_neon_bfmmla
+func.func @arm_neon_bfmmla(%a: vector<8xbf16>,
+ %b: vector<8xbf16>,
+ %c: vector<4xf32>) -> vector<4xf32> {
+ // CHECK: arm_neon.intr.bfmmla {{.*}}: vector<8xbf16> to vector<4xf32>
+ %0 = arm_neon.intr.bfmmla %c, %a, %b : vector<8xbf16> to vector<4xf32>
+ return %0 : vector<4xf32>
+}
diff --git a/mlir/test/Target/LLVMIR/arm-neon.mlir b/mlir/test/Target/LLVMIR/arm-neon.mlir
index e096172667c9f..e1328ad448f0a 100644
--- a/mlir/test/Target/LLVMIR/arm-neon.mlir
+++ b/mlir/test/Target/LLVMIR/arm-neon.mlir
@@ -82,3 +82,16 @@ llvm.func @arm_neon_usmmla(%arg0: vector<16xi8>,
-> vector<4xi32>
llvm.return %0 : vector<4xi32>
}
+
+// -----
+
+// CHECK-LABEL: arm_neon_bfmmla
+llvm.func @arm_neon_bfmmla(%arg0: vector<8xbf16>,
+ %arg1: vector<8xbf16>,
+ %arg2: vector<4xf32>) -> vector<4xf32> {
+ // CHECK: <4 x float> @llvm.aarch64.neon.bfmmla(<4 x float
+ %0 = "arm_neon.intr.bfmmla"(%arg2, %arg0, %arg1) :
+ (vector<4xf32>, vector<8xbf16>, vector<8xbf16>)
+ -> vector<4xf32>
+ llvm.return %0 : vector<4xf32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/145038
More information about the Mlir-commits
mailing list