[llvm] Fix scalar overload name constructed by ReplaceWithVeclib.cpp (PR #111095)

Tex Riddell via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 7 11:48:49 PDT 2024


https://github.com/tex3d updated https://github.com/llvm/llvm-project/pull/111095

>From 8227c918a8f99ac081e630fb2a1bfa14be6e82c6 Mon Sep 17 00:00:00 2001
From: Tex Riddell <texr at microsoft.com>
Date: Thu, 3 Oct 2024 20:46:41 -0700
Subject: [PATCH 1/2] Fix scalar overload name constructed by
 ReplaceWithVeclib.cpp

ReplaceWithVeclib.cpp would construct overload name using all the arguments in the intrinsic, but overloads should only be constructed from arguments for which isVectorIntrinsicWithOverloadTypeAtArg returns true, including the return (-1).

Fixes #111093
---
 llvm/lib/CodeGen/ReplaceWithVeclib.cpp          | 16 +++++++++++++++-
 .../AArch64/replace-with-veclib-armpl.ll        | 17 +++++------------
 2 files changed, 20 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 9fbb7b461364b1..551210db85713a 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -108,8 +108,22 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
   // all vector operands match the previously found EC.
   SmallVector<Type *, 8> ScalarArgTypes;
   Intrinsic::ID IID = II->getIntrinsicID();
+
+  // OloadTys collects types used in scalar intrinsic overload name.
+  SmallVector<Type *, 3> OloadTys;
+  if (VTy && isVectorIntrinsicWithOverloadTypeAtArg(IID, -1))
+    OloadTys.push_back(VTy->getElementType());
+
   for (auto Arg : enumerate(II->args())) {
     auto *ArgTy = Arg.value()->getType();
+    // Gather type if it is used in the overload name.
+    if (isVectorIntrinsicWithOverloadTypeAtArg(IID, Arg.index())) {
+      if (!isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index()) && isa<VectorType>(ArgTy))
+        OloadTys.push_back(cast<VectorType>(ArgTy)->getElementType());
+      else
+        OloadTys.push_back(ArgTy);
+    }
+
     if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
       ScalarArgTypes.push_back(ArgTy);
     } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
@@ -129,7 +143,7 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
   // using scalar argument types.
   std::string ScalarName =
       Intrinsic::isOverloaded(IID)
-          ? Intrinsic::getName(IID, ScalarArgTypes, II->getModule())
+          ? Intrinsic::getName(IID, OloadTys, II->getModule())
           : Intrinsic::getName(IID).str();
 
   // Try to find the mapping for the scalar version of this intrinsic and the
diff --git a/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll b/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
index f7e95008b71237..7b173bda561553 100644
--- a/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
+++ b/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
@@ -15,7 +15,7 @@ declare <vscale x 2 x double> @llvm.cos.nxv2f64(<vscale x 2 x double>)
 declare <vscale x 4 x float> @llvm.cos.nxv4f32(<vscale x 4 x float>)
 
 ;.
