[llvm] [VectorCombine] foldShuffleOfCastops - handle unary shuffles (PR #160009)

via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 23 06:14:28 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-vectorizers

Author: Chaitanya Koparkar (ckoparkar)

<details>
<summary>Changes</summary>

Fixes #<!-- -->156853.

---
Full diff: https://github.com/llvm/llvm-project/pull/160009.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+44-23) 
- (modified) llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll (+8-9) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 0ef933f596604..526a4add2a89a 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -2487,21 +2487,28 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
   if (!match(&I, m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(OldMask))))
     return false;
 
+  // Check whether this is a unary shuffle.
+  // TODO: check if this can be extended to match undef or unused values,
+  // perhaps using ShuffleVectorInst::isSingleSource.
+  bool IsBinaryShuffle = !isa<PoisonValue>(V1);
+
   auto *C0 = dyn_cast<CastInst>(V0);
   auto *C1 = dyn_cast<CastInst>(V1);
-  if (!C0 || !C1)
+  if (!C0 || (IsBinaryShuffle && !C1))
     return false;
 
   Instruction::CastOps Opcode = C0->getOpcode();
-  if (C0->getSrcTy() != C1->getSrcTy())
-    return false;
 
-  // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
-  if (Opcode != C1->getOpcode()) {
-    if (match(C0, m_SExtLike(m_Value())) && match(C1, m_SExtLike(m_Value())))
-      Opcode = Instruction::SExt;
-    else
+  if (IsBinaryShuffle) {
+    if (C0->getSrcTy() != C1->getSrcTy())
       return false;
+    // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
+    if (Opcode != C1->getOpcode()) {
+      if (match(C0, m_SExtLike(m_Value())) && match(C1, m_SExtLike(m_Value())))
+        Opcode = Instruction::SExt;
+      else
+        return false;
+    }
   }
 
   auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
@@ -2544,23 +2551,31 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
   InstructionCost CostC0 =
       TTI.getCastInstrCost(C0->getOpcode(), CastDstTy, CastSrcTy,
                            TTI::CastContextHint::None, CostKind);
-  InstructionCost CostC1 =
-      TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy,
-                           TTI::CastContextHint::None, CostKind);
-  InstructionCost OldCost = CostC0 + CostC1;
-  OldCost +=
-      TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, ShuffleDstTy,
-                         CastDstTy, OldMask, CostKind, 0, nullptr, {}, &I);
 
-  InstructionCost NewCost =
-      TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, NewShuffleDstTy,
-                         CastSrcTy, NewMask, CostKind);
+  TargetTransformInfo::ShuffleKind ShuffleKind;
+  if (IsBinaryShuffle)
+    ShuffleKind = TargetTransformInfo::SK_PermuteTwoSrc;
+  else
+    ShuffleKind = TargetTransformInfo::SK_PermuteSingleSrc;
+
+  InstructionCost OldCost = CostC0;
+  OldCost += TTI.getShuffleCost(ShuffleKind, ShuffleDstTy, CastDstTy, OldMask,
+                                CostKind, 0, nullptr, {}, &I);
+
+  InstructionCost NewCost = TTI.getShuffleCost(ShuffleKind, NewShuffleDstTy,
+                                               CastSrcTy, NewMask, CostKind);
   NewCost += TTI.getCastInstrCost(Opcode, ShuffleDstTy, NewShuffleDstTy,
                                   TTI::CastContextHint::None, CostKind);
   if (!C0->hasOneUse())
     NewCost += CostC0;
