[llvm] 7c4aa3b - AMDGPU: InstCombine amdgcn.rcp(amdgcn.sqrt) -> amdgcn.rsq

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 16 07:12:33 PDT 2023


Author: Matt Arsenault
Date: 2023-08-16T10:04:13-04:00
New Revision: 7c4aa3b37ed1c8ab276d67eb9757e56a2f26790c

URL: https://github.com/llvm/llvm-project/commit/7c4aa3b37ed1c8ab276d67eb9757e56a2f26790c
DIFF: https://github.com/llvm/llvm-project/commit/7c4aa3b37ed1c8ab276d67eb9757e56a2f26790c.diff

LOG: AMDGPU: InstCombine amdgcn.rcp(amdgcn.sqrt) -> amdgcn.rsq

We currently have some wrong combines in the backend that
approximately do this.

https://reviews.llvm.org/D158002

Added: 
    

Modified: 
    llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
    llvm/test/Transforms/InstCombine/AMDGPU/rcp-contract-rsq.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
index 3c399e49722785..a0274aecfa3274 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
@@ -408,6 +408,13 @@ static Value *simplifyAMDGCNMemoryIntrinsicDemanded(InstCombiner &IC,
                                                     int DMaskIdx = -1,
                                                     bool IsLoad = true);
 
+/// Return true if it's legal to contract llvm.amdgcn.rcp(llvm.sqrt)
+static bool canContractSqrtToRsq(const FPMathOperator *SqrtOp) {
+  return (SqrtOp->getType()->isFloatTy() &&
+          (SqrtOp->hasApproxFunc() || SqrtOp->getFPAccuracy() >= 1.0f)) ||
+         SqrtOp->getType()->isHalfTy();
+}
+
 std::optional<Instruction *>
 GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
   Intrinsic::ID IID = II.getIntrinsicID();
@@ -437,6 +444,37 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
       return IC.replaceInstUsesWith(II, ConstantFP::get(II.getContext(), Val));
     }
 
+    FastMathFlags FMF = cast<FPMathOperator>(II).getFastMathFlags();
+    if (!FMF.allowContract())
+      break;
+    auto *SrcCI = dyn_cast<IntrinsicInst>(Src);
+    if (!SrcCI)
+      break;
+
+    auto IID = SrcCI->getIntrinsicID();
+    // llvm.amdgcn.rcp(llvm.amdgcn.sqrt(x)) -> llvm.amdgcn.rsq(x) if contractable
+    //
+    // llvm.amdgcn.rcp(llvm.sqrt(x)) -> llvm.amdgcn.rsq(x) if contractable and
+    // relaxed.
+    if (IID == Intrinsic::amdgcn_sqrt || IID == Intrinsic::sqrt) {
+      const FPMathOperator *SqrtOp = cast<FPMathOperator>(SrcCI);
+      FastMathFlags InnerFMF = SqrtOp->getFastMathFlags();
+      if (!InnerFMF.allowContract() || !SrcCI->hasOneUse())
+        break;
+
+      if (IID == Intrinsic::sqrt && !canContractSqrtToRsq(SqrtOp))
+        break;
+
+      Function *NewDecl = Intrinsic::getDeclaration(
+          SrcCI->getModule(), Intrinsic::amdgcn_rsq, {SrcCI->getType()});
+
+      InnerFMF |= FMF;
+      II.setFastMathFlags(InnerFMF);
+
+      II.setCalledFunction(NewDecl);
+      return IC.replaceOperand(II, 0, SrcCI->getArgOperand(0));
+    }
+
     break;
   }
   case Intrinsic::amdgcn_sqrt:

