[llvm] [AArch64][SVE] Add custom lowering for bfloat FMUL (with +bf16) (PR #167502)
Paul Walker via llvm-commits
llvm-commits at lists.llvm.org
Mon Nov 17 09:41:15 PST 2025
================
@@ -7538,6 +7547,50 @@ SDValue AArch64TargetLowering::LowerINIT_TRAMPOLINE(SDValue Op,
EndOfTrmp);
}
+SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ EVT VT = Op.getValueType();
+ auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
+ if (VT.getScalarType() != MVT::bf16 ||
+ (Subtarget.hasSVEB16B16() &&
+ Subtarget.isNonStreamingSVEorSME2Available()))
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
+
+ assert(Subtarget.hasBF16() && "Expected +bf16 for custom FMUL lowering");
+
+ auto MakeGetIntrinsic = [&](Intrinsic::ID IID) {
+ return [&, IID](EVT VT, auto... Ops) {
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
+ DAG.getConstant(IID, DL, MVT::i32), Ops...);
+ };
+ };
+
+ // Create helpers for building intrinsic calls.
+ auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb);
+ auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt);
+ auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2);
+ auto FCVTNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2);
+
+ SDValue LHS = Op.getOperand(0);
+ SDValue RHS = Op.getOperand(1);
+
+ SDValue Zero =
+ DAG.getNeutralElement(ISD::FADD, DL, MVT::nxv4f32, Op->getFlags());
+ SDValue Pg =
+ DAG.getConstant(1, DL, VT == MVT::nxv2bf16 ? MVT::nxv2i1 : MVT::nxv4i1);
+
+ // Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom
+ // instructions. These result in two f32 vectors, which can be converted back
+ // to bf16 with FCVT and FCVTNT.
+ SDValue BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS);
+ SDValue BottomBF16 = FCVT(VT, DAG.getPOISON(VT), Pg, BottomF32);
----------------
paulwalker-arm wrote:
Someday I'll investigate why these cases do not blow up, but there's a type mismatch here because the aarch64_sve_fcvt_bf16f32_v2 expects a nxv8bf16 passthrough and result. For the nxv4bf16 case you want to use nxv8bf16 and reinterpret the result, or use a stock FP_ROUND and then reinterpret its result for the nxv8bf16 case.
https://github.com/llvm/llvm-project/pull/167502
More information about the llvm-commits
mailing list