[llvm] [SLP]Improve minbitwidth analysis for abs/smin/smax/umin/umax intrinsics. (PR #86135)
Alexey Bataev via llvm-commits
llvm-commits at lists.llvm.org
Wed Mar 27 07:53:46 PDT 2024
https://github.com/alexey-bataev updated https://github.com/llvm/llvm-project/pull/86135
>From 80a78476a3611e793a054d9aa2fe78fff2f5ceb1 Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Thu, 21 Mar 2024 15:32:57 +0000
Subject: [PATCH 1/2] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20in?=
=?UTF-8?q?itial=20version?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Created using spr 1.3.5
---
.../Transforms/Vectorize/SLPVectorizer.cpp | 123 ++++++++++++++++--
.../cmp-after-intrinsic-call-minbitwidth.ll | 12 +-
.../X86/store-abs-minbitwidth.ll | 9 +-
3 files changed, 124 insertions(+), 20 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 36b446962c4a63..7f680b7af9b565 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -6994,14 +6994,11 @@ bool BoUpSLP::areAllUsersVectorized(
static std::pair<InstructionCost, InstructionCost>
getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy,
- TargetTransformInfo *TTI, TargetLibraryInfo *TLI) {
+ TargetTransformInfo *TTI, TargetLibraryInfo *TLI,
+ ArrayRef<Type *> VecTys) {
Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
// Calculate the cost of the scalar and vector calls.
- SmallVector<Type *, 4> VecTys;
- for (Use &Arg : CI->args())
- VecTys.push_back(
- FixedVectorType::get(Arg->getType(), VecTy->getNumElements()));
FastMathFlags FMF;
if (auto *FPCI = dyn_cast<FPMathOperator>(CI))
FMF = FPCI->getFastMathFlags();
@@ -9009,7 +9006,25 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
};
auto GetVectorCost = [=](InstructionCost CommonCost) {
auto *CI = cast<CallInst>(VL0);
- auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI);
+ Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
+ SmallVector<Type *> VecTys;
+ for (auto [Idx, Arg] : enumerate(CI->args())) {
+ if (ID != Intrinsic::not_intrinsic) {
+ if (isVectorIntrinsicWithScalarOpAtArg(ID, Idx)) {
+ VecTys.push_back(Arg->getType());
+ continue;
+ }
+ if (It != MinBWs.end()) {
+ VecTys.push_back(FixedVectorType::get(
+ IntegerType::get(CI->getContext(), It->second.first),
+ VecTy->getNumElements()));
+ continue;
+ }
+ }
+ VecTys.push_back(
+ FixedVectorType::get(Arg->getType(), VecTy->getNumElements()));
+ }
+ auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI, VecTys);
return std::min(VecCallCosts.first, VecCallCosts.second) + CommonCost;
};
return GetCostDiff(GetScalarCost, GetVectorCost);
@@ -12462,7 +12477,24 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
- auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI);
+ SmallVector<Type *> VecTys;
+ for (auto [Idx, Arg] : enumerate(CI->args())) {
+ if (ID != Intrinsic::not_intrinsic) {
+ if (isVectorIntrinsicWithScalarOpAtArg(ID, Idx)) {
+ VecTys.push_back(Arg->getType());
+ continue;
+ }
+ if (It != MinBWs.end()) {
+ VecTys.push_back(FixedVectorType::get(
+ IntegerType::get(CI->getContext(), It->second.first),
+ VecTy->getNumElements()));
+ continue;
+ }
+ }
+ VecTys.push_back(
+ FixedVectorType::get(Arg->getType(), VecTy->getNumElements()));
+ }
+ auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI, VecTys);
bool UseIntrinsic = ID != Intrinsic::not_intrinsic &&
VecCallCosts.first <= VecCallCosts.second;
@@ -12471,14 +12503,13 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
SmallVector<Type *, 2> TysForDecl;
// Add return type if intrinsic is overloaded on it.
if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
- TysForDecl.push_back(
- FixedVectorType::get(CI->getType(), E->Scalars.size()));
+ TysForDecl.push_back(VecTy);
+ auto *CEI = cast<CallInst>(VL0);
for (unsigned I : seq<unsigned>(0, CI->arg_size())) {
ValueList OpVL;
// Some intrinsics have scalar arguments. This argument should not be
// vectorized.
if (UseIntrinsic && isVectorIntrinsicWithScalarOpAtArg(ID, I)) {
- CallInst *CEI = cast<CallInst>(VL0);
ScalarArg = CEI->getArgOperand(I);
OpVecs.push_back(CEI->getArgOperand(I));
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
@@ -12491,6 +12522,15 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
return E->VectorizedValue;
}
+ ScalarArg = CEI->getArgOperand(I);
+ if (cast<VectorType>(OpVec->getType())->getElementType() !=
+ ScalarArg->getType() && It == MinBWs.end()) {
+ auto *CastTy = FixedVectorType::get(ScalarArg->getType(),
+ VecTy->getNumElements());
+ OpVec = Builder.CreateIntCast(OpVec, CastTy, GetOperandSignedness(I));
+ } else if (It != MinBWs.end()) {
+ OpVec = Builder.CreateIntCast(OpVec, VecTy, GetOperandSignedness(I));
+ }
LLVM_DEBUG(dbgs() << "SLP: OpVec[" << I << "]: " << *OpVec << "\n");
OpVecs.push_back(OpVec);
if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
@@ -14195,6 +14235,69 @@ bool BoUpSLP::collectValuesToDemote(
break;
}
+ case Instruction::Call: {
+ if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
+ return false;
+ if (auto *IC = dyn_cast<IntrinsicInst>(I)) {
+ Intrinsic::ID ID = getVectorIntrinsicIDForCall(IC, TLI);
+ if (ID == Intrinsic::abs || ID == Intrinsic::smin ||
+ ID == Intrinsic::smax || ID == Intrinsic::umin ||
+ ID == Intrinsic::umax) {
+ InstructionCost BestCost =
+ std::numeric_limits<InstructionCost::CostType>::max();
+ unsigned BestBitWidth = BitWidth;
+ unsigned VF = ITE->Scalars.size();
+ // Choose the best bitwidth based on cost estimations.
+ (void)AttemptCheckBitwidth(
+ [&](unsigned BitWidth, unsigned) {
+ SmallVector<Type *> VecTys;
+ auto *ITy =
+ IntegerType::get(IC->getContext(), PowerOf2Ceil(BitWidth));
+ for (auto [Idx, Arg] : enumerate(IC->args())) {
+ if (isVectorIntrinsicWithScalarOpAtArg(ID, Idx)) {
+ VecTys.push_back(Arg->getType());
+ continue;
+ }
+ VecTys.push_back(FixedVectorType::get(ITy, VF));
+ }
+ auto VecCallCosts = getVectorCallCosts(
+ IC, FixedVectorType::get(ITy, VF), TTI, TLI, VecTys);
+ InstructionCost Cost =
+ std::min(VecCallCosts.first, VecCallCosts.second);
+ if (Cost < BestCost) {
+ BestCost = Cost;
+ BestBitWidth = BitWidth;
+ }
+ return false;
+ },
+ NeedToExit);
+ NeedToExit = false;
+ BitWidth = BestBitWidth;
+ switch (ID) {
+ case Intrinsic::abs:
+ End = 1;
+ if (!ProcessOperands(IC->getArgOperand(0), NeedToExit))
+ return false;
+ break;
+ case Intrinsic::smin:
+ case Intrinsic::smax:
+ case Intrinsic::umin:
+ case Intrinsic::umax:
+ End = 2;
+ if (!ProcessOperands({IC->getArgOperand(0), IC->getArgOperand(1)},
+ NeedToExit))
+ return false;
+ break;
+ default:
+ llvm_unreachable("Unexpected intrinsic.");
+ }
+ break;
+ }
+ }
+ MaxDepthLevel = 1;
+ return FinalAnalysis();
+ }
+
// Otherwise, conservatively give up.
default:
MaxDepthLevel = 1;
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/cmp-after-intrinsic-call-minbitwidth.ll b/llvm/test/Transforms/SLPVectorizer/X86/cmp-after-intrinsic-call-minbitwidth.ll
index a05d4fdd6315b5..9fa88084aaa0af 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/cmp-after-intrinsic-call-minbitwidth.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/cmp-after-intrinsic-call-minbitwidth.ll
@@ -5,12 +5,14 @@ define void @test() {
; CHECK-LABEL: define void @test(
; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[TMP0:%.*]] = call <2 x i32> @llvm.smin.v2i32(<2 x i32> zeroinitializer, <2 x i32> zeroinitializer)
-; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> zeroinitializer, <2 x i32> zeroinitializer, <2 x i32> [[TMP0]]
-; CHECK-NEXT: [[TMP2:%.*]] = or <2 x i32> [[TMP1]], zeroinitializer
-; CHECK-NEXT: [[ADD:%.*]] = extractelement <2 x i32> [[TMP2]], i32 1
+; CHECK-NEXT: [[TMP0:%.*]] = call <2 x i2> @llvm.smin.v2i2(<2 x i2> zeroinitializer, <2 x i2> zeroinitializer)
+; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> zeroinitializer, <2 x i2> zeroinitializer, <2 x i2> [[TMP0]]
+; CHECK-NEXT: [[TMP2:%.*]] = or <2 x i2> [[TMP1]], zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = extractelement <2 x i2> [[TMP2]], i32 1
+; CHECK-NEXT: [[ADD:%.*]] = zext i2 [[TMP3]] to i32
; CHECK-NEXT: [[SHR:%.*]] = ashr i32 [[ADD]], 0
-; CHECK-NEXT: [[ADD45:%.*]] = extractelement <2 x i32> [[TMP2]], i32 0
+; CHECK-NEXT: [[TMP5:%.*]] = extractelement <2 x i2> [[TMP2]], i32 0
+; CHECK-NEXT: [[ADD45:%.*]] = zext i2 [[TMP5]] to i32
; CHECK-NEXT: [[ADD152:%.*]] = or i32 [[ADD45]], [[ADD]]
; CHECK-NEXT: [[IDXPROM153:%.*]] = sext i32 [[ADD152]] to i64
; CHECK-NEXT: [[ARRAYIDX154:%.*]] = getelementptr i8, ptr null, i64 [[IDXPROM153]]
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/store-abs-minbitwidth.ll b/llvm/test/Transforms/SLPVectorizer/X86/store-abs-minbitwidth.ll
index e8b854b7cea6cb..60bec6668d23ba 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/store-abs-minbitwidth.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/store-abs-minbitwidth.ll
@@ -13,14 +13,13 @@ define i32 @test(ptr noalias %in, ptr noalias %inn, ptr %out) {
; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <2 x i8> [[TMP3]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <2 x i8> [[TMP2]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <4 x i8> [[TMP5]], <4 x i8> [[TMP6]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
-; CHECK-NEXT: [[TMP8:%.*]] = sext <4 x i8> [[TMP7]] to <4 x i32>
+; CHECK-NEXT: [[TMP8:%.*]] = sext <4 x i8> [[TMP7]] to <4 x i16>
; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <2 x i8> [[TMP1]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <2 x i8> [[TMP4]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <4 x i8> [[TMP9]], <4 x i8> [[TMP10]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
-; CHECK-NEXT: [[TMP12:%.*]] = sext <4 x i8> [[TMP11]] to <4 x i32>
-; CHECK-NEXT: [[TMP13:%.*]] = sub <4 x i32> [[TMP12]], [[TMP8]]
-; CHECK-NEXT: [[TMP14:%.*]] = call <4 x i32> @llvm.abs.v4i32(<4 x i32> [[TMP13]], i1 true)
-; CHECK-NEXT: [[TMP15:%.*]] = trunc <4 x i32> [[TMP14]] to <4 x i16>
+; CHECK-NEXT: [[TMP12:%.*]] = sext <4 x i8> [[TMP11]] to <4 x i16>
+; CHECK-NEXT: [[TMP13:%.*]] = sub <4 x i16> [[TMP12]], [[TMP8]]
+; CHECK-NEXT: [[TMP15:%.*]] = call <4 x i16> @llvm.abs.v4i16(<4 x i16> [[TMP13]], i1 true)
; CHECK-NEXT: store <4 x i16> [[TMP15]], ptr [[OUT:%.*]], align 2
; CHECK-NEXT: ret i32 undef
;
>From eb6ccb0c1ab58df2bb32d216535ef8e8a14e37c4 Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Thu, 21 Mar 2024 15:38:15 +0000
Subject: [PATCH 2/2] Fix formatting
Created using spr 1.3.5
---
llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 7f680b7af9b565..b878eb95ed3466 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -12524,7 +12524,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
}
ScalarArg = CEI->getArgOperand(I);
if (cast<VectorType>(OpVec->getType())->getElementType() !=
- ScalarArg->getType() && It == MinBWs.end()) {
+ ScalarArg->getType() &&
+ It == MinBWs.end()) {
auto *CastTy = FixedVectorType::get(ScalarArg->getType(),
VecTy->getNumElements());
OpVec = Builder.CreateIntCast(OpVec, CastTy, GetOperandSignedness(I));
More information about the llvm-commits
mailing list