[llvm] Fix scalar overload name constructed by ReplaceWithVeclib.cpp (PR #111095)
Tex Riddell via llvm-commits
llvm-commits at lists.llvm.org
Thu Oct 3 20:48:10 PDT 2024
https://github.com/tex3d created https://github.com/llvm/llvm-project/pull/111095
ReplaceWithVeclib.cpp would construct overload name using all the arguments in the intrinsic, but overloads should only be constructed from arguments for which isVectorIntrinsicWithOverloadTypeAtArg returns true, including the return (-1).
Fixes #111093
>From 8227c918a8f99ac081e630fb2a1bfa14be6e82c6 Mon Sep 17 00:00:00 2001
From: Tex Riddell <texr at microsoft.com>
Date: Thu, 3 Oct 2024 20:46:41 -0700
Subject: [PATCH] Fix scalar overload name constructed by ReplaceWithVeclib.cpp
ReplaceWithVeclib.cpp would construct overload name using all the arguments in the intrinsic, but overloads should only be constructed from arguments for which isVectorIntrinsicWithOverloadTypeAtArg returns true, including the return (-1).
Fixes #111093
---
llvm/lib/CodeGen/ReplaceWithVeclib.cpp | 16 +++++++++++++++-
.../AArch64/replace-with-veclib-armpl.ll | 17 +++++------------
2 files changed, 20 insertions(+), 13 deletions(-)
diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 9fbb7b461364b1..551210db85713a 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -108,8 +108,22 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
// all vector operands match the previously found EC.
SmallVector<Type *, 8> ScalarArgTypes;
Intrinsic::ID IID = II->getIntrinsicID();
+
+ // OloadTys collects types used in scalar intrinsic overload name.
+ SmallVector<Type *, 3> OloadTys;
+ if (VTy && isVectorIntrinsicWithOverloadTypeAtArg(IID, -1))
+ OloadTys.push_back(VTy->getElementType());
+
for (auto Arg : enumerate(II->args())) {
auto *ArgTy = Arg.value()->getType();
+ // Gather type if it is used in the overload name.
+ if (isVectorIntrinsicWithOverloadTypeAtArg(IID, Arg.index())) {
+ if (!isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index()) && isa<VectorType>(ArgTy))
+ OloadTys.push_back(cast<VectorType>(ArgTy)->getElementType());
+ else
+ OloadTys.push_back(ArgTy);
+ }
+
if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
ScalarArgTypes.push_back(ArgTy);
} else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
@@ -129,7 +143,7 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
// using scalar argument types.
std::string ScalarName =
Intrinsic::isOverloaded(IID)
- ? Intrinsic::getName(IID, ScalarArgTypes, II->getModule())
+ ? Intrinsic::getName(IID, OloadTys, II->getModule())
: Intrinsic::getName(IID).str();
// Try to find the mapping for the scalar version of this intrinsic and the
diff --git a/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll b/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
index f7e95008b71237..7b173bda561553 100644
--- a/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
+++ b/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
@@ -15,7 +15,7 @@ declare <vscale x 2 x double> @llvm.cos.nxv2f64(<vscale x 2 x double>)
declare <vscale x 4 x float> @llvm.cos.nxv4f32(<vscale x 4 x float>)
;.
-; CHECK: @llvm.compiler.used = appending global [60 x ptr] [ptr @armpl_vcosq_f64, ptr @armpl_vcosq_f32, ptr @armpl_svcos_f64_x, ptr @armpl_svcos_f32_x, ptr @armpl_vexpq_f64, ptr @armpl_vexpq_f32, ptr @armpl_svexp_f64_x, ptr @armpl_svexp_f32_x, ptr @armpl_vexp10q_f64, ptr @armpl_vexp10q_f32, ptr @armpl_svexp10_f64_x, ptr @armpl_svexp10_f32_x, ptr @armpl_vexp2q_f64, ptr @armpl_vexp2q_f32, ptr @armpl_svexp2_f64_x, ptr @armpl_svexp2_f32_x, ptr @armpl_vlogq_f64, ptr @armpl_vlogq_f32, ptr @armpl_svlog_f64_x, ptr @armpl_svlog_f32_x, ptr @armpl_vlog10q_f64, ptr @armpl_vlog10q_f32, ptr @armpl_svlog10_f64_x, ptr @armpl_svlog10_f32_x, ptr @armpl_vlog2q_f64, ptr @armpl_vlog2q_f32, ptr @armpl_svlog2_f64_x, ptr @armpl_svlog2_f32_x, ptr @armpl_vsinq_f64, ptr @armpl_vsinq_f32, ptr @armpl_svsin_f64_x, ptr @armpl_svsin_f32_x, ptr @armpl_vtanq_f64, ptr @armpl_vtanq_f32, ptr @armpl_svtan_f64_x, ptr @armpl_svtan_f32_x, ptr @armpl_vacosq_f64, ptr @armpl_vacosq_f32, ptr @armpl_svacos_f64_x, ptr @armpl_svacos_f32_x, ptr @armpl_vasinq_f64, ptr @armpl_vasinq_f32, ptr @armpl_svasin_f64_x, ptr @armpl_svasin_f32_x, ptr @armpl_vatanq_f64, ptr @armpl_vatanq_f32, ptr @armpl_svatan_f64_x, ptr @armpl_svatan_f32_x, ptr @armpl_vcoshq_f64, ptr @armpl_vcoshq_f32, ptr @armpl_svcosh_f64_x, ptr @armpl_svcosh_f32_x, ptr @armpl_vsinhq_f64, ptr @armpl_vsinhq_f32, ptr @armpl_svsinh_f64_x, ptr @armpl_svsinh_f32_x, ptr @armpl_vtanhq_f64, ptr @armpl_vtanhq_f32, ptr @armpl_svtanh_f64_x, ptr @armpl_svtanh_f32_x], section "llvm.metadata"
+; CHECK: @llvm.compiler.used = appending global [64 x ptr] [ptr @armpl_vcosq_f64, ptr @armpl_vcosq_f32, ptr @armpl_svcos_f64_x, ptr @armpl_svcos_f32_x, ptr @armpl_vexpq_f64, ptr @armpl_vexpq_f32, ptr @armpl_svexp_f64_x, ptr @armpl_svexp_f32_x, ptr @armpl_vexp10q_f64, ptr @armpl_vexp10q_f32, ptr @armpl_svexp10_f64_x, ptr @armpl_svexp10_f32_x, ptr @armpl_vexp2q_f64, ptr @armpl_vexp2q_f32, ptr @armpl_svexp2_f64_x, ptr @armpl_svexp2_f32_x, ptr @armpl_vlogq_f64, ptr @armpl_vlogq_f32, ptr @armpl_svlog_f64_x, ptr @armpl_svlog_f32_x, ptr @armpl_vlog10q_f64, ptr @armpl_vlog10q_f32, ptr @armpl_svlog10_f64_x, ptr @armpl_svlog10_f32_x, ptr @armpl_vlog2q_f64, ptr @armpl_vlog2q_f32, ptr @armpl_svlog2_f64_x, ptr @armpl_svlog2_f32_x, ptr @armpl_vpowq_f64, ptr @armpl_vpowq_f32, ptr @armpl_svpow_f64_x, ptr @armpl_svpow_f32_x, ptr @armpl_vsinq_f64, ptr @armpl_vsinq_f32, ptr @armpl_svsin_f64_x, ptr @armpl_svsin_f32_x, ptr @armpl_vtanq_f64, ptr @armpl_vtanq_f32, ptr @armpl_svtan_f64_x, ptr @armpl_svtan_f32_x, ptr @armpl_vacosq_f64, ptr @armpl_vacosq_f32, ptr @armpl_svacos_f64_x, ptr @armpl_svacos_f32_x, ptr @armpl_vasinq_f64, ptr @armpl_vasinq_f32, ptr @armpl_svasin_f64_x, ptr @armpl_svasin_f32_x, ptr @armpl_vatanq_f64, ptr @armpl_vatanq_f32, ptr @armpl_svatan_f64_x, ptr @armpl_svatan_f32_x, ptr @armpl_vcoshq_f64, ptr @armpl_vcoshq_f32, ptr @armpl_svcosh_f64_x, ptr @armpl_svcosh_f32_x, ptr @armpl_vsinhq_f64, ptr @armpl_vsinhq_f32, ptr @armpl_svsinh_f64_x, ptr @armpl_svsinh_f32_x, ptr @armpl_vtanhq_f64, ptr @armpl_vtanhq_f32, ptr @armpl_svtanh_f64_x, ptr @armpl_svtanh_f32_x], section "llvm.metadata"
;.
define <2 x double> @llvm_cos_f64(<2 x double> %in) {
@@ -333,17 +333,10 @@ declare <4 x float> @llvm.pow.v4f32(<4 x float>, <4 x float>)
declare <vscale x 2 x double> @llvm.pow.nxv2f64(<vscale x 2 x double>, <vscale x 2 x double>)
declare <vscale x 4 x float> @llvm.pow.nxv4f32(<vscale x 4 x float>, <vscale x 4 x float>)
-;
-; There is a bug in the replace-with-veclib pass, and for intrinsics which take
-; more than one arguments, but has just one overloaded type, it incorrectly
-; reconstructs the scalar name, for pow specifically it is searching for:
-; llvm.pow.f64.f64 and llvm.pow.f32.f32
-;
-
define <2 x double> @llvm_pow_f64(<2 x double> %in, <2 x double> %power) {
; CHECK-LABEL: define <2 x double> @llvm_pow_f64
; CHECK-SAME: (<2 x double> [[IN:%.*]], <2 x double> [[POWER:%.*]]) {
-; CHECK-NEXT: [[TMP1:%.*]] = call fast <2 x double> @llvm.pow.v2f64(<2 x double> [[IN]], <2 x double> [[POWER]])
+; CHECK-NEXT: [[TMP1:%.*]] = call fast <2 x double> @armpl_vpowq_f64(<2 x double> [[IN]], <2 x double> [[POWER]])
; CHECK-NEXT: ret <2 x double> [[TMP1]]
;
%1 = call fast <2 x double> @llvm.pow.v2f64(<2 x double> %in, <2 x double> %power)
@@ -353,7 +346,7 @@ define <2 x double> @llvm_pow_f64(<2 x double> %in, <2 x double> %power) {
define <4 x float> @llvm_pow_f32(<4 x float> %in, <4 x float> %power) {
; CHECK-LABEL: define <4 x float> @llvm_pow_f32
; CHECK-SAME: (<4 x float> [[IN:%.*]], <4 x float> [[POWER:%.*]]) {
-; CHECK-NEXT: [[TMP1:%.*]] = call fast <4 x float> @llvm.pow.v4f32(<4 x float> [[IN]], <4 x float> [[POWER]])
+; CHECK-NEXT: [[TMP1:%.*]] = call fast <4 x float> @armpl_vpowq_f32(<4 x float> [[IN]], <4 x float> [[POWER]])
; CHECK-NEXT: ret <4 x float> [[TMP1]]
;
%1 = call fast <4 x float> @llvm.pow.v4f32(<4 x float> %in, <4 x float> %power)
@@ -363,7 +356,7 @@ define <4 x float> @llvm_pow_f32(<4 x float> %in, <4 x float> %power) {
define <vscale x 2 x double> @llvm_pow_vscale_f64(<vscale x 2 x double> %in, <vscale x 2 x double> %power) #0 {
; CHECK-LABEL: define <vscale x 2 x double> @llvm_pow_vscale_f64
; CHECK-SAME: (<vscale x 2 x double> [[IN:%.*]], <vscale x 2 x double> [[POWER:%.*]]) #[[ATTR1]] {
-; CHECK-NEXT: [[TMP1:%.*]] = call fast <vscale x 2 x double> @llvm.pow.nxv2f64(<vscale x 2 x double> [[IN]], <vscale x 2 x double> [[POWER]])
+; CHECK-NEXT: [[TMP1:%.*]] = call fast <vscale x 2 x double> @armpl_svpow_f64_x(<vscale x 2 x double> [[IN]], <vscale x 2 x double> [[POWER]], <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> poison, i1 true, i64 0), <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer))
; CHECK-NEXT: ret <vscale x 2 x double> [[TMP1]]
;
%1 = call fast <vscale x 2 x double> @llvm.pow.nxv2f64(<vscale x 2 x double> %in, <vscale x 2 x double> %power)
@@ -373,7 +366,7 @@ define <vscale x 2 x double> @llvm_pow_vscale_f64(<vscale x 2 x double> %in, <vs
define <vscale x 4 x float> @llvm_pow_vscale_f32(<vscale x 4 x float> %in, <vscale x 4 x float> %power) #0 {
; CHECK-LABEL: define <vscale x 4 x float> @llvm_pow_vscale_f32
; CHECK-SAME: (<vscale x 4 x float> [[IN:%.*]], <vscale x 4 x float> [[POWER:%.*]]) #[[ATTR1]] {
-; CHECK-NEXT: [[TMP1:%.*]] = call fast <vscale x 4 x float> @llvm.pow.nxv4f32(<vscale x 4 x float> [[IN]], <vscale x 4 x float> [[POWER]])
+; CHECK-NEXT: [[TMP1:%.*]] = call fast <vscale x 4 x float> @armpl_svpow_f32_x(<vscale x 4 x float> [[IN]], <vscale x 4 x float> [[POWER]], <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer))
; CHECK-NEXT: ret <vscale x 4 x float> [[TMP1]]
;
%1 = call fast <vscale x 4 x float> @llvm.pow.nxv4f32(<vscale x 4 x float> %in, <vscale x 4 x float> %power)
More information about the llvm-commits
mailing list