[llvm] 7c4aa3b - AMDGPU: InstCombine amdgcn.rcp(amdgcn.sqrt) -> amdgcn.rsq
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Wed Aug 16 07:12:33 PDT 2023
Author: Matt Arsenault
Date: 2023-08-16T10:04:13-04:00
New Revision: 7c4aa3b37ed1c8ab276d67eb9757e56a2f26790c
URL: https://github.com/llvm/llvm-project/commit/7c4aa3b37ed1c8ab276d67eb9757e56a2f26790c
DIFF: https://github.com/llvm/llvm-project/commit/7c4aa3b37ed1c8ab276d67eb9757e56a2f26790c.diff
LOG: AMDGPU: InstCombine amdgcn.rcp(amdgcn.sqrt) -> amdgcn.rsq
We currently have some wrong combines in the backend that
approximately do this.
https://reviews.llvm.org/D158002
Added:
Modified:
llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
llvm/test/Transforms/InstCombine/AMDGPU/rcp-contract-rsq.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
index 3c399e49722785..a0274aecfa3274 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
@@ -408,6 +408,13 @@ static Value *simplifyAMDGCNMemoryIntrinsicDemanded(InstCombiner &IC,
int DMaskIdx = -1,
bool IsLoad = true);
+/// Return true if it's legal to contract llvm.amdgcn.rcp(llvm.sqrt)
+static bool canContractSqrtToRsq(const FPMathOperator *SqrtOp) {
+ return (SqrtOp->getType()->isFloatTy() &&
+ (SqrtOp->hasApproxFunc() || SqrtOp->getFPAccuracy() >= 1.0f)) ||
+ SqrtOp->getType()->isHalfTy();
+}
+
std::optional<Instruction *>
GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
Intrinsic::ID IID = II.getIntrinsicID();
@@ -437,6 +444,37 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
return IC.replaceInstUsesWith(II, ConstantFP::get(II.getContext(), Val));
}
+ FastMathFlags FMF = cast<FPMathOperator>(II).getFastMathFlags();
+ if (!FMF.allowContract())
+ break;
+ auto *SrcCI = dyn_cast<IntrinsicInst>(Src);
+ if (!SrcCI)
+ break;
+
+ auto IID = SrcCI->getIntrinsicID();
+ // llvm.amdgcn.rcp(llvm.amdgcn.sqrt(x)) -> llvm.amdgcn.rsq(x) if contractable
+ //
+ // llvm.amdgcn.rcp(llvm.sqrt(x)) -> llvm.amdgcn.rsq(x) if contractable and
+ // relaxed.
+ if (IID == Intrinsic::amdgcn_sqrt || IID == Intrinsic::sqrt) {
+ const FPMathOperator *SqrtOp = cast<FPMathOperator>(SrcCI);
+ FastMathFlags InnerFMF = SqrtOp->getFastMathFlags();
+ if (!InnerFMF.allowContract() || !SrcCI->hasOneUse())
+ break;
+
+ if (IID == Intrinsic::sqrt && !canContractSqrtToRsq(SqrtOp))
+ break;
+
+ Function *NewDecl = Intrinsic::getDeclaration(
+ SrcCI->getModule(), Intrinsic::amdgcn_rsq, {SrcCI->getType()});
+
+ InnerFMF |= FMF;
+ II.setFastMathFlags(InnerFMF);
+
+ II.setCalledFunction(NewDecl);
+ return IC.replaceOperand(II, 0, SrcCI->getArgOperand(0));
+ }
+
break;
}
case Intrinsic::amdgcn_sqrt:
diff --git a/llvm/test/Transforms/InstCombine/AMDGPU/rcp-contract-rsq.ll b/llvm/test/Transforms/InstCombine/AMDGPU/rcp-contract-rsq.ll
index e695433e586c5c..0be2c7aa85b27b 100644
--- a/llvm/test/Transforms/InstCombine/AMDGPU/rcp-contract-rsq.ll
+++ b/llvm/test/Transforms/InstCombine/AMDGPU/rcp-contract-rsq.ll
@@ -22,8 +22,7 @@ declare double @llvm.sqrt.f64(double)
define float @amdgcn_rcp_amdgcn_sqrt_f32_contract(float %x) {
; CHECK-LABEL: define float @amdgcn_rcp_amdgcn_sqrt_f32_contract
; CHECK-SAME: (float [[X:%.*]]) #[[ATTR1:[0-9]+]] {
-; CHECK-NEXT: [[SQRT:%.*]] = call contract float @llvm.amdgcn.sqrt.f32(float [[X]])
-; CHECK-NEXT: [[RSQ:%.*]] = call contract float @llvm.amdgcn.rcp.f32(float [[SQRT]])
+; CHECK-NEXT: [[RSQ:%.*]] = call contract float @llvm.amdgcn.rsq.f32(float [[X]])
; CHECK-NEXT: ret float [[RSQ]]
;
%sqrt = call contract float @llvm.amdgcn.sqrt.f32(float %x)
@@ -76,8 +75,7 @@ define float @amdgcn_rcp_amdgcn_sqrt_f32_contract_multi_use(float %x, ptr %ptr)
define float @amdgcn_rcp_nnan_amdgcn_sqrt_ninf_f32_contract(float %x) {
; CHECK-LABEL: define float @amdgcn_rcp_nnan_amdgcn_sqrt_ninf_f32_contract
; CHECK-SAME: (float [[X:%.*]]) #[[ATTR1]] {
-; CHECK-NEXT: [[SQRT:%.*]] = call nnan contract float @llvm.amdgcn.sqrt.f32(float [[X]])
-; CHECK-NEXT: [[RSQ:%.*]] = call ninf contract float @llvm.amdgcn.rcp.f32(float [[SQRT]])
+; CHECK-NEXT: [[RSQ:%.*]] = call nnan ninf contract float @llvm.amdgcn.rsq.f32(float [[X]])
; CHECK-NEXT: ret float [[RSQ]]
;
%sqrt = call nnan contract float @llvm.amdgcn.sqrt.f32(float %x)
@@ -89,8 +87,7 @@ define float @amdgcn_rcp_nnan_amdgcn_sqrt_ninf_f32_contract(float %x) {
define half @amdgcn_rcp_amdgcn_sqrt_f16_contract(half %x) {
; CHECK-LABEL: define half @amdgcn_rcp_amdgcn_sqrt_f16_contract
; CHECK-SAME: (half [[X:%.*]]) #[[ATTR1]] {
-; CHECK-NEXT: [[SQRT:%.*]] = call contract half @llvm.amdgcn.sqrt.f16(half [[X]])
-; CHECK-NEXT: [[RSQ:%.*]] = call contract half @llvm.amdgcn.rcp.f16(half [[SQRT]])
+; CHECK-NEXT: [[RSQ:%.*]] = call contract half @llvm.amdgcn.rsq.f16(half [[X]])
; CHECK-NEXT: ret half [[RSQ]]
;
%sqrt = call contract half @llvm.amdgcn.sqrt.f16(half %x)
@@ -143,8 +140,7 @@ define half @amdgcn_rcp_amdgcn_sqrt_f16_contract_multi_use(half %x, ptr %ptr) {
define half @amdgcn_rcp_nnan_amdgcn_sqrt_ninf_f16_contract(half %x) {
; CHECK-LABEL: define half @amdgcn_rcp_nnan_amdgcn_sqrt_ninf_f16_contract
; CHECK-SAME: (half [[X:%.*]]) #[[ATTR1]] {
-; CHECK-NEXT: [[SQRT:%.*]] = call nnan contract half @llvm.amdgcn.sqrt.f16(half [[X]])
-; CHECK-NEXT: [[RSQ:%.*]] = call ninf contract half @llvm.amdgcn.rcp.f16(half [[SQRT]])
+; CHECK-NEXT: [[RSQ:%.*]] = call nnan ninf contract half @llvm.amdgcn.rsq.f16(half [[X]])
; CHECK-NEXT: ret half [[RSQ]]
;
%sqrt = call nnan contract half @llvm.amdgcn.sqrt.f16(half %x)
@@ -156,8 +152,7 @@ define half @amdgcn_rcp_nnan_amdgcn_sqrt_ninf_f16_contract(half %x) {
define double @amdgcn_rcp_amdgcn_sqrt_f64_contract(double %x) {
; CHECK-LABEL: define double @amdgcn_rcp_amdgcn_sqrt_f64_contract
; CHECK-SAME: (double [[X:%.*]]) #[[ATTR1]] {
-; CHECK-NEXT: [[SQRT:%.*]] = call contract double @llvm.amdgcn.sqrt.f64(double [[X]])
-; CHECK-NEXT: [[RSQ:%.*]] = call contract double @llvm.amdgcn.rcp.f64(double [[SQRT]])
+; CHECK-NEXT: [[RSQ:%.*]] = call contract double @llvm.amdgcn.rsq.f64(double [[X]])
; CHECK-NEXT: ret double [[RSQ]]
;
%sqrt = call contract double @llvm.amdgcn.sqrt.f64(double %x)
@@ -210,8 +205,7 @@ define double @amdgcn_rcp_amdgcn_sqrt_f64_contract_multi_use(double %x, ptr %ptr
define double @amdgcn_rcp_nnan_amdgcn_sqrt_ninf_f64_contract(double %x) {
; CHECK-LABEL: define double @amdgcn_rcp_nnan_amdgcn_sqrt_ninf_f64_contract
; CHECK-SAME: (double [[X:%.*]]) #[[ATTR1]] {
-; CHECK-NEXT: [[SQRT:%.*]] = call nnan contract double @llvm.amdgcn.sqrt.f64(double [[X]])
-; CHECK-NEXT: [[RSQ:%.*]] = call ninf contract double @llvm.amdgcn.rcp.f64(double [[SQRT]])
+; CHECK-NEXT: [[RSQ:%.*]] = call nnan ninf contract double @llvm.amdgcn.rsq.f64(double [[X]])
; CHECK-NEXT: ret double [[RSQ]]
;
%sqrt = call nnan contract double @llvm.amdgcn.sqrt.f64(double %x)
@@ -236,8 +230,7 @@ define float @amdgcn_rcp_sqrt_f32_contract(float %x) {
define half @amdgcn_rcp_sqrt_f16_contract(half %x) {
; CHECK-LABEL: define half @amdgcn_rcp_sqrt_f16_contract
; CHECK-SAME: (half [[X:%.*]]) #[[ATTR1]] {
-; CHECK-NEXT: [[SQRT:%.*]] = call contract half @llvm.sqrt.f16(half [[X]])
-; CHECK-NEXT: [[RSQ:%.*]] = call contract half @llvm.amdgcn.rcp.f16(half [[SQRT]])
+; CHECK-NEXT: [[RSQ:%.*]] = call contract half @llvm.amdgcn.rsq.f16(half [[X]])
; CHECK-NEXT: ret half [[RSQ]]
;
%sqrt = call contract half @llvm.sqrt.f16(half %x)
@@ -261,8 +254,7 @@ define double @amdgcn_rcp_sqrt_f64_contract(double %x) {
define float @amdgcn_rcp_afn_sqrt_f32_contract(float %x) {
; CHECK-LABEL: define float @amdgcn_rcp_afn_sqrt_f32_contract
; CHECK-SAME: (float [[X:%.*]]) #[[ATTR1]] {
-; CHECK-NEXT: [[SQRT:%.*]] = call contract afn float @llvm.sqrt.f32(float [[X]])
-; CHECK-NEXT: [[RSQ:%.*]] = call contract float @llvm.amdgcn.rcp.f32(float [[SQRT]])
+; CHECK-NEXT: [[RSQ:%.*]] = call contract afn float @llvm.amdgcn.rsq.f32(float [[X]])
; CHECK-NEXT: ret float [[RSQ]]
;
%sqrt = call afn contract float @llvm.sqrt.f32(float %x)
@@ -273,8 +265,7 @@ define float @amdgcn_rcp_afn_sqrt_f32_contract(float %x) {
define float @amdgcn_rcp_fpmath3_sqrt_f32_contract(float %x) {
; CHECK-LABEL: define float @amdgcn_rcp_fpmath3_sqrt_f32_contract
; CHECK-SAME: (float [[X:%.*]]) #[[ATTR1]] {
-; CHECK-NEXT: [[SQRT:%.*]] = call contract float @llvm.sqrt.f32(float [[X]]), !fpmath !0
-; CHECK-NEXT: [[RSQ:%.*]] = call contract float @llvm.amdgcn.rcp.f32(float [[SQRT]])
+; CHECK-NEXT: [[RSQ:%.*]] = call contract float @llvm.amdgcn.rsq.f32(float [[X]])
; CHECK-NEXT: ret float [[RSQ]]
;
%sqrt = call contract float @llvm.sqrt.f32(float %x), !fpmath !0
@@ -285,8 +276,7 @@ define float @amdgcn_rcp_fpmath3_sqrt_f32_contract(float %x) {
define float @amdgcn_rcp_fpmath1_sqrt_f32_contract(float %x) {
; CHECK-LABEL: define float @amdgcn_rcp_fpmath1_sqrt_f32_contract
; CHECK-SAME: (float [[X:%.*]]) #[[ATTR1]] {
-; CHECK-NEXT: [[SQRT:%.*]] = call contract float @llvm.sqrt.f32(float [[X]]), !fpmath !1
-; CHECK-NEXT: [[RSQ:%.*]] = call contract float @llvm.amdgcn.rcp.f32(float [[SQRT]])
+; CHECK-NEXT: [[RSQ:%.*]] = call contract float @llvm.amdgcn.rsq.f32(float [[X]])
; CHECK-NEXT: ret float [[RSQ]]
;
%sqrt = call contract float @llvm.sqrt.f32(float %x), !fpmath !1
More information about the llvm-commits
mailing list