[llvm] d225de6 - Revert "[X86][AVX] Add getBROADCAST_LOAD helper function. NFCI."

Tres Popp via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 27 07:56:06 PDT 2021


Author: Tres Popp
Date: 2021-07-27T16:55:50+02:00
New Revision: d225de60c933d8feffde917c49a5bbb72530ae28

URL: https://github.com/llvm/llvm-project/commit/d225de60c933d8feffde917c49a5bbb72530ae28
DIFF: https://github.com/llvm/llvm-project/commit/d225de60c933d8feffde917c49a5bbb72530ae28.diff

LOG: Revert "[X86][AVX] Add getBROADCAST_LOAD helper function. NFCI."

This reverts commit 1cfecf4fc4278afb0005923f6dff595cd372da5c.

This commit broke LLVM code generated through XLA by removing a
conditional on Ld->getExtensionType() == ISD::NON_EXTLOAD

This is not a perfect revert. The new function is left as other uses of
it exist now.

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 344bf73b2c5e0..48c3713d3704a 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -16084,12 +16084,21 @@ static SDValue lowerV2X128Shuffle(const SDLoc &DL, MVT VT, SDValue V1,
     bool SplatHi = isShuffleEquivalent(Mask, {2, 3, 2, 3}, V1);
     if ((SplatLo || SplatHi) && !Subtarget.hasAVX512() && V1.hasOneUse() &&
         MayFoldLoad(peekThroughOneUseBitcasts(V1))) {
-      MVT MemVT = VT.getHalfNumVectorElementsVT();
-      unsigned Ofs = SplatLo ? 0 : MemVT.getStoreSize();
       auto *Ld = cast<LoadSDNode>(peekThroughOneUseBitcasts(V1));
-      if (SDValue BcstLd = getBROADCAST_LOAD(X86ISD::SUBV_BROADCAST_LOAD, DL,
-                                             VT, MemVT, Ld, Ofs, DAG))
-        return BcstLd;
+      if (!Ld->isNonTemporal()) {
+        MVT MemVT = VT.getHalfNumVectorElementsVT();
+        unsigned Ofs = SplatLo ? 0 : MemVT.getStoreSize();
+        SDVTList Tys = DAG.getVTList(VT, MVT::Other);
+        SDValue Ptr = DAG.getMemBasePlusOffset(Ld->getBasePtr(),
+                                               TypeSize::Fixed(Ofs), DL);
+        SDValue Ops[] = {Ld->getChain(), Ptr};
+        SDValue BcastLd = DAG.getMemIntrinsicNode(
+            X86ISD::SUBV_BROADCAST_LOAD, DL, Tys, Ops, MemVT,
+            DAG.getMachineFunction().getMachineMemOperand(
+                Ld->getMemOperand(), Ofs, MemVT.getStoreSize()));
+        DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), BcastLd.getValue(1));
+        return BcastLd;
+      }
     }
 
     // With AVX2, use VPERMQ/VPERMPD for unary shuffles to allow memory folding.
@@ -38011,7 +38020,7 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG,
       return Res;
 
     // Fold vperm2x128 subvector shuffle with an inner concat pattern.
-    // vperm2x128(concat(X,Y),concat(Z,W)) --> concat X,Y etc.  
+    // vperm2x128(concat(X,Y),concat(Z,W)) --> concat X,Y etc.
     auto FindSubVector128 = [&](unsigned Idx) {
       if (Idx > 3)
         return SDValue();
@@ -38992,10 +39001,10 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
     }
       // Subvector broadcast.
     case X86ISD::SUBV_BROADCAST_LOAD: {
-      SDLoc DL(Op);
       auto *MemIntr = cast<MemIntrinsicSDNode>(Op);
       EVT MemVT = MemIntr->getMemoryVT();
       if (ExtSizeInBits == MemVT.getStoreSizeInBits()) {
+        SDLoc DL(Op);
         SDValue Ld =
             TLO.DAG.getLoad(MemVT, DL, MemIntr->getChain(),
                             MemIntr->getBasePtr(), MemIntr->getMemOperand());
@@ -39004,13 +39013,18 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
         return TLO.CombineTo(Op, insertSubVector(TLO.DAG.getUNDEF(VT), Ld, 0,
                                                  TLO.DAG, DL, ExtSizeInBits));
       } else if ((ExtSizeInBits % MemVT.getStoreSizeInBits()) == 0) {
+        SDLoc DL(Op);
         EVT BcstVT = EVT::getVectorVT(*TLO.DAG.getContext(), VT.getScalarType(),
                                       ExtSizeInBits / VT.getScalarSizeInBits());
-        if (SDValue BcstLd =
-                getBROADCAST_LOAD(Opc, DL, BcstVT, MemVT, MemIntr, 0, TLO.DAG))
-          return TLO.CombineTo(Op,
-                               insertSubVector(TLO.DAG.getUNDEF(VT), BcstLd, 0,
-                                               TLO.DAG, DL, ExtSizeInBits));
+        SDVTList Tys = TLO.DAG.getVTList(BcstVT, MVT::Other);
+        SDValue Ops[] = {MemIntr->getOperand(0), MemIntr->getOperand(1)};
+        SDValue Bcst =
+            TLO.DAG.getMemIntrinsicNode(X86ISD::SUBV_BROADCAST_LOAD, DL, Tys,
+                                        Ops, MemVT, MemIntr->getMemOperand());
+        TLO.DAG.makeEquivalentMemoryOrdering(SDValue(MemIntr, 1),
+                                             Bcst.getValue(1));
+        return TLO.CombineTo(Op, insertSubVector(TLO.DAG.getUNDEF(VT), Bcst, 0,
+                                                 TLO.DAG, DL, ExtSizeInBits));
       }
       break;
     }
