[llvm] [VectorCombine] foldShuffleOfCastops - handle unary shuffles (PR #160009)

Chaitanya Koparkar via llvm-commits llvm-commits at lists.llvm.org
Sun Sep 21 13:19:38 PDT 2025


https://github.com/ckoparkar created https://github.com/llvm/llvm-project/pull/160009

Fixes #156853.

>From ed4c1dfb83a5d3d33f185559a4820b7d3fefb45e Mon Sep 17 00:00:00 2001
From: Chaitanya Koparkar <ckoparkar at gmail.com>
Date: Sun, 21 Sep 2025 16:01:33 -0400
Subject: [PATCH] [VectorCombine] foldShuffleOfCastops - handle unary shuffles

---
 .../Transforms/Vectorize/VectorCombine.cpp    | 67 ++++++++++++-------
 1 file changed, 44 insertions(+), 23 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 17cb18a22336a..c2a4353b3eb62 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -2477,21 +2477,28 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
   if (!match(&I, m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(OldMask))))
     return false;
 
+  // Check whether this is a unary shuffle.
+  // TODO: should this be extended to match undef or unused values.
+  bool IsBinaryShuffle = !isa<PoisonValue>(V1);
+  LLVM_DEBUG(dbgs() << "Is binary shuffle: " << IsBinaryShuffle << "\n");
+
   auto *C0 = dyn_cast<CastInst>(V0);
   auto *C1 = dyn_cast<CastInst>(V1);
-  if (!C0 || !C1)
+  if (!C0 || (IsBinaryShuffle && !C1))
     return false;
 
   Instruction::CastOps Opcode = C0->getOpcode();
-  if (C0->getSrcTy() != C1->getSrcTy())
-    return false;
 
-  // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
-  if (Opcode != C1->getOpcode()) {
-    if (match(C0, m_SExtLike(m_Value())) && match(C1, m_SExtLike(m_Value())))
-      Opcode = Instruction::SExt;
-    else
+  if (IsBinaryShuffle) {
+    if (C0->getSrcTy() != C1->getSrcTy())
       return false;
+    // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
+    if (Opcode != C1->getOpcode()) {
+      if (match(C0, m_SExtLike(m_Value())) && match(C1, m_SExtLike(m_Value())))
+        Opcode = Instruction::SExt;
+      else
+        return false;
+    }
   }
 
   auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
@@ -2534,23 +2541,31 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
   InstructionCost CostC0 =
       TTI.getCastInstrCost(C0->getOpcode(), CastDstTy, CastSrcTy,
                            TTI::CastContextHint::None, CostKind);
-  InstructionCost CostC1 =
-      TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy,
-                           TTI::CastContextHint::None, CostKind);
-  InstructionCost OldCost = CostC0 + CostC1;
-  OldCost +=
-      TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, ShuffleDstTy,
-                         CastDstTy, OldMask, CostKind, 0, nullptr, {}, &I);
 
-  InstructionCost NewCost =
-      TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, NewShuffleDstTy,
-                         CastSrcTy, NewMask, CostKind);
+  TargetTransformInfo::ShuffleKind ShuffleKind;
+  if (IsBinaryShuffle)
+    ShuffleKind = TargetTransformInfo::SK_PermuteTwoSrc;
+  else
+    ShuffleKind = TargetTransformInfo::SK_PermuteSingleSrc;
+
+  InstructionCost OldCost = CostC0;
+  OldCost += TTI.getShuffleCost(ShuffleKind, ShuffleDstTy, CastDstTy, OldMask,
+                                CostKind, 0, nullptr, {}, &I);
+
+  InstructionCost NewCost = TTI.getShuffleCost(ShuffleKind, NewShuffleDstTy,
+                                               CastSrcTy, NewMask, CostKind);
   NewCost += TTI.getCastInstrCost(Opcode, ShuffleDstTy, NewShuffleDstTy,
                                   TTI::CastContextHint::None, CostKind);
   if (!C0->hasOneUse())
     NewCost += CostC0;
-  if (!C1->hasOneUse())
-    NewCost += CostC1;
+  if (IsBinaryShuffle) {
+    InstructionCost CostC1 =
+        TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy,
+                             TTI::CastContextHint::None, CostKind);
+    OldCost += CostC1;
+    if (!C1->hasOneUse())
+      NewCost += CostC1;
+  }
 
   LLVM_DEBUG(dbgs() << "Found a shuffle feeding two casts: " << I
                     << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
@@ -2558,14 +2573,20 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
   if (NewCost > OldCost)
     return false;
 
-  Value *Shuf = Builder.CreateShuffleVector(C0->getOperand(0),
-                                            C1->getOperand(0), NewMask);
+  Value *Shuf;
+  if (IsBinaryShuffle)
+    Shuf = Builder.CreateShuffleVector(C0->getOperand(0), C1->getOperand(0),
+                                       NewMask);
+  else
+    Shuf = Builder.CreateShuffleVector(C0->getOperand(0), NewMask);
+
   Value *Cast = Builder.CreateCast(Opcode, Shuf, ShuffleDstTy);
 
   // Intersect flags from the old casts.
   if (auto *NewInst = dyn_cast<Instruction>(Cast)) {
     NewInst->copyIRFlags(C0);
-    NewInst->andIRFlags(C1);
+    if (IsBinaryShuffle)
+      NewInst->andIRFlags(C1);
   }
 
   Worklist.pushValue(Shuf);



More information about the llvm-commits mailing list