[llvm] af8d27a - [DAG] visitVECTOR_SHUFFLE - pull out shuffle merging code into lambda helper. NFCI.

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 14 03:05:33 PST 2021


Author: Simon Pilgrim
Date: 2021-01-14T11:05:19Z
New Revision: af8d27a7a8266b89916b5e4db2b2fd97eb7d84e5

URL: https://github.com/llvm/llvm-project/commit/af8d27a7a8266b89916b5e4db2b2fd97eb7d84e5
DIFF: https://github.com/llvm/llvm-project/commit/af8d27a7a8266b89916b5e4db2b2fd97eb7d84e5.diff

LOG: [DAG] visitVECTOR_SHUFFLE - pull out shuffle merging code into lambda helper. NFCI.

Make it easier to reuse in a future patch.

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 24bc7fe7e0ad..f4c9b814b806 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -20823,30 +20823,19 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
       return DAG.getCommutedVectorShuffle(*SVN);
   }
 
-  // Try to fold according to rules:
-  //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
-  //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
-  //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
-  // Don't try to fold shuffles with illegal type.
-  // Only fold if this shuffle is the only user of the other shuffle.
-  if (N0.getOpcode() == ISD::VECTOR_SHUFFLE && N->isOnlyUserOf(N0.getNode()) &&
-      Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
-    ShuffleVectorSDNode *OtherSV = cast<ShuffleVectorSDNode>(N0);
-
+  // Compute the combined shuffle mask for a shuffle with SV0 as the first
+  // operand, and SV1 as the second operand.
+  // i.e. Merge SVN(OtherSVN, N1) -> shuffle(SV0, SV1, Mask).
+  auto MergeInnerShuffle = [NumElts](ShuffleVectorSDNode *SVN,
+                                     ShuffleVectorSDNode *OtherSVN, SDValue N1,
+                                     SDValue &SV0, SDValue &SV1,
+                                     SmallVectorImpl<int> &Mask) -> bool {
     // Don't try to fold splats; they're likely to simplify somehow, or they
     // might be free.
-    if (OtherSV->isSplat())
-      return SDValue();
-
-    // The incoming shuffle must be of the same type as the result of the
-    // current shuffle.
-    assert(OtherSV->getOperand(0).getValueType() == VT &&
-           "Shuffle types don't match");
+    if (OtherSVN->isSplat())
+      return false;
 
-    SDValue SV0, SV1;
-    SmallVector<int, 4> Mask;
-    // Compute the combined shuffle mask for a shuffle with SV0 as the first
-    // operand, and SV1 as the second operand.
+    Mask.clear();
     for (unsigned i = 0; i != NumElts; ++i) {
       int Idx = SVN->getMaskElt(i);
       if (Idx < 0) {
@@ -20859,15 +20848,14 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
       if (Idx < (int)NumElts) {
         // This shuffle index refers to the inner shuffle N0. Lookup the inner
         // shuffle mask to identify which vector is actually referenced.
-        Idx = OtherSV->getMaskElt(Idx);
+        Idx = OtherSVN->getMaskElt(Idx);
         if (Idx < 0) {
           // Propagate Undef.
           Mask.push_back(Idx);
           continue;
         }
-
-        CurrentVec = (Idx < (int) NumElts) ? OtherSV->getOperand(0)
-                                           : OtherSV->getOperand(1);
+        CurrentVec = (Idx < (int)NumElts) ? OtherSVN->getOperand(0)
+                                          : OtherSVN->getOperand(1);
       } else {
         // This shuffle index references an element within N1.
         CurrentVec = N1;
@@ -20892,31 +20880,52 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
 
       // Bail out if we cannot convert the shuffle pair into a single shuffle.
       if (SV1.getNode() && SV1 != CurrentVec)
-        return SDValue();
+        return false;
 
       // Ok. CurrentVec is the right hand side.
       // Update the mask accordingly.
       SV1 = CurrentVec;
       Mask.push_back(Idx + NumElts);
     }
+    return true;
+  };
 
-    // Check if all indices in Mask are Undef. In case, propagate Undef.
-    if (llvm::all_of(Mask, [](int M) { return M < 0; }))
-      return DAG.getUNDEF(VT);
+  // Try to fold according to rules:
+  //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
+  //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
+  //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
+  // Don't try to fold shuffles with illegal type.
+  // Only fold if this shuffle is the only user of the other shuffle.
+  if (N0.getOpcode() == ISD::VECTOR_SHUFFLE && N->isOnlyUserOf(N0.getNode()) &&
+      Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
+    ShuffleVectorSDNode *OtherSV = cast<ShuffleVectorSDNode>(N0);
+
+    // The incoming shuffle must be of the same type as the result of the
+    // current shuffle.
+    assert(OtherSV->getOperand(0).getValueType() == VT &&
+           "Shuffle types don't match");
 
-    if (!SV0.getNode())
-      SV0 = DAG.getUNDEF(VT);
-    if (!SV1.getNode())
-      SV1 = DAG.getUNDEF(VT);
-
-    // Avoid introducing shuffles with illegal mask.
-    //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
-    //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
-    //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
-    //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, A, M2)
-    //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, A, M2)
-    //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, B, M2)
-    return TLI.buildLegalVectorShuffle(VT, SDLoc(N), SV0, SV1, Mask, DAG);
+    SDValue SV0, SV1;
+    SmallVector<int, 4> Mask;
+    if (MergeInnerShuffle(SVN, OtherSV, N1, SV0, SV1, Mask)) {
+      // Check if all indices in Mask are Undef. In case, propagate Undef.
+      if (llvm::all_of(Mask, [](int M) { return M < 0; }))
+        return DAG.getUNDEF(VT);
+
+      if (!SV0.getNode())
+        SV0 = DAG.getUNDEF(VT);
+      if (!SV1.getNode())
+        SV1 = DAG.getUNDEF(VT);
+
+      // Avoid introducing shuffles with illegal mask.
+      //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
+      //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
+      //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
+      //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, A, M2)
+      //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, A, M2)
+      //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, B, M2)
+      return TLI.buildLegalVectorShuffle(VT, SDLoc(N), SV0, SV1, Mask, DAG);
+    }
   }
 
   if (SDValue V = foldShuffleOfConcatUndefs(SVN, DAG))


        


More information about the llvm-commits mailing list