[llvm] e4e5c42 - [X86][SSE] isTargetShuffleEquivalent - ensure shuffle inputs are the correct size.

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Sun Oct 4 07:32:32 PDT 2020


Author: Simon Pilgrim
Date: 2020-10-04T15:32:05+01:00
New Revision: e4e5c42896df5ed61a98926ea42f5b1ab734e1c4

URL: https://github.com/llvm/llvm-project/commit/e4e5c42896df5ed61a98926ea42f5b1ab734e1c4
DIFF: https://github.com/llvm/llvm-project/commit/e4e5c42896df5ed61a98926ea42f5b1ab734e1c4.diff

LOG: [X86][SSE] isTargetShuffleEquivalent - ensure shuffle inputs are the correct size.

Preliminary patch for the next stage of PR45974 - we don't want to be creating 'padded' vectors on-the-fly at all in combineX86ShufflesRecursively, and only pad the source inputs if we have a definite match inside combineX86ShuffleChain.

This means that the inputs to combineX86ShuffleChain might soon be smaller than the final root value type, so we should ensure that isTargetShuffleEquivalent only matches with the inputs if they are the correct size.

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 13de7bb75b8a..9b5412c945ff 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -10930,7 +10930,7 @@ static bool isShuffleEquivalent(ArrayRef<int> Mask, ArrayRef<int> ExpectedMask,
 ///
 /// SM_SentinelZero is accepted as a valid negative index but must match in
 /// both.
-static bool isTargetShuffleEquivalent(ArrayRef<int> Mask,
+static bool isTargetShuffleEquivalent(MVT VT, ArrayRef<int> Mask,
                                       ArrayRef<int> ExpectedMask,
                                       SDValue V1 = SDValue(),
                                       SDValue V2 = SDValue()) {
@@ -10944,6 +10944,12 @@ static bool isTargetShuffleEquivalent(ArrayRef<int> Mask,
   if (!isUndefOrZeroOrInRange(Mask, 0, 2 * Size))
     return false;
 
+  // Don't use V1/V2 if they're not the same size as the shuffle mask type.
+  if (V1 && V1.getValueSizeInBits() != VT.getSizeInBits())
+    V1 = SDValue();
+  if (V2 && V2.getValueSizeInBits() != VT.getSizeInBits())
+    V2 = SDValue();
+
   for (int i = 0; i < Size; ++i) {
     int MaskIdx = Mask[i];
     int ExpectedIdx = ExpectedMask[i];
@@ -11002,8 +11008,8 @@ static bool isUnpackWdShuffleMask(ArrayRef<int> Mask, MVT VT) {
   SmallVector<int, 8> Unpckhwd;
   createUnpackShuffleMask(MVT::v8i16, Unpckhwd, /* Lo = */ false,
                           /* Unary = */ false);
-  bool IsUnpackwdMask = (isTargetShuffleEquivalent(Mask, Unpcklwd) ||
-                         isTargetShuffleEquivalent(Mask, Unpckhwd));
+  bool IsUnpackwdMask = (isTargetShuffleEquivalent(VT, Mask, Unpcklwd) ||
+                         isTargetShuffleEquivalent(VT, Mask, Unpckhwd));
   return IsUnpackwdMask;
 }
 
@@ -11020,8 +11026,8 @@ static bool is128BitUnpackShuffleMask(ArrayRef<int> Mask) {
   for (unsigned i = 0; i != 4; ++i) {
     SmallVector<int, 16> UnpackMask;
     createUnpackShuffleMask(VT, UnpackMask, (i >> 1) % 2, i % 2);
-    if (isTargetShuffleEquivalent(Mask, UnpackMask) ||
-        isTargetShuffleEquivalent(CommutedMask, UnpackMask))
+    if (isTargetShuffleEquivalent(VT, Mask, UnpackMask) ||
+        isTargetShuffleEquivalent(VT, CommutedMask, UnpackMask))
       return true;
   }
   return false;
@@ -11214,7 +11220,7 @@ static bool matchShuffleWithUNPCK(MVT VT, SDValue &V1, SDValue &V2,
   // Attempt to match the target mask against the unpack lo/hi mask patterns.
   SmallVector<int, 64> Unpckl, Unpckh;
   createUnpackShuffleMask(VT, Unpckl, /* Lo = */ true, IsUnary);
-  if (isTargetShuffleEquivalent(TargetMask, Unpckl)) {
+  if (isTargetShuffleEquivalent(VT, TargetMask, Unpckl)) {
     UnpackOpcode = X86ISD::UNPCKL;
     V2 = (Undef2 ? DAG.getUNDEF(VT) : (IsUnary ? V1 : V2));
     V1 = (Undef1 ? DAG.getUNDEF(VT) : V1);
@@ -11222,7 +11228,7 @@ static bool matchShuffleWithUNPCK(MVT VT, SDValue &V1, SDValue &V2,
   }
 
   createUnpackShuffleMask(VT, Unpckh, /* Lo = */ false, IsUnary);
-  if (isTargetShuffleEquivalent(TargetMask, Unpckh)) {
+  if (isTargetShuffleEquivalent(VT, TargetMask, Unpckh)) {
     UnpackOpcode = X86ISD::UNPCKH;
     V2 = (Undef2 ? DAG.getUNDEF(VT) : (IsUnary ? V1 : V2));
     V1 = (Undef1 ? DAG.getUNDEF(VT) : V1);
@@ -11260,14 +11266,14 @@ static bool matchShuffleWithUNPCK(MVT VT, SDValue &V1, SDValue &V2,
   // If a binary shuffle, commute and try again.
   if (!IsUnary) {
     ShuffleVectorSDNode::commuteMask(Unpckl);
-    if (isTargetShuffleEquivalent(TargetMask, Unpckl)) {
+    if (isTargetShuffleEquivalent(VT, TargetMask, Unpckl)) {
       UnpackOpcode = X86ISD::UNPCKL;
       std::swap(V1, V2);
       return true;
     }
 
     ShuffleVectorSDNode::commuteMask(Unpckh);
-    if (isTargetShuffleEquivalent(TargetMask, Unpckh)) {
+    if (isTargetShuffleEquivalent(VT, TargetMask, Unpckh)) {
       UnpackOpcode = X86ISD::UNPCKH;
       std::swap(V1, V2);
       return true;
@@ -11638,14 +11644,14 @@ static bool matchShuffleWithPACK(MVT VT, MVT &SrcVT, SDValue &V1, SDValue &V2,
     // Try binary shuffle.
     SmallVector<int, 32> BinaryMask;
     createPackShuffleMask(VT, BinaryMask, false, NumStages);
-    if (isTargetShuffleEquivalent(TargetMask, BinaryMask, V1, V2))
+    if (isTargetShuffleEquivalent(VT, TargetMask, BinaryMask, V1, V2))
       if (MatchPACK(V1, V2, PackVT))
         return true;
 
     // Try unary shuffle.
     SmallVector<int, 32> UnaryMask;
     createPackShuffleMask(VT, UnaryMask, true, NumStages);
-    if (isTargetShuffleEquivalent(TargetMask, UnaryMask, V1))
+    if (isTargetShuffleEquivalent(VT, TargetMask, UnaryMask, V1))
       if (MatchPACK(V1, V1, PackVT))
         return true;
   }
@@ -34522,17 +34528,17 @@ static bool matchUnaryShuffle(MVT MaskVT, ArrayRef<int> Mask,
   // instructions are no slower than UNPCKLPD but has the option to
   // fold the input operand into even an unaligned memory load.
   if (MaskVT.is128BitVector() && Subtarget.hasSSE3() && AllowFloatDomain) {
-    if (isTargetShuffleEquivalent(Mask, {0, 0}, V1)) {
+    if (isTargetShuffleEquivalent(MaskVT, Mask, {0, 0}, V1)) {
       Shuffle = X86ISD::MOVDDUP;
       SrcVT = DstVT = MVT::v2f64;
       return true;
     }
-    if (isTargetShuffleEquivalent(Mask, {0, 0, 2, 2}, V1)) {
+    if (isTargetShuffleEquivalent(MaskVT, Mask, {0, 0, 2, 2}, V1)) {
       Shuffle = X86ISD::MOVSLDUP;
       SrcVT = DstVT = MVT::v4f32;
       return true;
     }
-    if (isTargetShuffleEquivalent(Mask, {1, 1, 3, 3}, V1)) {
+    if (isTargetShuffleEquivalent(MaskVT, Mask, {1, 1, 3, 3}, V1)) {
       Shuffle = X86ISD::MOVSHDUP;
       SrcVT = DstVT = MVT::v4f32;
       return true;
@@ -34541,17 +34547,17 @@ static bool matchUnaryShuffle(MVT MaskVT, ArrayRef<int> Mask,
 
   if (MaskVT.is256BitVector() && AllowFloatDomain) {
     assert(Subtarget.hasAVX() && "AVX required for 256-bit vector shuffles");
-    if (isTargetShuffleEquivalent(Mask, {0, 0, 2, 2}, V1)) {
+    if (isTargetShuffleEquivalent(MaskVT, Mask, {0, 0, 2, 2}, V1)) {
       Shuffle = X86ISD::MOVDDUP;
       SrcVT = DstVT = MVT::v4f64;
       return true;
     }
-    if (isTargetShuffleEquivalent(Mask, {0, 0, 2, 2, 4, 4, 6, 6}, V1)) {
+    if (isTargetShuffleEquivalent(MaskVT, Mask, {0, 0, 2, 2, 4, 4, 6, 6}, V1)) {
       Shuffle = X86ISD::MOVSLDUP;
       SrcVT = DstVT = MVT::v8f32;
       return true;
     }
-    if (isTargetShuffleEquivalent(Mask, {1, 1, 3, 3, 5, 5, 7, 7}, V1)) {
+    if (isTargetShuffleEquivalent(MaskVT, Mask, {1, 1, 3, 3, 5, 5, 7, 7}, V1)) {
       Shuffle = X86ISD::MOVSHDUP;
       SrcVT = DstVT = MVT::v8f32;
       return true;
@@ -34561,19 +34567,21 @@ static bool matchUnaryShuffle(MVT MaskVT, ArrayRef<int> Mask,
   if (MaskVT.is512BitVector() && AllowFloatDomain) {
     assert(Subtarget.hasAVX512() &&
            "AVX512 required for 512-bit vector shuffles");
-    if (isTargetShuffleEquivalent(Mask, {0, 0, 2, 2, 4, 4, 6, 6}, V1)) {
+    if (isTargetShuffleEquivalent(MaskVT, Mask, {0, 0, 2, 2, 4, 4, 6, 6}, V1)) {
       Shuffle = X86ISD::MOVDDUP;
       SrcVT = DstVT = MVT::v8f64;
       return true;
     }
     if (isTargetShuffleEquivalent(
-            Mask, {0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14, 14}, V1)) {
+            MaskVT, Mask,
+            {0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14, 14}, V1)) {
       Shuffle = X86ISD::MOVSLDUP;
       SrcVT = DstVT = MVT::v16f32;
       return true;
     }
     if (isTargetShuffleEquivalent(
-            Mask, {1, 1, 3, 3, 5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15}, V1)) {
+            MaskVT, Mask,
+            {1, 1, 3, 3, 5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15}, V1)) {
       Shuffle = X86ISD::MOVSHDUP;
       SrcVT = DstVT = MVT::v16f32;
       return true;
@@ -34732,27 +34740,27 @@ static bool matchBinaryShuffle(MVT MaskVT, ArrayRef<int> Mask,
   unsigned EltSizeInBits = MaskVT.getScalarSizeInBits();
 
   if (MaskVT.is128BitVector()) {
-    if (isTargetShuffleEquivalent(Mask, {0, 0}) && AllowFloatDomain) {
+    if (isTargetShuffleEquivalent(MaskVT, Mask, {0, 0}) && AllowFloatDomain) {
       V2 = V1;
       V1 = (SM_SentinelUndef == Mask[0] ? DAG.getUNDEF(MVT::v4f32) : V1);
       Shuffle = Subtarget.hasSSE2() ? X86ISD::UNPCKL : X86ISD::MOVLHPS;
       SrcVT = DstVT = Subtarget.hasSSE2() ? MVT::v2f64 : MVT::v4f32;
       return true;
     }
-    if (isTargetShuffleEquivalent(Mask, {1, 1}) && AllowFloatDomain) {
+    if (isTargetShuffleEquivalent(MaskVT, Mask, {1, 1}) && AllowFloatDomain) {
       V2 = V1;
       Shuffle = Subtarget.hasSSE2() ? X86ISD::UNPCKH : X86ISD::MOVHLPS;
       SrcVT = DstVT = Subtarget.hasSSE2() ? MVT::v2f64 : MVT::v4f32;
       return true;
     }
-    if (isTargetShuffleEquivalent(Mask, {0, 3}) && Subtarget.hasSSE2() &&
-        (AllowFloatDomain || !Subtarget.hasSSE41())) {
+    if (isTargetShuffleEquivalent(MaskVT, Mask, {0, 3}) &&
+        Subtarget.hasSSE2() && (AllowFloatDomain || !Subtarget.hasSSE41())) {
       std::swap(V1, V2);
       Shuffle = X86ISD::MOVSD;
       SrcVT = DstVT = MVT::v2f64;
       return true;
     }
-    if (isTargetShuffleEquivalent(Mask, {4, 1, 2, 3}) &&
+    if (isTargetShuffleEquivalent(MaskVT, Mask, {4, 1, 2, 3}) &&
         (AllowFloatDomain || !Subtarget.hasSSE41())) {
       Shuffle = X86ISD::MOVSS;
       SrcVT = DstVT = MVT::v4f32;
@@ -35325,7 +35333,8 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
   // from a scalar.
   // TODO: Handle other insertions here as well?
   if (!UnaryShuffle && AllowFloatDomain && RootSizeInBits == 128 &&
-      Subtarget.hasSSE41() && !isTargetShuffleEquivalent(Mask, {4, 1, 2, 3})) {
+      Subtarget.hasSSE41() &&
+      !isTargetShuffleEquivalent(MaskVT, Mask, {4, 1, 2, 3})) {
     if (MaskEltSizeInBits == 32) {
       SDValue SrcV1 = V1, SrcV2 = V2;
       if (matchShuffleAsInsertPS(SrcV1, SrcV2, PermuteImm, Zeroable, Mask,
@@ -35340,7 +35349,8 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
         return DAG.getBitcast(RootVT, Res);
       }
     }
-    if (MaskEltSizeInBits == 64 && isTargetShuffleEquivalent(Mask, {0, 2}) &&
+    if (MaskEltSizeInBits == 64 &&
+        isTargetShuffleEquivalent(MaskVT, Mask, {0, 2}) &&
         V2.getOpcode() == ISD::SCALAR_TO_VECTOR &&
         V2.getScalarValueSizeInBits() <= 32) {
       if (Depth == 0 && Root.getOpcode() == X86ISD::INSERTPS)


        


More information about the llvm-commits mailing list