[llvm] AMDGPU: Emit 1/llvm.sqrt(x) instead of rsqrt calls in libcall handling (PR #92863)
via llvm-commits
llvm-commits at lists.llvm.org
Mon May 20 23:17:19 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-amdgpu
Author: Matt Arsenault (arsenm)
<details>
<summary>Changes</summary>
With the contract flag we should end up codegening to the rsqrt instruction, or denormal corrected rsqrt sequence present in the library.
---
Full diff: https://github.com/llvm/llvm-project/pull/92863.diff
3 Files Affected:
- (modified) llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp (+29-9)
- (modified) llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-rootn.ll (+25-18)
- (modified) llvm/test/CodeGen/AMDGPU/simplify-libcalls.ll (+2-2)
``````````diff
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp b/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
index aab79ceb57f22..c515138d95a2a 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
@@ -1215,16 +1215,36 @@ bool AMDGPULibCalls::fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B,
"__rootn2div");
replaceCall(FPOp, nval);
return true;
- } else if (ci_opr1 == -2) { // rootn(x, -2) = rsqrt(x)
- if (FunctionCallee FPExpr =
- getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_RSQRT, FInfo))) {
- LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> rsqrt(" << *opr0
- << ")\n");
- Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2rsqrt");
- replaceCall(FPOp, nval);
- return true;
- }
}
+
+ if (ci_opr1 == -2 &&
+ shouldReplaceLibcallWithIntrinsic(CI,
+ /*AllowMinSizeF32=*/true,
+ /*AllowF64=*/true)) {
+ // rootn(x, -2) = rsqrt(x)
+
+ // The original rootn had looser ulp requirements than the resultant sqrt
+ // and fdiv.
+ MDBuilder MDHelper(M->getContext());
+ MDNode *FPMD = MDHelper.createFPMath(std::max(FPOp->getFPAccuracy(), 2.0f));
+
+ // TODO: Could handle strictfp but need to fix strict sqrt emission
+ FastMathFlags FMF = FPOp->getFastMathFlags();
+ FMF.setAllowContract(true);
+
+ CallInst *Sqrt = B.CreateUnaryIntrinsic(Intrinsic::sqrt, opr0, CI);
+ Instruction *RSqrt = cast<Instruction>(
+ B.CreateFDiv(ConstantFP::get(opr0->getType(), 1.0), Sqrt));
+ Sqrt->setFastMathFlags(FMF);
+ RSqrt->setFastMathFlags(FMF);
+ RSqrt->setMetadata(LLVMContext::MD_fpmath, FPMD);
+
+ LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> rsqrt(" << *opr0
+ << ")\n");
+ replaceCall(CI, RSqrt);
+ return true;
+ }
+
return false;
}
diff --git a/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-rootn.ll b/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-rootn.ll
index c105ad7590e69..7932f8d1fc5be 100644
--- a/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-rootn.ll
+++ b/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-rootn.ll
@@ -302,7 +302,8 @@ define half @test_rootn_f16_neg1(half %x) {
define half @test_rootn_f16_neg2(half %x) {
; CHECK-LABEL: define half @test_rootn_f16_neg2(
; CHECK-SAME: half [[X:%.*]]) {
-; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = call half @_Z5rsqrtDh(half [[X]])
+; CHECK-NEXT: [[TMP1:%.*]] = call contract half @llvm.sqrt.f16(half [[X]])
+; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = fdiv contract half 0xH3C00, [[TMP1]], !fpmath [[META0]]
; CHECK-NEXT: ret half [[__ROOTN2RSQRT]]
;
%call = tail call half @_Z5rootnDhi(half %x, i32 -2)
@@ -371,7 +372,8 @@ define <2 x half> @test_rootn_v2f16_neg1(<2 x half> %x) {
define <2 x half> @test_rootn_v2f16_neg2(<2 x half> %x) {
; CHECK-LABEL: define <2 x half> @test_rootn_v2f16_neg2(
; CHECK-SAME: <2 x half> [[X:%.*]]) {
-; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = call <2 x half> @_Z5rsqrtDv2_Dh(<2 x half> [[X]])
+; CHECK-NEXT: [[TMP1:%.*]] = call contract <2 x half> @llvm.sqrt.v2f16(<2 x half> [[X]])
+; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = fdiv contract <2 x half> <half 0xH3C00, half 0xH3C00>, [[TMP1]], !fpmath [[META0]]
; CHECK-NEXT: ret <2 x half> [[__ROOTN2RSQRT]]
;
%call = tail call <2 x half> @_Z5rootnDv2_DhDv2_i(<2 x half> %x, <2 x i32> <i32 -2, i32 -2>)
@@ -865,7 +867,8 @@ define float @test_rootn_f32__y_neg2(float %x) {
; CHECK-LABEL: define float @test_rootn_f32__y_neg2(
; CHECK-SAME: float [[X:%.*]]) {
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = call float @_Z5rsqrtf(float [[X]])
+; CHECK-NEXT: [[TMP0:%.*]] = call contract float @llvm.sqrt.f32(float [[X]])
+; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = fdiv contract float 1.000000e+00, [[TMP0]], !fpmath [[META0]]
; CHECK-NEXT: ret float [[__ROOTN2RSQRT]]
;
entry:
@@ -877,7 +880,8 @@ define float @test_rootn_f32__y_neg2__flags(float %x) {
; CHECK-LABEL: define float @test_rootn_f32__y_neg2__flags(
; CHECK-SAME: float [[X:%.*]]) {
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = call nnan nsz float @_Z5rsqrtf(float [[X]])
+; CHECK-NEXT: [[TMP0:%.*]] = call nnan nsz contract float @llvm.sqrt.f32(float [[X]])
+; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = fdiv nnan nsz contract float 1.000000e+00, [[TMP0]], !fpmath [[META0]]
; CHECK-NEXT: ret float [[__ROOTN2RSQRT]]
;
entry:
@@ -889,7 +893,7 @@ define float @test_rootn_f32__y_neg2__strictfp(float %x) #1 {
; CHECK-LABEL: define float @test_rootn_f32__y_neg2__strictfp(
; CHECK-SAME: float [[X:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = call float @_Z5rsqrtf(float [[X]]) #[[ATTR0]]
+; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR0]]
; CHECK-NEXT: ret float [[__ROOTN2RSQRT]]
;
entry:
@@ -901,7 +905,7 @@ define float @test_rootn_f32__y_neg2__noinline(float %x) {
; CHECK-LABEL: define float @test_rootn_f32__y_neg2__noinline(
; CHECK-SAME: float [[X:%.*]]) {
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = call float @_Z5rsqrtf(float [[X]])
+; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR3:[0-9]+]]
; CHECK-NEXT: ret float [[__ROOTN2RSQRT]]
;
entry:
@@ -913,7 +917,7 @@ define float @test_rootn_f32__y_neg2__nobuiltin(float %x) {
; CHECK-LABEL: define float @test_rootn_f32__y_neg2__nobuiltin(
; CHECK-SAME: float [[X:%.*]]) {
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR3:[0-9]+]]
+; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR4:[0-9]+]]
; CHECK-NEXT: ret float [[CALL]]
;
entry:
@@ -925,7 +929,8 @@ define <2 x float> @test_rootn_v2f32__y_neg2(<2 x float> %x) {
; CHECK-LABEL: define <2 x float> @test_rootn_v2f32__y_neg2(
; CHECK-SAME: <2 x float> [[X:%.*]]) {
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = call <2 x float> @_Z5rsqrtDv2_f(<2 x float> [[X]])
+; CHECK-NEXT: [[TMP0:%.*]] = call contract <2 x float> @llvm.sqrt.v2f32(<2 x float> [[X]])
+; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = fdiv contract <2 x float> <float 1.000000e+00, float 1.000000e+00>, [[TMP0]], !fpmath [[META0]]
; CHECK-NEXT: ret <2 x float> [[__ROOTN2RSQRT]]
;
entry:
@@ -937,7 +942,8 @@ define <2 x float> @test_rootn_v2f32__y_neg2__flags(<2 x float> %x) {
; CHECK-LABEL: define <2 x float> @test_rootn_v2f32__y_neg2__flags(
; CHECK-SAME: <2 x float> [[X:%.*]]) {
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = call nnan nsz <2 x float> @_Z5rsqrtDv2_f(<2 x float> [[X]])
+; CHECK-NEXT: [[TMP0:%.*]] = call nnan nsz contract <2 x float> @llvm.sqrt.v2f32(<2 x float> [[X]])
+; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = fdiv nnan nsz contract <2 x float> <float 1.000000e+00, float 1.000000e+00>, [[TMP0]], !fpmath [[META0]]
; CHECK-NEXT: ret <2 x float> [[__ROOTN2RSQRT]]
;
entry:
@@ -949,7 +955,7 @@ define <2 x float> @test_rootn_v2f32__y_neg2__strictfp(<2 x float> %x) #1 {
; CHECK-LABEL: define <2 x float> @test_rootn_v2f32__y_neg2__strictfp(
; CHECK-SAME: <2 x float> [[X:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = call <2 x float> @_Z5rsqrtDv2_f(<2 x float> [[X]]) #[[ATTR0]]
+; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = tail call <2 x float> @_Z5rootnDv2_fDv2_i(<2 x float> [[X]], <2 x i32> <i32 -2, i32 -2>) #[[ATTR0]]
; CHECK-NEXT: ret <2 x float> [[__ROOTN2RSQRT]]
;
entry:
@@ -1125,7 +1131,7 @@ define float @test_rootn_fast_f32_nobuiltin(float %x, i32 %y) {
; CHECK-LABEL: define float @test_rootn_fast_f32_nobuiltin(
; CHECK-SAME: float [[X:%.*]], i32 [[Y:%.*]]) {
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[CALL:%.*]] = tail call fast float @_Z5rootnfi(float [[X]], i32 [[Y]]) #[[ATTR3]]
+; CHECK-NEXT: [[CALL:%.*]] = tail call fast float @_Z5rootnfi(float [[X]], i32 [[Y]]) #[[ATTR4]]
; CHECK-NEXT: ret float [[CALL]]
;
entry:
@@ -1420,7 +1426,7 @@ entry:
define float @test_rootn_f32__y_0_nobuiltin(float %x) {
; CHECK-LABEL: define float @test_rootn_f32__y_0_nobuiltin(
; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 0) #[[ATTR3]]
+; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 0) #[[ATTR4]]
; CHECK-NEXT: ret float [[CALL]]
;
%call = tail call float @_Z5rootnfi(float %x, i32 0) #0
@@ -1430,7 +1436,7 @@ define float @test_rootn_f32__y_0_nobuiltin(float %x) {
define float @test_rootn_f32__y_1_nobuiltin(float %x) {
; CHECK-LABEL: define float @test_rootn_f32__y_1_nobuiltin(
; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 1) #[[ATTR3]]
+; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 1) #[[ATTR4]]
; CHECK-NEXT: ret float [[CALL]]
;
%call = tail call float @_Z5rootnfi(float %x, i32 1) #0
@@ -1440,7 +1446,7 @@ define float @test_rootn_f32__y_1_nobuiltin(float %x) {
define float @test_rootn_f32__y_2_nobuiltin(float %x) {
; CHECK-LABEL: define float @test_rootn_f32__y_2_nobuiltin(
; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 2) #[[ATTR3]]
+; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 2) #[[ATTR4]]
; CHECK-NEXT: ret float [[CALL]]
;
%call = tail call float @_Z5rootnfi(float %x, i32 2) #0
@@ -1450,7 +1456,7 @@ define float @test_rootn_f32__y_2_nobuiltin(float %x) {
define float @test_rootn_f32__y_3_nobuiltin(float %x) {
; CHECK-LABEL: define float @test_rootn_f32__y_3_nobuiltin(
; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 3) #[[ATTR3]]
+; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 3) #[[ATTR4]]
; CHECK-NEXT: ret float [[CALL]]
;
%call = tail call float @_Z5rootnfi(float %x, i32 3) #0
@@ -1460,7 +1466,7 @@ define float @test_rootn_f32__y_3_nobuiltin(float %x) {
define float @test_rootn_f32__y_neg1_nobuiltin(float %x) {
; CHECK-LABEL: define float @test_rootn_f32__y_neg1_nobuiltin(
; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -1) #[[ATTR3]]
+; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -1) #[[ATTR4]]
; CHECK-NEXT: ret float [[CALL]]
;
%call = tail call float @_Z5rootnfi(float %x, i32 -1) #0
@@ -1470,7 +1476,7 @@ define float @test_rootn_f32__y_neg1_nobuiltin(float %x) {
define float @test_rootn_f32__y_neg2_nobuiltin(float %x) {
; CHECK-LABEL: define float @test_rootn_f32__y_neg2_nobuiltin(
; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR3]]
+; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR4]]
; CHECK-NEXT: ret float [[CALL]]
;
%call = tail call float @_Z5rootnfi(float %x, i32 -2) #0
@@ -1487,7 +1493,8 @@ attributes #2 = { noinline }
; CHECK: attributes #[[ATTR0]] = { strictfp }
; CHECK: attributes #[[ATTR1:[0-9]+]] = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
; CHECK: attributes #[[ATTR2:[0-9]+]] = { nounwind memory(read) }
-; CHECK: attributes #[[ATTR3]] = { nobuiltin }
+; CHECK: attributes #[[ATTR3]] = { noinline }
+; CHECK: attributes #[[ATTR4]] = { nobuiltin }
;.
; CHECK: [[META0]] = !{float 2.000000e+00}
; CHECK: [[META1]] = !{float 3.000000e+00}
diff --git a/llvm/test/CodeGen/AMDGPU/simplify-libcalls.ll b/llvm/test/CodeGen/AMDGPU/simplify-libcalls.ll
index 152eba5dec946..5a241f85b2e2c 100644
--- a/llvm/test/CodeGen/AMDGPU/simplify-libcalls.ll
+++ b/llvm/test/CodeGen/AMDGPU/simplify-libcalls.ll
@@ -506,8 +506,8 @@ entry:
}
; GCN-LABEL: {{^}}define amdgpu_kernel void @test_rootn_m2
-; GCN-POSTLINK: call fast float @_Z5rootnfi(float %tmp, i32 -2)
-; GCN-PRELINK: %__rootn2rsqrt = tail call fast float @_Z5rsqrtf(float %tmp)
+; GCN: [[SQRT:%.+]] = tail call fast float @llvm.sqrt.f32(float %tmp)
+; GCN-NEXT: fdiv fast float 1.000000e+00, [[SQRT]]
define amdgpu_kernel void @test_rootn_m2(ptr addrspace(1) nocapture %a) {
entry:
%tmp = load float, ptr addrspace(1) %a, align 4
``````````
</details>
https://github.com/llvm/llvm-project/pull/92863
More information about the llvm-commits
mailing list