[llvm] a1dfd6d - [X86] Add helper function to reduce some code duplication when shrinking a vector load to a vzext_load.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Wed May 27 01:32:32 PDT 2020


Author: Craig Topper
Date: 2020-05-27T01:32:13-07:00
New Revision: a1dfd6d828ac4f8e11e8013b952f0ef080890dcf

URL: https://github.com/llvm/llvm-project/commit/a1dfd6d828ac4f8e11e8013b952f0ef080890dcf
DIFF: https://github.com/llvm/llvm-project/commit/a1dfd6d828ac4f8e11e8013b952f0ef080890dcf.diff

LOG: [X86] Add helper function to reduce some code duplication when shrinking a vector load to a vzext_load.

There's more code for calling CombineTo and replacing the nodes
that I'd like to share, but its complicated by the getNode call
in the middle that needs to be specific to each opcode.

While there are also make sure we recursively delete the load
we're replacing. It eventually gets removed by a RemoveDeadNodes
call at the end of DAG combine, but we should be more eager about
it. We were inconsistently doing this in some places but not all.

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 6bf61af00590..e086c65c40cb 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -33596,6 +33596,21 @@ SDValue X86TargetLowering::unwrapAddress(SDValue N) const {
   return N;
 }
 
+// Helper to look for a normal load that can be narrowed into a vzload with the
+// specified VT and memory VT. Returns SDValue() on failure.
+static SDValue narrowLoadToVZLoad(LoadSDNode *LN, MVT MemVT, MVT VT,
+                                  SelectionDAG &DAG) {
+  // Can't if the load is volatile or atomic.
+  if (!LN->isSimple())
+    return SDValue();
+
+  SDVTList Tys = DAG.getVTList(VT, MVT::Other);
+  SDValue Ops[] = {LN->getChain(), LN->getBasePtr()};
+  return DAG.getMemIntrinsicNode(X86ISD::VZEXT_LOAD, SDLoc(LN), Tys, Ops, MemVT,
+                                 LN->getPointerInfo(), LN->getOriginalAlign(),
+                                 LN->getMemOperand()->getFlags());
+}
+
 // Attempt to match a combined shuffle mask against supported unary shuffle
 // instructions.
 // TODO: Investigate sharing more of this with shuffle lowering.
@@ -35598,13 +35613,7 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG,
     if (VT == MVT::v2f64 && Src.hasOneUse() &&
         ISD::isNormalLoad(Src.getNode())) {
       LoadSDNode *LN = cast<LoadSDNode>(Src);
-      // Unless the load is volatile or atomic.
-      if (LN->isSimple()) {
-        SDVTList Tys = DAG.getVTList(MVT::v2f64, MVT::Other);
-        SDValue Ops[] = { LN->getChain(), LN->getBasePtr() };
-        SDValue VZLoad = DAG.getMemIntrinsicNode(
-            X86ISD::VZEXT_LOAD, DL, Tys, Ops, MVT::f64, LN->getPointerInfo(),
-            LN->getOriginalAlign(), LN->getMemOperand()->getFlags());
+      if (SDValue VZLoad = narrowLoadToVZLoad(LN, MVT::f64, MVT::v2f64, DAG)) {
         SDValue Movddup = DAG.getNode(X86ISD::MOVDDUP, DL, MVT::v2f64, VZLoad);
         DCI.CombineTo(N.getNode(), Movddup);
         DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), VZLoad.getValue(1));
@@ -35786,7 +35795,7 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG,
         SDValue Ops[] = { LN->getChain(), LN->getBasePtr() };
         SDValue BcastLd = DAG.getMemIntrinsicNode(
             X86ISD::VBROADCAST_LOAD, DL, Tys, Ops, MVT::f64,
-            LN->getPointerInfo(), LN->getAlign(),
+            LN->getPointerInfo(), LN->getOriginalAlign(),
             LN->getMemOperand()->getFlags());
         DCI.CombineTo(N.getNode(), BcastLd);
         DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), BcastLd.getValue(1));
