[llvm] 66b76fa - AMDGPU: Directly emit sqrt intrinsic when folding rootn(x, 2) (#92598)

via llvm-commits llvm-commits at lists.llvm.org
Mon May 20 22:57:07 PDT 2024


Author: Matt Arsenault
Date: 2024-05-21T07:57:04+02:00
New Revision: 66b76faffb211b3cb2d58e3ab9401e6396447de9

URL: https://github.com/llvm/llvm-project/commit/66b76faffb211b3cb2d58e3ab9401e6396447de9
DIFF: https://github.com/llvm/llvm-project/commit/66b76faffb211b3cb2d58e3ab9401e6396447de9.diff

LOG: AMDGPU: Directly emit sqrt intrinsic when folding rootn(x, 2) (#92598)

This avoids depending on pre/post link runs.

Depends #92595

Added: 
    

Modified: 
    llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
    llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-rootn.ll
    llvm/test/CodeGen/AMDGPU/simplify-libcalls.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp b/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
index 47de1791dae31..aab79ceb57f22 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
@@ -22,6 +22,7 @@
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/IntrinsicsAMDGPU.h"
+#include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/PatternMatch.h"
 #include "llvm/InitializePasses.h"
 #include <cmath>
@@ -1175,17 +1176,30 @@ bool AMDGPULibCalls::fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B,
     return true;
   }
 
-  Module *M = Parent->getParent();
-  if (ci_opr1 == 2) { // rootn(x, 2) = sqrt(x)
-    if (FunctionCallee FPExpr =
-            getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_SQRT, FInfo))) {
-      LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> sqrt(" << *opr0
-                        << ")\n");
-      Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2sqrt");
-      replaceCall(FPOp, nval);
-      return true;
-    }
-  } else if (ci_opr1 == 3) { // rootn(x, 3) = cbrt(x)
+  Module *M = B.GetInsertBlock()->getModule();
+
+  CallInst *CI = cast<CallInst>(FPOp);
+  if (ci_opr1 == 2 &&
+      shouldReplaceLibcallWithIntrinsic(CI,
+                                        /*AllowMinSizeF32=*/true,
+                                        /*AllowF64=*/true)) {
+    // rootn(x, 2) = sqrt(x)
+    LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> sqrt(" << *opr0 << ")\n");
+
+    CallInst *NewCall = B.CreateUnaryIntrinsic(Intrinsic::sqrt, opr0, CI);
+    NewCall->takeName(CI);
+
+    // OpenCL rootn has a looser ulp of 2 requirement than sqrt, so add some
+    // metadata.
+    MDBuilder MDHelper(M->getContext());
+    MDNode *FPMD = MDHelper.createFPMath(std::max(FPOp->getFPAccuracy(), 2.0f));
+    NewCall->setMetadata(LLVMContext::MD_fpmath, FPMD);
+
+    replaceCall(CI, NewCall);
+    return true;
+  }
+
+  if (ci_opr1 == 3) { // rootn(x, 3) = cbrt(x)
     if (FunctionCallee FPExpr =
             getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_CBRT, FInfo))) {
       LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> cbrt(" << *opr0

diff  --git a/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-rootn.ll b/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-rootn.ll
index d75517cb26875..c105ad7590e69 100644
--- a/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-rootn.ll
+++ b/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-rootn.ll
@@ -272,8 +272,8 @@ define half @test_rootn_f16_1(half %x) {
 define half @test_rootn_f16_2(half %x) {
 ; CHECK-LABEL: define half @test_rootn_f16_2(
 ; CHECK-SAME: half [[X:%.*]]) {
-; CHECK-NEXT:    [[__ROOTN2SQRT:%.*]] = call half @_Z4sqrtDh(half [[X]])
-; CHECK-NEXT:    ret half [[__ROOTN2SQRT]]
+; CHECK-NEXT:    [[CALL:%.*]] = call half @llvm.sqrt.f16(half [[X]]), !fpmath [[META0:![0-9]+]]
+; CHECK-NEXT:    ret half [[CALL]]
 ;
   %call = tail call half @_Z5rootnDhi(half %x, i32 2)
   ret half %call
@@ -351,8 +351,8 @@ define <2 x half> @test_rootn_v2f16_1(<2 x half> %x) {
 define <2 x half> @test_rootn_v2f16_2(<2 x half> %x) {
 ; CHECK-LABEL: define <2 x half> @test_rootn_v2f16_2(
 ; CHECK-SAME: <2 x half> [[X:%.*]]) {
-; CHECK-NEXT:    [[__ROOTN2SQRT:%.*]] = call <2 x half> @_Z4sqrtDv2_Dh(<2 x half> [[X]])
-; CHECK-NEXT:    ret <2 x half> [[__ROOTN2SQRT]]
+; CHECK-NEXT:    [[CALL:%.*]] = call <2 x half> @llvm.sqrt.v2f16(<2 x half> [[X]]), !fpmath [[META0]]
+; CHECK-NEXT:    ret <2 x half> [[CALL]]
 ;
   %call = tail call <2 x half> @_Z5rootnDv2_DhDv2_i(<2 x half> %x, <2 x i32> <i32 2, i32 2>)
   ret <2 x half> %call
@@ -612,8 +612,8 @@ define float @test_rootn_f32__y_2(float %x) {
 ; CHECK-LABEL: define float @test_rootn_f32__y_2(
 ; CHECK-SAME: float [[X:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[__ROOTN2SQRT:%.*]] = call float @_Z4sqrtf(float [[X]])
-; CHECK-NEXT:    ret float [[__ROOTN2SQRT]]
+; CHECK-NEXT:    [[CALL:%.*]] = call float @llvm.sqrt.f32(float [[X]]), !fpmath [[META0]]
+; CHECK-NEXT:    ret float [[CALL]]
 ;
 entry:
   %call = tail call float @_Z5rootnfi(float %x, i32 2)
@@ -624,8 +624,8 @@ define float @test_rootn_f32__y_2_flags(float %x) {
 ; CHECK-LABEL: define float @test_rootn_f32__y_2_flags(
 ; CHECK-SAME: float [[X:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[__ROOTN2SQRT:%.*]] = call nnan nsz float @_Z4sqrtf(float [[X]])
-; CHECK-NEXT:    ret float [[__ROOTN2SQRT]]
+; CHECK-NEXT:    [[CALL:%.*]] = call nnan nsz float @llvm.sqrt.f32(float [[X]]), !fpmath [[META0]]
+; CHECK-NEXT:    ret float [[CALL]]
 ;
 entry:
   %call = tail call nnan nsz float @_Z5rootnfi(float %x, i32 2)
@@ -637,8 +637,8 @@ define float @test_rootn_f32__y_2_fpmath_3(float %x) {
 ; CHECK-LABEL: define float @test_rootn_f32__y_2_fpmath_3(
 ; CHECK-SAME: float [[X:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[__ROOTN2SQRT:%.*]] = call nnan nsz float @_Z4sqrtf(float [[X]])
-; CHECK-NEXT:    ret float [[__ROOTN2SQRT]]
+; CHECK-NEXT:    [[CALL:%.*]] = call nnan nsz float @llvm.sqrt.f32(float [[X]]), !fpmath [[META1:![0-9]+]]
+; CHECK-NEXT:    ret float [[CALL]]
 ;
 entry:
   %call = tail call nnan nsz float @_Z5rootnfi(float %x, i32 2), !fpmath !0
@@ -649,8 +649,8 @@ define <2 x float> @test_rootn_v2f32__y_2_flags(<2 x float> %x) {
 ; CHECK-LABEL: define <2 x float> @test_rootn_v2f32__y_2_flags(
 ; CHECK-SAME: <2 x float> [[X:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[__ROOTN2SQRT:%.*]] = call nnan nsz <2 x float> @_Z4sqrtDv2_f(<2 x float> [[X]])
-; CHECK-NEXT:    ret <2 x float> [[__ROOTN2SQRT]]
+; CHECK-NEXT:    [[CALL:%.*]] = call nnan nsz <2 x float> @llvm.sqrt.v2f32(<2 x float> [[X]]), !fpmath [[META0]]
+; CHECK-NEXT:    ret <2 x float> [[CALL]]
 ;
 entry:
   %call = tail call nnan nsz <2 x float> @_Z5rootnDv2_fDv2_i(<2 x float> %x, <2 x i32> <i32 2, i32 2>)
@@ -661,8 +661,8 @@ define <3 x float> @test_rootn_v3f32__y_2(<3 x float> %x) {
 ; CHECK-LABEL: define <3 x float> @test_rootn_v3f32__y_2(
 ; CHECK-SAME: <3 x float> [[X:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[__ROOTN2SQRT:%.*]] = call <3 x float> @_Z4sqrtDv3_f(<3 x float> [[X]])
-; CHECK-NEXT:    ret <3 x float> [[__ROOTN2SQRT]]
+; CHECK-NEXT:    [[CALL:%.*]] = call <3 x float> @llvm.sqrt.v3f32(<3 x float> [[X]]), !fpmath [[META0]]
+; CHECK-NEXT:    ret <3 x float> [[CALL]]
 ;
 entry:
   %call = tail call <3 x float> @_Z5rootnDv3_fDv3_i(<3 x float> %x, <3 x i32> <i32 2, i32 2, i32 2>)
@@ -673,8 +673,8 @@ define <3 x float> @test_rootn_v3f32__y_2_undef(<3 x float> %x) {
 ; CHECK-LABEL: define <3 x float> @test_rootn_v3f32__y_2_undef(
 ; CHECK-SAME: <3 x float> [[X:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[__ROOTN2SQRT:%.*]] = call <3 x float> @_Z4sqrtDv3_f(<3 x float> [[X]])
-; CHECK-NEXT:    ret <3 x float> [[__ROOTN2SQRT]]
+; CHECK-NEXT:    [[CALL:%.*]] = call <3 x float> @llvm.sqrt.v3f32(<3 x float> [[X]]), !fpmath [[META0]]
+; CHECK-NEXT:    ret <3 x float> [[CALL]]
 ;
 entry:
   %call = tail call <3 x float> @_Z5rootnDv3_fDv3_i(<3 x float> %x, <3 x i32> <i32 2, i32 poison, i32 2>)
@@ -685,8 +685,8 @@ define <4 x float> @test_rootn_v4f32__y_2(<4 x float> %x) {
 ; CHECK-LABEL: define <4 x float> @test_rootn_v4f32__y_2(
 ; CHECK-SAME: <4 x float> [[X:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[__ROOTN2SQRT:%.*]] = call <4 x float> @_Z4sqrtDv4_f(<4 x float> [[X]])
-; CHECK-NEXT:    ret <4 x float> [[__ROOTN2SQRT]]
+; CHECK-NEXT:    [[CALL:%.*]] = call <4 x float> @llvm.sqrt.v4f32(<4 x float> [[X]]), !fpmath [[META0]]
+; CHECK-NEXT:    ret <4 x float> [[CALL]]
 ;
 entry:
   %call = tail call <4 x float> @_Z5rootnDv4_fDv4_i(<4 x float> %x, <4 x i32> <i32 2, i32 2, i32 2, i32 2>)
@@ -697,8 +697,8 @@ define <8 x float> @test_rootn_v8f32__y_2(<8 x float> %x) {
 ; CHECK-LABEL: define <8 x float> @test_rootn_v8f32__y_2(
 ; CHECK-SAME: <8 x float> [[X:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[__ROOTN2SQRT:%.*]] = call <8 x float> @_Z4sqrtDv8_f(<8 x float> [[X]])
-; CHECK-NEXT:    ret <8 x float> [[__ROOTN2SQRT]]
+; CHECK-NEXT:    [[CALL:%.*]] = call <8 x float> @llvm.sqrt.v8f32(<8 x float> [[X]]), !fpmath [[META0]]
+; CHECK-NEXT:    ret <8 x float> [[CALL]]
 ;
 entry:
   %call = tail call <8 x float> @_Z5rootnDv8_fDv8_i(<8 x float> %x, <8 x i32> <i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2>)
@@ -709,8 +709,8 @@ define <16 x float> @test_rootn_v16f32__y_2(<16 x float> %x) {
 ; CHECK-LABEL: define <16 x float> @test_rootn_v16f32__y_2(
 ; CHECK-SAME: <16 x float> [[X:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[__ROOTN2SQRT:%.*]] = call <16 x float> @_Z4sqrtDv16_f(<16 x float> [[X]])
-; CHECK-NEXT:    ret <16 x float> [[__ROOTN2SQRT]]
+; CHECK-NEXT:    [[CALL:%.*]] = call <16 x float> @llvm.sqrt.v16f32(<16 x float> [[X]]), !fpmath [[META0]]
+; CHECK-NEXT:    ret <16 x float> [[CALL]]
 ;
 entry:
   %call = tail call <16 x float> @_Z5rootnDv16_fDv16_i(<16 x float> %x, <16 x i32> <i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2>)
@@ -757,8 +757,8 @@ define <2 x float> @test_rootn_v2f32__y_nonsplat_2_poison(<2 x float> %x) {
 ; CHECK-LABEL: define <2 x float> @test_rootn_v2f32__y_nonsplat_2_poison(
 ; CHECK-SAME: <2 x float> [[X:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[__ROOTN2SQRT:%.*]] = call <2 x float> @_Z4sqrtDv2_f(<2 x float> [[X]])
-; CHECK-NEXT:    ret <2 x float> [[__ROOTN2SQRT]]
+; CHECK-NEXT:    [[CALL:%.*]] = call <2 x float> @llvm.sqrt.v2f32(<2 x float> [[X]]), !fpmath [[META0]]
+; CHECK-NEXT:    ret <2 x float> [[CALL]]
 ;
 entry:
   %call = tail call <2 x float> @_Z5rootnDv2_fDv2_i(<2 x float> %x, <2 x i32> <i32 2, i32 poison>)
@@ -913,7 +913,7 @@ define float @test_rootn_f32__y_neg2__nobuiltin(float %x) {
 ; CHECK-LABEL: define float @test_rootn_f32__y_neg2__nobuiltin(
 ; CHECK-SAME: float [[X:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR2:[0-9]+]]
+; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR3:[0-9]+]]
 ; CHECK-NEXT:    ret float [[CALL]]
 ;
 entry:
@@ -1125,7 +1125,7 @@ define float @test_rootn_fast_f32_nobuiltin(float %x, i32 %y) {
 ; CHECK-LABEL: define float @test_rootn_fast_f32_nobuiltin(
 ; CHECK-SAME: float [[X:%.*]], i32 [[Y:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[CALL:%.*]] = tail call fast float @_Z5rootnfi(float [[X]], i32 [[Y]]) #[[ATTR2]]
+; CHECK-NEXT:    [[CALL:%.*]] = tail call fast float @_Z5rootnfi(float [[X]], i32 [[Y]]) #[[ATTR3]]
 ; CHECK-NEXT:    ret float [[CALL]]
 ;
 entry:
@@ -1420,7 +1420,7 @@ entry:
 define float @test_rootn_f32__y_0_nobuiltin(float %x) {
 ; CHECK-LABEL: define float @test_rootn_f32__y_0_nobuiltin(
 ; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 0) #[[ATTR2]]
+; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 0) #[[ATTR3]]
 ; CHECK-NEXT:    ret float [[CALL]]
 ;
   %call = tail call float @_Z5rootnfi(float %x, i32 0) #0
@@ -1430,7 +1430,7 @@ define float @test_rootn_f32__y_0_nobuiltin(float %x) {
 define float @test_rootn_f32__y_1_nobuiltin(float %x) {
 ; CHECK-LABEL: define float @test_rootn_f32__y_1_nobuiltin(
 ; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 1) #[[ATTR2]]
+; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 1) #[[ATTR3]]
 ; CHECK-NEXT:    ret float [[CALL]]
 ;
   %call = tail call float @_Z5rootnfi(float %x, i32 1) #0
@@ -1440,7 +1440,7 @@ define float @test_rootn_f32__y_1_nobuiltin(float %x) {
 define float @test_rootn_f32__y_2_nobuiltin(float %x) {
 ; CHECK-LABEL: define float @test_rootn_f32__y_2_nobuiltin(
 ; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 2) #[[ATTR2]]
+; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 2) #[[ATTR3]]
 ; CHECK-NEXT:    ret float [[CALL]]
 ;
   %call = tail call float @_Z5rootnfi(float %x, i32 2) #0
@@ -1450,7 +1450,7 @@ define float @test_rootn_f32__y_2_nobuiltin(float %x) {
 define float @test_rootn_f32__y_3_nobuiltin(float %x) {
 ; CHECK-LABEL: define float @test_rootn_f32__y_3_nobuiltin(
 ; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 3) #[[ATTR2]]
+; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 3) #[[ATTR3]]
 ; CHECK-NEXT:    ret float [[CALL]]
 ;
   %call = tail call float @_Z5rootnfi(float %x, i32 3) #0
@@ -1460,7 +1460,7 @@ define float @test_rootn_f32__y_3_nobuiltin(float %x) {
 define float @test_rootn_f32__y_neg1_nobuiltin(float %x) {
 ; CHECK-LABEL: define float @test_rootn_f32__y_neg1_nobuiltin(
 ; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -1) #[[ATTR2]]
+; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -1) #[[ATTR3]]
 ; CHECK-NEXT:    ret float [[CALL]]
 ;
   %call = tail call float @_Z5rootnfi(float %x, i32 -1) #0
@@ -1470,7 +1470,7 @@ define float @test_rootn_f32__y_neg1_nobuiltin(float %x) {
 define float @test_rootn_f32__y_neg2_nobuiltin(float %x) {
 ; CHECK-LABEL: define float @test_rootn_f32__y_neg2_nobuiltin(
 ; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR2]]
+; CHECK-NEXT:    [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR3]]
 ; CHECK-NEXT:    ret float [[CALL]]
 ;
   %call = tail call float @_Z5rootnfi(float %x, i32 -2) #0
@@ -1485,6 +1485,10 @@ attributes #2 = { noinline }
 !0 = !{float 3.0}
 ;.
 ; CHECK: attributes #[[ATTR0]] = { strictfp }
-; CHECK: attributes #[[ATTR1:[0-9]+]] = { nounwind memory(read) }
-; CHECK: attributes #[[ATTR2]] = { nobuiltin }
+; CHECK: attributes #[[ATTR1:[0-9]+]] = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
+; CHECK: attributes #[[ATTR2:[0-9]+]] = { nounwind memory(read) }
+; CHECK: attributes #[[ATTR3]] = { nobuiltin }
+;.
+; CHECK: [[META0]] = !{float 2.000000e+00}
+; CHECK: [[META1]] = !{float 3.000000e+00}
 ;.

diff  --git a/llvm/test/CodeGen/AMDGPU/simplify-libcalls.ll b/llvm/test/CodeGen/AMDGPU/simplify-libcalls.ll
index 54ca33401ccf4..152eba5dec946 100644
--- a/llvm/test/CodeGen/AMDGPU/simplify-libcalls.ll
+++ b/llvm/test/CodeGen/AMDGPU/simplify-libcalls.ll
@@ -475,8 +475,7 @@ entry:
 declare float @_Z5rootnfi(float, i32)
 
 ; GCN-LABEL: {{^}}define amdgpu_kernel void @test_rootn_2
-; GCN-POSTLINK: call fast float @_Z5rootnfi(float %tmp, i32 2)
-; GCN-PRELINK: %__rootn2sqrt = tail call fast float @llvm.sqrt.f32(float %tmp)
+; GCN: call fast float @llvm.sqrt.f32(float %tmp)
 define amdgpu_kernel void @test_rootn_2(ptr addrspace(1) nocapture %a) {
 entry:
   %tmp = load float, ptr addrspace(1) %a, align 4


        


More information about the llvm-commits mailing list