[llvm] AMDGPU: Directly emit sqrt intrinsic when folding rootn(x, 2) (PR #92598)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Mon May 20 13:24:43 PDT 2024
https://github.com/arsenm updated https://github.com/llvm/llvm-project/pull/92598
>From d666b68a257a8cd36d0731741013e06119f63f03 Mon Sep 17 00:00:00 2001
From: Matt Arsenault <Matthew.Arsenault at amd.com>
Date: Sat, 2 Dec 2023 12:59:16 +0900
Subject: [PATCH] AMDGPU: Directly emit sqrt intrinsic when folding rootn(x, 2)
This avoids depending on pre/post link runs.
---
llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp | 36 +++++++---
.../AMDGPU/amdgpu-simplify-libcall-rootn.ll | 72 ++++++++++---------
llvm/test/CodeGen/AMDGPU/simplify-libcalls.ll | 3 +-
3 files changed, 64 insertions(+), 47 deletions(-)
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp b/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
index 47de1791dae31..aab79ceb57f22 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
@@ -22,6 +22,7 @@
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
+#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/InitializePasses.h"
#include <cmath>
@@ -1175,17 +1176,30 @@ bool AMDGPULibCalls::fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B,
return true;
}
- Module *M = Parent->getParent();
- if (ci_opr1 == 2) { // rootn(x, 2) = sqrt(x)
- if (FunctionCallee FPExpr =
- getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_SQRT, FInfo))) {
- LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> sqrt(" << *opr0
- << ")\n");
- Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2sqrt");
- replaceCall(FPOp, nval);
- return true;
- }
- } else if (ci_opr1 == 3) { // rootn(x, 3) = cbrt(x)
+ Module *M = B.GetInsertBlock()->getModule();
+
+ CallInst *CI = cast<CallInst>(FPOp);
+ if (ci_opr1 == 2 &&
+ shouldReplaceLibcallWithIntrinsic(CI,
+ /*AllowMinSizeF32=*/true,
+ /*AllowF64=*/true)) {
+ // rootn(x, 2) = sqrt(x)
+ LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> sqrt(" << *opr0 << ")\n");
+
+ CallInst *NewCall = B.CreateUnaryIntrinsic(Intrinsic::sqrt, opr0, CI);
+ NewCall->takeName(CI);
+
+ // OpenCL rootn has a looser ulp of 2 requirement than sqrt, so add some
+ // metadata.
+ MDBuilder MDHelper(M->getContext());
+ MDNode *FPMD = MDHelper.createFPMath(std::max(FPOp->getFPAccuracy(), 2.0f));
+ NewCall->setMetadata(LLVMContext::MD_fpmath, FPMD);
+
+ replaceCall(CI, NewCall);
+ return true;
+ }
+
+ if (ci_opr1 == 3) { // rootn(x, 3) = cbrt(x)
if (FunctionCallee FPExpr =
getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_CBRT, FInfo))) {
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> cbrt(" << *opr0
diff --git a/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-rootn.ll b/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-rootn.ll
index d75517cb26875..c105ad7590e69 100644
--- a/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-rootn.ll
+++ b/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-rootn.ll
@@ -272,8 +272,8 @@ define half @test_rootn_f16_1(half %x) {
define half @test_rootn_f16_2(half %x) {
; CHECK-LABEL: define half @test_rootn_f16_2(
; CHECK-SAME: half [[X:%.*]]) {
-; CHECK-NEXT: [[__ROOTN2SQRT:%.*]] = call half @_Z4sqrtDh(half [[X]])
-; CHECK-NEXT: ret half [[__ROOTN2SQRT]]
+; CHECK-NEXT: [[CALL:%.*]] = call half @llvm.sqrt.f16(half [[X]]), !fpmath [[META0:![0-9]+]]
+; CHECK-NEXT: ret half [[CALL]]
;
%call = tail call half @_Z5rootnDhi(half %x, i32 2)
ret half %call
@@ -351,8 +351,8 @@ define <2 x half> @test_rootn_v2f16_1(<2 x half> %x) {
define <2 x half> @test_rootn_v2f16_2(<2 x half> %x) {
; CHECK-LABEL: define <2 x half> @test_rootn_v2f16_2(
; CHECK-SAME: <2 x half> [[X:%.*]]) {
-; CHECK-NEXT: [[__ROOTN2SQRT:%.*]] = call <2 x half> @_Z4sqrtDv2_Dh(<2 x half> [[X]])
-; CHECK-NEXT: ret <2 x half> [[__ROOTN2SQRT]]
+; CHECK-NEXT: [[CALL:%.*]] = call <2 x half> @llvm.sqrt.v2f16(<2 x half> [[X]]), !fpmath [[META0]]
+; CHECK-NEXT: ret <2 x half> [[CALL]]
;
%call = tail call <2 x half> @_Z5rootnDv2_DhDv2_i(<2 x half> %x, <2 x i32> <i32 2, i32 2>)
ret <2 x half> %call
@@ -612,8 +612,8 @@ define float @test_rootn_f32__y_2(float %x) {
; CHECK-LABEL: define float @test_rootn_f32__y_2(
; CHECK-SAME: float [[X:%.*]]) {
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[__ROOTN2SQRT:%.*]] = call float @_Z4sqrtf(float [[X]])
-; CHECK-NEXT: ret float [[__ROOTN2SQRT]]
+; CHECK-NEXT: [[CALL:%.*]] = call float @llvm.sqrt.f32(float [[X]]), !fpmath [[META0]]
+; CHECK-NEXT: ret float [[CALL]]
;
entry:
%call = tail call float @_Z5rootnfi(float %x, i32 2)
@@ -624,8 +624,8 @@ define float @test_rootn_f32__y_2_flags(float %x) {
; CHECK-LABEL: define float @test_rootn_f32__y_2_flags(
; CHECK-SAME: float [[X:%.*]]) {
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[__ROOTN2SQRT:%.*]] = call nnan nsz float @_Z4sqrtf(float [[X]])
-; CHECK-NEXT: ret float [[__ROOTN2SQRT]]
+; CHECK-NEXT: [[CALL:%.*]] = call nnan nsz float @llvm.sqrt.f32(float [[X]]), !fpmath [[META0]]
+; CHECK-NEXT: ret float [[CALL]]
;
entry:
%call = tail call nnan nsz float @_Z5rootnfi(float %x, i32 2)
@@ -637,8 +637,8 @@ define float @test_rootn_f32__y_2_fpmath_3(float %x) {
; CHECK-LABEL: define float @test_rootn_f32__y_2_fpmath_3(
; CHECK-SAME: float [[X:%.*]]) {
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[__ROOTN2SQRT:%.*]] = call nnan nsz float @_Z4sqrtf(float [[X]])
-; CHECK-NEXT: ret float [[__ROOTN2SQRT]]
+; CHECK-NEXT: [[CALL:%.*]] = call nnan nsz float @llvm.sqrt.f32(float [[X]]), !fpmath [[META1:![0-9]+]]
+; CHECK-NEXT: ret float [[CALL]]
;
entry:
%call = tail call nnan nsz float @_Z5rootnfi(float %x, i32 2), !fpmath !0
@@ -649,8 +649,8 @@ define <2 x float> @test_rootn_v2f32__y_2_flags(<2 x float> %x) {
; CHECK-LABEL: define <2 x float> @test_rootn_v2f32__y_2_flags(
; CHECK-SAME: <2 x float> [[X:%.*]]) {
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[__ROOTN2SQRT:%.*]] = call nnan nsz <2 x float> @_Z4sqrtDv2_f(<2 x float> [[X]])
-; CHECK-NEXT: ret <2 x float> [[__ROOTN2SQRT]]
+; CHECK-NEXT: [[CALL:%.*]] = call nnan nsz <2 x float> @llvm.sqrt.v2f32(<2 x float> [[X]]), !fpmath [[META0]]
+; CHECK-NEXT: ret <2 x float> [[CALL]]
;
entry:
%call = tail call nnan nsz <2 x float> @_Z5rootnDv2_fDv2_i(<2 x float> %x, <2 x i32> <i32 2, i32 2>)
@@ -661,8 +661,8 @@ define <3 x float> @test_rootn_v3f32__y_2(<3 x float> %x) {
; CHECK-LABEL: define <3 x float> @test_rootn_v3f32__y_2(
; CHECK-SAME: <3 x float> [[X:%.*]]) {
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[__ROOTN2SQRT:%.*]] = call <3 x float> @_Z4sqrtDv3_f(<3 x float> [[X]])
-; CHECK-NEXT: ret <3 x float> [[__ROOTN2SQRT]]
+; CHECK-NEXT: [[CALL:%.*]] = call <3 x float> @llvm.sqrt.v3f32(<3 x float> [[X]]), !fpmath [[META0]]
+; CHECK-NEXT: ret <3 x float> [[CALL]]
;
entry:
%call = tail call <3 x float> @_Z5rootnDv3_fDv3_i(<3 x float> %x, <3 x i32> <i32 2, i32 2, i32 2>)
@@ -673,8 +673,8 @@ define <3 x float> @test_rootn_v3f32__y_2_undef(<3 x float> %x) {
; CHECK-LABEL: define <3 x float> @test_rootn_v3f32__y_2_undef(
; CHECK-SAME: <3 x float> [[X:%.*]]) {
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[__ROOTN2SQRT:%.*]] = call <3 x float> @_Z4sqrtDv3_f(<3 x float> [[X]])
-; CHECK-NEXT: ret <3 x float> [[__ROOTN2SQRT]]
+; CHECK-NEXT: [[CALL:%.*]] = call <3 x float> @llvm.sqrt.v3f32(<3 x float> [[X]]), !fpmath [[META0]]
+; CHECK-NEXT: ret <3 x float> [[CALL]]
;
entry:
%call = tail call <3 x float> @_Z5rootnDv3_fDv3_i(<3 x float> %x, <3 x i32> <i32 2, i32 poison, i32 2>)
@@ -685,8 +685,8 @@ define <4 x float> @test_rootn_v4f32__y_2(<4 x float> %x) {
; CHECK-LABEL: define <4 x float> @test_rootn_v4f32__y_2(
; CHECK-SAME: <4 x float> [[X:%.*]]) {
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[__ROOTN2SQRT:%.*]] = call <4 x float> @_Z4sqrtDv4_f(<4 x float> [[X]])
-; CHECK-NEXT: ret <4 x float> [[__ROOTN2SQRT]]
+; CHECK-NEXT: [[CALL:%.*]] = call <4 x float> @llvm.sqrt.v4f32(<4 x float> [[X]]), !fpmath [[META0]]
+; CHECK-NEXT: ret <4 x float> [[CALL]]
;
entry:
%call = tail call <4 x float> @_Z5rootnDv4_fDv4_i(<4 x float> %x, <4 x i32> <i32 2, i32 2, i32 2, i32 2>)
@@ -697,8 +697,8 @@ define <8 x float> @test_rootn_v8f32__y_2(<8 x float> %x) {
; CHECK-LABEL: define <8 x float> @test_rootn_v8f32__y_2(
; CHECK-SAME: <8 x float> [[X:%.*]]) {
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[__ROOTN2SQRT:%.*]] = call <8 x float> @_Z4sqrtDv8_f(<8 x float> [[X]])
-; CHECK-NEXT: ret <8 x float> [[__ROOTN2SQRT]]
+; CHECK-NEXT: [[CALL:%.*]] = call <8 x float> @llvm.sqrt.v8f32(<8 x float> [[X]]), !fpmath [[META0]]
+; CHECK-NEXT: ret <8 x float> [[CALL]]
;
entry:
%call = tail call <8 x float> @_Z5rootnDv8_fDv8_i(<8 x float> %x, <8 x i32> <i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2>)
@@ -709,8 +709,8 @@ define <16 x float> @test_rootn_v16f32__y_2(<16 x float> %x) {
; CHECK-LABEL: define <16 x float> @test_rootn_v16f32__y_2(
; CHECK-SAME: <16 x float> [[X:%.*]]) {
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[__ROOTN2SQRT:%.*]] = call <16 x float> @_Z4sqrtDv16_f(<16 x float> [[X]])
-; CHECK-NEXT: ret <16 x float> [[__ROOTN2SQRT]]
+; CHECK-NEXT: [[CALL:%.*]] = call <16 x float> @llvm.sqrt.v16f32(<16 x float> [[X]]), !fpmath [[META0]]
+; CHECK-NEXT: ret <16 x float> [[CALL]]
;
entry:
%call = tail call <16 x float> @_Z5rootnDv16_fDv16_i(<16 x float> %x, <16 x i32> <i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2>)
@@ -757,8 +757,8 @@ define <2 x float> @test_rootn_v2f32__y_nonsplat_2_poison(<2 x float> %x) {
; CHECK-LABEL: define <2 x float> @test_rootn_v2f32__y_nonsplat_2_poison(
; CHECK-SAME: <2 x float> [[X:%.*]]) {
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[__ROOTN2SQRT:%.*]] = call <2 x float> @_Z4sqrtDv2_f(<2 x float> [[X]])
-; CHECK-NEXT: ret <2 x float> [[__ROOTN2SQRT]]
+; CHECK-NEXT: [[CALL:%.*]] = call <2 x float> @llvm.sqrt.v2f32(<2 x float> [[X]]), !fpmath [[META0]]
+; CHECK-NEXT: ret <2 x float> [[CALL]]
;
entry:
%call = tail call <2 x float> @_Z5rootnDv2_fDv2_i(<2 x float> %x, <2 x i32> <i32 2, i32 poison>)
@@ -913,7 +913,7 @@ define float @test_rootn_f32__y_neg2__nobuiltin(float %x) {
; CHECK-LABEL: define float @test_rootn_f32__y_neg2__nobuiltin(
; CHECK-SAME: float [[X:%.*]]) {
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR2:[0-9]+]]
+; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR3:[0-9]+]]
; CHECK-NEXT: ret float [[CALL]]
;
entry:
@@ -1125,7 +1125,7 @@ define float @test_rootn_fast_f32_nobuiltin(float %x, i32 %y) {
; CHECK-LABEL: define float @test_rootn_fast_f32_nobuiltin(
; CHECK-SAME: float [[X:%.*]], i32 [[Y:%.*]]) {
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[CALL:%.*]] = tail call fast float @_Z5rootnfi(float [[X]], i32 [[Y]]) #[[ATTR2]]
+; CHECK-NEXT: [[CALL:%.*]] = tail call fast float @_Z5rootnfi(float [[X]], i32 [[Y]]) #[[ATTR3]]
; CHECK-NEXT: ret float [[CALL]]
;
entry:
@@ -1420,7 +1420,7 @@ entry:
define float @test_rootn_f32__y_0_nobuiltin(float %x) {
; CHECK-LABEL: define float @test_rootn_f32__y_0_nobuiltin(
; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 0) #[[ATTR2]]
+; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 0) #[[ATTR3]]
; CHECK-NEXT: ret float [[CALL]]
;
%call = tail call float @_Z5rootnfi(float %x, i32 0) #0
@@ -1430,7 +1430,7 @@ define float @test_rootn_f32__y_0_nobuiltin(float %x) {
define float @test_rootn_f32__y_1_nobuiltin(float %x) {
; CHECK-LABEL: define float @test_rootn_f32__y_1_nobuiltin(
; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 1) #[[ATTR2]]
+; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 1) #[[ATTR3]]
; CHECK-NEXT: ret float [[CALL]]
;
%call = tail call float @_Z5rootnfi(float %x, i32 1) #0
@@ -1440,7 +1440,7 @@ define float @test_rootn_f32__y_1_nobuiltin(float %x) {
define float @test_rootn_f32__y_2_nobuiltin(float %x) {
; CHECK-LABEL: define float @test_rootn_f32__y_2_nobuiltin(
; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 2) #[[ATTR2]]
+; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 2) #[[ATTR3]]
; CHECK-NEXT: ret float [[CALL]]
;
%call = tail call float @_Z5rootnfi(float %x, i32 2) #0
@@ -1450,7 +1450,7 @@ define float @test_rootn_f32__y_2_nobuiltin(float %x) {
define float @test_rootn_f32__y_3_nobuiltin(float %x) {
; CHECK-LABEL: define float @test_rootn_f32__y_3_nobuiltin(
; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 3) #[[ATTR2]]
+; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 3) #[[ATTR3]]
; CHECK-NEXT: ret float [[CALL]]
;
%call = tail call float @_Z5rootnfi(float %x, i32 3) #0
@@ -1460,7 +1460,7 @@ define float @test_rootn_f32__y_3_nobuiltin(float %x) {
define float @test_rootn_f32__y_neg1_nobuiltin(float %x) {
; CHECK-LABEL: define float @test_rootn_f32__y_neg1_nobuiltin(
; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -1) #[[ATTR2]]
+; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -1) #[[ATTR3]]
; CHECK-NEXT: ret float [[CALL]]
;
%call = tail call float @_Z5rootnfi(float %x, i32 -1) #0
@@ -1470,7 +1470,7 @@ define float @test_rootn_f32__y_neg1_nobuiltin(float %x) {
define float @test_rootn_f32__y_neg2_nobuiltin(float %x) {
; CHECK-LABEL: define float @test_rootn_f32__y_neg2_nobuiltin(
; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR2]]
+; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR3]]
; CHECK-NEXT: ret float [[CALL]]
;
%call = tail call float @_Z5rootnfi(float %x, i32 -2) #0
@@ -1485,6 +1485,10 @@ attributes #2 = { noinline }
!0 = !{float 3.0}
;.
; CHECK: attributes #[[ATTR0]] = { strictfp }
-; CHECK: attributes #[[ATTR1:[0-9]+]] = { nounwind memory(read) }
-; CHECK: attributes #[[ATTR2]] = { nobuiltin }
+; CHECK: attributes #[[ATTR1:[0-9]+]] = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
+; CHECK: attributes #[[ATTR2:[0-9]+]] = { nounwind memory(read) }
+; CHECK: attributes #[[ATTR3]] = { nobuiltin }
+;.
+; CHECK: [[META0]] = !{float 2.000000e+00}
+; CHECK: [[META1]] = !{float 3.000000e+00}
;.
diff --git a/llvm/test/CodeGen/AMDGPU/simplify-libcalls.ll b/llvm/test/CodeGen/AMDGPU/simplify-libcalls.ll
index 54ca33401ccf4..152eba5dec946 100644
--- a/llvm/test/CodeGen/AMDGPU/simplify-libcalls.ll
+++ b/llvm/test/CodeGen/AMDGPU/simplify-libcalls.ll
@@ -475,8 +475,7 @@ entry:
declare float @_Z5rootnfi(float, i32)
; GCN-LABEL: {{^}}define amdgpu_kernel void @test_rootn_2
-; GCN-POSTLINK: call fast float @_Z5rootnfi(float %tmp, i32 2)
-; GCN-PRELINK: %__rootn2sqrt = tail call fast float @llvm.sqrt.f32(float %tmp)
+; GCN: call fast float @llvm.sqrt.f32(float %tmp)
define amdgpu_kernel void @test_rootn_2(ptr addrspace(1) nocapture %a) {
entry:
%tmp = load float, ptr addrspace(1) %a, align 4
More information about the llvm-commits
mailing list