[Mlir-commits] [mlir] [MLIR][ArmSVE] Add an ArmSVE dialect operation mapping to `bfmmla` (PR #145064)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 20 09:20:25 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Momchil Velikov (momchil-velikov)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/145064.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td (+35)
- (modified) mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp (+7-3)
- (modified) mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir (+9)
- (modified) mlir/test/Dialect/ArmSVE/roundtrip.mlir (+10)
- (modified) mlir/test/Target/LLVMIR/arm-sve.mlir (+12)
``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index 7385bb73b449a..c4007dd02c0d3 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -293,6 +293,35 @@ def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}
+
+def BfmmlaOp : ArmSVE_Op<"bfmmla", [Pure,
+ AllTypesMatch<["src1", "src2"]>,
+ AllTypesMatch<["acc", "dst"]>]> {
+ let summary = "BFloat16 matrix multiply-accumulate";
+ let description = [{
+ BFMMLA: BFloat16 matrix multiply-accumulate into 2×2 matrices";
+
+ This operation multiplies the 2x4 BFloat16 matrix held in each 128-bit
+ segment of the first source vector by the 4x2 BFloat16 matrix in the
+ corresponding segment of the second source vector, then accumulates
+ this intermediate result with the 2x2 Float32 matrix in the corresponding
+ segment of the accumulator vector, yielding the final 2x2 Float32
+ segment of the result.
+
+ Source:
+ https://developer.arm.com/documentation/100987/0000
+ }];
+ // Supports (vector<[8]xbf16>, vector<[8]xbf16>) -> (vector<[4]xf32>)
+ let arguments = (ins
+ ScalableVectorOfLengthAndType<[4], [F32]>:$acc,
+ ScalableVectorOfLengthAndType<[8], [BF16]>:$src1,
+ ScalableVectorOfLengthAndType<[8], [BF16]>:$src2
+ );
+ let results = (outs ScalableVectorOfLengthAndType<[4], [F32]>:$dst);
+ let assemblyFormat =
+ "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
+}
+
class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
"expected corresponding svbool type widened to [16]xi1",
lhsArg, rhsArg,
@@ -590,6 +619,12 @@ def UsmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"usmmla">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
+def BfmmlaIntrOp :
+ ArmSVE_IntrOp<"bfmmla", [Pure, TypeIs<"res", ScalableVectorOfLengthAndType<[4], [F32]>>]>,
+ Arguments<(ins Arg<ScalableVectorOfLengthAndType<[4], [F32]>, "acc">:$acc,
+ Arg<ScalableVectorOfLengthAndType<[8], [BF16]>, "lhs">:$lhs,
+ Arg<ScalableVectorOfLengthAndType<[8], [BF16]>, "rhs">:$rhs)>;
+
def SdotIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"sdot">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 35f2a02cc4ec6..73f388b6d81c0 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -25,6 +25,7 @@ using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
+using BfmmlaOpLowering = OneToOneConvertToLLVMPattern<BfmmlaOp, BfmmlaIntrOp>;
using DupQLaneLowering =
OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
using ScalableMaskedAddIOpLowering =
@@ -191,7 +192,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
// Populate conversion patterns
// clang-format off
- patterns.add<ConvertFromSvboolOpLowering,
+ patterns.add<BfmmlaOpLowering,
+ ConvertFromSvboolOpLowering,
ConvertToSvboolOpLowering,
DupQLaneLowering,
PselOpLowering,
@@ -220,7 +222,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
void mlir::configureArmSVELegalizeForExportTarget(
LLVMConversionTarget &target) {
// clang-format off
- target.addLegalOp<ConvertFromSvboolIntrOp,
+ target.addLegalOp<BfmmlaIntrOp,
+ ConvertFromSvboolIntrOp,
ConvertToSvboolIntrOp,
DupQLaneIntrOp,
PselIntrOp,
@@ -241,7 +244,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
ZipX2IntrOp,
ZipX4IntrOp,
SdotIntrOp>();
- target.addIllegalOp<ConvertFromSvboolOp,
+ target.addIllegalOp<BfmmlaOp,
+ ConvertFromSvboolOp,
ConvertToSvboolOp,
DupQLaneOp,
PselOp,
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 8c658db009adf..8673b994d1e71 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -60,6 +60,15 @@ func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
// -----
+func.func @arm_sve_bfmmla(%a: vector<[8]xbf16>,
+ %b: vector<[8]xbf16>,
+ %c: vector<[4]xf32>) -> vector<[4]xf32> {
+ // CHECK: arm_sve.intr.bfmmla
+ %0 = arm_sve.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32>
+ return %0 : vector<[4]xf32>
+}
+// -----
+
func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
%b: vector<[4]xi32>,
%c: vector<[4]xi32>,
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index 64e0cff39eb06..9a653df767400 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -55,6 +55,16 @@ func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
// -----
+func.func @arm_sve_bfmmla(%a: vector<[8]xbf16>,
+ %b: vector<[8]xbf16>,
+ %c: vector<[4]xf32>) -> vector<[4]xf32> {
+ // CHECK: arm_sve.bfmmla {{.*}}: vector<[8]xbf16> to vector<[4]xf32>
+ %0 = arm_sve.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32>
+ return %0 : vector<[4]xf32>
+}
+
+// -----
+
func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
%b: vector<[4]xi32>,
%c: vector<[4]xi32>,
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index da71cb5a63bd2..737145c74e331 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -60,6 +60,18 @@ llvm.func @arm_sve_usmmla(%arg0: vector<[16]xi8>,
llvm.return %0 : vector<[4]xi32>
}
+// CHECK-LABEL: define <vscale x 4 x float> @arm_sve_bfmmla
+llvm.func @arm_sve_bfmmla(%arg0: vector<[8]xbf16>,
+ %arg1: vector<[8]xbf16>,
+ %arg2: vector<[4]xf32>)
+ -> vector<[4]xf32> {
+ // CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.bfmmla(<vscale x 4 x float>
+ %0 = "arm_sve.intr.bfmmla"(%arg2, %arg0, %arg1) :
+ (vector<[4]xf32>, vector<[8]xbf16>, vector<[8]xbf16>)
+ -> vector<[4]xf32>
+ llvm.return %0 : vector<[4]xf32>
+}
+
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>,
%arg1: vector<[4]xi32>,
``````````
</details>
https://github.com/llvm/llvm-project/pull/145064
More information about the Mlir-commits
mailing list