[Mlir-commits] [mlir] [MLIR][ArmSVE] Add an ArmSVE dialect operation mapping to `bfmmla` (PR #145064)

Momchil Velikov llvmlistbot at llvm.org
Wed Jun 25 06:16:40 PDT 2025


https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/145064

>From feb578d1922772923bcc94d6e4aea43a51718a9c Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Fri, 20 Jun 2025 16:16:05 +0000
Subject: [PATCH 1/2] [MLIR][ArmSVE] Add an ArmSVE dialect operation mapping to
 `bfmmla`

---
 mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 35 +++++++++++++++++++
 .../Transforms/LegalizeForLLVMExport.cpp      | 10 ++++--
 .../Dialect/ArmSVE/legalize-for-llvm.mlir     |  9 +++++
 mlir/test/Dialect/ArmSVE/roundtrip.mlir       | 10 ++++++
 mlir/test/Target/LLVMIR/arm-sve.mlir          | 12 +++++++
 5 files changed, 73 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index 7385bb73b449a..c4007dd02c0d3 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -293,6 +293,35 @@ def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
     "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
 }
 
+
+def BfmmlaOp : ArmSVE_Op<"bfmmla", [Pure,
+                                    AllTypesMatch<["src1", "src2"]>,
+                                    AllTypesMatch<["acc", "dst"]>]> {
+  let summary = "BFloat16 matrix multiply-accumulate";
+  let description = [{
+    BFMMLA: BFloat16 matrix multiply-accumulate into 2×2 matrices";
+
+    This operation multiplies the 2x4 BFloat16 matrix held in each 128-bit
+    segment of the first source vector by the 4x2 BFloat16 matrix in the
+    corresponding segment of the second source vector, then accumulates
+    this intermediate result with the 2x2 Float32 matrix in the corresponding
+    segment of the accumulator vector, yielding the final 2x2 Float32
+    segment of the result.
+
+    Source:
+    https://developer.arm.com/documentation/100987/0000
+  }];
+  // Supports (vector<[8]xbf16>, vector<[8]xbf16>) -> (vector<[4]xf32>)
+  let arguments = (ins
+          ScalableVectorOfLengthAndType<[4], [F32]>:$acc,
+          ScalableVectorOfLengthAndType<[8], [BF16]>:$src1,
+          ScalableVectorOfLengthAndType<[8], [BF16]>:$src2
+  );
+  let results = (outs ScalableVectorOfLengthAndType<[4], [F32]>:$dst);
+  let assemblyFormat =
+    "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
+}
+
 class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
       "expected corresponding svbool type widened to [16]xi1",
       lhsArg, rhsArg,
@@ -590,6 +619,12 @@ def UsmmlaIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"usmmla">,
   Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
+def BfmmlaIntrOp :
+  ArmSVE_IntrOp<"bfmmla", [Pure, TypeIs<"res", ScalableVectorOfLengthAndType<[4], [F32]>>]>,
+  Arguments<(ins Arg<ScalableVectorOfLengthAndType<[4], [F32]>, "acc">:$acc,
+                 Arg<ScalableVectorOfLengthAndType<[8], [BF16]>, "lhs">:$lhs,
+                 Arg<ScalableVectorOfLengthAndType<[8], [BF16]>, "rhs">:$rhs)>;
+
 def SdotIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sdot">,
   Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 35f2a02cc4ec6..73f388b6d81c0 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -25,6 +25,7 @@ using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
 using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
+using BfmmlaOpLowering = OneToOneConvertToLLVMPattern<BfmmlaOp, BfmmlaIntrOp>;
 using DupQLaneLowering =
     OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
 using ScalableMaskedAddIOpLowering =
@@ -191,7 +192,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
   // Populate conversion patterns
 
   // clang-format off
-  patterns.add<ConvertFromSvboolOpLowering,
+  patterns.add<BfmmlaOpLowering,
+               ConvertFromSvboolOpLowering,
                ConvertToSvboolOpLowering,
                DupQLaneLowering,
                PselOpLowering,
