[llvm] [AArch64][LV][SLP] Vectorizers use getFRemInstrCost for frem costs (PR #82488)

Paschalis Mpeis via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 23 06:44:08 PST 2024


https://github.com/paschalis-mpeis updated https://github.com/llvm/llvm-project/pull/82488

>From 53ed688d0ac0933161c625b7470b912cb0770ae3 Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Wed, 21 Feb 2024 11:53:00 +0000
Subject: [PATCH 1/4] SLP cannot vectorize frem calls in AArch64.

It needs updated costs when there are available vector library functions
given the VF and type.
---
 .../SLPVectorizer/AArch64/slp-frem.ll         | 71 +++++++++++++++++++
 1 file changed, 71 insertions(+)
 create mode 100644 llvm/test/Transforms/SLPVectorizer/AArch64/slp-frem.ll

diff --git a/llvm/test/Transforms/SLPVectorizer/AArch64/slp-frem.ll b/llvm/test/Transforms/SLPVectorizer/AArch64/slp-frem.ll
new file mode 100644
index 00000000000000..45f667f5657889
--- /dev/null
+++ b/llvm/test/Transforms/SLPVectorizer/AArch64/slp-frem.ll
@@ -0,0 +1,71 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -S -mtriple=aarch64 -vector-library=ArmPL -passes=slp-vectorizer | FileCheck %s
+
+ at a = common global ptr null, align 8
+
+define void @frem_v2double() {
+; CHECK-LABEL: define void @frem_v2double() {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[A0:%.*]] = load double, ptr @a, align 8
+; CHECK-NEXT:    [[A1:%.*]] = load double, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
+; CHECK-NEXT:    [[B0:%.*]] = load double, ptr @a, align 8
+; CHECK-NEXT:    [[B1:%.*]] = load double, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
+; CHECK-NEXT:    [[R0:%.*]] = frem double [[A0]], [[B0]]
+; CHECK-NEXT:    [[R1:%.*]] = frem double [[A1]], [[B1]]
+; CHECK-NEXT:    store double [[R0]], ptr @a, align 8
+; CHECK-NEXT:    store double [[R1]], ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
+; CHECK-NEXT:    ret void
+;
+entry:
+  %a0 = load double, ptr getelementptr inbounds (double, ptr @a, i64 0), align 8
+  %a1 = load double, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
+  %b0 = load double, ptr getelementptr inbounds (double, ptr @a, i64 0), align 8
+  %b1 = load double, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
+  %r0 = frem double %a0, %b0
+  %r1 = frem double %a1, %b1
+  store double %r0, ptr getelementptr inbounds (double, ptr @a, i64 0), align 8
+  store double %r1, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
+  ret void
+}
+
+define void @frem_v4float() {
+; CHECK-LABEL: define void @frem_v4float() {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[A0:%.*]] = load float, ptr @a, align 8
+; CHECK-NEXT:    [[A1:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
+; CHECK-NEXT:    [[A2:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
+; CHECK-NEXT:    [[A3:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
+; CHECK-NEXT:    [[B0:%.*]] = load float, ptr @a, align 8
+; CHECK-NEXT:    [[B1:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
+; CHECK-NEXT:    [[B2:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
+; CHECK-NEXT:    [[B3:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
+; CHECK-NEXT:    [[R0:%.*]] = frem float [[A0]], [[B0]]
+; CHECK-NEXT:    [[R1:%.*]] = frem float [[A1]], [[B1]]
+; CHECK-NEXT:    [[R2:%.*]] = frem float [[A2]], [[B2]]
+; CHECK-NEXT:    [[R3:%.*]] = frem float [[A3]], [[B3]]
+; CHECK-NEXT:    store float [[R0]], ptr @a, align 8
+; CHECK-NEXT:    store float [[R1]], ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
+; CHECK-NEXT:    store float [[R2]], ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
+; CHECK-NEXT:    store float [[R3]], ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
+; CHECK-NEXT:    ret void
+;
+entry:
+  %a0 = load float, ptr getelementptr inbounds (float, ptr @a, i64 0), align 8
+  %a1 = load float, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
+  %a2 = load float, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
+  %a3 = load float, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
+  %b0 = load float, ptr getelementptr inbounds (float, ptr @a, i64 0), align 8
+  %b1 = load float, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
+  %b2 = load float, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
+  %b3 = load float, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
+  %r0 = frem float %a0, %b0
+  %r1 = frem float %a1, %b1
+  %r2 = frem float %a2, %b2
+  %r3 = frem float %a3, %b3
+  store float %r0, ptr getelementptr inbounds (float, ptr @a, i64 0), align 8
+  store float %r1, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
+  store float %r2, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
+  store float %r3, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
+  ret void
+}
+

>From d11c1c4ffb8d47f31c1c26796f116c386c7dcf1f Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Mon, 29 Jan 2024 14:10:30 +0000
Subject: [PATCH 2/4] [AArch64] SLP can vectorize frem

When vector library calls are available for frem, given its type and
vector length, the SLP vectorizer uses updated costs that amount to a
call, matching LoopVectorizer's functionality.

This allows 'superword-level' vectorization, which can be converted to
a vector lib call by later passes.

Add tests that vectorize code that contains 2x double and 4x float frem
instructions.
---
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 17 ++++++++--
 .../SLPVectorizer/AArch64/slp-frem.ll         | 32 +++++--------------
 2 files changed, 22 insertions(+), 27 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index de4e56ff80659a..051f3fcc74a614 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -8616,9 +8616,20 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
       unsigned OpIdx = isa<UnaryOperator>(VL0) ? 0 : 1;
       TTI::OperandValueInfo Op1Info = getOperandInfo(E->getOperand(0));
       TTI::OperandValueInfo Op2Info = getOperandInfo(E->getOperand(OpIdx));
-      return TTI->getArithmeticInstrCost(ShuffleOrOp, VecTy, CostKind, Op1Info,
-                                         Op2Info) +
-             CommonCost;
+      auto VecCost = TTI->getArithmeticInstrCost(ShuffleOrOp, VecTy, CostKind,
+                                                 Op1Info, Op2Info);
+      // Some targets can replace frem with vector library calls.
+      if (ShuffleOrOp == Instruction::FRem) {
+        LibFunc Func;
+        if (TLI->getLibFunc(ShuffleOrOp, ScalarTy, Func) &&
+            TLI->isFunctionVectorizable(TLI->getName(Func),
+                                        VecTy->getElementCount())) {
+          auto VecCallCost = TTI->getCallInstrCost(
+              nullptr, VecTy, {ScalarTy, ScalarTy}, CostKind);
+          VecCost = std::min(VecCost, VecCallCost);
+        }
+      }
+      return VecCost + CommonCost;
     };
     return GetCostDiff(GetScalarCost, GetVectorCost);
   }
diff --git a/llvm/test/Transforms/SLPVectorizer/AArch64/slp-frem.ll b/llvm/test/Transforms/SLPVectorizer/AArch64/slp-frem.ll
index 45f667f5657889..a38f4bdc4640e9 100644
--- a/llvm/test/Transforms/SLPVectorizer/AArch64/slp-frem.ll
+++ b/llvm/test/Transforms/SLPVectorizer/AArch64/slp-frem.ll
@@ -6,14 +6,10 @@
 define void @frem_v2double() {
 ; CHECK-LABEL: define void @frem_v2double() {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[A0:%.*]] = load double, ptr @a, align 8
-; CHECK-NEXT:    [[A1:%.*]] = load double, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
-; CHECK-NEXT:    [[B0:%.*]] = load double, ptr @a, align 8
-; CHECK-NEXT:    [[B1:%.*]] = load double, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
-; CHECK-NEXT:    [[R0:%.*]] = frem double [[A0]], [[B0]]
-; CHECK-NEXT:    [[R1:%.*]] = frem double [[A1]], [[B1]]
-; CHECK-NEXT:    store double [[R0]], ptr @a, align 8
-; CHECK-NEXT:    store double [[R1]], ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
+; CHECK-NEXT:    [[TMP0:%.*]] = load <2 x double>, ptr @a, align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = load <2 x double>, ptr @a, align 8
+; CHECK-NEXT:    [[TMP2:%.*]] = frem <2 x double> [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    store <2 x double> [[TMP2]], ptr @a, align 8
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -31,22 +27,10 @@ entry:
 define void @frem_v4float() {
 ; CHECK-LABEL: define void @frem_v4float() {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[A0:%.*]] = load float, ptr @a, align 8
-; CHECK-NEXT:    [[A1:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
-; CHECK-NEXT:    [[A2:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
-; CHECK-NEXT:    [[A3:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
-; CHECK-NEXT:    [[B0:%.*]] = load float, ptr @a, align 8
-; CHECK-NEXT:    [[B1:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
-; CHECK-NEXT:    [[B2:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
-; CHECK-NEXT:    [[B3:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
-; CHECK-NEXT:    [[R0:%.*]] = frem float [[A0]], [[B0]]
-; CHECK-NEXT:    [[R1:%.*]] = frem float [[A1]], [[B1]]
-; CHECK-NEXT:    [[R2:%.*]] = frem float [[A2]], [[B2]]
-; CHECK-NEXT:    [[R3:%.*]] = frem float [[A3]], [[B3]]
-; CHECK-NEXT:    store float [[R0]], ptr @a, align 8
-; CHECK-NEXT:    store float [[R1]], ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
-; CHECK-NEXT:    store float [[R2]], ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
-; CHECK-NEXT:    store float [[R3]], ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
+; CHECK-NEXT:    [[TMP0:%.*]] = load <4 x float>, ptr @a, align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = load <4 x float>, ptr @a, align 8
+; CHECK-NEXT:    [[TMP2:%.*]] = frem <4 x float> [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    store <4 x float> [[TMP2]], ptr @a, align 8
 ; CHECK-NEXT:    ret void
 ;
 entry:

>From ce534dcfed0fcbd2bdb8474b739e0d82b28058f9 Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Thu, 22 Feb 2024 09:25:16 +0000
Subject: [PATCH 3/4] Addressing reviewers

---
 llvm/include/llvm/Analysis/VectorUtils.h        | 10 +++++++++-
 llvm/lib/Analysis/VectorUtils.cpp               | 17 +++++++++++++++++
 llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 14 +++-----------
 3 files changed, 29 insertions(+), 12 deletions(-)

diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index 7a92e62b53c53d..d68c5a81ad11bb 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -16,6 +16,7 @@
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Analysis/LoopAccessAnalysis.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/IR/VFABIDemangler.h"
 #include "llvm/Support/CheckedArithmetic.h"
 
@@ -119,7 +120,6 @@ template <typename InstTy> class InterleaveGroup;
 class IRBuilderBase;
 class Loop;
 class ScalarEvolution;
-class TargetTransformInfo;
 class Type;
 class Value;
 
@@ -410,6 +410,14 @@ bool maskIsAllOneOrUndef(Value *Mask);
 /// for each lane which may be active.
 APInt possiblyDemandedEltsInMask(Value *Mask);
 
+/// Returns the cost of a call when a target has a vector library function for
+/// the given \p VecTy, otherwise an invalid cost.
+InstructionCost getVecLibCallCost(const Instruction *I,
+                                  const TargetTransformInfo *TTI,
+                                  const TargetLibraryInfo *TLI,
+                                  VectorType *VecTy,
+                                  TargetTransformInfo::TargetCostKind CostKind);
+
 /// The group of interleaved loads/stores sharing the same stride and
 /// close to each other.
 ///
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index 73facc76a92b2c..dcd1d072139b66 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -18,6 +18,7 @@
 #include "llvm/Analysis/LoopIterator.h"
 #include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/Constants.h"
@@ -1031,6 +1032,22 @@ APInt llvm::possiblyDemandedEltsInMask(Value *Mask) {
   return DemandedElts;
 }
 
+InstructionCost
+llvm::getVecLibCallCost(const Instruction *I, const TargetTransformInfo *TTI,
+                        const TargetLibraryInfo *TLI, VectorType *VecTy,
+                        TargetTransformInfo::TargetCostKind CostKind) {
+  SmallVector<Type *, 4> OpTypes;
+  for (auto &Op : I->operands())
+    OpTypes.push_back(Op->getType());
+
+  LibFunc Func;
+  if (TLI->getLibFunc(I->getOpcode(), I->getType(), Func) &&
+      TLI->isFunctionVectorizable(TLI->getName(Func), VecTy->getElementCount()))
+    return TTI->getCallInstrCost(nullptr, VecTy, OpTypes, CostKind);
+
+  return InstructionCost::getInvalid();
+}
+
 bool InterleavedAccessInfo::isStrided(int Stride) {
   unsigned Factor = std::abs(Stride);
   return Factor >= 2 && Factor <= MaxInterleaveGroupFactor;
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 051f3fcc74a614..9613e1963909a1 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -8619,17 +8619,9 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
       auto VecCost = TTI->getArithmeticInstrCost(ShuffleOrOp, VecTy, CostKind,
                                                  Op1Info, Op2Info);
       // Some targets can replace frem with vector library calls.
-      if (ShuffleOrOp == Instruction::FRem) {
-        LibFunc Func;
-        if (TLI->getLibFunc(ShuffleOrOp, ScalarTy, Func) &&
-            TLI->isFunctionVectorizable(TLI->getName(Func),
-                                        VecTy->getElementCount())) {
-          auto VecCallCost = TTI->getCallInstrCost(
-              nullptr, VecTy, {ScalarTy, ScalarTy}, CostKind);
-          VecCost = std::min(VecCost, VecCallCost);
-        }
-      }
-      return VecCost + CommonCost;
+      InstructionCost VecCallCost =
+          getVecLibCallCost(VL0, TTI, TLI, VecTy, CostKind);
+      return std::min(VecInstrCost, VecCallCost) + CommonCost;
     };
     return GetCostDiff(GetScalarCost, GetVectorCost);
   }

>From 71d6952bc4bdfc8ad679af0e876693557120ac5e Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Fri, 23 Feb 2024 10:36:27 +0000
Subject: [PATCH 4/4] [LV][SLP] Vectorizers now use getFRemInstrCost for frem
 costs

SLP vectorization for frem now happens when vector library calls are
available, given its type and vector length. This is due to using the
updated cost that amounts to a call.

Add tests that do SLP vectorization for code that contains 2x double and
4x float frem instructions.

LoopVectorizer now also uses getFRemInstrCost.
---
 .../llvm/Analysis/TargetTransformInfo.h       | 12 ++++++++
 llvm/include/llvm/Analysis/VectorUtils.h      | 10 +------
 llvm/lib/Analysis/TargetTransformInfo.cpp     | 19 +++++++++++++
 llvm/lib/Analysis/VectorUtils.cpp             | 17 -----------
 .../Transforms/Vectorize/LoopVectorize.cpp    | 28 ++++++-------------
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 16 +++++++----
 6 files changed, 51 insertions(+), 51 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 58577a6b6eb5c0..e396e56d111e0d 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1255,6 +1255,18 @@ class TargetTransformInfo {
       ArrayRef<const Value *> Args = ArrayRef<const Value *>(),
       const Instruction *CxtI = nullptr) const;
 
+  /// Returns the cost of a vector instruction based on the assumption that frem
+  /// will be later transformed (by ReplaceWithVecLib) into a call to a
+  /// platform specific frem vector math function.
+  /// If unsupported, it will return cost using getArithmeticInstrCost.
+  InstructionCost getFRemInstrCost(
+      const TargetLibraryInfo *TLI, unsigned Opcode, Type *Ty,
+      TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput,
+      TTI::OperandValueInfo Opd1Info = {TTI::OK_AnyValue, TTI::OP_None},
+      TTI::OperandValueInfo Opd2Info = {TTI::OK_AnyValue, TTI::OP_None},
+      ArrayRef<const Value *> Args = ArrayRef<const Value *>(),
+      const Instruction *CxtI = nullptr) const;
+
   /// Returns the cost estimation for alternating opcode pattern that can be
   /// lowered to a single instruction on the target. In X86 this is for the
   /// addsub instruction which corrsponds to a Shuffle + Fadd + FSub pattern in
diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index d68c5a81ad11bb..7a92e62b53c53d 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -16,7 +16,6 @@
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Analysis/LoopAccessAnalysis.h"
-#include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/IR/VFABIDemangler.h"
 #include "llvm/Support/CheckedArithmetic.h"
 
@@ -120,6 +119,7 @@ template <typename InstTy> class InterleaveGroup;
 class IRBuilderBase;
 class Loop;
 class ScalarEvolution;
+class TargetTransformInfo;
 class Type;
 class Value;
 
@@ -410,14 +410,6 @@ bool maskIsAllOneOrUndef(Value *Mask);
 /// for each lane which may be active.
 APInt possiblyDemandedEltsInMask(Value *Mask);
 
-/// Returns the cost of a call when a target has a vector library function for
-/// the given \p VecTy, otherwise an invalid cost.
-InstructionCost getVecLibCallCost(const Instruction *I,
-                                  const TargetTransformInfo *TTI,
-                                  const TargetLibraryInfo *TLI,
-                                  VectorType *VecTy,
-                                  TargetTransformInfo::TargetCostKind CostKind);
-
 /// The group of interleaved loads/stores sharing the same stride and
 /// close to each other.
 ///
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 1f11f0d7dd620e..f595923249cb54 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -9,6 +9,7 @@
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/CFG.h"
 #include "llvm/Analysis/LoopIterator.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/TargetTransformInfoImpl.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Dominators.h"
@@ -881,6 +882,24 @@ InstructionCost TargetTransformInfo::getArithmeticInstrCost(
   return Cost;
 }
 
+InstructionCost TargetTransformInfo::getFRemInstrCost(
+    const TargetLibraryInfo *TLI, unsigned Opcode, Type *Ty,
+    TTI::TargetCostKind CostKind, OperandValueInfo Op1Info,
+    OperandValueInfo Op2Info, ArrayRef<const Value *> Args,
+    const Instruction *CxtI) const {
+  assert(Opcode == Instruction::FRem && "Instruction must be frem");
+
+  VectorType *VecTy = dyn_cast<VectorType>(Ty);
+  Type *ScalarTy = VecTy ? VecTy->getScalarType() : Ty;
+  LibFunc Func;
+  if (VecTy && TLI->getLibFunc(Opcode, ScalarTy, Func) &&
+      TLI->isFunctionVectorizable(TLI->getName(Func), VecTy->getElementCount()))
+    return getCallInstrCost(nullptr, VecTy, {VecTy, VecTy}, CostKind);
+
+  return getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info, Op2Info, Args,
+                                CxtI);
+}
+
 InstructionCost TargetTransformInfo::getAltInstrCost(
     VectorType *VecTy, unsigned Opcode0, unsigned Opcode1,
     const SmallBitVector &OpcodeMask, TTI::TargetCostKind CostKind) const {
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index dcd1d072139b66..73facc76a92b2c 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -18,7 +18,6 @@
 #include "llvm/Analysis/LoopIterator.h"
 #include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
-#include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/Constants.h"
@@ -1032,22 +1031,6 @@ APInt llvm::possiblyDemandedEltsInMask(Value *Mask) {
   return DemandedElts;
 }
 
-InstructionCost
-llvm::getVecLibCallCost(const Instruction *I, const TargetTransformInfo *TTI,
-                        const TargetLibraryInfo *TLI, VectorType *VecTy,
-                        TargetTransformInfo::TargetCostKind CostKind) {
-  SmallVector<Type *, 4> OpTypes;
-  for (auto &Op : I->operands())
-    OpTypes.push_back(Op->getType());
-
-  LibFunc Func;
-  if (TLI->getLibFunc(I->getOpcode(), I->getType(), Func) &&
-      TLI->isFunctionVectorizable(TLI->getName(Func), VecTy->getElementCount()))
-    return TTI->getCallInstrCost(nullptr, VecTy, OpTypes, CostKind);
-
-  return InstructionCost::getInvalid();
-}
-
 bool InterleavedAccessInfo::isStrided(int Stride) {
   unsigned Factor = std::abs(Stride);
   return Factor >= 2 && Factor <= MaxInterleaveGroupFactor;
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 373ea751d568ae..55e85faf4f7605 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -6899,25 +6899,15 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF,
       Op2Info.Kind = TargetTransformInfo::OK_UniformValue;
 
     SmallVector<const Value *, 4> Operands(I->operand_values());
-    auto InstrCost = TTI.getArithmeticInstrCost(
-        I->getOpcode(), VectorTy, CostKind,
-        {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
-        Op2Info, Operands, I);
-
-    // Some targets can replace frem with vector library calls.
-    InstructionCost VecCallCost = InstructionCost::getInvalid();
-    if (I->getOpcode() == Instruction::FRem) {
-      LibFunc Func;
-      if (TLI->getLibFunc(I->getOpcode(), I->getType(), Func) &&
-          TLI->isFunctionVectorizable(TLI->getName(Func), VF)) {
-        SmallVector<Type *, 4> OpTypes;
-        for (auto &Op : I->operands())
-          OpTypes.push_back(Op->getType());
-        VecCallCost =
-            TTI.getCallInstrCost(nullptr, VectorTy, OpTypes, CostKind);
-      }
-    }
-    return std::min(InstrCost, VecCallCost);
+    TTI::OperandValueInfo Op1Info{TargetTransformInfo::OK_AnyValue,
+                                  TargetTransformInfo::OP_None};
+    // Some targets replace frem with vector library calls.
+    if (I->getOpcode() == Instruction::FRem)
+      return TTI.getFRemInstrCost(TLI, I->getOpcode(), VectorTy, CostKind,
+                                  Op1Info, Op2Info, Operands, I);
+
+    return TTI.getArithmeticInstrCost(I->getOpcode(), VectorTy, CostKind,
+                                      Op1Info, Op2Info, Operands, I);
   }
   case Instruction::FNeg: {
     return TTI.getArithmeticInstrCost(
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 9613e1963909a1..75b2d839410fd7 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -8616,12 +8616,16 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
       unsigned OpIdx = isa<UnaryOperator>(VL0) ? 0 : 1;
       TTI::OperandValueInfo Op1Info = getOperandInfo(E->getOperand(0));
       TTI::OperandValueInfo Op2Info = getOperandInfo(E->getOperand(OpIdx));
-      auto VecCost = TTI->getArithmeticInstrCost(ShuffleOrOp, VecTy, CostKind,
-                                                 Op1Info, Op2Info);
-      // Some targets can replace frem with vector library calls.
-      InstructionCost VecCallCost =
-          getVecLibCallCost(VL0, TTI, TLI, VecTy, CostKind);
-      return std::min(VecInstrCost, VecCallCost) + CommonCost;
+
+      // Some targets replace frem with vector library calls.
+      if (ShuffleOrOp == Instruction::FRem)
+        return TTI->getFRemInstrCost(TLI, ShuffleOrOp, VecTy, CostKind, Op1Info,
+                                     Op2Info) +
+               CommonCost;
+
+      return TTI->getArithmeticInstrCost(ShuffleOrOp, VecTy, CostKind, Op1Info,
+                                         Op2Info) +
+             CommonCost;
     };
     return GetCostDiff(GetScalarCost, GetVectorCost);
   }



More information about the llvm-commits mailing list