[clang] 6fd2fdc - [VectorCombine] foldShuffleOfCastops - extend shuffle(bitcast(x),bitcast(y)) -> bitcast(shuffle(x,y)) support

Simon Pilgrim via cfe-commits cfe-commits at lists.llvm.org
Thu Apr 11 06:07:20 PDT 2024


Author: Simon Pilgrim
Date: 2024-04-11T14:02:56+01:00
New Revision: 6fd2fdccf2f28fc155f614eec41f785492aad618

URL: https://github.com/llvm/llvm-project/commit/6fd2fdccf2f28fc155f614eec41f785492aad618
DIFF: https://github.com/llvm/llvm-project/commit/6fd2fdccf2f28fc155f614eec41f785492aad618.diff

LOG: [VectorCombine] foldShuffleOfCastops - extend shuffle(bitcast(x),bitcast(y)) -> bitcast(shuffle(x,y)) support

Handle shuffle mask scaling handling for cases where the bitcast src/dst element counts are different

Added: 
    

Modified: 
    clang/test/CodeGen/X86/avx-shuffle-builtins.c
    llvm/lib/Transforms/Vectorize/VectorCombine.cpp
    llvm/test/Transforms/PhaseOrdering/X86/pr67803.ll
    llvm/test/Transforms/VectorCombine/X86/shuffle-of-casts.ll
    llvm/test/Transforms/VectorCombine/X86/shuffle.ll

Removed: 
    


################################################################################
diff  --git a/clang/test/CodeGen/X86/avx-shuffle-builtins.c b/clang/test/CodeGen/X86/avx-shuffle-builtins.c
index 49a56e73230d7d..d184d28f3e07aa 100644
--- a/clang/test/CodeGen/X86/avx-shuffle-builtins.c
+++ b/clang/test/CodeGen/X86/avx-shuffle-builtins.c
@@ -61,8 +61,7 @@ __m256 test_mm256_permute2f128_ps(__m256 a, __m256 b) {
 
 __m256i test_mm256_permute2f128_si256(__m256i a, __m256i b) {
   // CHECK-LABEL: test_mm256_permute2f128_si256
-  // X64: shufflevector{{.*}}<i32 0, i32 1, i32 4, i32 5>
-  // X86: shufflevector{{.*}}<i32 0, i32 1, i32 2, i32 3, i32 8, i32 9, i32 10, i32 11>
+  // CHECK: shufflevector{{.*}}<i32 0, i32 1, i32 4, i32 5>
   return _mm256_permute2f128_si256(a, b, 0x20);
 }
 

diff  --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index b74fdf27d213a1..658e8e74fe5b80 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1448,9 +1448,9 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
 /// into "castop (shuffle)".
 bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
   Value *V0, *V1;
-  ArrayRef<int> Mask;
+  ArrayRef<int> OldMask;
   if (!match(&I, m_Shuffle(m_OneUse(m_Value(V0)), m_OneUse(m_Value(V1)),
-                           m_Mask(Mask))))
+                           m_Mask(OldMask))))
     return false;
 
   auto *C0 = dyn_cast<CastInst>(V0);
@@ -1473,12 +1473,32 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
   auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
   auto *CastDstTy = dyn_cast<FixedVectorType>(C0->getDestTy());
   auto *CastSrcTy = dyn_cast<FixedVectorType>(C0->getSrcTy());
-  if (!ShuffleDstTy || !CastDstTy || !CastSrcTy ||
-      CastDstTy->getElementCount() != CastSrcTy->getElementCount())
+  if (!ShuffleDstTy || !CastDstTy || !CastSrcTy)
     return false;
 
+  unsigned NumSrcElts = CastSrcTy->getNumElements();
+  unsigned NumDstElts = CastDstTy->getNumElements();
+  assert((NumDstElts == NumSrcElts || Opcode == Instruction::BitCast) &&
+         "Only bitcasts expected to alter src/dst element counts");
+
+  SmallVector<int, 16> NewMask;
+  if (NumSrcElts >= NumDstElts) {
+    // The bitcast is from wide to narrow/equal elements. The shuffle mask can
+    // always be expanded to the equivalent form choosing narrower elements.
+    assert(NumSrcElts % NumDstElts == 0 && "Unexpected shuffle mask");
+    unsigned ScaleFactor = NumSrcElts / NumDstElts;
+    narrowShuffleMaskElts(ScaleFactor, OldMask, NewMask);
+  } else {
+    // The bitcast is from narrow elements to wide elements. The shuffle mask
+    // must choose consecutive elements to allow casting first.
+    assert(NumDstElts % NumSrcElts == 0 && "Unexpected shuffle mask");
+    unsigned ScaleFactor = NumDstElts / NumSrcElts;
+    if (!widenShuffleMaskElts(ScaleFactor, OldMask, NewMask))
+      return false;
+  }
+
   auto *NewShuffleDstTy =