-; CHECK: @llvm.compiler.used = appending global [60 x ptr] [ptr @armpl_vcosq_f64, ptr @armpl_vcosq_f32, ptr @armpl_svcos_f64_x, ptr @armpl_svcos_f32_x, ptr @armpl_vexpq_f64, ptr @armpl_vexpq_f32, ptr @armpl_svexp_f64_x, ptr @armpl_svexp_f32_x, ptr @armpl_vexp10q_f64, ptr @armpl_vexp10q_f32, ptr @armpl_svexp10_f64_x, ptr @armpl_svexp10_f32_x, ptr @armpl_vexp2q_f64, ptr @armpl_vexp2q_f32, ptr @armpl_svexp2_f64_x, ptr @armpl_svexp2_f32_x, ptr @armpl_vlogq_f64, ptr @armpl_vlogq_f32, ptr @armpl_svlog_f64_x, ptr @armpl_svlog_f32_x, ptr @armpl_vlog10q_f64, ptr @armpl_vlog10q_f32, ptr @armpl_svlog10_f64_x, ptr @armpl_svlog10_f32_x, ptr @armpl_vlog2q_f64, ptr @armpl_vlog2q_f32, ptr @armpl_svlog2_f64_x, ptr @armpl_svlog2_f32_x, ptr @armpl_vsinq_f64, ptr @armpl_vsinq_f32, ptr @armpl_svsin_f64_x, ptr @armpl_svsin_f32_x, ptr @armpl_vtanq_f64, ptr @armpl_vtanq_f32, ptr @armpl_svtan_f64_x, ptr @armpl_svtan_f32_x, ptr @armpl_vacosq_f64, ptr @armpl_vacosq_f32, ptr @armpl_svacos_f64_x, ptr @armpl_svacos_f32_x, ptr @armpl_vasinq_f64, ptr @armpl_vasinq_f32, ptr @armpl_svasin_f64_x, ptr @armpl_svasin_f32_x, ptr @armpl_vatanq_f64, ptr @armpl_vatanq_f32, ptr @armpl_svatan_f64_x, ptr @armpl_svatan_f32_x, ptr @armpl_vcoshq_f64, ptr @armpl_vcoshq_f32, ptr @armpl_svcosh_f64_x, ptr @armpl_svcosh_f32_x, ptr @armpl_vsinhq_f64, ptr @armpl_vsinhq_f32, ptr @armpl_svsinh_f64_x, ptr @armpl_svsinh_f32_x, ptr @armpl_vtanhq_f64, ptr @armpl_vtanhq_f32, ptr @armpl_svtanh_f64_x, ptr @armpl_svtanh_f32_x], section "llvm.metadata"
+; CHECK: @llvm.compiler.used = appending global [64 x ptr] [ptr @armpl_vcosq_f64, ptr @armpl_vcosq_f32, ptr @armpl_svcos_f64_x, ptr @armpl_svcos_f32_x, ptr @armpl_vexpq_f64, ptr @armpl_vexpq_f32, ptr @armpl_svexp_f64_x, ptr @armpl_svexp_f32_x, ptr @armpl_vexp10q_f64, ptr @armpl_vexp10q_f32, ptr @armpl_svexp10_f64_x, ptr @armpl_svexp10_f32_x, ptr @armpl_vexp2q_f64, ptr @armpl_vexp2q_f32, ptr @armpl_svexp2_f64_x, ptr @armpl_svexp2_f32_x, ptr @armpl_vlogq_f64, ptr @armpl_vlogq_f32, ptr @armpl_svlog_f64_x, ptr @armpl_svlog_f32_x, ptr @armpl_vlog10q_f64, ptr @armpl_vlog10q_f32, ptr @armpl_svlog10_f64_x, ptr @armpl_svlog10_f32_x, ptr @armpl_vlog2q_f64, ptr @armpl_vlog2q_f32, ptr @armpl_svlog2_f64_x, ptr @armpl_svlog2_f32_x, ptr @armpl_vpowq_f64, ptr @armpl_vpowq_f32, ptr @armpl_svpow_f64_x, ptr @armpl_svpow_f32_x, ptr @armpl_vsinq_f64, ptr @armpl_vsinq_f32, ptr @armpl_svsin_f64_x, ptr @armpl_svsin_f32_x, ptr @armpl_vtanq_f64, ptr @armpl_vtanq_f32, ptr @armpl_svtan_f64_x, ptr @armpl_svtan_f32_x, ptr @armpl_vacosq_f64, ptr @armpl_vacosq_f32, ptr @armpl_svacos_f64_x, ptr @armpl_svacos_f32_x, ptr @armpl_vasinq_f64, ptr @armpl_vasinq_f32, ptr @armpl_svasin_f64_x, ptr @armpl_svasin_f32_x, ptr @armpl_vatanq_f64, ptr @armpl_vatanq_f32, ptr @armpl_svatan_f64_x, ptr @armpl_svatan_f32_x, ptr @armpl_vcoshq_f64, ptr @armpl_vcoshq_f32, ptr @armpl_svcosh_f64_x, ptr @armpl_svcosh_f32_x, ptr @armpl_vsinhq_f64, ptr @armpl_vsinhq_f32, ptr @armpl_svsinh_f64_x, ptr @armpl_svsinh_f32_x, ptr @armpl_vtanhq_f64, ptr @armpl_vtanhq_f32, ptr @armpl_svtanh_f64_x, ptr @armpl_svtanh_f32_x], section "llvm.metadata"
 
 ;.
 define <2 x double> @llvm_cos_f64(<2 x double> %in) {
@@ -333,17 +333,10 @@ declare <4 x float> @llvm.pow.v4f32(<4 x float>, <4 x float>)
 declare <vscale x 2 x double> @llvm.pow.nxv2f64(<vscale x 2 x double>, <vscale x 2 x double>)
 declare <vscale x 4 x float> @llvm.pow.nxv4f32(<vscale x 4 x float>, <vscale x 4 x float>)
 
-;
-; There is a bug in the replace-with-veclib pass, and for intrinsics which take
-; more than one arguments, but has just one overloaded type, it incorrectly
-; reconstructs the scalar name, for pow specifically it is searching for:
-; llvm.pow.f64.f64 and llvm.pow.f32.f32
-;
-
 define <2 x double> @llvm_pow_f64(<2 x double> %in, <2 x double> %power) {
 ; CHECK-LABEL: define <2 x double> @llvm_pow_f64
 ; CHECK-SAME: (<2 x double> [[IN:%.*]], <2 x double> [[POWER:%.*]]) {
-; CHECK-NEXT:    [[TMP1:%.*]] = call fast <2 x double> @llvm.pow.v2f64(<2 x double> [[IN]], <2 x double> [[POWER]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call fast <2 x double> @armpl_vpowq_f64(<2 x double> [[IN]], <2 x double> [[POWER]])
 ; CHECK-NEXT:    ret <2 x double> [[TMP1]]
 ;
   %1 = call fast <2 x double> @llvm.pow.v2f64(<2 x double> %in, <2 x double> %power)
@@ -353,7 +346,7 @@ define <2 x double> @llvm_pow_f64(<2 x double> %in, <2 x double> %power) {
 define <4 x float> @llvm_pow_f32(<4 x float> %in, <4 x float> %power) {
 ; CHECK-LABEL: define <4 x float> @llvm_pow_f32
 ; CHECK-SAME: (<4 x float> [[IN:%.*]], <4 x float> [[POWER:%.*]]) {
-; CHECK-NEXT:    [[TMP1:%.*]] = call fast <4 x float> @llvm.pow.v4f32(<4 x float> [[IN]], <4 x float> [[POWER]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call fast <4 x float> @armpl_vpowq_f32(<4 x float> [[IN]], <4 x float> [[POWER]])
 ; CHECK-NEXT:    ret <4 x float> [[TMP1]]
 ;
   %1 = call fast <4 x float> @llvm.pow.v4f32(<4 x float> %in, <4 x float> %power)
@@ -363,7 +356,7 @@ define <4 x float> @llvm_pow_f32(<4 x float> %in, <4 x float> %power) {
 define <vscale x 2 x double> @llvm_pow_vscale_f64(<vscale x 2 x double> %in, <vscale x 2 x double> %power) #0 {
 ; CHECK-LABEL: define <vscale x 2 x double> @llvm_pow_vscale_f64
 ; CHECK-SAME: (<vscale x 2 x double> [[IN:%.*]], <vscale x 2 x double> [[POWER:%.*]]) #[[ATTR1]] {
-; CHECK-NEXT:    [[TMP1:%.*]] = call fast <vscale x 2 x double> @llvm.pow.nxv2f64(<vscale x 2 x double> [[IN]], <vscale x 2 x double> [[POWER]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call fast <vscale x 2 x double> @armpl_svpow_f64_x(<vscale x 2 x double> [[IN]], <vscale x 2 x double> [[POWER]], <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> poison, i1 true, i64 0), <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer))
 ; CHECK-NEXT:    ret <vscale x 2 x double> [[TMP1]]
 ;
   %1 = call fast <vscale x 2 x double> @llvm.pow.nxv2f64(<vscale x 2 x double> %in, <vscale x 2 x double> %power)
@@ -373,7 +366,7 @@ define <vscale x 2 x double> @llvm_pow_vscale_f64(<vscale x 2 x double> %in, <vs
 define <vscale x 4 x float> @llvm_pow_vscale_f32(<vscale x 4 x float> %in, <vscale x 4 x float> %power) #0 {
 ; CHECK-LABEL: define <vscale x 4 x float> @llvm_pow_vscale_f32
 ; CHECK-SAME: (<vscale x 4 x float> [[IN:%.*]], <vscale x 4 x float> [[POWER:%.*]]) #[[ATTR1]] {
-; CHECK-NEXT:    [[TMP1:%.*]] = call fast <vscale x 4 x float> @llvm.pow.nxv4f32(<vscale x 4 x float> [[IN]], <vscale x 4 x float> [[POWER]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call fast <vscale x 4 x float> @armpl_svpow_f32_x(<vscale x 4 x float> [[IN]], <vscale x 4 x float> [[POWER]], <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer))
 ; CHECK-NEXT:    ret <vscale x 4 x float> [[TMP1]]
 ;
   %1 = call fast <vscale x 4 x float> @llvm.pow.nxv4f32(<vscale x 4 x float> %in, <vscale x 4 x float> %power)

>From f0557de28123377baff92fc12d712244451b3708 Mon Sep 17 00:00:00 2001
From: Tex Riddell <texr at microsoft.com>
Date: Mon, 7 Oct 2024 11:48:16 -0700
Subject: [PATCH 2/2] Simplify logic a bit, handle scalar return overload type.

---
 llvm/lib/CodeGen/ReplaceWithVeclib.cpp | 25 +++++++++----------------
 1 file changed, 9 insertions(+), 16 deletions(-)

diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 551210db85713a..740712d17fe68a 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -104,6 +104,7 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
   // a void type.
   auto *VTy = dyn_cast<VectorType>(II->getType());
   ElementCount EC(VTy ? VTy->getElementCount() : ElementCount::getFixed(0));
+  Type *ScalarRetTy = II->getType()->getScalarType();
   // Compute the argument types of the corresponding scalar call and check that
   // all vector operands match the previously found EC.
   SmallVector<Type *, 8> ScalarArgTypes;
@@ -111,30 +112,23 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
 
   // OloadTys collects types used in scalar intrinsic overload name.
   SmallVector<Type *, 3> OloadTys;
-  if (VTy && isVectorIntrinsicWithOverloadTypeAtArg(IID, -1))
-    OloadTys.push_back(VTy->getElementType());
+  if (isVectorIntrinsicWithOverloadTypeAtArg(IID, -1))
+    OloadTys.push_back(ScalarRetTy);
 
   for (auto Arg : enumerate(II->args())) {
     auto *ArgTy = Arg.value()->getType();
-    // Gather type if it is used in the overload name.
-    if (isVectorIntrinsicWithOverloadTypeAtArg(IID, Arg.index())) {
-      if (!isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index()) && isa<VectorType>(ArgTy))
-        OloadTys.push_back(cast<VectorType>(ArgTy)->getElementType());
-      else
-        OloadTys.push_back(ArgTy);
-    }
-
-    if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
-      ScalarArgTypes.push_back(ArgTy);
-    } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
-      ScalarArgTypes.push_back(VectorArgTy->getElementType());
+    auto *ScalarArgTy = ArgTy->getScalarType();
+    ScalarArgTypes.push_back(ScalarArgTy);
+    if (isVectorIntrinsicWithOverloadTypeAtArg(IID, Arg.index()))
+        OloadTys.push_back(ScalarArgTy);
+    if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
       // When return type is void, set EC to the first vector argument, and
       // disallow vector arguments with different ECs.
       if (EC.isZero())
         EC = VectorArgTy->getElementCount();
       else if (EC != VectorArgTy->getElementCount())
         return false;
-    } else
+    } else if (!isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index()))
       // Exit when it is supposed to be a vector argument but it isn't.
       return false;
   }
@@ -160,7 +154,6 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
 
   // Replace the call to the intrinsic with a call to the vector library
   // function.
-  Type *ScalarRetTy = II->getType()->getScalarType();
   FunctionType *ScalarFTy =
       FunctionType::get(ScalarRetTy, ScalarArgTypes, /*isVarArg*/ false);
   const std::string MangledName = VD->getVectorFunctionABIVariantString();



More information about the llvm-commits mailing list