[llvm] [VectorCombine] foldShuffleOfBinops - add support for length changing shuffles (PR #88899)

David Green via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 22 00:48:04 PDT 2024


================
@@ -1394,55 +1394,85 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
   return true;
 }
 
-/// Try to convert "shuffle (binop), (binop)" with a shared binop operand into
-/// "binop (shuffle), (shuffle)".
+/// Try to convert "shuffle (binop), (binop)" into "binop (shuffle), (shuffle)".
 bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
-  auto *VecTy = cast<FixedVectorType>(I.getType());
   BinaryOperator *B0, *B1;
-  ArrayRef<int> Mask;
+  ArrayRef<int> OldMask;
   if (!match(&I, m_Shuffle(m_OneUse(m_BinOp(B0)), m_OneUse(m_BinOp(B1)),
-                           m_Mask(Mask))) ||
-      B0->getOpcode() != B1->getOpcode() || B0->getType() != VecTy)
+                           m_Mask(OldMask))))
     return false;
 
-  // Try to replace a binop with a shuffle if the shuffle is not costly.
-  // The new shuffle will choose from a single, common operand, so it may be
-  // cheaper than the existing two-operand shuffle.
-  SmallVector<int> UnaryMask = createUnaryMask(Mask, Mask.size());
+  // TODO: Add support for addlike etc.
   Instruction::BinaryOps Opcode = B0->getOpcode();
-  InstructionCost BinopCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
-  InstructionCost ShufCost = TTI.getShuffleCost(
-      TargetTransformInfo::SK_PermuteSingleSrc, VecTy, UnaryMask);
-  if (ShufCost > BinopCost)
+  if (Opcode != B1->getOpcode())
+    return false;
+
+  auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
+  auto *BinOpTy = dyn_cast<FixedVectorType>(B0->getType());
+  if (!ShuffleDstTy || !BinOpTy)
     return false;
 
+  unsigned NumSrcElts = BinOpTy->getNumElements();
+
   // If we have something like "add X, Y" and "add Z, X", swap ops to match.
   Value *X = B0->getOperand(0), *Y = B0->getOperand(1);
   Value *Z = B1->getOperand(0), *W = B1->getOperand(1);
   if (BinaryOperator::isCommutative(Opcode) && X != Z && Y != W)
     std::swap(X, Y);
 
-  Value *Shuf0, *Shuf1;
+  auto ConvertToUnary = [NumSrcElts](int &M) {
+    if (M >= (int)NumSrcElts)
+      M -= NumSrcElts;
+  };
+
+  SmallVector<int> NewMask0(OldMask.begin(), OldMask.end());
+  TargetTransformInfo::ShuffleKind SK0 = TargetTransformInfo::SK_PermuteTwoSrc;
   if (X == Z) {
-    // shuf (bo X, Y), (bo X, W) --> bo (shuf X), (shuf Y, W)
-    Shuf0 = Builder.CreateShuffleVector(X, UnaryMask);
-    Shuf1 = Builder.CreateShuffleVector(Y, W, Mask);
-  } else if (Y == W) {
-    // shuf (bo X, Y), (bo Z, Y) --> bo (shuf X, Z), (shuf Y)
-    Shuf0 = Builder.CreateShuffleVector(X, Z, Mask);
-    Shuf1 = Builder.CreateShuffleVector(Y, UnaryMask);
-  } else {
-    return false;
+    llvm::for_each(NewMask0, ConvertToUnary);
+    SK0 = TargetTransformInfo::SK_PermuteSingleSrc;
+    Z = PoisonValue::get(BinOpTy);
+  }
+
+  SmallVector<int> NewMask1(OldMask.begin(), OldMask.end());
+  TargetTransformInfo::ShuffleKind SK1 = TargetTransformInfo::SK_PermuteTwoSrc;
+  if (Y == W) {
+    llvm::for_each(NewMask1, ConvertToUnary);
+    SK1 = TargetTransformInfo::SK_PermuteSingleSrc;
+    W = PoisonValue::get(BinOpTy);
   }
 
+  // Try to replace a binop with a shuffle if the shuffle is not costly.
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+
+  InstructionCost OldCost =
+      TTI.getArithmeticInstrCost(B0->getOpcode(), BinOpTy) +
+      TTI.getArithmeticInstrCost(B1->getOpcode(), BinOpTy) +
+      TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, BinOpTy,
+                         OldMask, CostKind, 0, nullptr, std::nullopt, &I);
+
+  InstructionCost NewCost =
+      TTI.getShuffleCost(SK0, BinOpTy, NewMask0, CostKind) +
+      TTI.getShuffleCost(SK1, BinOpTy, NewMask1, CostKind) +
+      TTI.getArithmeticInstrCost(Opcode, ShuffleDstTy);
+
+  LLVM_DEBUG(dbgs() << "Found a shuffle feeding two binops: " << I
+                    << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
+                    << "\n");
+  if (NewCost > OldCost)
----------------
davemgreen wrote:

Would >= be better?

https://github.com/llvm/llvm-project/pull/88899


More information about the llvm-commits mailing list