[llvm] InstSimplify: lookthru casts, binops in folding shuffles (PR #92668)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Sat May 18 12:29:14 PDT 2024


================
@@ -5315,55 +5315,104 @@ Value *llvm::simplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty,
   return ::simplifyCastInst(CastOpc, Op, Ty, Q, RecursionLimit);
 }
 
+using ReplacementTy = std::optional<std::pair<Value *, Value *>>;
+
 /// For the given destination element of a shuffle, peek through shuffles to
 /// match a root vector source operand that contains that element in the same
 /// vector lane (ie, the same mask index), so we can eliminate the shuffle(s).
-static Value *foldIdentityShuffles(int DestElt, Value *Op0, Value *Op1,
-                                   int MaskVal, Value *RootVec,
-                                   unsigned MaxRecurse) {
+static std::pair<Value *, ReplacementTy>
+foldIdentityShuffles(int DestElt, Value *Op0, Value *Op1, int MaskVal,
+                     Value *RootVec, unsigned MaxRecurse) {
   if (!MaxRecurse--)
-    return nullptr;
+    return {nullptr, std::nullopt};
 
   // Bail out if any mask value is undefined. That kind of shuffle may be
   // simplified further based on demanded bits or other folds.
   if (MaskVal == -1)
-    return nullptr;
+    return {nullptr, std::nullopt};
 
   // The mask value chooses which source operand we need to look at next.
-  int InVecNumElts = cast<FixedVectorType>(Op0->getType())->getNumElements();
+  unsigned InVecNumElts =
+      cast<FixedVectorType>(Op0->getType())->getNumElements();
   int RootElt = MaskVal;
   Value *SourceOp = Op0;
-  if (MaskVal >= InVecNumElts) {
+  if (MaskVal > -1 && static_cast<unsigned>(MaskVal) >= InVecNumElts) {
     RootElt = MaskVal - InVecNumElts;
     SourceOp = Op1;
   }
 
+  // The next RootVec preseves an existing RootVec for casts and binops, so that
+  // the final RootVec is the last cast or binop in the chain.
+  Value *NextRootVec = RootVec ? RootVec : SourceOp;
+
+  // Look through a cast instruction that preserves number of elements in the
+  // vector. Set the RootVec to the cast, the SourceOp to the operand and
+  // recurse. If, in a later stack frame, an appropriate ShuffleVector is
+  // matched, the example will reduce.
+  if (auto *SourceCast = dyn_cast<CastInst>(SourceOp))
+    if (auto *CastTy = dyn_cast<FixedVectorType>(SourceCast->getSrcTy()))
+      if (CastTy->getNumElements() == InVecNumElts)
+        return foldIdentityShuffles(DestElt, SourceCast->getOperand(0), Op1,
+                                    MaskVal, NextRootVec, MaxRecurse);
+
+  // Look through a binary operator, with two identical operands. Set the
+  // RootVec to the binop, the SourceOp to the operand, and recurse. If, in a
+  // later stack frame, an appropriate ShuffleVector is matched, the example
+  // will reduce.
+  Value *BinOpLHS = nullptr, *BinOpRHS = nullptr;
+  if (match(SourceOp, m_BinOp(m_Value(BinOpLHS), m_Value(BinOpRHS))) &&
+      BinOpLHS == BinOpRHS)
+    return foldIdentityShuffles(DestElt, BinOpLHS, Op1, MaskVal, NextRootVec,
+                                MaxRecurse);
+
   // If the source operand is a shuffle itself, look through it to find the
   // matching root vector.
   if (auto *SourceShuf = dyn_cast<ShuffleVectorInst>(SourceOp)) {
+    // Here, we use RootVec, because there is no requirement for finding the
+    // last shuffle in a chain. In fact, the zeroth operand of the first shuffle
+    // in the chain will be used as the RootVec for folding.
     return foldIdentityShuffles(
         DestElt, SourceShuf->getOperand(0), SourceShuf->getOperand(1),
         SourceShuf->getMaskValue(RootElt), RootVec, MaxRecurse);
   }
 
-  // TODO: Look through bitcasts? What if the bitcast changes the vector element
-  // size?
-
-  // The source operand is not a shuffle. Initialize the root vector value for
-  // this shuffle if that has not been done yet.
-  if (!RootVec)
-    RootVec = SourceOp;
-
-  // Give up as soon as a source operand does not match the existing root value.
-  if (RootVec != SourceOp)
-    return nullptr;
-
   // The element must be coming from the same lane in the source vector
   // (although it may have crossed lanes in intermediate shuffles).
   if (RootElt != DestElt)
-    return nullptr;
+    return {nullptr, std::nullopt};
+
+  // If NextRootVec is equal to SourceOp, no replacements are required. Just
+  // return NextRootVec as the leaf value of the recursion.
+  if (NextRootVec == SourceOp)
+    return {NextRootVec, std::nullopt};
+
+  // We again have to match the condition for a vector-num-element-preserving
+  // cast or binop with equal operands, as we are not assured of the recursion
+  // happening from the call after the previous match. The RootVec was set to
+  // the last cast or binop in the chain.
+  if ((isa<CastInst>(RootVec) &&
+       isa<FixedVectorType>(cast<CastInst>(RootVec)->getSrcTy()) &&
----------------
arsenm wrote:

dyn_cast instead of isa+cast

https://github.com/llvm/llvm-project/pull/92668


More information about the llvm-commits mailing list