[llvm] [AArch64] Lower v8bf16 FMUL to BFMLAL top/bottom with +sve (PR #169655)

via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 26 06:23:13 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

Assuming the predicate is hoisted, this should have a slightly better throughput: https://godbolt.org/z/jb7aP7Efc

Note: SVE must be used to convert back to bf16 as the bfmlalb/t instructions operate on even/odd lanes, but the neon bfcvtn/2 process the top/bottom halves of vectors.

---
Full diff: https://github.com/llvm/llvm-project/pull/169655.diff


2 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+44-18) 
- (modified) llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll (+25-12) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 83ce39fa314d1..9451a508033a1 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1824,6 +1824,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
         else
           setOperationPromotedToType(ISD::FMUL, VT, PromotedVT);
       }
+
+      if (Subtarget->hasBF16() && Subtarget->isNeonAvailable())
+        setOperationAction(ISD::FMUL, MVT::v8bf16, Custom);
     }
 
     setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
@@ -7688,7 +7691,8 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
 
   assert(Subtarget->hasBF16() && "Expected +bf16 for custom FMUL lowering");
-  assert((VT == MVT::nxv4bf16 || VT == MVT::nxv8bf16) && "Unexpected FMUL VT");
+  assert((VT == MVT::nxv4bf16 || VT == MVT::nxv8bf16 || VT == MVT::v8bf16) &&
+         "Unexpected FMUL VT");
 
   auto MakeGetIntrinsic = [&](Intrinsic::ID IID) {
     return [&, IID](EVT VT, auto... Ops) {
@@ -7697,37 +7701,59 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
     };
   };
 
-  auto ReinterpretCast = [&](SDValue Value, EVT VT) {
-    if (VT == Value.getValueType())
+  auto Reinterpret = [&](SDValue Value, EVT VT) {
+    EVT SrcVT = Value.getValueType();
+    if (VT == SrcVT)
       return Value;
+    if (SrcVT.isFixedLengthVector())
+      return convertToScalableVector(DAG, VT, Value);
+    if (VT.isFixedLengthVector())
+      return convertFromScalableVector(DAG, VT, Value);
     return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Value);
   };
 
-  // Create helpers for building intrinsic calls.
-  auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb);
-  auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt);
   auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2);
   auto FCVTNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2);
 
-  // All intrinsics expect to operate on full bf16 vector types.
-  SDValue LHS = ReinterpretCast(Op.getOperand(0), MVT::nxv8bf16);
-  SDValue RHS = ReinterpretCast(Op.getOperand(1), MVT::nxv8bf16);
-
-  SDValue Zero =
-      DAG.getNeutralElement(ISD::FADD, DL, MVT::nxv4f32, Op->getFlags());
-  SDValue Pg = DAG.getConstant(1, DL, MVT::nxv4i1);
+  EVT AccVT = VT.isFixedLengthVector() ? MVT::v4f32 : MVT::nxv4f32;
+  SDValue Zero = DAG.getNeutralElement(ISD::FADD, DL, AccVT, Op->getFlags());
+  SDValue Pg = getPredicateForVector(DAG, DL, AccVT);
 
-  // Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom
+  // Lower bf16 FMUL as a pair (VT == [nx]v8bf16) of BFMLAL top/bottom
   // instructions. These result in two f32 vectors, which can be converted back
   // to bf16 with FCVT and FCVTNT.
-  SDValue BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS);
+  SDValue TopF32;
+  SDValue BottomF32;
+  if (VT == MVT::v8bf16) {
+    SDValue LHS = Op.getOperand(0);
+    SDValue RHS = Op.getOperand(1);
+
+    auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_neon_bfmlalb);
+    auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_neon_bfmlalt);
+
+    // Note: The NEON BFMLAL[BT] reads even/odd lanes like the SVE variant.
+    // This does not match BFCVTN[2], so we use SVE to convert back to bf16.
+    BottomF32 = Reinterpret(BFMLALB(MVT::v4f32, Zero, LHS, RHS), MVT::nxv4f32);
+    TopF32 = Reinterpret(BFMLALT(MVT::v4f32, Zero, LHS, RHS), MVT::nxv4f32);
+  } else {
+    // All SVE intrinsics expect to operate on full bf16 vector types.
+    SDValue LHS = Reinterpret(Op.getOperand(0), MVT::nxv8bf16);
+    SDValue RHS = Reinterpret(Op.getOperand(1), MVT::nxv8bf16);
+
+    auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb);
+    auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt);
+
+    BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS);
+    TopF32 = BFMLALT(MVT::nxv4f32, Zero, LHS, RHS);
+  }
+
   SDValue BottomBF16 =
       FCVT(MVT::nxv8bf16, DAG.getPOISON(MVT::nxv8bf16), Pg, BottomF32);
   // Note: nxv4bf16 only uses even lanes.
   if (VT == MVT::nxv4bf16)
