[llvm] [SLP]Allow bitcast/bswap based reductions for types, larger than the total strided size (PR #184018)

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 5 14:56:38 PST 2026


https://github.com/alexey-bataev updated https://github.com/llvm/llvm-project/pull/184018

>From 4b865fa395a242e745d5e8baee5e35fccad21f93 Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Sun, 1 Mar 2026 09:45:39 -0800
Subject: [PATCH 1/2] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20in?=
 =?UTF-8?q?itial=20version?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Created using spr 1.3.7
---
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 67 ++++++++++++++-----
 .../X86/disjoint-or-reductions.ll             | 10 ++-
 2 files changed, 53 insertions(+), 24 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 14b36a8619eb5..745bb5d8394f6 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -13422,7 +13422,7 @@ bool BoUpSLP::matchesShlZExt(const TreeEntry &TE, OrdersType &Order,
     return false;
   Order.clear();
   unsigned CurrentValue = 0;
-  // Rhs should be (0, Stride, 2 * Stride, ..., Sz-Stride).
+  // Rhs should be (0, Stride, 2 * Stride, ..., N-Stride), where N <= Sz.
   if (all_of(RhsTE->Scalars,
              [&](Value *V) {
                CurrentValue += Stride;
@@ -13433,7 +13433,7 @@ bool BoUpSLP::matchesShlZExt(const TreeEntry &TE, OrdersType &Order,
                  return false;
                return C->getUniqueInteger() == CurrentValue - Stride;
              }) &&
-      CurrentValue == Sz) {
+      CurrentValue <= Sz) {
     Order.clear();
   } else {
     const unsigned VF = RhsTE->getVectorFactor();
@@ -13441,8 +13441,8 @@ bool BoUpSLP::matchesShlZExt(const TreeEntry &TE, OrdersType &Order,
     // Track which logical positions we've seen; reject duplicate shift amounts.
     SmallBitVector SeenPositions(VF);
     // Check if need to reorder Rhs to make it in form (0, Stride, 2 * Stride,
-    // ..., Sz-Stride).
-    if (VF * Stride != Sz)
+    // ..., N-Stride), where N <= Sz.
+    if (VF * Stride > Sz)
       return false;
     for (const auto [Idx, V] : enumerate(RhsTE->Scalars)) {
       if (isa<UndefValue>(V))
@@ -15821,7 +15821,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
       ScalarCost += TTI.getInstructionCost(ZExt, CostKind);
       return ScalarCost;
     };
-    auto GetVectorCost = [&, &TTI = *TTI](InstructionCost CommonCost) {
+    auto GetVectorCost = [&, &TTI = *TTI](InstructionCost) {
       const TreeEntry *LhsTE = getOperandEntry(E, /*Idx=*/0);
       TTI::CastContextHint CastCtx =
           getCastContextHint(*getOperandEntry(LhsTE, /*Idx=*/0));
@@ -15830,14 +15830,21 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
       InstructionCost BitcastCost = TTI.getCastInstrCost(
           Instruction::BitCast, ScalarTy, SrcVecTy, CastCtx, CostKind);
       if (ShuffleOrOp == TreeEntry::ReducedBitcastBSwap) {
-        auto *OrigScalarTy = E->getMainOp()->getType();
+        auto *OrigScalarTy = IntegerType::getIntNTy(
+            ScalarTy->getContext(),
+            DL->getTypeSizeInBits(SrcScalarTy) * EntryVF);
         IntrinsicCostAttributes CostAttrs(Intrinsic::bswap, OrigScalarTy,
                                           {OrigScalarTy});
         InstructionCost IntrinsicCost =
             TTI.getIntrinsicInstrCost(CostAttrs, CostKind);
         BitcastCost += IntrinsicCost;
+        if (OrigScalarTy != ScalarTy) {
+          BitcastCost +=
+              TTI.getCastInstrCost(Instruction::ZExt, ScalarTy, OrigScalarTy,
+                                   TTI::CastContextHint::None, CostKind);
+        }
       }
-      return BitcastCost + CommonCost;
+      return BitcastCost;
     };
     return GetCostDiff(GetScalarCost, GetVectorCost);
   }
@@ -15860,11 +15867,13 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
       ScalarCost += TTI.getInstructionCost(Load, CostKind);
       return ScalarCost;
     };
-    auto GetVectorCost = [&, &TTI = *TTI](InstructionCost CommonCost) {
+    auto GetVectorCost = [&, &TTI = *TTI](InstructionCost) {
       const TreeEntry *LhsTE = getOperandEntry(E, /*Idx=*/0);
       const TreeEntry *LoadTE = getOperandEntry(LhsTE, /*Idx=*/0);
       auto *LI0 = cast<LoadInst>(LoadTE->getMainOp());
-      auto *OrigScalarTy = E->getMainOp()->getType();
+      auto *OrigScalarTy = IntegerType::getIntNTy(
+          ScalarTy->getContext(),
+          DL->getTypeSizeInBits(LI0->getType()) * EntryVF);
       InstructionCost LoadCost =
           TTI.getMemoryOpCost(Instruction::Load, OrigScalarTy, LI0->getAlign(),
                               LI0->getPointerAddressSpace(), CostKind);
@@ -15874,8 +15883,13 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
         InstructionCost IntrinsicCost =
             TTI.getIntrinsicInstrCost(CostAttrs, CostKind);
         LoadCost += IntrinsicCost;
+        if (OrigScalarTy != ScalarTy) {
+          LoadCost +=
+              TTI.getCastInstrCost(Instruction::ZExt, ScalarTy, OrigScalarTy,
+                                   TTI::CastContextHint::None, CostKind);
+        }
       }
-      return LoadCost + CommonCost;
+      return LoadCost;
     };
     return GetCostDiff(GetScalarCost, GetVectorCost);
   }
@@ -21634,17 +21648,25 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
       Const->VectorizedValue = PoisonValue::get(getWidenedType(
           Const->Scalars.front()->getType(), Const->getVectorFactor()));
       Value *Op = vectorizeOperand(ZExt, 0);
+      auto *SrcType = IntegerType::get(
+          Op->getContext(),
+          DL->getTypeSizeInBits(
+              cast<CastInst>(ZExt->getMainOp())->getSrcTy()) *
+              E->getVectorFactor());
+      auto *OrigScalarTy = ScalarTy;
       // Set the scalar type properly to avoid casting to the extending type.
       ScalarTy = cast<CastInst>(ZExt->getMainOp())->getSrcTy();
       Op = FinalShuffle(Op, E);
-      auto *V = Builder.CreateBitCast(
-          Op, IntegerType::get(
-                  Op->getContext(),
-                  DL->getTypeSizeInBits(ZExt->getMainOp()->getType())));
-      if (ShuffleOrOp == TreeEntry::ReducedBitcastBSwap)
+      auto *V = Builder.CreateBitCast(Op, SrcType);
+      if (ShuffleOrOp == TreeEntry::ReducedBitcastBSwap) {
         V = Builder.CreateUnaryIntrinsic(Intrinsic::bswap, V);
+        ++NumVectorInstructions;
+      }
+      if (SrcType != OrigScalarTy) {
+        V = Builder.CreateIntCast(V, OrigScalarTy, /*isSigned=*/false);
+        ++NumVectorInstructions;
+      }
       E->VectorizedValue = V;
-      ++NumVectorInstructions;
       return V;
     }
     case TreeEntry::ReducedBitcastLoads:
@@ -21662,13 +21684,22 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
           Load->getMainOp()->getType(), Load->getVectorFactor()));
       LoadInst *LI = cast<LoadInst>(Load->getMainOp());
       Value *PO = LI->getPointerOperand();
