[llvm] [AArch64] Lower v8bf16 FMUL to BFMLAL top/bottom with +sve (PR #169655)

Benjamin Maxwell via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 5 02:20:42 PST 2025


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/169655

>From 441b21c14686ed15e63333e2e0b55e021a3c1d17 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 25 Nov 2025 16:18:52 +0000
Subject: [PATCH 1/4] [AArch64] Lower v8bf16 FMUL to BFMLAL top/bottom with
 +sve

Assuming the predicate is hoisted, this should have a slightly better
throughput: https://godbolt.org/z/jb7aP7Efc

Note: SVE must be used to convert back to bf16 as the bfmlalb/t
instructions operate on even/odd lanes, but the neon bfcvtn/2 process
the top/bottom halves of vectors.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 62 +++++++++++++------
 .../CodeGen/AArch64/bf16-v8-instructions.ll   | 37 +++++++----
 2 files changed, 69 insertions(+), 30 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7a15d7b75f1b9..64e8edf48cba7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1831,6 +1831,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
         else
           setOperationPromotedToType(ISD::FMUL, VT, PromotedVT);
       }
+
+      if (Subtarget->hasBF16() && Subtarget->isNeonAvailable())
+        setOperationAction(ISD::FMUL, MVT::v8bf16, Custom);
     }
 
     setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
@@ -7739,7 +7742,8 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
 
   assert(Subtarget->hasBF16() && "Expected +bf16 for custom FMUL lowering");
-  assert((VT == MVT::nxv4bf16 || VT == MVT::nxv8bf16) && "Unexpected FMUL VT");
+  assert((VT == MVT::nxv4bf16 || VT == MVT::nxv8bf16 || VT == MVT::v8bf16) &&
+         "Unexpected FMUL VT");
 
   auto MakeGetIntrinsic = [&](Intrinsic::ID IID) {
     return [&, IID](EVT VT, auto... Ops) {
@@ -7748,37 +7752,59 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
     };
   };
 
-  auto ReinterpretCast = [&](SDValue Value, EVT VT) {
-    if (VT == Value.getValueType())
+  auto Reinterpret = [&](SDValue Value, EVT VT) {
+    EVT SrcVT = Value.getValueType();
+    if (VT == SrcVT)
       return Value;
+    if (SrcVT.isFixedLengthVector())
+      return convertToScalableVector(DAG, VT, Value);
+    if (VT.isFixedLengthVector())
+      return convertFromScalableVector(DAG, VT, Value);
     return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Value);
   };
 
-  // Create helpers for building intrinsic calls.
-  auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb);
-  auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt);
   auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2);
   auto FCVTNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2);
 
-  // All intrinsics expect to operate on full bf16 vector types.
-  SDValue LHS = ReinterpretCast(Op.getOperand(0), MVT::nxv8bf16);
-  SDValue RHS = ReinterpretCast(Op.getOperand(1), MVT::nxv8bf16);
-
-  SDValue Zero =
-      DAG.getNeutralElement(ISD::FADD, DL, MVT::nxv4f32, Op->getFlags());
-  SDValue Pg = DAG.getConstant(1, DL, MVT::nxv4i1);
+  EVT AccVT = VT.isFixedLengthVector() ? MVT::v4f32 : MVT::nxv4f32;
+  SDValue Zero = DAG.getNeutralElement(ISD::FADD, DL, AccVT, Op->getFlags());
+  SDValue Pg = getPredicateForVector(DAG, DL, AccVT);
 
-  // Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom
+  // Lower bf16 FMUL as a pair (VT == [nx]v8bf16) of BFMLAL top/bottom
   // instructions. These result in two f32 vectors, which can be converted back
   // to bf16 with FCVT and FCVTNT.
