[llvm] 32ff710 - [AArch64] Lower v8bf16 FMUL to BFMLAL top/bottom with +sve (#169655)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Dec 8 03:56:22 PST 2025
Author: Benjamin Maxwell
Date: 2025-12-08T11:56:18Z
New Revision: 32ff7100d737bcfce2f713dd9838df88bdd3b631
URL: https://github.com/llvm/llvm-project/commit/32ff7100d737bcfce2f713dd9838df88bdd3b631
DIFF: https://github.com/llvm/llvm-project/commit/32ff7100d737bcfce2f713dd9838df88bdd3b631.diff
LOG: [AArch64] Lower v8bf16 FMUL to BFMLAL top/bottom with +sve (#169655)
Assuming the predicate is hoisted, this should have a slightly better
throughput: https://godbolt.org/z/jb7aP7Efc
Note: SVE must be used to convert back to bf16 as the bfmlalb/t
instructions operate on even/odd lanes, but the neon bfcvtn/2 process
the top/bottom halves of vectors.
Added:
Modified:
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll
llvm/test/CodeGen/AArch64/fixed-length-bf16-arith.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7199319ccdd9f..bf0b6614e5e18 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1834,6 +1834,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
else
setOperationPromotedToType(ISD::FMUL, VT, PromotedVT);
}
+
+ if (Subtarget->hasBF16() && Subtarget->isNeonAvailable())
+ setOperationAction(ISD::FMUL, MVT::v8bf16, Custom);
}
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
@@ -7742,7 +7745,8 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
assert(Subtarget->hasBF16() && "Expected +bf16 for custom FMUL lowering");
- assert((VT == MVT::nxv4bf16 || VT == MVT::nxv8bf16) && "Unexpected FMUL VT");
+ assert((VT == MVT::nxv4bf16 || VT == MVT::nxv8bf16 || VT == MVT::v8bf16) &&
+ "Unexpected FMUL VT");
auto MakeGetIntrinsic = [&](Intrinsic::ID IID) {
return [&, IID](EVT VT, auto... Ops) {
@@ -7751,37 +7755,56 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
};
};
- auto ReinterpretCast = [&](SDValue Value, EVT VT) {
- if (VT == Value.getValueType())
+ auto Reinterpret = [&](SDValue Value, EVT VT) {
+ EVT SrcVT = Value.getValueType();
+ if (VT == SrcVT)
return Value;
+ if (SrcVT.isFixedLengthVector())
+ return convertToScalableVector(DAG, VT, Value);
+ if (VT.isFixedLengthVector())
+ return convertFromScalableVector(DAG, VT, Value);
return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Value);
};
- // Create helpers for building intrinsic calls.
- auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb);
- auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt);
+ bool UseSVEBFMLAL = VT.isScalableVector();
auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2);
auto FCVTNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2);
- // All intrinsics expect to operate on full bf16 vector types.
- SDValue LHS = ReinterpretCast(Op.getOperand(0), MVT::nxv8bf16);
- SDValue RHS = ReinterpretCast(Op.getOperand(1), MVT::nxv8bf16);
+ // Note: The NEON BFMLAL[BT] reads even/odd lanes like the SVE variant.
+ // This does not match BFCVTN[2], so we use SVE to convert back to bf16.
+ auto BFMLALB =
+ MakeGetIntrinsic(UseSVEBFMLAL ? Intrinsic::aarch64_sve_bfmlalb
+ : Intrinsic::aarch64_neon_bfmlalb);
+ auto BFMLALT =
+ MakeGetIntrinsic(UseSVEBFMLAL ? Intrinsic::aarch64_sve_bfmlalt
+ : Intrinsic::aarch64_neon_bfmlalt);
- SDValue Zero =
- DAG.getNeutralElement(ISD::FADD, DL, MVT::nxv4f32, Op->getFlags());
- SDValue Pg = DAG.getConstant(1, DL, MVT::nxv4i1);
+ EVT AccVT = UseSVEBFMLAL ? MVT::nxv4f32 : MVT::v4f32;
+ SDValue Zero = DAG.getNeutralElement(ISD::FADD, DL, AccVT, Op->getFlags());
+ SDValue Pg = getPredicateForVector(DAG, DL, AccVT);
- // Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom
+ // Lower bf16 FMUL as a pair (VT == [nx]v8bf16) of BFMLAL top/bottom
// instructions. These result in two f32 vectors, which can be converted back
// to bf16 with FCVT and FCVTNT.
- SDValue BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS);
+ SDValue LHS = Op.getOperand(0);
+ SDValue RHS = Op.getOperand(1);
+
+ // All SVE intrinsics expect to operate on full bf16 vector types.
+ if (UseSVEBFMLAL) {
+ LHS = Reinterpret(LHS, MVT::nxv8bf16);
+ RHS = Reinterpret(RHS, MVT::nxv8bf16);
+ }
+
+ SDValue BottomF32 = Reinterpret(BFMLALB(AccVT, Zero, LHS, RHS), MVT::nxv4f32);
SDValue BottomBF16 =
FCVT(MVT::nxv8bf16, DAG.getPOISON(MVT::nxv8bf16), Pg, BottomF32);
// Note: nxv4bf16 only uses even lanes.
if (VT == MVT::nxv4bf16)
- return ReinterpretCast(BottomBF16, VT);
- SDValue TopF32 = BFMLALT(MVT::nxv4f32, Zero, LHS, RHS);
- return FCVTNT(VT, BottomBF16, Pg, TopF32);
+ return Reinterpret(BottomBF16, VT);
+
+ SDValue TopF32 = Reinterpret(BFMLALT(AccVT, Zero, LHS, RHS), MVT::nxv4f32);
+ SDValue TopBF16 = FCVTNT(MVT::nxv8bf16, BottomBF16, Pg, TopF32);
+ return Reinterpret(TopBF16, VT);
}
SDValue AArch64TargetLowering::LowerFMA(SDValue Op, SelectionDAG &DAG) const {
diff --git a/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll b/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll
index 6a7a4cbd8b20a..e3c0d97c08f54 100644
--- a/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll
+++ b/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll
@@ -1,6 +1,7 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc < %s -mtriple=aarch64 -mattr=-bf16 | FileCheck %s --check-prefixes=CHECK,CHECK-CVT
-; RUN: llc < %s -mtriple=aarch64 -mattr=+bf16 | FileCheck %s --check-prefixes=CHECK,CHECK-BF16
+; RUN: llc < %s -mtriple=aarch64 -mattr=+bf16 | FileCheck %s --check-prefixes=CHECK,CHECK-BF16,CHECK-NOSVE-BF16
+; RUN: llc < %s -mtriple=aarch64 -mattr=+bf16,+sve | FileCheck %s --check-prefixes=CHECK,CHECK-BF16,CHECK-SVE-BF16
define <8 x bfloat> @add_h(<8 x bfloat> %a, <8 x bfloat> %b) {
; CHECK-CVT-LABEL: add_h:
@@ -117,17 +118,29 @@ define <8 x bfloat> @mul_h(<8 x bfloat> %a, <8 x bfloat> %b) {
; CHECK-CVT-NEXT: uzp2 v0.8h, v0.8h, v2.8h
; CHECK-CVT-NEXT: ret
;
-; CHECK-BF16-LABEL: mul_h:
-; CHECK-BF16: // %bb.0: // %entry
-; CHECK-BF16-NEXT: shll v2.4s, v1.4h, #16
-; CHECK-BF16-NEXT: shll v3.4s, v0.4h, #16
-; CHECK-BF16-NEXT: shll2 v1.4s, v1.8h, #16
-; CHECK-BF16-NEXT: shll2 v0.4s, v0.8h, #16
-; CHECK-BF16-NEXT: fmul v2.4s, v3.4s, v2.4s
-; CHECK-BF16-NEXT: fmul v1.4s, v0.4s, v1.4s
-; CHECK-BF16-NEXT: bfcvtn v0.4h, v2.4s
-; CHECK-BF16-NEXT: bfcvtn2 v0.8h, v1.4s
-; CHECK-BF16-NEXT: ret
+; CHECK-NOSVE-BF16-LABEL: mul_h:
+; CHECK-NOSVE-BF16: // %bb.0: // %entry
+; CHECK-NOSVE-BF16-NEXT: shll v2.4s, v1.4h, #16
+; CHECK-NOSVE-BF16-NEXT: shll v3.4s, v0.4h, #16
+; CHECK-NOSVE-BF16-NEXT: shll2 v1.4s, v1.8h, #16
+; CHECK-NOSVE-BF16-NEXT: shll2 v0.4s, v0.8h, #16
+; CHECK-NOSVE-BF16-NEXT: fmul v2.4s, v3.4s, v2.4s
+; CHECK-NOSVE-BF16-NEXT: fmul v1.4s, v0.4s, v1.4s
+; CHECK-NOSVE-BF16-NEXT: bfcvtn v0.4h, v2.4s
+; CHECK-NOSVE-BF16-NEXT: bfcvtn2 v0.8h, v1.4s
+; CHECK-NOSVE-BF16-NEXT: ret
+;
+; CHECK-SVE-BF16-LABEL: mul_h:
+; CHECK-SVE-BF16: // %bb.0: // %entry
+; CHECK-SVE-BF16-NEXT: movi v2.4s, #128, lsl #24
+; CHECK-SVE-BF16-NEXT: movi v3.4s, #128, lsl #24
+; CHECK-SVE-BF16-NEXT: ptrue p0.s, vl4
+; CHECK-SVE-BF16-NEXT: bfmlalb v2.4s, v0.8h, v1.8h
+; CHECK-SVE-BF16-NEXT: bfmlalt v3.4s, v0.8h, v1.8h
+; CHECK-SVE-BF16-NEXT: bfcvt z2.h, p0/m, z2.s
+; CHECK-SVE-BF16-NEXT: bfcvtnt z2.h, p0/m, z3.s
+; CHECK-SVE-BF16-NEXT: mov v0.16b, v2.16b
+; CHECK-SVE-BF16-NEXT: ret
entry:
%0 = fmul <8 x bfloat> %a, %b
ret <8 x bfloat> %0
diff --git a/llvm/test/CodeGen/AArch64/fixed-length-bf16-arith.ll b/llvm/test/CodeGen/AArch64/fixed-length-bf16-arith.ll
index e6344b9eb89dc..45f8b2fa95a83 100644
--- a/llvm/test/CodeGen/AArch64/fixed-length-bf16-arith.ll
+++ b/llvm/test/CodeGen/AArch64/fixed-length-bf16-arith.ll
@@ -761,14 +761,14 @@ define <4 x bfloat> @fmul_v4bf16(<4 x bfloat> %a, <4 x bfloat> %b) {
define <8 x bfloat> @fmul_v8bf16(<8 x bfloat> %a, <8 x bfloat> %b) {
; NOB16B16-LABEL: fmul_v8bf16:
; NOB16B16: // %bb.0:
-; NOB16B16-NEXT: shll v2.4s, v1.4h, #16
-; NOB16B16-NEXT: shll v3.4s, v0.4h, #16
-; NOB16B16-NEXT: shll2 v1.4s, v1.8h, #16
-; NOB16B16-NEXT: shll2 v0.4s, v0.8h, #16
-; NOB16B16-NEXT: fmul v2.4s, v3.4s, v2.4s
-; NOB16B16-NEXT: fmul v1.4s, v0.4s, v1.4s
-; NOB16B16-NEXT: bfcvtn v0.4h, v2.4s
-; NOB16B16-NEXT: bfcvtn2 v0.8h, v1.4s
+; NOB16B16-NEXT: movi v2.4s, #128, lsl #24
+; NOB16B16-NEXT: movi v3.4s, #128, lsl #24
+; NOB16B16-NEXT: ptrue p0.s, vl4
+; NOB16B16-NEXT: bfmlalb v2.4s, v0.8h, v1.8h
+; NOB16B16-NEXT: bfmlalt v3.4s, v0.8h, v1.8h
+; NOB16B16-NEXT: bfcvt z2.h, p0/m, z2.s
+; NOB16B16-NEXT: bfcvtnt z2.h, p0/m, z3.s
+; NOB16B16-NEXT: mov v0.16b, v2.16b
; NOB16B16-NEXT: ret
;
; B16B16-LABEL: fmul_v8bf16:
More information about the llvm-commits
mailing list