[llvm] c236883 - [X86] Optimize fdiv with reciprocal instructions for half type

via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 7 18:41:21 PDT 2021


Author: Wang, Pengfei
Date: 2021-10-08T09:41:13+08:00
New Revision: c236883b6ba791881256b31cfdb8a8520a821a67

URL: https://github.com/llvm/llvm-project/commit/c236883b6ba791881256b31cfdb8a8520a821a67
DIFF: https://github.com/llvm/llvm-project/commit/c236883b6ba791881256b31cfdb8a8520a821a67.diff

LOG: [X86] Optimize fdiv with reciprocal instructions for half type

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D110557

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/avx512fp16-arith-vl-intrinsics.ll
    llvm/test/CodeGen/X86/avx512fp16-arith.ll
    llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll
    llvm/test/CodeGen/X86/avx512fp16vl-intrinsics.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 2a167cbaf9a7..550c6bcb8d65 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -23046,9 +23046,10 @@ SDValue DAGCombiner::BuildDivEstimate(SDValue N, SDValue Op,
   if (LegalDAG)
     return SDValue();
 
-  // TODO: Handle half and/or extended types?
+  // TODO: Handle extended types?
   EVT VT = Op.getValueType();
-  if (VT.getScalarType() != MVT::f32 && VT.getScalarType() != MVT::f64)
+  if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
+      VT.getScalarType() != MVT::f64)
     return SDValue();
 
   // If estimates are explicitly disabled for this function, we're done.
@@ -23185,9 +23186,10 @@ SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
   if (LegalDAG)
     return SDValue();
 
-  // TODO: Handle half and/or extended types?
+  // TODO: Handle extended types?
   EVT VT = Op.getValueType();
-  if (VT.getScalarType() != MVT::f32 && VT.getScalarType() != MVT::f64)
+  if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
+      VT.getScalarType() != MVT::f64)
     return SDValue();
 
   // If estimates are explicitly disabled for this function, we're done.

diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index dd32d323677e..23f539b1edb7 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -23148,6 +23148,7 @@ SDValue X86TargetLowering::getSqrtEstimate(SDValue Op,
                                            int &RefinementSteps,
                                            bool &UseOneConstNR,
                                            bool Reciprocal) const {
+  SDLoc DL(Op);
   EVT VT = Op.getValueType();
 
   // SSE1 has rsqrtss and rsqrtps. AVX adds a 256-bit variant for rsqrtps.
@@ -23169,7 +23170,23 @@ SDValue X86TargetLowering::getSqrtEstimate(SDValue Op,
     UseOneConstNR = false;
     // There is no FSQRT for 512-bits, but there is RSQRT14.
     unsigned Opcode = VT == MVT::v16f32 ? X86ISD::RSQRT14 : X86ISD::FRSQRT;
-    return DAG.getNode(Opcode, SDLoc(Op), VT, Op);
+    return DAG.getNode(Opcode, DL, VT, Op);
+  }
+
+  if (VT.getScalarType() == MVT::f16 && isTypeLegal(VT) &&
+      Subtarget.hasFP16()) {
+    if (RefinementSteps == ReciprocalEstimate::Unspecified)
+      RefinementSteps = 0;
+
+    if (VT == MVT::f16) {
+      SDValue Zero = DAG.getIntPtrConstant(0, DL);
+      SDValue Undef = DAG.getUNDEF(MVT::v8f16);
+      Op = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v8f16, Op);
+      Op = DAG.getNode(X86ISD::RSQRT14S, DL, MVT::v8f16, Undef, Op);
+      return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Op, Zero);
+    }
+
+    return DAG.getNode(X86ISD::RSQRT14, DL, VT, Op);
   }
   return SDValue();
 }
@@ -23179,6 +23196,7 @@ SDValue X86TargetLowering::getSqrtEstimate(SDValue Op,
 SDValue X86TargetLowering::getRecipEstimate(SDValue Op, SelectionDAG &DAG,
                                             int Enabled,
                                             int &RefinementSteps) const {
+  SDLoc DL(Op);
   EVT VT = Op.getValueType();
 
   // SSE1 has rcpss and rcpps. AVX adds a 256-bit variant for rcpps.
@@ -23203,7 +23221,23 @@ SDValue X86TargetLowering::getRecipEstimate(SDValue Op, SelectionDAG &DAG,
 
     // There is no FSQRT for 512-bits, but there is RCP14.
     unsigned Opcode = VT == MVT::v16f32 ? X86ISD::RCP14 : X86ISD::FRCP;
-    return DAG.getNode(Opcode, SDLoc(Op), VT, Op);
+    return DAG.getNode(Opcode, DL, VT, Op);
+  }
+
+  if (VT.getScalarType() == MVT::f16 && isTypeLegal(VT) &&
+      Subtarget.hasFP16()) {
+    if (RefinementSteps == ReciprocalEstimate::Unspecified)
+      RefinementSteps = 0;
+
+    if (VT == MVT::f16) {
+      SDValue Zero = DAG.getIntPtrConstant(0, DL);
+      SDValue Undef = DAG.getUNDEF(MVT::v8f16);
+      Op = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v8f16, Op);
+      Op = DAG.getNode(X86ISD::RCP14S, DL, MVT::v8f16, Undef, Op);
+      return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Op, Zero);
+    }
+
+    return DAG.getNode(X86ISD::RCP14, DL, VT, Op);
   }
   return SDValue();
 }

