[llvm] fe2119a - [VectorCombine] foldBitcastShuffle - include the cost of bitcasts in the comparison

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 20 03:56:48 PDT 2024


Author: Simon Pilgrim
Date: 2024-03-20T10:56:38Z
New Revision: fe2119a7b08b6e468b2a67768904ea85b1bf0a45

URL: https://github.com/llvm/llvm-project/commit/fe2119a7b08b6e468b2a67768904ea85b1bf0a45
DIFF: https://github.com/llvm/llvm-project/commit/fe2119a7b08b6e468b2a67768904ea85b1bf0a45.diff

LOG: [VectorCombine] foldBitcastShuffle - include the cost of bitcasts in the comparison

This makes no real difference currently as we only fold unary shuffles, but I'm hoping to handle binary shuffles in a future patch.

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 85c8d3996bba51..0b16a8b7676923 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -684,10 +684,10 @@ bool VectorCombine::foldInsExtFNeg(Instruction &I) {
 /// destination type followed by shuffle. This can enable further transforms by
 /// moving bitcasts or shuffles together.
 bool VectorCombine::foldBitcastShuffle(Instruction &I) {
-  Value *V;
+  Value *V0;
   ArrayRef<int> Mask;
-  if (!match(&I, m_BitCast(
-                     m_OneUse(m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask))))))
+  if (!match(&I, m_BitCast(m_OneUse(
+                     m_Shuffle(m_Value(V0), m_Undef(), m_Mask(Mask))))))
     return false;
 
   // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for
@@ -696,7 +696,7 @@ bool VectorCombine::foldBitcastShuffle(Instruction &I) {
   // 2) Disallow non-vector casts.
   // TODO: We could allow any shuffle.
   auto *DestTy = dyn_cast<FixedVectorType>(I.getType());
-  auto *SrcTy = dyn_cast<FixedVectorType>(V->getType());
+  auto *SrcTy = dyn_cast<FixedVectorType>(V0->getType());
   if (!DestTy || !SrcTy)
     return false;
 
@@ -724,20 +724,31 @@ bool VectorCombine::foldBitcastShuffle(Instruction &I) {
   // Bitcast the shuffle src - keep its original width but using the destination
   // scalar type.
   unsigned NumSrcElts = SrcTy->getPrimitiveSizeInBits() / DestEltSize;
-  auto *ShuffleTy = FixedVectorType::get(DestTy->getScalarType(), NumSrcElts);
-
-  // The new shuffle must not cost more than the old shuffle. The bitcast is
-  // moved ahead of the shuffle, so assume that it has the same cost as before.
-  InstructionCost DestCost = TTI.getShuffleCost(
-      TargetTransformInfo::SK_PermuteSingleSrc, ShuffleTy, NewMask);
+  auto *NewShuffleTy =
+      FixedVectorType::get(DestTy->getScalarType(), NumSrcElts);
+  auto *OldShuffleTy =
+      FixedVectorType::get(SrcTy->getScalarType(), Mask.size());
+
+  // The new shuffle must not cost more than the old shuffle.
+  TargetTransformInfo::TargetCostKind CK =
+      TargetTransformInfo::TCK_RecipThroughput;
+  TargetTransformInfo::ShuffleKind SK =
+      TargetTransformInfo::SK_PermuteSingleSrc;
+
+  InstructionCost DestCost =
+      TTI.getShuffleCost(SK, NewShuffleTy, NewMask, CK) +
+      TTI.getCastInstrCost(Instruction::BitCast, NewShuffleTy, SrcTy,
+                           TargetTransformInfo::CastContextHint::None, CK);
   InstructionCost SrcCost =
-      TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, SrcTy, Mask);
+      TTI.getShuffleCost(SK, SrcTy, Mask, CK) +
+      TTI.getCastInstrCost(Instruction::BitCast, DestTy, OldShuffleTy,
+                           TargetTransformInfo::CastContextHint::None, CK);
   if (DestCost > SrcCost || !DestCost.isValid())
     return false;
 
-  // bitcast (shuf V, MaskC) --> shuf (bitcast V), MaskC'
+  // bitcast (shuf V0, MaskC) --> shuf (bitcast V0), MaskC'
   ++NumShufOfBitcast;
-  Value *CastV = Builder.CreateBitCast(V, ShuffleTy);
+  Value *CastV = Builder.CreateBitCast(V0, NewShuffleTy);
   Value *Shuf = Builder.CreateShuffleVector(CastV, NewMask);
   replaceValue(I, *Shuf);
   return true;


        


More information about the llvm-commits mailing list