[llvm] [AArch64] Lower v8bf16 FMUL to BFMLAL top/bottom with +sve (PR #169655)
Paul Walker via llvm-commits
llvm-commits at lists.llvm.org
Wed Dec 3 08:43:37 PST 2025
================
@@ -7697,37 +7701,59 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
};
};
- auto ReinterpretCast = [&](SDValue Value, EVT VT) {
- if (VT == Value.getValueType())
+ auto Reinterpret = [&](SDValue Value, EVT VT) {
+ EVT SrcVT = Value.getValueType();
+ if (VT == SrcVT)
return Value;
+ if (SrcVT.isFixedLengthVector())
+ return convertToScalableVector(DAG, VT, Value);
+ if (VT.isFixedLengthVector())
+ return convertFromScalableVector(DAG, VT, Value);
return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Value);
};
- // Create helpers for building intrinsic calls.
- auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb);
- auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt);
auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2);
auto FCVTNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2);
- // All intrinsics expect to operate on full bf16 vector types.
- SDValue LHS = ReinterpretCast(Op.getOperand(0), MVT::nxv8bf16);
- SDValue RHS = ReinterpretCast(Op.getOperand(1), MVT::nxv8bf16);
-
- SDValue Zero =
- DAG.getNeutralElement(ISD::FADD, DL, MVT::nxv4f32, Op->getFlags());
- SDValue Pg = DAG.getConstant(1, DL, MVT::nxv4i1);
+ EVT AccVT = VT.isFixedLengthVector() ? MVT::v4f32 : MVT::nxv4f32;
+ SDValue Zero = DAG.getNeutralElement(ISD::FADD, DL, AccVT, Op->getFlags());
+ SDValue Pg = getPredicateForVector(DAG, DL, AccVT);
- // Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom
+ // Lower bf16 FMUL as a pair (VT == [nx]v8bf16) of BFMLAL top/bottom
// instructions. These result in two f32 vectors, which can be converted back
// to bf16 with FCVT and FCVTNT.
- SDValue BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS);
+ SDValue TopF32;
+ SDValue BottomF32;
+ if (VT == MVT::v8bf16) {
+ SDValue LHS = Op.getOperand(0);
+ SDValue RHS = Op.getOperand(1);
+
+ auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_neon_bfmlalb);
+ auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_neon_bfmlalt);
+
+ // Note: The NEON BFMLAL[BT] reads even/odd lanes like the SVE variant.
+ // This does not match BFCVTN[2], so we use SVE to convert back to bf16.
+ BottomF32 = Reinterpret(BFMLALB(MVT::v4f32, Zero, LHS, RHS), MVT::nxv4f32);
+ TopF32 = Reinterpret(BFMLALT(MVT::v4f32, Zero, LHS, RHS), MVT::nxv4f32);
+ } else {
+ // All SVE intrinsics expect to operate on full bf16 vector types.
+ SDValue LHS = Reinterpret(Op.getOperand(0), MVT::nxv8bf16);
+ SDValue RHS = Reinterpret(Op.getOperand(1), MVT::nxv8bf16);
+
+ auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb);
+ auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt);
+
+ BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS);
+ TopF32 = BFMLALT(MVT::nxv4f32, Zero, LHS, RHS);
----------------
paulwalker-arm wrote:
It's not a good idea to create unused nodes. I'd suggest just moving the nxv4bf16 part into the else block but then is having two common lines is then worth it? compared to just having a fixed/scalable split?
With the later you can use convert[From,To]ScalableVector directly with the fixed length part, which I kind of prefer because it's clearer what is going on.
I'll leave the solution up to you just as long as we don't create any unused nodes.
https://github.com/llvm/llvm-project/pull/169655
More information about the llvm-commits
mailing list