@@ -35804,13 +35813,8 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG,
     // the load is volatile.
     if (N0.hasOneUse() && ISD::isNormalLoad(N0.getNode())) {
       auto *LN = cast<LoadSDNode>(N0);
-      if (LN->isSimple()) {
-        SDVTList Tys = DAG.getVTList(VT, MVT::Other);
-        SDValue Ops[] = {LN->getChain(), LN->getBasePtr()};
-        SDValue VZLoad = DAG.getMemIntrinsicNode(
-            X86ISD::VZEXT_LOAD, DL, Tys, Ops, VT.getVectorElementType(),
-            LN->getPointerInfo(), LN->getAlign(),
-            LN->getMemOperand()->getFlags());
+      if (SDValue VZLoad =
+              narrowLoadToVZLoad(LN, VT.getVectorElementType(), VT, DAG)) {
         DCI.CombineTo(N.getNode(), VZLoad);
         DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), VZLoad.getValue(1));
         DCI.recursivelyDeleteUnusedNodes(LN);
@@ -44541,21 +44545,16 @@ static SDValue combineX86INT_TO_FP(SDNode *N, SelectionDAG &DAG,
       ISD::isNormalLoad(In.getNode()) && In.hasOneUse()) {
     assert(InVT.is128BitVector() && "Expected 128-bit input vector");
     LoadSDNode *LN = cast<LoadSDNode>(N->getOperand(0));
-    // Unless the load is volatile or atomic.
-    if (LN->isSimple()) {
+    unsigned NumBits = InVT.getScalarSizeInBits() * VT.getVectorNumElements();
+    MVT MemVT = MVT::getIntegerVT(NumBits);
+    MVT LoadVT = MVT::getVectorVT(MemVT, 128 / NumBits);
+    if (SDValue VZLoad = narrowLoadToVZLoad(LN, MemVT, LoadVT, DAG)) {
       SDLoc dl(N);
-      unsigned NumBits = InVT.getScalarSizeInBits() * VT.getVectorNumElements();
-      MVT MemVT = MVT::getIntegerVT(NumBits);
-      MVT LoadVT = MVT::getVectorVT(MemVT, 128 / NumBits);
-      SDVTList Tys = DAG.getVTList(LoadVT, MVT::Other);
-      SDValue Ops[] = { LN->getChain(), LN->getBasePtr() };
-      SDValue VZLoad = DAG.getMemIntrinsicNode(
-          X86ISD::VZEXT_LOAD, dl, Tys, Ops, MemVT, LN->getPointerInfo(),
-          LN->getOriginalAlign(), LN->getMemOperand()->getFlags());
       SDValue Convert = DAG.getNode(N->getOpcode(), dl, VT,
                                     DAG.getBitcast(InVT, VZLoad));
       DCI.CombineTo(N, Convert);
       DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), VZLoad.getValue(1));
+      DCI.recursivelyDeleteUnusedNodes(LN);
       return SDValue(N, 0);
     }
   }
@@ -44575,21 +44574,16 @@ static SDValue combineCVTP2I_CVTTP2I(SDNode *N, SelectionDAG &DAG,
       ISD::isNormalLoad(In.getNode()) && In.hasOneUse()) {
     assert(InVT.is128BitVector() && "Expected 128-bit input vector");
     LoadSDNode *LN = cast<LoadSDNode>(In);
-    // Unless the load is volatile or atomic.
-    if (LN->isSimple()) {
+    unsigned NumBits = InVT.getScalarSizeInBits() * VT.getVectorNumElements();
+    MVT MemVT = MVT::getFloatingPointVT(NumBits);
+    MVT LoadVT = MVT::getVectorVT(MemVT, 128 / NumBits);
+    if (SDValue VZLoad = narrowLoadToVZLoad(LN, MemVT, LoadVT, DAG)) {
       SDLoc dl(N);
-      unsigned NumBits = InVT.getScalarSizeInBits() * VT.getVectorNumElements();
-      MVT MemVT = MVT::getFloatingPointVT(NumBits);
-      MVT LoadVT = MVT::getVectorVT(MemVT, 128 / NumBits);
-      SDVTList Tys = DAG.getVTList(LoadVT, MVT::Other);
-      SDValue Ops[] = { LN->getChain(), LN->getBasePtr() };
-      SDValue VZLoad = DAG.getMemIntrinsicNode(
-          X86ISD::VZEXT_LOAD, dl, Tys, Ops, MemVT, LN->getPointerInfo(),
-          LN->getOriginalAlign(), LN->getMemOperand()->getFlags());
       SDValue Convert = DAG.getNode(N->getOpcode(), dl, VT,
                                     DAG.getBitcast(InVT, VZLoad));
       DCI.CombineTo(N, Convert);
       DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), VZLoad.getValue(1));
+      DCI.recursivelyDeleteUnusedNodes(LN);
       return SDValue(N, 0);
     }
   }
@@ -44665,18 +44659,13 @@ static SDValue combineCVTPH2PS(SDNode *N, SelectionDAG &DAG,
     // Convert a full vector load into vzload when not all bits are needed.
     if (ISD::isNormalLoad(Src.getNode()) && Src.hasOneUse()) {
       LoadSDNode *LN = cast<LoadSDNode>(N->getOperand(0));
-      // Unless the load is volatile or atomic.
-      if (LN->isSimple()) {
+      if (SDValue VZLoad = narrowLoadToVZLoad(LN, MVT::i64, MVT::v2i64, DAG)) {
         SDLoc dl(N);
-        SDVTList Tys = DAG.getVTList(MVT::v2i64, MVT::Other);
-        SDValue Ops[] = { LN->getChain(), LN->getBasePtr() };
-        SDValue VZLoad = DAG.getMemIntrinsicNode(
-            X86ISD::VZEXT_LOAD, dl, Tys, Ops, MVT::i64, LN->getPointerInfo(),
-            LN->getOriginalAlign(), LN->getMemOperand()->getFlags());
         SDValue Convert = DAG.getNode(N->getOpcode(), dl, MVT::v4f32,
                                       DAG.getBitcast(MVT::v8i16, VZLoad));
         DCI.CombineTo(N, Convert);
         DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), VZLoad.getValue(1));
+        DCI.recursivelyDeleteUnusedNodes(LN);
         return SDValue(N, 0);
       }
     }


        


More information about the llvm-commits mailing list