-  if (!C1->hasOneUse())
-    NewCost += CostC1;
+  if (IsBinaryShuffle) {
+    InstructionCost CostC1 =
+        TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy,
+                             TTI::CastContextHint::None, CostKind);
+    OldCost += CostC1;
+    if (!C1->hasOneUse())
+      NewCost += CostC1;
+  }
 
   LLVM_DEBUG(dbgs() << "Found a shuffle feeding two casts: " << I
                     << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
@@ -2568,14 +2583,20 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
   if (NewCost > OldCost)
     return false;
 
-  Value *Shuf = Builder.CreateShuffleVector(C0->getOperand(0),
-                                            C1->getOperand(0), NewMask);
+  Value *Shuf;
+  if (IsBinaryShuffle)
+    Shuf = Builder.CreateShuffleVector(C0->getOperand(0), C1->getOperand(0),
+                                       NewMask);
+  else
+    Shuf = Builder.CreateShuffleVector(C0->getOperand(0), NewMask);
+
   Value *Cast = Builder.CreateCast(Opcode, Shuf, ShuffleDstTy);
 
   // Intersect flags from the old casts.
   if (auto *NewInst = dyn_cast<Instruction>(Cast)) {
     NewInst->copyIRFlags(C0);
-    NewInst->andIRFlags(C1);
+    if (IsBinaryShuffle)
+      NewInst->andIRFlags(C1);
   }
 
   Worklist.pushValue(Shuf);
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll b/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll
index acbc836ffcab0..ed29719d49493 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll
@@ -205,8 +205,8 @@ define <8 x i8> @abs_different(<8 x i8> %a) {
 define <4 x i32> @poison_intrinsic(<2 x i16> %l256) {
 ; CHECK-LABEL: @poison_intrinsic(
 ; CHECK-NEXT:    [[L266:%.*]] = call <2 x i16> @llvm.abs.v2i16(<2 x i16> [[L256:%.*]], i1 false)
-; CHECK-NEXT:    [[L267:%.*]] = zext <2 x i16> [[L266]] to <2 x i32>
-; CHECK-NEXT:    [[L271:%.*]] = shufflevector <2 x i32> [[L267]], <2 x i32> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
+; CHECK-NEXT:    [[L267:%.*]] = shufflevector <2 x i16> [[L266]], <2 x i16> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
+; CHECK-NEXT:    [[L271:%.*]] = zext <4 x i16> [[L267]] to <4 x i32>
 ; CHECK-NEXT:    ret <4 x i32> [[L271]]
 ;
   %l266 = call <2 x i16> @llvm.abs.v2i16(<2 x i16> %l256, i1 false)
@@ -534,9 +534,9 @@ define <4 x i64> @single_zext(<4 x i32> %x) {
 
 define <4 x i64> @not_zext(<4 x i32> %x) {
 ; CHECK-LABEL: @not_zext(
-; CHECK-NEXT:    [[ZEXT:%.*]] = zext <4 x i32> [[X:%.*]] to <4 x i64>
-; CHECK-NEXT:    [[REVSHUF:%.*]] = shufflevector <4 x i64> [[ZEXT]], <4 x i64> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
-; CHECK-NEXT:    ret <4 x i64> [[REVSHUF]]
+; CHECK-NEXT:    [[REVSHUF:%.*]] = shufflevector <4 x i32> [[X]], <4 x i32> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
+; CHECK-NEXT:    [[ZEXT:%.*]] = zext <4 x i32> [[REVSHUF:%.*]] to <4 x i64>
+; CHECK-NEXT:    ret <4 x i64> [[ZEXT]]
 ;
   %zext = zext <4 x i32> %x to <4 x i64>
   %revshuf = shufflevector <4 x i64> %zext, <4 x i64> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
@@ -922,10 +922,9 @@ define <4 x i8> @singleop(<4 x i8> %a, <4 x i8> %b) {
 
 define <4 x i64> @cast_mismatched_types(<4 x i32> %x) {
 ; CHECK-LABEL: @cast_mismatched_types(
-; CHECK-NEXT:    [[SHUF:%.*]] = shufflevector <4 x i32> [[X:%.*]], <4 x i32> poison, <2 x i32> <i32 0, i32 2>
-; CHECK-NEXT:    [[ZEXT:%.*]] = zext <2 x i32> [[SHUF]] to <2 x i64>
-; CHECK-NEXT:    [[EXTSHUF:%.*]] = shufflevector <2 x i64> [[ZEXT]], <2 x i64> poison, <4 x i32> <i32 0, i32 2, i32 1, i32 3>
-; CHECK-NEXT:    ret <4 x i64> [[EXTSHUF]]
+; CHECK-SAME: <4 x i32> [[X:%.*]]) {
+; CHECK-NEXT:    [[ZEXT:%.*]] = zext <4 x i32> [[X]] to <4 x i64>
+; CHECK-NEXT:    ret <4 x i64> [[ZEXT]]
 ;
   %shuf = shufflevector <4 x i32> %x, <4 x i32> poison, <2 x i32> <i32 0, i32 2>
   %zext = zext <2 x i32> %shuf to <2 x i64>

``````````

</details>


https://github.com/llvm/llvm-project/pull/160009


More information about the llvm-commits mailing list