[llvm] [AArch64][SVE] Add custom lowering for bfloat FMUL (with +bf16) (PR #167502)

Benjamin Maxwell via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 19 05:22:54 PST 2025


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/167502

>From a0925e8924bcb4619059861134da1049a0a7969f Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 11 Nov 2025 13:14:00 +0000
Subject: [PATCH 1/6] [AArch64][SVE] Add custom lowering for bfloat FMUL (with
 +bf16)

This lowers an SVE FMUL of bf16 using the BFMLAL top/bottom instructions
rather than extending to an f32 mul. This does require zeroing the
accumulator, but requires fewer extends/unpacking.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 45 +++++++++-
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |  1 +
 llvm/test/CodeGen/AArch64/sve-bf16-arith.ll   | 88 +++++++++++--------
 3 files changed, 98 insertions(+), 36 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 81c87ace76e56..add7397a79c4c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1803,6 +1803,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
         setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32);
         setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32);
       }
+
+      if (Subtarget->hasBF16() &&
+          (Subtarget->hasSVE() || Subtarget->hasSME())) {
+        for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16})
+          setOperationAction(ISD::FMUL, VT, Custom);
+      }
     }
 
     setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
@@ -7538,6 +7544,43 @@ SDValue AArch64TargetLowering::LowerINIT_TRAMPOLINE(SDValue Op,
                      EndOfTrmp);
 }
 
+SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
+  EVT VT = Op.getValueType();
+  auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
+  if (VT.getScalarType() != MVT::bf16 ||
+      !(Subtarget.hasBF16() && (Subtarget.hasSVE() || Subtarget.hasSME())))
+    return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
+
+  SDLoc DL(Op);
+  SDValue Zero = DAG.getConstantFP(0.0, DL, MVT::nxv4f32);
+  SDValue LHS = Op.getOperand(0);
+  SDValue RHS = Op.getOperand(1);
+
+  auto GetIntrinsic = [&](Intrinsic::ID IID, EVT VT, auto... Ops) {
+    return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
+                       DAG.getConstant(IID, DL, MVT::i32), Ops...);
+  };
+
+  SDValue Pg =
+      getPTrue(DAG, DL, VT == MVT::nxv2bf16 ? MVT::nxv2i1 : MVT::nxv4i1,
+               AArch64SVEPredPattern::all);
+  // Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom
+  // instructions. These result in two f32 vectors, which can be converted back
+  // to bf16 with FCVT and FCVNT.
+  SDValue BottomF32 = GetIntrinsic(Intrinsic::aarch64_sve_bfmlalb, MVT::nxv4f32,
+                                   Zero, LHS, RHS);
+  SDValue BottomBF16 = GetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2, VT,
+                                    DAG.getPOISON(VT), Pg, BottomF32);
+  if (VT == MVT::nxv8bf16) {
+    // Note: nxv2bf16 and nxv4bf16 only use even lanes.
+    SDValue TopF32 = GetIntrinsic(Intrinsic::aarch64_sve_bfmlalt, MVT::nxv4f32,
+                                  Zero, LHS, RHS);
+    return GetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2, VT,
+                        BottomBF16, Pg, TopF32);
+  }
+  return BottomBF16;
+}
+
 SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
                                               SelectionDAG &DAG) const {
   LLVM_DEBUG(dbgs() << "Custom lowering: ");
@@ -7612,7 +7655,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
   case ISD::FSUB:
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSUB_PRED);
   case ISD::FMUL:
-    return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
+    return LowerFMUL(Op, DAG);
   case ISD::FMA:
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
   case ISD::FDIV:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 70bfae717fb76..a926a4822d9cc 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -609,6 +609,7 @@ class AArch64TargetLowering : public TargetLowering {
   SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerStore128(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerABS(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerFMUL(SDValue Op, SelectionDAG &DAG) const;
 
   SDValue LowerMGATHER(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerMSCATTER(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll b/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll
index 582e8456c05b3..95db2666dbbba 100644
--- a/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll
+++ b/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll
@@ -1,7 +1,7 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc -mattr=+sve,+bf16                         < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16
+; RUN: llc -mattr=+sve,+bf16                         < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16,NOB16B16-NONSTREAMING
 ; RUN: llc -mattr=+sve,+bf16,+sve-b16b16             < %s | FileCheck %s --check-prefixes=CHECK,B16B16
-; RUN: llc -mattr=+sme,+sve-b16b16 -force-streaming  < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16
+; RUN: llc -mattr=+sme,+sve-b16b16 -force-streaming  < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16,NOB16B16-STREAMING
 ; RUN: llc -mattr=+sme2,+sve-b16b16 -force-streaming < %s | FileCheck %s --check-prefixes=CHECK,B16B16
 
 target triple = "aarch64-unknown-linux-gnu"
@@ -514,64 +514,82 @@ define <vscale x 8 x bfloat> @fmla_nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x
 ;
 
 define <vscale x 2 x bfloat> @fmul_nxv2bf16(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b) {
-; NOB16B16-LABEL: fmul_nxv2bf16:
-; NOB16B16:       // %bb.0:
-; NOB16B16-NEXT:    lsl z1.s, z1.s, #16
-; NOB16B16-NEXT:    lsl z0.s, z0.s, #16
-; NOB16B16-NEXT:    ptrue p0.d
-; NOB16B16-NEXT:    fmul z0.s, p0/m, z0.s, z1.s
-; NOB16B16-NEXT:    bfcvt z0.h, p0/m, z0.s
-; NOB16B16-NEXT:    ret
+; NOB16B16-NONSTREAMING-LABEL: fmul_nxv2bf16:
+; NOB16B16-NONSTREAMING:       // %bb.0:
+; NOB16B16-NONSTREAMING-NEXT:    movi v2.2d, #0000000000000000
+; NOB16B16-NONSTREAMING-NEXT:    ptrue p0.d
+; NOB16B16-NONSTREAMING-NEXT:    bfmlalb z2.s, z0.h, z1.h
+; NOB16B16-NONSTREAMING-NEXT:    bfcvt z0.h, p0/m, z2.s
+; NOB16B16-NONSTREAMING-NEXT:    ret
 ;
 ; B16B16-LABEL: fmul_nxv2bf16:
 ; B16B16:       // %bb.0:
 ; B16B16-NEXT:    bfmul z0.h, z0.h, z1.h
 ; B16B16-NEXT:    ret
+;
+; NOB16B16-STREAMING-LABEL: fmul_nxv2bf16:
+; NOB16B16-STREAMING:       // %bb.0:
+; NOB16B16-STREAMING-NEXT:    mov z2.s, #0 // =0x0
+; NOB16B16-STREAMING-NEXT:    ptrue p0.d
+; NOB16B16-STREAMING-NEXT:    bfmlalb z2.s, z0.h, z1.h
+; NOB16B16-STREAMING-NEXT:    bfcvt z0.h, p0/m, z2.s
+; NOB16B16-STREAMING-NEXT:    ret
   %res = fmul <vscale x 2 x bfloat> %a, %b
   ret <vscale x 2 x bfloat> %res
 }
 
 define <vscale x 4 x bfloat> @fmul_nxv4bf16(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b) {
-; NOB16B16-LABEL: fmul_nxv4bf16:
-; NOB16B16:       // %bb.0:
-; NOB16B16-NEXT:    lsl z1.s, z1.s, #16
-; NOB16B16-NEXT:    lsl z0.s, z0.s, #16
-; NOB16B16-NEXT:    ptrue p0.s
-; NOB16B16-NEXT:    fmul z0.s, z0.s, z1.s
-; NOB16B16-NEXT:    bfcvt z0.h, p0/m, z0.s
-; NOB16B16-NEXT:    ret
+; NOB16B16-NONSTREAMING-LABEL: fmul_nxv4bf16:
+; NOB16B16-NONSTREAMING:       // %bb.0:
+; NOB16B16-NONSTREAMING-NEXT:    movi v2.2d, #0000000000000000
+; NOB16B16-NONSTREAMING-NEXT:    ptrue p0.s
+; NOB16B16-NONSTREAMING-NEXT:    bfmlalb z2.s, z0.h, z1.h
+; NOB16B16-NONSTREAMING-NEXT:    bfcvt z0.h, p0/m, z2.s
+; NOB16B16-NONSTREAMING-NEXT:    ret
 ;
 ; B16B16-LABEL: fmul_nxv4bf16:
 ; B16B16:       // %bb.0:
 ; B16B16-NEXT:    bfmul z0.h, z0.h, z1.h
 ; B16B16-NEXT:    ret
+;
+; NOB16B16-STREAMING-LABEL: fmul_nxv4bf16:
+; NOB16B16-STREAMING:       // %bb.0:
+; NOB16B16-STREAMING-NEXT:    mov z2.s, #0 // =0x0
+; NOB16B16-STREAMING-NEXT:    ptrue p0.s
+; NOB16B16-STREAMING-NEXT:    bfmlalb z2.s, z0.h, z1.h
+; NOB16B16-STREAMING-NEXT:    bfcvt z0.h, p0/m, z2.s
+; NOB16B16-STREAMING-NEXT:    ret
   %res = fmul <vscale x 4 x bfloat> %a, %b
   ret <vscale x 4 x bfloat> %res
 }
 
 define <vscale x 8 x bfloat> @fmul_nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b) {
-; NOB16B16-LABEL: fmul_nxv8bf16:
-; NOB16B16:       // %bb.0:
-; NOB16B16-NEXT:    uunpkhi z2.s, z1.h
-; NOB16B16-NEXT:    uunpkhi z3.s, z0.h
-; NOB16B16-NEXT:    uunpklo z1.s, z1.h
-; NOB16B16-NEXT:    uunpklo z0.s, z0.h
-; NOB16B16-NEXT:    ptrue p0.s
-; NOB16B16-NEXT:    lsl z2.s, z2.s, #16
-; NOB16B16-NEXT:    lsl z3.s, z3.s, #16
-; NOB16B16-NEXT:    lsl z1.s, z1.s, #16
-; NOB16B16-NEXT:    lsl z0.s, z0.s, #16
-; NOB16B16-NEXT:    fmul z2.s, z3.s, z2.s
-; NOB16B16-NEXT:    fmul z0.s, z0.s, z1.s
-; NOB16B16-NEXT:    bfcvt z1.h, p0/m, z2.s
-; NOB16B16-NEXT:    bfcvt z0.h, p0/m, z0.s
-; NOB16B16-NEXT:    uzp1 z0.h, z0.h, z1.h
-; NOB16B16-NEXT:    ret
+; NOB16B16-NONSTREAMING-LABEL: fmul_nxv8bf16:
+; NOB16B16-NONSTREAMING:       // %bb.0:
+; NOB16B16-NONSTREAMING-NEXT:    movi v2.2d, #0000000000000000
+; NOB16B16-NONSTREAMING-NEXT:    movi v3.2d, #0000000000000000
+; NOB16B16-NONSTREAMING-NEXT:    ptrue p0.s
+; NOB16B16-NONSTREAMING-NEXT:    bfmlalb z2.s, z0.h, z1.h
+; NOB16B16-NONSTREAMING-NEXT:    bfmlalt z3.s, z0.h, z1.h
+; NOB16B16-NONSTREAMING-NEXT:    bfcvt z0.h, p0/m, z2.s
+; NOB16B16-NONSTREAMING-NEXT:    bfcvtnt z0.h, p0/m, z3.s
+; NOB16B16-NONSTREAMING-NEXT:    ret
 ;
 ; B16B16-LABEL: fmul_nxv8bf16:
 ; B16B16:       // %bb.0:
 ; B16B16-NEXT:    bfmul z0.h, z0.h, z1.h
 ; B16B16-NEXT:    ret
+;
+; NOB16B16-STREAMING-LABEL: fmul_nxv8bf16:
+; NOB16B16-STREAMING:       // %bb.0:
+; NOB16B16-STREAMING-NEXT:    mov z2.s, #0 // =0x0
+; NOB16B16-STREAMING-NEXT:    mov z3.s, #0 // =0x0
+; NOB16B16-STREAMING-NEXT:    ptrue p0.s
+; NOB16B16-STREAMING-NEXT:    bfmlalb z2.s, z0.h, z1.h
+; NOB16B16-STREAMING-NEXT:    bfmlalt z3.s, z0.h, z1.h
+; NOB16B16-STREAMING-NEXT:    bfcvt z0.h, p0/m, z2.s
+; NOB16B16-STREAMING-NEXT:    bfcvtnt z0.h, p0/m, z3.s
+; NOB16B16-STREAMING-NEXT:    ret
   %res = fmul <vscale x 8 x bfloat> %a, %b
   ret <vscale x 8 x bfloat> %res
 }

>From 9edf68482bfd49130aac67d65b7a4ca0eb81f49b Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 14 Nov 2025 15:19:40 +0000
Subject: [PATCH 2/6] Fixups

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 71 +++++++++++--------
 1 file changed, 40 insertions(+), 31 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index add7397a79c4c..65a61ce5ccc51 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1797,17 +1797,20 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
 
     if (!Subtarget->hasSVEB16B16() ||
         !Subtarget->isNonStreamingSVEorSME2Available()) {
-      for (auto Opcode : {ISD::FADD, ISD::FMA, ISD::FMAXIMUM, ISD::FMAXNUM,
-                          ISD::FMINIMUM, ISD::FMINNUM, ISD::FMUL, ISD::FSUB}) {
-        setOperationPromotedToType(Opcode, MVT::nxv2bf16, MVT::nxv2f32);
-        setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32);
-        setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32);
-      }
-
-      if (Subtarget->hasBF16() &&
-          (Subtarget->hasSVE() || Subtarget->hasSME())) {
-        for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16})
+      for (MVT VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
+        MVT PromotedVT = VT.changeVectorElementType(MVT::f32);
+        setOperationPromotedToType(ISD::FADD, VT, PromotedVT);
+        setOperationPromotedToType(ISD::FMA, VT, PromotedVT);
+        setOperationPromotedToType(ISD::FMAXIMUM, VT, PromotedVT);
+        setOperationPromotedToType(ISD::FMAXNUM, VT, PromotedVT);
+        setOperationPromotedToType(ISD::FMINIMUM, VT, PromotedVT);
+        setOperationPromotedToType(ISD::FMINNUM, VT, PromotedVT);
+        setOperationPromotedToType(ISD::FSUB, VT, PromotedVT);
+
+        if (Subtarget->hasBF16())
           setOperationAction(ISD::FMUL, VT, Custom);
+        else
+          setOperationPromotedToType(ISD::FMUL, VT, PromotedVT);
       }
     }
 
@@ -7545,40 +7548,46 @@ SDValue AArch64TargetLowering::LowerINIT_TRAMPOLINE(SDValue Op,
 }
 
 SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
+  SDLoc DL(Op);
   EVT VT = Op.getValueType();
   auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
   if (VT.getScalarType() != MVT::bf16 ||
-      !(Subtarget.hasBF16() && (Subtarget.hasSVE() || Subtarget.hasSME())))
+      (Subtarget.hasSVEB16B16() &&
+       Subtarget.isNonStreamingSVEorSME2Available()))
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
 