@@ -220,7 +222,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
 void mlir::configureArmSVELegalizeForExportTarget(
     LLVMConversionTarget &target) {
   // clang-format off
-  target.addLegalOp<ConvertFromSvboolIntrOp,
+  target.addLegalOp<BfmmlaIntrOp,
+                    ConvertFromSvboolIntrOp,
                     ConvertToSvboolIntrOp,
                     DupQLaneIntrOp,
                     PselIntrOp,
@@ -241,7 +244,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
                     ZipX2IntrOp,
                     ZipX4IntrOp,
                     SdotIntrOp>();
-  target.addIllegalOp<ConvertFromSvboolOp,
+  target.addIllegalOp<BfmmlaOp,
+                      ConvertFromSvboolOp,
                       ConvertToSvboolOp,
                       DupQLaneOp,
                       PselOp,
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 8c658db009adf..8673b994d1e71 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -60,6 +60,15 @@ func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
 
 // -----
 
+func.func @arm_sve_bfmmla(%a: vector<[8]xbf16>,
+                          %b: vector<[8]xbf16>,
+                          %c: vector<[4]xf32>) -> vector<[4]xf32> {
+  // CHECK: arm_sve.intr.bfmmla
+  %0 = arm_sve.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32>
+  return %0 : vector<[4]xf32>
+}
+// -----
+
 func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
                             %b: vector<[4]xi32>,
                             %c: vector<[4]xi32>,
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index 64e0cff39eb06..9a653df767400 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -55,6 +55,16 @@ func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
 
 // -----
 
+func.func @arm_sve_bfmmla(%a: vector<[8]xbf16>,
+                          %b: vector<[8]xbf16>,
+                          %c: vector<[4]xf32>) -> vector<[4]xf32> {
+  // CHECK: arm_sve.bfmmla {{.*}}: vector<[8]xbf16> to vector<[4]xf32>
+  %0 = arm_sve.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32>
+  return %0 : vector<[4]xf32>
+}
+
+// -----
+
 func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
                             %b: vector<[4]xi32>,
                             %c: vector<[4]xi32>,
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index da71cb5a63bd2..737145c74e331 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -60,6 +60,18 @@ llvm.func @arm_sve_usmmla(%arg0: vector<[16]xi8>,
   llvm.return %0 : vector<[4]xi32>
 }
 
+// CHECK-LABEL: define <vscale x 4 x float> @arm_sve_bfmmla
+llvm.func @arm_sve_bfmmla(%arg0: vector<[8]xbf16>,
+                          %arg1: vector<[8]xbf16>,
+                          %arg2: vector<[4]xf32>)
+                          -> vector<[4]xf32> {
+  // CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.bfmmla(<vscale x 4 x float>
+  %0 = "arm_sve.intr.bfmmla"(%arg2, %arg0, %arg1) :
+    (vector<[4]xf32>, vector<[8]xbf16>, vector<[8]xbf16>)
+        -> vector<[4]xf32>
+  llvm.return %0 : vector<[4]xf32>
+}
+
 // CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
 llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>,
                           %arg1: vector<[4]xi32>,

>From fcae83d1364a35585b98e950e473b1ee1c6c0583 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Wed, 25 Jun 2025 13:12:04 +0000
Subject: [PATCH 2/2] [fixup] Skip the two-stage LLVM IR generation, map the op
 directly to the LLVM IR intrinsic

---
 mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td  | 18 ++++++------------
 .../Transforms/LegalizeForLLVMExport.cpp       |  9 +++------
 .../test/Dialect/ArmSVE/legalize-for-llvm.mlir |  9 ---------
 mlir/test/Dialect/ArmSVE/roundtrip.mlir        |  4 ++--
 4 files changed, 11 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index c4007dd02c0d3..8988df680b8f9 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -293,10 +293,10 @@ def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
     "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
 }
 