-  SDValue BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS);
+  SDValue TopF32;
+  SDValue BottomF32;
+  if (VT == MVT::v8bf16) {
+    SDValue LHS = Op.getOperand(0);
+    SDValue RHS = Op.getOperand(1);
+
+    auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_neon_bfmlalb);
+    auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_neon_bfmlalt);
+
+    // Note: The NEON BFMLAL[BT] reads even/odd lanes like the SVE variant.
+    // This does not match BFCVTN[2], so we use SVE to convert back to bf16.
+    BottomF32 = Reinterpret(BFMLALB(MVT::v4f32, Zero, LHS, RHS), MVT::nxv4f32);
+    TopF32 = Reinterpret(BFMLALT(MVT::v4f32, Zero, LHS, RHS), MVT::nxv4f32);
+  } else {
+    // All SVE intrinsics expect to operate on full bf16 vector types.
+    SDValue LHS = Reinterpret(Op.getOperand(0), MVT::nxv8bf16);
+    SDValue RHS = Reinterpret(Op.getOperand(1), MVT::nxv8bf16);
+
+    auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb);
+    auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt);
+
+    BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS);
+    TopF32 = BFMLALT(MVT::nxv4f32, Zero, LHS, RHS);
+  }
+
   SDValue BottomBF16 =
       FCVT(MVT::nxv8bf16, DAG.getPOISON(MVT::nxv8bf16), Pg, BottomF32);
   // Note: nxv4bf16 only uses even lanes.
   if (VT == MVT::nxv4bf16)
-    return ReinterpretCast(BottomBF16, VT);
-  SDValue TopF32 = BFMLALT(MVT::nxv4f32, Zero, LHS, RHS);
-  return FCVTNT(VT, BottomBF16, Pg, TopF32);
+    return Reinterpret(BottomBF16, VT);
+  SDValue TopBF16 = FCVTNT(MVT::nxv8bf16, BottomBF16, Pg, TopF32);
+  return Reinterpret(TopBF16, VT);
 }
 
 SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
diff --git a/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll b/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll
index 6a7a4cbd8b20a..e3c0d97c08f54 100644
--- a/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll
+++ b/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll
@@ -1,6 +1,7 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
 ; RUN: llc < %s -mtriple=aarch64 -mattr=-bf16 | FileCheck %s --check-prefixes=CHECK,CHECK-CVT
-; RUN: llc < %s -mtriple=aarch64 -mattr=+bf16 | FileCheck %s --check-prefixes=CHECK,CHECK-BF16
+; RUN: llc < %s -mtriple=aarch64 -mattr=+bf16 | FileCheck %s --check-prefixes=CHECK,CHECK-BF16,CHECK-NOSVE-BF16
+; RUN: llc < %s -mtriple=aarch64 -mattr=+bf16,+sve | FileCheck %s --check-prefixes=CHECK,CHECK-BF16,CHECK-SVE-BF16
 
 define <8 x bfloat> @add_h(<8 x bfloat> %a, <8 x bfloat> %b) {
 ; CHECK-CVT-LABEL: add_h:
@@ -117,17 +118,29 @@ define <8 x bfloat> @mul_h(<8 x bfloat> %a, <8 x bfloat> %b) {
 ; CHECK-CVT-NEXT:    uzp2 v0.8h, v0.8h, v2.8h
 ; CHECK-CVT-NEXT:    ret
 ;
-; CHECK-BF16-LABEL: mul_h:
-; CHECK-BF16:       // %bb.0: // %entry
-; CHECK-BF16-NEXT:    shll v2.4s, v1.4h, #16
-; CHECK-BF16-NEXT:    shll v3.4s, v0.4h, #16
-; CHECK-BF16-NEXT:    shll2 v1.4s, v1.8h, #16
-; CHECK-BF16-NEXT:    shll2 v0.4s, v0.8h, #16
-; CHECK-BF16-NEXT:    fmul v2.4s, v3.4s, v2.4s
-; CHECK-BF16-NEXT:    fmul v1.4s, v0.4s, v1.4s
-; CHECK-BF16-NEXT:    bfcvtn v0.4h, v2.4s
-; CHECK-BF16-NEXT:    bfcvtn2 v0.8h, v1.4s
-; CHECK-BF16-NEXT:    ret
+; CHECK-NOSVE-BF16-LABEL: mul_h:
+; CHECK-NOSVE-BF16:       // %bb.0: // %entry
+; CHECK-NOSVE-BF16-NEXT:    shll v2.4s, v1.4h, #16
+; CHECK-NOSVE-BF16-NEXT:    shll v3.4s, v0.4h, #16
+; CHECK-NOSVE-BF16-NEXT:    shll2 v1.4s, v1.8h, #16
+; CHECK-NOSVE-BF16-NEXT:    shll2 v0.4s, v0.8h, #16
+; CHECK-NOSVE-BF16-NEXT:    fmul v2.4s, v3.4s, v2.4s
+; CHECK-NOSVE-BF16-NEXT:    fmul v1.4s, v0.4s, v1.4s
+; CHECK-NOSVE-BF16-NEXT:    bfcvtn v0.4h, v2.4s
+; CHECK-NOSVE-BF16-NEXT:    bfcvtn2 v0.8h, v1.4s
+; CHECK-NOSVE-BF16-NEXT:    ret
+;
+; CHECK-SVE-BF16-LABEL: mul_h:
+; CHECK-SVE-BF16:       // %bb.0: // %entry
+; CHECK-SVE-BF16-NEXT:    movi v2.4s, #128, lsl #24
+; CHECK-SVE-BF16-NEXT:    movi v3.4s, #128, lsl #24
+; CHECK-SVE-BF16-NEXT:    ptrue p0.s, vl4
+; CHECK-SVE-BF16-NEXT:    bfmlalb v2.4s, v0.8h, v1.8h
+; CHECK-SVE-BF16-NEXT:    bfmlalt v3.4s, v0.8h, v1.8h
+; CHECK-SVE-BF16-NEXT:    bfcvt z2.h, p0/m, z2.s
+; CHECK-SVE-BF16-NEXT:    bfcvtnt z2.h, p0/m, z3.s
+; CHECK-SVE-BF16-NEXT:    mov v0.16b, v2.16b
+; CHECK-SVE-BF16-NEXT:    ret
 entry:
   %0 = fmul <8 x bfloat> %a, %b
   ret <8 x bfloat> %0

>From 2d50b45d92cb148f38ba47f1fa060c0df3b6a159 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 4 Dec 2025 17:40:52 +0000
Subject: [PATCH 2/4] Rework

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 43 +++++++++----------
 1 file changed, 20 insertions(+), 23 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 64e8edf48cba7..6160afee826ea 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -7763,46 +7763,43 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
     return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Value);
   };
 