-    return ReinterpretCast(BottomBF16, VT);
-  SDValue TopF32 = BFMLALT(MVT::nxv4f32, Zero, LHS, RHS);
-  return FCVTNT(VT, BottomBF16, Pg, TopF32);
+    return Reinterpret(BottomBF16, VT);
+  SDValue TopBF16 = FCVTNT(MVT::nxv8bf16, BottomBF16, Pg, TopF32);
+  return Reinterpret(TopBF16, VT);
 }
 
 SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
diff --git a/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll b/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll
index 6a7a4cbd8b20a..e3c0d97c08f54 100644
--- a/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll
+++ b/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll
@@ -1,6 +1,7 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
 ; RUN: llc < %s -mtriple=aarch64 -mattr=-bf16 | FileCheck %s --check-prefixes=CHECK,CHECK-CVT
-; RUN: llc < %s -mtriple=aarch64 -mattr=+bf16 | FileCheck %s --check-prefixes=CHECK,CHECK-BF16
+; RUN: llc < %s -mtriple=aarch64 -mattr=+bf16 | FileCheck %s --check-prefixes=CHECK,CHECK-BF16,CHECK-NOSVE-BF16
+; RUN: llc < %s -mtriple=aarch64 -mattr=+bf16,+sve | FileCheck %s --check-prefixes=CHECK,CHECK-BF16,CHECK-SVE-BF16
 
 define <8 x bfloat> @add_h(<8 x bfloat> %a, <8 x bfloat> %b) {
 ; CHECK-CVT-LABEL: add_h:
@@ -117,17 +118,29 @@ define <8 x bfloat> @mul_h(<8 x bfloat> %a, <8 x bfloat> %b) {
 ; CHECK-CVT-NEXT:    uzp2 v0.8h, v0.8h, v2.8h
 ; CHECK-CVT-NEXT:    ret
 ;
-; CHECK-BF16-LABEL: mul_h:
-; CHECK-BF16:       // %bb.0: // %entry
-; CHECK-BF16-NEXT:    shll v2.4s, v1.4h, #16
-; CHECK-BF16-NEXT:    shll v3.4s, v0.4h, #16
-; CHECK-BF16-NEXT:    shll2 v1.4s, v1.8h, #16
-; CHECK-BF16-NEXT:    shll2 v0.4s, v0.8h, #16
-; CHECK-BF16-NEXT:    fmul v2.4s, v3.4s, v2.4s
-; CHECK-BF16-NEXT:    fmul v1.4s, v0.4s, v1.4s
-; CHECK-BF16-NEXT:    bfcvtn v0.4h, v2.4s
-; CHECK-BF16-NEXT:    bfcvtn2 v0.8h, v1.4s
-; CHECK-BF16-NEXT:    ret
+; CHECK-NOSVE-BF16-LABEL: mul_h:
+; CHECK-NOSVE-BF16:       // %bb.0: // %entry
+; CHECK-NOSVE-BF16-NEXT:    shll v2.4s, v1.4h, #16
+; CHECK-NOSVE-BF16-NEXT:    shll v3.4s, v0.4h, #16
+; CHECK-NOSVE-BF16-NEXT:    shll2 v1.4s, v1.8h, #16
+; CHECK-NOSVE-BF16-NEXT:    shll2 v0.4s, v0.8h, #16
+; CHECK-NOSVE-BF16-NEXT:    fmul v2.4s, v3.4s, v2.4s
+; CHECK-NOSVE-BF16-NEXT:    fmul v1.4s, v0.4s, v1.4s
+; CHECK-NOSVE-BF16-NEXT:    bfcvtn v0.4h, v2.4s
+; CHECK-NOSVE-BF16-NEXT:    bfcvtn2 v0.8h, v1.4s
+; CHECK-NOSVE-BF16-NEXT:    ret
+;
+; CHECK-SVE-BF16-LABEL: mul_h:
+; CHECK-SVE-BF16:       // %bb.0: // %entry
+; CHECK-SVE-BF16-NEXT:    movi v2.4s, #128, lsl #24
+; CHECK-SVE-BF16-NEXT:    movi v3.4s, #128, lsl #24
+; CHECK-SVE-BF16-NEXT:    ptrue p0.s, vl4
+; CHECK-SVE-BF16-NEXT:    bfmlalb v2.4s, v0.8h, v1.8h
+; CHECK-SVE-BF16-NEXT:    bfmlalt v3.4s, v0.8h, v1.8h
+; CHECK-SVE-BF16-NEXT:    bfcvt z2.h, p0/m, z2.s
+; CHECK-SVE-BF16-NEXT:    bfcvtnt z2.h, p0/m, z3.s
+; CHECK-SVE-BF16-NEXT:    mov v0.16b, v2.16b
+; CHECK-SVE-BF16-NEXT:    ret
 entry:
   %0 = fmul <8 x bfloat> %a, %b
   ret <8 x bfloat> %0

``````````

</details>


https://github.com/llvm/llvm-project/pull/169655


More information about the llvm-commits mailing list