[llvm] Fix scalar overload name constructed by ReplaceWithVeclib.cpp (PR #111095)

via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 3 20:48:41 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Tex Riddell (tex3d)

<details>
<summary>Changes</summary>

ReplaceWithVeclib.cpp would construct overload name using all the arguments in the intrinsic, but overloads should only be constructed from arguments for which isVectorIntrinsicWithOverloadTypeAtArg returns true, including the return (-1).

Fixes #<!-- -->111093

---
Full diff: https://github.com/llvm/llvm-project/pull/111095.diff


2 Files Affected:

- (modified) llvm/lib/CodeGen/ReplaceWithVeclib.cpp (+15-1) 
- (modified) llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll (+5-12) 


``````````diff
diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 9fbb7b461364b1..551210db85713a 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -108,8 +108,22 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
   // all vector operands match the previously found EC.
   SmallVector<Type *, 8> ScalarArgTypes;
   Intrinsic::ID IID = II->getIntrinsicID();
+
+  // OloadTys collects types used in scalar intrinsic overload name.
+  SmallVector<Type *, 3> OloadTys;
+  if (VTy && isVectorIntrinsicWithOverloadTypeAtArg(IID, -1))
+    OloadTys.push_back(VTy->getElementType());
+
   for (auto Arg : enumerate(II->args())) {
     auto *ArgTy = Arg.value()->getType();
+    // Gather type if it is used in the overload name.
+    if (isVectorIntrinsicWithOverloadTypeAtArg(IID, Arg.index())) {
+      if (!isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index()) && isa<VectorType>(ArgTy))
+        OloadTys.push_back(cast<VectorType>(ArgTy)->getElementType());
+      else
+        OloadTys.push_back(ArgTy);
+    }
+
     if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
       ScalarArgTypes.push_back(ArgTy);
     } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
@@ -129,7 +143,7 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
   // using scalar argument types.
   std::string ScalarName =
       Intrinsic::isOverloaded(IID)
-          ? Intrinsic::getName(IID, ScalarArgTypes, II->getModule())
+          ? Intrinsic::getName(IID, OloadTys, II->getModule())
           : Intrinsic::getName(IID).str();
 
   // Try to find the mapping for the scalar version of this intrinsic and the
diff --git a/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll b/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
index f7e95008b71237..7b173bda561553 100644
--- a/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
+++ b/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
@@ -15,7 +15,7 @@ declare <vscale x 2 x double> @llvm.cos.nxv2f64(<vscale x 2 x double>)
 declare <vscale x 4 x float> @llvm.cos.nxv4f32(<vscale x 4 x float>)
 
 ;.
