[llvm] [AArch64][SVE] Add custom lowering for bfloat FMUL (with +bf16) (PR #167502)

Benjamin Maxwell via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 11 05:22:20 PST 2025


https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/167502

This lowers an SVE FMUL of bf16 using the BFMLAL top/bottom instructions rather than extending to an f32 mul. This does require zeroing the accumulator, but requires fewer extends/unpacking.

>From eaf1129fe8e1561c1ad04c65cba197a5ca1d0712 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 11 Nov 2025 13:14:00 +0000
Subject: [PATCH] [AArch64][SVE] Add custom lowering for bfloat FMUL (with
 +bf16)

This lowers an SVE FMUL of bf16 using the BFMLAL top/bottom instructions
rather than extending to an f32 mul. This does require zeroing the
accumulator, but requires fewer extends/unpacking.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 45 +++++++++-
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |  1 +
 llvm/test/CodeGen/AArch64/sve-bf16-arith.ll   | 88 +++++++++++--------
 3 files changed, 98 insertions(+), 36 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 8457f6178fdc2..9508e63630669 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1801,6 +1801,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
         setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32);
         setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32);
       }
+
+      if (Subtarget->hasBF16() &&
+          (Subtarget->hasSVE() || Subtarget->hasSME())) {
+        for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16})
+          setOperationAction(ISD::FMUL, VT, Custom);
+      }
     }
 
     setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
@@ -7529,6 +7535,43 @@ SDValue AArch64TargetLowering::LowerINIT_TRAMPOLINE(SDValue Op,
                      EndOfTrmp);
 }
 
+SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
+  EVT VT = Op.getValueType();
+  auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
+  if (VT.getScalarType() != MVT::bf16 ||
+      !(Subtarget.hasBF16() && (Subtarget.hasSVE() || Subtarget.hasSME())))
+    return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
+
+  SDLoc DL(Op);
+  SDValue Zero = DAG.getConstantFP(0.0, DL, MVT::nxv4f32);
+  SDValue LHS = Op.getOperand(0);
+  SDValue RHS = Op.getOperand(1);
+
+  auto GetIntrinsic = [&](Intrinsic::ID IID, EVT VT, auto... Ops) {
+    return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
+                       DAG.getConstant(IID, DL, MVT::i32), Ops...);
+  };
+
+  SDValue Pg =
+      getPTrue(DAG, DL, VT == MVT::nxv2bf16 ? MVT::nxv2i1 : MVT::nxv4i1,
+               AArch64SVEPredPattern::all);
+  // Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom
+  // instructions. These result in two f32 vectors, which can be converted back
+  // to bf16 with FCVT and FCVNT.
+  SDValue BottomF32 = GetIntrinsic(Intrinsic::aarch64_sve_bfmlalb, MVT::nxv4f32,
+                                   Zero, LHS, RHS);
+  SDValue BottomBF16 = GetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2, VT,
+                                    DAG.getPOISON(VT), Pg, BottomF32);
+  if (VT == MVT::nxv8bf16) {
+    // Note: nxv2bf16 and nxv4bf16 only use even lanes.
+    SDValue TopF32 = GetIntrinsic(Intrinsic::aarch64_sve_bfmlalt, MVT::nxv4f32,
+                                  Zero, LHS, RHS);
+    return GetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2, VT,
+                        BottomBF16, Pg, TopF32);
+  }
+  return BottomBF16;
+}
+
 SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
                                               SelectionDAG &DAG) const {
   LLVM_DEBUG(dbgs() << "Custom lowering: ");
@@ -7603,7 +7646,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
   case ISD::FSUB:
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSUB_PRED);
   case ISD::FMUL:
