[llvm] [SLP] Improve cost model for i1 select-as-or/and patterns (PR #188572)
Alexey Bataev via llvm-commits
llvm-commits at lists.llvm.org
Thu Mar 26 05:37:37 PDT 2026
https://github.com/alexey-bataev updated https://github.com/llvm/llvm-project/pull/188572
>From 4ef84ee13b619af0a6e53f371e3f23df2e1276d4 Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Wed, 25 Mar 2026 12:24:23 -0700
Subject: [PATCH] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20initia?=
=?UTF-8?q?l=20version?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Created using spr 1.3.7
---
.../Transforms/Vectorize/SLPVectorizer.cpp | 96 ++++++++++++++-----
.../X86/select-logical-or-and-i1-vector.ll | 44 ++++-----
2 files changed, 91 insertions(+), 49 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 6473cc0449af2..63a670413db0b 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -15999,19 +15999,40 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
CmpPredicate CurrentPred = ScalarTy->isFloatingPointTy()
? CmpInst::BAD_FCMP_PREDICATE
: CmpInst::BAD_ICMP_PREDICATE;
+ Value *LHS = nullptr, *RHS = nullptr;
auto MatchCmp = m_Cmp(CurrentPred, m_Value(), m_Value());
- if ((!match(VI, m_Select(MatchCmp, m_Value(), m_Value())) &&
- !match(VI, MatchCmp)) ||
+ bool IsSelect = ShuffleOrOp == Instruction::Select &&
+ match(VI, m_Select(MatchCmp, m_Value(LHS), m_Value(RHS)));
+ if ((!IsSelect && !match(VI, MatchCmp)) ||
(CurrentPred != static_cast<CmpInst::Predicate>(VecPred) &&
CurrentPred != static_cast<CmpInst::Predicate>(SwappedVecPred)))
VecPred = SwappedVecPred = ScalarTy->isFloatingPointTy()
? CmpInst::BAD_FCMP_PREDICATE
: CmpInst::BAD_ICMP_PREDICATE;
- InstructionCost ScalarCost = TTI->getCmpSelInstrCost(
- E->getOpcode(), OrigScalarTy, Builder.getInt1Ty(), CurrentPred,
- CostKind, getOperandInfo(VI->getOperand(0)),
- getOperandInfo(VI->getOperand(1)), VI);
+ // Check if operands are of i1 types, like a condition expression.
+ InstructionCost ScalarCost = InstructionCost::getInvalid();
+ if (IsSelect && LHS->getType() == VI->getOperand(0)->getType()) {
+ assert(LHS->getType() == RHS->getType() &&
+ "Expected same type for LHS/RHS");
+ // select i1 v, i1 true, i1 b -> or i1 v, i1 b
+ if (match(LHS, m_One())) {
+ ScalarCost = TTI->getArithmeticInstrCost(
+ Instruction::Or, LHS->getType(), CostKind,
+ getOperandInfo(VI->getOperand(0)), getOperandInfo(RHS));
+ } else if (match(RHS, m_Zero())) {
+ // select i1 v, i1 b, i1 false -> and i1 v, i1 b
+ ScalarCost = TTI->getArithmeticInstrCost(
+ Instruction::And, LHS->getType(), CostKind,
+ getOperandInfo(VI->getOperand(0)), getOperandInfo(LHS));
+ }
+ }
+ if (!ScalarCost.isValid()) {
+ ScalarCost = TTI->getCmpSelInstrCost(
+ E->getOpcode(), OrigScalarTy, Builder.getInt1Ty(), CurrentPred,
+ CostKind, getOperandInfo(VI->getOperand(0)),
+ getOperandInfo(VI->getOperand(1)), VI);
+ }
InstructionCost IntrinsicCost = GetMinMaxCost(OrigScalarTy, VI);
if (IntrinsicCost.isValid())
ScalarCost = IntrinsicCost;
@@ -16021,25 +16042,50 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
auto GetVectorCost = [&](InstructionCost CommonCost) {
auto *MaskTy = getWidenedType(Builder.getInt1Ty(), VL.size());
- InstructionCost VecCost =
- TTI->getCmpSelInstrCost(E->getOpcode(), VecTy, MaskTy, VecPred,
- CostKind, getOperandInfo(E->getOperand(0)),
- getOperandInfo(E->getOperand(1)), VL0);
- if (auto *SI = dyn_cast<SelectInst>(VL0)) {
- auto *CondType =
- getWidenedType(SI->getCondition()->getType(), VL.size());
- unsigned CondNumElements = CondType->getNumElements();
- unsigned VecTyNumElements = getNumElements(VecTy);
- assert(VecTyNumElements >= CondNumElements &&
- VecTyNumElements % CondNumElements == 0 &&
- "Cannot vectorize Instruction::Select");
- if (CondNumElements != VecTyNumElements) {
- // When the return type is i1 but the source is fixed vector type, we
- // need to duplicate the condition value.
- VecCost += ::getShuffleCost(
- *TTI, TTI::SK_PermuteSingleSrc, CondType,
- createReplicatedMask(VecTyNumElements / CondNumElements,
- CondNumElements));
+ InstructionCost VecCost = InstructionCost::getInvalid();
+ if (ShuffleOrOp == Instruction::Select) {
+ ArrayRef<Value *> Cond = E->getOperand(0);
+ ArrayRef<Value *> LHS = E->getOperand(1);
+ ArrayRef<Value *> RHS = E->getOperand(2);
+ // select <VF x i1>, <VF x i1>, <VF x i1>?
+ if (Cond.front()->getType() == LHS.front()->getType()) {
+ // select <VF x i1> v, <VF x i1> true, <VF x i1> b -> or <VF x i1> v,
+ // <VF x i1> b
+ if (all_of(LHS, [&](Value *V) { return match(V, m_One()); })) {
+ VecCost = TTI->getArithmeticInstrCost(
+ Instruction::Or, VecTy, CostKind, getOperandInfo(Cond),
+ getOperandInfo(RHS));
+ } else if (all_of(RHS,
+ [&](Value *V) { return match(V, m_Zero()); })) {
+ // select <VF x i1> v, <VF x i1> b, <VF x i1> false -> and <VF x i1>
+ // v, <VF x i1> b
+ VecCost = TTI->getArithmeticInstrCost(
+ Instruction::And, VecTy, CostKind, getOperandInfo(Cond),
+ getOperandInfo(LHS));
+ }
+ }
+ }
+ if (!VecCost.isValid()) {
+ VecCost =
+ TTI->getCmpSelInstrCost(E->getOpcode(), VecTy, MaskTy, VecPred,
+ CostKind, getOperandInfo(E->getOperand(0)),
+ getOperandInfo(E->getOperand(1)), VL0);
+ if (auto *SI = dyn_cast<SelectInst>(VL0)) {
+ auto *CondType =
+ getWidenedType(SI->getCondition()->getType(), VL.size());
+ unsigned CondNumElements = CondType->getNumElements();
+ unsigned VecTyNumElements = getNumElements(VecTy);
+ assert(VecTyNumElements >= CondNumElements &&
+ VecTyNumElements % CondNumElements == 0 &&
+ "Cannot vectorize Instruction::Select");
+ if (CondNumElements != VecTyNumElements) {
+ // When the return type is i1 but the source is fixed vector type,
+ // we need to duplicate the condition value.
+ VecCost += ::getShuffleCost(
+ *TTI, TTI::SK_PermuteSingleSrc, CondType,
+ createReplicatedMask(VecTyNumElements / CondNumElements,
+ CondNumElements));
+ }
}
}
return VecCost + CommonCost;
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/select-logical-or-and-i1-vector.ll b/llvm/test/Transforms/SLPVectorizer/X86/select-logical-or-and-i1-vector.ll
index 3b2b6f6a37df1..1c532ef31bc39 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/select-logical-or-and-i1-vector.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/select-logical-or-and-i1-vector.ll
@@ -12,18 +12,16 @@ define void @select_logical_or_i1(ptr %dst,
; CHECK-LABEL: define void @select_logical_or_i1(
; CHECK-SAME: ptr [[DST:%.*]], float [[D0:%.*]], float [[D1:%.*]], float [[D2:%.*]], float [[D3:%.*]], float [[THRESHOLD:%.*]], float [[HPHB_VAL:%.*]], i1 [[SCALAR_COND:%.*]], float [[Y0:%.*]], float [[Y1:%.*]], float [[Y2:%.*]], float [[Y3:%.*]], float [[E0:%.*]], float [[E1:%.*]], float [[E2:%.*]], float [[E3:%.*]]) #[[ATTR0:[0-9]+]] {
; CHECK-NEXT: [[ENTRY:.*:]]
-; CHECK-NEXT: [[CMP0:%.*]] = fcmp fast uge float [[D0]], [[THRESHOLD]]
-; CHECK-NEXT: [[CMP1:%.*]] = fcmp fast uge float [[D1]], [[THRESHOLD]]
-; CHECK-NEXT: [[CMP2:%.*]] = fcmp fast uge float [[D2]], [[THRESHOLD]]
-; CHECK-NEXT: [[CMP3:%.*]] = fcmp fast uge float [[D3]], [[THRESHOLD]]
-; CHECK-NEXT: [[OR3:%.*]] = select i1 [[CMP3]], i1 true, i1 [[SCALAR_COND]]
-; CHECK-NEXT: [[OR2:%.*]] = select i1 [[CMP2]], i1 true, i1 [[SCALAR_COND]]
-; CHECK-NEXT: [[OR1:%.*]] = select i1 [[CMP1]], i1 true, i1 [[SCALAR_COND]]
-; CHECK-NEXT: [[OR0:%.*]] = select i1 [[CMP0]], i1 true, i1 [[SCALAR_COND]]
-; CHECK-NEXT: [[TMP0:%.*]] = insertelement <4 x i1> poison, i1 [[OR0]], i32 0
-; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x i1> [[TMP0]], i1 [[OR1]], i32 1
-; CHECK-NEXT: [[TMP2:%.*]] = insertelement <4 x i1> [[TMP1]], i1 [[OR2]], i32 2
-; CHECK-NEXT: [[TMP9:%.*]] = insertelement <4 x i1> [[TMP2]], i1 [[OR3]], i32 3
+; CHECK-NEXT: [[TMP0:%.*]] = insertelement <4 x float> poison, float [[D0]], i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x float> [[TMP0]], float [[D1]], i32 1
+; CHECK-NEXT: [[TMP2:%.*]] = insertelement <4 x float> [[TMP1]], float [[D2]], i32 2
+; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x float> [[TMP2]], float [[D3]], i32 3
+; CHECK-NEXT: [[TMP4:%.*]] = insertelement <4 x float> poison, float [[THRESHOLD]], i32 0
+; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <4 x float> [[TMP4]], <4 x float> poison, <4 x i32> zeroinitializer
+; CHECK-NEXT: [[TMP6:%.*]] = fcmp fast uge <4 x float> [[TMP3]], [[TMP5]]
+; CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x i1> poison, i1 [[SCALAR_COND]], i32 0
+; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <4 x i1> [[TMP7]], <4 x i1> poison, <4 x i32> zeroinitializer
+; CHECK-NEXT: [[TMP9:%.*]] = select <4 x i1> [[TMP6]], <4 x i1> splat (i1 true), <4 x i1> [[TMP8]]
; CHECK-NEXT: [[TMP10:%.*]] = insertelement <4 x float> poison, float [[HPHB_VAL]], i32 0
; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <4 x float> [[TMP10]], <4 x float> poison, <4 x i32> zeroinitializer
; CHECK-NEXT: [[TMP12:%.*]] = select <4 x i1> [[TMP9]], <4 x float> zeroinitializer, <4 x float> [[TMP11]]
@@ -86,18 +84,16 @@ define void @select_logical_and_i1(ptr %dst,
; CHECK-LABEL: define void @select_logical_and_i1(
; CHECK-SAME: ptr [[DST:%.*]], float [[D0:%.*]], float [[D1:%.*]], float [[D2:%.*]], float [[D3:%.*]], float [[THRESHOLD:%.*]], float [[HPHB_VAL:%.*]], i1 [[SCALAR_COND:%.*]], float [[Y0:%.*]], float [[Y1:%.*]], float [[Y2:%.*]], float [[Y3:%.*]], float [[E0:%.*]], float [[E1:%.*]], float [[E2:%.*]], float [[E3:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: [[ENTRY:.*:]]
-; CHECK-NEXT: [[CMP0:%.*]] = fcmp fast uge float [[D0]], [[THRESHOLD]]
-; CHECK-NEXT: [[CMP1:%.*]] = fcmp fast uge float [[D1]], [[THRESHOLD]]
-; CHECK-NEXT: [[CMP2:%.*]] = fcmp fast uge float [[D2]], [[THRESHOLD]]
-; CHECK-NEXT: [[CMP3:%.*]] = fcmp fast uge float [[D3]], [[THRESHOLD]]
-; CHECK-NEXT: [[AND3:%.*]] = select i1 [[CMP3]], i1 [[SCALAR_COND]], i1 false
-; CHECK-NEXT: [[AND2:%.*]] = select i1 [[CMP2]], i1 [[SCALAR_COND]], i1 false
-; CHECK-NEXT: [[AND1:%.*]] = select i1 [[CMP1]], i1 [[SCALAR_COND]], i1 false
-; CHECK-NEXT: [[AND0:%.*]] = select i1 [[CMP0]], i1 [[SCALAR_COND]], i1 false
-; CHECK-NEXT: [[TMP0:%.*]] = insertelement <4 x i1> poison, i1 [[AND0]], i32 0
-; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x i1> [[TMP0]], i1 [[AND1]], i32 1
-; CHECK-NEXT: [[TMP2:%.*]] = insertelement <4 x i1> [[TMP1]], i1 [[AND2]], i32 2
-; CHECK-NEXT: [[TMP9:%.*]] = insertelement <4 x i1> [[TMP2]], i1 [[AND3]], i32 3
+; CHECK-NEXT: [[TMP0:%.*]] = insertelement <4 x float> poison, float [[D0]], i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x float> [[TMP0]], float [[D1]], i32 1
+; CHECK-NEXT: [[TMP2:%.*]] = insertelement <4 x float> [[TMP1]], float [[D2]], i32 2
+; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x float> [[TMP2]], float [[D3]], i32 3
+; CHECK-NEXT: [[TMP4:%.*]] = insertelement <4 x float> poison, float [[THRESHOLD]], i32 0
+; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <4 x float> [[TMP4]], <4 x float> poison, <4 x i32> zeroinitializer
+; CHECK-NEXT: [[TMP6:%.*]] = fcmp fast uge <4 x float> [[TMP3]], [[TMP5]]
+; CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x i1> poison, i1 [[SCALAR_COND]], i32 0
+; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <4 x i1> [[TMP7]], <4 x i1> poison, <4 x i32> zeroinitializer
+; CHECK-NEXT: [[TMP9:%.*]] = select <4 x i1> [[TMP6]], <4 x i1> [[TMP8]], <4 x i1> zeroinitializer
; CHECK-NEXT: [[TMP10:%.*]] = insertelement <4 x float> poison, float [[HPHB_VAL]], i32 0
; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <4 x float> [[TMP10]], <4 x float> poison, <4 x i32> zeroinitializer
; CHECK-NEXT: [[TMP12:%.*]] = select <4 x i1> [[TMP9]], <4 x float> zeroinitializer, <4 x float> [[TMP11]]
More information about the llvm-commits
mailing list