-      FixedVectorType::get(CastSrcTy->getScalarType(), Mask.size());
+      FixedVectorType::get(CastSrcTy->getScalarType(), NewMask.size());
 
   // Try to replace a castop with a shuffle if the shuffle is not costly.
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
@@ -1489,11 +1509,11 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
       TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy,
                            TTI::CastContextHint::None, CostKind);
   OldCost +=
-      TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, CastDstTy, Mask,
-                         CostKind, 0, nullptr, std::nullopt, &I);
+      TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, CastDstTy,
+                         OldMask, CostKind, 0, nullptr, std::nullopt, &I);
 
   InstructionCost NewCost = TTI.getShuffleCost(
-      TargetTransformInfo::SK_PermuteTwoSrc, CastSrcTy, Mask, CostKind);
+      TargetTransformInfo::SK_PermuteTwoSrc, CastSrcTy, NewMask, CostKind);
   NewCost += TTI.getCastInstrCost(Opcode, ShuffleDstTy, NewShuffleDstTy,
                                   TTI::CastContextHint::None, CostKind);
 
@@ -1503,8 +1523,8 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
   if (NewCost > OldCost)
     return false;
 
-  Value *Shuf =
-      Builder.CreateShuffleVector(C0->getOperand(0), C1->getOperand(0), Mask);
+  Value *Shuf = Builder.CreateShuffleVector(C0->getOperand(0),
+                                            C1->getOperand(0), NewMask);
   Value *Cast = Builder.CreateCast(Opcode, Shuf, ShuffleDstTy);
 
   // Intersect flags from the old casts.