-      Type *ScalarTy = ZExt->getMainOp()->getType();
-      Value *V = Builder.CreateAlignedLoad(ScalarTy, PO, LI->getAlign());
+      auto *SrcType = IntegerType::get(
+          ScalarTy->getContext(),
+          DL->getTypeSizeInBits(cast<CastInst>(ZExt->getMainOp())->getSrcTy()) *
+              E->getVectorFactor());
+      auto *OrigScalarTy = ScalarTy;
+      ScalarTy = ZExt->getMainOp()->getType();
+      Value *V = Builder.CreateAlignedLoad(OrigScalarTy, PO, LI->getAlign());
       ++NumVectorInstructions;
       if (ShuffleOrOp == TreeEntry::ReducedBitcastBSwapLoads) {
         V = Builder.CreateUnaryIntrinsic(Intrinsic::bswap, V);
         ++NumVectorInstructions;
       }
+      if (SrcType != OrigScalarTy) {
+        V = Builder.CreateIntCast(V, OrigScalarTy, /*isSigned=*/false);
+        ++NumVectorInstructions;
+      }
       E->VectorizedValue = V;
       return V;
     }
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/disjoint-or-reductions.ll b/llvm/test/Transforms/SLPVectorizer/X86/disjoint-or-reductions.ll
index 93bbe58c5197e..1d606a74ed5ca 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/disjoint-or-reductions.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/disjoint-or-reductions.ll
@@ -260,9 +260,8 @@ define i64 @bswap_i32(ptr noalias %p, ptr noalias %p1) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = load <4 x i8>, ptr [[P:%.*]], align 1
 ; CHECK-NEXT:    [[TMP2:%.*]] = load <4 x i8>, ptr [[P1:%.*]], align 1
 ; CHECK-NEXT:    [[TMP3:%.*]] = add <4 x i8> [[TMP1]], [[TMP2]]