-    return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
+    return LowerFMUL(Op, DAG);
   case ISD::FMA:
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
   case ISD::FDIV:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 70bfae717fb76..a926a4822d9cc 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -609,6 +609,7 @@ class AArch64TargetLowering : public TargetLowering {
   SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerStore128(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerABS(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerFMUL(SDValue Op, SelectionDAG &DAG) const;
 
   SDValue LowerMGATHER(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerMSCATTER(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll b/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll
index 0580f5e0b019a..d8f1ec0241a17 100644
--- a/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll
+++ b/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll
@@ -1,7 +1,7 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc -mattr=+sve,+bf16                         < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16
+; RUN: llc -mattr=+sve,+bf16                         < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16,NOB16B16-NONSTREAMING
 ; RUN: llc -mattr=+sve,+bf16,+sve-b16b16             < %s | FileCheck %s --check-prefixes=CHECK,B16B16
-; RUN: llc -mattr=+sme,+sve-b16b16 -force-streaming  < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16
+; RUN: llc -mattr=+sme,+sve-b16b16 -force-streaming  < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16,NOB16B16-STREAMING
 ; RUN: llc -mattr=+sme2,+sve-b16b16 -force-streaming < %s | FileCheck %s --check-prefixes=CHECK,B16B16
 
 target triple = "aarch64-unknown-linux-gnu"
@@ -520,64 +520,82 @@ define <vscale x 8 x bfloat> @fmla_nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x
 ;
 
 define <vscale x 2 x bfloat> @fmul_nxv2bf16(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b) {
-; NOB16B16-LABEL: fmul_nxv2bf16:
-; NOB16B16:       // %bb.0:
-; NOB16B16-NEXT:    lsl z1.s, z1.s, #16
-; NOB16B16-NEXT:    lsl z0.s, z0.s, #16
-; NOB16B16-NEXT:    ptrue p0.d
-; NOB16B16-NEXT:    fmul z0.s, p0/m, z0.s, z1.s
-; NOB16B16-NEXT:    bfcvt z0.h, p0/m, z0.s
-; NOB16B16-NEXT:    ret
+; NOB16B16-NONSTREAMING-LABEL: fmul_nxv2bf16:
+; NOB16B16-NONSTREAMING:       // %bb.0:
+; NOB16B16-NONSTREAMING-NEXT:    movi v2.2d, #0000000000000000
+; NOB16B16-NONSTREAMING-NEXT:    ptrue p0.d
+; NOB16B16-NONSTREAMING-NEXT:    bfmlalb z2.s, z0.h, z1.h
+; NOB16B16-NONSTREAMING-NEXT:    bfcvt z0.h, p0/m, z2.s
+; NOB16B16-NONSTREAMING-NEXT:    ret
 ;
 ; B16B16-LABEL: fmul_nxv2bf16:
 ; B16B16:       // %bb.0:
 ; B16B16-NEXT:    bfmul z0.h, z0.h, z1.h
 ; B16B16-NEXT:    ret
+;
+; NOB16B16-STREAMING-LABEL: fmul_nxv2bf16:
+; NOB16B16-STREAMING:       // %bb.0:
+; NOB16B16-STREAMING-NEXT:    mov z2.s, #0 // =0x0
+; NOB16B16-STREAMING-NEXT:    ptrue p0.d
+; NOB16B16-STREAMING-NEXT:    bfmlalb z2.s, z0.h, z1.h
+; NOB16B16-STREAMING-NEXT:    bfcvt z0.h, p0/m, z2.s
+; NOB16B16-STREAMING-NEXT:    ret
   %res = fmul <vscale x 2 x bfloat> %a, %b
   ret <vscale x 2 x bfloat> %res
 }
 
 define <vscale x 4 x bfloat> @fmul_nxv4bf16(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b) {
-; NOB16B16-LABEL: fmul_nxv4bf16:
-; NOB16B16:       // %bb.0:
-; NOB16B16-NEXT:    lsl z1.s, z1.s, #16
-; NOB16B16-NEXT:    lsl z0.s, z0.s, #16
-; NOB16B16-NEXT:    ptrue p0.s
-; NOB16B16-NEXT:    fmul z0.s, z0.s, z1.s
-; NOB16B16-NEXT:    bfcvt z0.h, p0/m, z0.s
-; NOB16B16-NEXT:    ret
+; NOB16B16-NONSTREAMING-LABEL: fmul_nxv4bf16:
+; NOB16B16-NONSTREAMING:       // %bb.0:
+; NOB16B16-NONSTREAMING-NEXT:    movi v2.2d, #0000000000000000
+; NOB16B16-NONSTREAMING-NEXT:    ptrue p0.s
+; NOB16B16-NONSTREAMING-NEXT:    bfmlalb z2.s, z0.h, z1.h
+; NOB16B16-NONSTREAMING-NEXT:    bfcvt z0.h, p0/m, z2.s
+; NOB16B16-NONSTREAMING-NEXT:    ret
 ;
 ; B16B16-LABEL: fmul_nxv4bf16:
 ; B16B16:       // %bb.0:
 ; B16B16-NEXT:    bfmul z0.h, z0.h, z1.h
 ; B16B16-NEXT:    ret
+;
+; NOB16B16-STREAMING-LABEL: fmul_nxv4bf16:
+; NOB16B16-STREAMING:       // %bb.0:
+; NOB16B16-STREAMING-NEXT:    mov z2.s, #0 // =0x0
+; NOB16B16-STREAMING-NEXT:    ptrue p0.s
+; NOB16B16-STREAMING-NEXT:    bfmlalb z2.s, z0.h, z1.h
+; NOB16B16-STREAMING-NEXT:    bfcvt z0.h, p0/m, z2.s
+; NOB16B16-STREAMING-NEXT:    ret
   %res = fmul <vscale x 4 x bfloat> %a, %b
   ret <vscale x 4 x bfloat> %res
 }
 
 define <vscale x 8 x bfloat> @fmul_nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b) {
-; NOB16B16-LABEL: fmul_nxv8bf16:
-; NOB16B16:       // %bb.0:
-; NOB16B16-NEXT:    uunpkhi z2.s, z1.h
-; NOB16B16-NEXT:    uunpkhi z3.s, z0.h
-; NOB16B16-NEXT:    uunpklo z1.s, z1.h
-; NOB16B16-NEXT:    uunpklo z0.s, z0.h
-; NOB16B16-NEXT:    ptrue p0.s
-; NOB16B16-NEXT:    lsl z2.s, z2.s, #16
-; NOB16B16-NEXT:    lsl z3.s, z3.s, #16
-; NOB16B16-NEXT:    lsl z1.s, z1.s, #16
-; NOB16B16-NEXT:    lsl z0.s, z0.s, #16
-; NOB16B16-NEXT:    fmul z2.s, z3.s, z2.s
-; NOB16B16-NEXT:    fmul z0.s, z0.s, z1.s
-; NOB16B16-NEXT:    bfcvt z1.h, p0/m, z2.s
-; NOB16B16-NEXT:    bfcvt z0.h, p0/m, z0.s
-; NOB16B16-NEXT:    uzp1 z0.h, z0.h, z1.h
-; NOB16B16-NEXT:    ret
+; NOB16B16-NONSTREAMING-LABEL: fmul_nxv8bf16:
+; NOB16B16-NONSTREAMING:       // %bb.0:
+; NOB16B16-NONSTREAMING-NEXT:    movi v2.2d, #0000000000000000
+; NOB16B16-NONSTREAMING-NEXT:    movi v3.2d, #0000000000000000
+; NOB16B16-NONSTREAMING-NEXT:    ptrue p0.s
+; NOB16B16-NONSTREAMING-NEXT:    bfmlalb z2.s, z0.h, z1.h
+; NOB16B16-NONSTREAMING-NEXT:    bfmlalt z3.s, z0.h, z1.h
+; NOB16B16-NONSTREAMING-NEXT:    bfcvt z0.h, p0/m, z2.s
+; NOB16B16-NONSTREAMING-NEXT:    bfcvtnt z0.h, p0/m, z3.s
+; NOB16B16-NONSTREAMING-NEXT:    ret
 ;
 ; B16B16-LABEL: fmul_nxv8bf16:
 ; B16B16:       // %bb.0:
 ; B16B16-NEXT:    bfmul z0.h, z0.h, z1.h
 ; B16B16-NEXT:    ret
+;
+; NOB16B16-STREAMING-LABEL: fmul_nxv8bf16:
+; NOB16B16-STREAMING:       // %bb.0:
+; NOB16B16-STREAMING-NEXT:    mov z2.s, #0 // =0x0
+; NOB16B16-STREAMING-NEXT:    mov z3.s, #0 // =0x0
+; NOB16B16-STREAMING-NEXT:    ptrue p0.s
+; NOB16B16-STREAMING-NEXT:    bfmlalb z2.s, z0.h, z1.h
+; NOB16B16-STREAMING-NEXT:    bfmlalt z3.s, z0.h, z1.h
+; NOB16B16-STREAMING-NEXT:    bfcvt z0.h, p0/m, z2.s
+; NOB16B16-STREAMING-NEXT:    bfcvtnt z0.h, p0/m, z3.s
+; NOB16B16-STREAMING-NEXT:    ret
   %res = fmul <vscale x 8 x bfloat> %a, %b
   ret <vscale x 8 x bfloat> %res
 }



More information about the llvm-commits mailing list