-  SDLoc DL(Op);
+  assert(Subtarget.hasBF16() && "Expected +bf16 for custom FMUL lowering");
+
+  auto MakeGetIntrinsic = [&](Intrinsic::ID IID) {
+    return [&, IID](EVT VT, auto... Ops) {
+      return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
+                         DAG.getConstant(IID, DL, MVT::i32), Ops...);
+    };
+  };
+
+  // Create helpers for building intrinsic calls.
+  auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb);
+  auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt);
+  auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2);
+  auto FCVNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2);
+
   SDValue Zero = DAG.getConstantFP(0.0, DL, MVT::nxv4f32);
   SDValue LHS = Op.getOperand(0);
   SDValue RHS = Op.getOperand(1);
 
-  auto GetIntrinsic = [&](Intrinsic::ID IID, EVT VT, auto... Ops) {
-    return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
-                       DAG.getConstant(IID, DL, MVT::i32), Ops...);
-  };
-
   SDValue Pg =
-      getPTrue(DAG, DL, VT == MVT::nxv2bf16 ? MVT::nxv2i1 : MVT::nxv4i1,
-               AArch64SVEPredPattern::all);
+      DAG.getConstant(1, DL, VT == MVT::nxv2bf16 ? MVT::nxv2i1 : MVT::nxv4i1);
+
   // Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom
   // instructions. These result in two f32 vectors, which can be converted back
   // to bf16 with FCVT and FCVNT.
-  SDValue BottomF32 = GetIntrinsic(Intrinsic::aarch64_sve_bfmlalb, MVT::nxv4f32,
-                                   Zero, LHS, RHS);
-  SDValue BottomBF16 = GetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2, VT,
-                                    DAG.getPOISON(VT), Pg, BottomF32);
-  if (VT == MVT::nxv8bf16) {
-    // Note: nxv2bf16 and nxv4bf16 only use even lanes.
-    SDValue TopF32 = GetIntrinsic(Intrinsic::aarch64_sve_bfmlalt, MVT::nxv4f32,
-                                  Zero, LHS, RHS);
-    return GetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2, VT,
-                        BottomBF16, Pg, TopF32);
-  }
-  return BottomBF16;
+  SDValue BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS);
+  SDValue BottomBF16 = FCVT(VT, DAG.getPOISON(VT), Pg, BottomF32);
+  // Note: nxv2bf16 and nxv4bf16 only use even lanes.
+  if (VT != MVT::nxv8bf16)
+    return BottomBF16;
+  SDValue TopF32 = BFMLALT(MVT::nxv4f32, Zero, LHS, RHS);
+  return FCVNT(VT, BottomBF16, Pg, TopF32);
 }
 
 SDValue AArch64TargetLowering::LowerOperation(SDValue Op,

>From 3a5d7b33e7399e649884285234203a0e0e8ad1aa Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 14 Nov 2025 16:30:09 +0000
Subject: [PATCH 3/6] Use DAG.getNeutralElement()

---
 .../Target/AArch64/AArch64ISelLowering.cpp    |  3 ++-
 llvm/test/CodeGen/AArch64/sve-bf16-arith.ll   | 25 +++++++++++++++++--
 2 files changed, 25 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 65a61ce5ccc51..094726ce5ba7a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -7571,10 +7571,11 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
   auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2);
   auto FCVNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2);
 
-  SDValue Zero = DAG.getConstantFP(0.0, DL, MVT::nxv4f32);
   SDValue LHS = Op.getOperand(0);
   SDValue RHS = Op.getOperand(1);
 
+  SDValue Zero =
+      DAG.getNeutralElement(ISD::FADD, DL, MVT::nxv4f32, Op->getFlags());
   SDValue Pg =
       DAG.getConstant(1, DL, VT == MVT::nxv2bf16 ? MVT::nxv2i1 : MVT::nxv4i1);
 
diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll b/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll
index 95db2666dbbba..7d5aed2898bfd 100644
--- a/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll
+++ b/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll
@@ -534,7 +534,7 @@ define <vscale x 2 x bfloat> @fmul_nxv2bf16(<vscale x 2 x bfloat> %a, <vscale x
 ; NOB16B16-STREAMING-NEXT:    bfmlalb z2.s, z0.h, z1.h
 ; NOB16B16-STREAMING-NEXT:    bfcvt z0.h, p0/m, z2.s
 ; NOB16B16-STREAMING-NEXT:    ret
-  %res = fmul <vscale x 2 x bfloat> %a, %b
+  %res = fmul nsz <vscale x 2 x bfloat> %a, %b
   ret <vscale x 2 x bfloat> %res
 }
 
