[llvm] 02dfbbf - [SelectionDAG] Make ARITH_FENCE support half and bfloat type (#90836)

via llvm-commits llvm-commits at lists.llvm.org
Sat May 4 22:08:38 PDT 2024


Author: Phoebe Wang
Date: 2024-05-05T13:08:34+08:00
New Revision: 02dfbbff1937b3e2c1ee1cd4a5ad0a9f03ee23ea

URL: https://github.com/llvm/llvm-project/commit/02dfbbff1937b3e2c1ee1cd4a5ad0a9f03ee23ea
DIFF: https://github.com/llvm/llvm-project/commit/02dfbbff1937b3e2c1ee1cd4a5ad0a9f03ee23ea.diff

LOG: [SelectionDAG] Make ARITH_FENCE support half and bfloat type (#90836)

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
    llvm/test/CodeGen/X86/arithmetic_fence2.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index 8bd4839c17a628..bf87437b8dfd57 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -2863,6 +2863,8 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) {
     report_fatal_error("Do not know how to soft promote this operator's "
                        "result!");
 
+  case ISD::ARITH_FENCE:
+    R = SoftPromoteHalfRes_ARITH_FENCE(N); break;
   case ISD::BITCAST:    R = SoftPromoteHalfRes_BITCAST(N); break;
   case ISD::ConstantFP: R = SoftPromoteHalfRes_ConstantFP(N); break;
   case ISD::EXTRACT_VECTOR_ELT:
@@ -2942,6 +2944,11 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) {
     SetSoftPromotedHalf(SDValue(N, ResNo), R);
 }
 
+SDValue DAGTypeLegalizer::SoftPromoteHalfRes_ARITH_FENCE(SDNode *N) {
+  return DAG.getNode(ISD::ARITH_FENCE, SDLoc(N), MVT::i16,
+                     BitConvertToInteger(N->getOperand(0)));
+}
+
 SDValue DAGTypeLegalizer::SoftPromoteHalfRes_BITCAST(SDNode *N) {
   return BitConvertToInteger(N->getOperand(0));
 }

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 0252e3d6febca9..4b06e19656ce60 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -728,6 +728,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   void SetSoftPromotedHalf(SDValue Op, SDValue Result);
 
   void SoftPromoteHalfResult(SDNode *N, unsigned ResNo);
+  SDValue SoftPromoteHalfRes_ARITH_FENCE(SDNode *N);
   SDValue SoftPromoteHalfRes_BinOp(SDNode *N);
   SDValue SoftPromoteHalfRes_BITCAST(SDNode *N);
   SDValue SoftPromoteHalfRes_ConstantFP(SDNode *N);

diff  --git a/llvm/test/CodeGen/X86/arithmetic_fence2.ll b/llvm/test/CodeGen/X86/arithmetic_fence2.ll
index 6a854b58fc02d0..3c2ef21527f501 100644
--- a/llvm/test/CodeGen/X86/arithmetic_fence2.ll
+++ b/llvm/test/CodeGen/X86/arithmetic_fence2.ll
@@ -157,6 +157,160 @@ define <8 x float> @f6(<8 x float> %a) {
   ret <8 x float> %3
 }
 
+define half @f7(half %a) nounwind {
+; X86-LABEL: f7:
+; X86:       # %bb.0:
+; X86-NEXT:    pinsrw $0, {{[0-9]+}}(%esp), %xmm0
+; X86-NEXT:    #ARITH_FENCE
+; X86-NEXT:    retl
+;
+; X64-LABEL: f7:
+; X64:       # %bb.0:
+; X64-NEXT:    #ARITH_FENCE
+; X64-NEXT:    retq
+  %b = call half @llvm.arithmetic.fence.f16(half %a)
+  ret half %b
+}
+
+define bfloat @f8(bfloat %a) nounwind {
+; X86-LABEL: f8:
+; X86:       # %bb.0:
+; X86-NEXT:    movzwl {{[0-9]+}}(%esp), %eax
+; X86-NEXT:    #ARITH_FENCE
+; X86-NEXT:    pinsrw $0, %eax, %xmm0
+; X86-NEXT:    retl
+;
+; X64-LABEL: f8:
+; X64:       # %bb.0:
+; X64-NEXT:    pextrw $0, %xmm0, %eax
+; X64-NEXT:    #ARITH_FENCE
+; X64-NEXT:    pinsrw $0, %eax, %xmm0
+; X64-NEXT:    retq
+  %b = call bfloat @llvm.arithmetic.fence.bf16(bfloat %a)
+  ret bfloat %b
+}
+
+define <2 x half> @f9(<2 x half> %a) nounwind {
+; X86-LABEL: f9:
+; X86:       # %bb.0:
+; X86-NEXT:    movdqa %xmm0, %xmm1
+; X86-NEXT:    psrld $16, %xmm1
+; X86-NEXT:    #ARITH_FENCE
+; X86-NEXT:    #ARITH_FENCE
+; X86-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
+; X86-NEXT:    retl
+;
+; X64-LABEL: f9:
+; X64:       # %bb.0:
+; X64-NEXT:    movdqa %xmm0, %xmm1
+; X64-NEXT:    psrld $16, %xmm1
+; X64-NEXT:    #ARITH_FENCE
+; X64-NEXT:    #ARITH_FENCE
+; X64-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
+; X64-NEXT:    retq
+  %b = call <2 x half> @llvm.arithmetic.fence.v2f16(<2 x half> %a)
+  ret <2 x half> %b
+}
+
+define <3 x bfloat> @f10(<3 x bfloat> %a) nounwind {
+; X86-LABEL: f10:
+; X86:       # %bb.0:
+; X86-NEXT:    pextrw $0, %xmm0, %eax
+; X86-NEXT:    movdqa %xmm0, %xmm1
+; X86-NEXT:    psrld $16, %xmm1
+; X86-NEXT:    pextrw $0, %xmm1, %ecx
+; X86-NEXT:    shufps {{.*#+}} xmm0 = xmm0[1,1,1,1]
+; X86-NEXT:    pextrw $0, %xmm0, %edx
+; X86-NEXT:    #ARITH_FENCE
+; X86-NEXT:    #ARITH_FENCE
+; X86-NEXT:    #ARITH_FENCE
+; X86-NEXT:    pinsrw $0, %eax, %xmm0
+; X86-NEXT:    pinsrw $0, %ecx, %xmm1
+; X86-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
+; X86-NEXT:    pinsrw $0, %edx, %xmm1
+; X86-NEXT:    punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
+; X86-NEXT:    retl
+;
+; X64-LABEL: f10:
+; X64:       # %bb.0:
+; X64-NEXT:    pextrw $0, %xmm0, %eax
+; X64-NEXT:    movdqa %xmm0, %xmm1
+; X64-NEXT:    psrld $16, %xmm1
+; X64-NEXT:    pextrw $0, %xmm1, %ecx
+; X64-NEXT:    shufps {{.*#+}} xmm0 = xmm0[1,1,1,1]
+; X64-NEXT:    pextrw $0, %xmm0, %edx
+; X64-NEXT:    #ARITH_FENCE
+; X64-NEXT:    #ARITH_FENCE
+; X64-NEXT:    #ARITH_FENCE
+; X64-NEXT:    pinsrw $0, %eax, %xmm0
+; X64-NEXT:    pinsrw $0, %ecx, %xmm1
+; X64-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
+; X64-NEXT:    pinsrw $0, %edx, %xmm1
+; X64-NEXT:    punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
+; X64-NEXT:    retq
+  %b = call <3 x bfloat> @llvm.arithmetic.fence.v3bf16(<3 x bfloat> %a)
+  ret <3 x bfloat> %b
+}
+
+define <4 x bfloat> @f11(<4 x bfloat> %a) nounwind {
+; X86-LABEL: f11:
+; X86:       # %bb.0:
+; X86-NEXT:    pushl %esi
+; X86-NEXT:    movdqa %xmm0, %xmm1
+; X86-NEXT:    psrlq $48, %xmm1
+; X86-NEXT:    pextrw $0, %xmm1, %eax
+; X86-NEXT:    movdqa %xmm0, %xmm1
+; X86-NEXT:    shufps {{.*#+}} xmm1 = xmm1[1,1],xmm0[1,1]
+; X86-NEXT:    pextrw $0, %xmm1, %edx
+; X86-NEXT:    pextrw $0, %xmm0, %ecx
+; X86-NEXT:    psrld $16, %xmm0
+; X86-NEXT:    pextrw $0, %xmm0, %esi
+; X86-NEXT:    #ARITH_FENCE
+; X86-NEXT:    #ARITH_FENCE
+; X86-NEXT:    #ARITH_FENCE
+; X86-NEXT:    #ARITH_FENCE
+; X86-NEXT:    pinsrw $0, %eax, %xmm0
+; X86-NEXT:    pinsrw $0, %edx, %xmm1
+; X86-NEXT:    punpcklwd {{.*#+}} xmm1 = xmm1[0],xmm0[0],xmm1[1],xmm0[1],xmm1[2],xmm0[2],xmm1[3],xmm0[3]
+; X86-NEXT:    pinsrw $0, %ecx, %xmm0
+; X86-NEXT:    pinsrw $0, %esi, %xmm2
+; X86-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1],xmm0[2],xmm2[2],xmm0[3],xmm2[3]
+; X86-NEXT:    punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
+; X86-NEXT:    popl %esi
+; X86-NEXT:    retl
+;
+; X64-LABEL: f11:
+; X64:       # %bb.0:
+; X64-NEXT:    movdqa %xmm0, %xmm1
+; X64-NEXT:    psrlq $48, %xmm1
+; X64-NEXT:    pextrw $0, %xmm1, %eax
+; X64-NEXT:    movdqa %xmm0, %xmm1
+; X64-NEXT:    shufps {{.*#+}} xmm1 = xmm1[1,1],xmm0[1,1]
+; X64-NEXT:    pextrw $0, %xmm1, %ecx
+; X64-NEXT:    pextrw $0, %xmm0, %edx
+; X64-NEXT:    psrld $16, %xmm0
+; X64-NEXT:    pextrw $0, %xmm0, %esi
+; X64-NEXT:    #ARITH_FENCE
+; X64-NEXT:    #ARITH_FENCE
+; X64-NEXT:    #ARITH_FENCE
+; X64-NEXT:    #ARITH_FENCE
+; X64-NEXT:    pinsrw $0, %eax, %xmm0
+; X64-NEXT:    pinsrw $0, %ecx, %xmm1
+; X64-NEXT:    punpcklwd {{.*#+}} xmm1 = xmm1[0],xmm0[0],xmm1[1],xmm0[1],xmm1[2],xmm0[2],xmm1[3],xmm0[3]
+; X64-NEXT:    pinsrw $0, %edx, %xmm0
+; X64-NEXT:    pinsrw $0, %esi, %xmm2
+; X64-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1],xmm0[2],xmm2[2],xmm0[3],xmm2[3]
+; X64-NEXT:    punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
+; X64-NEXT:    retq
+  %b = call <4 x bfloat> @llvm.arithmetic.fence.v4bf16(<4 x bfloat> %a)
+  ret <4 x bfloat> %b
+}
+
+declare half @llvm.arithmetic.fence.f16(half)
+declare bfloat @llvm.arithmetic.fence.bf16(bfloat)
+declare <2 x half> @llvm.arithmetic.fence.v2f16(<2 x half>)
+declare <3 x bfloat> @llvm.arithmetic.fence.v3bf16(<3 x bfloat>)
+declare <4 x bfloat> @llvm.arithmetic.fence.v4bf16(<4 x bfloat>)
 declare float @llvm.arithmetic.fence.f32(float)
 declare double @llvm.arithmetic.fence.f64(double)
 declare <2 x float> @llvm.arithmetic.fence.v2f32(<2 x float>)


        


More information about the llvm-commits mailing list