[llvm] [VectorCombine] Add foldShuffleToIdentity (PR #88693)

David Green via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 24 07:17:42 PDT 2024


================
@@ -1547,6 +1548,147 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
   return true;
 }
 
+// Starting from a shuffle, look up through operands tracking the shuffled index
+// of each lane. If we can simplify away the shuffles to identities then
+// do so.
+bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
+  auto *Ty = dyn_cast<FixedVectorType>(I.getType());
+  if (!Ty || !isa<Instruction>(I.getOperand(0)) ||
+      !isa<Instruction>(I.getOperand(1)))
+    return false;
+
+  using InstLane = std::pair<Value *, int>;
+
+  auto LookThroughShuffles = [](Value *V, int Lane) -> InstLane {
+    while (auto *SV = dyn_cast<ShuffleVectorInst>(V)) {
+      unsigned NumElts =
+          cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements();
+      int M = SV->getMaskValue(Lane);
+      if (M < 0)
+        return {nullptr, PoisonMaskElem};
+      else if (M < (int)NumElts) {
+        V = SV->getOperand(0);
+        Lane = M;
+      } else {
+        V = SV->getOperand(1);
+        Lane = M - NumElts;
+      }
+    }
+    return InstLane{V, Lane};
+  };
+
+  auto GenerateInstLaneVectorFromOperand =
+      [&LookThroughShuffles](ArrayRef<InstLane> Item, int Op) {
+        SmallVector<InstLane> NItem;
+        for (InstLane V : Item) {
+          NItem.emplace_back(
+              !V.first
+                  ? InstLane{nullptr, PoisonMaskElem}
+                  : LookThroughShuffles(
+                        cast<Instruction>(V.first)->getOperand(Op), V.second));
+        }
+        return NItem;
+      };
+
+  SmallVector<InstLane> Start;
+  for (unsigned M = 0; M < Ty->getNumElements(); ++M)
+    Start.push_back(LookThroughShuffles(&I, M));
+
+  SmallVector<SmallVector<InstLane>> Worklist;
+  Worklist.push_back(Start);
+  SmallPtrSet<Value *, 4> IdentityLeafs, SplatLeafs;
+  unsigned NumVisited = 0;
+
+  while (!Worklist.empty()) {
+    SmallVector<InstLane> Item = Worklist.pop_back_val();
+    if (++NumVisited > MaxInstrsToScan)
+      return false;
+
+    // If we found an undef first lane then bail out to keep things simple.
+    if (!Item[0].first)
+      return false;
+
+    // Look for an identity value.
+    if (Item[0].second == 0 && Item[0].first->getType() == Ty &&
+        all_of(drop_begin(enumerate(Item)), [&](const auto &E) {
+          return !E.value().first || (E.value().first == Item[0].first &&
+                                      E.value().second == (int)E.index());
+        })) {
+      IdentityLeafs.insert(Item[0].first);
+      continue;
+    }
+    // Look for a splat value.
+    if (all_of(drop_begin(Item), [&](InstLane &IL) {
+          return !IL.first ||
+                 (IL.first == Item[0].first && IL.second == Item[0].second);
+        })) {
+      SplatLeafs.insert(Item[0].first);
+      continue;
+    }
+
+    // We need each element to be the same type of value, and check that each
+    // element has a single use.
+    if (!all_of(drop_begin(Item), [&](InstLane IL) {
+          if (!IL.first)
+            return true;
+          if (auto *I = dyn_cast<Instruction>(IL.first); I && !I->hasOneUse())
+            return false;
+          return IL.first->getValueID() == Item[0].first->getValueID() &&
+                 (!isa<IntrinsicInst>(IL.first) ||
+                  cast<IntrinsicInst>(IL.first)->getIntrinsicID() ==
+                      cast<IntrinsicInst>(Item[0].first)->getIntrinsicID());
+        }))
+      return false;
+
+    // Check the operator is one that we support.
+    if (isa<BinaryOperator>(Item[0].first)) {
+      Worklist.push_back(GenerateInstLaneVectorFromOperand(Item, 0));
+      Worklist.push_back(GenerateInstLaneVectorFromOperand(Item, 1));
+    } else if (isa<UnaryOperator>(Item[0].first)) {
+      Worklist.push_back(GenerateInstLaneVectorFromOperand(Item, 0));
+    } else {
+      return false;
+    }
+  }
+
+  // If we got this far, we know the shuffles are superfluous and can be
+  // removed. Scan through again and generate the new tree of instructions.
+  std::function<Value *(ArrayRef<InstLane>)> generate =
+      [&](ArrayRef<InstLane> Item) -> Value * {
+    if (IdentityLeafs.contains(Item[0].first) &&
+        all_of(drop_begin(enumerate(Item)), [&](const auto &E) {
+          return !E.value().first || (E.value().first == Item[0].first &&
+                                      E.value().second == (int)E.index());
+        })) {
+      return Item[0].first;
+    } else if (SplatLeafs.contains(Item[0].first)) {
+      if (auto ILI = dyn_cast<Instruction>(Item[0].first))
+        Builder.SetInsertPoint(*ILI->getInsertionPointAfterDef());
+      else if (isa<Argument>(Item[0].first))
+        Builder.SetInsertPointPastAllocas(I.getParent()->getParent());
+      SmallVector<int, 16> Mask(Ty->getNumElements(), Item[0].second);
+      return Builder.CreateShuffleVector(Item[0].first, Mask);
+    }
+
+    auto *I = cast<Instruction>(Item[0].first);
+    SmallVector<Value *> Ops;
+    unsigned E = I->getNumOperands();
+    for (unsigned Idx = 0; Idx < E; Idx++)
+      Ops.push_back(generate(GenerateInstLaneVectorFromOperand(Item, Idx)));
+    Builder.SetInsertPoint(I);
+    if (auto BI = dyn_cast<BinaryOperator>(I))
+      return Builder.CreateBinOp((Instruction::BinaryOps)BI->getOpcode(),
+                                 Ops[0], Ops[1]);
+    if (auto UI = dyn_cast<UnaryOperator>(I))
+      return Builder.CreateUnOp((Instruction::UnaryOps)UI->getOpcode(), Ops[0]);
+    llvm_unreachable("Unhandled instruction in generate");
----------------
davemgreen wrote:

I was planning to add a number of other types added here in follow patches. This one just contains the basics to make sure the general ideal is working OK. I can make it so that UnaryInstructions are left as the else (even if I personally prefer the structure of keeping them all consistent :) I find it more readable. )

https://github.com/llvm/llvm-project/pull/88693


More information about the llvm-commits mailing list