[llvm] d7bbb12 - Revert "[X86][AVX] Add getBROADCAST_LOAD helper function. NFCI."

Roman Lebedev via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 27 08:28:48 PDT 2021


Reproducer?

On Tue, Jul 27, 2021 at 5:24 PM Tres Popp via llvm-commits
<llvm-commits at lists.llvm.org> wrote:
>
>
> Author: Tres Popp
> Date: 2021-07-27T16:22:25+02:00
> New Revision: d7bbb1230a94cb239aa4a8cb896c45571444675d
>
> URL: https://github.com/llvm/llvm-project/commit/d7bbb1230a94cb239aa4a8cb896c45571444675d
> DIFF: https://github.com/llvm/llvm-project/commit/d7bbb1230a94cb239aa4a8cb896c45571444675d.diff
>
> LOG: Revert "[X86][AVX] Add getBROADCAST_LOAD helper function. NFCI."
>
> This reverts commit 1cfecf4fc4278afb0005923f6dff595cd372da5c.
>
> This commit broke LLVM code generated through XLA by removing a
> conditional on Ld->getExtensionType() == ISD::NON_EXTLOAD
>
> Added:
>
>
> Modified:
>     llvm/lib/Target/X86/X86ISelLowering.cpp
>
> Removed:
>
>
>
> ################################################################################
> diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
> index 344bf73b2c5e0..067b56e205e8e 100644
> --- a/llvm/lib/Target/X86/X86ISelLowering.cpp
> +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
> @@ -7988,30 +7988,6 @@ static bool getTargetShuffleInputs(SDValue Op, SmallVectorImpl<SDValue> &Inputs,
>                                  KnownZero, DAG, Depth, ResolveKnownElts);
>  }
>
> -// Attempt to create a scalar/subvector broadcast from the base MemSDNode.
> -static SDValue getBROADCAST_LOAD(unsigned Opcode, const SDLoc &DL, EVT VT,
> -                                 EVT MemVT, MemSDNode *Mem, unsigned Offset,
> -                                 SelectionDAG &DAG) {
> -  assert((Opcode == X86ISD::VBROADCAST_LOAD ||
> -          Opcode == X86ISD::SUBV_BROADCAST_LOAD) &&
> -         "Unknown broadcast load type");
> -
> -  // Ensure this is a simple (non-atomic, non-voltile), temporal read memop.
> -  if (!Mem || !Mem->readMem() || !Mem->isSimple() || Mem->isNonTemporal())
> -    return SDValue();
> -
> -  SDValue Ptr =
> -      DAG.getMemBasePlusOffset(Mem->getBasePtr(), TypeSize::Fixed(Offset), DL);
> -  SDVTList Tys = DAG.getVTList(VT, MVT::Other);
> -  SDValue Ops[] = {Mem->getChain(), Ptr};
> -  SDValue BcstLd = DAG.getMemIntrinsicNode(
> -      Opcode, DL, Tys, Ops, MemVT,
> -      DAG.getMachineFunction().getMachineMemOperand(
> -          Mem->getMemOperand(), Offset, MemVT.getStoreSize()));
> -  DAG.makeEquivalentMemoryOrdering(SDValue(Mem, 1), BcstLd.getValue(1));
> -  return BcstLd;
> -}
> -
>  /// Returns the scalar element that will make up the i'th
>  /// element of the result of the vector shuffle.
>  static SDValue getShuffleScalarElt(SDValue Op, unsigned Index,
> @@ -16084,12 +16060,21 @@ static SDValue lowerV2X128Shuffle(const SDLoc &DL, MVT VT, SDValue V1,
>      bool SplatHi = isShuffleEquivalent(Mask, {2, 3, 2, 3}, V1);
>      if ((SplatLo || SplatHi) && !Subtarget.hasAVX512() && V1.hasOneUse() &&
>          MayFoldLoad(peekThroughOneUseBitcasts(V1))) {
> -      MVT MemVT = VT.getHalfNumVectorElementsVT();
> -      unsigned Ofs = SplatLo ? 0 : MemVT.getStoreSize();
>        auto *Ld = cast<LoadSDNode>(peekThroughOneUseBitcasts(V1));
> -      if (SDValue BcstLd = getBROADCAST_LOAD(X86ISD::SUBV_BROADCAST_LOAD, DL,
> -                                             VT, MemVT, Ld, Ofs, DAG))
> -        return BcstLd;
> +      if (!Ld->isNonTemporal()) {
> +        MVT MemVT = VT.getHalfNumVectorElementsVT();
> +        unsigned Ofs = SplatLo ? 0 : MemVT.getStoreSize();
> +        SDVTList Tys = DAG.getVTList(VT, MVT::Other);
> +        SDValue Ptr = DAG.getMemBasePlusOffset(Ld->getBasePtr(),
> +                                               TypeSize::Fixed(Ofs), DL);
> +        SDValue Ops[] = {Ld->getChain(), Ptr};
> +        SDValue BcastLd = DAG.getMemIntrinsicNode(
> +            X86ISD::SUBV_BROADCAST_LOAD, DL, Tys, Ops, MemVT,
> +            DAG.getMachineFunction().getMachineMemOperand(
> +                Ld->getMemOperand(), Ofs, MemVT.getStoreSize()));
> +        DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), BcastLd.getValue(1));
> +        return BcastLd;
> +      }
>      }
>
>      // With AVX2, use VPERMQ/VPERMPD for unary shuffles to allow memory folding.
> @@ -38992,10 +38977,10 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
>      }
>        // Subvector broadcast.
>      case X86ISD::SUBV_BROADCAST_LOAD: {
> -      SDLoc DL(Op);
>        auto *MemIntr = cast<MemIntrinsicSDNode>(Op);
>        EVT MemVT = MemIntr->getMemoryVT();
>        if (ExtSizeInBits == MemVT.getStoreSizeInBits()) {
> +        SDLoc DL(Op);
>          SDValue Ld =
>              TLO.DAG.getLoad(MemVT, DL, MemIntr->getChain(),
>                              MemIntr->getBasePtr(), MemIntr->getMemOperand());
> @@ -39004,13 +38989,18 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
>          return TLO.CombineTo(Op, insertSubVector(TLO.DAG.getUNDEF(VT), Ld, 0,
>                                                   TLO.DAG, DL, ExtSizeInBits));
>        } else if ((ExtSizeInBits % MemVT.getStoreSizeInBits()) == 0) {
> +        SDLoc DL(Op);
>          EVT BcstVT = EVT::getVectorVT(*TLO.DAG.getContext(), VT.getScalarType(),
>                                        ExtSizeInBits / VT.getScalarSizeInBits());
> -        if (SDValue BcstLd =
> -                getBROADCAST_LOAD(Opc, DL, BcstVT, MemVT, MemIntr, 0, TLO.DAG))
> -          return TLO.CombineTo(Op,
> -                               insertSubVector(TLO.DAG.getUNDEF(VT), BcstLd, 0,
> -                                               TLO.DAG, DL, ExtSizeInBits));
> +        SDVTList Tys = TLO.DAG.getVTList(BcstVT, MVT::Other);
> +        SDValue Ops[] = {MemIntr->getOperand(0), MemIntr->getOperand(1)};
> +        SDValue Bcst =
> +            TLO.DAG.getMemIntrinsicNode(X86ISD::SUBV_BROADCAST_LOAD, DL, Tys,
> +                                        Ops, MemVT, MemIntr->getMemOperand());
> +        TLO.DAG.makeEquivalentMemoryOrdering(SDValue(MemIntr, 1),
> +                                             Bcst.getValue(1));
> +        return TLO.CombineTo(Op, insertSubVector(TLO.DAG.getUNDEF(VT), Bcst, 0,
> +                                                 TLO.DAG, DL, ExtSizeInBits));
>        }
>        break;
>      }
> @@ -50083,21 +50073,36 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT,
>      if (Op0.getOpcode() == X86ISD::VBROADCAST)
>        return DAG.getNode(Op0.getOpcode(), DL, VT, Op0.getOperand(0));
>
> -    // If this simple subvector or scalar/subvector broadcast_load is inserted
> -    // into both halves, use a larger broadcast_load. Update other uses to use
> -    // an extracted subvector.
> -    if (Op0.getOpcode() == ISD::LOAD ||
> -        Op0.getOpcode() == X86ISD::VBROADCAST_LOAD ||
> +    // If this scalar/subvector broadcast_load is inserted into both halves, use
> +    // a larger broadcast_load. Update other uses to use an extracted subvector.
> +    if (Op0.getOpcode() == X86ISD::VBROADCAST_LOAD ||
>          Op0.getOpcode() == X86ISD::SUBV_BROADCAST_LOAD) {
> -      auto *Mem = cast<MemSDNode>(Op0);
> -      unsigned Opcode = Op0.getOpcode() == X86ISD::VBROADCAST_LOAD
> -                            ? X86ISD::VBROADCAST_LOAD
> -                            : X86ISD::SUBV_BROADCAST_LOAD;
> -      if (SDValue BcastLd = getBROADCAST_LOAD(
> -              Opcode, DL, VT, Mem->getMemoryVT(), Mem, 0, DAG)) {
> +      auto *MemIntr = cast<MemIntrinsicSDNode>(Op0);
> +      SDVTList Tys = DAG.getVTList(VT, MVT::Other);
> +      SDValue Ops[] = {MemIntr->getChain(), MemIntr->getBasePtr()};
> +      SDValue BcastLd = DAG.getMemIntrinsicNode(Op0.getOpcode(), DL, Tys, Ops,
> +                                                MemIntr->getMemoryVT(),
> +                                                MemIntr->getMemOperand());
> +      DAG.ReplaceAllUsesOfValueWith(
> +          Op0, extractSubVector(BcastLd, 0, DAG, DL, Op0.getValueSizeInBits()));
> +      DAG.ReplaceAllUsesOfValueWith(SDValue(MemIntr, 1), BcastLd.getValue(1));
> +      return BcastLd;
> +    }
> +
> +    // If this is a simple subvector load repeated across multiple lanes, then
> +    // broadcast the load. Update other uses to use an extracted subvector.
> +    if (auto *Ld = dyn_cast<LoadSDNode>(Op0)) {
> +      if (Ld->isSimple() && !Ld->isNonTemporal() &&
> +          Ld->getExtensionType() == ISD::NON_EXTLOAD) {
> +        SDVTList Tys = DAG.getVTList(VT, MVT::Other);
> +        SDValue Ops[] = {Ld->getChain(), Ld->getBasePtr()};
> +        SDValue BcastLd =
> +            DAG.getMemIntrinsicNode(X86ISD::SUBV_BROADCAST_LOAD, DL, Tys, Ops,
> +                                    Ld->getMemoryVT(), Ld->getMemOperand());
>          DAG.ReplaceAllUsesOfValueWith(
>              Op0,
>              extractSubVector(BcastLd, 0, DAG, DL, Op0.getValueSizeInBits()));
> +        DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), BcastLd.getValue(1));
>          return BcastLd;
>        }
>      }
> @@ -50461,8 +50466,14 @@ static SDValue combineInsertSubvector(SDNode *N, SelectionDAG &DAG,
>    if (Vec.isUndef() && IdxVal != 0 && SubVec.hasOneUse() &&
>        SubVec.getOpcode() == X86ISD::VBROADCAST_LOAD) {
>      auto *MemIntr = cast<MemIntrinsicSDNode>(SubVec);
> -    return getBROADCAST_LOAD(X86ISD::VBROADCAST_LOAD, dl, OpVT,
> -                             MemIntr->getMemoryVT(), MemIntr, 0, DAG);
> +    SDVTList Tys = DAG.getVTList(OpVT, MVT::Other);
> +    SDValue Ops[] = { MemIntr->getChain(), MemIntr->getBasePtr() };
> +    SDValue BcastLd =
> +        DAG.getMemIntrinsicNode(X86ISD::VBROADCAST_LOAD, dl, Tys, Ops,
> +                                MemIntr->getMemoryVT(),
> +                                MemIntr->getMemOperand());
> +    DAG.ReplaceAllUsesOfValueWith(SDValue(MemIntr, 1), BcastLd.getValue(1));
> +    return BcastLd;
>    }
>
>    // If we're splatting the lower half subvector of a full vector load into the
>
>
>
> _______________________________________________
> llvm-commits mailing list
> llvm-commits at lists.llvm.org
> https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-commits


More information about the llvm-commits mailing list