diff  --git a/llvm/test/Transforms/PhaseOrdering/X86/pr67803.ll b/llvm/test/Transforms/PhaseOrdering/X86/pr67803.ll
index 45e411d733169a..36535264bd0f29 100644
--- a/llvm/test/Transforms/PhaseOrdering/X86/pr67803.ll
+++ b/llvm/test/Transforms/PhaseOrdering/X86/pr67803.ll
@@ -17,7 +17,6 @@ define <4 x i64> @PR67803(<4 x i64> %x, <4 x i64> %y, <4 x i64> %a, <4 x i64> %b
 ; CHECK-NEXT:    [[TMP9:%.*]] = bitcast <8 x i32> [[TMP3]] to <32 x i8>
 ; CHECK-NEXT:    [[TMP10:%.*]] = shufflevector <32 x i8> [[TMP9]], <32 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
 ; CHECK-NEXT:    [[TMP11:%.*]] = tail call <16 x i8> @llvm.x86.sse41.pblendvb(<16 x i8> [[TMP6]], <16 x i8> [[TMP8]], <16 x i8> [[TMP10]])
-; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <16 x i8> [[TMP11]] to <2 x i64>
 ; CHECK-NEXT:    [[TMP13:%.*]] = bitcast <4 x i64> [[A]] to <32 x i8>
 ; CHECK-NEXT:    [[TMP14:%.*]] = shufflevector <32 x i8> [[TMP13]], <32 x i8> poison, <16 x i32> <i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31>
 ; CHECK-NEXT:    [[TMP15:%.*]] = bitcast <4 x i64> [[B]] to <32 x i8>
@@ -25,8 +24,8 @@ define <4 x i64> @PR67803(<4 x i64> %x, <4 x i64> %y, <4 x i64> %a, <4 x i64> %b
 ; CHECK-NEXT:    [[TMP17:%.*]] = bitcast <8 x i32> [[TMP3]] to <32 x i8>
 ; CHECK-NEXT:    [[TMP18:%.*]] = shufflevector <32 x i8> [[TMP17]], <32 x i8> poison, <16 x i32> <i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31>
 ; CHECK-NEXT:    [[TMP19:%.*]] = tail call <16 x i8> @llvm.x86.sse41.pblendvb(<16 x i8> [[TMP14]], <16 x i8> [[TMP16]], <16 x i8> [[TMP18]])
-; CHECK-NEXT:    [[TMP20:%.*]] = bitcast <16 x i8> [[TMP19]] to <2 x i64>
-; CHECK-NEXT:    [[SHUFFLE_I23:%.*]] = shufflevector <2 x i64> [[TMP12]], <2 x i64> [[TMP20]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[TMP20:%.*]] = shufflevector <16 x i8> [[TMP11]], <16 x i8> [[TMP19]], <32 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15, i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31>
+; CHECK-NEXT:    [[SHUFFLE_I23:%.*]] = bitcast <32 x i8> [[TMP20]] to <4 x i64>
 ; CHECK-NEXT:    ret <4 x i64> [[SHUFFLE_I23]]
 ;
 entry:

diff  --git a/llvm/test/Transforms/VectorCombine/X86/shuffle-of-casts.ll b/llvm/test/Transforms/VectorCombine/X86/shuffle-of-casts.ll
index 97fceacd827580..4f8cfea146eade 100644
--- a/llvm/test/Transforms/VectorCombine/X86/shuffle-of-casts.ll
+++ b/llvm/test/Transforms/VectorCombine/X86/shuffle-of-casts.ll
@@ -179,13 +179,12 @@ define <8 x float> @concat_bitcast_v4i32_v8f32(<4 x i32> %a0, <4 x i32> %a1) {
   ret <8 x float> %r
 }
 
-; TODO - bitcasts (lower element count)
+; bitcasts (lower element count)
 
 define <4 x double> @concat_bitcast_v8i16_v4f64(<8 x i16> %a0, <8 x i16> %a1) {
 ; CHECK-LABEL: @concat_bitcast_v8i16_v4f64(
-; CHECK-NEXT:    [[X0:%.*]] = bitcast <8 x i16> [[A0:%.*]] to <2 x double>
-; CHECK-NEXT:    [[X1:%.*]] = bitcast <8 x i16> [[A1:%.*]] to <2 x double>
-; CHECK-NEXT:    [[R:%.*]] = shufflevector <2 x double> [[X0]], <2 x double> [[X1]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <8 x i16> [[A0:%.*]], <8 x i16> [[A1:%.*]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
+; CHECK-NEXT:    [[R:%.*]] = bitcast <16 x i16> [[TMP1]] to <4 x double>
 ; CHECK-NEXT:    ret <4 x double> [[R]]
 ;
   %x0 = bitcast <8 x i16> %a0 to <2 x double>
@@ -194,13 +193,12 @@ define <4 x double> @concat_bitcast_v8i16_v4f64(<8 x i16> %a0, <8 x i16> %a1) {
   ret <4 x double> %r
 }
 
-; TODO - bitcasts (higher element count)
+; bitcasts (higher element count)
 
 define <16 x i16> @concat_bitcast_v4i32_v16i16(<4 x i32> %a0, <4 x i32> %a1) {
 ; CHECK-LABEL: @concat_bitcast_v4i32_v16i16(
-; CHECK-NEXT:    [[X0:%.*]] = bitcast <4 x i32> [[A0:%.*]] to <8 x i16>
-; CHECK-NEXT:    [[X1:%.*]] = bitcast <4 x i32> [[A1:%.*]] to <8 x i16>
-; CHECK-NEXT:    [[R:%.*]] = shufflevector <8 x i16> [[X0]], <8 x i16> [[X1]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x i32> [[A0:%.*]], <4 x i32> [[A1:%.*]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[R:%.*]] = bitcast <8 x i32> [[TMP1]] to <16 x i16>
 ; CHECK-NEXT:    ret <16 x i16> [[R]]
 ;
   %x0 = bitcast <4 x i32> %a0 to <8 x i16>

diff  --git a/llvm/test/Transforms/VectorCombine/X86/shuffle.ll b/llvm/test/Transforms/VectorCombine/X86/shuffle.ll
index 3d47f373ab77c2..8337bb37bc549d 100644
--- a/llvm/test/Transforms/VectorCombine/X86/shuffle.ll
+++ b/llvm/test/Transforms/VectorCombine/X86/shuffle.ll
@@ -122,11 +122,14 @@ define <16 x i8> @bitcast_shuf_uses(<4 x i32> %v) {
 }
 
 ; shuffle of 2 operands removes bitcasts
+; TODO - can we remove the empty bitcast(bitcast()) ?
 
 define <4 x i64> @bitcast_shuf_remove_bitcasts(<2 x i64> %a0, <2 x i64> %a1) {
 ; CHECK-LABEL: @bitcast_shuf_remove_bitcasts(
 ; CHECK-NEXT:    [[R:%.*]] = shufflevector <2 x i64> [[A0:%.*]], <2 x i64> [[A1:%.*]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    ret <4 x i64> [[R]]
+; CHECK-NEXT:    [[SHUF:%.*]] = bitcast <4 x i64> [[R]] to <8 x i32>
+; CHECK-NEXT:    [[R1:%.*]] = bitcast <8 x i32> [[SHUF]] to <4 x i64>
+; CHECK-NEXT:    ret <4 x i64> [[R1]]
 ;
   %bc0 = bitcast <2 x i64> %a0 to <4 x i32>
   %bc1 = bitcast <2 x i64> %a1 to <4 x i32>


        


More information about the cfe-commits mailing list