[llvm] 32ff710 - [AArch64] Lower v8bf16 FMUL to BFMLAL top/bottom with +sve (#169655)

via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 8 03:56:22 PST 2025


Author: Benjamin Maxwell
Date: 2025-12-08T11:56:18Z
New Revision: 32ff7100d737bcfce2f713dd9838df88bdd3b631

URL: https://github.com/llvm/llvm-project/commit/32ff7100d737bcfce2f713dd9838df88bdd3b631
DIFF: https://github.com/llvm/llvm-project/commit/32ff7100d737bcfce2f713dd9838df88bdd3b631.diff

LOG: [AArch64] Lower v8bf16 FMUL to BFMLAL top/bottom with +sve (#169655)

Assuming the predicate is hoisted, this should have a slightly better
throughput: https://godbolt.org/z/jb7aP7Efc

Note: SVE must be used to convert back to bf16 as the bfmlalb/t
instructions operate on even/odd lanes, but the neon bfcvtn/2 process
the top/bottom halves of vectors.

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll
    llvm/test/CodeGen/AArch64/fixed-length-bf16-arith.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7199319ccdd9f..bf0b6614e5e18 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1834,6 +1834,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
         else
           setOperationPromotedToType(ISD::FMUL, VT, PromotedVT);
       }
+
+      if (Subtarget->hasBF16() && Subtarget->isNeonAvailable())
+        setOperationAction(ISD::FMUL, MVT::v8bf16, Custom);
     }
 
     setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
@@ -7742,7 +7745,8 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
 
   assert(Subtarget->hasBF16() && "Expected +bf16 for custom FMUL lowering");
-  assert((VT == MVT::nxv4bf16 || VT == MVT::nxv8bf16) && "Unexpected FMUL VT");
+  assert((VT == MVT::nxv4bf16 || VT == MVT::nxv8bf16 || VT == MVT::v8bf16) &&
+         "Unexpected FMUL VT");
 
   auto MakeGetIntrinsic = [&](Intrinsic::ID IID) {
     return [&, IID](EVT VT, auto... Ops) {
@@ -7751,37 +7755,56 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
     };
   };
 
-  auto ReinterpretCast = [&](SDValue Value, EVT VT) {
-    if (VT == Value.getValueType())
+  auto Reinterpret = [&](SDValue Value, EVT VT) {
+    EVT SrcVT = Value.getValueType();
+    if (VT == SrcVT)
       return Value;
+    if (SrcVT.isFixedLengthVector())
+      return convertToScalableVector(DAG, VT, Value);
+    if (VT.isFixedLengthVector())
+      return convertFromScalableVector(DAG, VT, Value);
     return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Value);
   };
 
-  // Create helpers for building intrinsic calls.
-  auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb);
-  auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt);
+  bool UseSVEBFMLAL = VT.isScalableVector();
   auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2);
   auto FCVTNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2);
 
-  // All intrinsics expect to operate on full bf16 vector types.
-  SDValue LHS = ReinterpretCast(Op.getOperand(0), MVT::nxv8bf16);
-  SDValue RHS = ReinterpretCast(Op.getOperand(1), MVT::nxv8bf16);
+  // Note: The NEON BFMLAL[BT] reads even/odd lanes like the SVE variant.
+  // This does not match BFCVTN[2], so we use SVE to convert back to bf16.
+  auto BFMLALB =
+      MakeGetIntrinsic(UseSVEBFMLAL ? Intrinsic::aarch64_sve_bfmlalb
+                                    : Intrinsic::aarch64_neon_bfmlalb);
+  auto BFMLALT =
+      MakeGetIntrinsic(UseSVEBFMLAL ? Intrinsic::aarch64_sve_bfmlalt
+                                    : Intrinsic::aarch64_neon_bfmlalt);
 
-  SDValue Zero =
-      DAG.getNeutralElement(ISD::FADD, DL, MVT::nxv4f32, Op->getFlags());
-  SDValue Pg = DAG.getConstant(1, DL, MVT::nxv4i1);
+  EVT AccVT = UseSVEBFMLAL ? MVT::nxv4f32 : MVT::v4f32;
+  SDValue Zero = DAG.getNeutralElement(ISD::FADD, DL, AccVT, Op->getFlags());
+  SDValue Pg = getPredicateForVector(DAG, DL, AccVT);
 
