[llvm] [SLP]Use getExtendedReduction cost and fix reduction cost calculations (PR #117350)
Alexey Bataev via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 22 09:33:52 PST 2024
https://github.com/alexey-bataev created https://github.com/llvm/llvm-project/pull/117350
Patch uses getExtendedReduction for reductions of ext-based nodes + adds
cost estimation for ctpop-kind reductions into basic implementation and
RISCV-V specific vcpop cost estimation.
>From 6e8a8f667dd4e0a2d94f7c13c3ffa661a21ece7e Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Fri, 22 Nov 2024 17:33:41 +0000
Subject: [PATCH] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20initia?=
=?UTF-8?q?l=20version?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Created using spr 1.3.5
---
llvm/include/llvm/CodeGen/BasicTTIImpl.h | 12 ++
.../Target/RISCV/RISCVTargetTransformInfo.cpp | 8 +
.../Transforms/Vectorize/SLPVectorizer.cpp | 151 ++++++++++++------
.../SLPVectorizer/RISCV/reductions.ll | 18 +--
.../remark-zext-incoming-for-neg-icmp.ll | 2 +-
5 files changed, 125 insertions(+), 66 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index b3583e2819ee4c..d4bd504bf9ba31 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -2765,6 +2765,18 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
Type *ResTy, VectorType *Ty,
FastMathFlags FMF,
TTI::TargetCostKind CostKind) {
+ if (auto *FTy = dyn_cast<FixedVectorType>(Ty);
+ FTy && Opcode == Instruction::Add &&
+ FTy->getElementType() == IntegerType::getInt1Ty(Ty->getContext())) {
+ // Represent vector_reduce_add(ZExt(<n x i1>)) as
+ // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
+ auto *IntTy =
+ IntegerType::get(ResTy->getContext(), FTy->getNumElements());
+ IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF);
+ return thisT()->getCastInstrCost(Instruction::BitCast, IntTy, FTy,
+ TTI::CastContextHint::None, CostKind) +
+ thisT()->getIntrinsicInstrCost(ICA, CostKind);
+ }
// Without any native support, this is equivalent to the cost of
// vecreduce.opcode(ext(Ty A)).
VectorType *ExtTy = VectorType::get(ResTy, Ty);
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 2b16dcbcd8695b..026b1e694e6a64 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -1620,6 +1620,14 @@ InstructionCost RISCVTTIImpl::getExtendedReductionCost(
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
+ if (Opcode == Instruction::Add && LT.second.isFixedLengthVector() &&
+ LT.second.getScalarType() == MVT::i1) {
+ // Represent vector_reduce_add(ZExt(<n x i1>)) as
+ // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
+ return LT.first *
+ getRISCVInstructionCost(RISCV::VCPOP_M, LT.second, CostKind);
+ }
+
if (ResTy->getScalarSizeInBits() != 2 * LT.second.getScalarSizeInBits())
return BaseT::getExtendedReductionCost(Opcode, IsUnsigned, ResTy, ValTy,
FMF, CostKind);
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 8e0ca2677bf0a9..46ae908f57ab89 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -1371,22 +1371,46 @@ class BoUpSLP {
return VectorizableTree.front()->Scalars;
}
+ /// Returns the type/is-signed info for the root node in the graph without
+ /// casting.
+ std::optional<std::pair<Type *, bool>> getRootNodeTypeWithNoCast() const {
+ const TreeEntry &Root = *VectorizableTree.front().get();
+ if (Root.State != TreeEntry::Vectorize || Root.isAltShuffle() ||
+ !Root.Scalars.front()->getType()->isIntegerTy())
+ return std::nullopt;
+ auto It = MinBWs.find(&Root);
+ if (It != MinBWs.end())
+ return std::make_pair(IntegerType::get(Root.Scalars.front()->getContext(),
+ It->second.first),
+ It->second.second);
+ if (Root.getOpcode() == Instruction::ZExt ||
+ Root.getOpcode() == Instruction::SExt)
+ return std::make_pair(cast<CastInst>(Root.getMainOp())->getSrcTy(),
+ Root.getOpcode() == Instruction::SExt);
+ return std::nullopt;
+ }
+
/// Checks if the root graph node can be emitted with narrower bitwidth at
/// codegen and returns it signedness, if so.
bool isSignedMinBitwidthRootNode() const {
return MinBWs.at(VectorizableTree.front().get()).second;
}
- /// Returns reduction bitwidth and signedness, if it does not match the
- /// original requested size.
- std::optional<std::pair<unsigned, bool>> getReductionBitWidthAndSign() const {
+ /// Returns reduction type after minbitdth analysis.
+ FixedVectorType *getReductionType() const {
if (ReductionBitWidth == 0 ||
+ !VectorizableTree.front()->Scalars.front()->getType()->isIntegerTy() ||
ReductionBitWidth >=
DL->getTypeSizeInBits(
VectorizableTree.front()->Scalars.front()->getType()))
- return std::nullopt;
- return std::make_pair(ReductionBitWidth,
- MinBWs.at(VectorizableTree.front().get()).second);
+ return getWidenedType(
+ VectorizableTree.front()->Scalars.front()->getType(),
+ VectorizableTree.front()->getVectorFactor());
+ return getWidenedType(
+ IntegerType::get(
+ VectorizableTree.front()->Scalars.front()->getContext(),
+ ReductionBitWidth),
+ VectorizableTree.front()->getVectorFactor());
}
/// Builds external uses of the vectorized scalars, i.e. the list of
@@ -11297,6 +11321,20 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
return CommonCost;
auto *VI = VL0->getOpcode() == Opcode ? VL0 : nullptr;
TTI::CastContextHint CCH = GetCastContextHint(VL0->getOperand(0));
+
+ bool IsArithmeticExtendedReduction =
+ E->Idx == 0 && UserIgnoreList &&
+ all_of(*UserIgnoreList, [](Value *V) {
+ auto *I = cast<Instruction>(V);
+ return is_contained({Instruction::Add, Instruction::FAdd,
+ Instruction::Mul, Instruction::FMul,
+ Instruction::And, Instruction::Or,
+ Instruction::Xor},
+ I->getOpcode());
+ });
+ if (IsArithmeticExtendedReduction &&
+ (VecOpcode == Instruction::ZExt || VecOpcode == Instruction::SExt))
+ return CommonCost;
return CommonCost +
TTI->getCastInstrCost(VecOpcode, VecTy, SrcVecTy, CCH, CostKind,
VecOpcode == Opcode ? VI : nullptr);
@@ -12652,32 +12690,48 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
unsigned SrcSize = It->second.first;
unsigned DstSize = ReductionBitWidth;
unsigned Opcode = Instruction::Trunc;
- if (SrcSize < DstSize)
- Opcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
- auto *SrcVecTy =
- getWidenedType(Builder.getIntNTy(SrcSize), E.getVectorFactor());
- auto *DstVecTy =
- getWidenedType(Builder.getIntNTy(DstSize), E.getVectorFactor());
- TTI::CastContextHint CCH = getCastContextHint(E);
- InstructionCost CastCost;
- switch (E.getOpcode()) {
- case Instruction::SExt:
- case Instruction::ZExt:
- case Instruction::Trunc: {
- const TreeEntry *OpTE = getOperandEntry(&E, 0);
- CCH = getCastContextHint(*OpTE);
- break;
- }
- default:
- break;
+ if (SrcSize < DstSize) {
+ bool IsArithmeticExtendedReduction =
+ all_of(*UserIgnoreList, [](Value *V) {
+ auto *I = cast<Instruction>(V);
+ return is_contained({Instruction::Add, Instruction::FAdd,
+ Instruction::Mul, Instruction::FMul,
+ Instruction::And, Instruction::Or,
+ Instruction::Xor},
+ I->getOpcode());
+ });
+ if (IsArithmeticExtendedReduction)
+ Opcode =
+ Instruction::BitCast; // Handle it by getExtendedReductionCost
+ else
+ Opcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
+ }
+ if (Opcode != Instruction::BitCast) {
+ auto *SrcVecTy =
+ getWidenedType(Builder.getIntNTy(SrcSize), E.getVectorFactor());
+ auto *DstVecTy =
+ getWidenedType(Builder.getIntNTy(DstSize), E.getVectorFactor());
+ TTI::CastContextHint CCH = getCastContextHint(E);
+ InstructionCost CastCost;
+ switch (E.getOpcode()) {
+ case Instruction::SExt:
+ case Instruction::ZExt:
+ case Instruction::Trunc: {
+ const TreeEntry *OpTE = getOperandEntry(&E, 0);
+ CCH = getCastContextHint(*OpTE);
+ break;
+ }
+ default:
+ break;
+ }
+ CastCost += TTI->getCastInstrCost(Opcode, DstVecTy, SrcVecTy, CCH,
+ TTI::TCK_RecipThroughput);
+ Cost += CastCost;
+ LLVM_DEBUG(dbgs() << "SLP: Adding cost " << CastCost
+ << " for final resize for reduction from " << SrcVecTy
+ << " to " << DstVecTy << "\n";
+ dbgs() << "SLP: Current total cost = " << Cost << "\n");
}
- CastCost += TTI->getCastInstrCost(Opcode, DstVecTy, SrcVecTy, CCH,
- TTI::TCK_RecipThroughput);
- Cost += CastCost;
- LLVM_DEBUG(dbgs() << "SLP: Adding cost " << CastCost
- << " for final resize for reduction from " << SrcVecTy
- << " to " << DstVecTy << "\n";
- dbgs() << "SLP: Current total cost = " << Cost << "\n");
}
}
@@ -19815,8 +19869,8 @@ class HorizontalReduction {
// Estimate cost.
InstructionCost TreeCost = V.getTreeCost(VL);
- InstructionCost ReductionCost = getReductionCost(
- TTI, VL, IsCmpSelMinMax, RdxFMF, V.getReductionBitWidthAndSign());
+ InstructionCost ReductionCost =
+ getReductionCost(TTI, VL, IsCmpSelMinMax, RdxFMF, V);
InstructionCost Cost = TreeCost + ReductionCost;
LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost
<< " for reduction\n");
@@ -20107,14 +20161,14 @@ class HorizontalReduction {
private:
/// Calculate the cost of a reduction.
- InstructionCost getReductionCost(
- TargetTransformInfo *TTI, ArrayRef<Value *> ReducedVals,
- bool IsCmpSelMinMax, FastMathFlags FMF,
- const std::optional<std::pair<unsigned, bool>> BitwidthAndSign) {
+ InstructionCost getReductionCost(TargetTransformInfo *TTI,
+ ArrayRef<Value *> ReducedVals,
+ bool IsCmpSelMinMax, FastMathFlags FMF,
+ const BoUpSLP &R) {
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
Type *ScalarTy = ReducedVals.front()->getType();
unsigned ReduxWidth = ReducedVals.size();
- FixedVectorType *VectorTy = getWidenedType(ScalarTy, ReduxWidth);
+ FixedVectorType *VectorTy = R.getReductionType();
InstructionCost VectorCost = 0, ScalarCost;
// If all of the reduced values are constant, the vector cost is 0, since
// the reduction value can be calculated at the compile time.
@@ -20172,21 +20226,16 @@ class HorizontalReduction {
VecTy, APInt::getAllOnes(ScalarTyNumElements), /*Insert*/ true,
/*Extract*/ false, TTI::TCK_RecipThroughput);
} else {
- auto [Bitwidth, IsSigned] =
- BitwidthAndSign.value_or(std::make_pair(0u, false));
- if (RdxKind == RecurKind::Add && Bitwidth == 1) {
- // Represent vector_reduce_add(ZExt(<n x i1>)) to
- // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
- auto *IntTy = IntegerType::get(ScalarTy->getContext(), ReduxWidth);
- IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF);
- VectorCost =
- TTI->getCastInstrCost(Instruction::BitCast, IntTy,
- getWidenedType(ScalarTy, ReduxWidth),
- TTI::CastContextHint::None, CostKind) +
- TTI->getIntrinsicInstrCost(ICA, CostKind);
- } else {
+ Type *RedTy = VectorTy->getElementType();
+ auto [RType, IsSigned] = R.getRootNodeTypeWithNoCast().value_or(
+ std::make_pair(RedTy, true));
+ if (RType == RedTy) {
VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
FMF, CostKind);
+ } else {
+ VectorCost = TTI->getExtendedReductionCost(
+ RdxOpcode, !IsSigned, RedTy, getWidenedType(RType, ReduxWidth),
+ FMF, CostKind);
}
}
}
diff --git a/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll
index bc24a44cecbe39..85131758853b3d 100644
--- a/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll
+++ b/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll
@@ -877,20 +877,10 @@ entry:
define i64 @red_zext_ld_4xi64(ptr %ptr) {
; CHECK-LABEL: @red_zext_ld_4xi64(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[LD0:%.*]] = load i8, ptr [[PTR:%.*]], align 1
-; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[LD0]] to i64
-; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 1
-; CHECK-NEXT: [[LD1:%.*]] = load i8, ptr [[GEP]], align 1
-; CHECK-NEXT: [[ZEXT_1:%.*]] = zext i8 [[LD1]] to i64
-; CHECK-NEXT: [[ADD_1:%.*]] = add nuw nsw i64 [[ZEXT]], [[ZEXT_1]]
-; CHECK-NEXT: [[GEP_1:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 2
-; CHECK-NEXT: [[LD2:%.*]] = load i8, ptr [[GEP_1]], align 1
-; CHECK-NEXT: [[ZEXT_2:%.*]] = zext i8 [[LD2]] to i64
-; CHECK-NEXT: [[ADD_2:%.*]] = add nuw nsw i64 [[ADD_1]], [[ZEXT_2]]
-; CHECK-NEXT: [[GEP_2:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 3
-; CHECK-NEXT: [[LD3:%.*]] = load i8, ptr [[GEP_2]], align 1
-; CHECK-NEXT: [[ZEXT_3:%.*]] = zext i8 [[LD3]] to i64
-; CHECK-NEXT: [[ADD_3:%.*]] = add nuw nsw i64 [[ADD_2]], [[ZEXT_3]]
+; CHECK-NEXT: [[TMP0:%.*]] = load <4 x i8>, ptr [[PTR:%.*]], align 1
+; CHECK-NEXT: [[TMP1:%.*]] = zext <4 x i8> [[TMP0]] to <4 x i16>
+; CHECK-NEXT: [[TMP2:%.*]] = call i16 @llvm.vector.reduce.add.v4i16(<4 x i16> [[TMP1]])
+; CHECK-NEXT: [[ADD_3:%.*]] = zext i16 [[TMP2]] to i64
; CHECK-NEXT: ret i64 [[ADD_3]]
;
entry:
diff --git a/llvm/test/Transforms/SLPVectorizer/RISCV/remark-zext-incoming-for-neg-icmp.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/remark-zext-incoming-for-neg-icmp.ll
index e4d20a6db8fa67..09c11bbefd4a35 100644
--- a/llvm/test/Transforms/SLPVectorizer/RISCV/remark-zext-incoming-for-neg-icmp.ll
+++ b/llvm/test/Transforms/SLPVectorizer/RISCV/remark-zext-incoming-for-neg-icmp.ll
@@ -8,7 +8,7 @@
; YAML-NEXT: Function: test
; YAML-NEXT: Args:
; YAML-NEXT: - String: 'Vectorized horizontal reduction with cost '
-; YAML-NEXT: - Cost: '-1'
+; YAML-NEXT: - Cost: '-10'
; YAML-NEXT: - String: ' and with tree size '
; YAML-NEXT: - TreeSize: '8'
; YAML-NEXT:...
More information about the llvm-commits
mailing list