[llvm] [SLP] Improve cost model for i1 select-as-or/and patterns (PR #188572)

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 26 05:37:37 PDT 2026


https://github.com/alexey-bataev updated https://github.com/llvm/llvm-project/pull/188572

>From 4ef84ee13b619af0a6e53f371e3f23df2e1276d4 Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Wed, 25 Mar 2026 12:24:23 -0700
Subject: [PATCH] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20initia?=
 =?UTF-8?q?l=20version?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Created using spr 1.3.7
---
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 96 ++++++++++++++-----
 .../X86/select-logical-or-and-i1-vector.ll    | 44 ++++-----
 2 files changed, 91 insertions(+), 49 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 6473cc0449af2..63a670413db0b 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -15999,19 +15999,40 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
       CmpPredicate CurrentPred = ScalarTy->isFloatingPointTy()
                                      ? CmpInst::BAD_FCMP_PREDICATE
                                      : CmpInst::BAD_ICMP_PREDICATE;
+      Value *LHS = nullptr, *RHS = nullptr;
       auto MatchCmp = m_Cmp(CurrentPred, m_Value(), m_Value());
-      if ((!match(VI, m_Select(MatchCmp, m_Value(), m_Value())) &&
-           !match(VI, MatchCmp)) ||
+      bool IsSelect = ShuffleOrOp == Instruction::Select &&
+                      match(VI, m_Select(MatchCmp, m_Value(LHS), m_Value(RHS)));
+      if ((!IsSelect && !match(VI, MatchCmp)) ||
           (CurrentPred != static_cast<CmpInst::Predicate>(VecPred) &&
            CurrentPred != static_cast<CmpInst::Predicate>(SwappedVecPred)))
         VecPred = SwappedVecPred = ScalarTy->isFloatingPointTy()
                                        ? CmpInst::BAD_FCMP_PREDICATE
                                        : CmpInst::BAD_ICMP_PREDICATE;
 
-      InstructionCost ScalarCost = TTI->getCmpSelInstrCost(
-          E->getOpcode(), OrigScalarTy, Builder.getInt1Ty(), CurrentPred,
-          CostKind, getOperandInfo(VI->getOperand(0)),
-          getOperandInfo(VI->getOperand(1)), VI);
+      // Check if operands are of i1 types, like a condition expression.
+      InstructionCost ScalarCost = InstructionCost::getInvalid();
+      if (IsSelect && LHS->getType() == VI->getOperand(0)->getType()) {
+        assert(LHS->getType() == RHS->getType() &&
+               "Expected same type for LHS/RHS");
+        // select i1 v, i1 true, i1 b -> or i1 v, i1 b
+        if (match(LHS, m_One())) {
+          ScalarCost = TTI->getArithmeticInstrCost(
+              Instruction::Or, LHS->getType(), CostKind,
+              getOperandInfo(VI->getOperand(0)), getOperandInfo(RHS));
+        } else if (match(RHS, m_Zero())) {
+          // select i1 v, i1 b, i1 false -> and i1 v, i1 b
+          ScalarCost = TTI->getArithmeticInstrCost(
+              Instruction::And, LHS->getType(), CostKind,
+              getOperandInfo(VI->getOperand(0)), getOperandInfo(LHS));
+        }
+      }
+      if (!ScalarCost.isValid()) {
+        ScalarCost = TTI->getCmpSelInstrCost(
+            E->getOpcode(), OrigScalarTy, Builder.getInt1Ty(), CurrentPred,
+            CostKind, getOperandInfo(VI->getOperand(0)),
+            getOperandInfo(VI->getOperand(1)), VI);
+      }
       InstructionCost IntrinsicCost = GetMinMaxCost(OrigScalarTy, VI);
       if (IntrinsicCost.isValid())
         ScalarCost = IntrinsicCost;
@@ -16021,25 +16042,50 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
     auto GetVectorCost = [&](InstructionCost CommonCost) {
       auto *MaskTy = getWidenedType(Builder.getInt1Ty(), VL.size());
 
-      InstructionCost VecCost =
-          TTI->getCmpSelInstrCost(E->getOpcode(), VecTy, MaskTy, VecPred,
-                                  CostKind, getOperandInfo(E->getOperand(0)),
-                                  getOperandInfo(E->getOperand(1)), VL0);
-      if (auto *SI = dyn_cast<SelectInst>(VL0)) {
-        auto *CondType =
-            getWidenedType(SI->getCondition()->getType(), VL.size());
-        unsigned CondNumElements = CondType->getNumElements();
-        unsigned VecTyNumElements = getNumElements(VecTy);
-        assert(VecTyNumElements >= CondNumElements &&
-               VecTyNumElements % CondNumElements == 0 &&
-               "Cannot vectorize Instruction::Select");
-        if (CondNumElements != VecTyNumElements) {
-          // When the return type is i1 but the source is fixed vector type, we
-          // need to duplicate the condition value.
-          VecCost += ::getShuffleCost(
-              *TTI, TTI::SK_PermuteSingleSrc, CondType,
-              createReplicatedMask(VecTyNumElements / CondNumElements,
-                                   CondNumElements));
+      InstructionCost VecCost = InstructionCost::getInvalid();
+      if (ShuffleOrOp == Instruction::Select) {
+        ArrayRef<Value *> Cond = E->getOperand(0);
+        ArrayRef<Value *> LHS = E->getOperand(1);
+        ArrayRef<Value *> RHS = E->getOperand(2);
+        // select <VF x i1>, <VF x i1>, <VF x i1>?
+        if (Cond.front()->getType() == LHS.front()->getType()) {
+          // select <VF x i1> v, <VF x i1> true, <VF x i1> b -> or <VF x i1> v,
+          // <VF x i1> b
+          if (all_of(LHS, [&](Value *V) { return match(V, m_One()); })) {
+            VecCost = TTI->getArithmeticInstrCost(
+                Instruction::Or, VecTy, CostKind, getOperandInfo(Cond),
+                getOperandInfo(RHS));
+          } else if (all_of(RHS,
+                            [&](Value *V) { return match(V, m_Zero()); })) {
+            // select <VF x i1> v, <VF x i1> b, <VF x i1> false -> and <VF x i1>
+            // v, <VF x i1> b
+            VecCost = TTI->getArithmeticInstrCost(
+                Instruction::And, VecTy, CostKind, getOperandInfo(Cond),
+                getOperandInfo(LHS));
+          }
+        }
+      }
+      if (!VecCost.isValid()) {
+        VecCost =
+            TTI->getCmpSelInstrCost(E->getOpcode(), VecTy, MaskTy, VecPred,
+                                    CostKind, getOperandInfo(E->getOperand(0)),
+                                    getOperandInfo(E->getOperand(1)), VL0);
+        if (auto *SI = dyn_cast<SelectInst>(VL0)) {
+          auto *CondType =
+              getWidenedType(SI->getCondition()->getType(), VL.size());
+          unsigned CondNumElements = CondType->getNumElements();
+          unsigned VecTyNumElements = getNumElements(VecTy);
+          assert(VecTyNumElements >= CondNumElements &&
+                 VecTyNumElements % CondNumElements == 0 &&
+                 "Cannot vectorize Instruction::Select");
+          if (CondNumElements != VecTyNumElements) {
+            // When the return type is i1 but the source is fixed vector type,
+            // we need to duplicate the condition value.
+            VecCost += ::getShuffleCost(
+                *TTI, TTI::SK_PermuteSingleSrc, CondType,
+                createReplicatedMask(VecTyNumElements / CondNumElements,
+                                     CondNumElements));
+          }
         }
       }
       return VecCost + CommonCost;
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/select-logical-or-and-i1-vector.ll b/llvm/test/Transforms/SLPVectorizer/X86/select-logical-or-and-i1-vector.ll
index 3b2b6f6a37df1..1c532ef31bc39 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/select-logical-or-and-i1-vector.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/select-logical-or-and-i1-vector.ll
@@ -12,18 +12,16 @@ define void @select_logical_or_i1(ptr %dst,
 ; CHECK-LABEL: define void @select_logical_or_i1(
 ; CHECK-SAME: ptr [[DST:%.*]], float [[D0:%.*]], float [[D1:%.*]], float [[D2:%.*]], float [[D3:%.*]], float [[THRESHOLD:%.*]], float [[HPHB_VAL:%.*]], i1 [[SCALAR_COND:%.*]], float [[Y0:%.*]], float [[Y1:%.*]], float [[Y2:%.*]], float [[Y3:%.*]], float [[E0:%.*]], float [[E1:%.*]], float [[E2:%.*]], float [[E3:%.*]]) #[[ATTR0:[0-9]+]] {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
-; CHECK-NEXT:    [[CMP0:%.*]] = fcmp fast uge float [[D0]], [[THRESHOLD]]
-; CHECK-NEXT:    [[CMP1:%.*]] = fcmp fast uge float [[D1]], [[THRESHOLD]]
-; CHECK-NEXT:    [[CMP2:%.*]] = fcmp fast uge float [[D2]], [[THRESHOLD]]
-; CHECK-NEXT:    [[CMP3:%.*]] = fcmp fast uge float [[D3]], [[THRESHOLD]]
-; CHECK-NEXT:    [[OR3:%.*]] = select i1 [[CMP3]], i1 true, i1 [[SCALAR_COND]]
-; CHECK-NEXT:    [[OR2:%.*]] = select i1 [[CMP2]], i1 true, i1 [[SCALAR_COND]]
-; CHECK-NEXT:    [[OR1:%.*]] = select i1 [[CMP1]], i1 true, i1 [[SCALAR_COND]]
-; CHECK-NEXT:    [[OR0:%.*]] = select i1 [[CMP0]], i1 true, i1 [[SCALAR_COND]]
-; CHECK-NEXT:    [[TMP0:%.*]] = insertelement <4 x i1> poison, i1 [[OR0]], i32 0
-; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <4 x i1> [[TMP0]], i1 [[OR1]], i32 1
-; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <4 x i1> [[TMP1]], i1 [[OR2]], i32 2
-; CHECK-NEXT:    [[TMP9:%.*]] = insertelement <4 x i1> [[TMP2]], i1 [[OR3]], i32 3
+; CHECK-NEXT:    [[TMP0:%.*]] = insertelement <4 x float> poison, float [[D0]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <4 x float> [[TMP0]], float [[D1]], i32 1
+; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <4 x float> [[TMP1]], float [[D2]], i32 2
+; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <4 x float> [[TMP2]], float [[D3]], i32 3
+; CHECK-NEXT:    [[TMP4:%.*]] = insertelement <4 x float> poison, float [[THRESHOLD]], i32 0
+; CHECK-NEXT:    [[TMP5:%.*]] = shufflevector <4 x float> [[TMP4]], <4 x float> poison, <4 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP6:%.*]] = fcmp fast uge <4 x float> [[TMP3]], [[TMP5]]
+; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <4 x i1> poison, i1 [[SCALAR_COND]], i32 0
+; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <4 x i1> [[TMP7]], <4 x i1> poison, <4 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP9:%.*]] = select <4 x i1> [[TMP6]], <4 x i1> splat (i1 true), <4 x i1> [[TMP8]]
 ; CHECK-NEXT:    [[TMP10:%.*]] = insertelement <4 x float> poison, float [[HPHB_VAL]], i32 0
 ; CHECK-NEXT:    [[TMP11:%.*]] = shufflevector <4 x float> [[TMP10]], <4 x float> poison, <4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP12:%.*]] = select <4 x i1> [[TMP9]], <4 x float> zeroinitializer, <4 x float> [[TMP11]]
@@ -86,18 +84,16 @@ define void @select_logical_and_i1(ptr %dst,
 ; CHECK-LABEL: define void @select_logical_and_i1(
 ; CHECK-SAME: ptr [[DST:%.*]], float [[D0:%.*]], float [[D1:%.*]], float [[D2:%.*]], float [[D3:%.*]], float [[THRESHOLD:%.*]], float [[HPHB_VAL:%.*]], i1 [[SCALAR_COND:%.*]], float [[Y0:%.*]], float [[Y1:%.*]], float [[Y2:%.*]], float [[Y3:%.*]], float [[E0:%.*]], float [[E1:%.*]], float [[E2:%.*]], float [[E3:%.*]]) #[[ATTR0]] {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
-; CHECK-NEXT:    [[CMP0:%.*]] = fcmp fast uge float [[D0]], [[THRESHOLD]]
-; CHECK-NEXT:    [[CMP1:%.*]] = fcmp fast uge float [[D1]], [[THRESHOLD]]
-; CHECK-NEXT:    [[CMP2:%.*]] = fcmp fast uge float [[D2]], [[THRESHOLD]]
-; CHECK-NEXT:    [[CMP3:%.*]] = fcmp fast uge float [[D3]], [[THRESHOLD]]
-; CHECK-NEXT:    [[AND3:%.*]] = select i1 [[CMP3]], i1 [[SCALAR_COND]], i1 false
-; CHECK-NEXT:    [[AND2:%.*]] = select i1 [[CMP2]], i1 [[SCALAR_COND]], i1 false
-; CHECK-NEXT:    [[AND1:%.*]] = select i1 [[CMP1]], i1 [[SCALAR_COND]], i1 false
-; CHECK-NEXT:    [[AND0:%.*]] = select i1 [[CMP0]], i1 [[SCALAR_COND]], i1 false
-; CHECK-NEXT:    [[TMP0:%.*]] = insertelement <4 x i1> poison, i1 [[AND0]], i32 0
-; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <4 x i1> [[TMP0]], i1 [[AND1]], i32 1
-; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <4 x i1> [[TMP1]], i1 [[AND2]], i32 2
-; CHECK-NEXT:    [[TMP9:%.*]] = insertelement <4 x i1> [[TMP2]], i1 [[AND3]], i32 3
+; CHECK-NEXT:    [[TMP0:%.*]] = insertelement <4 x float> poison, float [[D0]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <4 x float> [[TMP0]], float [[D1]], i32 1
+; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <4 x float> [[TMP1]], float [[D2]], i32 2
+; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <4 x float> [[TMP2]], float [[D3]], i32 3
+; CHECK-NEXT:    [[TMP4:%.*]] = insertelement <4 x float> poison, float [[THRESHOLD]], i32 0
+; CHECK-NEXT:    [[TMP5:%.*]] = shufflevector <4 x float> [[TMP4]], <4 x float> poison, <4 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP6:%.*]] = fcmp fast uge <4 x float> [[TMP3]], [[TMP5]]
+; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <4 x i1> poison, i1 [[SCALAR_COND]], i32 0
+; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <4 x i1> [[TMP7]], <4 x i1> poison, <4 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP9:%.*]] = select <4 x i1> [[TMP6]], <4 x i1> [[TMP8]], <4 x i1> zeroinitializer
 ; CHECK-NEXT:    [[TMP10:%.*]] = insertelement <4 x float> poison, float [[HPHB_VAL]], i32 0
 ; CHECK-NEXT:    [[TMP11:%.*]] = shufflevector <4 x float> [[TMP10]], <4 x float> poison, <4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP12:%.*]] = select <4 x i1> [[TMP9]], <4 x float> zeroinitializer, <4 x float> [[TMP11]]



More information about the llvm-commits mailing list