[llvm] 1fe4953 - [SVE] Remove custom lowering of scalable vector MGATHER & MSCATTER operations.

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 2 03:28:31 PDT 2022


Author: Paul Walker
Date: 2022-06-02T11:19:52+01:00
New Revision: 1fe4953d8939ab4f8f0a5de060c0a35758d835a8

URL: https://github.com/llvm/llvm-project/commit/1fe4953d8939ab4f8f0a5de060c0a35758d835a8
DIFF: https://github.com/llvm/llvm-project/commit/1fe4953d8939ab4f8f0a5de060c0a35758d835a8.diff

LOG: [SVE] Remove custom lowering of scalable vector MGATHER & MSCATTER operations.

Differential Revision: https://reviews.llvm.org/D126255

Added: 
    

Modified: 
    llvm/include/llvm/Target/TargetSelectionDAG.td
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64InstrInfo.td
    llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 7b2a25605acb0..47b686aca7b5d 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -238,6 +238,16 @@ def SDTMaskedLoad: SDTypeProfile<1, 4, [       // masked load
   SDTCisSameNumEltsAs<0, 3>
 ]>;
 
+def SDTMaskedGather : SDTypeProfile<1, 4, [
+  SDTCisVec<0>, SDTCisSameAs<0, 1>, SDTCisVec<2>, SDTCisPtrTy<3>, SDTCisVec<4>,
+  SDTCisSameNumEltsAs<0, 2>, SDTCisSameNumEltsAs<0, 4>
+]>;
+
+def SDTMaskedScatter : SDTypeProfile<0, 4, [
+  SDTCisVec<0>, SDTCisVec<1>, SDTCisPtrTy<2>, SDTCisVec<3>,
+  SDTCisSameNumEltsAs<0, 1>, SDTCisSameNumEltsAs<0, 3>
+]>;
+
 def SDTVecShuffle : SDTypeProfile<1, 2, [
   SDTCisSameAs<0, 1>, SDTCisSameAs<1, 2>
 ]>;