diff  --git a/llvm/test/CodeGen/X86/avx512fp16-arith-vl-intrinsics.ll b/llvm/test/CodeGen/X86/avx512fp16-arith-vl-intrinsics.ll
index d827206318e7..89eb782fcd42 100644
--- a/llvm/test/CodeGen/X86/avx512fp16-arith-vl-intrinsics.ll
+++ b/llvm/test/CodeGen/X86/avx512fp16-arith-vl-intrinsics.ll
@@ -250,6 +250,16 @@ define <16 x half> @test_int_x86_avx512fp16_div_ph_256(<16 x half> %x1, <16 x ha
   ret <16 x half> %res
 }
 
+define <16 x half> @test_int_x86_avx512fp16_div_ph_256_fast(<16 x half> %x1, <16 x half> %x2) {
+; CHECK-LABEL: test_int_x86_avx512fp16_div_ph_256_fast:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vrcpph %ymm1, %ymm1
+; CHECK-NEXT:    vmulph %ymm0, %ymm1, %ymm0
+; CHECK-NEXT:    retq
+  %res = fdiv fast <16 x half> %x1, %x2
+  ret <16 x half> %res
+}
+
 define <16 x half> @test_int_x86_avx512fp16_mask_div_ph_256(<16 x half> %x1, <16 x half> %x2, <16 x half> %src, i16 %mask, <16 x half>* %ptr) {
 ; CHECK-LABEL: test_int_x86_avx512fp16_mask_div_ph_256:
 ; CHECK:       # %bb.0:
@@ -290,6 +300,16 @@ define <8 x half> @test_int_x86_avx512fp16_div_ph_128(<8 x half> %x1, <8 x half>
   ret <8 x half> %res
 }
 
+define <8 x half> @test_int_x86_avx512fp16_div_ph_128_fast(<8 x half> %x1, <8 x half> %x2) {
+; CHECK-LABEL: test_int_x86_avx512fp16_div_ph_128_fast:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vrcpph %xmm1, %xmm1
+; CHECK-NEXT:    vmulph %xmm0, %xmm1, %xmm0
+; CHECK-NEXT:    retq
+  %res = fdiv fast <8 x half> %x1, %x2
+  ret <8 x half> %res
+}
+
 define <8 x half> @test_int_x86_avx512fp16_mask_div_ph_128(<8 x half> %x1, <8 x half> %x2, <8 x half> %src, i8 %mask, <8 x half>* %ptr) {
 ; CHECK-LABEL: test_int_x86_avx512fp16_mask_div_ph_128:
 ; CHECK:       # %bb.0:

diff  --git a/llvm/test/CodeGen/X86/avx512fp16-arith.ll b/llvm/test/CodeGen/X86/avx512fp16-arith.ll
index e897c195b906..75f9ae66bfef 100644
--- a/llvm/test/CodeGen/X86/avx512fp16-arith.ll
+++ b/llvm/test/CodeGen/X86/avx512fp16-arith.ll
@@ -154,6 +154,16 @@ define <32 x half> @vdivph_512_test(<32 x half> %i, <32 x half> %j) nounwind rea
   ret <32 x half> %x
 }
 
+define <32 x half> @vdivph_512_test_fast(<32 x half> %i, <32 x half> %j) nounwind readnone {
+; CHECK-LABEL: vdivph_512_test_fast:
+; CHECK:       ## %bb.0:
+; CHECK-NEXT:    vrcpph %zmm1, %zmm1
+; CHECK-NEXT:    vmulph %zmm0, %zmm1, %zmm0
+; CHECK-NEXT:    retq
+  %x = fdiv fast <32 x half> %i, %j
+  ret <32 x half> %x
+}
+
 define half @add_sh(half %i, half %j, half* %x.ptr) nounwind readnone {
 ; CHECK-LABEL: add_sh:
 ; CHECK:       ## %bb.0:
@@ -228,6 +238,16 @@ define half @div_sh_2(half %i, half %j, half* %x.ptr) nounwind readnone {
   ret half %r
 }
 
+define half @div_sh_3(half %i, half %j) nounwind readnone {
+; CHECK-LABEL: div_sh_3:
+; CHECK:       ## %bb.0:
+; CHECK-NEXT:    vrcpsh %xmm1, %xmm1, %xmm1
+; CHECK-NEXT:    vmulsh %xmm0, %xmm1, %xmm0
+; CHECK-NEXT:    retq
+  %r = fdiv fast half %i, %j
+  ret half %r
+}
+
 define i1 @cmp_une_sh(half %x, half %y) {
 ; CHECK-LABEL: cmp_une_sh:
 ; CHECK:       ## %bb.0: ## %entry

diff  --git a/llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll b/llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll
index 170e1ea1a6a9..04ad68fdc893 100644
--- a/llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll
+++ b/llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll
@@ -24,6 +24,17 @@ define <32 x half> @test_sqrt_ph_512(<32 x half> %a0) {
   ret <32 x half> %1
 }
 
+define <32 x half> @test_sqrt_ph_512_fast(<32 x half> %a0, <32 x half> %a1) {
+; CHECK-LABEL: test_sqrt_ph_512_fast:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vrsqrtph %zmm0, %zmm0
+; CHECK-NEXT:    vmulph %zmm0, %zmm1, %zmm0
+; CHECK-NEXT:    retq
+  %1 = call fast <32 x half> @llvm.sqrt.v32f16(<32 x half> %a0)
+  %2 = fdiv fast <32 x half> %a1, %1
+  ret <32 x half> %2
+}
+
 define <32 x half> @test_mask_sqrt_ph_512(<32 x half> %a0, <32 x half> %passthru, i32 %mask) {
 ; CHECK-LABEL: test_mask_sqrt_ph_512:
 ; CHECK:       # %bb.0:
@@ -98,6 +109,19 @@ define <8 x half> @test_sqrt_sh(<8 x half> %a0, <8 x half> %a1, <8 x half> %a2,
   ret <8 x half> %res
 }
 
+define half @test_sqrt_sh2(half %a0, half %a1) {
+; CHECK-LABEL: test_sqrt_sh2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vrsqrtsh %xmm0, %xmm0, %xmm0
+; CHECK-NEXT:    vmulsh %xmm0, %xmm1, %xmm0
+; CHECK-NEXT:    retq
+  %1 = call fast half @llvm.sqrt.f16(half %a0)
+  %2 = fdiv fast half %a1, %1
+  ret half %2
+}
+
+declare half @llvm.sqrt.f16(half)
+
 define <8 x half> @test_sqrt_sh_r(<8 x half> %a0, <8 x half> %a1, <8 x half> %a2, i8 %mask) {
 ; CHECK-LABEL: test_sqrt_sh_r:
 ; CHECK:       # %bb.0:

diff  --git a/llvm/test/CodeGen/X86/avx512fp16vl-intrinsics.ll b/llvm/test/CodeGen/X86/avx512fp16vl-intrinsics.ll
index 93efbace3e75..4a133e108f2c 100644
--- a/llvm/test/CodeGen/X86/avx512fp16vl-intrinsics.ll
+++ b/llvm/test/CodeGen/X86/avx512fp16vl-intrinsics.ll
@@ -958,6 +958,17 @@ define <8 x half> @test_sqrt_ph_128(<8 x half> %a0) {
   ret <8 x half> %1
 }
 
+define <8 x half> @test_sqrt_ph_128_fast(<8 x half> %a0, <8 x half> %a1) {
+; CHECK-LABEL: test_sqrt_ph_128_fast:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vrsqrtph %xmm0, %xmm0
+; CHECK-NEXT:    vmulph %xmm0, %xmm1, %xmm0
+; CHECK-NEXT:    retq
+  %1 = call fast <8 x half> @llvm.sqrt.v8f16(<8 x half> %a0)
+  %2 = fdiv fast <8 x half> %a1, %1
+  ret <8 x half> %2
+}
+
 define <8 x half> @test_mask_sqrt_ph_128(<8 x half> %a0, <8 x half> %passthru, i8 %mask) {
 ; CHECK-LABEL: test_mask_sqrt_ph_128:
 ; CHECK:       # %bb.0:
@@ -992,6 +1003,17 @@ define <16 x half> @test_sqrt_ph_256(<16 x half> %a0) {
   ret <16 x half> %1
 }
 
+define <16 x half> @test_sqrt_ph_256_fast(<16 x half> %a0, <16 x half> %a1) {
+; CHECK-LABEL: test_sqrt_ph_256_fast:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vrsqrtph %ymm0, %ymm0
+; CHECK-NEXT:    vmulph %ymm0, %ymm1, %ymm0
+; CHECK-NEXT:    retq
+  %1 = call fast <16 x half> @llvm.sqrt.v16f16(<16 x half> %a0)
+  %2 = fdiv fast <16 x half> %a1, %1
+  ret <16 x half> %2
+}
+
 define <16 x half> @test_mask_sqrt_ph_256(<16 x half> %a0, <16 x half> %passthru, i16 %mask) {
 ; CHECK-LABEL: test_mask_sqrt_ph_256:
 ; CHECK:       # %bb.0:


        


More information about the llvm-commits mailing list