[llvm] [VectorCombine] foldShuffleOfCastops - handle unary shuffles (PR #160009)

Chaitanya Koparkar via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 23 05:59:25 PDT 2025


https://github.com/ckoparkar updated https://github.com/llvm/llvm-project/pull/160009

>From ed4c1dfb83a5d3d33f185559a4820b7d3fefb45e Mon Sep 17 00:00:00 2001
From: Chaitanya Koparkar <ckoparkar at gmail.com>
Date: Sun, 21 Sep 2025 16:01:33 -0400
Subject: [PATCH 1/2] [VectorCombine] foldShuffleOfCastops - handle unary
 shuffles

---
 .../Transforms/Vectorize/VectorCombine.cpp    | 67 ++++++++++++-------
 1 file changed, 44 insertions(+), 23 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 17cb18a22336a..c2a4353b3eb62 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -2477,21 +2477,28 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
   if (!match(&I, m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(OldMask))))
     return false;
 
+  // Check whether this is a unary shuffle.
+  // TODO: should this be extended to match undef or unused values.
+  bool IsBinaryShuffle = !isa<PoisonValue>(V1);
+  LLVM_DEBUG(dbgs() << "Is binary shuffle: " << IsBinaryShuffle << "\n");
+
   auto *C0 = dyn_cast<CastInst>(V0);
   auto *C1 = dyn_cast<CastInst>(V1);
-  if (!C0 || !C1)
+  if (!C0 || (IsBinaryShuffle && !C1))
     return false;
 
   Instruction::CastOps Opcode = C0->getOpcode();
-  if (C0->getSrcTy() != C1->getSrcTy())
-    return false;
 
-  // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
-  if (Opcode != C1->getOpcode()) {
-    if (match(C0, m_SExtLike(m_Value())) && match(C1, m_SExtLike(m_Value())))
-      Opcode = Instruction::SExt;
-    else
+  if (IsBinaryShuffle) {
+    if (C0->getSrcTy() != C1->getSrcTy())
       return false;
+    // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
+    if (Opcode != C1->getOpcode()) {
+      if (match(C0, m_SExtLike(m_Value())) && match(C1, m_SExtLike(m_Value())))
+        Opcode = Instruction::SExt;
+      else
+        return false;
+    }
   }
 
   auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
@@ -2534,23 +2541,31 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
   InstructionCost CostC0 =
       TTI.getCastInstrCost(C0->getOpcode(), CastDstTy, CastSrcTy,
                            TTI::CastContextHint::None, CostKind);
-  InstructionCost CostC1 =
-      TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy,
-                           TTI::CastContextHint::None, CostKind);
-  InstructionCost OldCost = CostC0 + CostC1;
-  OldCost +=
-      TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, ShuffleDstTy,
-                         CastDstTy, OldMask, CostKind, 0, nullptr, {}, &I);
 
-  InstructionCost NewCost =
-      TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, NewShuffleDstTy,
-                         CastSrcTy, NewMask, CostKind);
+  TargetTransformInfo::ShuffleKind ShuffleKind;
+  if (IsBinaryShuffle)
+    ShuffleKind = TargetTransformInfo::SK_PermuteTwoSrc;
+  else
+    ShuffleKind = TargetTransformInfo::SK_PermuteSingleSrc;
+
+  InstructionCost OldCost = CostC0;
+  OldCost += TTI.getShuffleCost(ShuffleKind, ShuffleDstTy, CastDstTy, OldMask,
+                                CostKind, 0, nullptr, {}, &I);
+
+  InstructionCost NewCost = TTI.getShuffleCost(ShuffleKind, NewShuffleDstTy,
+                                               CastSrcTy, NewMask, CostKind);
   NewCost += TTI.getCastInstrCost(Opcode, ShuffleDstTy, NewShuffleDstTy,
                                   TTI::CastContextHint::None, CostKind);
   if (!C0->hasOneUse())
     NewCost += CostC0;