+  bool UseSVEBFMLAL = VT.isScalableVector();
   auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2);
   auto FCVTNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2);
 
-  EVT AccVT = VT.isFixedLengthVector() ? MVT::v4f32 : MVT::nxv4f32;
+  // Note: The NEON BFMLAL[BT] reads even/odd lanes like the SVE variant.
+  // This does not match BFCVTN[2], so we use SVE to convert back to bf16.
+  auto BFMLALB =
+      MakeGetIntrinsic(UseSVEBFMLAL ? Intrinsic::aarch64_sve_bfmlalb
+                                    : Intrinsic::aarch64_neon_bfmlalb);
+  auto BFMLALT =
+      MakeGetIntrinsic(UseSVEBFMLAL ? Intrinsic::aarch64_sve_bfmlalt
+                                    : Intrinsic::aarch64_neon_bfmlalt);
+
+  EVT AccVT = UseSVEBFMLAL ? MVT::nxv4f32 : MVT::v4f32;
   SDValue Zero = DAG.getNeutralElement(ISD::FADD, DL, AccVT, Op->getFlags());
   SDValue Pg = getPredicateForVector(DAG, DL, AccVT);
 
   // Lower bf16 FMUL as a pair (VT == [nx]v8bf16) of BFMLAL top/bottom
   // instructions. These result in two f32 vectors, which can be converted back
   // to bf16 with FCVT and FCVTNT.
-  SDValue TopF32;
-  SDValue BottomF32;
-  if (VT == MVT::v8bf16) {
-    SDValue LHS = Op.getOperand(0);
-    SDValue RHS = Op.getOperand(1);
-
-    auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_neon_bfmlalb);
-    auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_neon_bfmlalt);
-
-    // Note: The NEON BFMLAL[BT] reads even/odd lanes like the SVE variant.
-    // This does not match BFCVTN[2], so we use SVE to convert back to bf16.
-    BottomF32 = Reinterpret(BFMLALB(MVT::v4f32, Zero, LHS, RHS), MVT::nxv4f32);
-    TopF32 = Reinterpret(BFMLALT(MVT::v4f32, Zero, LHS, RHS), MVT::nxv4f32);
-  } else {
-    // All SVE intrinsics expect to operate on full bf16 vector types.
-    SDValue LHS = Reinterpret(Op.getOperand(0), MVT::nxv8bf16);
-    SDValue RHS = Reinterpret(Op.getOperand(1), MVT::nxv8bf16);
-
-    auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb);
-    auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt);
+  SDValue LHS = Op.getOperand(0);
+  SDValue RHS = Op.getOperand(1);
 
-    BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS);
-    TopF32 = BFMLALT(MVT::nxv4f32, Zero, LHS, RHS);
+  // All SVE intrinsics expect to operate on full bf16 vector types.
+  if (UseSVEBFMLAL) {
+    LHS = Reinterpret(Op.getOperand(0), MVT::nxv8bf16);
+    RHS = Reinterpret(Op.getOperand(1), MVT::nxv8bf16);
   }
 
