[llvm] r235970 - Masked gather and scatter: Added code for SelectionDAG.

Ahmed Bougacha ahmed.bougacha at gmail.com
Tue Apr 28 10:43:10 PDT 2015


Hi Elena,

I don't recall this being approved, was it?  In fact, I started
looking into it but got distracted, I'll send my first few comments;
sorry for being late.

Thanks!

-Ahmed


On Tue, Apr 28, 2015 at 12:57 AM, Elena Demikhovsky
<elena.demikhovsky at intel.com> wrote:
> Author: delena
> Date: Tue Apr 28 02:57:37 2015
> New Revision: 235970
>
> URL: http://llvm.org/viewvc/llvm-project?rev=235970&view=rev
> Log:
> Masked gather and scatter: Added code for SelectionDAG.
> All other patches, including tests will follow.
>
> http://reviews.llvm.org/D7665
>
>
> Modified:
>     llvm/trunk/include/llvm/CodeGen/ISDOpcodes.h
>     llvm/trunk/include/llvm/CodeGen/SelectionDAG.h
>     llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h
>     llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
>     llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
>     llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
>
> Modified: llvm/trunk/include/llvm/CodeGen/ISDOpcodes.h
> URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/CodeGen/ISDOpcodes.h?rev=235970&r1=235969&r2=235970&view=diff
> ==============================================================================
> --- llvm/trunk/include/llvm/CodeGen/ISDOpcodes.h (original)
> +++ llvm/trunk/include/llvm/CodeGen/ISDOpcodes.h Tue Apr 28 02:57:37 2015
> @@ -687,9 +687,16 @@ namespace ISD {
>      ATOMIC_LOAD_UMIN,
>      ATOMIC_LOAD_UMAX,
>
> -    // Masked load and store
> +    // Masked load and store - consecutive vector load and store operations
> +    // with additional mask operand that prevents memory accesses to the
> +    // masked-off lanes.
>      MLOAD, MSTORE,
>
> +    // Masked gather and scatter - load and store operations for a vector of
> +    // random addresses with additional mask operand that prevents memory
> +    // accesses to the masked-off lanes.
> +    MGATHER, MSCATTER,
> +
>      /// This corresponds to the llvm.lifetime.* intrinsics. The first operand
>      /// is the chain and the second operand is the alloca pointer.
>      LIFETIME_START, LIFETIME_END,
>
> Modified: llvm/trunk/include/llvm/CodeGen/SelectionDAG.h
> URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/CodeGen/SelectionDAG.h?rev=235970&r1=235969&r2=235970&view=diff
> ==============================================================================
> --- llvm/trunk/include/llvm/CodeGen/SelectionDAG.h (original)
> +++ llvm/trunk/include/llvm/CodeGen/SelectionDAG.h Tue Apr 28 02:57:37 2015
> @@ -856,6 +856,10 @@ public:
>    SDValue getMaskedStore(SDValue Chain, SDLoc dl, SDValue Val,
>                           SDValue Ptr, SDValue Mask, EVT MemVT,
>                           MachineMemOperand *MMO, bool IsTrunc);
> +  SDValue getMaskedGather(SDVTList VTs, EVT VT, SDLoc dl,
> +                          ArrayRef<SDValue> Ops, MachineMemOperand *MMO);
> +  SDValue getMaskedScatter(SDVTList VTs, EVT VT, SDLoc dl,
> +                           ArrayRef<SDValue> Ops, MachineMemOperand *MMO);
>    /// Construct a node to track a Value* through the backend.
>    SDValue getSrcValue(const Value *v);
>
>
> Modified: llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h
> URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h?rev=235970&r1=235969&r2=235970&view=diff
> ==============================================================================
> --- llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h (original)
> +++ llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h Tue Apr 28 02:57:37 2015
> @@ -1151,6 +1151,8 @@ public:
>             N->getOpcode() == ISD::ATOMIC_STORE        ||
>             N->getOpcode() == ISD::MLOAD               ||
>             N->getOpcode() == ISD::MSTORE              ||
> +           N->getOpcode() == ISD::MGATHER             ||
> +           N->getOpcode() == ISD::MSCATTER            ||
>             N->isMemIntrinsic()                        ||
>             N->isTargetMemoryOpcode();
>    }
> @@ -1987,6 +1989,82 @@ public:
>    }
>  };
>
> +/// This is a base class is used to represent
> +/// MGATHER and MSCATTER nodes
> +///
> +class MaskedGatherScatterSDNode : public MemSDNode {
> +  // Operands
> +  SDUse Ops[5];
> +public:
> +  friend class SelectionDAG;
> +  MaskedGatherScatterSDNode(ISD::NodeType NodeTy, unsigned Order, DebugLoc dl,
> +                            ArrayRef<SDValue> Operands, SDVTList VTs, EVT MemVT,
> +                            MachineMemOperand *MMO)
> +    : MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) {
> +    assert(Operands.size() == 5 && "Incompatible number of operands");
> +    InitOperands(Ops, Operands.data(), Operands.size());
> +  }
> +
> +  // In the both nodes address is Op1, mask is Op2:
> +  // MaskedGatherSDNode  (Chain, src0, mask, base, index), src0 is a passthru value
> +  // MaskedScatterSDNode (Chain, value, mask, base, index)
> +  // Mask is a vector of i1 elements
> +  const SDValue &getBasePtr() const { return getOperand(3); }
> +  const SDValue &getIndex()   const { return getOperand(4); }
> +  const SDValue &getMask()    const { return getOperand(2); }
> +  const SDValue &getValue()   const { return getOperand(1); }
> +
> +  static bool classof(const SDNode *N) {
> +    return N->getOpcode() == ISD::MGATHER ||
> +           N->getOpcode() == ISD::MSCATTER;
> +  }
> +};
> +
> +/// This class is used to represent an MGATHER node
> +///
> +class MaskedGatherSDNode : public MaskedGatherScatterSDNode {
> +public:
> +  friend class SelectionDAG;
> +  MaskedGatherSDNode(unsigned Order, DebugLoc dl, ArrayRef<SDValue> Operands,
> +                     SDVTList VTs, EVT MemVT, MachineMemOperand *MMO)
> +    : MaskedGatherScatterSDNode(ISD::MGATHER, Order, dl, Operands, VTs, MemVT,
> +                                MMO) {
> +    assert(getValue().getValueType() == getValueType(0) &&
> +           "Incompatible type of the PathThru value in MaskedGatherSDNode");
> +    assert(getMask().getValueType().getVectorNumElements() ==
> +           getValueType(0).getVectorNumElements() &&
> +           "Vector width mismatch between mask and data");
> +    assert(getMask().getValueType().getScalarType() == MVT::i1 &&
> +           "Vector width mismatch between mask and data");
> +  }
> +
> +  static bool classof(const SDNode *N) {
> +    return N->getOpcode() == ISD::MGATHER;
> +  }
> +};
> +
> +/// This class is used to represent an MSCATTER node
> +///
> +class MaskedScatterSDNode : public MaskedGatherScatterSDNode {
> +
> +public:
> +  friend class SelectionDAG;
> +  MaskedScatterSDNode(unsigned Order, DebugLoc dl,ArrayRef<SDValue> Operands,
> +                      SDVTList VTs, EVT MemVT, MachineMemOperand *MMO)
> +    : MaskedGatherScatterSDNode(ISD::MSCATTER, Order, dl, Operands, VTs, MemVT,
> +                                MMO) {
> +    assert(getMask().getValueType().getVectorNumElements() ==
> +           getValue().getValueType().getVectorNumElements() &&
> +           "Vector width mismatch between mask and data");
> +    assert(getMask().getValueType().getScalarType() == MVT::i1 &&
> +           "Vector width mismatch between mask and data");
> +  }
> +
> +  static bool classof(const SDNode *N) {
> +    return N->getOpcode() == ISD::MSCATTER;
> +  }
> +};
> +
>  /// An SDNode that represents everything that will be needed
>  /// to construct a MachineInstr. These nodes are created during the
>  /// instruction selection proper phase.
> @@ -2078,7 +2156,7 @@ template <> struct GraphTraits<SDNode*>
>  };
>
>  /// The largest SDNode class.
> -typedef AtomicSDNode LargestSDNode;
> +typedef MaskedGatherScatterSDNode LargestSDNode;
>
>  /// The SDNode class with the greatest alignment requirement.
>  typedef GlobalAddressSDNode MostAlignedSDNode;
>
> Modified: llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
> URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp?rev=235970&r1=235969&r2=235970&view=diff
> ==============================================================================
> --- llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (original)
> +++ llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp Tue Apr 28 02:57:37 2015
> @@ -5097,6 +5097,55 @@ SDValue SelectionDAG::getMaskedStore(SDV
>    return SDValue(N, 0);
>  }
>
> +SDValue
> +SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, SDLoc dl,
> +                              ArrayRef<SDValue> Ops,
> +                              MachineMemOperand *MMO) {
> +
> +  FoldingSetNodeID ID;
> +  AddNodeIDNode(ID, ISD::MGATHER, VTs, Ops);
> +  ID.AddInteger(VT.getRawBits());
> +  ID.AddInteger(encodeMemSDNodeFlags(ISD::NON_EXTLOAD, ISD::UNINDEXED,
> +                                     MMO->isVolatile(),
> +                                     MMO->isNonTemporal(),
> +                                     MMO->isInvariant()));
> +  ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
> +  void *IP = nullptr;
> +  if (SDNode *E = CSEMap.FindNodeOrInsertPos(ID, IP)) {
> +    cast<MaskedGatherSDNode>(E)->refineAlignment(MMO);
> +    return SDValue(E, 0);
> +  }
> +  MaskedGatherSDNode *N =
> +    new (NodeAllocator) MaskedGatherSDNode(dl.getIROrder(), dl.getDebugLoc(),
> +                                           Ops, VTs, VT, MMO);
> +  CSEMap.InsertNode(N, IP);
> +  InsertNode(N);
> +  return SDValue(N, 0);
> +}
> +
> +SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT VT, SDLoc dl,
> +                                       ArrayRef<SDValue> Ops,
> +                                       MachineMemOperand *MMO) {
> +  FoldingSetNodeID ID;
> +  AddNodeIDNode(ID, ISD::MSCATTER, VTs, Ops);
> +  ID.AddInteger(VT.getRawBits());
> +  ID.AddInteger(encodeMemSDNodeFlags(false, ISD::UNINDEXED, MMO->isVolatile(),
> +                                     MMO->isNonTemporal(),
> +                                     MMO->isInvariant()));
> +  ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
> +  void *IP = nullptr;
> +  if (SDNode *E = CSEMap.FindNodeOrInsertPos(ID, IP)) {
> +    cast<MaskedScatterSDNode>(E)->refineAlignment(MMO);
> +    return SDValue(E, 0);
> +  }
> +  SDNode *N =
> +    new (NodeAllocator) MaskedScatterSDNode(dl.getIROrder(), dl.getDebugLoc(),
> +                                            Ops, VTs, VT, MMO);
> +  CSEMap.InsertNode(N, IP);
> +  InsertNode(N);
> +  return SDValue(N, 0);
> +}
> +
>  SDValue SelectionDAG::getVAArg(EVT VT, SDLoc dl,
>                                 SDValue Chain, SDValue Ptr,
>                                 SDValue SV,
>
> Modified: llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
> URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp?rev=235970&r1=235969&r2=235970&view=diff
> ==============================================================================
> --- llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (original)
> +++ llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp Tue Apr 28 02:57:37 2015
> @@ -1059,6 +1059,12 @@ SDValue SelectionDAGBuilder::getValue(co
>    return Val;
>  }
>
> +// Return true if SDValue exists for the given Value
> +bool SelectionDAGBuilder::findValue(const Value *V) const {
> +  return (NodeMap.find(V) != NodeMap.end()) ||
> +    (FuncInfo.ValueMap.find(V) != FuncInfo.ValueMap.end());
> +}
> +
>  /// getNonRegisterValue - Return an SDValue for the given Value, but
>  /// don't look in FuncInfo.ValueMap for a virtual register.
>  SDValue SelectionDAGBuilder::getNonRegisterValue(const Value *V) {
> @@ -3026,6 +3032,92 @@ void SelectionDAGBuilder::visitMaskedSto
>    setValue(&I, StoreNode);
>  }
>
> +// Gather/scatter receive a vector of pointers.
> +// This vector of pointers may be represented as a base pointer + vector of
> +// indices, it depends on GEP and instruction preceeding GEP
> +// that calculates indices
> +static bool getUniformBase(Value *& Ptr, SDValue& Base, SDValue& Index,
> +                           SelectionDAGBuilder* SDB) {
> +
> +  assert (Ptr->getType()->isVectorTy() && "Uexpected pointer type");
> +  GetElementPtrInst *Gep = dyn_cast<GetElementPtrInst>(Ptr);
> +  if (!Gep || Gep->getNumOperands() > 2)
> +    return false;
> +  ShuffleVectorInst *ShuffleInst =
> +    dyn_cast<ShuffleVectorInst>(Gep->getPointerOperand());
> +  if (!ShuffleInst || !ShuffleInst->getMask()->isNullValue() ||
> +      cast<Instruction>(ShuffleInst->getOperand(0))->getOpcode() !=
> +      Instruction::InsertElement)
> +    return false;
> +
> +  Ptr = cast<InsertElementInst>(ShuffleInst->getOperand(0))->getOperand(1);
> +
> +  SelectionDAG& DAG = SDB->DAG;
> +  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
> +  // Check is the Ptr is inside current basic block
> +  // If not, look for the shuffle instruction
> +  if (SDB->findValue(Ptr))
> +    Base = SDB->getValue(Ptr);
> +  else if (SDB->findValue(ShuffleInst)) {
> +    SDValue ShuffleNode = SDB->getValue(ShuffleInst);
> +    Base = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(ShuffleNode),
> +                       ShuffleNode.getValueType().getScalarType(), ShuffleNode,
> +                       DAG.getConstant(0, TLI.getVectorIdxTy()));
> +    SDB->setValue(Ptr, Base);
> +  }
> +  else
> +    return false;
> +
> +  Value *IndexVal = Gep->getOperand(1);
> +  if (SDB->findValue(IndexVal)) {
> +    Index = SDB->getValue(IndexVal);
> +
> +    if (SExtInst* Sext = dyn_cast<SExtInst>(IndexVal)) {
> +      IndexVal = Sext->getOperand(0);
> +      if (SDB->findValue(IndexVal))
> +        Index = SDB->getValue(IndexVal);
> +    }
> +    return true;
> +  }
> +  return false;
> +}
> +
> +void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) {
> +  SDLoc sdl = getCurSDLoc();
> +
> +  // llvm.masked.scatter.*(Src0, Ptrs, alignemt, Mask)
> +  Value  *Ptr = I.getArgOperand(1);
> +  SDValue Src0 = getValue(I.getArgOperand(0));
> +  SDValue Mask = getValue(I.getArgOperand(3));
> +  EVT VT = Src0.getValueType();
> +  unsigned Alignment = (cast<ConstantInt>(I.getArgOperand(2)))->getZExtValue();
> +  if (!Alignment)
> +    Alignment = DAG.getEVTAlignment(VT);
> +  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
> +
> +  AAMDNodes AAInfo;
> +  I.getAAMetadata(AAInfo);
> +
> +  SDValue Base;
> +  SDValue Index;
> +  Value *BasePtr = Ptr;
> +  bool UniformBase = getUniformBase(BasePtr, Base, Index, this);
> +
> +  Value *MemOpBasePtr = UniformBase ? BasePtr : NULL;
> +  MachineMemOperand *MMO = DAG.getMachineFunction().
> +    getMachineMemOperand(MachinePointerInfo(MemOpBasePtr),
> +                         MachineMemOperand::MOStore,  VT.getStoreSize(),
> +                         Alignment, AAInfo);
> +  if (!UniformBase) {
> +    Base = DAG.getTargetConstant(0, TLI.getPointerTy());
> +    Index = getValue(Ptr);
> +  }
> +  SDValue Ops[] = { getRoot(), Src0, Mask, Base, Index };
> +  SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), VT, sdl, Ops, MMO);
> +  DAG.setRoot(Scatter);
> +  setValue(&I, Scatter);
> +}
> +
>  void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I) {
>    SDLoc sdl = getCurSDLoc();
>
> @@ -3067,6 +3159,60 @@ void SelectionDAGBuilder::visitMaskedLoa
>    setValue(&I, Load);
>  }
>
> +void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
> +  SDLoc sdl = getCurSDLoc();
> +
> +  // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
> +  Value  *Ptr = I.getArgOperand(0);
> +  SDValue Src0 = getValue(I.getArgOperand(3));
> +  SDValue Mask = getValue(I.getArgOperand(2));
> +
> +  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
> +  EVT VT = TLI.getValueType(I.getType());
> +  unsigned Alignment = (cast<ConstantInt>(I.getArgOperand(1)))->getZExtValue();
> +  if (!Alignment)
> +    Alignment = DAG.getEVTAlignment(VT);
> +
> +  AAMDNodes AAInfo;
> +  I.getAAMetadata(AAInfo);
> +  const MDNode *Ranges = I.getMetadata(LLVMContext::MD_range);
> +
> +  SDValue Root = DAG.getRoot();
> +  SDValue Base;
> +  SDValue Index;
> +  Value *BasePtr = Ptr;
> +  bool UniformBase = getUniformBase(BasePtr, Base, Index, this);
> +  bool ConstantMemory = false;
> +  if (UniformBase && AA->pointsToConstantMemory(
> +      AliasAnalysis::Location(BasePtr,
> +                                   AA->getTypeStoreSize(I.getType()),
> +                              AAInfo))) {
> +    // Do not serialize (non-volatile) loads of constant memory with anything.
> +    Root = DAG.getEntryNode();
> +    ConstantMemory = true;
> +  }
> +
> +  MachineMemOperand *MMO =
> +    DAG.getMachineFunction().
> +    getMachineMemOperand(MachinePointerInfo(UniformBase ? BasePtr : NULL),
> +                          MachineMemOperand::MOLoad,  VT.getStoreSize(),
> +                          Alignment, AAInfo, Ranges);
> +
> +  if (!UniformBase) {
> +    Base = DAG.getTargetConstant(0, TLI.getPointerTy());
> +    Index = getValue(Ptr);
> +  }
> +
> +  SDValue Ops[] = { Root, Src0, Mask, Base, Index };
> +  SDValue Gather = DAG.getMaskedGather(DAG.getVTList(VT, MVT::Other), VT, sdl,
> +                                       Ops, MMO);
> +
> +  SDValue OutChain = Gather.getValue(1);
> +  if (!ConstantMemory)
> +    PendingLoads.push_back(OutChain);
> +  setValue(&I, Gather);
> +}
> +
>  void SelectionDAGBuilder::visitAtomicCmpXchg(const AtomicCmpXchgInst &I) {
>    SDLoc dl = getCurSDLoc();
>    AtomicOrdering SuccessOrder = I.getSuccessOrdering();
> @@ -4216,9 +4362,13 @@ SelectionDAGBuilder::visitIntrinsicCall(
>      return nullptr;
>    }
>
> +  case Intrinsic::masked_gather:
> +    visitMaskedGather(I);
>    case Intrinsic::masked_load:
>      visitMaskedLoad(I);
>      return nullptr;
> +  case Intrinsic::masked_scatter:
> +    visitMaskedScatter(I);
>    case Intrinsic::masked_store:
>      visitMaskedStore(I);
>      return nullptr;
>
> Modified: llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
> URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h?rev=235970&r1=235969&r2=235970&view=diff
> ==============================================================================
> --- llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h (original)
> +++ llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h Tue Apr 28 02:57:37 2015
> @@ -667,6 +667,8 @@ public:
>    // generate the debug data structures now that we've seen its definition.
>    void resolveDanglingDebugInfo(const Value *V, SDValue Val);
>    SDValue getValue(const Value *V);
> +  bool findValue(const Value *V) const;
> +
>    SDValue getNonRegisterValue(const Value *V);
>    SDValue getValueImpl(const Value *V);
>
> @@ -814,6 +816,8 @@ private:
>    void visitStore(const StoreInst &I);
>    void visitMaskedLoad(const CallInst &I);
>    void visitMaskedStore(const CallInst &I);
> +  void visitMaskedGather(const CallInst &I);
> +  void visitMaskedScatter(const CallInst &I);
>    void visitAtomicCmpXchg(const AtomicCmpXchgInst &I);
>    void visitAtomicRMW(const AtomicRMWInst &I);
>    void visitFence(const FenceInst &I);
>
>
> _______________________________________________
> llvm-commits mailing list
> llvm-commits at cs.uiuc.edu
> http://lists.cs.uiuc.edu/mailman/listinfo/llvm-commits



More information about the llvm-commits mailing list