-  // Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom
+  // Lower bf16 FMUL as a pair (VT == [nx]v8bf16) of BFMLAL top/bottom
   // instructions. These result in two f32 vectors, which can be converted back
   // to bf16 with FCVT and FCVTNT.
-  SDValue BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS);
+  SDValue LHS = Op.getOperand(0);
+  SDValue RHS = Op.getOperand(1);
+
+  // All SVE intrinsics expect to operate on full bf16 vector types.
+  if (UseSVEBFMLAL) {
+    LHS = Reinterpret(LHS, MVT::nxv8bf16);
+    RHS = Reinterpret(RHS, MVT::nxv8bf16);
+  }
+
+  SDValue BottomF32 = Reinterpret(BFMLALB(AccVT, Zero, LHS, RHS), MVT::nxv4f32);
   SDValue BottomBF16 =
       FCVT(MVT::nxv8bf16, DAG.getPOISON(MVT::nxv8bf16), Pg, BottomF32);
   // Note: nxv4bf16 only uses even lanes.
   if (VT == MVT::nxv4bf16)
-    return ReinterpretCast(BottomBF16, VT);
-  SDValue TopF32 = BFMLALT(MVT::nxv4f32, Zero, LHS, RHS);
-  return FCVTNT(VT, BottomBF16, Pg, TopF32);
+    return Reinterpret(BottomBF16, VT);
+
+  SDValue TopF32 = Reinterpret(BFMLALT(AccVT, Zero, LHS, RHS), MVT::nxv4f32);
+  SDValue TopBF16 = FCVTNT(MVT::nxv8bf16, BottomBF16, Pg, TopF32);
+  return Reinterpret(TopBF16, VT);
 }
 
 SDValue AArch64TargetLowering::LowerFMA(SDValue Op, SelectionDAG &DAG) const {

diff  --git a/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll b/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll
index 6a7a4cbd8b20a..e3c0d97c08f54 100644
--- a/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll
+++ b/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll
@@ -1,6 +1,7 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
 ; RUN: llc < %s -mtriple=aarch64 -mattr=-bf16 | FileCheck %s --check-prefixes=CHECK,CHECK-CVT
-; RUN: llc < %s -mtriple=aarch64 -mattr=+bf16 | FileCheck %s --check-prefixes=CHECK,CHECK-BF16
+; RUN: llc < %s -mtriple=aarch64 -mattr=+bf16 | FileCheck %s --check-prefixes=CHECK,CHECK-BF16,CHECK-NOSVE-BF16
+; RUN: llc < %s -mtriple=aarch64 -mattr=+bf16,+sve | FileCheck %s --check-prefixes=CHECK,CHECK-BF16,CHECK-SVE-BF16
 
 define <8 x bfloat> @add_h(<8 x bfloat> %a, <8 x bfloat> %b) {
 ; CHECK-CVT-LABEL: add_h:
@@ -117,17 +118,29 @@ define <8 x bfloat> @mul_h(<8 x bfloat> %a, <8 x bfloat> %b) {
 ; CHECK-CVT-NEXT:    uzp2 v0.8h, v0.8h, v2.8h
 ; CHECK-CVT-NEXT:    ret
 ;
-; CHECK-BF16-LABEL: mul_h:
-; CHECK-BF16:       // %bb.0: // %entry
-; CHECK-BF16-NEXT:    shll v2.4s, v1.4h, #16
-; CHECK-BF16-NEXT:    shll v3.4s, v0.4h, #16
-; CHECK-BF16-NEXT:    shll2 v1.4s, v1.8h, #16
-; CHECK-BF16-NEXT:    shll2 v0.4s, v0.8h, #16
-; CHECK-BF16-NEXT:    fmul v2.4s, v3.4s, v2.4s
-; CHECK-BF16-NEXT:    fmul v1.4s, v0.4s, v1.4s
-; CHECK-BF16-NEXT:    bfcvtn v0.4h, v2.4s
-; CHECK-BF16-NEXT:    bfcvtn2 v0.8h, v1.4s
-; CHECK-BF16-NEXT:    ret
+; CHECK-NOSVE-BF16-LABEL: mul_h:
+; CHECK-NOSVE-BF16:       // %bb.0: // %entry
+; CHECK-NOSVE-BF16-NEXT:    shll v2.4s, v1.4h, #16
+; CHECK-NOSVE-BF16-NEXT:    shll v3.4s, v0.4h, #16
+; CHECK-NOSVE-BF16-NEXT:    shll2 v1.4s, v1.8h, #16
+; CHECK-NOSVE-BF16-NEXT:    shll2 v0.4s, v0.8h, #16
+; CHECK-NOSVE-BF16-NEXT:    fmul v2.4s, v3.4s, v2.4s
+; CHECK-NOSVE-BF16-NEXT:    fmul v1.4s, v0.4s, v1.4s
+; CHECK-NOSVE-BF16-NEXT:    bfcvtn v0.4h, v2.4s
+; CHECK-NOSVE-BF16-NEXT:    bfcvtn2 v0.8h, v1.4s
+; CHECK-NOSVE-BF16-NEXT:    ret
+;
+; CHECK-SVE-BF16-LABEL: mul_h:
+; CHECK-SVE-BF16:       // %bb.0: // %entry
+; CHECK-SVE-BF16-NEXT:    movi v2.4s, #128, lsl #24
+; CHECK-SVE-BF16-NEXT:    movi v3.4s, #128, lsl #24
+; CHECK-SVE-BF16-NEXT:    ptrue p0.s, vl4
+; CHECK-SVE-BF16-NEXT:    bfmlalb v2.4s, v0.8h, v1.8h
+; CHECK-SVE-BF16-NEXT:    bfmlalt v3.4s, v0.8h, v1.8h
+; CHECK-SVE-BF16-NEXT:    bfcvt z2.h, p0/m, z2.s
+; CHECK-SVE-BF16-NEXT:    bfcvtnt z2.h, p0/m, z3.s
+; CHECK-SVE-BF16-NEXT:    mov v0.16b, v2.16b
+; CHECK-SVE-BF16-NEXT:    ret
 entry:
   %0 = fmul <8 x bfloat> %a, %b
   ret <8 x bfloat> %0

diff  --git a/llvm/test/CodeGen/AArch64/fixed-length-bf16-arith.ll b/llvm/test/CodeGen/AArch64/fixed-length-bf16-arith.ll
index e6344b9eb89dc..45f8b2fa95a83 100644
--- a/llvm/test/CodeGen/AArch64/fixed-length-bf16-arith.ll
+++ b/llvm/test/CodeGen/AArch64/fixed-length-bf16-arith.ll
@@ -761,14 +761,14 @@ define <4 x bfloat> @fmul_v4bf16(<4 x bfloat> %a, <4 x bfloat> %b) {
 define <8 x bfloat> @fmul_v8bf16(<8 x bfloat> %a, <8 x bfloat> %b) {
 ; NOB16B16-LABEL: fmul_v8bf16:
 ; NOB16B16:       // %bb.0:
-; NOB16B16-NEXT:    shll v2.4s, v1.4h, #16
-; NOB16B16-NEXT:    shll v3.4s, v0.4h, #16
-; NOB16B16-NEXT:    shll2 v1.4s, v1.8h, #16
-; NOB16B16-NEXT:    shll2 v0.4s, v0.8h, #16
-; NOB16B16-NEXT:    fmul v2.4s, v3.4s, v2.4s
-; NOB16B16-NEXT:    fmul v1.4s, v0.4s, v1.4s
-; NOB16B16-NEXT:    bfcvtn v0.4h, v2.4s
-; NOB16B16-NEXT:    bfcvtn2 v0.8h, v1.4s
+; NOB16B16-NEXT:    movi v2.4s, #128, lsl #24
+; NOB16B16-NEXT:    movi v3.4s, #128, lsl #24
+; NOB16B16-NEXT:    ptrue p0.s, vl4
+; NOB16B16-NEXT:    bfmlalb v2.4s, v0.8h, v1.8h
+; NOB16B16-NEXT:    bfmlalt v3.4s, v0.8h, v1.8h
+; NOB16B16-NEXT:    bfcvt z2.h, p0/m, z2.s
+; NOB16B16-NEXT:    bfcvtnt z2.h, p0/m, z3.s
+; NOB16B16-NEXT:    mov v0.16b, v2.16b
 ; NOB16B16-NEXT:    ret
 ;
 ; B16B16-LABEL: fmul_v8bf16:


        


More information about the llvm-commits mailing list