[llvm] [X86][BF16] Do not combine FP_TRUNC + FP_EXTEND if they come from user (PR #91420)

Phoebe Wang via llvm-commits llvm-commits at lists.llvm.org
Tue May 7 19:26:08 PDT 2024


https://github.com/phoebewang created https://github.com/llvm/llvm-project/pull/91420

As discussed in https://github.com/llvm/llvm-project/commit/3cf8535dbf0bf5fafa99ea1f300e2384a7254fba We are not allowed to combine explicit fptrunc/fpext from user.

>From 65891448f0e6677ea357568f69a261996c966966 Mon Sep 17 00:00:00 2001
From: Phoebe Wang <phoebe.wang at intel.com>
Date: Wed, 8 May 2024 10:20:48 +0800
Subject: [PATCH] [X86][BF16] Do not combine FP_TRUNC + FP_EXTEND if they come
 from user

As discussed in https://github.com/llvm/llvm-project/commit/3cf8535dbf0bf5fafa99ea1f300e2384a7254fba
We are not allowed to combine explicit fptrunc/fpext from user.
---
 llvm/lib/Target/X86/X86ISelLowering.cpp |  7 ++--
 llvm/test/CodeGen/X86/bfloat.ll         | 44 +++++++++++++++++++++++++
 2 files changed, 48 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 8ec4984dfa557..678ba01ba0487 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -56732,6 +56732,7 @@ static SDValue combineFP16_TO_FP(SDNode *N, SelectionDAG &DAG,
 }
 
 static SDValue combineFP_EXTEND(SDNode *N, SelectionDAG &DAG,
+                                TargetLowering::DAGCombinerInfo &DCI,
                                 const X86Subtarget &Subtarget) {
   EVT VT = N->getValueType(0);
   bool IsStrict = N->isStrictFPOpcode();
@@ -56740,8 +56741,8 @@ static SDValue combineFP_EXTEND(SDNode *N, SelectionDAG &DAG,
 
   SDLoc dl(N);
   if (SrcVT.getScalarType() == MVT::bf16) {
-    if (!IsStrict && Src.getOpcode() == ISD::FP_ROUND &&
-        Src.getOperand(0).getValueType() == VT)
+    if (DCI.isAfterLegalizeDAG() && Src.getOpcode() == ISD::FP_ROUND &&
+        !IsStrict && Src.getOperand(0).getValueType() == VT)
       return Src.getOperand(0);
 
     if (!SrcVT.isVector())
@@ -57159,7 +57160,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
   case X86ISD::KSHIFTR:     return combineKSHIFT(N, DAG, DCI);
   case ISD::FP16_TO_FP:     return combineFP16_TO_FP(N, DAG, Subtarget);
   case ISD::STRICT_FP_EXTEND:
-  case ISD::FP_EXTEND:      return combineFP_EXTEND(N, DAG, Subtarget);
+  case ISD::FP_EXTEND:      return combineFP_EXTEND(N, DAG, DCI, Subtarget);
   case ISD::STRICT_FP_ROUND:
   case ISD::FP_ROUND:       return combineFP_ROUND(N, DAG, Subtarget);
   case X86ISD::VBROADCAST_LOAD:
diff --git a/llvm/test/CodeGen/X86/bfloat.ll b/llvm/test/CodeGen/X86/bfloat.ll
index 39d8e2d50c91e..b3e04590075f8 100644
--- a/llvm/test/CodeGen/X86/bfloat.ll
+++ b/llvm/test/CodeGen/X86/bfloat.ll
@@ -2420,3 +2420,47 @@ define <16 x bfloat> @concat_dup_v8bf16(<8 x bfloat> %x, <8 x bfloat> %y) {
   %a = shufflevector <8 x bfloat> %x, <8 x bfloat> %y, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
   ret <16 x bfloat> %a
 }
+
+define float @trunc_ext(float %a) nounwind {
+; X86-LABEL: trunc_ext:
+; X86:       # %bb.0:
+; X86-NEXT:    pushl %eax
+; X86-NEXT:    vmovss {{.*#+}} xmm0 = mem[0],zero,zero,zero
+; X86-NEXT:    vcvtneps2bf16 %xmm0, %xmm0
+; X86-NEXT:    vmovw %xmm0, %eax
+; X86-NEXT:    shll $16, %eax
+; X86-NEXT:    vmovd %eax, %xmm0
+; X86-NEXT:    vmovd %xmm0, (%esp)
+; X86-NEXT:    flds (%esp)
+; X86-NEXT:    popl %eax
+; X86-NEXT:    retl
+;
+; SSE2-LABEL: trunc_ext:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    pushq %rax
+; SSE2-NEXT:    callq __truncsfbf2 at PLT
+; SSE2-NEXT:    pextrw $0, %xmm0, %eax
+; SSE2-NEXT:    shll $16, %eax
+; SSE2-NEXT:    movd %eax, %xmm0
+; SSE2-NEXT:    popq %rax
+; SSE2-NEXT:    retq
+;
+; FP16-LABEL: trunc_ext:
+; FP16:       # %bb.0:
+; FP16-NEXT:    vcvtneps2bf16 %xmm0, %xmm0
+; FP16-NEXT:    vmovw %xmm0, %eax
+; FP16-NEXT:    shll $16, %eax
+; FP16-NEXT:    vmovd %eax, %xmm0
+; FP16-NEXT:    retq
+;
+; AVXNC-LABEL: trunc_ext:
+; AVXNC:       # %bb.0:
+; AVXNC-NEXT:    {vex} vcvtneps2bf16 %xmm0, %xmm0
+; AVXNC-NEXT:    vmovd %xmm0, %eax
+; AVXNC-NEXT:    shll $16, %eax
+; AVXNC-NEXT:    vmovd %eax, %xmm0
+; AVXNC-NEXT:    retq
+  %b = fptrunc float %a to bfloat
+  %c = fpext bfloat %b to float
+  ret float %c
+}



More information about the llvm-commits mailing list