@@ -559,7 +559,7 @@ define <vscale x 4 x bfloat> @fmul_nxv4bf16(<vscale x 4 x bfloat> %a, <vscale x
 ; NOB16B16-STREAMING-NEXT:    bfmlalb z2.s, z0.h, z1.h
 ; NOB16B16-STREAMING-NEXT:    bfcvt z0.h, p0/m, z2.s
 ; NOB16B16-STREAMING-NEXT:    ret
-  %res = fmul <vscale x 4 x bfloat> %a, %b
+  %res = fmul nsz <vscale x 4 x bfloat> %a, %b
   ret <vscale x 4 x bfloat> %res
 }
 
@@ -590,6 +590,27 @@ define <vscale x 8 x bfloat> @fmul_nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x
 ; NOB16B16-STREAMING-NEXT:    bfcvt z0.h, p0/m, z2.s
 ; NOB16B16-STREAMING-NEXT:    bfcvtnt z0.h, p0/m, z3.s
 ; NOB16B16-STREAMING-NEXT:    ret
+  %res = fmul nsz <vscale x 8 x bfloat> %a, %b
+  ret <vscale x 8 x bfloat> %res
+}
+
+define <vscale x 8 x bfloat> @fmul_nxv8bf16_no_nsz(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b) {
+; NOB16B16-LABEL: fmul_nxv8bf16_no_nsz:
+; NOB16B16:       // %bb.0:
+; NOB16B16-NEXT:    mov w8, #-2147483648 // =0x80000000
+; NOB16B16-NEXT:    ptrue p0.s
+; NOB16B16-NEXT:    mov z2.s, w8
+; NOB16B16-NEXT:    mov z3.d, z2.d
+; NOB16B16-NEXT:    bfmlalb z2.s, z0.h, z1.h
+; NOB16B16-NEXT:    bfmlalt z3.s, z0.h, z1.h
+; NOB16B16-NEXT:    bfcvt z0.h, p0/m, z2.s
+; NOB16B16-NEXT:    bfcvtnt z0.h, p0/m, z3.s
+; NOB16B16-NEXT:    ret
+;
+; B16B16-LABEL: fmul_nxv8bf16_no_nsz:
+; B16B16:       // %bb.0:
+; B16B16-NEXT:    bfmul z0.h, z0.h, z1.h
+; B16B16-NEXT:    ret
   %res = fmul <vscale x 8 x bfloat> %a, %b
   ret <vscale x 8 x bfloat> %res
 }

>From cb0f859dbc0b5e2866359fa0052537289888a726 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 14 Nov 2025 17:17:24 +0000
Subject: [PATCH 4/6] Update bf16-combines checks

---
 .../test/CodeGen/AArch64/sve-bf16-combines.ll | 210 ++++++++----------
 1 file changed, 90 insertions(+), 120 deletions(-)

diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-combines.ll b/llvm/test/CodeGen/AArch64/sve-bf16-combines.ll
index 16e8feb0dc5bb..a22f857e5f985 100644
--- a/llvm/test/CodeGen/AArch64/sve-bf16-combines.ll
+++ b/llvm/test/CodeGen/AArch64/sve-bf16-combines.ll
@@ -414,28 +414,22 @@ define <vscale x 8 x bfloat> @fsub_sel_negzero_nxv8bf16(<vscale x 8 x bfloat> %a
 define <vscale x 8 x bfloat> @fadd_sel_fmul_nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b, <vscale x 8 x bfloat> %c, <vscale x 8 x i1> %mask) {
 ; SVE-LABEL: fadd_sel_fmul_nxv8bf16:
 ; SVE:       // %bb.0:
-; SVE-NEXT:    uunpkhi z3.s, z2.h
-; SVE-NEXT:    uunpkhi z4.s, z1.h
-; SVE-NEXT:    uunpklo z2.s, z2.h
-; SVE-NEXT:    uunpklo z1.s, z1.h
+; SVE-NEXT:    mov w8, #-2147483648 // =0x80000000
 ; SVE-NEXT:    ptrue p1.s
-; SVE-NEXT:    lsl z3.s, z3.s, #16
-; SVE-NEXT:    lsl z4.s, z4.s, #16
-; SVE-NEXT:    lsl z2.s, z2.s, #16
-; SVE-NEXT:    lsl z1.s, z1.s, #16
-; SVE-NEXT:    fmul z3.s, z4.s, z3.s
-; SVE-NEXT:    fmul z1.s, z1.s, z2.s
-; SVE-NEXT:    bfcvt z2.h, p1/m, z3.s
-; SVE-NEXT:    movi v3.2d, #0000000000000000
-; SVE-NEXT:    bfcvt z1.h, p1/m, z1.s
-; SVE-NEXT:    uzp1 z1.h, z1.h, z2.h
-; SVE-NEXT:    sel z1.h, p0, z1.h, z3.h
+; SVE-NEXT:    mov z3.s, w8
+; SVE-NEXT:    mov z4.d, z3.d
+; SVE-NEXT:    bfmlalb z3.s, z1.h, z2.h
+; SVE-NEXT:    bfmlalt z4.s, z1.h, z2.h
+; SVE-NEXT:    movi v2.2d, #0000000000000000
+; SVE-NEXT:    bfcvt z1.h, p1/m, z3.s
 ; SVE-NEXT:    uunpkhi z3.s, z0.h
 ; SVE-NEXT:    uunpklo z0.s, z0.h
-; SVE-NEXT:    uunpkhi z2.s, z1.h
-; SVE-NEXT:    uunpklo z1.s, z1.h
+; SVE-NEXT:    bfcvtnt z1.h, p1/m, z4.s
 ; SVE-NEXT:    lsl z3.s, z3.s, #16
 ; SVE-NEXT:    lsl z0.s, z0.s, #16
+; SVE-NEXT:    sel z1.h, p0, z1.h, z2.h
+; SVE-NEXT:    uunpkhi z2.s, z1.h
+; SVE-NEXT:    uunpklo z1.s, z1.h
 ; SVE-NEXT:    lsl z2.s, z2.s, #16
 ; SVE-NEXT:    lsl z1.s, z1.s, #16
 ; SVE-NEXT:    fadd z2.s, z3.s, z2.s
@@ -461,24 +455,21 @@ define <vscale x 8 x bfloat> @fadd_sel_fmul_nxv8bf16(<vscale x 8 x bfloat> %a, <
 define <vscale x 8 x bfloat> @fsub_sel_fmul_nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b, <vscale x 8 x bfloat> %c, <vscale x 8 x i1> %mask) {
 ; SVE-LABEL: fsub_sel_fmul_nxv8bf16:
 ; SVE:       // %bb.0:
-; SVE-NEXT:    uunpkhi z3.s, z2.h
-; SVE-NEXT:    uunpkhi z4.s, z1.h
-; SVE-NEXT:    uunpklo z2.s, z2.h
-; SVE-NEXT:    uunpklo z1.s, z1.h
+; SVE-NEXT:    mov w8, #-2147483648 // =0x80000000
 ; SVE-NEXT:    ptrue p1.s
-; SVE-NEXT:    lsl z3.s, z3.s, #16
-; SVE-NEXT:    lsl z4.s, z4.s, #16
-; SVE-NEXT:    lsl z2.s, z2.s, #16
-; SVE-NEXT:    lsl z1.s, z1.s, #16
-; SVE-NEXT:    fmul z3.s, z4.s, z3.s
-; SVE-NEXT:    uunpklo z4.s, z0.h
-; SVE-NEXT:    fmul z1.s, z1.s, z2.s
-; SVE-NEXT:    bfcvt z2.h, p1/m, z3.s
+; SVE-NEXT:    mov z3.s, w8
+; SVE-NEXT:    mov z4.d, z3.d
+; SVE-NEXT:    bfmlalb z3.s, z1.h, z2.h
+; SVE-NEXT:    bfmlalt z4.s, z1.h, z2.h
+; SVE-NEXT:    bfcvt z1.h, p1/m, z3.s
 ; SVE-NEXT:    uunpkhi z3.s, z0.h
+; SVE-NEXT:    bfcvtnt z1.h, p1/m, z4.s
+; SVE-NEXT:    uunpklo z4.s, z0.h
+; SVE-NEXT:    lsl z3.s, z3.s, #16
+; SVE-NEXT:    uunpkhi z2.s, z1.h
+; SVE-NEXT:    uunpklo z1.s, z1.h
 ; SVE-NEXT:    lsl z4.s, z4.s, #16
-; SVE-NEXT:    bfcvt z1.h, p1/m, z1.s
 ; SVE-NEXT:    lsl z2.s, z2.s, #16
-; SVE-NEXT:    lsl z3.s, z3.s, #16
 ; SVE-NEXT:    lsl z1.s, z1.s, #16
 ; SVE-NEXT:    fsub z2.s, z3.s, z2.s
 ; SVE-NEXT:    fsub z1.s, z4.s, z1.s
@@ -503,24 +494,21 @@ define <vscale x 8 x bfloat> @fsub_sel_fmul_nxv8bf16(<vscale x 8 x bfloat> %a, <
 define <vscale x 8 x bfloat> @fadd_sel_fmul_nsz_nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b, <vscale x 8 x bfloat> %c, <vscale x 8 x i1> %mask) {
 ; SVE-LABEL: fadd_sel_fmul_nsz_nxv8bf16:
 ; SVE:       // %bb.0:
-; SVE-NEXT:    uunpkhi z3.s, z2.h
-; SVE-NEXT:    uunpkhi z4.s, z1.h
-; SVE-NEXT:    uunpklo z2.s, z2.h
-; SVE-NEXT:    uunpklo z1.s, z1.h
+; SVE-NEXT:    mov w8, #-2147483648 // =0x80000000
 ; SVE-NEXT:    ptrue p1.s
-; SVE-NEXT:    lsl z3.s, z3.s, #16
-; SVE-NEXT:    lsl z4.s, z4.s, #16
-; SVE-NEXT:    lsl z2.s, z2.s, #16
-; SVE-NEXT:    lsl z1.s, z1.s, #16
-; SVE-NEXT:    fmul z3.s, z4.s, z3.s
-; SVE-NEXT:    uunpklo z4.s, z0.h
-; SVE-NEXT:    fmul z1.s, z1.s, z2.s
-; SVE-NEXT:    bfcvt z2.h, p1/m, z3.s
+; SVE-NEXT:    mov z3.s, w8
+; SVE-NEXT:    mov z4.d, z3.d
+; SVE-NEXT:    bfmlalb z3.s, z1.h, z2.h
+; SVE-NEXT:    bfmlalt z4.s, z1.h, z2.h
+; SVE-NEXT:    bfcvt z1.h, p1/m, z3.s
 ; SVE-NEXT:    uunpkhi z3.s, z0.h
+; SVE-NEXT:    bfcvtnt z1.h, p1/m, z4.s
+; SVE-NEXT:    uunpklo z4.s, z0.h
+; SVE-NEXT:    lsl z3.s, z3.s, #16
+; SVE-NEXT:    uunpkhi z2.s, z1.h
+; SVE-NEXT:    uunpklo z1.s, z1.h
 ; SVE-NEXT:    lsl z4.s, z4.s, #16
-; SVE-NEXT:    bfcvt z1.h, p1/m, z1.s
 ; SVE-NEXT:    lsl z2.s, z2.s, #16
-; SVE-NEXT:    lsl z3.s, z3.s, #16
 ; SVE-NEXT:    lsl z1.s, z1.s, #16
 ; SVE-NEXT:    fadd z2.s, z3.s, z2.s
 ; SVE-NEXT:    fadd z1.s, z4.s, z1.s
@@ -545,24 +533,21 @@ define <vscale x 8 x bfloat> @fadd_sel_fmul_nsz_nxv8bf16(<vscale x 8 x bfloat> %
 define <vscale x 8 x bfloat> @fsub_sel_fmul_nsz_nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b, <vscale x 8 x bfloat> %c, <vscale x 8 x i1> %mask) {
 ; SVE-LABEL: fsub_sel_fmul_nsz_nxv8bf16:
 ; SVE:       // %bb.0:
-; SVE-NEXT:    uunpkhi z3.s, z2.h
-; SVE-NEXT:    uunpkhi z4.s, z1.h
-; SVE-NEXT:    uunpklo z2.s, z2.h
-; SVE-NEXT:    uunpklo z1.s, z1.h
+; SVE-NEXT:    mov w8, #-2147483648 // =0x80000000
 ; SVE-NEXT:    ptrue p1.s
-; SVE-NEXT:    lsl z3.s, z3.s, #16
-; SVE-NEXT:    lsl z4.s, z4.s, #16
-; SVE-NEXT:    lsl z2.s, z2.s, #16
-; SVE-NEXT:    lsl z1.s, z1.s, #16
-; SVE-NEXT:    fmul z3.s, z4.s, z3.s
-; SVE-NEXT:    uunpklo z4.s, z0.h
-; SVE-NEXT:    fmul z1.s, z1.s, z2.s
-; SVE-NEXT:    bfcvt z2.h, p1/m, z3.s
+; SVE-NEXT:    mov z3.s, w8
+; SVE-NEXT:    mov z4.d, z3.d
+; SVE-NEXT:    bfmlalb z3.s, z1.h, z2.h
+; SVE-NEXT:    bfmlalt z4.s, z1.h, z2.h
+; SVE-NEXT:    bfcvt z1.h, p1/m, z3.s
 ; SVE-NEXT:    uunpkhi z3.s, z0.h
+; SVE-NEXT:    bfcvtnt z1.h, p1/m, z4.s
+; SVE-NEXT:    uunpklo z4.s, z0.h
+; SVE-NEXT:    lsl z3.s, z3.s, #16
+; SVE-NEXT:    uunpkhi z2.s, z1.h
+; SVE-NEXT:    uunpklo z1.s, z1.h
 ; SVE-NEXT:    lsl z4.s, z4.s, #16
-; SVE-NEXT:    bfcvt z1.h, p1/m, z1.s
 ; SVE-NEXT:    lsl z2.s, z2.s, #16
-; SVE-NEXT:    lsl z3.s, z3.s, #16
 ; SVE-NEXT:    lsl z1.s, z1.s, #16
 ; SVE-NEXT:    fsub z2.s, z3.s, z2.s
 ; SVE-NEXT:    fsub z1.s, z4.s, z1.s
@@ -587,24 +572,21 @@ define <vscale x 8 x bfloat> @fsub_sel_fmul_nsz_nxv8bf16(<vscale x 8 x bfloat> %
 define <vscale x 8 x bfloat> @fadd_sel_fmul_negzero_nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b, <vscale x 8 x bfloat> %c, <vscale x 8 x i1> %mask) {
 ; SVE-LABEL: fadd_sel_fmul_negzero_nxv8bf16:
 ; SVE:       // %bb.0:
-; SVE-NEXT:    uunpkhi z3.s, z2.h
-; SVE-NEXT:    uunpkhi z4.s, z1.h
-; SVE-NEXT:    uunpklo z2.s, z2.h
-; SVE-NEXT:    uunpklo z1.s, z1.h
+; SVE-NEXT:    mov w8, #-2147483648 // =0x80000000
 ; SVE-NEXT:    ptrue p1.s
-; SVE-NEXT:    lsl z3.s, z3.s, #16
-; SVE-NEXT:    lsl z4.s, z4.s, #16
-; SVE-NEXT:    lsl z2.s, z2.s, #16
-; SVE-NEXT:    lsl z1.s, z1.s, #16
-; SVE-NEXT:    fmul z3.s, z4.s, z3.s
-; SVE-NEXT:    uunpklo z4.s, z0.h
-; SVE-NEXT:    fmul z1.s, z1.s, z2.s
-; SVE-NEXT:    bfcvt z2.h, p1/m, z3.s
+; SVE-NEXT:    mov z3.s, w8
+; SVE-NEXT:    mov z4.d, z3.d
+; SVE-NEXT:    bfmlalb z3.s, z1.h, z2.h
+; SVE-NEXT:    bfmlalt z4.s, z1.h, z2.h
+; SVE-NEXT:    bfcvt z1.h, p1/m, z3.s
 ; SVE-NEXT:    uunpkhi z3.s, z0.h
+; SVE-NEXT:    bfcvtnt z1.h, p1/m, z4.s
+; SVE-NEXT:    uunpklo z4.s, z0.h
+; SVE-NEXT:    lsl z3.s, z3.s, #16
+; SVE-NEXT:    uunpkhi z2.s, z1.h
+; SVE-NEXT:    uunpklo z1.s, z1.h
 ; SVE-NEXT:    lsl z4.s, z4.s, #16
-; SVE-NEXT:    bfcvt z1.h, p1/m, z1.s
 ; SVE-NEXT:    lsl z2.s, z2.s, #16
-; SVE-NEXT:    lsl z3.s, z3.s, #16
 ; SVE-NEXT:    lsl z1.s, z1.s, #16
 ; SVE-NEXT:    fadd z2.s, z3.s, z2.s
 ; SVE-NEXT:    fadd z1.s, z4.s, z1.s
@@ -630,30 +612,24 @@ define <vscale x 8 x bfloat> @fadd_sel_fmul_negzero_nxv8bf16(<vscale x 8 x bfloa
 define <vscale x 8 x bfloat> @fsub_sel_fmul_negzero_nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b, <vscale x 8 x bfloat> %c, <vscale x 8 x i1> %mask) {
 ; SVE-LABEL: fsub_sel_fmul_negzero_nxv8bf16:
 ; SVE:       // %bb.0:
-; SVE-NEXT:    uunpkhi z3.s, z2.h
-; SVE-NEXT:    uunpkhi z4.s, z1.h
-; SVE-NEXT:    mov w8, #32768 // =0x8000
-; SVE-NEXT:    uunpklo z2.s, z2.h
-; SVE-NEXT:    uunpklo z1.s, z1.h
+; SVE-NEXT:    mov w8, #-2147483648 // =0x80000000
 ; SVE-NEXT:    ptrue p1.s
-; SVE-NEXT:    lsl z3.s, z3.s, #16
-; SVE-NEXT:    lsl z4.s, z4.s, #16
-; SVE-NEXT:    lsl z2.s, z2.s, #16
-; SVE-NEXT:    lsl z1.s, z1.s, #16
-; SVE-NEXT:    fmul z3.s, z4.s, z3.s
-; SVE-NEXT:    fmul z1.s, z1.s, z2.s
-; SVE-NEXT:    bfcvt z2.h, p1/m, z3.s
-; SVE-NEXT:    fmov h3, w8
-; SVE-NEXT:    bfcvt z1.h, p1/m, z1.s
-; SVE-NEXT:    mov z3.h, h3
-; SVE-NEXT:    uzp1 z1.h, z1.h, z2.h
-; SVE-NEXT:    sel z1.h, p0, z1.h, z3.h
+; SVE-NEXT:    mov z3.s, w8
+; SVE-NEXT:    mov w8, #32768 // =0x8000
+; SVE-NEXT:    mov z4.d, z3.d
+; SVE-NEXT:    bfmlalb z3.s, z1.h, z2.h
+; SVE-NEXT:    bfmlalt z4.s, z1.h, z2.h
+; SVE-NEXT:    fmov h2, w8
+; SVE-NEXT:    bfcvt z1.h, p1/m, z3.s
 ; SVE-NEXT:    uunpkhi z3.s, z0.h
 ; SVE-NEXT:    uunpklo z0.s, z0.h
-; SVE-NEXT:    uunpkhi z2.s, z1.h
-; SVE-NEXT:    uunpklo z1.s, z1.h
+; SVE-NEXT:    mov z2.h, h2
+; SVE-NEXT:    bfcvtnt z1.h, p1/m, z4.s
 ; SVE-NEXT:    lsl z3.s, z3.s, #16
 ; SVE-NEXT:    lsl z0.s, z0.s, #16
+; SVE-NEXT:    sel z1.h, p0, z1.h, z2.h
+; SVE-NEXT:    uunpkhi z2.s, z1.h
+; SVE-NEXT:    uunpklo z1.s, z1.h
 ; SVE-NEXT:    lsl z2.s, z2.s, #16
 ; SVE-NEXT:    lsl z1.s, z1.s, #16
 ; SVE-NEXT:    fsub z2.s, z3.s, z2.s
@@ -682,24 +658,21 @@ define <vscale x 8 x bfloat> @fsub_sel_fmul_negzero_nxv8bf16(<vscale x 8 x bfloa
 define <vscale x 8 x bfloat> @fadd_sel_fmul_negzero_nsz_nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b, <vscale x 8 x bfloat> %c, <vscale x 8 x i1> %mask) {
 ; SVE-LABEL: fadd_sel_fmul_negzero_nsz_nxv8bf16:
 ; SVE:       // %bb.0:
-; SVE-NEXT:    uunpkhi z3.s, z2.h
-; SVE-NEXT:    uunpkhi z4.s, z1.h
-; SVE-NEXT:    uunpklo z2.s, z2.h
-; SVE-NEXT:    uunpklo z1.s, z1.h
+; SVE-NEXT:    mov w8, #-2147483648 // =0x80000000
 ; SVE-NEXT:    ptrue p1.s
-; SVE-NEXT:    lsl z3.s, z3.s, #16
-; SVE-NEXT:    lsl z4.s, z4.s, #16
-; SVE-NEXT:    lsl z2.s, z2.s, #16
-; SVE-NEXT:    lsl z1.s, z1.s, #16
-; SVE-NEXT:    fmul z3.s, z4.s, z3.s
-; SVE-NEXT:    uunpklo z4.s, z0.h
-; SVE-NEXT:    fmul z1.s, z1.s, z2.s
-; SVE-NEXT:    bfcvt z2.h, p1/m, z3.s
+; SVE-NEXT:    mov z3.s, w8
+; SVE-NEXT:    mov z4.d, z3.d
+; SVE-NEXT:    bfmlalb z3.s, z1.h, z2.h
+; SVE-NEXT:    bfmlalt z4.s, z1.h, z2.h
+; SVE-NEXT:    bfcvt z1.h, p1/m, z3.s
 ; SVE-NEXT:    uunpkhi z3.s, z0.h
+; SVE-NEXT:    bfcvtnt z1.h, p1/m, z4.s
+; SVE-NEXT:    uunpklo z4.s, z0.h
+; SVE-NEXT:    lsl z3.s, z3.s, #16
+; SVE-NEXT:    uunpkhi z2.s, z1.h
+; SVE-NEXT:    uunpklo z1.s, z1.h
 ; SVE-NEXT:    lsl z4.s, z4.s, #16
-; SVE-NEXT:    bfcvt z1.h, p1/m, z1.s
 ; SVE-NEXT:    lsl z2.s, z2.s, #16
-; SVE-NEXT:    lsl z3.s, z3.s, #16
 ; SVE-NEXT:    lsl z1.s, z1.s, #16
 ; SVE-NEXT:    fadd z2.s, z3.s, z2.s
 ; SVE-NEXT:    fadd z1.s, z4.s, z1.s
@@ -725,24 +698,21 @@ define <vscale x 8 x bfloat> @fadd_sel_fmul_negzero_nsz_nxv8bf16(<vscale x 8 x b
 define <vscale x 8 x bfloat> @fsub_sel_fmul_negzero_nsz_nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b, <vscale x 8 x bfloat> %c, <vscale x 8 x i1> %mask) {
 ; SVE-LABEL: fsub_sel_fmul_negzero_nsz_nxv8bf16:
 ; SVE:       // %bb.0:
-; SVE-NEXT:    uunpkhi z3.s, z2.h
-; SVE-NEXT:    uunpkhi z4.s, z1.h
-; SVE-NEXT:    uunpklo z2.s, z2.h
-; SVE-NEXT:    uunpklo z1.s, z1.h
+; SVE-NEXT:    mov w8, #-2147483648 // =0x80000000
 ; SVE-NEXT:    ptrue p1.s
-; SVE-NEXT:    lsl z3.s, z3.s, #16
-; SVE-NEXT:    lsl z4.s, z4.s, #16
-; SVE-NEXT:    lsl z2.s, z2.s, #16
-; SVE-NEXT:    lsl z1.s, z1.s, #16
-; SVE-NEXT:    fmul z3.s, z4.s, z3.s
-; SVE-NEXT:    uunpklo z4.s, z0.h
-; SVE-NEXT:    fmul z1.s, z1.s, z2.s
-; SVE-NEXT:    bfcvt z2.h, p1/m, z3.s
+; SVE-NEXT:    mov z3.s, w8
+; SVE-NEXT:    mov z4.d, z3.d
+; SVE-NEXT:    bfmlalb z3.s, z1.h, z2.h
+; SVE-NEXT:    bfmlalt z4.s, z1.h, z2.h
+; SVE-NEXT:    bfcvt z1.h, p1/m, z3.s
 ; SVE-NEXT:    uunpkhi z3.s, z0.h
+; SVE-NEXT:    bfcvtnt z1.h, p1/m, z4.s
+; SVE-NEXT:    uunpklo z4.s, z0.h
+; SVE-NEXT:    lsl z3.s, z3.s, #16
+; SVE-NEXT:    uunpkhi z2.s, z1.h
+; SVE-NEXT:    uunpklo z1.s, z1.h
 ; SVE-NEXT:    lsl z4.s, z4.s, #16
-; SVE-NEXT:    bfcvt z1.h, p1/m, z1.s
 ; SVE-NEXT:    lsl z2.s, z2.s, #16
-; SVE-NEXT:    lsl z3.s, z3.s, #16
 ; SVE-NEXT:    lsl z1.s, z1.s, #16
 ; SVE-NEXT:    fsub z2.s, z3.s, z2.s
 ; SVE-NEXT:    fsub z1.s, z4.s, z1.s

>From a064ab9607027260159db095567833b2492e736b Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 14 Nov 2025 21:00:59 +0000
Subject: [PATCH 5/6] Fix typo

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 094726ce5ba7a..ae154605d09bd 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -7569,7 +7569,7 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
   auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb);
   auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt);
   auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2);
-  auto FCVNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2);
+  auto FCVTNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2);
 
   SDValue LHS = Op.getOperand(0);
   SDValue RHS = Op.getOperand(1);
@@ -7581,14 +7581,14 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
 
   // Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom
   // instructions. These result in two f32 vectors, which can be converted back
-  // to bf16 with FCVT and FCVNT.
+  // to bf16 with FCVT and FCVTNT.
   SDValue BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS);
   SDValue BottomBF16 = FCVT(VT, DAG.getPOISON(VT), Pg, BottomF32);
   // Note: nxv2bf16 and nxv4bf16 only use even lanes.
   if (VT != MVT::nxv8bf16)
     return BottomBF16;
   SDValue TopF32 = BFMLALT(MVT::nxv4f32, Zero, LHS, RHS);
-  return FCVNT(VT, BottomBF16, Pg, TopF32);
+  return FCVTNT(VT, BottomBF16, Pg, TopF32);
 }
 
 SDValue AArch64TargetLowering::LowerOperation(SDValue Op,

>From 5bef05840cd5cfa094b6bb83d1fc8db6c6ff2606 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 19 Nov 2025 13:20:43 +0000
Subject: [PATCH 6/6] Fixups

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 32 +++++++++++--------
 llvm/test/CodeGen/AArch64/sve-bf16-arith.ll   | 23 +++++--------
 2 files changed, 27 insertions(+), 28 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index ae154605d09bd..7d769b3803d03 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1807,7 +1807,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
         setOperationPromotedToType(ISD::FMINNUM, VT, PromotedVT);
         setOperationPromotedToType(ISD::FSUB, VT, PromotedVT);
 
-        if (Subtarget->hasBF16())
+        if (VT != MVT::nxv2bf16 && Subtarget->hasBF16())
           setOperationAction(ISD::FMUL, VT, Custom);
         else
           setOperationPromotedToType(ISD::FMUL, VT, PromotedVT);
@@ -7550,13 +7550,12 @@ SDValue AArch64TargetLowering::LowerINIT_TRAMPOLINE(SDValue Op,
 SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
   SDLoc DL(Op);
   EVT VT = Op.getValueType();
-  auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
   if (VT.getScalarType() != MVT::bf16 ||
-      (Subtarget.hasSVEB16B16() &&
-       Subtarget.isNonStreamingSVEorSME2Available()))
+      (Subtarget->hasSVEB16B16() &&
+       Subtarget->isNonStreamingSVEorSME2Available()))
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
 
-  assert(Subtarget.hasBF16() && "Expected +bf16 for custom FMUL lowering");
+  assert(Subtarget->hasBF16() && "Expected +bf16 for custom FMUL lowering");
 
   auto MakeGetIntrinsic = [&](Intrinsic::ID IID) {
     return [&, IID](EVT VT, auto... Ops) {
@@ -7565,28 +7564,35 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
     };
   };
 
+  auto ReinterpretCast = [&](SDValue Value, EVT VT) {
+    if (VT == Value.getValueType())
+      return Value;
+    return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Value);
+  };
+
   // Create helpers for building intrinsic calls.
   auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb);
   auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt);
   auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2);
   auto FCVTNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2);
 
