[llvm] 285bc69 - [SLP]Fix PR80027: Fix costs processing for minbitwidth types.
Alexey Bataev via llvm-commits
llvm-commits at lists.llvm.org
Tue Jan 30 10:36:32 PST 2024
Author: Alexey Bataev
Date: 2024-01-30T10:32:55-08:00
New Revision: 285bc69846e76af805cd106ea3ea538a12f5c9b6
URL: https://github.com/llvm/llvm-project/commit/285bc69846e76af805cd106ea3ea538a12f5c9b6
DIFF: https://github.com/llvm/llvm-project/commit/285bc69846e76af805cd106ea3ea538a12f5c9b6.diff
LOG: [SLP]Fix PR80027: Fix costs processing for minbitwidth types.
Need to switch the types, the destination is first in getCastInstrCost
function.
Added:
llvm/test/Transforms/SLPVectorizer/SystemZ/minbitwidth-trunc.ll
Modified:
llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index e0a779467b1fa..bde65717ac1d4 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -7888,7 +7888,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
unsigned SrcBWSz = DL->getTypeSizeInBits(UserScalarTy);
unsigned VecOpcode;
- auto *SrcVecTy =
+ auto *UserVecTy =
FixedVectorType::get(UserScalarTy, E->getVectorFactor());
if (BWSz > SrcBWSz)
VecOpcode = Instruction::Trunc;
@@ -7896,11 +7896,10 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
VecOpcode =
It->second.second ? Instruction::SExt : Instruction::ZExt;
TTI::CastContextHint CCH = GetCastContextHint(VL0);
- VecCost += TTI->getCastInstrCost(VecOpcode, VecTy, SrcVecTy, CCH,
+ VecCost += TTI->getCastInstrCost(VecOpcode, UserVecTy, VecTy, CCH,
CostKind);
- ScalarCost +=
- Sz * TTI->getCastInstrCost(VecOpcode, ScalarTy, UserScalarTy,
- CCH, CostKind);
+ ScalarCost += Sz * TTI->getCastInstrCost(VecOpcode, UserScalarTy,
+ ScalarTy, CCH, CostKind);
}
}
}
@@ -8981,7 +8980,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
SmallVector<std::pair<Value *, const TreeEntry *>> FirstUsers;
SmallVector<APInt> DemandedElts;
SmallDenseSet<Value *, 4> UsedInserts;
- DenseSet<Value *> VectorCasts;
+ DenseSet<std::pair<const TreeEntry *, Type *>> VectorCasts;
for (ExternalUser &EU : ExternalUses) {
// We only add extract cost once for the same scalar.
if (!isa_and_nonnull<InsertElementInst>(EU.User) &&
@@ -9051,11 +9050,14 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
DemandedElts.push_back(APInt::getZero(FTy->getNumElements()));
VecId = FirstUsers.size() - 1;
auto It = MinBWs.find(ScalarTE);
- if (It != MinBWs.end() && VectorCasts.insert(EU.Scalar).second) {
+ if (It != MinBWs.end() &&
+ VectorCasts
+ .insert(std::make_pair(ScalarTE, FTy->getElementType()))
+ .second) {
unsigned BWSz = It->second.second;
- unsigned SrcBWSz = DL->getTypeSizeInBits(FTy->getElementType());
+ unsigned DstBWSz = DL->getTypeSizeInBits(FTy->getElementType());
unsigned VecOpcode;
- if (BWSz < SrcBWSz)
+ if (DstBWSz < BWSz)
VecOpcode = Instruction::Trunc;
else
VecOpcode =
@@ -9108,17 +9110,20 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
}
// Add reduced value cost, if resized.
if (!VectorizedVals.empty()) {
- auto BWIt = MinBWs.find(VectorizableTree.front().get());
+ const TreeEntry &Root = *VectorizableTree.front().get();
+ auto BWIt = MinBWs.find(&Root);
if (BWIt != MinBWs.end()) {
- Type *DstTy = VectorizableTree.front()->Scalars.front()->getType();
+ Type *DstTy = Root.Scalars.front()->getType();
unsigned OriginalSz = DL->getTypeSizeInBits(DstTy);
- unsigned Opcode = Instruction::Trunc;
- if (OriginalSz < BWIt->second.first)
- Opcode = BWIt->second.second ? Instruction::SExt : Instruction::ZExt;
- Type *SrcTy = IntegerType::get(DstTy->getContext(), BWIt->second.first);
- Cost += TTI->getCastInstrCost(Opcode, DstTy, SrcTy,
- TTI::CastContextHint::None,
- TTI::TCK_RecipThroughput);
+ if (OriginalSz != BWIt->second.first) {
+ unsigned Opcode = Instruction::Trunc;
+ if (OriginalSz < BWIt->second.first)
+ Opcode = BWIt->second.second ? Instruction::SExt : Instruction::ZExt;
+ Type *SrcTy = IntegerType::get(DstTy->getContext(), BWIt->second.first);
+ Cost += TTI->getCastInstrCost(Opcode, DstTy, SrcTy,
+ TTI::CastContextHint::None,
+ TTI::TCK_RecipThroughput);
+ }
}
}
@@ -11419,9 +11424,10 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
VecOpcode = Instruction::BitCast;
} else if (BWSz < SrcBWSz) {
VecOpcode = Instruction::Trunc;
- } else if (It != MinBWs.end()) {
+ } else if (SrcIt != MinBWs.end()) {
assert(BWSz > SrcBWSz && "Invalid cast!");
- VecOpcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
+ VecOpcode =
+ SrcIt->second.second ? Instruction::SExt : Instruction::ZExt;
}
}
Value *V = (VecOpcode != ShuffleOrOp && VecOpcode == Instruction::BitCast)
@@ -11929,7 +11935,7 @@ Value *BoUpSLP::vectorizeTree(
// basic block. Only one extractelement per block should be emitted.
DenseMap<Value *, DenseMap<BasicBlock *, Instruction *>> ScalarToEEs;
SmallDenseSet<Value *, 4> UsedInserts;
- DenseMap<Value *, Value *> VectorCasts;
+ DenseMap<std::pair<Value *, Type *>, Value *> VectorCasts;
SmallDenseSet<Value *, 4> ScalarsWithNullptrUser;
// Extract all of the elements with the external uses.
for (const auto &ExternalUse : ExternalUses) {
@@ -12050,7 +12056,9 @@ Value *BoUpSLP::vectorizeTree(
// Need to use original vector, if the root is truncated.
auto BWIt = MinBWs.find(E);
if (BWIt != MinBWs.end() && Vec->getType() != VU->getType()) {
- auto VecIt = VectorCasts.find(Scalar);
+ auto *ScalarTy = FTy->getElementType();
+ auto Key = std::make_pair(Vec, ScalarTy);
+ auto VecIt = VectorCasts.find(Key);
if (VecIt == VectorCasts.end()) {
IRBuilder<>::InsertPointGuard Guard(Builder);
if (auto *IVec = dyn_cast<Instruction>(Vec))
@@ -12058,10 +12066,10 @@ Value *BoUpSLP::vectorizeTree(
Vec = Builder.CreateIntCast(
Vec,
FixedVectorType::get(
- cast<VectorType>(VU->getType())->getElementType(),
+ ScalarTy,
cast<FixedVectorType>(Vec->getType())->getNumElements()),
BWIt->second.second);
- VectorCasts.try_emplace(Scalar, Vec);
+ VectorCasts.try_emplace(Key, Vec);
} else {
Vec = VecIt->second;
}
diff --git a/llvm/test/Transforms/SLPVectorizer/SystemZ/minbitwidth-trunc.ll b/llvm/test/Transforms/SLPVectorizer/SystemZ/minbitwidth-trunc.ll
new file mode 100644
index 0000000000000..e1942eb326079
--- /dev/null
+++ b/llvm/test/Transforms/SLPVectorizer/SystemZ/minbitwidth-trunc.ll
@@ -0,0 +1,48 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -S --passes=slp-vectorizer -mtriple=s390x-unknown-linux -mcpu=z14 < %s | FileCheck %s
+
+define void @test() {
+; CHECK-LABEL: define void @test(
+; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT: [[TMP1:%.*]] = zext i8 0 to i32
+; CHECK-NEXT: [[TMP2:%.*]] = zext i8 0 to i32
+; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x i32> <i32 0, i32 poison, i32 0, i32 0>, i32 [[TMP2]], i32 1
+; CHECK-NEXT: [[TMP4:%.*]] = select <4 x i1> zeroinitializer, <4 x i32> zeroinitializer, <4 x i32> [[TMP3]]
+; CHECK-NEXT: [[TMP5:%.*]] = select i1 false, i32 0, i32 0
+; CHECK-NEXT: [[TMP6:%.*]] = select i1 false, i32 0, i32 [[TMP1]]
+; CHECK-NEXT: [[TMP7:%.*]] = select i1 false, i32 0, i32 [[TMP2]]
+; CHECK-NEXT: [[TMP8:%.*]] = call i32 @llvm.vector.reduce.xor.v4i32(<4 x i32> [[TMP4]])
+; CHECK-NEXT: [[OP_RDX:%.*]] = xor i32 [[TMP8]], [[TMP5]]
+; CHECK-NEXT: [[OP_RDX1:%.*]] = xor i32 [[TMP6]], [[TMP7]]
+; CHECK-NEXT: [[OP_RDX2:%.*]] = xor i32 [[OP_RDX]], [[OP_RDX1]]
+; CHECK-NEXT: [[TMP9:%.*]] = trunc i32 [[OP_RDX2]] to i16
+; CHECK-NEXT: store i16 [[TMP9]], ptr null, align 2
+; CHECK-NEXT: ret void
+;
+ %1 = zext i8 0 to i32
+ %.not = icmp sgt i32 0, %1
+ %2 = zext i8 0 to i32
+ %3 = select i1 %.not, i32 0, i32 0
+ %4 = zext i8 0 to i32
+ %.not.1 = icmp sgt i32 0, %4
+ %5 = zext i8 0 to i32
+ %6 = select i1 %.not.1, i32 0, i32 %5
+ %7 = xor i32 %6, %3
+ %8 = zext i8 0 to i32
+ %.not.2 = icmp sgt i32 0, %8
+ %9 = select i1 %.not.2, i32 0, i32 0
+ %10 = xor i32 %9, %7
+ %11 = zext i8 0 to i32
+ %.not.3 = icmp sgt i32 0, %11
+ %12 = select i1 %.not.3, i32 0, i32 0
+ %13 = xor i32 %12, %10
+ %14 = select i1 false, i32 0, i32 0
+ %15 = xor i32 %14, %13
+ %16 = select i1 false, i32 0, i32 %2
+ %17 = xor i32 %16, %15
+ %18 = select i1 false, i32 0, i32 %5
+ %19 = xor i32 %18, %17
+ %20 = trunc i32 %19 to i16
+ store i16 %20, ptr null, align 2
+ ret void
+}
More information about the llvm-commits
mailing list