[Mlir-commits] [mlir] [MLIR][ArmSVE] Add an ArmSVE dialect operation mapping to `bfmmla` (PR #145064)
Momchil Velikov
llvmlistbot at llvm.org
Wed Jun 25 06:16:40 PDT 2025
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/145064
>From feb578d1922772923bcc94d6e4aea43a51718a9c Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Fri, 20 Jun 2025 16:16:05 +0000
Subject: [PATCH 1/2] [MLIR][ArmSVE] Add an ArmSVE dialect operation mapping to
`bfmmla`
---
mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 35 +++++++++++++++++++
.../Transforms/LegalizeForLLVMExport.cpp | 10 ++++--
.../Dialect/ArmSVE/legalize-for-llvm.mlir | 9 +++++
mlir/test/Dialect/ArmSVE/roundtrip.mlir | 10 ++++++
mlir/test/Target/LLVMIR/arm-sve.mlir | 12 +++++++
5 files changed, 73 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index 7385bb73b449a..c4007dd02c0d3 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -293,6 +293,35 @@ def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}
+
+def BfmmlaOp : ArmSVE_Op<"bfmmla", [Pure,
+ AllTypesMatch<["src1", "src2"]>,
+ AllTypesMatch<["acc", "dst"]>]> {
+ let summary = "BFloat16 matrix multiply-accumulate";
+ let description = [{
+ BFMMLA: BFloat16 matrix multiply-accumulate into 2×2 matrices";
+
+ This operation multiplies the 2x4 BFloat16 matrix held in each 128-bit
+ segment of the first source vector by the 4x2 BFloat16 matrix in the
+ corresponding segment of the second source vector, then accumulates
+ this intermediate result with the 2x2 Float32 matrix in the corresponding
+ segment of the accumulator vector, yielding the final 2x2 Float32
+ segment of the result.
+
+ Source:
+ https://developer.arm.com/documentation/100987/0000
+ }];
+ // Supports (vector<[8]xbf16>, vector<[8]xbf16>) -> (vector<[4]xf32>)
+ let arguments = (ins
+ ScalableVectorOfLengthAndType<[4], [F32]>:$acc,
+ ScalableVectorOfLengthAndType<[8], [BF16]>:$src1,
+ ScalableVectorOfLengthAndType<[8], [BF16]>:$src2
+ );
+ let results = (outs ScalableVectorOfLengthAndType<[4], [F32]>:$dst);
+ let assemblyFormat =
+ "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
+}
+
class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
"expected corresponding svbool type widened to [16]xi1",
lhsArg, rhsArg,
@@ -590,6 +619,12 @@ def UsmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"usmmla">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
+def BfmmlaIntrOp :
+ ArmSVE_IntrOp<"bfmmla", [Pure, TypeIs<"res", ScalableVectorOfLengthAndType<[4], [F32]>>]>,
+ Arguments<(ins Arg<ScalableVectorOfLengthAndType<[4], [F32]>, "acc">:$acc,
+ Arg<ScalableVectorOfLengthAndType<[8], [BF16]>, "lhs">:$lhs,
+ Arg<ScalableVectorOfLengthAndType<[8], [BF16]>, "rhs">:$rhs)>;
+
def SdotIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"sdot">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 35f2a02cc4ec6..73f388b6d81c0 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -25,6 +25,7 @@ using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
+using BfmmlaOpLowering = OneToOneConvertToLLVMPattern<BfmmlaOp, BfmmlaIntrOp>;
using DupQLaneLowering =
OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
using ScalableMaskedAddIOpLowering =
@@ -191,7 +192,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
// Populate conversion patterns
// clang-format off
- patterns.add<ConvertFromSvboolOpLowering,
+ patterns.add<BfmmlaOpLowering,
+ ConvertFromSvboolOpLowering,
ConvertToSvboolOpLowering,
DupQLaneLowering,
PselOpLowering,
@@ -220,7 +222,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
void mlir::configureArmSVELegalizeForExportTarget(
LLVMConversionTarget &target) {
// clang-format off
- target.addLegalOp<ConvertFromSvboolIntrOp,
+ target.addLegalOp<BfmmlaIntrOp,
+ ConvertFromSvboolIntrOp,
ConvertToSvboolIntrOp,
DupQLaneIntrOp,
PselIntrOp,
@@ -241,7 +244,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
ZipX2IntrOp,
ZipX4IntrOp,
SdotIntrOp>();
- target.addIllegalOp<ConvertFromSvboolOp,
+ target.addIllegalOp<BfmmlaOp,
+ ConvertFromSvboolOp,
ConvertToSvboolOp,
DupQLaneOp,
PselOp,
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 8c658db009adf..8673b994d1e71 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -60,6 +60,15 @@ func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
// -----
+func.func @arm_sve_bfmmla(%a: vector<[8]xbf16>,
+ %b: vector<[8]xbf16>,
+ %c: vector<[4]xf32>) -> vector<[4]xf32> {
+ // CHECK: arm_sve.intr.bfmmla
+ %0 = arm_sve.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32>
+ return %0 : vector<[4]xf32>
+}
+// -----
+
func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
%b: vector<[4]xi32>,
%c: vector<[4]xi32>,
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index 64e0cff39eb06..9a653df767400 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -55,6 +55,16 @@ func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
// -----
+func.func @arm_sve_bfmmla(%a: vector<[8]xbf16>,
+ %b: vector<[8]xbf16>,
+ %c: vector<[4]xf32>) -> vector<[4]xf32> {
+ // CHECK: arm_sve.bfmmla {{.*}}: vector<[8]xbf16> to vector<[4]xf32>
+ %0 = arm_sve.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32>
+ return %0 : vector<[4]xf32>
+}
+
+// -----
+
func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
%b: vector<[4]xi32>,
%c: vector<[4]xi32>,
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index da71cb5a63bd2..737145c74e331 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -60,6 +60,18 @@ llvm.func @arm_sve_usmmla(%arg0: vector<[16]xi8>,
llvm.return %0 : vector<[4]xi32>
}
+// CHECK-LABEL: define <vscale x 4 x float> @arm_sve_bfmmla
+llvm.func @arm_sve_bfmmla(%arg0: vector<[8]xbf16>,
+ %arg1: vector<[8]xbf16>,
+ %arg2: vector<[4]xf32>)
+ -> vector<[4]xf32> {
+ // CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.bfmmla(<vscale x 4 x float>
+ %0 = "arm_sve.intr.bfmmla"(%arg2, %arg0, %arg1) :
+ (vector<[4]xf32>, vector<[8]xbf16>, vector<[8]xbf16>)
+ -> vector<[4]xf32>
+ llvm.return %0 : vector<[4]xf32>
+}
+
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>,
%arg1: vector<[4]xi32>,
>From fcae83d1364a35585b98e950e473b1ee1c6c0583 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Wed, 25 Jun 2025 13:12:04 +0000
Subject: [PATCH 2/2] [fixup] Skip the two-stage LLVM IR generation, map the op
directly to the LLVM IR intrinsic
---
mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 18 ++++++------------
.../Transforms/LegalizeForLLVMExport.cpp | 9 +++------
.../test/Dialect/ArmSVE/legalize-for-llvm.mlir | 9 ---------
mlir/test/Dialect/ArmSVE/roundtrip.mlir | 4 ++--
4 files changed, 11 insertions(+), 29 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index c4007dd02c0d3..8988df680b8f9 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -293,10 +293,10 @@ def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}
-
-def BfmmlaOp : ArmSVE_Op<"bfmmla", [Pure,
- AllTypesMatch<["src1", "src2"]>,
- AllTypesMatch<["acc", "dst"]>]> {
+def BfmmlaOp : ArmSVE_IntrOp<"bfmmla", [Pure,
+ AllTypesMatch<["src1", "src2"]>,
+ AllTypesMatch<["acc", "res"]>,
+ ]> {
let summary = "BFloat16 matrix multiply-accumulate";
let description = [{
BFMMLA: BFloat16 matrix multiply-accumulate into 2×2 matrices";
@@ -317,9 +317,9 @@ def BfmmlaOp : ArmSVE_Op<"bfmmla", [Pure,
ScalableVectorOfLengthAndType<[8], [BF16]>:$src1,
ScalableVectorOfLengthAndType<[8], [BF16]>:$src2
);
- let results = (outs ScalableVectorOfLengthAndType<[4], [F32]>:$dst);
+ let results = (outs ScalableVectorOfLengthAndType<[4], [F32]>:$res);
let assemblyFormat =
- "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
+ "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
}
class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
@@ -619,12 +619,6 @@ def UsmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"usmmla">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
-def BfmmlaIntrOp :
- ArmSVE_IntrOp<"bfmmla", [Pure, TypeIs<"res", ScalableVectorOfLengthAndType<[4], [F32]>>]>,
- Arguments<(ins Arg<ScalableVectorOfLengthAndType<[4], [F32]>, "acc">:$acc,
- Arg<ScalableVectorOfLengthAndType<[8], [BF16]>, "lhs">:$lhs,
- Arg<ScalableVectorOfLengthAndType<[8], [BF16]>, "rhs">:$rhs)>;
-
def SdotIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"sdot">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 73f388b6d81c0..006332b48325f 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -25,7 +25,6 @@ using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
-using BfmmlaOpLowering = OneToOneConvertToLLVMPattern<BfmmlaOp, BfmmlaIntrOp>;
using DupQLaneLowering =
OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
using ScalableMaskedAddIOpLowering =
@@ -192,8 +191,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
// Populate conversion patterns
// clang-format off
- patterns.add<BfmmlaOpLowering,
- ConvertFromSvboolOpLowering,
+ patterns.add<ConvertFromSvboolOpLowering,
ConvertToSvboolOpLowering,
DupQLaneLowering,
PselOpLowering,
@@ -222,7 +220,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
void mlir::configureArmSVELegalizeForExportTarget(
LLVMConversionTarget &target) {
// clang-format off
- target.addLegalOp<BfmmlaIntrOp,
+ target.addLegalOp<BfmmlaOp,
ConvertFromSvboolIntrOp,
ConvertToSvboolIntrOp,
DupQLaneIntrOp,
@@ -244,8 +242,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
ZipX2IntrOp,
ZipX4IntrOp,
SdotIntrOp>();
- target.addIllegalOp<BfmmlaOp,
- ConvertFromSvboolOp,
+ target.addIllegalOp<ConvertFromSvboolOp,
ConvertToSvboolOp,
DupQLaneOp,
PselOp,
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 8673b994d1e71..8c658db009adf 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -60,15 +60,6 @@ func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
// -----
-func.func @arm_sve_bfmmla(%a: vector<[8]xbf16>,
- %b: vector<[8]xbf16>,
- %c: vector<[4]xf32>) -> vector<[4]xf32> {
- // CHECK: arm_sve.intr.bfmmla
- %0 = arm_sve.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32>
- return %0 : vector<[4]xf32>
-}
-// -----
-
func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
%b: vector<[4]xi32>,
%c: vector<[4]xi32>,
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index 9a653df767400..b7b9329f1cb5a 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -58,8 +58,8 @@ func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
func.func @arm_sve_bfmmla(%a: vector<[8]xbf16>,
%b: vector<[8]xbf16>,
%c: vector<[4]xf32>) -> vector<[4]xf32> {
- // CHECK: arm_sve.bfmmla {{.*}}: vector<[8]xbf16> to vector<[4]xf32>
- %0 = arm_sve.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32>
+ // CHECK: arm_sve.intr.bfmmla {{.*}}: vector<[8]xbf16> to vector<[4]xf32>
+ %0 = arm_sve.intr.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32>
return %0 : vector<[4]xf32>
}
More information about the Mlir-commits
mailing list