+  SDValue BottomF32 = Reinterpret(BFMLALB(AccVT, Zero, LHS, RHS), MVT::nxv4f32);
   SDValue BottomBF16 =
       FCVT(MVT::nxv8bf16, DAG.getPOISON(MVT::nxv8bf16), Pg, BottomF32);
   // Note: nxv4bf16 only uses even lanes.
   if (VT == MVT::nxv4bf16)
     return Reinterpret(BottomBF16, VT);
+
+  SDValue TopF32 = Reinterpret(BFMLALT(AccVT, Zero, LHS, RHS), MVT::nxv4f32);
   SDValue TopBF16 = FCVTNT(MVT::nxv8bf16, BottomBF16, Pg, TopF32);
   return Reinterpret(TopBF16, VT);
 }

>From 3aa5ab6497f3f6f47910cd0bc38b795db689a7fa Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 4 Dec 2025 18:16:13 +0000
Subject: [PATCH 3/4] Rebase to update test

---
 .../CodeGen/AArch64/fixed-length-bf16-arith.ll   | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/llvm/test/CodeGen/AArch64/fixed-length-bf16-arith.ll b/llvm/test/CodeGen/AArch64/fixed-length-bf16-arith.ll
index e6344b9eb89dc..45f8b2fa95a83 100644
--- a/llvm/test/CodeGen/AArch64/fixed-length-bf16-arith.ll
+++ b/llvm/test/CodeGen/AArch64/fixed-length-bf16-arith.ll
@@ -761,14 +761,14 @@ define <4 x bfloat> @fmul_v4bf16(<4 x bfloat> %a, <4 x bfloat> %b) {
 define <8 x bfloat> @fmul_v8bf16(<8 x bfloat> %a, <8 x bfloat> %b) {
 ; NOB16B16-LABEL: fmul_v8bf16:
 ; NOB16B16:       // %bb.0:
-; NOB16B16-NEXT:    shll v2.4s, v1.4h, #16
-; NOB16B16-NEXT:    shll v3.4s, v0.4h, #16
-; NOB16B16-NEXT:    shll2 v1.4s, v1.8h, #16
-; NOB16B16-NEXT:    shll2 v0.4s, v0.8h, #16
-; NOB16B16-NEXT:    fmul v2.4s, v3.4s, v2.4s
-; NOB16B16-NEXT:    fmul v1.4s, v0.4s, v1.4s
-; NOB16B16-NEXT:    bfcvtn v0.4h, v2.4s
-; NOB16B16-NEXT:    bfcvtn2 v0.8h, v1.4s
+; NOB16B16-NEXT:    movi v2.4s, #128, lsl #24
+; NOB16B16-NEXT:    movi v3.4s, #128, lsl #24
+; NOB16B16-NEXT:    ptrue p0.s, vl4
+; NOB16B16-NEXT:    bfmlalb v2.4s, v0.8h, v1.8h
+; NOB16B16-NEXT:    bfmlalt v3.4s, v0.8h, v1.8h
+; NOB16B16-NEXT:    bfcvt z2.h, p0/m, z2.s
+; NOB16B16-NEXT:    bfcvtnt z2.h, p0/m, z3.s
+; NOB16B16-NEXT:    mov v0.16b, v2.16b
 ; NOB16B16-NEXT:    ret
 ;
 ; B16B16-LABEL: fmul_v8bf16:

>From 845b6713f75e4b3e1cd29fb544f081b05408a073 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 5 Dec 2025 10:19:54 +0000
Subject: [PATCH 4/4] Tweak

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 6160afee826ea..4b25f9be5a728 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -7788,8 +7788,8 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
 
   // All SVE intrinsics expect to operate on full bf16 vector types.
   if (UseSVEBFMLAL) {
-    LHS = Reinterpret(Op.getOperand(0), MVT::nxv8bf16);
-    RHS = Reinterpret(Op.getOperand(1), MVT::nxv8bf16);
+    LHS = Reinterpret(LHS, MVT::nxv8bf16);
+    RHS = Reinterpret(RHS, MVT::nxv8bf16);
   }
 
   SDValue BottomF32 = Reinterpret(BFMLALB(AccVT, Zero, LHS, RHS), MVT::nxv4f32);



More information about the llvm-commits mailing list