[llvm] [SLP]Add cost estimation for gather node reshuffling (PR #115201)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 15 03:03:35 PST 2024


================
@@ -13077,26 +13093,149 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
   // Pair.first is the offset to the vector, while Pair.second is the index of
   // scalar in the list.
   for (const std::pair<unsigned, int> &Pair : EntryLanes) {
-    unsigned Idx = Part * VL.size() + Pair.second;
+    int Idx = Part * VL.size() + Pair.second;
     Mask[Idx] =
         Pair.first * VF +
         (ForOrder ? std::distance(
                         Entries[Pair.first]->Scalars.begin(),
                         find(Entries[Pair.first]->Scalars, VL[Pair.second]))
                   : Entries[Pair.first]->findLaneForValue(VL[Pair.second]));
-    IsIdentity &= Mask[Idx] == Pair.second;
+    IsIdentity &= Mask[Idx] % VL.size() == Idx % VL.size();
   }
-  switch (Entries.size()) {
-  case 1:
-    if (IsIdentity || EntryLanes.size() > 1 || VL.size() <= 2)
-      return TargetTransformInfo::SK_PermuteSingleSrc;
-    break;
-  case 2:
-    if (EntryLanes.size() > 2 || VL.size() <= 2)
-      return TargetTransformInfo::SK_PermuteTwoSrc;
-    break;
-  default:
-    break;
+  if (ForOrder || IsIdentity || Entries.empty()) {
+    switch (Entries.size()) {
+    case 1:
+      if (IsIdentity || EntryLanes.size() > 1 || VL.size() <= 2)
+        return TargetTransformInfo::SK_PermuteSingleSrc;
+      break;
+    case 2:
+      if (EntryLanes.size() > 2 || VL.size() <= 2)
+        return TargetTransformInfo::SK_PermuteTwoSrc;
+      break;
+    default:
+      break;
+    }
+  } else if (!isa<VectorType>(VL.front()->getType()) &&
+             (EntryLanes.size() > Entries.size() || VL.size() <= 2)) {
+    // Do the cost estimation if shuffle beneficial than buildvector.
+    SmallVector<int> SubMask(std::next(Mask.begin(), Part * VL.size()),
+                             std::next(Mask.begin(), (Part + 1) * VL.size()));
+    int MinElement = SubMask.front(), MaxElement = SubMask.front();
+    for (int Idx : SubMask) {
+      if (Idx == PoisonMaskElem)
+        continue;
+      if (MinElement == PoisonMaskElem || MinElement % VF > Idx % VF)
+        MinElement = Idx;
+      if (MaxElement == PoisonMaskElem || MaxElement % VF < Idx % VF)
+        MaxElement = Idx;
+    }
+    assert(MaxElement >= 0 && MinElement >= 0 &&
+           "Expected at least single element.");
----------------
RKSimon wrote:

assert MaxElement >= MinElement?

https://github.com/llvm/llvm-project/pull/115201


More information about the llvm-commits mailing list