[Mlir-commits] [mlir] [MLIR][ArmSVE] Add an ArmSVE dialect operation mapping to `bfmmla` (PR #145064)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 20 09:20:25 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Momchil Velikov (momchil-velikov)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/145064.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td (+35) 
- (modified) mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp (+7-3) 
- (modified) mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir (+9) 
- (modified) mlir/test/Dialect/ArmSVE/roundtrip.mlir (+10) 
- (modified) mlir/test/Target/LLVMIR/arm-sve.mlir (+12) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index 7385bb73b449a..c4007dd02c0d3 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -293,6 +293,35 @@ def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
     "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
 }
 
+
+def BfmmlaOp : ArmSVE_Op<"bfmmla", [Pure,
+                                    AllTypesMatch<["src1", "src2"]>,
+                                    AllTypesMatch<["acc", "dst"]>]> {
+  let summary = "BFloat16 matrix multiply-accumulate";
+  let description = [{
+    BFMMLA: BFloat16 matrix multiply-accumulate into 2×2 matrices";
+
+    This operation multiplies the 2x4 BFloat16 matrix held in each 128-bit
+    segment of the first source vector by the 4x2 BFloat16 matrix in the
+    corresponding segment of the second source vector, then accumulates
+    this intermediate result with the 2x2 Float32 matrix in the corresponding
+    segment of the accumulator vector, yielding the final 2x2 Float32
+    segment of the result.
+
+    Source:
+    https://developer.arm.com/documentation/100987/0000
+  }];
+  // Supports (vector<[8]xbf16>, vector<[8]xbf16>) -> (vector<[4]xf32>)
+  let arguments = (ins
+          ScalableVectorOfLengthAndType<[4], [F32]>:$acc,
+          ScalableVectorOfLengthAndType<[8], [BF16]>:$src1,
+          ScalableVectorOfLengthAndType<[8], [BF16]>:$src2
+  );
+  let results = (outs ScalableVectorOfLengthAndType<[4], [F32]>:$dst);
+  let assemblyFormat =
+    "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
+}
+
 class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
       "expected corresponding svbool type widened to [16]xi1",
       lhsArg, rhsArg,
@@ -590,6 +619,12 @@ def UsmmlaIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"usmmla">,
   Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
+def BfmmlaIntrOp :
+  ArmSVE_IntrOp<"bfmmla", [Pure, TypeIs<"res", ScalableVectorOfLengthAndType<[4], [F32]>>]>,
+  Arguments<(ins Arg<ScalableVectorOfLengthAndType<[4], [F32]>, "acc">:$acc,
+                 Arg<ScalableVectorOfLengthAndType<[8], [BF16]>, "lhs">:$lhs,
+                 Arg<ScalableVectorOfLengthAndType<[8], [BF16]>, "rhs">:$rhs)>;
+
 def SdotIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sdot">,
   Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 35f2a02cc4ec6..73f388b6d81c0 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -25,6 +25,7 @@ using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
 using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
+using BfmmlaOpLowering = OneToOneConvertToLLVMPattern<BfmmlaOp, BfmmlaIntrOp>;
 using DupQLaneLowering =
     OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
 using ScalableMaskedAddIOpLowering =
@@ -191,7 +192,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
   // Populate conversion patterns
 
   // clang-format off
-  patterns.add<ConvertFromSvboolOpLowering,
+  patterns.add<BfmmlaOpLowering,
+               ConvertFromSvboolOpLowering,
                ConvertToSvboolOpLowering,
                DupQLaneLowering,
                PselOpLowering,
@@ -220,7 +222,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
 void mlir::configureArmSVELegalizeForExportTarget(
     LLVMConversionTarget &target) {
   // clang-format off
-  target.addLegalOp<ConvertFromSvboolIntrOp,
+  target.addLegalOp<BfmmlaIntrOp,
+                    ConvertFromSvboolIntrOp,
                     ConvertToSvboolIntrOp,
                     DupQLaneIntrOp,
                     PselIntrOp,
@@ -241,7 +244,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
                     ZipX2IntrOp,
                     ZipX4IntrOp,
                     SdotIntrOp>();
-  target.addIllegalOp<ConvertFromSvboolOp,
+  target.addIllegalOp<BfmmlaOp,
+                      ConvertFromSvboolOp,
                       ConvertToSvboolOp,
                       DupQLaneOp,
                       PselOp,
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 8c658db009adf..8673b994d1e71 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -60,6 +60,15 @@ func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
 
 // -----
 
+func.func @arm_sve_bfmmla(%a: vector<[8]xbf16>,
+                          %b: vector<[8]xbf16>,
+                          %c: vector<[4]xf32>) -> vector<[4]xf32> {
+  // CHECK: arm_sve.intr.bfmmla
+  %0 = arm_sve.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32>
+  return %0 : vector<[4]xf32>
+}
+// -----
+
 func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
                             %b: vector<[4]xi32>,
                             %c: vector<[4]xi32>,
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index 64e0cff39eb06..9a653df767400 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -55,6 +55,16 @@ func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
 
 // -----
 
+func.func @arm_sve_bfmmla(%a: vector<[8]xbf16>,
+                          %b: vector<[8]xbf16>,
+                          %c: vector<[4]xf32>) -> vector<[4]xf32> {
+  // CHECK: arm_sve.bfmmla {{.*}}: vector<[8]xbf16> to vector<[4]xf32>
+  %0 = arm_sve.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32>
+  return %0 : vector<[4]xf32>
+}
+
+// -----
+
 func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
                             %b: vector<[4]xi32>,
                             %c: vector<[4]xi32>,
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index da71cb5a63bd2..737145c74e331 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -60,6 +60,18 @@ llvm.func @arm_sve_usmmla(%arg0: vector<[16]xi8>,
   llvm.return %0 : vector<[4]xi32>
 }
 
+// CHECK-LABEL: define <vscale x 4 x float> @arm_sve_bfmmla
+llvm.func @arm_sve_bfmmla(%arg0: vector<[8]xbf16>,
+                          %arg1: vector<[8]xbf16>,
+                          %arg2: vector<[4]xf32>)
+                          -> vector<[4]xf32> {
+  // CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.bfmmla(<vscale x 4 x float>
+  %0 = "arm_sve.intr.bfmmla"(%arg2, %arg0, %arg1) :
+    (vector<[4]xf32>, vector<[8]xbf16>, vector<[8]xbf16>)
+        -> vector<[4]xf32>
+  llvm.return %0 : vector<[4]xf32>
+}
+
 // CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
 llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>,
                           %arg1: vector<[4]xi32>,

``````````

</details>


https://github.com/llvm/llvm-project/pull/145064


More information about the Mlir-commits mailing list