[llvm] [VectorCombine] foldShuffleOfShuffles - fold "shuffle (shuffle x, y, m1), (shuffle y, x, m2)" -> "shuffle x, y, m3" (PR #120959)

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 23 04:47:22 PST 2024


================
@@ -1883,78 +1884,104 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
     return false;
 
   ArrayRef<int> InnerMask0, InnerMask1;
-  Value *V0 = nullptr, *V1 = nullptr;
-  UndefValue *U0 = nullptr, *U1 = nullptr;
-  bool Match0 = match(
-      OuterV0, m_Shuffle(m_Value(V0), m_UndefValue(U0), m_Mask(InnerMask0)));
-  bool Match1 = match(
-      OuterV1, m_Shuffle(m_Value(V1), m_UndefValue(U1), m_Mask(InnerMask1)));
+  Value *X0, *X1, *Y0, *Y1;
+  bool Match0 =
+      match(OuterV0, m_Shuffle(m_Value(X0), m_Value(Y0), m_Mask(InnerMask0)));
+  bool Match1 =
+      match(OuterV1, m_Shuffle(m_Value(X1), m_Value(Y1), m_Mask(InnerMask1)));
   if (!Match0 && !Match1)
     return false;
 
-  V0 = Match0 ? V0 : OuterV0;
-  V1 = Match1 ? V1 : OuterV1;
+  X0 = Match0 ? X0 : OuterV0;
+  Y0 = Match0 ? Y0 : OuterV0;
+  X1 = Match1 ? X1 : OuterV1;
+  Y1 = Match1 ? Y1 : OuterV1;
   auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
-  auto *ShuffleSrcTy = dyn_cast<FixedVectorType>(V0->getType());
-  auto *ShuffleImmTy = dyn_cast<FixedVectorType>(I.getOperand(0)->getType());
+  auto *ShuffleSrcTy = dyn_cast<FixedVectorType>(X0->getType());
+  auto *ShuffleImmTy = dyn_cast<FixedVectorType>(OuterV0->getType());
   if (!ShuffleDstTy || !ShuffleSrcTy || !ShuffleImmTy ||
-      V0->getType() != V1->getType())
+      X0->getType() != X1->getType())
     return false;
 
   unsigned NumSrcElts = ShuffleSrcTy->getNumElements();
   unsigned NumImmElts = ShuffleImmTy->getNumElements();
 
-  // Bail if either inner masks reference a RHS undef arg.
-  if ((Match0 && !isa<PoisonValue>(U0) &&
-       any_of(InnerMask0, [&](int M) { return M >= (int)NumSrcElts; })) ||
-      (Match1 && !isa<PoisonValue>(U1) &&
-       any_of(InnerMask1, [&](int M) { return M >= (int)NumSrcElts; })))
-    return false;
-
-  // Merge shuffles - replace index to the RHS poison arg with PoisonMaskElem,
+  // Attempt to merge shuffles, matching upto 2 source operands.
+  // Replace index to a poison arg with PoisonMaskElem.
+  // Bail if either inner masks reference an undef arg.
   SmallVector<int, 16> NewMask(OuterMask);
+  Value *NewX = nullptr, *NewY = nullptr;
   for (int &M : NewMask) {
+    Value *Src = nullptr;
     if (0 <= M && M < (int)NumImmElts) {
-      if (Match0)
-        M = (InnerMask0[M] >= (int)NumSrcElts) ? PoisonMaskElem : InnerMask0[M];
+      Src = OuterV0;
+      if (Match0) {
+        M = InnerMask0[M];
+        Src = M >= (int)NumSrcElts ? Y0 : X0;
+        M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M;
+      }
     } else if (M >= (int)NumImmElts) {
+      Src = OuterV1;
+      M -= NumImmElts;
       if (Match1) {
-        if (InnerMask1[M - NumImmElts] >= (int)NumSrcElts)
-          M = PoisonMaskElem;
-        else
-          M = InnerMask1[M - NumImmElts] + (V0 == V1 ? 0 : NumSrcElts);
+        M = InnerMask1[M];
+        Src = M >= (int)NumSrcElts ? Y1 : X1;
+        M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M;
       }
     }
+    if (Src && M != PoisonMaskElem) {
+      assert(0 <= M && M < (int)NumSrcElts && "Unexpected shuffle mask index");
+      if (isa<UndefValue>(Src)) {
+        // We've referenced an undef element - if its poison, update the shuffle
+        // mask, else bail.
+        if (!isa<PoisonValue>(Src))
+          return false;
+        M = PoisonMaskElem;
+        continue;
+      } else if (!NewX || NewX == Src) {
+        NewX = Src;
+        continue;
+      } else if (!NewY || NewY == Src) {
----------------
alexey-bataev wrote:

Drop else here

https://github.com/llvm/llvm-project/pull/120959


More information about the llvm-commits mailing list