@@ -50083,21 +50097,36 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT,
     if (Op0.getOpcode() == X86ISD::VBROADCAST)
       return DAG.getNode(Op0.getOpcode(), DL, VT, Op0.getOperand(0));
 
-    // If this simple subvector or scalar/subvector broadcast_load is inserted
-    // into both halves, use a larger broadcast_load. Update other uses to use
-    // an extracted subvector.
-    if (Op0.getOpcode() == ISD::LOAD ||
-        Op0.getOpcode() == X86ISD::VBROADCAST_LOAD ||
+    // If this scalar/subvector broadcast_load is inserted into both halves, use
+    // a larger broadcast_load. Update other uses to use an extracted subvector.
+    if (Op0.getOpcode() == X86ISD::VBROADCAST_LOAD ||
         Op0.getOpcode() == X86ISD::SUBV_BROADCAST_LOAD) {
-      auto *Mem = cast<MemSDNode>(Op0);
-      unsigned Opcode = Op0.getOpcode() == X86ISD::VBROADCAST_LOAD
-                            ? X86ISD::VBROADCAST_LOAD
-                            : X86ISD::SUBV_BROADCAST_LOAD;
-      if (SDValue BcastLd = getBROADCAST_LOAD(
-              Opcode, DL, VT, Mem->getMemoryVT(), Mem, 0, DAG)) {
+      auto *MemIntr = cast<MemIntrinsicSDNode>(Op0);
+      SDVTList Tys = DAG.getVTList(VT, MVT::Other);
+      SDValue Ops[] = {MemIntr->getChain(), MemIntr->getBasePtr()};
+      SDValue BcastLd = DAG.getMemIntrinsicNode(Op0.getOpcode(), DL, Tys, Ops,
+                                                MemIntr->getMemoryVT(),
+                                                MemIntr->getMemOperand());
+      DAG.ReplaceAllUsesOfValueWith(
+          Op0, extractSubVector(BcastLd, 0, DAG, DL, Op0.getValueSizeInBits()));
+      DAG.ReplaceAllUsesOfValueWith(SDValue(MemIntr, 1), BcastLd.getValue(1));
+      return BcastLd;
+    }
+
+    // If this is a simple subvector load repeated across multiple lanes, then
+    // broadcast the load. Update other uses to use an extracted subvector.
+    if (auto *Ld = dyn_cast<LoadSDNode>(Op0)) {
+      if (Ld->isSimple() && !Ld->isNonTemporal() &&
+          Ld->getExtensionType() == ISD::NON_EXTLOAD) {
+        SDVTList Tys = DAG.getVTList(VT, MVT::Other);
+        SDValue Ops[] = {Ld->getChain(), Ld->getBasePtr()};
+        SDValue BcastLd =
+            DAG.getMemIntrinsicNode(X86ISD::SUBV_BROADCAST_LOAD, DL, Tys, Ops,
+                                    Ld->getMemoryVT(), Ld->getMemOperand());
         DAG.ReplaceAllUsesOfValueWith(
             Op0,
             extractSubVector(BcastLd, 0, DAG, DL, Op0.getValueSizeInBits()));
+        DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), BcastLd.getValue(1));
         return BcastLd;
       }
     }
@@ -50461,8 +50490,14 @@ static SDValue combineInsertSubvector(SDNode *N, SelectionDAG &DAG,
   if (Vec.isUndef() && IdxVal != 0 && SubVec.hasOneUse() &&
       SubVec.getOpcode() == X86ISD::VBROADCAST_LOAD) {
     auto *MemIntr = cast<MemIntrinsicSDNode>(SubVec);
-    return getBROADCAST_LOAD(X86ISD::VBROADCAST_LOAD, dl, OpVT,
-                             MemIntr->getMemoryVT(), MemIntr, 0, DAG);
+    SDVTList Tys = DAG.getVTList(OpVT, MVT::Other);
+    SDValue Ops[] = { MemIntr->getChain(), MemIntr->getBasePtr() };
+    SDValue BcastLd =
+        DAG.getMemIntrinsicNode(X86ISD::VBROADCAST_LOAD, dl, Tys, Ops,
+                                MemIntr->getMemoryVT(),
+                                MemIntr->getMemOperand());
+    DAG.ReplaceAllUsesOfValueWith(SDValue(MemIntr, 1), BcastLd.getValue(1));
+    return BcastLd;
   }
 
   // If we're splatting the lower half subvector of a full vector load into the


        


More information about the llvm-commits mailing list