@@ -652,6 +662,12 @@ def masked_st    : SDNode<"ISD::MSTORE",  SDTMaskedStore,
 def masked_ld    : SDNode<"ISD::MLOAD",  SDTMaskedLoad,
                        [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
 
+def masked_gather : SDNode<"ISD::MGATHER", SDTMaskedGather,
+                           [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
+
+def masked_scatter : SDNode<"ISD::MSCATTER", SDTMaskedScatter,
+                            [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>;
+
 // Do not use ld, st directly. Use load, extload, sextload, zextload, store,
 // and truncst (see below).
 def ld         : SDNode<"ISD::LOAD"       , SDTLoad,
@@ -1628,6 +1644,124 @@ def atomic_load_64 :
   let MemoryVT = i64;
 }
 
+def nonext_masked_gather :
+  PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx),
+          (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{
+  return cast<MaskedGatherSDNode>(N)->getExtensionType() == ISD::NON_EXTLOAD;
+}]>;
+
+// Any extending masked gather fragments.
+def ext_masked_gather_i8 :
+  PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx),
+          (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{
+  auto MGN = cast<MaskedGatherSDNode>(N);
+  return MGN->getExtensionType() == ISD::EXTLOAD &&
+         MGN->getMemoryVT().getScalarType() == MVT::i8;
+}]>;
+def ext_masked_gather_i16 :
+  PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx),
+          (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{
+  auto MGN = cast<MaskedGatherSDNode>(N);
+  return MGN->getExtensionType() == ISD::EXTLOAD &&
+         MGN->getMemoryVT().getScalarType() == MVT::i16;
+}]>;
+def ext_masked_gather_i32 :
+  PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx),
+          (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{
+  auto MGN = cast<MaskedGatherSDNode>(N);
+  return MGN->getExtensionType() == ISD::EXTLOAD &&
+         MGN->getMemoryVT().getScalarType() == MVT::i32;
+}]>;
+
+// Sign extending masked gather fragments.
+def sext_masked_gather_i8 :
+  PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx),
+          (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{
+  auto MGN = cast<MaskedGatherSDNode>(N);
+  return MGN->getExtensionType() == ISD::SEXTLOAD &&
+         MGN->getMemoryVT().getScalarType() == MVT::i8;
+}]>;
+def sext_masked_gather_i16 :
+  PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx),
+          (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{
+  auto MGN = cast<MaskedGatherSDNode>(N);
+  return MGN->getExtensionType() == ISD::SEXTLOAD &&
+         MGN->getMemoryVT().getScalarType() == MVT::i16;
+}]>;
+def sext_masked_gather_i32 :
+  PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx),
+          (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{
+  auto MGN = cast<MaskedGatherSDNode>(N);
+  return MGN->getExtensionType() == ISD::SEXTLOAD &&
+         MGN->getMemoryVT().getScalarType() == MVT::i32;
+}]>;
+
+// Zero extending masked gather fragments.
+def zext_masked_gather_i8 :
+  PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx),
+          (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{
+  auto MGN = cast<MaskedGatherSDNode>(N);
+  return MGN->getExtensionType() == ISD::ZEXTLOAD &&
+         MGN->getMemoryVT().getScalarType() == MVT::i8;
+}]>;
+def zext_masked_gather_i16 :
+  PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx),
+          (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{
+  auto MGN = cast<MaskedGatherSDNode>(N);
+  return MGN->getExtensionType() == ISD::ZEXTLOAD &&
+         MGN->getMemoryVT().getScalarType() == MVT::i16;
+}]>;
+def zext_masked_gather_i32 :
+  PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx),
+          (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{
+  auto MGN = cast<MaskedGatherSDNode>(N);
+  return MGN->getExtensionType() == ISD::ZEXTLOAD &&
+         MGN->getMemoryVT().getScalarType() == MVT::i32;
+}]>;
+
+// Any/Zero extending masked gather fragments.
+def azext_masked_gather_i8 :
+  PatFrags<(ops node:$def, node:$pred, node:$ptr, node:$idx),
+           [(ext_masked_gather_i8 node:$def, node:$pred, node:$ptr, node:$idx),
+            (zext_masked_gather_i8 node:$def, node:$pred, node:$ptr, node:$idx)]>;
+def azext_masked_gather_i16 :
+  PatFrags<(ops node:$def, node:$pred, node:$ptr, node:$idx),
+           [(ext_masked_gather_i16 node:$def, node:$pred, node:$ptr, node:$idx),
+            (zext_masked_gather_i16 node:$def, node:$pred, node:$ptr, node:$idx)]>;
+def azext_masked_gather_i32 :
+  PatFrags<(ops node:$def, node:$pred, node:$ptr, node:$idx),
+           [(ext_masked_gather_i32 node:$def, node:$pred, node:$ptr, node:$idx),
+            (zext_masked_gather_i32 node:$def, node:$pred, node:$ptr, node:$idx)]>;
+
+def nontrunc_masked_scatter :
+  PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx),
+          (masked_scatter node:$val, node:$pred, node:$ptr, node:$idx), [{
+  return !cast<MaskedScatterSDNode>(N)->isTruncatingStore();
+}]>;
+
+// Truncating masked scatter fragments.
+def trunc_masked_scatter_i8 :
+  PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx),
+          (masked_scatter node:$val, node:$pred, node:$ptr, node:$idx), [{
+  auto MSN = cast<MaskedScatterSDNode>(N);
+  return MSN->isTruncatingStore() &&
+         MSN->getMemoryVT().getScalarType() == MVT::i8;
+}]>;
+def trunc_masked_scatter_i16 :
+  PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx),
+          (masked_scatter node:$val, node:$pred, node:$ptr, node:$idx), [{
+  auto MSN = cast<MaskedScatterSDNode>(N);
+  return MSN->isTruncatingStore() &&
+         MSN->getMemoryVT().getScalarType() == MVT::i16;
+}]>;
+def trunc_masked_scatter_i32 :
+  PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx),
+          (masked_scatter node:$val, node:$pred, node:$ptr, node:$idx), [{
+  auto MSN = cast<MaskedScatterSDNode>(N);
+  return MSN->isTruncatingStore() &&
+         MSN->getMemoryVT().getScalarType() == MVT::i32;
+}]>;
+
 //===----------------------------------------------------------------------===//
 // Selection DAG Pattern Support.
 //

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 00c8381472183..6243aa9dc2bb5 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4562,29 +4562,6 @@ unsigned getGatherVecOpcode(bool IsScaled, bool IsSigned, bool NeedsExtend) {
   return AddrModes.find(Key)->second;
 }
 
