[llvm] [AArch64] Lower v8bf16 FMUL to BFMLAL top/bottom with +sve (PR #169655)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Nov 26 06:23:13 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: Benjamin Maxwell (MacDue)
<details>
<summary>Changes</summary>
Assuming the predicate is hoisted, this should have a slightly better throughput: https://godbolt.org/z/jb7aP7Efc
Note: SVE must be used to convert back to bf16 as the bfmlalb/t instructions operate on even/odd lanes, but the neon bfcvtn/2 process the top/bottom halves of vectors.
---
Full diff: https://github.com/llvm/llvm-project/pull/169655.diff
2 Files Affected:
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+44-18)
- (modified) llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll (+25-12)
``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 83ce39fa314d1..9451a508033a1 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1824,6 +1824,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
else
setOperationPromotedToType(ISD::FMUL, VT, PromotedVT);
}
+
+ if (Subtarget->hasBF16() && Subtarget->isNeonAvailable())
+ setOperationAction(ISD::FMUL, MVT::v8bf16, Custom);
}
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
@@ -7688,7 +7691,8 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
assert(Subtarget->hasBF16() && "Expected +bf16 for custom FMUL lowering");
- assert((VT == MVT::nxv4bf16 || VT == MVT::nxv8bf16) && "Unexpected FMUL VT");
+ assert((VT == MVT::nxv4bf16 || VT == MVT::nxv8bf16 || VT == MVT::v8bf16) &&
+ "Unexpected FMUL VT");
auto MakeGetIntrinsic = [&](Intrinsic::ID IID) {
return [&, IID](EVT VT, auto... Ops) {
@@ -7697,37 +7701,59 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
};
};
- auto ReinterpretCast = [&](SDValue Value, EVT VT) {
- if (VT == Value.getValueType())
+ auto Reinterpret = [&](SDValue Value, EVT VT) {
+ EVT SrcVT = Value.getValueType();
+ if (VT == SrcVT)
return Value;
+ if (SrcVT.isFixedLengthVector())
+ return convertToScalableVector(DAG, VT, Value);
+ if (VT.isFixedLengthVector())
+ return convertFromScalableVector(DAG, VT, Value);
return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Value);
};
- // Create helpers for building intrinsic calls.
- auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb);
- auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt);
auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2);
auto FCVTNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2);
- // All intrinsics expect to operate on full bf16 vector types.
- SDValue LHS = ReinterpretCast(Op.getOperand(0), MVT::nxv8bf16);
- SDValue RHS = ReinterpretCast(Op.getOperand(1), MVT::nxv8bf16);
-
- SDValue Zero =
- DAG.getNeutralElement(ISD::FADD, DL, MVT::nxv4f32, Op->getFlags());
- SDValue Pg = DAG.getConstant(1, DL, MVT::nxv4i1);
+ EVT AccVT = VT.isFixedLengthVector() ? MVT::v4f32 : MVT::nxv4f32;
+ SDValue Zero = DAG.getNeutralElement(ISD::FADD, DL, AccVT, Op->getFlags());
+ SDValue Pg = getPredicateForVector(DAG, DL, AccVT);
- // Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom
+ // Lower bf16 FMUL as a pair (VT == [nx]v8bf16) of BFMLAL top/bottom
// instructions. These result in two f32 vectors, which can be converted back
// to bf16 with FCVT and FCVTNT.
- SDValue BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS);
+ SDValue TopF32;
+ SDValue BottomF32;
+ if (VT == MVT::v8bf16) {
+ SDValue LHS = Op.getOperand(0);
+ SDValue RHS = Op.getOperand(1);
+
+ auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_neon_bfmlalb);
+ auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_neon_bfmlalt);
+
+ // Note: The NEON BFMLAL[BT] reads even/odd lanes like the SVE variant.
+ // This does not match BFCVTN[2], so we use SVE to convert back to bf16.
+ BottomF32 = Reinterpret(BFMLALB(MVT::v4f32, Zero, LHS, RHS), MVT::nxv4f32);
+ TopF32 = Reinterpret(BFMLALT(MVT::v4f32, Zero, LHS, RHS), MVT::nxv4f32);
+ } else {
+ // All SVE intrinsics expect to operate on full bf16 vector types.
+ SDValue LHS = Reinterpret(Op.getOperand(0), MVT::nxv8bf16);
+ SDValue RHS = Reinterpret(Op.getOperand(1), MVT::nxv8bf16);
+
+ auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb);
+ auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt);
+
+ BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS);
+ TopF32 = BFMLALT(MVT::nxv4f32, Zero, LHS, RHS);
+ }
+
SDValue BottomBF16 =
FCVT(MVT::nxv8bf16, DAG.getPOISON(MVT::nxv8bf16), Pg, BottomF32);
// Note: nxv4bf16 only uses even lanes.
if (VT == MVT::nxv4bf16)
- return ReinterpretCast(BottomBF16, VT);
- SDValue TopF32 = BFMLALT(MVT::nxv4f32, Zero, LHS, RHS);
- return FCVTNT(VT, BottomBF16, Pg, TopF32);
+ return Reinterpret(BottomBF16, VT);
+ SDValue TopBF16 = FCVTNT(MVT::nxv8bf16, BottomBF16, Pg, TopF32);
+ return Reinterpret(TopBF16, VT);
}
SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
diff --git a/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll b/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll
index 6a7a4cbd8b20a..e3c0d97c08f54 100644
--- a/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll
+++ b/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll
@@ -1,6 +1,7 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc < %s -mtriple=aarch64 -mattr=-bf16 | FileCheck %s --check-prefixes=CHECK,CHECK-CVT
-; RUN: llc < %s -mtriple=aarch64 -mattr=+bf16 | FileCheck %s --check-prefixes=CHECK,CHECK-BF16
+; RUN: llc < %s -mtriple=aarch64 -mattr=+bf16 | FileCheck %s --check-prefixes=CHECK,CHECK-BF16,CHECK-NOSVE-BF16
+; RUN: llc < %s -mtriple=aarch64 -mattr=+bf16,+sve | FileCheck %s --check-prefixes=CHECK,CHECK-BF16,CHECK-SVE-BF16
define <8 x bfloat> @add_h(<8 x bfloat> %a, <8 x bfloat> %b) {
; CHECK-CVT-LABEL: add_h:
@@ -117,17 +118,29 @@ define <8 x bfloat> @mul_h(<8 x bfloat> %a, <8 x bfloat> %b) {
; CHECK-CVT-NEXT: uzp2 v0.8h, v0.8h, v2.8h
; CHECK-CVT-NEXT: ret
;
-; CHECK-BF16-LABEL: mul_h:
-; CHECK-BF16: // %bb.0: // %entry
-; CHECK-BF16-NEXT: shll v2.4s, v1.4h, #16
-; CHECK-BF16-NEXT: shll v3.4s, v0.4h, #16
-; CHECK-BF16-NEXT: shll2 v1.4s, v1.8h, #16
-; CHECK-BF16-NEXT: shll2 v0.4s, v0.8h, #16
-; CHECK-BF16-NEXT: fmul v2.4s, v3.4s, v2.4s
-; CHECK-BF16-NEXT: fmul v1.4s, v0.4s, v1.4s
-; CHECK-BF16-NEXT: bfcvtn v0.4h, v2.4s
-; CHECK-BF16-NEXT: bfcvtn2 v0.8h, v1.4s
-; CHECK-BF16-NEXT: ret
+; CHECK-NOSVE-BF16-LABEL: mul_h:
+; CHECK-NOSVE-BF16: // %bb.0: // %entry
+; CHECK-NOSVE-BF16-NEXT: shll v2.4s, v1.4h, #16
+; CHECK-NOSVE-BF16-NEXT: shll v3.4s, v0.4h, #16
+; CHECK-NOSVE-BF16-NEXT: shll2 v1.4s, v1.8h, #16
+; CHECK-NOSVE-BF16-NEXT: shll2 v0.4s, v0.8h, #16
+; CHECK-NOSVE-BF16-NEXT: fmul v2.4s, v3.4s, v2.4s
+; CHECK-NOSVE-BF16-NEXT: fmul v1.4s, v0.4s, v1.4s
+; CHECK-NOSVE-BF16-NEXT: bfcvtn v0.4h, v2.4s
+; CHECK-NOSVE-BF16-NEXT: bfcvtn2 v0.8h, v1.4s
+; CHECK-NOSVE-BF16-NEXT: ret
+;
+; CHECK-SVE-BF16-LABEL: mul_h:
+; CHECK-SVE-BF16: // %bb.0: // %entry
+; CHECK-SVE-BF16-NEXT: movi v2.4s, #128, lsl #24
+; CHECK-SVE-BF16-NEXT: movi v3.4s, #128, lsl #24
+; CHECK-SVE-BF16-NEXT: ptrue p0.s, vl4
+; CHECK-SVE-BF16-NEXT: bfmlalb v2.4s, v0.8h, v1.8h
+; CHECK-SVE-BF16-NEXT: bfmlalt v3.4s, v0.8h, v1.8h
+; CHECK-SVE-BF16-NEXT: bfcvt z2.h, p0/m, z2.s
+; CHECK-SVE-BF16-NEXT: bfcvtnt z2.h, p0/m, z3.s
+; CHECK-SVE-BF16-NEXT: mov v0.16b, v2.16b
+; CHECK-SVE-BF16-NEXT: ret
entry:
%0 = fmul <8 x bfloat> %a, %b
ret <8 x bfloat> %0
``````````
</details>
https://github.com/llvm/llvm-project/pull/169655
More information about the llvm-commits
mailing list