[llvm] [AArch64][SVE] Add custom lowering for bfloat FMUL (with +bf16) (PR #167502)

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 17 09:41:15 PST 2025


================
@@ -7538,6 +7547,50 @@ SDValue AArch64TargetLowering::LowerINIT_TRAMPOLINE(SDValue Op,
                      EndOfTrmp);
 }
 
+SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  EVT VT = Op.getValueType();
+  auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
+  if (VT.getScalarType() != MVT::bf16 ||
+      (Subtarget.hasSVEB16B16() &&
+       Subtarget.isNonStreamingSVEorSME2Available()))
+    return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
+
+  assert(Subtarget.hasBF16() && "Expected +bf16 for custom FMUL lowering");
+
+  auto MakeGetIntrinsic = [&](Intrinsic::ID IID) {
+    return [&, IID](EVT VT, auto... Ops) {
+      return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
+                         DAG.getConstant(IID, DL, MVT::i32), Ops...);
+    };
+  };
+
+  // Create helpers for building intrinsic calls.
+  auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb);
+  auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt);
+  auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2);
+  auto FCVTNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2);
+
+  SDValue LHS = Op.getOperand(0);
+  SDValue RHS = Op.getOperand(1);
+
+  SDValue Zero =
+      DAG.getNeutralElement(ISD::FADD, DL, MVT::nxv4f32, Op->getFlags());
+  SDValue Pg =
+      DAG.getConstant(1, DL, VT == MVT::nxv2bf16 ? MVT::nxv2i1 : MVT::nxv4i1);
+
+  // Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom
+  // instructions. These result in two f32 vectors, which can be converted back
+  // to bf16 with FCVT and FCVTNT.
+  SDValue BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS);
+  SDValue BottomBF16 = FCVT(VT, DAG.getPOISON(VT), Pg, BottomF32);
----------------
paulwalker-arm wrote:

Someday I'll investigate why these cases do not blow up, but there's a type mismatch here because the aarch64_sve_fcvt_bf16f32_v2 expects a nxv8bf16 passthrough and result. For the nxv4bf16 case you want to use nxv8bf16 and reinterpret the result, or use a stock FP_ROUND and then reinterpret its result for the  nxv8bf16 case.

https://github.com/llvm/llvm-project/pull/167502


More information about the llvm-commits mailing list