-unsigned getScatterVecOpcode(bool IsScaled, bool IsSigned, bool NeedsExtend) {
-  std::map<std::tuple<bool, bool, bool>, unsigned> AddrModes = {
-      {std::make_tuple(/*Scaled*/ false, /*Signed*/ false, /*Extend*/ false),
-       AArch64ISD::SST1_PRED},
-      {std::make_tuple(/*Scaled*/ false, /*Signed*/ false, /*Extend*/ true),
-       AArch64ISD::SST1_UXTW_PRED},
-      {std::make_tuple(/*Scaled*/ false, /*Signed*/ true, /*Extend*/ false),
-       AArch64ISD::SST1_PRED},
-      {std::make_tuple(/*Scaled*/ false, /*Signed*/ true, /*Extend*/ true),
-       AArch64ISD::SST1_SXTW_PRED},
-      {std::make_tuple(/*Scaled*/ true, /*Signed*/ false, /*Extend*/ false),
-       AArch64ISD::SST1_SCALED_PRED},
-      {std::make_tuple(/*Scaled*/ true, /*Signed*/ false, /*Extend*/ true),
-       AArch64ISD::SST1_UXTW_SCALED_PRED},
-      {std::make_tuple(/*Scaled*/ true, /*Signed*/ true, /*Extend*/ false),
-       AArch64ISD::SST1_SCALED_PRED},
-      {std::make_tuple(/*Scaled*/ true, /*Signed*/ true, /*Extend*/ true),
-       AArch64ISD::SST1_SXTW_SCALED_PRED},
-  };
-  auto Key = std::make_tuple(IsScaled, IsSigned, NeedsExtend);
-  return AddrModes.find(Key)->second;
-}
-
 unsigned getSignExtendedGatherOpcode(unsigned Opcode) {
   switch (Opcode) {
   default:
@@ -4607,53 +4584,6 @@ unsigned getSignExtendedGatherOpcode(unsigned Opcode) {
   }
 }
 
-bool getGatherScatterIndexIsExtended(SDValue Index) {
-  // Ignore non-pointer sized indices.
-  if (Index.getValueType() != MVT::nxv2i64)
-    return false;
-
-  unsigned Opcode = Index.getOpcode();
-  if (Opcode == ISD::SIGN_EXTEND_INREG)
-    return cast<VTSDNode>(Index.getOperand(1))->getVT() == MVT::nxv2i32;
-
-  if (Opcode == ISD::AND) {
-    SDValue Splat = Index.getOperand(1);
-    if (Splat.getOpcode() != ISD::SPLAT_VECTOR)
-      return false;
-    ConstantSDNode *Mask = dyn_cast<ConstantSDNode>(Splat.getOperand(0));
-    if (!Mask || Mask->getZExtValue() != 0xFFFFFFFF)
-      return false;
-    return true;
-  }
-
-  return false;
-}
-
-// If the base pointer of a masked gather or scatter is constant, we
-// may be able to swap BasePtr & Index and use the vector + immediate addressing
-// mode, e.g.
-// VECTOR + IMMEDIATE:
-//    getelementptr nullptr, <vscale x N x T> (splat(#x)) + %indices)
-// -> getelementptr #x, <vscale x N x T> %indices
-void selectGatherScatterAddrMode(SDValue &BasePtr, SDValue &Index,
-                                 bool IsScaled, EVT MemVT, unsigned &Opcode,
-                                 bool IsGather, SelectionDAG &DAG) {
-  ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(BasePtr);
-  if (!Offset || IsScaled)
-    return;
-
-  uint64_t OffsetVal = Offset->getZExtValue();
-  unsigned ScalarSizeInBytes = MemVT.getScalarSizeInBits() / 8;
-
-  if (OffsetVal % ScalarSizeInBytes || OffsetVal / ScalarSizeInBytes > 31)
-    return;
-
-  // Immediate is in range
-  Opcode =
-      IsGather ? AArch64ISD::GLD1_IMM_MERGE_ZERO : AArch64ISD::SST1_IMM_PRED;
-  std::swap(BasePtr, Index);
-}
-
 SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
                                             SelectionDAG &DAG) const {
   MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(Op);
@@ -4749,37 +4679,8 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
     return DAG.getMergeValues({Result, Load.getValue(1)}, DL);
   }
 
-  bool IdxNeedsExtend =
-      getGatherScatterIndexIsExtended(Index) ||
-      Index.getSimpleValueType().getVectorElementType() == MVT::i32;
-
-  EVT IndexVT = Index.getSimpleValueType();
-  SDValue InputVT = DAG.getValueType(MemVT);
-
-  // Handle FP data by using an integer gather and casting the result.
-  if (VT.isFloatingPoint())
-    InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
-
-  SDVTList VTs = DAG.getVTList(IndexVT, MVT::Other);
-
-  if (getGatherScatterIndexIsExtended(Index))
-    Index = Index.getOperand(0);
-
-  unsigned Opcode = getGatherVecOpcode(IsScaled, IsSigned, IdxNeedsExtend);
-  selectGatherScatterAddrMode(BasePtr, Index, IsScaled, MemVT, Opcode,
-                              /*isGather=*/true, DAG);
-
-  if (ExtType == ISD::SEXTLOAD)
-    Opcode = getSignExtendedGatherOpcode(Opcode);
-
-  SDValue Ops[] = {Chain, Mask, BasePtr, Index, InputVT};
-  SDValue Result = DAG.getNode(Opcode, DL, VTs, Ops);
-  Chain = Result.getValue(1);
-
-  if (VT.isFloatingPoint())
-    Result = getSVESafeBitCast(VT, Result, DAG);
-
-  return DAG.getMergeValues({Result, Chain}, DL);
+  // Everything else is legal.
+  return Op;
 }
 
 SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
@@ -4859,29 +4760,8 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
                                 MSC->getMemOperand(), IndexType, Truncating);
   }
 