-  if (!C1->hasOneUse())
-    NewCost += CostC1;
+  if (IsBinaryShuffle) {
+    InstructionCost CostC1 =
+        TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy,
+                             TTI::CastContextHint::None, CostKind);
+    OldCost += CostC1;
+    if (!C1->hasOneUse())
+      NewCost += CostC1;
+  }
 
   LLVM_DEBUG(dbgs() << "Found a shuffle feeding two casts: " << I
                     << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
@@ -2558,14 +2573,20 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
   if (NewCost > OldCost)
     return false;
 
-  Value *Shuf = Builder.CreateShuffleVector(C0->getOperand(0),
-                                            C1->getOperand(0), NewMask);
+  Value *Shuf;
+  if (IsBinaryShuffle)
+    Shuf = Builder.CreateShuffleVector(C0->getOperand(0), C1->getOperand(0),
+                                       NewMask);
+  else
+    Shuf = Builder.CreateShuffleVector(C0->getOperand(0), NewMask);
+
   Value *Cast = Builder.CreateCast(Opcode, Shuf, ShuffleDstTy);
 
   // Intersect flags from the old casts.
   if (auto *NewInst = dyn_cast<Instruction>(Cast)) {
     NewInst->copyIRFlags(C0);
-    NewInst->andIRFlags(C1);
+    if (IsBinaryShuffle)
+      NewInst->andIRFlags(C1);
   }
 
   Worklist.pushValue(Shuf);

>From 8ae981639ffb9b4e48b7452819d003e0c3b3dab9 Mon Sep 17 00:00:00 2001
From: Chaitanya Koparkar <ckoparkar at gmail.com>
Date: Tue, 23 Sep 2025 08:50:25 -0400
Subject: [PATCH 2/2] Fix shuffletoidentity test

---
 llvm/lib/Transforms/Vectorize/VectorCombine.cpp |  4 ++--
 .../VectorCombine/AArch64/shuffletoidentity.ll  | 17 ++++++++---------
 2 files changed, 10 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 7c3c219b94190..526a4add2a89a 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -2488,9 +2488,9 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
     return false;
 
   // Check whether this is a unary shuffle.
-  // TODO: should this be extended to match undef or unused values.
+  // TODO: check if this can be extended to match undef or unused values,
+  // perhaps using ShuffleVectorInst::isSingleSource.
   bool IsBinaryShuffle = !isa<PoisonValue>(V1);
-  LLVM_DEBUG(dbgs() << "Is binary shuffle: " << IsBinaryShuffle << "\n");
 
   auto *C0 = dyn_cast<CastInst>(V0);
   auto *C1 = dyn_cast<CastInst>(V1);
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll b/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll
index acbc836ffcab0..ed29719d49493 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll
@@ -205,8 +205,8 @@ define <8 x i8> @abs_different(<8 x i8> %a) {
 define <4 x i32> @poison_intrinsic(<2 x i16> %l256) {
 ; CHECK-LABEL: @poison_intrinsic(
 ; CHECK-NEXT:    [[L266:%.*]] = call <2 x i16> @llvm.abs.v2i16(<2 x i16> [[L256:%.*]], i1 false)
-; CHECK-NEXT:    [[L267:%.*]] = zext <2 x i16> [[L266]] to <2 x i32>
-; CHECK-NEXT:    [[L271:%.*]] = shufflevector <2 x i32> [[L267]], <2 x i32> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
+; CHECK-NEXT:    [[L267:%.*]] = shufflevector <2 x i16> [[L266]], <2 x i16> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
+; CHECK-NEXT:    [[L271:%.*]] = zext <4 x i16> [[L267]] to <4 x i32>
 ; CHECK-NEXT:    ret <4 x i32> [[L271]]
 ;
   %l266 = call <2 x i16> @llvm.abs.v2i16(<2 x i16> %l256, i1 false)
@@ -534,9 +534,9 @@ define <4 x i64> @single_zext(<4 x i32> %x) {
 
 define <4 x i64> @not_zext(<4 x i32> %x) {
 ; CHECK-LABEL: @not_zext(
-; CHECK-NEXT:    [[ZEXT:%.*]] = zext <4 x i32> [[X:%.*]] to <4 x i64>
-; CHECK-NEXT:    [[REVSHUF:%.*]] = shufflevector <4 x i64> [[ZEXT]], <4 x i64> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
-; CHECK-NEXT:    ret <4 x i64> [[REVSHUF]]
+; CHECK-NEXT:    [[REVSHUF:%.*]] = shufflevector <4 x i32> [[X]], <4 x i32> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
+; CHECK-NEXT:    [[ZEXT:%.*]] = zext <4 x i32> [[REVSHUF:%.*]] to <4 x i64>
+; CHECK-NEXT:    ret <4 x i64> [[ZEXT]]
 ;
   %zext = zext <4 x i32> %x to <4 x i64>
   %revshuf = shufflevector <4 x i64> %zext, <4 x i64> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
@@ -922,10 +922,9 @@ define <4 x i8> @singleop(<4 x i8> %a, <4 x i8> %b) {
 
 define <4 x i64> @cast_mismatched_types(<4 x i32> %x) {
 ; CHECK-LABEL: @cast_mismatched_types(
-; CHECK-NEXT:    [[SHUF:%.*]] = shufflevector <4 x i32> [[X:%.*]], <4 x i32> poison, <2 x i32> <i32 0, i32 2>
-; CHECK-NEXT:    [[ZEXT:%.*]] = zext <2 x i32> [[SHUF]] to <2 x i64>
-; CHECK-NEXT:    [[EXTSHUF:%.*]] = shufflevector <2 x i64> [[ZEXT]], <2 x i64> poison, <4 x i32> <i32 0, i32 2, i32 1, i32 3>
-; CHECK-NEXT:    ret <4 x i64> [[EXTSHUF]]
+; CHECK-SAME: <4 x i32> [[X:%.*]]) {
+; CHECK-NEXT:    [[ZEXT:%.*]] = zext <4 x i32> [[X]] to <4 x i64>
+; CHECK-NEXT:    ret <4 x i64> [[ZEXT]]
 ;
   %shuf = shufflevector <4 x i32> %x, <4 x i32> poison, <2 x i32> <i32 0, i32 2>
   %zext = zext <2 x i32> %shuf to <2 x i64>



More information about the llvm-commits mailing list