[llvm] c236883 - [X86] Optimize fdiv with reciprocal instructions for half type
via llvm-commits
llvm-commits at lists.llvm.org
Thu Oct 7 18:41:21 PDT 2021
Author: Wang, Pengfei
Date: 2021-10-08T09:41:13+08:00
New Revision: c236883b6ba791881256b31cfdb8a8520a821a67
URL: https://github.com/llvm/llvm-project/commit/c236883b6ba791881256b31cfdb8a8520a821a67
DIFF: https://github.com/llvm/llvm-project/commit/c236883b6ba791881256b31cfdb8a8520a821a67.diff
LOG: [X86] Optimize fdiv with reciprocal instructions for half type
Reviewed By: craig.topper
Differential Revision: https://reviews.llvm.org/D110557
Added:
Modified:
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/avx512fp16-arith-vl-intrinsics.ll
llvm/test/CodeGen/X86/avx512fp16-arith.ll
llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll
llvm/test/CodeGen/X86/avx512fp16vl-intrinsics.ll
Removed:
################################################################################
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 2a167cbaf9a7..550c6bcb8d65 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -23046,9 +23046,10 @@ SDValue DAGCombiner::BuildDivEstimate(SDValue N, SDValue Op,
if (LegalDAG)
return SDValue();
- // TODO: Handle half and/or extended types?
+ // TODO: Handle extended types?
EVT VT = Op.getValueType();
- if (VT.getScalarType() != MVT::f32 && VT.getScalarType() != MVT::f64)
+ if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
+ VT.getScalarType() != MVT::f64)
return SDValue();
// If estimates are explicitly disabled for this function, we're done.
@@ -23185,9 +23186,10 @@ SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
if (LegalDAG)
return SDValue();
- // TODO: Handle half and/or extended types?
+ // TODO: Handle extended types?
EVT VT = Op.getValueType();
- if (VT.getScalarType() != MVT::f32 && VT.getScalarType() != MVT::f64)
+ if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
+ VT.getScalarType() != MVT::f64)
return SDValue();
// If estimates are explicitly disabled for this function, we're done.
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index dd32d323677e..23f539b1edb7 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -23148,6 +23148,7 @@ SDValue X86TargetLowering::getSqrtEstimate(SDValue Op,
int &RefinementSteps,
bool &UseOneConstNR,
bool Reciprocal) const {
+ SDLoc DL(Op);
EVT VT = Op.getValueType();
// SSE1 has rsqrtss and rsqrtps. AVX adds a 256-bit variant for rsqrtps.
@@ -23169,7 +23170,23 @@ SDValue X86TargetLowering::getSqrtEstimate(SDValue Op,
UseOneConstNR = false;
// There is no FSQRT for 512-bits, but there is RSQRT14.
unsigned Opcode = VT == MVT::v16f32 ? X86ISD::RSQRT14 : X86ISD::FRSQRT;
- return DAG.getNode(Opcode, SDLoc(Op), VT, Op);
+ return DAG.getNode(Opcode, DL, VT, Op);
+ }
+
+ if (VT.getScalarType() == MVT::f16 && isTypeLegal(VT) &&
+ Subtarget.hasFP16()) {
+ if (RefinementSteps == ReciprocalEstimate::Unspecified)
+ RefinementSteps = 0;
+
+ if (VT == MVT::f16) {
+ SDValue Zero = DAG.getIntPtrConstant(0, DL);
+ SDValue Undef = DAG.getUNDEF(MVT::v8f16);
+ Op = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v8f16, Op);
+ Op = DAG.getNode(X86ISD::RSQRT14S, DL, MVT::v8f16, Undef, Op);
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Op, Zero);
+ }
+
+ return DAG.getNode(X86ISD::RSQRT14, DL, VT, Op);
}
return SDValue();
}
@@ -23179,6 +23196,7 @@ SDValue X86TargetLowering::getSqrtEstimate(SDValue Op,
SDValue X86TargetLowering::getRecipEstimate(SDValue Op, SelectionDAG &DAG,
int Enabled,
int &RefinementSteps) const {
+ SDLoc DL(Op);
EVT VT = Op.getValueType();
// SSE1 has rcpss and rcpps. AVX adds a 256-bit variant for rcpps.
@@ -23203,7 +23221,23 @@ SDValue X86TargetLowering::getRecipEstimate(SDValue Op, SelectionDAG &DAG,
// There is no FSQRT for 512-bits, but there is RCP14.
unsigned Opcode = VT == MVT::v16f32 ? X86ISD::RCP14 : X86ISD::FRCP;
- return DAG.getNode(Opcode, SDLoc(Op), VT, Op);
+ return DAG.getNode(Opcode, DL, VT, Op);
+ }
+
+ if (VT.getScalarType() == MVT::f16 && isTypeLegal(VT) &&
+ Subtarget.hasFP16()) {
+ if (RefinementSteps == ReciprocalEstimate::Unspecified)
+ RefinementSteps = 0;
+
+ if (VT == MVT::f16) {
+ SDValue Zero = DAG.getIntPtrConstant(0, DL);
+ SDValue Undef = DAG.getUNDEF(MVT::v8f16);
+ Op = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v8f16, Op);
+ Op = DAG.getNode(X86ISD::RCP14S, DL, MVT::v8f16, Undef, Op);
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Op, Zero);
+ }
+
+ return DAG.getNode(X86ISD::RCP14, DL, VT, Op);
}
return SDValue();
}
diff --git a/llvm/test/CodeGen/X86/avx512fp16-arith-vl-intrinsics.ll b/llvm/test/CodeGen/X86/avx512fp16-arith-vl-intrinsics.ll
index d827206318e7..89eb782fcd42 100644
--- a/llvm/test/CodeGen/X86/avx512fp16-arith-vl-intrinsics.ll
+++ b/llvm/test/CodeGen/X86/avx512fp16-arith-vl-intrinsics.ll
@@ -250,6 +250,16 @@ define <16 x half> @test_int_x86_avx512fp16_div_ph_256(<16 x half> %x1, <16 x ha
ret <16 x half> %res
}
+define <16 x half> @test_int_x86_avx512fp16_div_ph_256_fast(<16 x half> %x1, <16 x half> %x2) {
+; CHECK-LABEL: test_int_x86_avx512fp16_div_ph_256_fast:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vrcpph %ymm1, %ymm1
+; CHECK-NEXT: vmulph %ymm0, %ymm1, %ymm0
+; CHECK-NEXT: retq
+ %res = fdiv fast <16 x half> %x1, %x2
+ ret <16 x half> %res
+}
+
define <16 x half> @test_int_x86_avx512fp16_mask_div_ph_256(<16 x half> %x1, <16 x half> %x2, <16 x half> %src, i16 %mask, <16 x half>* %ptr) {
; CHECK-LABEL: test_int_x86_avx512fp16_mask_div_ph_256:
; CHECK: # %bb.0:
@@ -290,6 +300,16 @@ define <8 x half> @test_int_x86_avx512fp16_div_ph_128(<8 x half> %x1, <8 x half>
ret <8 x half> %res
}
+define <8 x half> @test_int_x86_avx512fp16_div_ph_128_fast(<8 x half> %x1, <8 x half> %x2) {
+; CHECK-LABEL: test_int_x86_avx512fp16_div_ph_128_fast:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vrcpph %xmm1, %xmm1
+; CHECK-NEXT: vmulph %xmm0, %xmm1, %xmm0
+; CHECK-NEXT: retq
+ %res = fdiv fast <8 x half> %x1, %x2
+ ret <8 x half> %res
+}
+
define <8 x half> @test_int_x86_avx512fp16_mask_div_ph_128(<8 x half> %x1, <8 x half> %x2, <8 x half> %src, i8 %mask, <8 x half>* %ptr) {
; CHECK-LABEL: test_int_x86_avx512fp16_mask_div_ph_128:
; CHECK: # %bb.0:
diff --git a/llvm/test/CodeGen/X86/avx512fp16-arith.ll b/llvm/test/CodeGen/X86/avx512fp16-arith.ll
index e897c195b906..75f9ae66bfef 100644
--- a/llvm/test/CodeGen/X86/avx512fp16-arith.ll
+++ b/llvm/test/CodeGen/X86/avx512fp16-arith.ll
@@ -154,6 +154,16 @@ define <32 x half> @vdivph_512_test(<32 x half> %i, <32 x half> %j) nounwind rea
ret <32 x half> %x
}
+define <32 x half> @vdivph_512_test_fast(<32 x half> %i, <32 x half> %j) nounwind readnone {
+; CHECK-LABEL: vdivph_512_test_fast:
+; CHECK: ## %bb.0:
+; CHECK-NEXT: vrcpph %zmm1, %zmm1
+; CHECK-NEXT: vmulph %zmm0, %zmm1, %zmm0
+; CHECK-NEXT: retq
+ %x = fdiv fast <32 x half> %i, %j
+ ret <32 x half> %x
+}
+
define half @add_sh(half %i, half %j, half* %x.ptr) nounwind readnone {
; CHECK-LABEL: add_sh:
; CHECK: ## %bb.0:
@@ -228,6 +238,16 @@ define half @div_sh_2(half %i, half %j, half* %x.ptr) nounwind readnone {
ret half %r
}
+define half @div_sh_3(half %i, half %j) nounwind readnone {
+; CHECK-LABEL: div_sh_3:
+; CHECK: ## %bb.0:
+; CHECK-NEXT: vrcpsh %xmm1, %xmm1, %xmm1
+; CHECK-NEXT: vmulsh %xmm0, %xmm1, %xmm0
+; CHECK-NEXT: retq
+ %r = fdiv fast half %i, %j
+ ret half %r
+}
+
define i1 @cmp_une_sh(half %x, half %y) {
; CHECK-LABEL: cmp_une_sh:
; CHECK: ## %bb.0: ## %entry
diff --git a/llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll b/llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll
index 170e1ea1a6a9..04ad68fdc893 100644
--- a/llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll
+++ b/llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll
@@ -24,6 +24,17 @@ define <32 x half> @test_sqrt_ph_512(<32 x half> %a0) {
ret <32 x half> %1
}
+define <32 x half> @test_sqrt_ph_512_fast(<32 x half> %a0, <32 x half> %a1) {
+; CHECK-LABEL: test_sqrt_ph_512_fast:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vrsqrtph %zmm0, %zmm0
+; CHECK-NEXT: vmulph %zmm0, %zmm1, %zmm0
+; CHECK-NEXT: retq
+ %1 = call fast <32 x half> @llvm.sqrt.v32f16(<32 x half> %a0)
+ %2 = fdiv fast <32 x half> %a1, %1
+ ret <32 x half> %2
+}
+
define <32 x half> @test_mask_sqrt_ph_512(<32 x half> %a0, <32 x half> %passthru, i32 %mask) {
; CHECK-LABEL: test_mask_sqrt_ph_512:
; CHECK: # %bb.0:
@@ -98,6 +109,19 @@ define <8 x half> @test_sqrt_sh(<8 x half> %a0, <8 x half> %a1, <8 x half> %a2,
ret <8 x half> %res
}
+define half @test_sqrt_sh2(half %a0, half %a1) {
+; CHECK-LABEL: test_sqrt_sh2:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vrsqrtsh %xmm0, %xmm0, %xmm0
+; CHECK-NEXT: vmulsh %xmm0, %xmm1, %xmm0
+; CHECK-NEXT: retq
+ %1 = call fast half @llvm.sqrt.f16(half %a0)
+ %2 = fdiv fast half %a1, %1
+ ret half %2
+}
+
+declare half @llvm.sqrt.f16(half)
+
define <8 x half> @test_sqrt_sh_r(<8 x half> %a0, <8 x half> %a1, <8 x half> %a2, i8 %mask) {
; CHECK-LABEL: test_sqrt_sh_r:
; CHECK: # %bb.0:
diff --git a/llvm/test/CodeGen/X86/avx512fp16vl-intrinsics.ll b/llvm/test/CodeGen/X86/avx512fp16vl-intrinsics.ll
index 93efbace3e75..4a133e108f2c 100644
--- a/llvm/test/CodeGen/X86/avx512fp16vl-intrinsics.ll
+++ b/llvm/test/CodeGen/X86/avx512fp16vl-intrinsics.ll
@@ -958,6 +958,17 @@ define <8 x half> @test_sqrt_ph_128(<8 x half> %a0) {
ret <8 x half> %1
}
+define <8 x half> @test_sqrt_ph_128_fast(<8 x half> %a0, <8 x half> %a1) {
+; CHECK-LABEL: test_sqrt_ph_128_fast:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vrsqrtph %xmm0, %xmm0
+; CHECK-NEXT: vmulph %xmm0, %xmm1, %xmm0
+; CHECK-NEXT: retq
+ %1 = call fast <8 x half> @llvm.sqrt.v8f16(<8 x half> %a0)
+ %2 = fdiv fast <8 x half> %a1, %1
+ ret <8 x half> %2
+}
+
define <8 x half> @test_mask_sqrt_ph_128(<8 x half> %a0, <8 x half> %passthru, i8 %mask) {
; CHECK-LABEL: test_mask_sqrt_ph_128:
; CHECK: # %bb.0:
@@ -992,6 +1003,17 @@ define <16 x half> @test_sqrt_ph_256(<16 x half> %a0) {
ret <16 x half> %1
}
+define <16 x half> @test_sqrt_ph_256_fast(<16 x half> %a0, <16 x half> %a1) {
+; CHECK-LABEL: test_sqrt_ph_256_fast:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vrsqrtph %ymm0, %ymm0
+; CHECK-NEXT: vmulph %ymm0, %ymm1, %ymm0
+; CHECK-NEXT: retq
+ %1 = call fast <16 x half> @llvm.sqrt.v16f16(<16 x half> %a0)
+ %2 = fdiv fast <16 x half> %a1, %1
+ ret <16 x half> %2
+}
+
define <16 x half> @test_mask_sqrt_ph_256(<16 x half> %a0, <16 x half> %passthru, i16 %mask) {
; CHECK-LABEL: test_mask_sqrt_ph_256:
; CHECK: # %bb.0:
More information about the llvm-commits
mailing list