-  SDValue LHS = Op.getOperand(0);
-  SDValue RHS = Op.getOperand(1);
+  // All intrinsics expect to operate on full bf16 vector types.
+  SDValue LHS = ReinterpretCast(Op.getOperand(0), MVT::nxv8bf16);
+  SDValue RHS = ReinterpretCast(Op.getOperand(1), MVT::nxv8bf16);
 
   SDValue Zero =
       DAG.getNeutralElement(ISD::FADD, DL, MVT::nxv4f32, Op->getFlags());
-  SDValue Pg =
-      DAG.getConstant(1, DL, VT == MVT::nxv2bf16 ? MVT::nxv2i1 : MVT::nxv4i1);
+  SDValue Pg = DAG.getConstant(1, DL, MVT::nxv4i1);
 
   // Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom
   // instructions. These result in two f32 vectors, which can be converted back
   // to bf16 with FCVT and FCVTNT.
   SDValue BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS);
-  SDValue BottomBF16 = FCVT(VT, DAG.getPOISON(VT), Pg, BottomF32);
-  // Note: nxv2bf16 and nxv4bf16 only use even lanes.
-  if (VT != MVT::nxv8bf16)
-    return BottomBF16;
+  SDValue BottomBF16 =
+      FCVT(MVT::nxv8bf16, DAG.getPOISON(MVT::nxv8bf16), Pg, BottomF32);
+  // Note: nxv4bf16 only uses even lanes.
+  if (VT == MVT::nxv4bf16)
+    return ReinterpretCast(BottomBF16, VT);
   SDValue TopF32 = BFMLALT(MVT::nxv4f32, Zero, LHS, RHS);
   return FCVTNT(VT, BottomBF16, Pg, TopF32);
 }
diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll b/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll
index 7d5aed2898bfd..f9441e0151f86 100644
--- a/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll
+++ b/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll
@@ -514,26 +514,19 @@ define <vscale x 8 x bfloat> @fmla_nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x
 ;
 
 define <vscale x 2 x bfloat> @fmul_nxv2bf16(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b) {
-; NOB16B16-NONSTREAMING-LABEL: fmul_nxv2bf16:
-; NOB16B16-NONSTREAMING:       // %bb.0:
-; NOB16B16-NONSTREAMING-NEXT:    movi v2.2d, #0000000000000000
-; NOB16B16-NONSTREAMING-NEXT:    ptrue p0.d
-; NOB16B16-NONSTREAMING-NEXT:    bfmlalb z2.s, z0.h, z1.h
-; NOB16B16-NONSTREAMING-NEXT:    bfcvt z0.h, p0/m, z2.s
-; NOB16B16-NONSTREAMING-NEXT:    ret
+; NOB16B16-LABEL: fmul_nxv2bf16:
+; NOB16B16:       // %bb.0:
+; NOB16B16-NEXT:    lsl z1.s, z1.s, #16
+; NOB16B16-NEXT:    lsl z0.s, z0.s, #16
+; NOB16B16-NEXT:    ptrue p0.d
+; NOB16B16-NEXT:    fmul z0.s, p0/m, z0.s, z1.s
+; NOB16B16-NEXT:    bfcvt z0.h, p0/m, z0.s
+; NOB16B16-NEXT:    ret
 ;
 ; B16B16-LABEL: fmul_nxv2bf16:
 ; B16B16:       // %bb.0:
 ; B16B16-NEXT:    bfmul z0.h, z0.h, z1.h
 ; B16B16-NEXT:    ret
-;
-; NOB16B16-STREAMING-LABEL: fmul_nxv2bf16:
-; NOB16B16-STREAMING:       // %bb.0:
-; NOB16B16-STREAMING-NEXT:    mov z2.s, #0 // =0x0
-; NOB16B16-STREAMING-NEXT:    ptrue p0.d
-; NOB16B16-STREAMING-NEXT:    bfmlalb z2.s, z0.h, z1.h
-; NOB16B16-STREAMING-NEXT:    bfcvt z0.h, p0/m, z2.s
-; NOB16B16-STREAMING-NEXT:    ret
   %res = fmul nsz <vscale x 2 x bfloat> %a, %b
   ret <vscale x 2 x bfloat> %res
 }



More information about the llvm-commits mailing list