-
-def BfmmlaOp : ArmSVE_Op<"bfmmla", [Pure,
-                                    AllTypesMatch<["src1", "src2"]>,
-                                    AllTypesMatch<["acc", "dst"]>]> {
+def BfmmlaOp : ArmSVE_IntrOp<"bfmmla", [Pure,
+                                        AllTypesMatch<["src1", "src2"]>,
+                                        AllTypesMatch<["acc", "res"]>,
+                                        ]> {
   let summary = "BFloat16 matrix multiply-accumulate";
   let description = [{
     BFMMLA: BFloat16 matrix multiply-accumulate into 2×2 matrices";
@@ -317,9 +317,9 @@ def BfmmlaOp : ArmSVE_Op<"bfmmla", [Pure,
           ScalableVectorOfLengthAndType<[8], [BF16]>:$src1,
           ScalableVectorOfLengthAndType<[8], [BF16]>:$src2
   );
-  let results = (outs ScalableVectorOfLengthAndType<[4], [F32]>:$dst);
+  let results = (outs ScalableVectorOfLengthAndType<[4], [F32]>:$res);
   let assemblyFormat =
-    "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
+    "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
 }
 
 class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
@@ -619,12 +619,6 @@ def UsmmlaIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"usmmla">,
   Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
-def BfmmlaIntrOp :
-  ArmSVE_IntrOp<"bfmmla", [Pure, TypeIs<"res", ScalableVectorOfLengthAndType<[4], [F32]>>]>,
-  Arguments<(ins Arg<ScalableVectorOfLengthAndType<[4], [F32]>, "acc">:$acc,
-                 Arg<ScalableVectorOfLengthAndType<[8], [BF16]>, "lhs">:$lhs,
-                 Arg<ScalableVectorOfLengthAndType<[8], [BF16]>, "rhs">:$rhs)>;
-
 def SdotIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sdot">,
   Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 73f388b6d81c0..006332b48325f 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -25,7 +25,6 @@ using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
 using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
-using BfmmlaOpLowering = OneToOneConvertToLLVMPattern<BfmmlaOp, BfmmlaIntrOp>;
 using DupQLaneLowering =
     OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
 using ScalableMaskedAddIOpLowering =
@@ -192,8 +191,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
   // Populate conversion patterns
 
   // clang-format off
-  patterns.add<BfmmlaOpLowering,
-               ConvertFromSvboolOpLowering,
+  patterns.add<ConvertFromSvboolOpLowering,
                ConvertToSvboolOpLowering,
                DupQLaneLowering,
                PselOpLowering,
@@ -222,7 +220,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
 void mlir::configureArmSVELegalizeForExportTarget(
     LLVMConversionTarget &target) {
   // clang-format off
-  target.addLegalOp<BfmmlaIntrOp,
+  target.addLegalOp<BfmmlaOp,
                     ConvertFromSvboolIntrOp,
                     ConvertToSvboolIntrOp,
                     DupQLaneIntrOp,
@@ -244,8 +242,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
                     ZipX2IntrOp,
                     ZipX4IntrOp,
                     SdotIntrOp>();
-  target.addIllegalOp<BfmmlaOp,
-                      ConvertFromSvboolOp,
+  target.addIllegalOp<ConvertFromSvboolOp,
                       ConvertToSvboolOp,
                       DupQLaneOp,
                       PselOp,
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 8673b994d1e71..8c658db009adf 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -60,15 +60,6 @@ func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
 
 // -----
 
-func.func @arm_sve_bfmmla(%a: vector<[8]xbf16>,
-                          %b: vector<[8]xbf16>,
-                          %c: vector<[4]xf32>) -> vector<[4]xf32> {
-  // CHECK: arm_sve.intr.bfmmla
-  %0 = arm_sve.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32>
-  return %0 : vector<[4]xf32>
-}
-// -----
-
 func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
                             %b: vector<[4]xi32>,
                             %c: vector<[4]xi32>,
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index 9a653df767400..b7b9329f1cb5a 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -58,8 +58,8 @@ func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
 func.func @arm_sve_bfmmla(%a: vector<[8]xbf16>,
                           %b: vector<[8]xbf16>,
                           %c: vector<[4]xf32>) -> vector<[4]xf32> {
-  // CHECK: arm_sve.bfmmla {{.*}}: vector<[8]xbf16> to vector<[4]xf32>
-  %0 = arm_sve.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32>
+  // CHECK: arm_sve.intr.bfmmla {{.*}}: vector<[8]xbf16> to vector<[4]xf32>
+  %0 = arm_sve.intr.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32>
   return %0 : vector<[4]xf32>
 }
 



More information about the Mlir-commits mailing list