diff  --git a/llvm/test/Transforms/InstCombine/AMDGPU/rcp-contract-rsq.ll b/llvm/test/Transforms/InstCombine/AMDGPU/rcp-contract-rsq.ll
index e695433e586c5c..0be2c7aa85b27b 100644
--- a/llvm/test/Transforms/InstCombine/AMDGPU/rcp-contract-rsq.ll
+++ b/llvm/test/Transforms/InstCombine/AMDGPU/rcp-contract-rsq.ll
@@ -22,8 +22,7 @@ declare double @llvm.sqrt.f64(double)
 define float @amdgcn_rcp_amdgcn_sqrt_f32_contract(float %x) {
 ; CHECK-LABEL: define float @amdgcn_rcp_amdgcn_sqrt_f32_contract
 ; CHECK-SAME: (float [[X:%.*]]) #[[ATTR1:[0-9]+]] {
-; CHECK-NEXT:    [[SQRT:%.*]] = call contract float @llvm.amdgcn.sqrt.f32(float [[X]])
-; CHECK-NEXT:    [[RSQ:%.*]] = call contract float @llvm.amdgcn.rcp.f32(float [[SQRT]])
+; CHECK-NEXT:    [[RSQ:%.*]] = call contract float @llvm.amdgcn.rsq.f32(float [[X]])
 ; CHECK-NEXT:    ret float [[RSQ]]
 ;
   %sqrt = call contract float @llvm.amdgcn.sqrt.f32(float %x)
@@ -76,8 +75,7 @@ define float @amdgcn_rcp_amdgcn_sqrt_f32_contract_multi_use(float %x, ptr %ptr)
 define float @amdgcn_rcp_nnan_amdgcn_sqrt_ninf_f32_contract(float %x) {
 ; CHECK-LABEL: define float @amdgcn_rcp_nnan_amdgcn_sqrt_ninf_f32_contract
 ; CHECK-SAME: (float [[X:%.*]]) #[[ATTR1]] {
-; CHECK-NEXT:    [[SQRT:%.*]] = call nnan contract float @llvm.amdgcn.sqrt.f32(float [[X]])
-; CHECK-NEXT:    [[RSQ:%.*]] = call ninf contract float @llvm.amdgcn.rcp.f32(float [[SQRT]])
+; CHECK-NEXT:    [[RSQ:%.*]] = call nnan ninf contract float @llvm.amdgcn.rsq.f32(float [[X]])
 ; CHECK-NEXT:    ret float [[RSQ]]
 ;
   %sqrt = call nnan contract float @llvm.amdgcn.sqrt.f32(float %x)
@@ -89,8 +87,7 @@ define float @amdgcn_rcp_nnan_amdgcn_sqrt_ninf_f32_contract(float %x) {
 define half @amdgcn_rcp_amdgcn_sqrt_f16_contract(half %x) {
 ; CHECK-LABEL: define half @amdgcn_rcp_amdgcn_sqrt_f16_contract
 ; CHECK-SAME: (half [[X:%.*]]) #[[ATTR1]] {
-; CHECK-NEXT:    [[SQRT:%.*]] = call contract half @llvm.amdgcn.sqrt.f16(half [[X]])
-; CHECK-NEXT:    [[RSQ:%.*]] = call contract half @llvm.amdgcn.rcp.f16(half [[SQRT]])
+; CHECK-NEXT:    [[RSQ:%.*]] = call contract half @llvm.amdgcn.rsq.f16(half [[X]])
 ; CHECK-NEXT:    ret half [[RSQ]]
 ;
   %sqrt = call contract half @llvm.amdgcn.sqrt.f16(half %x)
@@ -143,8 +140,7 @@ define half @amdgcn_rcp_amdgcn_sqrt_f16_contract_multi_use(half %x, ptr %ptr) {
 define half @amdgcn_rcp_nnan_amdgcn_sqrt_ninf_f16_contract(half %x) {
 ; CHECK-LABEL: define half @amdgcn_rcp_nnan_amdgcn_sqrt_ninf_f16_contract
 ; CHECK-SAME: (half [[X:%.*]]) #[[ATTR1]] {
-; CHECK-NEXT:    [[SQRT:%.*]] = call nnan contract half @llvm.amdgcn.sqrt.f16(half [[X]])
-; CHECK-NEXT:    [[RSQ:%.*]] = call ninf contract half @llvm.amdgcn.rcp.f16(half [[SQRT]])
+; CHECK-NEXT:    [[RSQ:%.*]] = call nnan ninf contract half @llvm.amdgcn.rsq.f16(half [[X]])
 ; CHECK-NEXT:    ret half [[RSQ]]
 ;
   %sqrt = call nnan contract half @llvm.amdgcn.sqrt.f16(half %x)
@@ -156,8 +152,7 @@ define half @amdgcn_rcp_nnan_amdgcn_sqrt_ninf_f16_contract(half %x) {
 define double @amdgcn_rcp_amdgcn_sqrt_f64_contract(double %x) {
 ; CHECK-LABEL: define double @amdgcn_rcp_amdgcn_sqrt_f64_contract
 ; CHECK-SAME: (double [[X:%.*]]) #[[ATTR1]] {
-; CHECK-NEXT:    [[SQRT:%.*]] = call contract double @llvm.amdgcn.sqrt.f64(double [[X]])
-; CHECK-NEXT:    [[RSQ:%.*]] = call contract double @llvm.amdgcn.rcp.f64(double [[SQRT]])
+; CHECK-NEXT:    [[RSQ:%.*]] = call contract double @llvm.amdgcn.rsq.f64(double [[X]])
 ; CHECK-NEXT:    ret double [[RSQ]]
 ;
   %sqrt = call contract double @llvm.amdgcn.sqrt.f64(double %x)
@@ -210,8 +205,7 @@ define double @amdgcn_rcp_amdgcn_sqrt_f64_contract_multi_use(double %x, ptr %ptr
 define double @amdgcn_rcp_nnan_amdgcn_sqrt_ninf_f64_contract(double %x) {
 ; CHECK-LABEL: define double @amdgcn_rcp_nnan_amdgcn_sqrt_ninf_f64_contract
 ; CHECK-SAME: (double [[X:%.*]]) #[[ATTR1]] {
-; CHECK-NEXT:    [[SQRT:%.*]] = call nnan contract double @llvm.amdgcn.sqrt.f64(double [[X]])
-; CHECK-NEXT:    [[RSQ:%.*]] = call ninf contract double @llvm.amdgcn.rcp.f64(double [[SQRT]])
+; CHECK-NEXT:    [[RSQ:%.*]] = call nnan ninf contract double @llvm.amdgcn.rsq.f64(double [[X]])
 ; CHECK-NEXT:    ret double [[RSQ]]
 ;
   %sqrt = call nnan contract double @llvm.amdgcn.sqrt.f64(double %x)
@@ -236,8 +230,7 @@ define float @amdgcn_rcp_sqrt_f32_contract(float %x) {
 define half @amdgcn_rcp_sqrt_f16_contract(half %x) {
 ; CHECK-LABEL: define half @amdgcn_rcp_sqrt_f16_contract
 ; CHECK-SAME: (half [[X:%.*]]) #[[ATTR1]] {
-; CHECK-NEXT:    [[SQRT:%.*]] = call contract half @llvm.sqrt.f16(half [[X]])
-; CHECK-NEXT:    [[RSQ:%.*]] = call contract half @llvm.amdgcn.rcp.f16(half [[SQRT]])
+; CHECK-NEXT:    [[RSQ:%.*]] = call contract half @llvm.amdgcn.rsq.f16(half [[X]])
 ; CHECK-NEXT:    ret half [[RSQ]]
 ;
   %sqrt = call contract half @llvm.sqrt.f16(half %x)
@@ -261,8 +254,7 @@ define double @amdgcn_rcp_sqrt_f64_contract(double %x) {
 define float @amdgcn_rcp_afn_sqrt_f32_contract(float %x) {
 ; CHECK-LABEL: define float @amdgcn_rcp_afn_sqrt_f32_contract
 ; CHECK-SAME: (float [[X:%.*]]) #[[ATTR1]] {
-; CHECK-NEXT:    [[SQRT:%.*]] = call contract afn float @llvm.sqrt.f32(float [[X]])
-; CHECK-NEXT:    [[RSQ:%.*]] = call contract float @llvm.amdgcn.rcp.f32(float [[SQRT]])
+; CHECK-NEXT:    [[RSQ:%.*]] = call contract afn float @llvm.amdgcn.rsq.f32(float [[X]])
 ; CHECK-NEXT:    ret float [[RSQ]]
 ;
   %sqrt = call afn contract float @llvm.sqrt.f32(float %x)
@@ -273,8 +265,7 @@ define float @amdgcn_rcp_afn_sqrt_f32_contract(float %x) {
 define float @amdgcn_rcp_fpmath3_sqrt_f32_contract(float %x) {
 ; CHECK-LABEL: define float @amdgcn_rcp_fpmath3_sqrt_f32_contract
 ; CHECK-SAME: (float [[X:%.*]]) #[[ATTR1]] {
-; CHECK-NEXT:    [[SQRT:%.*]] = call contract float @llvm.sqrt.f32(float [[X]]), !fpmath !0
-; CHECK-NEXT:    [[RSQ:%.*]] = call contract float @llvm.amdgcn.rcp.f32(float [[SQRT]])
+; CHECK-NEXT:    [[RSQ:%.*]] = call contract float @llvm.amdgcn.rsq.f32(float [[X]])
 ; CHECK-NEXT:    ret float [[RSQ]]
 ;
   %sqrt = call contract float @llvm.sqrt.f32(float %x), !fpmath !0
@@ -285,8 +276,7 @@ define float @amdgcn_rcp_fpmath3_sqrt_f32_contract(float %x) {
 define float @amdgcn_rcp_fpmath1_sqrt_f32_contract(float %x) {
 ; CHECK-LABEL: define float @amdgcn_rcp_fpmath1_sqrt_f32_contract
 ; CHECK-SAME: (float [[X:%.*]]) #[[ATTR1]] {
-; CHECK-NEXT:    [[SQRT:%.*]] = call contract float @llvm.sqrt.f32(float [[X]]), !fpmath !1
-; CHECK-NEXT:    [[RSQ:%.*]] = call contract float @llvm.amdgcn.rcp.f32(float [[SQRT]])
+; CHECK-NEXT:    [[RSQ:%.*]] = call contract float @llvm.amdgcn.rsq.f32(float [[X]])
 ; CHECK-NEXT:    ret float [[RSQ]]
 ;
   %sqrt = call contract float @llvm.sqrt.f32(float %x), !fpmath !1


        


More information about the llvm-commits mailing list