[llvm] [SelectionDAG] Make ARITH_FENCE support half and bfloat type (PR #90836)

via llvm-commits llvm-commits at lists.llvm.org
Thu May 2 02:16:46 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-x86

Author: Phoebe Wang (phoebewang)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/90836.diff


3 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp (+7) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+1) 
- (modified) llvm/test/CodeGen/X86/arithmetic_fence2.ll (+85) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index abe5be76382556..00f94e48a3f9ad 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -2825,6 +2825,8 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) {
     report_fatal_error("Do not know how to soft promote this operator's "
                        "result!");
 
+  case ISD::ARITH_FENCE:
+    R = SoftPromoteHalfRes_ARITH_FENCE(N); break;
   case ISD::BITCAST:    R = SoftPromoteHalfRes_BITCAST(N); break;
   case ISD::ConstantFP: R = SoftPromoteHalfRes_ConstantFP(N); break;
   case ISD::EXTRACT_VECTOR_ELT:
@@ -2904,6 +2906,11 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) {
     SetSoftPromotedHalf(SDValue(N, ResNo), R);
 }
 
+SDValue DAGTypeLegalizer::SoftPromoteHalfRes_ARITH_FENCE(SDNode *N) {
+  return DAG.getNode(ISD::ARITH_FENCE, SDLoc(N), MVT::i16,
+                     BitConvertToInteger(N->getOperand(0)));
+}
+
 SDValue DAGTypeLegalizer::SoftPromoteHalfRes_BITCAST(SDNode *N) {
   return BitConvertToInteger(N->getOperand(0));
 }
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 49be824deb5134..e9714f6f72b6bb 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -726,6 +726,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   void SetSoftPromotedHalf(SDValue Op, SDValue Result);
 
   void SoftPromoteHalfResult(SDNode *N, unsigned ResNo);
+  SDValue SoftPromoteHalfRes_ARITH_FENCE(SDNode *N);
   SDValue SoftPromoteHalfRes_BinOp(SDNode *N);
   SDValue SoftPromoteHalfRes_BITCAST(SDNode *N);
   SDValue SoftPromoteHalfRes_ConstantFP(SDNode *N);
diff --git a/llvm/test/CodeGen/X86/arithmetic_fence2.ll b/llvm/test/CodeGen/X86/arithmetic_fence2.ll
index 6a854b58fc02d0..bc80e1a70112d3 100644
--- a/llvm/test/CodeGen/X86/arithmetic_fence2.ll
+++ b/llvm/test/CodeGen/X86/arithmetic_fence2.ll
@@ -157,6 +157,91 @@ define <8 x float> @f6(<8 x float> %a) {
   ret <8 x float> %3
 }
 
+define half @f7(half %a) nounwind {
+; X86-LABEL: f7:
+; X86:       # %bb.0:
+; X86-NEXT:    subl $12, %esp
+; X86-NEXT:    pinsrw $0, {{[0-9]+}}(%esp), %xmm0
+; X86-NEXT:    pextrw $0, %xmm0, %eax
+; X86-NEXT:    movw %ax, (%esp)
+; X86-NEXT:    calll __extendhfsf2
+; X86-NEXT:    fstps {{[0-9]+}}(%esp)
+; X86-NEXT:    movss {{.*#+}} xmm0 = mem[0],zero,zero,zero
+; X86-NEXT:    addss %xmm0, %xmm0
+; X86-NEXT:    movss %xmm0, (%esp)
+; X86-NEXT:    calll __truncsfhf2
+; X86-NEXT:    pextrw $0, %xmm0, %eax
+; X86-NEXT:    movw %ax, (%esp)
+; X86-NEXT:    calll __extendhfsf2
+; X86-NEXT:    fstps {{[0-9]+}}(%esp)
+; X86-NEXT:    movss {{.*#+}} xmm0 = mem[0],zero,zero,zero
+; X86-NEXT:    addss %xmm0, %xmm0
+; X86-NEXT:    movss %xmm0, (%esp)
+; X86-NEXT:    calll __truncsfhf2
+; X86-NEXT:    addl $12, %esp
+; X86-NEXT:    retl
+;
+; X64-LABEL: f7:
+; X64:       # %bb.0:
+; X64-NEXT:    pushq %rax
+; X64-NEXT:    callq __extendhfsf2 at PLT
+; X64-NEXT:    addss %xmm0, %xmm0
+; X64-NEXT:    callq __truncsfhf2 at PLT
+; X64-NEXT:    callq __extendhfsf2 at PLT
+; X64-NEXT:    addss %xmm0, %xmm0
+; X64-NEXT:    callq __truncsfhf2 at PLT
+; X64-NEXT:    popq %rax
+; X64-NEXT:    retq
+  %1 = fadd fast half %a, %a
+  %t = call half @llvm.arithmetic.fence.f16(half %1)
+  %2 = fadd fast half %a, %a
+  %3 = fadd fast half %1, %2
+  ret half %3
+}
+
+define bfloat @f8(bfloat %a) nounwind {
+; X86-LABEL: f8:
+; X86:       # %bb.0:
+; X86-NEXT:    pushl %eax
+; X86-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; X86-NEXT:    shll $16, %eax
+; X86-NEXT:    movd %eax, %xmm0
+; X86-NEXT:    addss %xmm0, %xmm0
+; X86-NEXT:    movss %xmm0, (%esp)
+; X86-NEXT:    calll __truncsfbf2
+; X86-NEXT:    pextrw $0, %xmm0, %eax
+; X86-NEXT:    shll $16, %eax
+; X86-NEXT:    movd %eax, %xmm0
+; X86-NEXT:    addss %xmm0, %xmm0
+; X86-NEXT:    movss %xmm0, (%esp)
+; X86-NEXT:    calll __truncsfbf2
+; X86-NEXT:    popl %eax
+; X86-NEXT:    retl
+;
+; X64-LABEL: f8:
+; X64:       # %bb.0:
+; X64-NEXT:    pushq %rax
+; X64-NEXT:    pextrw $0, %xmm0, %eax
+; X64-NEXT:    shll $16, %eax
+; X64-NEXT:    movd %eax, %xmm0
+; X64-NEXT:    addss %xmm0, %xmm0
+; X64-NEXT:    callq __truncsfbf2 at PLT
+; X64-NEXT:    pextrw $0, %xmm0, %eax
+; X64-NEXT:    shll $16, %eax
+; X64-NEXT:    movd %eax, %xmm0
+; X64-NEXT:    addss %xmm0, %xmm0
+; X64-NEXT:    callq __truncsfbf2 at PLT
+; X64-NEXT:    popq %rax
+; X64-NEXT:    retq
+  %1 = fadd fast bfloat %a, %a
+  %t = call bfloat @llvm.arithmetic.fence.bf16(bfloat %1)
+  %2 = fadd fast bfloat %a, %a
+  %3 = fadd fast bfloat %1, %2
+  ret bfloat %3
+}
+
+declare half @llvm.arithmetic.fence.f16(half)
+declare bfloat @llvm.arithmetic.fence.bf16(bfloat)
 declare float @llvm.arithmetic.fence.f32(float)
 declare double @llvm.arithmetic.fence.f64(double)
 declare <2 x float> @llvm.arithmetic.fence.v2f32(<2 x float>)

``````````

</details>


https://github.com/llvm/llvm-project/pull/90836


More information about the llvm-commits mailing list