-  bool NeedsExtend =
-      getGatherScatterIndexIsExtended(Index) ||
-      Index.getSimpleValueType().getVectorElementType() == MVT::i32;
-
-  SDVTList VTs = DAG.getVTList(MVT::Other);
-  SDValue InputVT = DAG.getValueType(MemVT);
-
-  if (VT.isFloatingPoint()) {
-    // Handle FP data by casting the data so an integer scatter can be used.
-    EVT StoreValVT = getPackedSVEVectorVT(VT.getVectorElementCount());
-    StoreVal = getSVESafeBitCast(StoreValVT, StoreVal, DAG);
-    InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
-  }
-
-  if (getGatherScatterIndexIsExtended(Index))
-    Index = Index.getOperand(0);
-
-  unsigned Opcode = getScatterVecOpcode(IsScaled, IsSigned, NeedsExtend);
-  selectGatherScatterAddrMode(BasePtr, Index, IsScaled, MemVT, Opcode,
-                              /*isGather=*/false, DAG);
-
-  SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, InputVT};
-  return DAG.getNode(Opcode, DL, VTs, Ops);
+  // Everything else is legal.
+  return Op;
 }
 
 SDValue AArch64TargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const {

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 003e2abf9ce59..3d42ac84a6266 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -443,6 +443,58 @@ def non_temporal_store :
          cast<MaskedStoreSDNode>(N)->isNonTemporal();
 }]>;
 