-; CHECK-NEXT:    [[TMP4:%.*]] = zext <4 x i8> [[TMP3]] to <4 x i32>
-; CHECK-NEXT:    [[TMP5:%.*]] = shl <4 x i32> [[TMP4]], <i32 24, i32 16, i32 8, i32 0>
-; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> [[TMP5]])
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <4 x i8> [[TMP3]] to i32
+; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.bswap.i32(i32 [[TMP4]])
 ; CHECK-NEXT:    [[TMP7:%.*]] = zext i32 [[TMP6]] to i64
 ; CHECK-NEXT:    ret i64 [[TMP7]]
 ;
@@ -310,9 +309,8 @@ define i64 @reorder_i32(ptr noalias %p, ptr noalias %p1) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = load <4 x i8>, ptr [[P:%.*]], align 1
 ; CHECK-NEXT:    [[TMP2:%.*]] = load <4 x i8>, ptr [[P1:%.*]], align 1
 ; CHECK-NEXT:    [[TMP3:%.*]] = add <4 x i8> [[TMP1]], [[TMP2]]
-; CHECK-NEXT:    [[TMP4:%.*]] = zext <4 x i8> [[TMP3]] to <4 x i32>
-; CHECK-NEXT:    [[TMP5:%.*]] = shl <4 x i32> [[TMP4]], <i32 16, i32 24, i32 0, i32 8>
-; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> [[TMP5]])
+; CHECK-NEXT:    [[TMP4:%.*]] = shufflevector <4 x i8> [[TMP3]], <4 x i8> poison, <4 x i32> <i32 2, i32 3, i32 0, i32 1>
+; CHECK-NEXT:    [[TMP6:%.*]] = bitcast <4 x i8> [[TMP4]] to i32
 ; CHECK-NEXT:    [[TMP7:%.*]] = zext i32 [[TMP6]] to i64
 ; CHECK-NEXT:    ret i64 [[TMP7]]
 ;

>From 707a55c14e93594a0b8c34c2f3f2063d1d0cbfbb Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Sun, 1 Mar 2026 09:48:45 -0800
Subject: [PATCH 2/2] Fix formatting

Created using spr 1.3.7
---
 llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 745bb5d8394f6..ff912232c7c39 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -21650,8 +21650,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
       Value *Op = vectorizeOperand(ZExt, 0);
       auto *SrcType = IntegerType::get(
           Op->getContext(),
-          DL->getTypeSizeInBits(
-              cast<CastInst>(ZExt->getMainOp())->getSrcTy()) *
+          DL->getTypeSizeInBits(cast<CastInst>(ZExt->getMainOp())->getSrcTy()) *
               E->getVectorFactor());
       auto *OrigScalarTy = ScalarTy;
       // Set the scalar type properly to avoid casting to the extending type.



More information about the llvm-commits mailing list