[llvm] AMDGPU: Emit 1/llvm.sqrt(x) instead of rsqrt calls in libcall handling (PR #92863)

via llvm-commits llvm-commits at lists.llvm.org
Mon May 20 23:17:19 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-amdgpu

Author: Matt Arsenault (arsenm)

<details>
<summary>Changes</summary>

With the contract flag we should end up codegening to the rsqrt instruction, or denormal corrected rsqrt sequence present in the library.

---
Full diff: https://github.com/llvm/llvm-project/pull/92863.diff


3 Files Affected:

- (modified) llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp (+29-9) 
- (modified) llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-rootn.ll (+25-18) 
- (modified) llvm/test/CodeGen/AMDGPU/simplify-libcalls.ll (+2-2) 


``````````diff
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp b/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
index aab79ceb57f22..c515138d95a2a 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
@@ -1215,16 +1215,36 @@ bool AMDGPULibCalls::fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B,
                                "__rootn2div");
     replaceCall(FPOp, nval);
     return true;
-  } else if (ci_opr1 == -2) { // rootn(x, -2) = rsqrt(x)
-    if (FunctionCallee FPExpr =
-            getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_RSQRT, FInfo))) {
-      LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> rsqrt(" << *opr0
-                        << ")\n");
-      Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2rsqrt");
-      replaceCall(FPOp, nval);
-      return true;
-    }
   }
+
+  if (ci_opr1 == -2 &&
+      shouldReplaceLibcallWithIntrinsic(CI,
+                                        /*AllowMinSizeF32=*/true,
+                                        /*AllowF64=*/true)) {
+    // rootn(x, -2) = rsqrt(x)
+
+    // The original rootn had looser ulp requirements than the resultant sqrt
+    // and fdiv.
+    MDBuilder MDHelper(M->getContext());
+    MDNode *FPMD = MDHelper.createFPMath(std::max(FPOp->getFPAccuracy(), 2.0f));
+
+    // TODO: Could handle strictfp but need to fix strict sqrt emission
+    FastMathFlags FMF = FPOp->getFastMathFlags();
+    FMF.setAllowContract(true);
+
+    CallInst *Sqrt = B.CreateUnaryIntrinsic(Intrinsic::sqrt, opr0, CI);
+    Instruction *RSqrt = cast<Instruction>(
+        B.CreateFDiv(ConstantFP::get(opr0->getType(), 1.0), Sqrt));
+    Sqrt->setFastMathFlags(FMF);
+    RSqrt->setFastMathFlags(FMF);
+    RSqrt->setMetadata(LLVMContext::MD_fpmath, FPMD);
+
+    LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> rsqrt(" << *opr0
+                      << ")\n");
+    replaceCall(CI, RSqrt);
+    return true;
+  }
+
   return false;
 }
 
