[llvm] [VectorCombine] foldShuffleOfCastops - handle unary shuffles (PR #160009)
Chaitanya Koparkar via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 23 05:59:25 PDT 2025
https://github.com/ckoparkar updated https://github.com/llvm/llvm-project/pull/160009
>From ed4c1dfb83a5d3d33f185559a4820b7d3fefb45e Mon Sep 17 00:00:00 2001
From: Chaitanya Koparkar <ckoparkar at gmail.com>
Date: Sun, 21 Sep 2025 16:01:33 -0400
Subject: [PATCH 1/2] [VectorCombine] foldShuffleOfCastops - handle unary
shuffles
---
.../Transforms/Vectorize/VectorCombine.cpp | 67 ++++++++++++-------
1 file changed, 44 insertions(+), 23 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 17cb18a22336a..c2a4353b3eb62 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -2477,21 +2477,28 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
if (!match(&I, m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(OldMask))))
return false;
+ // Check whether this is a unary shuffle.
+ // TODO: should this be extended to match undef or unused values.
+ bool IsBinaryShuffle = !isa<PoisonValue>(V1);
+ LLVM_DEBUG(dbgs() << "Is binary shuffle: " << IsBinaryShuffle << "\n");
+
auto *C0 = dyn_cast<CastInst>(V0);
auto *C1 = dyn_cast<CastInst>(V1);
- if (!C0 || !C1)
+ if (!C0 || (IsBinaryShuffle && !C1))
return false;
Instruction::CastOps Opcode = C0->getOpcode();
- if (C0->getSrcTy() != C1->getSrcTy())
- return false;
- // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
- if (Opcode != C1->getOpcode()) {
- if (match(C0, m_SExtLike(m_Value())) && match(C1, m_SExtLike(m_Value())))
- Opcode = Instruction::SExt;
- else
+ if (IsBinaryShuffle) {
+ if (C0->getSrcTy() != C1->getSrcTy())
return false;
+ // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
+ if (Opcode != C1->getOpcode()) {
+ if (match(C0, m_SExtLike(m_Value())) && match(C1, m_SExtLike(m_Value())))
+ Opcode = Instruction::SExt;
+ else
+ return false;
+ }
}
auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
@@ -2534,23 +2541,31 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
InstructionCost CostC0 =
TTI.getCastInstrCost(C0->getOpcode(), CastDstTy, CastSrcTy,
TTI::CastContextHint::None, CostKind);
- InstructionCost CostC1 =
- TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy,
- TTI::CastContextHint::None, CostKind);
- InstructionCost OldCost = CostC0 + CostC1;
- OldCost +=
- TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, ShuffleDstTy,
- CastDstTy, OldMask, CostKind, 0, nullptr, {}, &I);
- InstructionCost NewCost =
- TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, NewShuffleDstTy,
- CastSrcTy, NewMask, CostKind);
+ TargetTransformInfo::ShuffleKind ShuffleKind;
+ if (IsBinaryShuffle)
+ ShuffleKind = TargetTransformInfo::SK_PermuteTwoSrc;
+ else
+ ShuffleKind = TargetTransformInfo::SK_PermuteSingleSrc;
+
+ InstructionCost OldCost = CostC0;
+ OldCost += TTI.getShuffleCost(ShuffleKind, ShuffleDstTy, CastDstTy, OldMask,
+ CostKind, 0, nullptr, {}, &I);
+
+ InstructionCost NewCost = TTI.getShuffleCost(ShuffleKind, NewShuffleDstTy,
+ CastSrcTy, NewMask, CostKind);
NewCost += TTI.getCastInstrCost(Opcode, ShuffleDstTy, NewShuffleDstTy,
TTI::CastContextHint::None, CostKind);
if (!C0->hasOneUse())
NewCost += CostC0;
- if (!C1->hasOneUse())
- NewCost += CostC1;
+ if (IsBinaryShuffle) {
+ InstructionCost CostC1 =
+ TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy,
+ TTI::CastContextHint::None, CostKind);
+ OldCost += CostC1;
+ if (!C1->hasOneUse())
+ NewCost += CostC1;
+ }
LLVM_DEBUG(dbgs() << "Found a shuffle feeding two casts: " << I
<< "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
@@ -2558,14 +2573,20 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
if (NewCost > OldCost)
return false;
- Value *Shuf = Builder.CreateShuffleVector(C0->getOperand(0),
- C1->getOperand(0), NewMask);
+ Value *Shuf;
+ if (IsBinaryShuffle)
+ Shuf = Builder.CreateShuffleVector(C0->getOperand(0), C1->getOperand(0),
+ NewMask);
+ else
+ Shuf = Builder.CreateShuffleVector(C0->getOperand(0), NewMask);
+
Value *Cast = Builder.CreateCast(Opcode, Shuf, ShuffleDstTy);
// Intersect flags from the old casts.
if (auto *NewInst = dyn_cast<Instruction>(Cast)) {
NewInst->copyIRFlags(C0);
- NewInst->andIRFlags(C1);
+ if (IsBinaryShuffle)
+ NewInst->andIRFlags(C1);
}
Worklist.pushValue(Shuf);
>From 8ae981639ffb9b4e48b7452819d003e0c3b3dab9 Mon Sep 17 00:00:00 2001
From: Chaitanya Koparkar <ckoparkar at gmail.com>
Date: Tue, 23 Sep 2025 08:50:25 -0400
Subject: [PATCH 2/2] Fix shuffletoidentity test
---
llvm/lib/Transforms/Vectorize/VectorCombine.cpp | 4 ++--
.../VectorCombine/AArch64/shuffletoidentity.ll | 17 ++++++++---------
2 files changed, 10 insertions(+), 11 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 7c3c219b94190..526a4add2a89a 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -2488,9 +2488,9 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
return false;
// Check whether this is a unary shuffle.
- // TODO: should this be extended to match undef or unused values.
+ // TODO: check if this can be extended to match undef or unused values,
+ // perhaps using ShuffleVectorInst::isSingleSource.
bool IsBinaryShuffle = !isa<PoisonValue>(V1);
- LLVM_DEBUG(dbgs() << "Is binary shuffle: " << IsBinaryShuffle << "\n");
auto *C0 = dyn_cast<CastInst>(V0);
auto *C1 = dyn_cast<CastInst>(V1);
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll b/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll
index acbc836ffcab0..ed29719d49493 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll
@@ -205,8 +205,8 @@ define <8 x i8> @abs_different(<8 x i8> %a) {
define <4 x i32> @poison_intrinsic(<2 x i16> %l256) {
; CHECK-LABEL: @poison_intrinsic(
; CHECK-NEXT: [[L266:%.*]] = call <2 x i16> @llvm.abs.v2i16(<2 x i16> [[L256:%.*]], i1 false)
-; CHECK-NEXT: [[L267:%.*]] = zext <2 x i16> [[L266]] to <2 x i32>
-; CHECK-NEXT: [[L271:%.*]] = shufflevector <2 x i32> [[L267]], <2 x i32> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
+; CHECK-NEXT: [[L267:%.*]] = shufflevector <2 x i16> [[L266]], <2 x i16> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
+; CHECK-NEXT: [[L271:%.*]] = zext <4 x i16> [[L267]] to <4 x i32>
; CHECK-NEXT: ret <4 x i32> [[L271]]
;
%l266 = call <2 x i16> @llvm.abs.v2i16(<2 x i16> %l256, i1 false)
@@ -534,9 +534,9 @@ define <4 x i64> @single_zext(<4 x i32> %x) {
define <4 x i64> @not_zext(<4 x i32> %x) {
; CHECK-LABEL: @not_zext(
-; CHECK-NEXT: [[ZEXT:%.*]] = zext <4 x i32> [[X:%.*]] to <4 x i64>
-; CHECK-NEXT: [[REVSHUF:%.*]] = shufflevector <4 x i64> [[ZEXT]], <4 x i64> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
-; CHECK-NEXT: ret <4 x i64> [[REVSHUF]]
+; CHECK-NEXT: [[REVSHUF:%.*]] = shufflevector <4 x i32> [[X]], <4 x i32> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
+; CHECK-NEXT: [[ZEXT:%.*]] = zext <4 x i32> [[REVSHUF:%.*]] to <4 x i64>
+; CHECK-NEXT: ret <4 x i64> [[ZEXT]]
;
%zext = zext <4 x i32> %x to <4 x i64>
%revshuf = shufflevector <4 x i64> %zext, <4 x i64> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
@@ -922,10 +922,9 @@ define <4 x i8> @singleop(<4 x i8> %a, <4 x i8> %b) {
define <4 x i64> @cast_mismatched_types(<4 x i32> %x) {
; CHECK-LABEL: @cast_mismatched_types(
-; CHECK-NEXT: [[SHUF:%.*]] = shufflevector <4 x i32> [[X:%.*]], <4 x i32> poison, <2 x i32> <i32 0, i32 2>
-; CHECK-NEXT: [[ZEXT:%.*]] = zext <2 x i32> [[SHUF]] to <2 x i64>
-; CHECK-NEXT: [[EXTSHUF:%.*]] = shufflevector <2 x i64> [[ZEXT]], <2 x i64> poison, <4 x i32> <i32 0, i32 2, i32 1, i32 3>
-; CHECK-NEXT: ret <4 x i64> [[EXTSHUF]]
+; CHECK-SAME: <4 x i32> [[X:%.*]]) {
+; CHECK-NEXT: [[ZEXT:%.*]] = zext <4 x i32> [[X]] to <4 x i64>
+; CHECK-NEXT: ret <4 x i64> [[ZEXT]]
;
%shuf = shufflevector <4 x i32> %x, <4 x i32> poison, <2 x i32> <i32 0, i32 2>
%zext = zext <2 x i32> %shuf to <2 x i64>
More information about the llvm-commits
mailing list