+multiclass masked_gather_scatter<PatFrags GatherScatterOp> {
+  // offsets = (signed)Index << sizeof(elt)
+  def NAME#_signed_scaled :
+    PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx),
+            (GatherScatterOp node:$val, node:$pred, node:$ptr, node:$idx),[{
+    auto MGS = cast<MaskedGatherScatterSDNode>(N);
+    bool Signed = MGS->isIndexSigned() ||
+        MGS->getIndex().getValueType().getVectorElementType() == MVT::i64;
+    return Signed && MGS->isIndexScaled();
+  }]>;
+  // offsets = (signed)Index
+  def NAME#_signed_unscaled :
+    PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx),
+            (GatherScatterOp node:$val, node:$pred, node:$ptr, node:$idx),[{
+    auto MGS = cast<MaskedGatherScatterSDNode>(N);
+    bool Signed = MGS->isIndexSigned() ||
+        MGS->getIndex().getValueType().getVectorElementType() == MVT::i64;
+    return Signed && !MGS->isIndexScaled();
+  }]>;
+  // offsets = (unsigned)Index << sizeof(elt)
+  def NAME#_unsigned_scaled :
+    PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx),
+            (GatherScatterOp node:$val, node:$pred, node:$ptr, node:$idx),[{
+    auto MGS = cast<MaskedGatherScatterSDNode>(N);
+    bool Signed = MGS->isIndexSigned() ||
+        MGS->getIndex().getValueType().getVectorElementType() == MVT::i64;
+    return !Signed && MGS->isIndexScaled();
+  }]>;
+  // offsets = (unsigned)Index
+  def NAME#_unsigned_unscaled :
+    PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx),
+            (GatherScatterOp node:$val, node:$pred, node:$ptr, node:$idx),[{
+    auto MGS = cast<MaskedGatherScatterSDNode>(N);
+    bool Signed = MGS->isIndexSigned() ||
+        MGS->getIndex().getValueType().getVectorElementType() == MVT::i64;
+    return !Signed && !MGS->isIndexScaled();
+  }]>;
+}
+
+defm nonext_masked_gather    : masked_gather_scatter<nonext_masked_gather>;
+defm azext_masked_gather_i8  : masked_gather_scatter<azext_masked_gather_i8>;
+defm azext_masked_gather_i16 : masked_gather_scatter<azext_masked_gather_i16>;
+defm azext_masked_gather_i32 : masked_gather_scatter<azext_masked_gather_i32>;
+defm sext_masked_gather_i8   : masked_gather_scatter<sext_masked_gather_i8>;
+defm sext_masked_gather_i16  : masked_gather_scatter<sext_masked_gather_i16>;
+defm sext_masked_gather_i32  : masked_gather_scatter<sext_masked_gather_i32>;
+
+defm nontrunc_masked_scatter  : masked_gather_scatter<nontrunc_masked_scatter>;
+defm trunc_masked_scatter_i8  : masked_gather_scatter<trunc_masked_scatter_i8>;
+defm trunc_masked_scatter_i16 : masked_gather_scatter<trunc_masked_scatter_i16>;
+defm trunc_masked_scatter_i32 : masked_gather_scatter<trunc_masked_scatter_i32>;
+
 // top16Zero - answer true if the upper 16 bits of $src are 0, false otherwise
 def top16Zero: PatLeaf<(i32 GPR32:$src), [{
   return SDValue(N,0)->getValueType(0) == MVT::i32 &&

diff  --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 5e60e0fe2cdb1..3262b3e20d81e 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -1036,6 +1036,92 @@ let Predicates = [HasSVE] in {
   defm GLDFF1W_D  : sve_mem_64b_gld_sv_32_scaled<0b1011, "ldff1w",  AArch64ldff1_gather_sxtw_scaled_z,  AArch64ldff1_gather_uxtw_scaled_z,  ZPR64ExtSXTW32, ZPR64ExtUXTW32, nxv2i32>;
   defm GLD1D      : sve_mem_64b_gld_sv_32_scaled<0b1110, "ld1d",    AArch64ld1_gather_sxtw_scaled_z,    AArch64ld1_gather_uxtw_scaled_z,    ZPR64ExtSXTW64, ZPR64ExtUXTW64, nxv2i64>;
   defm GLDFF1D    : sve_mem_64b_gld_sv_32_scaled<0b1111, "ldff1d",  AArch64ldff1_gather_sxtw_scaled_z,  AArch64ldff1_gather_uxtw_scaled_z,  ZPR64ExtSXTW64, ZPR64ExtUXTW64, nxv2i64>;
+
+  multiclass sve_masked_gather_x2_scaled<ValueType Ty, SDPatternOperator Load, string Inst> {
+    // base + vector of scaled offsets
+    def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), GPR64:$base, (nxv2i64 ZPR:$offs))),
+              (!cast<Instruction>(Inst # _SCALED) PPR:$gp, GPR64:$base, ZPR:$offs)>;
+    // base + vector of signed 32bit scaled offsets
+    def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), GPR64:$base, (sext_inreg (nxv2i64 ZPR:$offs), nxv2i32))),
+              (!cast<Instruction>(Inst # _SXTW_SCALED) PPR:$gp, GPR64:$base, ZPR:$offs)>;
+    // base + vector of unsigned 32bit scaled offsets
+    def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), GPR64:$base, (and (nxv2i64 ZPR:$offs), (nxv2i64 (AArch64dup (i64 0xFFFFFFFF)))))),
+              (!cast<Instruction>(Inst # _UXTW_SCALED) PPR:$gp, GPR64:$base, ZPR:$offs)>;
+  }
+
+  multiclass sve_masked_gather_x2_unscaled<ValueType Ty, SDPatternOperator Load, string Inst, Operand ImmTy> {
+    // vector of pointers + immediate offset (includes zero)
+    def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), (i64 ImmTy:$imm), (nxv2i64 ZPR:$ptrs))),
+              (!cast<Instruction>(Inst # _IMM) PPR:$gp, ZPR:$ptrs, ImmTy:$imm)>;
+    // base + vector of offsets
+    def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), GPR64:$base, (nxv2i64 ZPR:$offs))),
+              (!cast<Instruction>(Inst) PPR:$gp, GPR64:$base, ZPR:$offs)>;
+    // base + vector of signed 32bit offsets
+    def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), GPR64:$base, (sext_inreg (nxv2i64 ZPR:$offs), nxv2i32))),
+              (!cast<Instruction>(Inst # _SXTW) PPR:$gp, GPR64:$base, ZPR:$offs)>;
+    // base + vector of unsigned 32bit offsets
+    def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), GPR64:$base, (and (nxv2i64 ZPR:$offs), (nxv2i64 (AArch64dup (i64 0xFFFFFFFF)))))),
+              (!cast<Instruction>(Inst # _UXTW) PPR:$gp, GPR64:$base, ZPR:$offs)>;
+  }
+
+  multiclass sve_masked_gather_x4<ValueType Ty, SDPatternOperator Load, Instruction Inst> {
+    def : Pat<(Ty (Load (SVEDup0Undef), (nxv4i1 PPR:$gp), GPR64:$base, (nxv4i32 ZPR:$offs))),
+              (Inst PPR:$gp, GPR64:$base, ZPR:$offs)>;
+  }
+
+  defm : sve_masked_gather_x2_scaled<nxv2i64,  azext_masked_gather_i16_signed_scaled, "GLD1H_D">;
+  defm : sve_masked_gather_x2_scaled<nxv2i64,  sext_masked_gather_i16_signed_scaled,  "GLD1SH_D">;
+  defm : sve_masked_gather_x2_scaled<nxv2i64,  azext_masked_gather_i32_signed_scaled, "GLD1W_D">;
+  defm : sve_masked_gather_x2_scaled<nxv2i64,  sext_masked_gather_i32_signed_scaled,  "GLD1SW_D">;
+  defm : sve_masked_gather_x2_scaled<nxv2i64,  nonext_masked_gather_signed_scaled,    "GLD1D">;
+  defm : sve_masked_gather_x2_scaled<nxv2f16,  nonext_masked_gather_signed_scaled,    "GLD1H_D">;
+  defm : sve_masked_gather_x2_scaled<nxv2f32,  nonext_masked_gather_signed_scaled,    "GLD1W_D">;
+  defm : sve_masked_gather_x2_scaled<nxv2f64,  nonext_masked_gather_signed_scaled,    "GLD1D">;
+  defm : sve_masked_gather_x2_scaled<nxv2bf16, nonext_masked_gather_signed_scaled,    "GLD1H_D">;
+
+  defm : sve_masked_gather_x2_unscaled<nxv2i64,  azext_masked_gather_i8_signed_unscaled,  "GLD1B_D" , imm0_31>;
+  defm : sve_masked_gather_x2_unscaled<nxv2i64,  sext_masked_gather_i8_signed_unscaled,   "GLD1SB_D", imm0_31>;
+  defm : sve_masked_gather_x2_unscaled<nxv2i64,  azext_masked_gather_i16_signed_unscaled, "GLD1H_D",  uimm5s2>;
+  defm : sve_masked_gather_x2_unscaled<nxv2i64,  sext_masked_gather_i16_signed_unscaled,  "GLD1SH_D", uimm5s2>;
+  defm : sve_masked_gather_x2_unscaled<nxv2i64,  azext_masked_gather_i32_signed_unscaled, "GLD1W_D",  uimm5s4>;
+  defm : sve_masked_gather_x2_unscaled<nxv2i64,  sext_masked_gather_i32_signed_unscaled,  "GLD1SW_D", uimm5s4>;
+  defm : sve_masked_gather_x2_unscaled<nxv2i64,  nonext_masked_gather_signed_unscaled,    "GLD1D",    uimm5s8>;
+  defm : sve_masked_gather_x2_unscaled<nxv2f16,  nonext_masked_gather_signed_unscaled,    "GLD1H_D",  uimm5s2>;
+  defm : sve_masked_gather_x2_unscaled<nxv2f32,  nonext_masked_gather_signed_unscaled,    "GLD1W_D",  uimm5s4>;
+  defm : sve_masked_gather_x2_unscaled<nxv2f64,  nonext_masked_gather_signed_unscaled,    "GLD1D",    uimm5s8>;
+  defm : sve_masked_gather_x2_unscaled<nxv2bf16, nonext_masked_gather_signed_unscaled,    "GLD1H_D",  uimm5s2>;
+
+  defm : sve_masked_gather_x4<nxv4i32,  azext_masked_gather_i16_signed_scaled, GLD1H_S_SXTW_SCALED>;
+  defm : sve_masked_gather_x4<nxv4i32,  sext_masked_gather_i16_signed_scaled,  GLD1SH_S_SXTW_SCALED>;
+  defm : sve_masked_gather_x4<nxv4i32,  nonext_masked_gather_signed_scaled,    GLD1W_SXTW_SCALED>;
+  defm : sve_masked_gather_x4<nxv4f16,  nonext_masked_gather_signed_scaled,    GLD1H_S_SXTW_SCALED>;
+  defm : sve_masked_gather_x4<nxv4f32,  nonext_masked_gather_signed_scaled,    GLD1W_SXTW_SCALED>;
+  defm : sve_masked_gather_x4<nxv4bf16, nonext_masked_gather_signed_scaled,    GLD1H_S_SXTW_SCALED>;
+
+  defm : sve_masked_gather_x4<nxv4i32,  azext_masked_gather_i8_signed_unscaled,  GLD1B_S_SXTW>;
+  defm : sve_masked_gather_x4<nxv4i32,  sext_masked_gather_i8_signed_unscaled,   GLD1SB_S_SXTW>;
+  defm : sve_masked_gather_x4<nxv4i32,  azext_masked_gather_i16_signed_unscaled, GLD1H_S_SXTW>;
+  defm : sve_masked_gather_x4<nxv4i32,  sext_masked_gather_i16_signed_unscaled,  GLD1SH_S_SXTW>;
+  defm : sve_masked_gather_x4<nxv4i32,  nonext_masked_gather_signed_unscaled,    GLD1W_SXTW>;
+  defm : sve_masked_gather_x4<nxv4f16,  nonext_masked_gather_signed_unscaled,    GLD1H_S_SXTW>;
+  defm : sve_masked_gather_x4<nxv4f32,  nonext_masked_gather_signed_unscaled,    GLD1W_SXTW>;
+  defm : sve_masked_gather_x4<nxv4bf16, nonext_masked_gather_signed_unscaled,    GLD1H_S_SXTW>;
+
+  defm : sve_masked_gather_x4<nxv4i32,  azext_masked_gather_i16_unsigned_scaled, GLD1H_S_UXTW_SCALED>;
+  defm : sve_masked_gather_x4<nxv4i32,  sext_masked_gather_i16_unsigned_scaled,  GLD1SH_S_UXTW_SCALED>;
+  defm : sve_masked_gather_x4<nxv4i32,  nonext_masked_gather_unsigned_scaled,    GLD1W_UXTW_SCALED>;
+  defm : sve_masked_gather_x4<nxv4f16,  nonext_masked_gather_unsigned_scaled,    GLD1H_S_UXTW_SCALED>;
+  defm : sve_masked_gather_x4<nxv4f32,  nonext_masked_gather_unsigned_scaled,    GLD1W_UXTW_SCALED>;
+  defm : sve_masked_gather_x4<nxv4bf16, nonext_masked_gather_unsigned_scaled,    GLD1H_S_UXTW_SCALED>;
+
+  defm : sve_masked_gather_x4<nxv4i32,  azext_masked_gather_i8_unsigned_unscaled,  GLD1B_S_UXTW>;
+  defm : sve_masked_gather_x4<nxv4i32,  sext_masked_gather_i8_unsigned_unscaled,   GLD1SB_S_UXTW>;
+  defm : sve_masked_gather_x4<nxv4i32,  azext_masked_gather_i16_unsigned_unscaled, GLD1H_S_UXTW>;
+  defm : sve_masked_gather_x4<nxv4i32,  sext_masked_gather_i16_unsigned_unscaled,  GLD1SH_S_UXTW>;
+  defm : sve_masked_gather_x4<nxv4i32,  nonext_masked_gather_unsigned_unscaled,    GLD1W_UXTW>;
+  defm : sve_masked_gather_x4<nxv4f16,  nonext_masked_gather_unsigned_unscaled,    GLD1H_S_UXTW>;
+  defm : sve_masked_gather_x4<nxv4f32,  nonext_masked_gather_unsigned_unscaled,    GLD1W_UXTW>;
+  defm : sve_masked_gather_x4<nxv4bf16, nonext_masked_gather_unsigned_unscaled,    GLD1H_S_UXTW>;
 } // End HasSVE
 
 let Predicates = [HasSVEorSME] in {
@@ -1126,6 +1212,81 @@ let Predicates = [HasSVE] in {
   defm SST1H_D : sve_mem_sst_sv_64_scaled<0b01, "st1h", AArch64st1_scatter_scaled, ZPR64ExtLSL16, nxv2i16>;
   defm SST1W_D : sve_mem_sst_sv_64_scaled<0b10, "st1w", AArch64st1_scatter_scaled, ZPR64ExtLSL32, nxv2i32>;
   defm SST1D   : sve_mem_sst_sv_64_scaled<0b11, "st1d", AArch64st1_scatter_scaled, ZPR64ExtLSL64, nxv2i64>;
+
+  multiclass sve_masked_scatter_x2_scaled<ValueType Ty, SDPatternOperator Store, string Inst> {
+    // base + vector of scaled offsets
+    def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), GPR64:$base, (nxv2i64 ZPR:$offs)),
+              (!cast<Instruction>(Inst # _SCALED) ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>;
+    // base + vector of signed 32bit scaled offsets
+    def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), GPR64:$base, (sext_inreg (nxv2i64 ZPR:$offs), nxv2i32)),
+              (!cast<Instruction>(Inst # _SXTW_SCALED) ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>;
+    // base + vector of unsigned 32bit scaled offsets
+    def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), GPR64:$base, (and (nxv2i64 ZPR:$offs), (nxv2i64 (AArch64dup (i64 0xFFFFFFFF))))),
+              (!cast<Instruction>(Inst # _UXTW_SCALED) ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>;
+  }
+
+  multiclass sve_masked_scatter_x2_unscaled<ValueType Ty, SDPatternOperator Store, string Inst, Operand ImmTy> {
+    // vector of pointers + immediate offset (includes zero)
+    def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), (i64 ImmTy:$imm), (nxv2i64 ZPR:$ptrs)),
+              (!cast<Instruction>(Inst # _IMM) ZPR:$data, PPR:$gp, ZPR:$ptrs, ImmTy:$imm)>;
+    // base + vector of offsets
+    def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), GPR64:$base, (nxv2i64 ZPR:$offs)),
+              (!cast<Instruction>(Inst) ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>;
+    // base + vector of signed 32bit offsets
+    def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), GPR64:$base, (sext_inreg (nxv2i64 ZPR:$offs), nxv2i32)),
+              (!cast<Instruction>(Inst # _SXTW) ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>;
+    // base + vector of unsigned 32bit offsets
+    def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), GPR64:$base, (and (nxv2i64 ZPR:$offs), (nxv2i64 (AArch64dup (i64 0xFFFFFFFF))))),
+              (!cast<Instruction>(Inst # _UXTW) ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>;
+  }
+
+  multiclass sve_masked_scatter_x4<ValueType Ty, SDPatternOperator Store, Instruction Inst> {
+    def : Pat<(Store (Ty ZPR:$data), (nxv4i1 PPR:$gp), GPR64:$base, (nxv4i32 ZPR:$offs)),
+              (Inst ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>;
+  }
+
+  defm : sve_masked_scatter_x2_scaled<nxv2i64,  trunc_masked_scatter_i16_signed_scaled, "SST1H_D">;
+  defm : sve_masked_scatter_x2_scaled<nxv2i64,  trunc_masked_scatter_i32_signed_scaled, "SST1W_D">;
+  defm : sve_masked_scatter_x2_scaled<nxv2i64,  nontrunc_masked_scatter_signed_scaled,  "SST1D">;
+  defm : sve_masked_scatter_x2_scaled<nxv2f16,  nontrunc_masked_scatter_signed_scaled,  "SST1H_D">;
+  defm : sve_masked_scatter_x2_scaled<nxv2f32,  nontrunc_masked_scatter_signed_scaled,  "SST1W_D">;
+  defm : sve_masked_scatter_x2_scaled<nxv2f64,  nontrunc_masked_scatter_signed_scaled,  "SST1D">;
+  defm : sve_masked_scatter_x2_scaled<nxv2bf16, nontrunc_masked_scatter_signed_scaled,  "SST1H_D">;
+
+  defm : sve_masked_scatter_x2_unscaled<nxv2i64,  trunc_masked_scatter_i8_signed_unscaled,  "SST1B_D" , imm0_31>;
+  defm : sve_masked_scatter_x2_unscaled<nxv2i64,  trunc_masked_scatter_i16_signed_unscaled, "SST1H_D",  uimm5s2>;
+  defm : sve_masked_scatter_x2_unscaled<nxv2i64,  trunc_masked_scatter_i32_signed_unscaled, "SST1W_D",  uimm5s4>;
+  defm : sve_masked_scatter_x2_unscaled<nxv2i64,  nontrunc_masked_scatter_signed_unscaled,  "SST1D",    uimm5s8>;
+  defm : sve_masked_scatter_x2_unscaled<nxv2f16,  nontrunc_masked_scatter_signed_unscaled,  "SST1H_D",  uimm5s2>;
+  defm : sve_masked_scatter_x2_unscaled<nxv2f32,  nontrunc_masked_scatter_signed_unscaled,  "SST1W_D",  uimm5s4>;
+  defm : sve_masked_scatter_x2_unscaled<nxv2f64,  nontrunc_masked_scatter_signed_unscaled,  "SST1D",    uimm5s8>;
+  defm : sve_masked_scatter_x2_unscaled<nxv2bf16, nontrunc_masked_scatter_signed_unscaled,  "SST1H_D",  uimm5s2>;
+
+  defm : sve_masked_scatter_x4<nxv4i32,  trunc_masked_scatter_i16_signed_scaled, SST1H_S_SXTW_SCALED>;
+  defm : sve_masked_scatter_x4<nxv4i32,  nontrunc_masked_scatter_signed_scaled,  SST1W_SXTW_SCALED>;
+  defm : sve_masked_scatter_x4<nxv4f16,  nontrunc_masked_scatter_signed_scaled,  SST1H_S_SXTW_SCALED>;
+  defm : sve_masked_scatter_x4<nxv4f32,  nontrunc_masked_scatter_signed_scaled,  SST1W_SXTW_SCALED>;
+  defm : sve_masked_scatter_x4<nxv4bf16, nontrunc_masked_scatter_signed_scaled,  SST1H_S_SXTW_SCALED>;
+
+  defm : sve_masked_scatter_x4<nxv4i32,  trunc_masked_scatter_i8_signed_unscaled,  SST1B_S_SXTW>;
+  defm : sve_masked_scatter_x4<nxv4i32,  trunc_masked_scatter_i16_signed_unscaled, SST1H_S_SXTW>;
+  defm : sve_masked_scatter_x4<nxv4i32,  nontrunc_masked_scatter_signed_unscaled,  SST1W_SXTW>;
+  defm : sve_masked_scatter_x4<nxv4f16,  nontrunc_masked_scatter_signed_unscaled,  SST1H_S_SXTW>;
+  defm : sve_masked_scatter_x4<nxv4f32,  nontrunc_masked_scatter_signed_unscaled,  SST1W_SXTW>;
+  defm : sve_masked_scatter_x4<nxv4bf16, nontrunc_masked_scatter_signed_unscaled,  SST1H_S_SXTW>;
+
+  defm : sve_masked_scatter_x4<nxv4i32,  trunc_masked_scatter_i16_unsigned_scaled, SST1H_S_UXTW_SCALED>;
+  defm : sve_masked_scatter_x4<nxv4i32,  nontrunc_masked_scatter_unsigned_scaled,  SST1W_UXTW_SCALED>;
+  defm : sve_masked_scatter_x4<nxv4f16,  nontrunc_masked_scatter_unsigned_scaled,  SST1H_S_UXTW_SCALED>;
+  defm : sve_masked_scatter_x4<nxv4f32,  nontrunc_masked_scatter_unsigned_scaled,  SST1W_UXTW_SCALED>;
+  defm : sve_masked_scatter_x4<nxv4bf16, nontrunc_masked_scatter_unsigned_scaled,  SST1H_S_UXTW_SCALED>;
+
+  defm : sve_masked_scatter_x4<nxv4i32,  trunc_masked_scatter_i8_unsigned_unscaled,  SST1B_S_UXTW>;
+  defm : sve_masked_scatter_x4<nxv4i32,  trunc_masked_scatter_i16_unsigned_unscaled, SST1H_S_UXTW>;
+  defm : sve_masked_scatter_x4<nxv4i32,  nontrunc_masked_scatter_unsigned_unscaled,  SST1W_UXTW>;
+  defm : sve_masked_scatter_x4<nxv4f16,  nontrunc_masked_scatter_unsigned_unscaled,  SST1H_S_UXTW>;
+  defm : sve_masked_scatter_x4<nxv4f32,  nontrunc_masked_scatter_unsigned_unscaled,  SST1W_UXTW>;
+  defm : sve_masked_scatter_x4<nxv4bf16, nontrunc_masked_scatter_unsigned_unscaled,  SST1H_S_UXTW>;
 } // End HasSVE
 
 let Predicates = [HasSVEorSME] in {


        


More information about the llvm-commits mailing list