-; CHECK: @llvm.compiler.used = appending global [60 x ptr] [ptr @armpl_vcosq_f64, ptr @armpl_vcosq_f32, ptr @armpl_svcos_f64_x, ptr @armpl_svcos_f32_x, ptr @armpl_vexpq_f64, ptr @armpl_vexpq_f32, ptr @armpl_svexp_f64_x, ptr @armpl_svexp_f32_x, ptr @armpl_vexp10q_f64, ptr @armpl_vexp10q_f32, ptr @armpl_svexp10_f64_x, ptr @armpl_svexp10_f32_x, ptr @armpl_vexp2q_f64, ptr @armpl_vexp2q_f32, ptr @armpl_svexp2_f64_x, ptr @armpl_svexp2_f32_x, ptr @armpl_vlogq_f64, ptr @armpl_vlogq_f32, ptr @armpl_svlog_f64_x, ptr @armpl_svlog_f32_x, ptr @armpl_vlog10q_f64, ptr @armpl_vlog10q_f32, ptr @armpl_svlog10_f64_x, ptr @armpl_svlog10_f32_x, ptr @armpl_vlog2q_f64, ptr @armpl_vlog2q_f32, ptr @armpl_svlog2_f64_x, ptr @armpl_svlog2_f32_x, ptr @armpl_vsinq_f64, ptr @armpl_vsinq_f32, ptr @armpl_svsin_f64_x, ptr @armpl_svsin_f32_x, ptr @armpl_vtanq_f64, ptr @armpl_vtanq_f32, ptr @armpl_svtan_f64_x, ptr @armpl_svtan_f32_x, ptr @armpl_vacosq_f64, ptr @armpl_vacosq_f32, ptr @armpl_svacos_f64_x, ptr @armpl_svacos_f32_x, ptr @armpl_vasinq_f64, ptr @armpl_vasinq_f32, ptr @armpl_svasin_f64_x, ptr @armpl_svasin_f32_x, ptr @armpl_vatanq_f64, ptr @armpl_vatanq_f32, ptr @armpl_svatan_f64_x, ptr @armpl_svatan_f32_x, ptr @armpl_vcoshq_f64, ptr @armpl_vcoshq_f32, ptr @armpl_svcosh_f64_x, ptr @armpl_svcosh_f32_x, ptr @armpl_vsinhq_f64, ptr @armpl_vsinhq_f32, ptr @armpl_svsinh_f64_x, ptr @armpl_svsinh_f32_x, ptr @armpl_vtanhq_f64, ptr @armpl_vtanhq_f32, ptr @armpl_svtanh_f64_x, ptr @armpl_svtanh_f32_x], section "llvm.metadata"
+; CHECK: @llvm.compiler.used = appending global [64 x ptr] [ptr @armpl_vcosq_f64, ptr @armpl_vcosq_f32, ptr @armpl_svcos_f64_x, ptr @armpl_svcos_f32_x, ptr @armpl_vexpq_f64, ptr @armpl_vexpq_f32, ptr @armpl_svexp_f64_x, ptr @armpl_svexp_f32_x, ptr @armpl_vexp10q_f64, ptr @armpl_vexp10q_f32, ptr @armpl_svexp10_f64_x, ptr @armpl_svexp10_f32_x, ptr @armpl_vexp2q_f64, ptr @armpl_vexp2q_f32, ptr @armpl_svexp2_f64_x, ptr @armpl_svexp2_f32_x, ptr @armpl_vlogq_f64, ptr @armpl_vlogq_f32, ptr @armpl_svlog_f64_x, ptr @armpl_svlog_f32_x, ptr @armpl_vlog10q_f64, ptr @armpl_vlog10q_f32, ptr @armpl_svlog10_f64_x, ptr @armpl_svlog10_f32_x, ptr @armpl_vlog2q_f64, ptr @armpl_vlog2q_f32, ptr @armpl_svlog2_f64_x, ptr @armpl_svlog2_f32_x, ptr @armpl_vpowq_f64, ptr @armpl_vpowq_f32, ptr @armpl_svpow_f64_x, ptr @armpl_svpow_f32_x, ptr @armpl_vsinq_f64, ptr @armpl_vsinq_f32, ptr @armpl_svsin_f64_x, ptr @armpl_svsin_f32_x, ptr @armpl_vtanq_f64, ptr @armpl_vtanq_f32, ptr @armpl_svtan_f64_x, ptr @armpl_svtan_f32_x, ptr @armpl_vacosq_f64, ptr @armpl_vacosq_f32, ptr @armpl_svacos_f64_x, ptr @armpl_svacos_f32_x, ptr @armpl_vasinq_f64, ptr @armpl_vasinq_f32, ptr @armpl_svasin_f64_x, ptr @armpl_svasin_f32_x, ptr @armpl_vatanq_f64, ptr @armpl_vatanq_f32, ptr @armpl_svatan_f64_x, ptr @armpl_svatan_f32_x, ptr @armpl_vcoshq_f64, ptr @armpl_vcoshq_f32, ptr @armpl_svcosh_f64_x, ptr @armpl_svcosh_f32_x, ptr @armpl_vsinhq_f64, ptr @armpl_vsinhq_f32, ptr @armpl_svsinh_f64_x, ptr @armpl_svsinh_f32_x, ptr @armpl_vtanhq_f64, ptr @armpl_vtanhq_f32, ptr @armpl_svtanh_f64_x, ptr @armpl_svtanh_f32_x], section "llvm.metadata"
 
 ;.
 define <2 x double> @llvm_cos_f64(<2 x double> %in) {
@@ -333,17 +333,10 @@ declare <4 x float> @llvm.pow.v4f32(<4 x float>, <4 x float>)
 declare <vscale x 2 x double> @llvm.pow.nxv2f64(<vscale x 2 x double>, <vscale x 2 x double>)
 declare <vscale x 4 x float> @llvm.pow.nxv4f32(<vscale x 4 x float>, <vscale x 4 x float>)
 
-;
-; There is a bug in the replace-with-veclib pass, and for intrinsics which take
-; more than one arguments, but has just one overloaded type, it incorrectly
-; reconstructs the scalar name, for pow specifically it is searching for:
-; llvm.pow.f64.f64 and llvm.pow.f32.f32
-;
-
 define <2 x double> @llvm_pow_f64(<2 x double> %in, <2 x double> %power) {
 ; CHECK-LABEL: define <2 x double> @llvm_pow_f64
 ; CHECK-SAME: (<2 x double> [[IN:%.*]], <2 x double> [[POWER:%.*]]) {
-; CHECK-NEXT:    [[TMP1:%.*]] = call fast <2 x double> @llvm.pow.v2f64(<2 x double> [[IN]], <2 x double> [[POWER]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call fast <2 x double> @armpl_vpowq_f64(<2 x double> [[IN]], <2 x double> [[POWER]])
 ; CHECK-NEXT:    ret <2 x double> [[TMP1]]
 ;
   %1 = call fast <2 x double> @llvm.pow.v2f64(<2 x double> %in, <2 x double> %power)
@@ -353,7 +346,7 @@ define <2 x double> @llvm_pow_f64(<2 x double> %in, <2 x double> %power) {
 define <4 x float> @llvm_pow_f32(<4 x float> %in, <4 x float> %power) {
 ; CHECK-LABEL: define <4 x float> @llvm_pow_f32
 ; CHECK-SAME: (<4 x float> [[IN:%.*]], <4 x float> [[POWER:%.*]]) {
-; CHECK-NEXT:    [[TMP1:%.*]] = call fast <4 x float> @llvm.pow.v4f32(<4 x float> [[IN]], <4 x float> [[POWER]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call fast <4 x float> @armpl_vpowq_f32(<4 x float> [[IN]], <4 x float> [[POWER]])
 ; CHECK-NEXT:    ret <4 x float> [[TMP1]]
 ;
   %1 = call fast <4 x float> @llvm.pow.v4f32(<4 x float> %in, <4 x float> %power)
@@ -363,7 +356,7 @@ define <4 x float> @llvm_pow_f32(<4 x float> %in, <4 x float> %power) {
 define <vscale x 2 x double> @llvm_pow_vscale_f64(<vscale x 2 x double> %in, <vscale x 2 x double> %power) #0 {
 ; CHECK-LABEL: define <vscale x 2 x double> @llvm_pow_vscale_f64
 ; CHECK-SAME: (<vscale x 2 x double> [[IN:%.*]], <vscale x 2 x double> [[POWER:%.*]]) #[[ATTR1]] {
-; CHECK-NEXT:    [[TMP1:%.*]] = call fast <vscale x 2 x double> @llvm.pow.nxv2f64(<vscale x 2 x double> [[IN]], <vscale x 2 x double> [[POWER]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call fast <vscale x 2 x double> @armpl_svpow_f64_x(<vscale x 2 x double> [[IN]], <vscale x 2 x double> [[POWER]], <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> poison, i1 true, i64 0), <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer))
 ; CHECK-NEXT:    ret <vscale x 2 x double> [[TMP1]]
 ;
   %1 = call fast <vscale x 2 x double> @llvm.pow.nxv2f64(<vscale x 2 x double> %in, <vscale x 2 x double> %power)
@@ -373,7 +366,7 @@ define <vscale x 2 x double> @llvm_pow_vscale_f64(<vscale x 2 x double> %in, <vs
 define <vscale x 4 x float> @llvm_pow_vscale_f32(<vscale x 4 x float> %in, <vscale x 4 x float> %power) #0 {
 ; CHECK-LABEL: define <vscale x 4 x float> @llvm_pow_vscale_f32
 ; CHECK-SAME: (<vscale x 4 x float> [[IN:%.*]], <vscale x 4 x float> [[POWER:%.*]]) #[[ATTR1]] {
-; CHECK-NEXT:    [[TMP1:%.*]] = call fast <vscale x 4 x float> @llvm.pow.nxv4f32(<vscale x 4 x float> [[IN]], <vscale x 4 x float> [[POWER]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call fast <vscale x 4 x float> @armpl_svpow_f32_x(<vscale x 4 x float> [[IN]], <vscale x 4 x float> [[POWER]], <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer))
 ; CHECK-NEXT:    ret <vscale x 4 x float> [[TMP1]]
 ;
   %1 = call fast <vscale x 4 x float> @llvm.pow.nxv4f32(<vscale x 4 x float> %in, <vscale x 4 x float> %power)

``````````

</details>


https://github.com/llvm/llvm-project/pull/111095


More information about the llvm-commits mailing list