[llvm] [AArch64][LV][SLP] Vectorizers use getFRemInstrCost for frem costs (PR #82488)

Paschalis Mpeis via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 11 10:40:24 PDT 2024


https://github.com/paschalis-mpeis updated https://github.com/llvm/llvm-project/pull/82488

>From abe1b4e71e9fe57be4a3962e81c58ce22e313024 Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Wed, 21 Feb 2024 11:53:00 +0000
Subject: [PATCH 1/6] SLP cannot vectorize frem calls in AArch64.

It needs updated costs when there are available vector library functions
given the VF and type.
---
 .../SLPVectorizer/AArch64/slp-frem.ll         | 71 +++++++++++++++++++
 1 file changed, 71 insertions(+)
 create mode 100644 llvm/test/Transforms/SLPVectorizer/AArch64/slp-frem.ll

diff --git a/llvm/test/Transforms/SLPVectorizer/AArch64/slp-frem.ll b/llvm/test/Transforms/SLPVectorizer/AArch64/slp-frem.ll
new file mode 100644
index 00000000000000..45f667f5657889
--- /dev/null
+++ b/llvm/test/Transforms/SLPVectorizer/AArch64/slp-frem.ll
@@ -0,0 +1,71 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -S -mtriple=aarch64 -vector-library=ArmPL -passes=slp-vectorizer | FileCheck %s
+
+ at a = common global ptr null, align 8
+
+define void @frem_v2double() {
+; CHECK-LABEL: define void @frem_v2double() {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[A0:%.*]] = load double, ptr @a, align 8
+; CHECK-NEXT:    [[A1:%.*]] = load double, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
+; CHECK-NEXT:    [[B0:%.*]] = load double, ptr @a, align 8
+; CHECK-NEXT:    [[B1:%.*]] = load double, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
+; CHECK-NEXT:    [[R0:%.*]] = frem double [[A0]], [[B0]]
+; CHECK-NEXT:    [[R1:%.*]] = frem double [[A1]], [[B1]]
+; CHECK-NEXT:    store double [[R0]], ptr @a, align 8
+; CHECK-NEXT:    store double [[R1]], ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
+; CHECK-NEXT:    ret void
+;
+entry:
+  %a0 = load double, ptr getelementptr inbounds (double, ptr @a, i64 0), align 8
+  %a1 = load double, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
+  %b0 = load double, ptr getelementptr inbounds (double, ptr @a, i64 0), align 8
+  %b1 = load double, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
+  %r0 = frem double %a0, %b0
+  %r1 = frem double %a1, %b1
+  store double %r0, ptr getelementptr inbounds (double, ptr @a, i64 0), align 8
+  store double %r1, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
+  ret void
+}
+
+define void @frem_v4float() {
+; CHECK-LABEL: define void @frem_v4float() {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[A0:%.*]] = load float, ptr @a, align 8
+; CHECK-NEXT:    [[A1:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
+; CHECK-NEXT:    [[A2:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
+; CHECK-NEXT:    [[A3:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
+; CHECK-NEXT:    [[B0:%.*]] = load float, ptr @a, align 8
+; CHECK-NEXT:    [[B1:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
+; CHECK-NEXT:    [[B2:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
+; CHECK-NEXT:    [[B3:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
+; CHECK-NEXT:    [[R0:%.*]] = frem float [[A0]], [[B0]]
+; CHECK-NEXT:    [[R1:%.*]] = frem float [[A1]], [[B1]]
+; CHECK-NEXT:    [[R2:%.*]] = frem float [[A2]], [[B2]]
+; CHECK-NEXT:    [[R3:%.*]] = frem float [[A3]], [[B3]]
+; CHECK-NEXT:    store float [[R0]], ptr @a, align 8
+; CHECK-NEXT:    store float [[R1]], ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
+; CHECK-NEXT:    store float [[R2]], ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
+; CHECK-NEXT:    store float [[R3]], ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
+; CHECK-NEXT:    ret void
+;
+entry:
+  %a0 = load float, ptr getelementptr inbounds (float, ptr @a, i64 0), align 8
+  %a1 = load float, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
+  %a2 = load float, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
+  %a3 = load float, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
+  %b0 = load float, ptr getelementptr inbounds (float, ptr @a, i64 0), align 8
+  %b1 = load float, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
+  %b2 = load float, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
+  %b3 = load float, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
+  %r0 = frem float %a0, %b0
+  %r1 = frem float %a1, %b1
+  %r2 = frem float %a2, %b2
+  %r3 = frem float %a3, %b3
+  store float %r0, ptr getelementptr inbounds (float, ptr @a, i64 0), align 8
+  store float %r1, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
+  store float %r2, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
+  store float %r3, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
+  ret void
+}
+

>From 36ce5eb8f1d26f984e46c9da930a1c15085e1dd9 Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Mon, 29 Jan 2024 14:10:30 +0000
Subject: [PATCH 2/6] [AArch64] SLP can vectorize frem

When vector library calls are available for frem, given its type and
vector length, the SLP vectorizer uses updated costs that amount to a
call, matching LoopVectorizer's functionality.

This allows 'superword-level' vectorization, which can be converted to
a vector lib call by later passes.

Add tests that vectorize code that contains 2x double and 4x float frem
instructions.
---
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 17 ++++++++--
 .../SLPVectorizer/AArch64/slp-frem.ll         | 32 +++++--------------
 2 files changed, 22 insertions(+), 27 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 7b99c3ac8c55a5..2a962745dbf988 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -8851,9 +8851,20 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
       unsigned OpIdx = isa<UnaryOperator>(VL0) ? 0 : 1;
       TTI::OperandValueInfo Op1Info = getOperandInfo(E->getOperand(0));
       TTI::OperandValueInfo Op2Info = getOperandInfo(E->getOperand(OpIdx));
-      return TTI->getArithmeticInstrCost(ShuffleOrOp, VecTy, CostKind, Op1Info,
-                                         Op2Info) +
-             CommonCost;
+      auto VecCost = TTI->getArithmeticInstrCost(ShuffleOrOp, VecTy, CostKind,
+                                                 Op1Info, Op2Info);
+      // Some targets can replace frem with vector library calls.
+      if (ShuffleOrOp == Instruction::FRem) {
+        LibFunc Func;
+        if (TLI->getLibFunc(ShuffleOrOp, ScalarTy, Func) &&
+            TLI->isFunctionVectorizable(TLI->getName(Func),
+                                        VecTy->getElementCount())) {
+          auto VecCallCost = TTI->getCallInstrCost(
+              nullptr, VecTy, {ScalarTy, ScalarTy}, CostKind);
+          VecCost = std::min(VecCost, VecCallCost);
+        }
+      }
+      return VecCost + CommonCost;
     };
     return GetCostDiff(GetScalarCost, GetVectorCost);
   }
diff --git a/llvm/test/Transforms/SLPVectorizer/AArch64/slp-frem.ll b/llvm/test/Transforms/SLPVectorizer/AArch64/slp-frem.ll
index 45f667f5657889..a38f4bdc4640e9 100644
--- a/llvm/test/Transforms/SLPVectorizer/AArch64/slp-frem.ll
+++ b/llvm/test/Transforms/SLPVectorizer/AArch64/slp-frem.ll
@@ -6,14 +6,10 @@
 define void @frem_v2double() {
 ; CHECK-LABEL: define void @frem_v2double() {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[A0:%.*]] = load double, ptr @a, align 8
-; CHECK-NEXT:    [[A1:%.*]] = load double, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
-; CHECK-NEXT:    [[B0:%.*]] = load double, ptr @a, align 8
-; CHECK-NEXT:    [[B1:%.*]] = load double, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
-; CHECK-NEXT:    [[R0:%.*]] = frem double [[A0]], [[B0]]
-; CHECK-NEXT:    [[R1:%.*]] = frem double [[A1]], [[B1]]
-; CHECK-NEXT:    store double [[R0]], ptr @a, align 8
-; CHECK-NEXT:    store double [[R1]], ptr getelementptr inbounds (double, ptr @a, i64 1), align 8
+; CHECK-NEXT:    [[TMP0:%.*]] = load <2 x double>, ptr @a, align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = load <2 x double>, ptr @a, align 8
+; CHECK-NEXT:    [[TMP2:%.*]] = frem <2 x double> [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    store <2 x double> [[TMP2]], ptr @a, align 8
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -31,22 +27,10 @@ entry:
 define void @frem_v4float() {
 ; CHECK-LABEL: define void @frem_v4float() {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[A0:%.*]] = load float, ptr @a, align 8
-; CHECK-NEXT:    [[A1:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
-; CHECK-NEXT:    [[A2:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
-; CHECK-NEXT:    [[A3:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
-; CHECK-NEXT:    [[B0:%.*]] = load float, ptr @a, align 8
-; CHECK-NEXT:    [[B1:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
-; CHECK-NEXT:    [[B2:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
-; CHECK-NEXT:    [[B3:%.*]] = load float, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
-; CHECK-NEXT:    [[R0:%.*]] = frem float [[A0]], [[B0]]
-; CHECK-NEXT:    [[R1:%.*]] = frem float [[A1]], [[B1]]
-; CHECK-NEXT:    [[R2:%.*]] = frem float [[A2]], [[B2]]
-; CHECK-NEXT:    [[R3:%.*]] = frem float [[A3]], [[B3]]
-; CHECK-NEXT:    store float [[R0]], ptr @a, align 8
-; CHECK-NEXT:    store float [[R1]], ptr getelementptr inbounds (float, ptr @a, i64 1), align 8
-; CHECK-NEXT:    store float [[R2]], ptr getelementptr inbounds (float, ptr @a, i64 2), align 8
-; CHECK-NEXT:    store float [[R3]], ptr getelementptr inbounds (float, ptr @a, i64 3), align 8
+; CHECK-NEXT:    [[TMP0:%.*]] = load <4 x float>, ptr @a, align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = load <4 x float>, ptr @a, align 8
+; CHECK-NEXT:    [[TMP2:%.*]] = frem <4 x float> [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    store <4 x float> [[TMP2]], ptr @a, align 8
 ; CHECK-NEXT:    ret void
 ;
 entry:

>From 3ed8acc7591ee4a52fa39e54c6013ae6da12e807 Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Thu, 22 Feb 2024 09:25:16 +0000
Subject: [PATCH 3/6] Addressing reviewers

---
 llvm/include/llvm/Analysis/VectorUtils.h        | 10 +++++++++-
 llvm/lib/Analysis/VectorUtils.cpp               | 17 +++++++++++++++++
 llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 14 +++-----------
 3 files changed, 29 insertions(+), 12 deletions(-)

diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index c6eb66cc9660ca..0bdc8007900e60 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -16,6 +16,7 @@
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Analysis/LoopAccessAnalysis.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/IR/VFABIDemangler.h"
 #include "llvm/Support/CheckedArithmetic.h"
 
@@ -119,7 +120,6 @@ template <typename InstTy> class InterleaveGroup;
 class IRBuilderBase;
 class Loop;
 class ScalarEvolution;
-class TargetTransformInfo;
 class Type;
 class Value;
 
@@ -415,6 +415,14 @@ bool maskContainsAllOneOrUndef(Value *Mask);
 /// for each lane which may be active.
 APInt possiblyDemandedEltsInMask(Value *Mask);
 
+/// Returns the cost of a call when a target has a vector library function for
+/// the given \p VecTy, otherwise an invalid cost.
+InstructionCost getVecLibCallCost(const Instruction *I,
+                                  const TargetTransformInfo *TTI,
+                                  const TargetLibraryInfo *TLI,
+                                  VectorType *VecTy,
+                                  TargetTransformInfo::TargetCostKind CostKind);
+
 /// The group of interleaved loads/stores sharing the same stride and
 /// close to each other.
 ///
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index bf7bc0ba84a033..bff1b2dc035e4b 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -18,6 +18,7 @@
 #include "llvm/Analysis/LoopIterator.h"
 #include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/Constants.h"
@@ -1056,6 +1057,22 @@ APInt llvm::possiblyDemandedEltsInMask(Value *Mask) {
   return DemandedElts;
 }
 
+InstructionCost
+llvm::getVecLibCallCost(const Instruction *I, const TargetTransformInfo *TTI,
+                        const TargetLibraryInfo *TLI, VectorType *VecTy,
+                        TargetTransformInfo::TargetCostKind CostKind) {
+  SmallVector<Type *, 4> OpTypes;
+  for (auto &Op : I->operands())
+    OpTypes.push_back(Op->getType());
+
+  LibFunc Func;
+  if (TLI->getLibFunc(I->getOpcode(), I->getType(), Func) &&
+      TLI->isFunctionVectorizable(TLI->getName(Func), VecTy->getElementCount()))
+    return TTI->getCallInstrCost(nullptr, VecTy, OpTypes, CostKind);
+
+  return InstructionCost::getInvalid();
+}
+
 bool InterleavedAccessInfo::isStrided(int Stride) {
   unsigned Factor = std::abs(Stride);
   return Factor >= 2 && Factor <= MaxInterleaveGroupFactor;
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 2a962745dbf988..ee728d6b2c24c8 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -8854,17 +8854,9 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
       auto VecCost = TTI->getArithmeticInstrCost(ShuffleOrOp, VecTy, CostKind,
                                                  Op1Info, Op2Info);
       // Some targets can replace frem with vector library calls.
-      if (ShuffleOrOp == Instruction::FRem) {
-        LibFunc Func;
-        if (TLI->getLibFunc(ShuffleOrOp, ScalarTy, Func) &&
-            TLI->isFunctionVectorizable(TLI->getName(Func),
-                                        VecTy->getElementCount())) {
-          auto VecCallCost = TTI->getCallInstrCost(
-              nullptr, VecTy, {ScalarTy, ScalarTy}, CostKind);
-          VecCost = std::min(VecCost, VecCallCost);
-        }
-      }
-      return VecCost + CommonCost;
+      InstructionCost VecCallCost =
+          getVecLibCallCost(VL0, TTI, TLI, VecTy, CostKind);
+      return std::min(VecInstrCost, VecCallCost) + CommonCost;
     };
     return GetCostDiff(GetScalarCost, GetVectorCost);
   }

>From 34bbbf876b23bd55212c895b13c458b969fb2170 Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Fri, 23 Feb 2024 10:36:27 +0000
Subject: [PATCH 4/6] [LV][SLP] Vectorizers now use getFRemInstrCost for frem
 costs

SLP vectorization for frem now happens when vector library calls are
available, given its type and vector length. This is due to using the
updated cost that amounts to a call.

Add tests that do SLP vectorization for code that contains 2x double and
4x float frem instructions.

LoopVectorizer now also uses getFRemInstrCost.
---
 .../llvm/Analysis/TargetTransformInfo.h       | 12 ++++++++
 llvm/include/llvm/Analysis/VectorUtils.h      | 10 +------
 llvm/lib/Analysis/TargetTransformInfo.cpp     | 19 +++++++++++++
 llvm/lib/Analysis/VectorUtils.cpp             | 17 -----------
 .../Transforms/Vectorize/LoopVectorize.cpp    | 28 ++++++-------------
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 16 +++++++----
 6 files changed, 51 insertions(+), 51 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 4eab357f1b33b6..f472abd41a6f83 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1255,6 +1255,18 @@ class TargetTransformInfo {
       ArrayRef<const Value *> Args = ArrayRef<const Value *>(),
       const Instruction *CxtI = nullptr) const;
 
+  /// Returns the cost of a vector instruction based on the assumption that frem
+  /// will be later transformed (by ReplaceWithVecLib) into a call to a
+  /// platform specific frem vector math function.
+  /// If unsupported, it will return cost using getArithmeticInstrCost.
+  InstructionCost getFRemInstrCost(
+      const TargetLibraryInfo *TLI, unsigned Opcode, Type *Ty,
+      TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput,
+      TTI::OperandValueInfo Opd1Info = {TTI::OK_AnyValue, TTI::OP_None},
+      TTI::OperandValueInfo Opd2Info = {TTI::OK_AnyValue, TTI::OP_None},
+      ArrayRef<const Value *> Args = ArrayRef<const Value *>(),
+      const Instruction *CxtI = nullptr) const;
+
   /// Returns the cost estimation for alternating opcode pattern that can be
   /// lowered to a single instruction on the target. In X86 this is for the
   /// addsub instruction which corrsponds to a Shuffle + Fadd + FSub pattern in
diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index 0bdc8007900e60..c6eb66cc9660ca 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -16,7 +16,6 @@
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Analysis/LoopAccessAnalysis.h"
-#include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/IR/VFABIDemangler.h"
 #include "llvm/Support/CheckedArithmetic.h"
 
@@ -120,6 +119,7 @@ template <typename InstTy> class InterleaveGroup;
 class IRBuilderBase;
 class Loop;
 class ScalarEvolution;
+class TargetTransformInfo;
 class Type;
 class Value;
 
@@ -415,14 +415,6 @@ bool maskContainsAllOneOrUndef(Value *Mask);
 /// for each lane which may be active.
 APInt possiblyDemandedEltsInMask(Value *Mask);
 
-/// Returns the cost of a call when a target has a vector library function for
-/// the given \p VecTy, otherwise an invalid cost.
-InstructionCost getVecLibCallCost(const Instruction *I,
-                                  const TargetTransformInfo *TTI,
-                                  const TargetLibraryInfo *TLI,
-                                  VectorType *VecTy,
-                                  TargetTransformInfo::TargetCostKind CostKind);
-
 /// The group of interleaved loads/stores sharing the same stride and
 /// close to each other.
 ///
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 15311be4dba277..2b61bb136f54ce 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -9,6 +9,7 @@
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/CFG.h"
 #include "llvm/Analysis/LoopIterator.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/TargetTransformInfoImpl.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Dominators.h"
@@ -883,6 +884,24 @@ InstructionCost TargetTransformInfo::getArithmeticInstrCost(
   return Cost;
 }
 
+InstructionCost TargetTransformInfo::getFRemInstrCost(
+    const TargetLibraryInfo *TLI, unsigned Opcode, Type *Ty,
+    TTI::TargetCostKind CostKind, OperandValueInfo Op1Info,
+    OperandValueInfo Op2Info, ArrayRef<const Value *> Args,
+    const Instruction *CxtI) const {
+  assert(Opcode == Instruction::FRem && "Instruction must be frem");
+
+  VectorType *VecTy = dyn_cast<VectorType>(Ty);
+  Type *ScalarTy = VecTy ? VecTy->getScalarType() : Ty;
+  LibFunc Func;
+  if (VecTy && TLI->getLibFunc(Opcode, ScalarTy, Func) &&
+      TLI->isFunctionVectorizable(TLI->getName(Func), VecTy->getElementCount()))
+    return getCallInstrCost(nullptr, VecTy, {VecTy, VecTy}, CostKind);
+
+  return getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info, Op2Info, Args,
+                                CxtI);
+}
+
 InstructionCost TargetTransformInfo::getAltInstrCost(
     VectorType *VecTy, unsigned Opcode0, unsigned Opcode1,
     const SmallBitVector &OpcodeMask, TTI::TargetCostKind CostKind) const {
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index bff1b2dc035e4b..bf7bc0ba84a033 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -18,7 +18,6 @@
 #include "llvm/Analysis/LoopIterator.h"
 #include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
-#include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/Constants.h"
@@ -1057,22 +1056,6 @@ APInt llvm::possiblyDemandedEltsInMask(Value *Mask) {
   return DemandedElts;
 }
 
-InstructionCost
-llvm::getVecLibCallCost(const Instruction *I, const TargetTransformInfo *TTI,
-                        const TargetLibraryInfo *TLI, VectorType *VecTy,
-                        TargetTransformInfo::TargetCostKind CostKind) {
-  SmallVector<Type *, 4> OpTypes;
-  for (auto &Op : I->operands())
-    OpTypes.push_back(Op->getType());
-
-  LibFunc Func;
-  if (TLI->getLibFunc(I->getOpcode(), I->getType(), Func) &&
-      TLI->isFunctionVectorizable(TLI->getName(Func), VecTy->getElementCount()))
-    return TTI->getCallInstrCost(nullptr, VecTy, OpTypes, CostKind);
-
-  return InstructionCost::getInvalid();
-}
-
 bool InterleavedAccessInfo::isStrided(int Stride) {
   unsigned Factor = std::abs(Stride);
   return Factor >= 2 && Factor <= MaxInterleaveGroupFactor;
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index edaad4d033bdf0..edf0b95f6d32e8 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -6911,25 +6911,15 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF,
       Op2Info.Kind = TargetTransformInfo::OK_UniformValue;
 
     SmallVector<const Value *, 4> Operands(I->operand_values());
-    auto InstrCost = TTI.getArithmeticInstrCost(
-        I->getOpcode(), VectorTy, CostKind,
-        {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
-        Op2Info, Operands, I);
-
-    // Some targets can replace frem with vector library calls.
-    InstructionCost VecCallCost = InstructionCost::getInvalid();
-    if (I->getOpcode() == Instruction::FRem) {
-      LibFunc Func;
-      if (TLI->getLibFunc(I->getOpcode(), I->getType(), Func) &&
-          TLI->isFunctionVectorizable(TLI->getName(Func), VF)) {
-        SmallVector<Type *, 4> OpTypes;
-        for (auto &Op : I->operands())
-          OpTypes.push_back(Op->getType());
-        VecCallCost =
-            TTI.getCallInstrCost(nullptr, VectorTy, OpTypes, CostKind);
-      }
-    }
-    return std::min(InstrCost, VecCallCost);
+    TTI::OperandValueInfo Op1Info{TargetTransformInfo::OK_AnyValue,
+                                  TargetTransformInfo::OP_None};
+    // Some targets replace frem with vector library calls.
+    if (I->getOpcode() == Instruction::FRem)
+      return TTI.getFRemInstrCost(TLI, I->getOpcode(), VectorTy, CostKind,
+                                  Op1Info, Op2Info, Operands, I);
+
+    return TTI.getArithmeticInstrCost(I->getOpcode(), VectorTy, CostKind,
+                                      Op1Info, Op2Info, Operands, I);
   }
   case Instruction::FNeg: {
     return TTI.getArithmeticInstrCost(
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index ee728d6b2c24c8..d92434cd2ab2d3 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -8851,12 +8851,16 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
       unsigned OpIdx = isa<UnaryOperator>(VL0) ? 0 : 1;
       TTI::OperandValueInfo Op1Info = getOperandInfo(E->getOperand(0));
       TTI::OperandValueInfo Op2Info = getOperandInfo(E->getOperand(OpIdx));
-      auto VecCost = TTI->getArithmeticInstrCost(ShuffleOrOp, VecTy, CostKind,
-                                                 Op1Info, Op2Info);
-      // Some targets can replace frem with vector library calls.
-      InstructionCost VecCallCost =
-          getVecLibCallCost(VL0, TTI, TLI, VecTy, CostKind);
-      return std::min(VecInstrCost, VecCallCost) + CommonCost;
+
+      // Some targets replace frem with vector library calls.
+      if (ShuffleOrOp == Instruction::FRem)
+        return TTI->getFRemInstrCost(TLI, ShuffleOrOp, VecTy, CostKind, Op1Info,
+                                     Op2Info) +
+               CommonCost;
+
+      return TTI->getArithmeticInstrCost(ShuffleOrOp, VecTy, CostKind, Op1Info,
+                                         Op2Info) +
+             CommonCost;
     };
     return GetCostDiff(GetScalarCost, GetVectorCost);
   }

>From ecd7da705ab614914b0a5f1afd092f0530369617 Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Tue, 27 Feb 2024 12:36:24 +0000
Subject: [PATCH 5/6] Addressing reviewers (2)

---
 llvm/include/llvm/Analysis/TargetTransformInfo.h |  5 +++--
 llvm/lib/Analysis/TargetTransformInfo.cpp        | 16 ++++++----------
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp  |  5 ++---
 llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp  |  5 +----
 4 files changed, 12 insertions(+), 19 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index f472abd41a6f83..d07b9d2e77d273 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1258,9 +1258,10 @@ class TargetTransformInfo {
   /// Returns the cost of a vector instruction based on the assumption that frem
   /// will be later transformed (by ReplaceWithVecLib) into a call to a
   /// platform specific frem vector math function.
-  /// If unsupported, it will return cost using getArithmeticInstrCost.
+  /// Returns the same cost as getArithmeticInstrCost when no math function is
+  /// available.
   InstructionCost getFRemInstrCost(
-      const TargetLibraryInfo *TLI, unsigned Opcode, Type *Ty,
+      const TargetLibraryInfo *TLI, Type *Ty,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput,
       TTI::OperandValueInfo Opd1Info = {TTI::OK_AnyValue, TTI::OP_None},
       TTI::OperandValueInfo Opd2Info = {TTI::OK_AnyValue, TTI::OP_None},
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 2b61bb136f54ce..7c2d871f671d8e 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -885,21 +885,17 @@ InstructionCost TargetTransformInfo::getArithmeticInstrCost(
 }
 
 InstructionCost TargetTransformInfo::getFRemInstrCost(
-    const TargetLibraryInfo *TLI, unsigned Opcode, Type *Ty,
-    TTI::TargetCostKind CostKind, OperandValueInfo Op1Info,
-    OperandValueInfo Op2Info, ArrayRef<const Value *> Args,
-    const Instruction *CxtI) const {
-  assert(Opcode == Instruction::FRem && "Instruction must be frem");
-
+    const TargetLibraryInfo *TLI, Type *Ty, TTI::TargetCostKind CostKind,
+    OperandValueInfo Op1Info, OperandValueInfo Op2Info,
+    ArrayRef<const Value *> Args, const Instruction *CxtI) const {
   VectorType *VecTy = dyn_cast<VectorType>(Ty);
-  Type *ScalarTy = VecTy ? VecTy->getScalarType() : Ty;
   LibFunc Func;
-  if (VecTy && TLI->getLibFunc(Opcode, ScalarTy, Func) &&
+  if (VecTy && TLI->getLibFunc(Instruction::FRem, Ty->getScalarType(), Func) &&
       TLI->isFunctionVectorizable(TLI->getName(Func), VecTy->getElementCount()))
     return getCallInstrCost(nullptr, VecTy, {VecTy, VecTy}, CostKind);
 
-  return getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info, Op2Info, Args,
-                                CxtI);
+  return getArithmeticInstrCost(Instruction::FRem, Ty, CostKind, Op1Info,
+                                Op2Info, Args, CxtI);
 }
 
 InstructionCost TargetTransformInfo::getAltInstrCost(
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index edf0b95f6d32e8..e33090ebac41cb 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -6913,10 +6913,9 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF,
     SmallVector<const Value *, 4> Operands(I->operand_values());
     TTI::OperandValueInfo Op1Info{TargetTransformInfo::OK_AnyValue,
                                   TargetTransformInfo::OP_None};
-    // Some targets replace frem with vector library calls.
     if (I->getOpcode() == Instruction::FRem)
-      return TTI.getFRemInstrCost(TLI, I->getOpcode(), VectorTy, CostKind,
-                                  Op1Info, Op2Info, Operands, I);
+      return TTI.getFRemInstrCost(TLI, VectorTy, CostKind, Op1Info, Op2Info,
+                                  Operands, I);
 
     return TTI.getArithmeticInstrCost(I->getOpcode(), VectorTy, CostKind,
                                       Op1Info, Op2Info, Operands, I);
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index d92434cd2ab2d3..aeee52a59801a5 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -8851,11 +8851,8 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
       unsigned OpIdx = isa<UnaryOperator>(VL0) ? 0 : 1;
       TTI::OperandValueInfo Op1Info = getOperandInfo(E->getOperand(0));
       TTI::OperandValueInfo Op2Info = getOperandInfo(E->getOperand(OpIdx));
-
-      // Some targets replace frem with vector library calls.
       if (ShuffleOrOp == Instruction::FRem)
-        return TTI->getFRemInstrCost(TLI, ShuffleOrOp, VecTy, CostKind, Op1Info,
-                                     Op2Info) +
+        return TTI->getFRemInstrCost(TLI, VecTy, CostKind, Op1Info, Op2Info) +
                CommonCost;
 
       return TTI->getArithmeticInstrCost(ShuffleOrOp, VecTy, CostKind, Op1Info,

>From 6d508fb09fb3a90aa323772e919713b37223ec08 Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Mon, 11 Mar 2024 17:19:20 +0000
Subject: [PATCH 6/6] [AArch64][LV][SLP] Vectorizers use call cost for
 vectorized frem

getArithmeticInstrCost is used by both LoopVectorizer and SLPVectorizer
to compute the cost of frem, which becomes a call cost on AArch64 when
TLI has a vector library function.

Add tests that do SLP vectorization for code that contains 2x double and
4x float frem instructions.
---
 .../llvm/Analysis/TargetTransformInfo.h       | 19 +++---------
 llvm/lib/Analysis/TargetTransformInfo.cpp     | 31 ++++++++++---------
 .../Transforms/Vectorize/LoopVectorize.cpp    | 12 +++----
 .../Transforms/Vectorize/SLPVectorizer.cpp    |  7 ++---
 4 files changed, 27 insertions(+), 42 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index d07b9d2e77d273..cc1996fac42b2b 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1247,26 +1247,17 @@ class TargetTransformInfo {
   /// cases or optimizations based on those values.
   /// \p CxtI is the optional original context instruction, if one exists, to
   /// provide even more information.
+  /// \p TLibInfo use to search for platform specific vector library functions
+  /// for instructions that might be converted to calls. The only known case
+  /// currently is frem.
   InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput,
       TTI::OperandValueInfo Opd1Info = {TTI::OK_AnyValue, TTI::OP_None},
       TTI::OperandValueInfo Opd2Info = {TTI::OK_AnyValue, TTI::OP_None},
       ArrayRef<const Value *> Args = ArrayRef<const Value *>(),
-      const Instruction *CxtI = nullptr) const;
-
-  /// Returns the cost of a vector instruction based on the assumption that frem
-  /// will be later transformed (by ReplaceWithVecLib) into a call to a
-  /// platform specific frem vector math function.
-  /// Returns the same cost as getArithmeticInstrCost when no math function is
-  /// available.
-  InstructionCost getFRemInstrCost(
-      const TargetLibraryInfo *TLI, Type *Ty,
-      TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput,
-      TTI::OperandValueInfo Opd1Info = {TTI::OK_AnyValue, TTI::OP_None},
-      TTI::OperandValueInfo Opd2Info = {TTI::OK_AnyValue, TTI::OP_None},
-      ArrayRef<const Value *> Args = ArrayRef<const Value *>(),
-      const Instruction *CxtI = nullptr) const;
+      const Instruction *CxtI = nullptr,
+      const TargetLibraryInfo *TLibInfo = nullptr) const;
 
   /// Returns the cost estimation for alternating opcode pattern that can be
   /// lowered to a single instruction on the target. In X86 this is for the
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 7c2d871f671d8e..2e0bd843396596 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -875,7 +875,22 @@ TargetTransformInfo::getOperandInfo(const Value *V) {
 InstructionCost TargetTransformInfo::getArithmeticInstrCost(
     unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
     OperandValueInfo Op1Info, OperandValueInfo Op2Info,
-    ArrayRef<const Value *> Args, const Instruction *CxtI) const {
+    ArrayRef<const Value *> Args, const Instruction *CxtI,
+    const TargetLibraryInfo *TLibInfo) const {
+
+  // Use call cost for frem intructions that have platform specific vector math
+  // functions, as those will be replaced with calls later by SelectionDAG or
+  // ReplaceWithVecLib pass.
+  if (TLibInfo && Opcode == Instruction::FRem) {
+    VectorType *VecTy = dyn_cast<VectorType>(Ty);
+    LibFunc Func;
+    if (VecTy &&
+        TLibInfo->getLibFunc(Instruction::FRem, Ty->getScalarType(), Func) &&
+        TLibInfo->isFunctionVectorizable(TLibInfo->getName(Func),
+                                         VecTy->getElementCount()))
+      return getCallInstrCost(nullptr, VecTy, {VecTy, VecTy}, CostKind);
+  }
+
   InstructionCost Cost =
       TTIImpl->getArithmeticInstrCost(Opcode, Ty, CostKind,
                                       Op1Info, Op2Info,
@@ -884,20 +899,6 @@ InstructionCost TargetTransformInfo::getArithmeticInstrCost(
   return Cost;
 }
 
-InstructionCost TargetTransformInfo::getFRemInstrCost(
-    const TargetLibraryInfo *TLI, Type *Ty, TTI::TargetCostKind CostKind,
-    OperandValueInfo Op1Info, OperandValueInfo Op2Info,
-    ArrayRef<const Value *> Args, const Instruction *CxtI) const {
-  VectorType *VecTy = dyn_cast<VectorType>(Ty);
-  LibFunc Func;
-  if (VecTy && TLI->getLibFunc(Instruction::FRem, Ty->getScalarType(), Func) &&
-      TLI->isFunctionVectorizable(TLI->getName(Func), VecTy->getElementCount()))
-    return getCallInstrCost(nullptr, VecTy, {VecTy, VecTy}, CostKind);
-
-  return getArithmeticInstrCost(Instruction::FRem, Ty, CostKind, Op1Info,
-                                Op2Info, Args, CxtI);
-}
-
 InstructionCost TargetTransformInfo::getAltInstrCost(
     VectorType *VecTy, unsigned Opcode0, unsigned Opcode1,
     const SmallBitVector &OpcodeMask, TTI::TargetCostKind CostKind) const {
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index e33090ebac41cb..52b992b19e4b04 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -6911,14 +6911,10 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF,
       Op2Info.Kind = TargetTransformInfo::OK_UniformValue;
 
     SmallVector<const Value *, 4> Operands(I->operand_values());
-    TTI::OperandValueInfo Op1Info{TargetTransformInfo::OK_AnyValue,
-                                  TargetTransformInfo::OP_None};
-    if (I->getOpcode() == Instruction::FRem)
-      return TTI.getFRemInstrCost(TLI, VectorTy, CostKind, Op1Info, Op2Info,
-                                  Operands, I);
-
-    return TTI.getArithmeticInstrCost(I->getOpcode(), VectorTy, CostKind,
-                                      Op1Info, Op2Info, Operands, I);
+    return TTI.getArithmeticInstrCost(
+        I->getOpcode(), VectorTy, CostKind,
+        {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+        Op2Info, Operands, I, TLI);
   }
   case Instruction::FNeg: {
     return TTI.getArithmeticInstrCost(
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index aeee52a59801a5..75730dd3c4969e 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -8851,12 +8851,9 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
       unsigned OpIdx = isa<UnaryOperator>(VL0) ? 0 : 1;
       TTI::OperandValueInfo Op1Info = getOperandInfo(E->getOperand(0));
       TTI::OperandValueInfo Op2Info = getOperandInfo(E->getOperand(OpIdx));
-      if (ShuffleOrOp == Instruction::FRem)
-        return TTI->getFRemInstrCost(TLI, VecTy, CostKind, Op1Info, Op2Info) +
-               CommonCost;
-
       return TTI->getArithmeticInstrCost(ShuffleOrOp, VecTy, CostKind, Op1Info,
-                                         Op2Info) +
+                                         Op2Info, ArrayRef<const Value *>(),
+                                         nullptr, TLI) +
              CommonCost;
     };
     return GetCostDiff(GetScalarCost, GetVectorCost);



More information about the llvm-commits mailing list