[llvm] [X86][AVX10.2] Fix wrong predicates for BF16 feature (PR #113800)

Phoebe Wang via llvm-commits llvm-commits at lists.llvm.org
Sun Oct 27 03:11:27 PDT 2024


https://github.com/phoebewang created https://github.com/llvm/llvm-project/pull/113800

Since AVX10.2, we need to enable 128/256-bit vector by default and check for 512 feature for 512-bit vector.

>From 6194f4d7d227057b3bf9cc46f98e9965e737bc17 Mon Sep 17 00:00:00 2001
From: "Wang, Phoebe" <phoebe.wang at intel.com>
Date: Sun, 27 Oct 2024 18:06:34 +0800
Subject: [PATCH] [X86][AVX10.2] Fix wrong predicates for BF16 feature

Since AVX10.2, we need to enable 128/256-bit vector by default and check
for 512 feature for 512-bit vector.
---
 llvm/lib/Target/X86/X86ISelLowering.cpp    | 41 ++++++++++------------
 llvm/test/CodeGen/X86/avx10_2bf16-arith.ll | 22 ++++++++++++
 2 files changed, 41 insertions(+), 22 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index a6d77873ec2901..9d447959faf55a 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -2406,7 +2406,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
     addLegalFPImmediate(APFloat::getZero(APFloat::BFloat()));
   }
 
-  if (!Subtarget.useSoftFloat() && Subtarget.hasBF16()) {
+  if (!Subtarget.useSoftFloat() && Subtarget.hasBF16() &&
+      Subtarget.useAVX512Regs()) {
     addRegisterClass(MVT::v32bf16, &X86::VR512RegClass);
     setF16Action(MVT::v32bf16, Expand);
     for (unsigned Opc : {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FDIV})
@@ -2419,27 +2420,23 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
   }
 
   if (!Subtarget.useSoftFloat() && Subtarget.hasAVX10_2()) {
-    addRegisterClass(MVT::v8bf16, &X86::VR128XRegClass);
-    addRegisterClass(MVT::v16bf16, &X86::VR256XRegClass);
-    addRegisterClass(MVT::v32bf16, &X86::VR512RegClass);
-
-    setOperationAction(ISD::FADD, MVT::v32bf16, Legal);
-    setOperationAction(ISD::FSUB, MVT::v32bf16, Legal);
-    setOperationAction(ISD::FMUL, MVT::v32bf16, Legal);
-    setOperationAction(ISD::FDIV, MVT::v32bf16, Legal);
-    setOperationAction(ISD::FSQRT, MVT::v32bf16, Legal);
-    setOperationAction(ISD::FMA, MVT::v32bf16, Legal);
-    setOperationAction(ISD::SETCC, MVT::v32bf16, Custom);
-    if (Subtarget.hasVLX()) {
-      for (auto VT : {MVT::v8bf16, MVT::v16bf16}) {
-        setOperationAction(ISD::FADD, VT, Legal);
-        setOperationAction(ISD::FSUB, VT, Legal);
-        setOperationAction(ISD::FMUL, VT, Legal);
-        setOperationAction(ISD::FDIV, VT, Legal);
-        setOperationAction(ISD::FSQRT, VT, Legal);
-        setOperationAction(ISD::FMA, VT, Legal);
-        setOperationAction(ISD::SETCC, VT, Custom);
-      }
+    for (auto VT : {MVT::v8bf16, MVT::v16bf16}) {
+      setOperationAction(ISD::FADD, VT, Legal);
+      setOperationAction(ISD::FSUB, VT, Legal);
+      setOperationAction(ISD::FMUL, VT, Legal);
+      setOperationAction(ISD::FDIV, VT, Legal);
+      setOperationAction(ISD::FSQRT, VT, Legal);
+      setOperationAction(ISD::FMA, VT, Legal);
+      setOperationAction(ISD::SETCC, VT, Custom);
+    }
+    if (Subtarget.hasAVX10_2_512()) {
+      setOperationAction(ISD::FADD, MVT::v32bf16, Legal);
+      setOperationAction(ISD::FSUB, MVT::v32bf16, Legal);
+      setOperationAction(ISD::FMUL, MVT::v32bf16, Legal);
+      setOperationAction(ISD::FDIV, MVT::v32bf16, Legal);
+      setOperationAction(ISD::FSQRT, MVT::v32bf16, Legal);
+      setOperationAction(ISD::FMA, MVT::v32bf16, Legal);
+      setOperationAction(ISD::SETCC, MVT::v32bf16, Custom);
     }
   }
 
diff --git a/llvm/test/CodeGen/X86/avx10_2bf16-arith.ll b/llvm/test/CodeGen/X86/avx10_2bf16-arith.ll
index e0f5679e8ac96d..c97d27ff324bbb 100644
--- a/llvm/test/CodeGen/X86/avx10_2bf16-arith.ll
+++ b/llvm/test/CodeGen/X86/avx10_2bf16-arith.ll
@@ -1166,3 +1166,25 @@ entry:
   %2 = select <8 x i1> %1, <8 x bfloat> %0, <8 x bfloat> zeroinitializer
   ret <8 x bfloat> %2
 }
+
+define <32 x bfloat> @addv(<32 x bfloat> %a, <32 x bfloat> %b) nounwind {
+; X64-LABEL: addv:
+; X64:       # %bb.0:
+; X64-NEXT:    vaddnepbf16 %ymm2, %ymm0, %ymm0 # encoding: [0x62,0xf5,0x7d,0x28,0x58,0xc2]
+; X64-NEXT:    vaddnepbf16 %ymm3, %ymm1, %ymm1 # encoding: [0x62,0xf5,0x75,0x28,0x58,0xcb]
+; X64-NEXT:    retq # encoding: [0xc3]
+;
+; X86-LABEL: addv:
+; X86:       # %bb.0:
+; X86-NEXT:    pushl %ebp # encoding: [0x55]
+; X86-NEXT:    movl %esp, %ebp # encoding: [0x89,0xe5]
+; X86-NEXT:    andl $-32, %esp # encoding: [0x83,0xe4,0xe0]
+; X86-NEXT:    subl $32, %esp # encoding: [0x83,0xec,0x20]
+; X86-NEXT:    vaddnepbf16 %ymm2, %ymm0, %ymm0 # encoding: [0x62,0xf5,0x7d,0x28,0x58,0xc2]
+; X86-NEXT:    vaddnepbf16 8(%ebp), %ymm1, %ymm1 # encoding: [0x62,0xf5,0x75,0x28,0x58,0x8d,0x08,0x00,0x00,0x00]
+; X86-NEXT:    movl %ebp, %esp # encoding: [0x89,0xec]
+; X86-NEXT:    popl %ebp # encoding: [0x5d]
+; X86-NEXT:    retl # encoding: [0xc3]
+  %add = fadd <32 x bfloat> %a, %b
+  ret <32 x bfloat> %add
+}



More information about the llvm-commits mailing list