[llvm-branch-commits] [llvm] [AArch64] SLP can vectorize frem (PR #82488)

Paschalis Mpeis via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Feb 22 01:33:30 PST 2024


https://github.com/paschalis-mpeis updated https://github.com/llvm/llvm-project/pull/82488

>From 641aaf7c13d520bef52b092726f8346bfecb1c8d Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Wed, 21 Feb 2024 11:53:00 +0000
Subject: [PATCH 1/4] SLP cannot vectorize frem calls in AArch64.

It needs updated costs when there are available vector library functions
given the VF and type.
---
 .../SLPVectorizer/AArch64/slp-frem.ll         | 71 +++++++++++++++++++
 1 file changed, 71 insertions(+)
 create mode 100644 llvm/test/Transforms/SLPVectorizer/AArch64/slp-frem.ll

diff --git a/llvm/test/Transforms/SLPVectorizer/AArch64/slp-frem.ll b/llvm/test/Transforms/SLPVectorizer/AArch64/slp-frem.ll
new file mode 100644
index 00000000000000..45f667f5657889
--- /dev/null
+++ b/llvm/test/Transforms/SLPVectorizer/AArch64/slp-frem.ll
@@ -0,0 +1,71 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -S -mtriple=aarch64 -vector-library=ArmPL -passes=slp-vectorizer | FileCheck %s
+
+ at a = common global ptr null, align 8
+
+define void @frem_v2double() {
+; CHECK-LABEL: define void @frem_v2double() {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[A0:%.*]] = load double, ptr @a, align 8
+; CHECK-NEXT:    [[A1:%.*]] = load double, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
+; CHECK-NEXT:    [[B0:%.*]] = load double, ptr @a, align 8
+; CHECK-NEXT:    [[B1:%.*]] = load double, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
+; CHECK-NEXT:    [[R0:%.*]] = frem double [[A0]], [[B0]]
+; CHECK-NEXT:    [[R1:%.*]] = frem double [[A1]], [[B1]]
+; CHECK-NEXT:    store double [[R0]], ptr @a, align 8
+; CHECK-NEXT:    store double [[R1]], ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
+; CHECK-NEXT:    ret void
+;
+entry:
+  %a0 = load double, ptr getelementptr inbounds (double, ptr @a, i64 0), align 8
+  %a1 = load double, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
+  %b0 = load double, ptr getelementptr inbounds (double, ptr @a, i64 0), align 8
+  %b1 = load double, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
+  %r0 = frem double %a0, %b0
+  %r1 = frem double %a1, %b1
+  store double %r0, ptr getelementptr inbounds (double, ptr @a, i64 0), align 8
+  store double %r1, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
+  ret void
+}
+
+define void @frem_v4float() {
+; CHECK-LABEL: define void @frem_v4float() {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[A0:%.*]] = load float, ptr @a, align 8
+; CHECK-NEXT:    [[A1:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
+; CHECK-NEXT:    [[A2:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
+; CHECK-NEXT:    [[A3:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
+; CHECK-NEXT:    [[B0:%.*]] = load float, ptr @a, align 8
+; CHECK-NEXT:    [[B1:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
+; CHECK-NEXT:    [[B2:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
+; CHECK-NEXT:    [[B3:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
+; CHECK-NEXT:    [[R0:%.*]] = frem float [[A0]], [[B0]]
+; CHECK-NEXT:    [[R1:%.*]] = frem float [[A1]], [[B1]]
+; CHECK-NEXT:    [[R2:%.*]] = frem float [[A2]], [[B2]]
+; CHECK-NEXT:    [[R3:%.*]] = frem float [[A3]], [[B3]]
+; CHECK-NEXT:    store float [[R0]], ptr @a, align 8
+; CHECK-NEXT:    store float [[R1]], ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
+; CHECK-NEXT:    store float [[R2]], ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
+; CHECK-NEXT:    store float [[R3]], ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
+; CHECK-NEXT:    ret void
+;
+entry:
+  %a0 = load float, ptr getelementptr inbounds (float, ptr @a, i64 0), align 8
+  %a1 = load float, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
+  %a2 = load float, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
+  %a3 = load float, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
+  %b0 = load float, ptr getelementptr inbounds (float, ptr @a, i64 0), align 8
+  %b1 = load float, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
+  %b2 = load float, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
+  %b3 = load float, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
+  %r0 = frem float %a0, %b0
+  %r1 = frem float %a1, %b1
+  %r2 = frem float %a2, %b2
+  %r3 = frem float %a3, %b3
+  store float %r0, ptr getelementptr inbounds (float, ptr @a, i64 0), align 8
+  store float %r1, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
+  store float %r2, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
+  store float %r3, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
+  ret void
+}
+

>From 29ae086478e3d4bae6b6250670f87273359626d7 Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Mon, 29 Jan 2024 14:10:30 +0000
Subject: [PATCH 2/4] [AArch64] SLP can vectorize frem

When vector library calls are available for frem, given its type and
vector length, the SLP vectorizer uses updated costs that amount to a
call, matching LoopVectorizer's functionality.

This allows 'superword-level' vectorization, which can be converted to
a vector lib call by later passes.

Add tests that vectorize code that contains 2x double and 4x float frem
instructions.
---
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 17 ++++++++--
 .../SLPVectorizer/AArch64/slp-frem.ll         | 32 +++++--------------
 2 files changed, 22 insertions(+), 27 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 4e334748c95934..effe52fe2c4e31 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -8362,9 +8362,20 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
       unsigned OpIdx = isa<UnaryOperator>(VL0) ? 0 : 1;
       TTI::OperandValueInfo Op1Info = getOperandInfo(E->getOperand(0));
       TTI::OperandValueInfo Op2Info = getOperandInfo(E->getOperand(OpIdx));
-      return TTI->getArithmeticInstrCost(ShuffleOrOp, VecTy, CostKind, Op1Info,
-                                         Op2Info) +
-             CommonCost;
+      auto VecCost = TTI->getArithmeticInstrCost(ShuffleOrOp, VecTy, CostKind,
+                                                 Op1Info, Op2Info);
+      // Some targets can replace frem with vector library calls.
+      if (ShuffleOrOp == Instruction::FRem) {
+        LibFunc Func;
+        if (TLI->getLibFunc(ShuffleOrOp, ScalarTy, Func) &&
+            TLI->isFunctionVectorizable(TLI->getName(Func),
+                                        VecTy->getElementCount())) {
+          auto VecCallCost = TTI->getCallInstrCost(
+              nullptr, VecTy, {ScalarTy, ScalarTy}, CostKind);
+          VecCost = std::min(VecCost, VecCallCost);
+        }
+      }
+      return VecCost + CommonCost;
     };
     return GetCostDiff(GetScalarCost, GetVectorCost);
   }
diff --git a/llvm/test/Transforms/SLPVectorizer/AArch64/slp-frem.ll b/llvm/test/Transforms/SLPVectorizer/AArch64/slp-frem.ll
index 45f667f5657889..a38f4bdc4640e9 100644
--- a/llvm/test/Transforms/SLPVectorizer/AArch64/slp-frem.ll
+++ b/llvm/test/Transforms/SLPVectorizer/AArch64/slp-frem.ll
@@ -6,14 +6,10 @@
 define void @frem_v2double() {
 ; CHECK-LABEL: define void @frem_v2double() {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[A0:%.*]] = load double, ptr @a, align 8
-; CHECK-NEXT:    [[A1:%.*]] = load double, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
-; CHECK-NEXT:    [[B0:%.*]] = load double, ptr @a, align 8
-; CHECK-NEXT:    [[B1:%.*]] = load double, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
-; CHECK-NEXT:    [[R0:%.*]] = frem double [[A0]], [[B0]]
-; CHECK-NEXT:    [[R1:%.*]] = frem double [[A1]], [[B1]]
-; CHECK-NEXT:    store double [[R0]], ptr @a, align 8
-; CHECK-NEXT:    store double [[R1]], ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
+; CHECK-NEXT:    [[TMP0:%.*]] = load <2 x double>, ptr @a, align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = load <2 x double>, ptr @a, align 8
+; CHECK-NEXT:    [[TMP2:%.*]] = frem <2 x double> [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    store <2 x double> [[TMP2]], ptr @a, align 8
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -31,22 +27,10 @@ entry:
 define void @frem_v4float() {
 ; CHECK-LABEL: define void @frem_v4float() {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[A0:%.*]] = load float, ptr @a, align 8
-; CHECK-NEXT:    [[A1:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
-; CHECK-NEXT:    [[A2:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
-; CHECK-NEXT:    [[A3:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
-; CHECK-NEXT:    [[B0:%.*]] = load float, ptr @a, align 8
-; CHECK-NEXT:    [[B1:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
-; CHECK-NEXT:    [[B2:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
-; CHECK-NEXT:    [[B3:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
-; CHECK-NEXT:    [[R0:%.*]] = frem float [[A0]], [[B0]]
-; CHECK-NEXT:    [[R1:%.*]] = frem float [[A1]], [[B1]]
-; CHECK-NEXT:    [[R2:%.*]] = frem float [[A2]], [[B2]]
-; CHECK-NEXT:    [[R3:%.*]] = frem float [[A3]], [[B3]]
-; CHECK-NEXT:    store float [[R0]], ptr @a, align 8
-; CHECK-NEXT:    store float [[R1]], ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
-; CHECK-NEXT:    store float [[R2]], ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
-; CHECK-NEXT:    store float [[R3]], ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
+; CHECK-NEXT:    [[TMP0:%.*]] = load <4 x float>, ptr @a, align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = load <4 x float>, ptr @a, align 8
+; CHECK-NEXT:    [[TMP2:%.*]] = frem <4 x float> [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    store <4 x float> [[TMP2]], ptr @a, align 8
 ; CHECK-NEXT:    ret void
 ;
 entry:

>From b4a7eed279a092c5d83b019788373aee93540db6 Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Wed, 21 Feb 2024 18:02:52 +0000
Subject: [PATCH 3/4] Added 'getVecLibCallCost' in TTI.

Unfortunately TLI (TargetLibraryInfo) is not available in TTI and
changing the signature of 'getArithmeticInstrCost' would cause
significant changes in loads of places.

As a compromise getVecLibCallCost returns a vector library exist for
a given target + vector type.
---
 .../llvm/Analysis/TargetTransformInfo.h        |  7 +++++++
 llvm/lib/Analysis/TargetTransformInfo.cpp      | 13 +++++++++++++
 .../lib/Transforms/Vectorize/SLPVectorizer.cpp | 18 +++++-------------
 3 files changed, 25 insertions(+), 13 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 58577a6b6eb5c0..bd331693745267 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1255,6 +1255,13 @@ class TargetTransformInfo {
       ArrayRef<const Value *> Args = ArrayRef<const Value *>(),
       const Instruction *CxtI = nullptr) const;
 
+  /// Returns the cost of a call when a target has a vector library function for
+  /// the given \p VecTy, otherwise an invalid cost.
+  InstructionCost getVecLibCallCost(const int OpCode,
+                                    const TargetLibraryInfo *TLI,
+                                    VectorType *VecTy,
+                                    TTI::TargetCostKind CostKind);
+
   /// Returns the cost estimation for alternating opcode pattern that can be
   /// lowered to a single instruction on the target. In X86 this is for the
   /// addsub instruction which corrsponds to a Shuffle + Fadd + FSub pattern in
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 1f11f0d7dd620e..58d39069aa740f 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -9,6 +9,7 @@
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/CFG.h"
 #include "llvm/Analysis/LoopIterator.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/TargetTransformInfoImpl.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Dominators.h"
@@ -869,6 +870,18 @@ TargetTransformInfo::getOperandInfo(const Value *V) {
   return {OpInfo, OpProps};
 }
 
+InstructionCost TargetTransformInfo::getVecLibCallCost(
+    const int OpCode, const TargetLibraryInfo *TLI, VectorType *VecTy,
+    TTI::TargetCostKind CostKind) {
+  Type *ScalarTy = VecTy->getScalarType();
+  LibFunc Func;
+  if (TLI->getLibFunc(OpCode, ScalarTy, Func) &&
+      TLI->isFunctionVectorizable(TLI->getName(Func), VecTy->getElementCount()))
+    return getCallInstrCost(nullptr, VecTy, {ScalarTy, ScalarTy}, CostKind);
+
+  return InstructionCost::getInvalid();
+}
+
 InstructionCost TargetTransformInfo::getArithmeticInstrCost(
     unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
     OperandValueInfo Op1Info, OperandValueInfo Op2Info,
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index effe52fe2c4e31..40958258565c81 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -8362,20 +8362,12 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
       unsigned OpIdx = isa<UnaryOperator>(VL0) ? 0 : 1;
       TTI::OperandValueInfo Op1Info = getOperandInfo(E->getOperand(0));
       TTI::OperandValueInfo Op2Info = getOperandInfo(E->getOperand(OpIdx));
-      auto VecCost = TTI->getArithmeticInstrCost(ShuffleOrOp, VecTy, CostKind,
-                                                 Op1Info, Op2Info);
+      InstructionCost VecInstrCost = TTI->getArithmeticInstrCost(
+          ShuffleOrOp, VecTy, CostKind, Op1Info, Op2Info);
       // Some targets can replace frem with vector library calls.
-      if (ShuffleOrOp == Instruction::FRem) {
-        LibFunc Func;
-        if (TLI->getLibFunc(ShuffleOrOp, ScalarTy, Func) &&
-            TLI->isFunctionVectorizable(TLI->getName(Func),
-                                        VecTy->getElementCount())) {
-          auto VecCallCost = TTI->getCallInstrCost(
-              nullptr, VecTy, {ScalarTy, ScalarTy}, CostKind);
-          VecCost = std::min(VecCost, VecCallCost);
-        }
-      }
-      return VecCost + CommonCost;
+      InstructionCost VecCallCost =
+          TTI->getVecLibCallCost(ShuffleOrOp, TLI, VecTy, CostKind);
+      return std::min(VecInstrCost, VecCallCost) + CommonCost;
     };
     return GetCostDiff(GetScalarCost, GetVectorCost);
   }

>From 0a77f3a0bfd0197012aa8ed48f8252863e93ad9d Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Thu, 22 Feb 2024 09:25:16 +0000
Subject: [PATCH 4/4] Addressing reviewers

---
 .../include/llvm/Analysis/TargetTransformInfo.h |  7 -------
 llvm/include/llvm/Analysis/VectorUtils.h        | 10 +++++++++-
 llvm/lib/Analysis/TargetTransformInfo.cpp       | 13 -------------
 llvm/lib/Analysis/VectorUtils.cpp               | 17 +++++++++++++++++
 llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp |  2 +-
 5 files changed, 27 insertions(+), 22 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index bd331693745267..58577a6b6eb5c0 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1255,13 +1255,6 @@ class TargetTransformInfo {
       ArrayRef<const Value *> Args = ArrayRef<const Value *>(),
       const Instruction *CxtI = nullptr) const;
 
-  /// Returns the cost of a call when a target has a vector library function for
-  /// the given \p VecTy, otherwise an invalid cost.
-  InstructionCost getVecLibCallCost(const int OpCode,
-                                    const TargetLibraryInfo *TLI,
-                                    VectorType *VecTy,
-                                    TTI::TargetCostKind CostKind);
-
   /// Returns the cost estimation for alternating opcode pattern that can be
   /// lowered to a single instruction on the target. In X86 this is for the
   /// addsub instruction which corrsponds to a Shuffle + Fadd + FSub pattern in
diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index 7a92e62b53c53d..d68c5a81ad11bb 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -16,6 +16,7 @@
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Analysis/LoopAccessAnalysis.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/IR/VFABIDemangler.h"
 #include "llvm/Support/CheckedArithmetic.h"
 
@@ -119,7 +120,6 @@ template <typename InstTy> class InterleaveGroup;
 class IRBuilderBase;
 class Loop;
 class ScalarEvolution;
-class TargetTransformInfo;
 class Type;
 class Value;
 
@@ -410,6 +410,14 @@ bool maskIsAllOneOrUndef(Value *Mask);
 /// for each lane which may be active.
 APInt possiblyDemandedEltsInMask(Value *Mask);
 
+/// Returns the cost of a call when a target has a vector library function for
+/// the given \p VecTy, otherwise an invalid cost.
+InstructionCost getVecLibCallCost(const Instruction *I,
+                                  const TargetTransformInfo *TTI,
+                                  const TargetLibraryInfo *TLI,
+                                  VectorType *VecTy,
+                                  TargetTransformInfo::TargetCostKind CostKind);
+
 /// The group of interleaved loads/stores sharing the same stride and
 /// close to each other.
 ///
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 58d39069aa740f..1f11f0d7dd620e 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -9,7 +9,6 @@
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/CFG.h"
 #include "llvm/Analysis/LoopIterator.h"
-#include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/TargetTransformInfoImpl.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Dominators.h"
@@ -870,18 +869,6 @@ TargetTransformInfo::getOperandInfo(const Value *V) {
   return {OpInfo, OpProps};
 }
 
-InstructionCost TargetTransformInfo::getVecLibCallCost(
-    const int OpCode, const TargetLibraryInfo *TLI, VectorType *VecTy,
-    TTI::TargetCostKind CostKind) {
-  Type *ScalarTy = VecTy->getScalarType();
-  LibFunc Func;
-  if (TLI->getLibFunc(OpCode, ScalarTy, Func) &&
-      TLI->isFunctionVectorizable(TLI->getName(Func), VecTy->getElementCount()))
-    return getCallInstrCost(nullptr, VecTy, {ScalarTy, ScalarTy}, CostKind);
-
-  return InstructionCost::getInvalid();
-}
-
 InstructionCost TargetTransformInfo::getArithmeticInstrCost(
     unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
     OperandValueInfo Op1Info, OperandValueInfo Op2Info,
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index 73facc76a92b2c..dcd1d072139b66 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -18,6 +18,7 @@
 #include "llvm/Analysis/LoopIterator.h"
 #include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/Constants.h"
@@ -1031,6 +1032,22 @@ APInt llvm::possiblyDemandedEltsInMask(Value *Mask) {
   return DemandedElts;
 }
 
+InstructionCost
+llvm::getVecLibCallCost(const Instruction *I, const TargetTransformInfo *TTI,
+                        const TargetLibraryInfo *TLI, VectorType *VecTy,
+                        TargetTransformInfo::TargetCostKind CostKind) {
+  SmallVector<Type *, 4> OpTypes;
+  for (auto &Op : I->operands())
+    OpTypes.push_back(Op->getType());
+
+  LibFunc Func;
+  if (TLI->getLibFunc(I->getOpcode(), I->getType(), Func) &&
+      TLI->isFunctionVectorizable(TLI->getName(Func), VecTy->getElementCount()))
+    return TTI->getCallInstrCost(nullptr, VecTy, OpTypes, CostKind);
+
+  return InstructionCost::getInvalid();
+}
+
 bool InterleavedAccessInfo::isStrided(int Stride) {
   unsigned Factor = std::abs(Stride);
   return Factor >= 2 && Factor <= MaxInterleaveGroupFactor;
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 40958258565c81..99255c272be829 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -8366,7 +8366,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
           ShuffleOrOp, VecTy, CostKind, Op1Info, Op2Info);
       // Some targets can replace frem with vector library calls.
       InstructionCost VecCallCost =
-          TTI->getVecLibCallCost(ShuffleOrOp, TLI, VecTy, CostKind);
+          getVecLibCallCost(VL0, TTI, TLI, VecTy, CostKind);
       return std::min(VecInstrCost, VecCallCost) + CommonCost;
     };
     return GetCostDiff(GetScalarCost, GetVectorCost);



More information about the llvm-branch-commits mailing list