[llvm] [SLP]Allow bitcast/bswap based reductions for types, larger than the total strided size (PR #184018)
Alexey Bataev via llvm-commits
llvm-commits at lists.llvm.org
Thu Mar 5 14:56:38 PST 2026
https://github.com/alexey-bataev updated https://github.com/llvm/llvm-project/pull/184018
>From 4b865fa395a242e745d5e8baee5e35fccad21f93 Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Sun, 1 Mar 2026 09:45:39 -0800
Subject: [PATCH 1/2] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20in?=
=?UTF-8?q?itial=20version?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Created using spr 1.3.7
---
.../Transforms/Vectorize/SLPVectorizer.cpp | 67 ++++++++++++++-----
.../X86/disjoint-or-reductions.ll | 10 ++-
2 files changed, 53 insertions(+), 24 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 14b36a8619eb5..745bb5d8394f6 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -13422,7 +13422,7 @@ bool BoUpSLP::matchesShlZExt(const TreeEntry &TE, OrdersType &Order,
return false;
Order.clear();
unsigned CurrentValue = 0;
- // Rhs should be (0, Stride, 2 * Stride, ..., Sz-Stride).
+ // Rhs should be (0, Stride, 2 * Stride, ..., N-Stride), where N <= Sz.
if (all_of(RhsTE->Scalars,
[&](Value *V) {
CurrentValue += Stride;
@@ -13433,7 +13433,7 @@ bool BoUpSLP::matchesShlZExt(const TreeEntry &TE, OrdersType &Order,
return false;
return C->getUniqueInteger() == CurrentValue - Stride;
}) &&
- CurrentValue == Sz) {
+ CurrentValue <= Sz) {
Order.clear();
} else {
const unsigned VF = RhsTE->getVectorFactor();
@@ -13441,8 +13441,8 @@ bool BoUpSLP::matchesShlZExt(const TreeEntry &TE, OrdersType &Order,
// Track which logical positions we've seen; reject duplicate shift amounts.
SmallBitVector SeenPositions(VF);
// Check if need to reorder Rhs to make it in form (0, Stride, 2 * Stride,
- // ..., Sz-Stride).
- if (VF * Stride != Sz)
+ // ..., N-Stride), where N <= Sz.
+ if (VF * Stride > Sz)
return false;
for (const auto [Idx, V] : enumerate(RhsTE->Scalars)) {
if (isa<UndefValue>(V))
@@ -15821,7 +15821,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
ScalarCost += TTI.getInstructionCost(ZExt, CostKind);
return ScalarCost;
};
- auto GetVectorCost = [&, &TTI = *TTI](InstructionCost CommonCost) {
+ auto GetVectorCost = [&, &TTI = *TTI](InstructionCost) {
const TreeEntry *LhsTE = getOperandEntry(E, /*Idx=*/0);
TTI::CastContextHint CastCtx =
getCastContextHint(*getOperandEntry(LhsTE, /*Idx=*/0));
@@ -15830,14 +15830,21 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
InstructionCost BitcastCost = TTI.getCastInstrCost(
Instruction::BitCast, ScalarTy, SrcVecTy, CastCtx, CostKind);
if (ShuffleOrOp == TreeEntry::ReducedBitcastBSwap) {
- auto *OrigScalarTy = E->getMainOp()->getType();
+ auto *OrigScalarTy = IntegerType::getIntNTy(
+ ScalarTy->getContext(),
+ DL->getTypeSizeInBits(SrcScalarTy) * EntryVF);
IntrinsicCostAttributes CostAttrs(Intrinsic::bswap, OrigScalarTy,
{OrigScalarTy});
InstructionCost IntrinsicCost =
TTI.getIntrinsicInstrCost(CostAttrs, CostKind);
BitcastCost += IntrinsicCost;
+ if (OrigScalarTy != ScalarTy) {
+ BitcastCost +=
+ TTI.getCastInstrCost(Instruction::ZExt, ScalarTy, OrigScalarTy,
+ TTI::CastContextHint::None, CostKind);
+ }
}
- return BitcastCost + CommonCost;
+ return BitcastCost;
};
return GetCostDiff(GetScalarCost, GetVectorCost);
}
@@ -15860,11 +15867,13 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
ScalarCost += TTI.getInstructionCost(Load, CostKind);
return ScalarCost;
};
- auto GetVectorCost = [&, &TTI = *TTI](InstructionCost CommonCost) {
+ auto GetVectorCost = [&, &TTI = *TTI](InstructionCost) {
const TreeEntry *LhsTE = getOperandEntry(E, /*Idx=*/0);
const TreeEntry *LoadTE = getOperandEntry(LhsTE, /*Idx=*/0);
auto *LI0 = cast<LoadInst>(LoadTE->getMainOp());
- auto *OrigScalarTy = E->getMainOp()->getType();
+ auto *OrigScalarTy = IntegerType::getIntNTy(
+ ScalarTy->getContext(),
+ DL->getTypeSizeInBits(LI0->getType()) * EntryVF);
InstructionCost LoadCost =
TTI.getMemoryOpCost(Instruction::Load, OrigScalarTy, LI0->getAlign(),
LI0->getPointerAddressSpace(), CostKind);
@@ -15874,8 +15883,13 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
InstructionCost IntrinsicCost =
TTI.getIntrinsicInstrCost(CostAttrs, CostKind);
LoadCost += IntrinsicCost;
+ if (OrigScalarTy != ScalarTy) {
+ LoadCost +=
+ TTI.getCastInstrCost(Instruction::ZExt, ScalarTy, OrigScalarTy,
+ TTI::CastContextHint::None, CostKind);
+ }
}
- return LoadCost + CommonCost;
+ return LoadCost;
};
return GetCostDiff(GetScalarCost, GetVectorCost);
}
@@ -21634,17 +21648,25 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
Const->VectorizedValue = PoisonValue::get(getWidenedType(
Const->Scalars.front()->getType(), Const->getVectorFactor()));
Value *Op = vectorizeOperand(ZExt, 0);
+ auto *SrcType = IntegerType::get(
+ Op->getContext(),
+ DL->getTypeSizeInBits(
+ cast<CastInst>(ZExt->getMainOp())->getSrcTy()) *
+ E->getVectorFactor());
+ auto *OrigScalarTy = ScalarTy;
// Set the scalar type properly to avoid casting to the extending type.
ScalarTy = cast<CastInst>(ZExt->getMainOp())->getSrcTy();
Op = FinalShuffle(Op, E);
- auto *V = Builder.CreateBitCast(
- Op, IntegerType::get(
- Op->getContext(),
- DL->getTypeSizeInBits(ZExt->getMainOp()->getType())));
- if (ShuffleOrOp == TreeEntry::ReducedBitcastBSwap)
+ auto *V = Builder.CreateBitCast(Op, SrcType);
+ if (ShuffleOrOp == TreeEntry::ReducedBitcastBSwap) {
V = Builder.CreateUnaryIntrinsic(Intrinsic::bswap, V);
+ ++NumVectorInstructions;
+ }
+ if (SrcType != OrigScalarTy) {
+ V = Builder.CreateIntCast(V, OrigScalarTy, /*isSigned=*/false);
+ ++NumVectorInstructions;
+ }
E->VectorizedValue = V;
- ++NumVectorInstructions;
return V;
}
case TreeEntry::ReducedBitcastLoads:
@@ -21662,13 +21684,22 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
Load->getMainOp()->getType(), Load->getVectorFactor()));
LoadInst *LI = cast<LoadInst>(Load->getMainOp());
Value *PO = LI->getPointerOperand();
- Type *ScalarTy = ZExt->getMainOp()->getType();
- Value *V = Builder.CreateAlignedLoad(ScalarTy, PO, LI->getAlign());
+ auto *SrcType = IntegerType::get(
+ ScalarTy->getContext(),
+ DL->getTypeSizeInBits(cast<CastInst>(ZExt->getMainOp())->getSrcTy()) *
+ E->getVectorFactor());
+ auto *OrigScalarTy = ScalarTy;
+ ScalarTy = ZExt->getMainOp()->getType();
+ Value *V = Builder.CreateAlignedLoad(OrigScalarTy, PO, LI->getAlign());
++NumVectorInstructions;
if (ShuffleOrOp == TreeEntry::ReducedBitcastBSwapLoads) {
V = Builder.CreateUnaryIntrinsic(Intrinsic::bswap, V);
++NumVectorInstructions;
}
+ if (SrcType != OrigScalarTy) {
+ V = Builder.CreateIntCast(V, OrigScalarTy, /*isSigned=*/false);
+ ++NumVectorInstructions;
+ }
E->VectorizedValue = V;
return V;
}
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/disjoint-or-reductions.ll b/llvm/test/Transforms/SLPVectorizer/X86/disjoint-or-reductions.ll
index 93bbe58c5197e..1d606a74ed5ca 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/disjoint-or-reductions.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/disjoint-or-reductions.ll
@@ -260,9 +260,8 @@ define i64 @bswap_i32(ptr noalias %p, ptr noalias %p1) {
; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i8>, ptr [[P:%.*]], align 1
; CHECK-NEXT: [[TMP2:%.*]] = load <4 x i8>, ptr [[P1:%.*]], align 1
; CHECK-NEXT: [[TMP3:%.*]] = add <4 x i8> [[TMP1]], [[TMP2]]
-; CHECK-NEXT: [[TMP4:%.*]] = zext <4 x i8> [[TMP3]] to <4 x i32>
-; CHECK-NEXT: [[TMP5:%.*]] = shl <4 x i32> [[TMP4]], <i32 24, i32 16, i32 8, i32 0>
-; CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> [[TMP5]])
+; CHECK-NEXT: [[TMP4:%.*]] = bitcast <4 x i8> [[TMP3]] to i32
+; CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.bswap.i32(i32 [[TMP4]])
; CHECK-NEXT: [[TMP7:%.*]] = zext i32 [[TMP6]] to i64
; CHECK-NEXT: ret i64 [[TMP7]]
;
@@ -310,9 +309,8 @@ define i64 @reorder_i32(ptr noalias %p, ptr noalias %p1) {
; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i8>, ptr [[P:%.*]], align 1
; CHECK-NEXT: [[TMP2:%.*]] = load <4 x i8>, ptr [[P1:%.*]], align 1
; CHECK-NEXT: [[TMP3:%.*]] = add <4 x i8> [[TMP1]], [[TMP2]]
-; CHECK-NEXT: [[TMP4:%.*]] = zext <4 x i8> [[TMP3]] to <4 x i32>
-; CHECK-NEXT: [[TMP5:%.*]] = shl <4 x i32> [[TMP4]], <i32 16, i32 24, i32 0, i32 8>
-; CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> [[TMP5]])
+; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <4 x i8> [[TMP3]], <4 x i8> poison, <4 x i32> <i32 2, i32 3, i32 0, i32 1>
+; CHECK-NEXT: [[TMP6:%.*]] = bitcast <4 x i8> [[TMP4]] to i32
; CHECK-NEXT: [[TMP7:%.*]] = zext i32 [[TMP6]] to i64
; CHECK-NEXT: ret i64 [[TMP7]]
;
>From 707a55c14e93594a0b8c34c2f3f2063d1d0cbfbb Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Sun, 1 Mar 2026 09:48:45 -0800
Subject: [PATCH 2/2] Fix formatting
Created using spr 1.3.7
---
llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 745bb5d8394f6..ff912232c7c39 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -21650,8 +21650,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
Value *Op = vectorizeOperand(ZExt, 0);
auto *SrcType = IntegerType::get(
Op->getContext(),
- DL->getTypeSizeInBits(
- cast<CastInst>(ZExt->getMainOp())->getSrcTy()) *
+ DL->getTypeSizeInBits(cast<CastInst>(ZExt->getMainOp())->getSrcTy()) *
E->getVectorFactor());
auto *OrigScalarTy = ScalarTy;
// Set the scalar type properly to avoid casting to the extending type.
More information about the llvm-commits
mailing list