[llvm] a403ad9 - [VectorCombine] foldBitcastShuffle - limit bitcast(shuffle(x,y)) -> shuffle(bitcast(x),bitcast(y))

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 11 03:43:29 PDT 2024


Author: Simon Pilgrim
Date: 2024-04-11T11:43:11+01:00
New Revision: a403ad9336a24c459ee79d2cb7675c4b1f32bfd9

URL: https://github.com/llvm/llvm-project/commit/a403ad9336a24c459ee79d2cb7675c4b1f32bfd9
DIFF: https://github.com/llvm/llvm-project/commit/a403ad9336a24c459ee79d2cb7675c4b1f32bfd9.diff

LOG: [VectorCombine] foldBitcastShuffle - limit bitcast(shuffle(x,y)) -> shuffle(bitcast(x),bitcast(y))

Only fold bitcast(shuffle(x,y)) -> shuffle(bitcast(x),bitcast(y)) if we won't actually increase the number of bitcasts (i.e. x or y is already bitcasted from the correct type).

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/VectorCombine.cpp
    llvm/test/Transforms/VectorCombine/X86/shuffle.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 633b46e2dc8ba6..2f9767538e6cb1 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -713,6 +713,18 @@ bool VectorCombine::foldBitcastShuffle(Instruction &I) {
   if (SrcTy->getPrimitiveSizeInBits() % DestEltSize != 0)
     return false;
 
+  bool IsUnary = isa<UndefValue>(V1);
+
+  // For binary shuffles, only fold bitcast(shuffle(X,Y))
+  // if it won't increase the number of bitcasts.
+  if (!IsUnary) {
+    auto *BCTy0 = dyn_cast<FixedVectorType>(peekThroughBitcasts(V0)->getType());
+    auto *BCTy1 = dyn_cast<FixedVectorType>(peekThroughBitcasts(V1)->getType());
+    if (!(BCTy0 && BCTy0->getElementType() == DestTy->getElementType()) &&
+        !(BCTy1 && BCTy1->getElementType() == DestTy->getElementType()))
+      return false;
+  }
+
   SmallVector<int, 16> NewMask;
   if (DestEltSize <= SrcEltSize) {
     // The bitcast is from wide to narrow/equal elements. The shuffle mask can
@@ -736,7 +748,6 @@ bool VectorCombine::foldBitcastShuffle(Instruction &I) {
       FixedVectorType::get(DestTy->getScalarType(), NumSrcElts);
   auto *OldShuffleTy =
       FixedVectorType::get(SrcTy->getScalarType(), Mask.size());
-  bool IsUnary = isa<UndefValue>(V1);
   unsigned NumOps = IsUnary ? 1 : 2;
 
   // The new shuffle must not cost more than the old shuffle.

diff  --git a/llvm/test/Transforms/VectorCombine/X86/shuffle.ll b/llvm/test/Transforms/VectorCombine/X86/shuffle.ll
index 5020d37f86f565..3d47f373ab77c2 100644
--- a/llvm/test/Transforms/VectorCombine/X86/shuffle.ll
+++ b/llvm/test/Transforms/VectorCombine/X86/shuffle.ll
@@ -149,13 +149,12 @@ define <8 x i32> @bitcast_shuf_one_bitcast(<4 x i32> %a0, <2 x i64> %a1) {
   ret <8 x i32> %r
 }
 
-; TODO - Negative test - shuffle of 2 operands must not increase bitcasts
+; Negative test - shuffle of 2 operands must not increase bitcasts
 
 define <8 x i32> @bitcast_shuf_too_many_bitcasts(<2 x i64> %a0, <2 x i64> %a1) {
 ; CHECK-LABEL: @bitcast_shuf_too_many_bitcasts(
-; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <2 x i64> [[A0:%.*]] to <4 x i32>
-; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <2 x i64> [[A1:%.*]] to <4 x i32>
-; CHECK-NEXT:    [[R:%.*]] = shufflevector <4 x i32> [[TMP1]], <4 x i32> [[TMP2]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[SHUF:%.*]] = shufflevector <2 x i64> [[A0:%.*]], <2 x i64> [[A1:%.*]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[R:%.*]] = bitcast <4 x i64> [[SHUF]] to <8 x i32>
 ; CHECK-NEXT:    ret <8 x i32> [[R]]
 ;
   %shuf = shufflevector <2 x i64> %a0, <2 x i64> %a1, <4 x i32> <i32 0, i32 1, i32 2, i32 3>


        


More information about the llvm-commits mailing list