diff --git a/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-rootn.ll b/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-rootn.ll
index c105ad7590e69..7932f8d1fc5be 100644
--- a/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-rootn.ll
+++ b/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-rootn.ll
@@ -302,7 +302,8 @@ define half @test_rootn_f16_neg1(half %x) {
 define half @test_rootn_f16_neg2(half %x) {
 ; CHECK-LABEL: define half @test_rootn_f16_neg2(
 ; CHECK-SAME: half [[X:%.*]]) {
-; CHECK-NEXT:    [[__ROOTN2RSQRT:%.*]] = call half @_Z5rsqrtDh(half [[X]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call contract half @llvm.sqrt.f16(half [[X]])
+; CHECK-NEXT:    [[__ROOTN2RSQRT:%.*]] = fdiv contract half 0xH3C00, [[TMP1]], !fpmath [[META0]]
 ; CHECK-NEXT:    ret half [[__ROOTN2RSQRT]]
 ;
   %call = tail call half @_Z5rootnDhi(half %x, i32 -2)
@@ -371,7 +372,8 @@ define <2 x half> @test_rootn_v2f16_neg1(<2 x half> %x) {
 define <2 x half> @test_rootn_v2f16_neg2(<2 x half> %x) {
 ; CHECK-LABEL: define <2 x half> @test_rootn_v2f16_neg2(
 ; CHECK-SAME: <2 x half> [[X:%.*]]) {
-; CHECK-NEXT:    [[__ROOTN2RSQRT:%.*]] = call <2 x half> @_Z5rsqrtDv2_Dh(<2 x half> [[X]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call contract <2 x half> @llvm.sqrt.v2f16(<2 x half> [[X]])
+; CHECK-NEXT:    [[__ROOTN2RSQRT:%.*]] = fdiv contract <2 x half> <half 0xH3C00, half 0xH3C00>, [[TMP1]], !fpmath [[META0]]
 ; CHECK-NEXT:    ret <2 x half> [[__ROOTN2RSQRT]]
 ;
   %call = tail call <2 x half> @_Z5rootnDv2_DhDv2_i(<2 x half> %x, <2 x i32> <i32 -2, i32 -2>)
@@ -865,7 +867,8 @@ define float @test_rootn_f32__y_neg2(float %x) {
 ; CHECK-LABEL: define float @test_rootn_f32__y_neg2(
 ; CHECK-SAME: float [[X:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[__ROOTN2RSQRT:%.*]] = call float @_Z5rsqrtf(float [[X]])
+; CHECK-NEXT:    [[TMP0:%.*]] = call contract float @llvm.sqrt.f32(float [[X]])
+; CHECK-NEXT:    [[__ROOTN2RSQRT:%.*]] = fdiv contract float 1.000000e+00, [[TMP0]], !fpmath [[META0]]
 ; CHECK-NEXT:    ret float [[__ROOTN2RSQRT]]
 ;
 entry:
@@ -877,7 +880,8 @@ define float @test_rootn_f32__y_neg2__flags(float %x) {
 ; CHECK-LABEL: define float @test_rootn_f32__y_neg2__flags(
 ; CHECK-SAME: float [[X:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[__ROOTN2RSQRT:%.*]] = call nnan nsz float @_Z5rsqrtf(float [[X]])
+; CHECK-NEXT:    [[TMP0:%.*]] = call nnan nsz contract float @llvm.sqrt.f32(float [[X]])
+; CHECK-NEXT:    [[__ROOTN2RSQRT:%.*]] = fdiv nnan nsz contract float 1.000000e+00, [[TMP0]], !fpmath [[META0]]
 ; CHECK-NEXT:    ret float [[__ROOTN2RSQRT]]
 ;
 entry:
@@ -889,7 +893,7 @@ define float @test_rootn_f32__y_neg2__strictfp(float %x) #1 {
 ; CHECK-LABEL: define float @test_rootn_f32__y_neg2__strictfp(
 ; CHECK-SAME: float [[X:%.*]]) #[[ATTR0]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[__ROOTN2RSQRT:%.*]] = call float @_Z5rsqrtf(float [[X]]) #[[ATTR0]]
+; CHECK-NEXT:    [[__ROOTN2RSQRT:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR0]]
 ; CHECK-NEXT:    ret float [[__ROOTN2RSQRT]]
 ;
 entry:
@@ -901,7 +905,7 @@ define float @test_rootn_f32__y_neg2__noinline(float %x) {
 ; CHECK-LABEL: define float @test_rootn_f32__y_neg2__noinline(
 ; CHECK-SAME: float [[X:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[__ROOTN2RSQRT:%.*]] = call float @_Z5rsqrtf(float [[X]])
+; CHECK-NEXT:    [[__ROOTN2RSQRT:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR3:[0-9]+]]
 ; CHECK-NEXT:    ret float [[__ROOTN2RSQRT]]
 ;
 entry:
@@ -913,7 +917,7 @@ define float @test_rootn_f32__y_neg2__nobuiltin(float %x) {
 ; CHECK-LABEL: define float @test_rootn_f32__y_neg2__nobuiltin(
 ; CHECK-SAME: float [[X:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR3:[0-9]+]]
+; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR4:[0-9]+]]
 ; CHECK-NEXT:    ret float [[CALL]]
 ;
 entry:
@@ -925,7 +929,8 @@ define <2 x float> @test_rootn_v2f32__y_neg2(<2 x float> %x) {
 ; CHECK-LABEL: define <2 x float> @test_rootn_v2f32__y_neg2(
 ; CHECK-SAME: <2 x float> [[X:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[__ROOTN2RSQRT:%.*]] = call <2 x float> @_Z5rsqrtDv2_f(<2 x float> [[X]])
+; CHECK-NEXT:    [[TMP0:%.*]] = call contract <2 x float> @llvm.sqrt.v2f32(<2 x float> [[X]])
+; CHECK-NEXT:    [[__ROOTN2RSQRT:%.*]] = fdiv contract <2 x float> <float 1.000000e+00, float 1.000000e+00>, [[TMP0]], !fpmath [[META0]]
 ; CHECK-NEXT:    ret <2 x float> [[__ROOTN2RSQRT]]
 ;
 entry:
@@ -937,7 +942,8 @@ define <2 x float> @test_rootn_v2f32__y_neg2__flags(<2 x float> %x) {
 ; CHECK-LABEL: define <2 x float> @test_rootn_v2f32__y_neg2__flags(
 ; CHECK-SAME: <2 x float> [[X:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[__ROOTN2RSQRT:%.*]] = call nnan nsz <2 x float> @_Z5rsqrtDv2_f(<2 x float> [[X]])
+; CHECK-NEXT:    [[TMP0:%.*]] = call nnan nsz contract <2 x float> @llvm.sqrt.v2f32(<2 x float> [[X]])
+; CHECK-NEXT:    [[__ROOTN2RSQRT:%.*]] = fdiv nnan nsz contract <2 x float> <float 1.000000e+00, float 1.000000e+00>, [[TMP0]], !fpmath [[META0]]
 ; CHECK-NEXT:    ret <2 x float> [[__ROOTN2RSQRT]]
 ;
 entry:
@@ -949,7 +955,7 @@ define <2 x float> @test_rootn_v2f32__y_neg2__strictfp(<2 x float> %x) #1 {
 ; CHECK-LABEL: define <2 x float> @test_rootn_v2f32__y_neg2__strictfp(
 ; CHECK-SAME: <2 x float> [[X:%.*]]) #[[ATTR0]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[__ROOTN2RSQRT:%.*]] = call <2 x float> @_Z5rsqrtDv2_f(<2 x float> [[X]]) #[[ATTR0]]
+; CHECK-NEXT:    [[__ROOTN2RSQRT:%.*]] = tail call <2 x float> @_Z5rootnDv2_fDv2_i(<2 x float> [[X]], <2 x i32> <i32 -2, i32 -2>) #[[ATTR0]]
 ; CHECK-NEXT:    ret <2 x float> [[__ROOTN2RSQRT]]
 ;
 entry:
@@ -1125,7 +1131,7 @@ define float @test_rootn_fast_f32_nobuiltin(float %x, i32 %y) {
 ; CHECK-LABEL: define float @test_rootn_fast_f32_nobuiltin(
 ; CHECK-SAME: float [[X:%.*]], i32 [[Y:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[CALL:%.*]] = tail call fast float @_Z5rootnfi(float [[X]], i32 [[Y]]) #[[ATTR3]]
+; CHECK-NEXT:    [[CALL:%.*]] = tail call fast float @_Z5rootnfi(float [[X]], i32 [[Y]]) #[[ATTR4]]
 ; CHECK-NEXT:    ret float [[CALL]]
 ;
 entry:
@@ -1420,7 +1426,7 @@ entry:
 define float @test_rootn_f32__y_0_nobuiltin(float %x) {
 ; CHECK-LABEL: define float @test_rootn_f32__y_0_nobuiltin(
 ; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 0) #[[ATTR3]]
+; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 0) #[[ATTR4]]
 ; CHECK-NEXT:    ret float [[CALL]]
 ;
   %call = tail call float @_Z5rootnfi(float %x, i32 0) #0
@@ -1430,7 +1436,7 @@ define float @test_rootn_f32__y_0_nobuiltin(float %x) {
 define float @test_rootn_f32__y_1_nobuiltin(float %x) {
 ; CHECK-LABEL: define float @test_rootn_f32__y_1_nobuiltin(
 ; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 1) #[[ATTR3]]
+; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 1) #[[ATTR4]]
 ; CHECK-NEXT:    ret float [[CALL]]
 ;
   %call = tail call float @_Z5rootnfi(float %x, i32 1) #0
@@ -1440,7 +1446,7 @@ define float @test_rootn_f32__y_1_nobuiltin(float %x) {
 define float @test_rootn_f32__y_2_nobuiltin(float %x) {
 ; CHECK-LABEL: define float @test_rootn_f32__y_2_nobuiltin(
 ; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 2) #[[ATTR3]]
+; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 2) #[[ATTR4]]
 ; CHECK-NEXT:    ret float [[CALL]]
 ;
   %call = tail call float @_Z5rootnfi(float %x, i32 2) #0
@@ -1450,7 +1456,7 @@ define float @test_rootn_f32__y_2_nobuiltin(float %x) {
 define float @test_rootn_f32__y_3_nobuiltin(float %x) {
 ; CHECK-LABEL: define float @test_rootn_f32__y_3_nobuiltin(
 ; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 3) #[[ATTR3]]
+; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 3) #[[ATTR4]]
 ; CHECK-NEXT:    ret float [[CALL]]
 ;
   %call = tail call float @_Z5rootnfi(float %x, i32 3) #0
@@ -1460,7 +1466,7 @@ define float @test_rootn_f32__y_3_nobuiltin(float %x) {
 define float @test_rootn_f32__y_neg1_nobuiltin(float %x) {
 ; CHECK-LABEL: define float @test_rootn_f32__y_neg1_nobuiltin(
 ; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -1) #[[ATTR3]]
+; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -1) #[[ATTR4]]
 ; CHECK-NEXT:    ret float [[CALL]]
 ;
   %call = tail call float @_Z5rootnfi(float %x, i32 -1) #0
@@ -1470,7 +1476,7 @@ define float @test_rootn_f32__y_neg1_nobuiltin(float %x) {
 define float @test_rootn_f32__y_neg2_nobuiltin(float %x) {
 ; CHECK-LABEL: define float @test_rootn_f32__y_neg2_nobuiltin(
 ; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR3]]
+; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR4]]
 ; CHECK-NEXT:    ret float [[CALL]]
 ;
   %call = tail call float @_Z5rootnfi(float %x, i32 -2) #0
@@ -1487,7 +1493,8 @@ attributes #2 = { noinline }
 ; CHECK: attributes #[[ATTR0]] = { strictfp }
 ; CHECK: attributes #[[ATTR1:[0-9]+]] = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
 ; CHECK: attributes #[[ATTR2:[0-9]+]] = { nounwind memory(read) }
-; CHECK: attributes #[[ATTR3]] = { nobuiltin }
+; CHECK: attributes #[[ATTR3]] = { noinline }
+; CHECK: attributes #[[ATTR4]] = { nobuiltin }
 ;.
 ; CHECK: [[META0]] = !{float 2.000000e+00}
 ; CHECK: [[META1]] = !{float 3.000000e+00}
diff --git a/llvm/test/CodeGen/AMDGPU/simplify-libcalls.ll b/llvm/test/CodeGen/AMDGPU/simplify-libcalls.ll
index 152eba5dec946..5a241f85b2e2c 100644
--- a/llvm/test/CodeGen/AMDGPU/simplify-libcalls.ll
+++ b/llvm/test/CodeGen/AMDGPU/simplify-libcalls.ll
@@ -506,8 +506,8 @@ entry:
 }
 
 ; GCN-LABEL: {{^}}define amdgpu_kernel void @test_rootn_m2
-; GCN-POSTLINK: call fast float @_Z5rootnfi(float %tmp, i32 -2)
-; GCN-PRELINK: %__rootn2rsqrt = tail call fast float @_Z5rsqrtf(float %tmp)
+; GCN: [[SQRT:%.+]] = tail call fast float @llvm.sqrt.f32(float %tmp)
+; GCN-NEXT: fdiv fast float 1.000000e+00, [[SQRT]]
 define amdgpu_kernel void @test_rootn_m2(ptr addrspace(1) nocapture %a) {
 entry:
   %tmp = load float, ptr addrspace(1) %a, align 4

``````````

</details>


https://github.com/llvm/llvm-project/pull/92863


More information about the llvm-commits mailing list