[llvm] 1ace9fa - [LLVM][CodeGen][SVE] Enable Bfloat fma contraction. (#147941)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jul 11 02:13:51 PDT 2025


Author: Paul Walker
Date: 2025-07-11T10:13:48+01:00
New Revision: 1ace9fa60b668f5a6d0bf4768ff8b4c0dd62f0dd

URL: https://github.com/llvm/llvm-project/commit/1ace9fa60b668f5a6d0bf4768ff8b4c0dd62f0dd
DIFF: https://github.com/llvm/llvm-project/commit/1ace9fa60b668f5a6d0bf4768ff8b4c0dd62f0dd.diff

LOG: [LLVM][CodeGen][SVE] Enable Bfloat fma contraction. (#147941)

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/SVEInstrFormats.td
    llvm/test/CodeGen/AArch64/sve-bf16-combines.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 331c8036e26f1..5f6832fd2e575 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -17834,17 +17834,19 @@ bool AArch64TargetLowering::shouldConsiderGEPOffsetSplit() const {
 
 bool AArch64TargetLowering::isFMAFasterThanFMulAndFAdd(
     const MachineFunction &MF, EVT VT) const {
-  VT = VT.getScalarType();
+  EVT ScalarVT = VT.getScalarType();
 
-  if (!VT.isSimple())
+  if (!ScalarVT.isSimple())
     return false;
 
-  switch (VT.getSimpleVT().SimpleTy) {
+  switch (ScalarVT.getSimpleVT().SimpleTy) {
   case MVT::f16:
     return Subtarget->hasFullFP16();
   case MVT::f32:
   case MVT::f64:
     return true;
+  case MVT::bf16:
+    return VT.isScalableVector() && Subtarget->hasSVEB16B16();
   default:
     break;
   }

diff  --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index 3b7e5a6c2b1cf..a0320f919e8c5 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -2490,6 +2490,8 @@ multiclass sve_fp_3op_p_zds_a_bfloat<bits<2> opc, string asm, string Ps,
            SVEPseudo2Instr<Ps, 1>, SVEInstr2Rev<NAME, "", 0>;
 
   def : SVE_4_Op_Pat<nxv8bf16, op, nxv8i1, nxv8bf16, nxv8bf16, nxv8bf16, !cast<Instruction>(NAME)>;
+  def : SVE_4_Op_Pat<nxv4bf16, op, nxv4i1, nxv4bf16, nxv4bf16, nxv4bf16, !cast<Instruction>(NAME)>;
+  def : SVE_4_Op_Pat<nxv2bf16, op, nxv2i1, nxv2bf16, nxv2bf16, nxv2bf16, !cast<Instruction>(NAME)>;
 }
 
 class sve_fp_3op_p_zds_b<bits<2> sz, bits<2> opc, string asm,

diff  --git a/llvm/test/CodeGen/AArch64/sve-bf16-combines.ll b/llvm/test/CodeGen/AArch64/sve-bf16-combines.ll
index 8c1d41f71c1ec..5c58eab391972 100644
--- a/llvm/test/CodeGen/AArch64/sve-bf16-combines.ll
+++ b/llvm/test/CodeGen/AArch64/sve-bf16-combines.ll
@@ -6,8 +6,8 @@ target triple = "aarch64-unknown-linux-gnu"
 define <vscale x 8 x bfloat> @fmla_nxv8bf16(<vscale x 8 x bfloat> %acc, <vscale x 8 x bfloat> %m1, <vscale x 8 x bfloat> %m2) {
 ; CHECK-LABEL: fmla_nxv8bf16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    bfmul z1.h, z1.h, z2.h
-; CHECK-NEXT:    bfadd z0.h, z0.h, z1.h
+; CHECK-NEXT:    ptrue p0.h
+; CHECK-NEXT:    bfmla z0.h, p0/m, z1.h, z2.h
 ; CHECK-NEXT:    ret
   %mul = fmul contract <vscale x 8 x bfloat> %m1, %m2
   %res = fadd contract <vscale x 8 x bfloat> %acc, %mul
@@ -17,8 +17,8 @@ define <vscale x 8 x bfloat> @fmla_nxv8bf16(<vscale x 8 x bfloat> %acc, <vscale
 define <vscale x 4 x bfloat> @fmla_nxv4bf16(<vscale x 4 x bfloat> %acc, <vscale x 4 x bfloat> %m1, <vscale x 4 x bfloat> %m2) {
 ; CHECK-LABEL: fmla_nxv4bf16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    bfmul z1.h, z1.h, z2.h
-; CHECK-NEXT:    bfadd z0.h, z0.h, z1.h
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    bfmla z0.h, p0/m, z1.h, z2.h
 ; CHECK-NEXT:    ret
   %mul = fmul contract <vscale x 4 x bfloat> %m1, %m2
   %res = fadd contract <vscale x 4 x bfloat> %acc, %mul
@@ -28,8 +28,8 @@ define <vscale x 4 x bfloat> @fmla_nxv4bf16(<vscale x 4 x bfloat> %acc, <vscale
 define <vscale x 2 x bfloat> @fmla_nxv2bf16(<vscale x 2 x bfloat> %acc, <vscale x 2 x bfloat> %m1, <vscale x 2 x bfloat> %m2) {
 ; CHECK-LABEL: fmla_nxv2bf16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    bfmul z1.h, z1.h, z2.h
-; CHECK-NEXT:    bfadd z0.h, z0.h, z1.h
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    bfmla z0.h, p0/m, z1.h, z2.h
 ; CHECK-NEXT:    ret
   %mul = fmul contract <vscale x 2 x bfloat> %m1, %m2
   %res = fadd contract <vscale x 2 x bfloat> %acc, %mul
@@ -39,8 +39,8 @@ define <vscale x 2 x bfloat> @fmla_nxv2bf16(<vscale x 2 x bfloat> %acc, <vscale
 define <vscale x 8 x bfloat> @fmls_nxv8bf16(<vscale x 8 x bfloat> %acc, <vscale x 8 x bfloat> %m1, <vscale x 8 x bfloat> %m2) {
 ; CHECK-LABEL: fmls_nxv8bf16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    bfmul z1.h, z1.h, z2.h
-; CHECK-NEXT:    bfsub z0.h, z0.h, z1.h
+; CHECK-NEXT:    ptrue p0.h
+; CHECK-NEXT:    bfmls z0.h, p0/m, z1.h, z2.h
 ; CHECK-NEXT:    ret
   %mul = fmul contract <vscale x 8 x bfloat> %m1, %m2
   %res = fsub contract <vscale x 8 x bfloat> %acc, %mul
@@ -50,8 +50,8 @@ define <vscale x 8 x bfloat> @fmls_nxv8bf16(<vscale x 8 x bfloat> %acc, <vscale
 define <vscale x 4 x bfloat> @fmls_nxv4bf16(<vscale x 4 x bfloat> %acc, <vscale x 4 x bfloat> %m1, <vscale x 4 x bfloat> %m2) {
 ; CHECK-LABEL: fmls_nxv4bf16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    bfmul z1.h, z1.h, z2.h
-; CHECK-NEXT:    bfsub z0.h, z0.h, z1.h
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    bfmls z0.h, p0/m, z1.h, z2.h
 ; CHECK-NEXT:    ret
   %mul = fmul contract <vscale x 4 x bfloat> %m1, %m2
   %res = fsub contract <vscale x 4 x bfloat> %acc, %mul
@@ -61,8 +61,8 @@ define <vscale x 4 x bfloat> @fmls_nxv4bf16(<vscale x 4 x bfloat> %acc, <vscale
 define <vscale x 2 x bfloat> @fmls_nxv2bf16(<vscale x 2 x bfloat> %acc, <vscale x 2 x bfloat> %m1, <vscale x 2 x bfloat> %m2) {
 ; CHECK-LABEL: fmls_nxv2bf16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    bfmul z1.h, z1.h, z2.h
-; CHECK-NEXT:    bfsub z0.h, z0.h, z1.h
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    bfmls z0.h, p0/m, z1.h, z2.h
 ; CHECK-NEXT:    ret
   %mul = fmul contract <vscale x 2 x bfloat> %m1, %m2
   %res = fsub contract <vscale x 2 x bfloat> %acc, %mul
@@ -72,9 +72,7 @@ define <vscale x 2 x bfloat> @fmls_nxv2bf16(<vscale x 2 x bfloat> %acc, <vscale
 define <vscale x 8 x bfloat> @fmla_sel_nxv8bf16(<vscale x 8 x i1> %pred, <vscale x 8 x bfloat> %acc, <vscale x 8 x bfloat> %m1, <vscale x 8 x bfloat> %m2) {
 ; CHECK-LABEL: fmla_sel_nxv8bf16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    bfmul z1.h, z1.h, z2.h
-; CHECK-NEXT:    bfadd z1.h, z0.h, z1.h
-; CHECK-NEXT:    mov z0.h, p0/m, z1.h
+; CHECK-NEXT:    bfmla z0.h, p0/m, z1.h, z2.h
 ; CHECK-NEXT:    ret
   %mul = fmul contract <vscale x 8 x bfloat> %m1, %m2
   %add = fadd contract <vscale x 8 x bfloat> %acc, %mul
@@ -85,9 +83,7 @@ define <vscale x 8 x bfloat> @fmla_sel_nxv8bf16(<vscale x 8 x i1> %pred, <vscale
 define <vscale x 4 x bfloat> @fmla_sel_nxv4bf16(<vscale x 4 x i1> %pred, <vscale x 4 x bfloat> %acc, <vscale x 4 x bfloat> %m1, <vscale x 4 x bfloat> %m2) {
 ; CHECK-LABEL: fmla_sel_nxv4bf16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    bfmul z1.h, z1.h, z2.h
-; CHECK-NEXT:    bfadd z1.h, z0.h, z1.h
-; CHECK-NEXT:    mov z0.s, p0/m, z1.s
+; CHECK-NEXT:    bfmla z0.h, p0/m, z1.h, z2.h
 ; CHECK-NEXT:    ret
   %mul = fmul contract <vscale x 4 x bfloat> %m1, %m2
   %add = fadd contract <vscale x 4 x bfloat> %acc, %mul
@@ -98,9 +94,7 @@ define <vscale x 4 x bfloat> @fmla_sel_nxv4bf16(<vscale x 4 x i1> %pred, <vscale
 define <vscale x 2 x bfloat> @fmla_sel_nxv2bf16(<vscale x 2 x i1> %pred, <vscale x 2 x bfloat> %acc, <vscale x 2 x bfloat> %m1, <vscale x 2 x bfloat> %m2) {
 ; CHECK-LABEL: fmla_sel_nxv2bf16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    bfmul z1.h, z1.h, z2.h
-; CHECK-NEXT:    bfadd z1.h, z0.h, z1.h
-; CHECK-NEXT:    mov z0.d, p0/m, z1.d
+; CHECK-NEXT:    bfmla z0.h, p0/m, z1.h, z2.h
 ; CHECK-NEXT:    ret
   %mul = fmul contract <vscale x 2 x bfloat> %m1, %m2
   %add = fadd contract <vscale x 2 x bfloat> %acc, %mul
@@ -111,9 +105,7 @@ define <vscale x 2 x bfloat> @fmla_sel_nxv2bf16(<vscale x 2 x i1> %pred, <vscale
 define <vscale x 8 x bfloat> @fmls_sel_nxv8bf16(<vscale x 8 x i1> %pred, <vscale x 8 x bfloat> %acc, <vscale x 8 x bfloat> %m1, <vscale x 8 x bfloat> %m2) {
 ; CHECK-LABEL: fmls_sel_nxv8bf16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    bfmul z1.h, z1.h, z2.h
-; CHECK-NEXT:    bfsub z1.h, z0.h, z1.h
-; CHECK-NEXT:    mov z0.h, p0/m, z1.h
+; CHECK-NEXT:    bfmls z0.h, p0/m, z1.h, z2.h
 ; CHECK-NEXT:    ret
   %mul = fmul contract <vscale x 8 x bfloat> %m1, %m2
   %sub = fsub contract <vscale x 8 x bfloat> %acc, %mul
@@ -124,9 +116,7 @@ define <vscale x 8 x bfloat> @fmls_sel_nxv8bf16(<vscale x 8 x i1> %pred, <vscale
 define <vscale x 4 x bfloat> @fmls_sel_nxv4bf16(<vscale x 4 x i1> %pred, <vscale x 4 x bfloat> %acc, <vscale x 4 x bfloat> %m1, <vscale x 4 x bfloat> %m2) {
 ; CHECK-LABEL: fmls_sel_nxv4bf16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    bfmul z1.h, z1.h, z2.h
-; CHECK-NEXT:    bfsub z1.h, z0.h, z1.h
-; CHECK-NEXT:    mov z0.s, p0/m, z1.s
+; CHECK-NEXT:    bfmls z0.h, p0/m, z1.h, z2.h
 ; CHECK-NEXT:    ret
   %mul = fmul contract <vscale x 4 x bfloat> %m1, %m2
   %sub = fsub contract <vscale x 4 x bfloat> %acc, %mul
@@ -137,9 +127,7 @@ define <vscale x 4 x bfloat> @fmls_sel_nxv4bf16(<vscale x 4 x i1> %pred, <vscale
 define <vscale x 2 x bfloat> @fmls_sel_nxv2bf16(<vscale x 2 x i1> %pred, <vscale x 2 x bfloat> %acc, <vscale x 2 x bfloat> %m1, <vscale x 2 x bfloat> %m2) {
 ; CHECK-LABEL: fmls_sel_nxv2bf16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    bfmul z1.h, z1.h, z2.h
-; CHECK-NEXT:    bfsub z1.h, z0.h, z1.h
-; CHECK-NEXT:    mov z0.d, p0/m, z1.d
+; CHECK-NEXT:    bfmls z0.h, p0/m, z1.h, z2.h
 ; CHECK-NEXT:    ret
   %mul = fmul contract <vscale x 2 x bfloat> %m1, %m2
   %sub = fsub contract <vscale x 2 x bfloat> %acc, %mul


        


More information about the llvm-commits mailing list