[llvm] 1fe4953 - [SVE] Remove custom lowering of scalable vector MGATHER & MSCATTER operations.
Paul Walker via llvm-commits
llvm-commits at lists.llvm.org
Thu Jun 2 03:28:31 PDT 2022
Author: Paul Walker
Date: 2022-06-02T11:19:52+01:00
New Revision: 1fe4953d8939ab4f8f0a5de060c0a35758d835a8
URL: https://github.com/llvm/llvm-project/commit/1fe4953d8939ab4f8f0a5de060c0a35758d835a8
DIFF: https://github.com/llvm/llvm-project/commit/1fe4953d8939ab4f8f0a5de060c0a35758d835a8.diff
LOG: [SVE] Remove custom lowering of scalable vector MGATHER & MSCATTER operations.
Differential Revision: https://reviews.llvm.org/D126255
Added:
Modified:
llvm/include/llvm/Target/TargetSelectionDAG.td
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64InstrInfo.td
llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Removed:
################################################################################
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 7b2a25605acb0..47b686aca7b5d 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -238,6 +238,16 @@ def SDTMaskedLoad: SDTypeProfile<1, 4, [ // masked load
SDTCisSameNumEltsAs<0, 3>
]>;
+def SDTMaskedGather : SDTypeProfile<1, 4, [
+ SDTCisVec<0>, SDTCisSameAs<0, 1>, SDTCisVec<2>, SDTCisPtrTy<3>, SDTCisVec<4>,
+ SDTCisSameNumEltsAs<0, 2>, SDTCisSameNumEltsAs<0, 4>
+]>;
+
+def SDTMaskedScatter : SDTypeProfile<0, 4, [
+ SDTCisVec<0>, SDTCisVec<1>, SDTCisPtrTy<2>, SDTCisVec<3>,
+ SDTCisSameNumEltsAs<0, 1>, SDTCisSameNumEltsAs<0, 3>
+]>;
+
def SDTVecShuffle : SDTypeProfile<1, 2, [
SDTCisSameAs<0, 1>, SDTCisSameAs<1, 2>
]>;
@@ -652,6 +662,12 @@ def masked_st : SDNode<"ISD::MSTORE", SDTMaskedStore,
def masked_ld : SDNode<"ISD::MLOAD", SDTMaskedLoad,
[SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
+def masked_gather : SDNode<"ISD::MGATHER", SDTMaskedGather,
+ [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
+
+def masked_scatter : SDNode<"ISD::MSCATTER", SDTMaskedScatter,
+ [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>;
+
// Do not use ld, st directly. Use load, extload, sextload, zextload, store,
// and truncst (see below).
def ld : SDNode<"ISD::LOAD" , SDTLoad,
@@ -1628,6 +1644,124 @@ def atomic_load_64 :
let MemoryVT = i64;
}
+def nonext_masked_gather :
+ PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx),
+ (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{
+ return cast<MaskedGatherSDNode>(N)->getExtensionType() == ISD::NON_EXTLOAD;
+}]>;
+
+// Any extending masked gather fragments.
+def ext_masked_gather_i8 :
+ PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx),
+ (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{
+ auto MGN = cast<MaskedGatherSDNode>(N);
+ return MGN->getExtensionType() == ISD::EXTLOAD &&
+ MGN->getMemoryVT().getScalarType() == MVT::i8;
+}]>;
+def ext_masked_gather_i16 :
+ PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx),
+ (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{
+ auto MGN = cast<MaskedGatherSDNode>(N);
+ return MGN->getExtensionType() == ISD::EXTLOAD &&
+ MGN->getMemoryVT().getScalarType() == MVT::i16;
+}]>;
+def ext_masked_gather_i32 :
+ PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx),
+ (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{
+ auto MGN = cast<MaskedGatherSDNode>(N);
+ return MGN->getExtensionType() == ISD::EXTLOAD &&
+ MGN->getMemoryVT().getScalarType() == MVT::i32;
+}]>;
+
+// Sign extending masked gather fragments.
+def sext_masked_gather_i8 :
+ PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx),
+ (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{
+ auto MGN = cast<MaskedGatherSDNode>(N);
+ return MGN->getExtensionType() == ISD::SEXTLOAD &&
+ MGN->getMemoryVT().getScalarType() == MVT::i8;
+}]>;
+def sext_masked_gather_i16 :
+ PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx),
+ (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{
+ auto MGN = cast<MaskedGatherSDNode>(N);
+ return MGN->getExtensionType() == ISD::SEXTLOAD &&
+ MGN->getMemoryVT().getScalarType() == MVT::i16;
+}]>;
+def sext_masked_gather_i32 :
+ PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx),
+ (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{
+ auto MGN = cast<MaskedGatherSDNode>(N);
+ return MGN->getExtensionType() == ISD::SEXTLOAD &&
+ MGN->getMemoryVT().getScalarType() == MVT::i32;
+}]>;
+
+// Zero extending masked gather fragments.
+def zext_masked_gather_i8 :
+ PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx),
+ (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{
+ auto MGN = cast<MaskedGatherSDNode>(N);
+ return MGN->getExtensionType() == ISD::ZEXTLOAD &&
+ MGN->getMemoryVT().getScalarType() == MVT::i8;
+}]>;
+def zext_masked_gather_i16 :
+ PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx),
+ (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{
+ auto MGN = cast<MaskedGatherSDNode>(N);
+ return MGN->getExtensionType() == ISD::ZEXTLOAD &&
+ MGN->getMemoryVT().getScalarType() == MVT::i16;
+}]>;
+def zext_masked_gather_i32 :
+ PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx),
+ (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{
+ auto MGN = cast<MaskedGatherSDNode>(N);
+ return MGN->getExtensionType() == ISD::ZEXTLOAD &&
+ MGN->getMemoryVT().getScalarType() == MVT::i32;
+}]>;
+
+// Any/Zero extending masked gather fragments.
+def azext_masked_gather_i8 :
+ PatFrags<(ops node:$def, node:$pred, node:$ptr, node:$idx),
+ [(ext_masked_gather_i8 node:$def, node:$pred, node:$ptr, node:$idx),
+ (zext_masked_gather_i8 node:$def, node:$pred, node:$ptr, node:$idx)]>;
+def azext_masked_gather_i16 :
+ PatFrags<(ops node:$def, node:$pred, node:$ptr, node:$idx),
+ [(ext_masked_gather_i16 node:$def, node:$pred, node:$ptr, node:$idx),
+ (zext_masked_gather_i16 node:$def, node:$pred, node:$ptr, node:$idx)]>;
+def azext_masked_gather_i32 :
+ PatFrags<(ops node:$def, node:$pred, node:$ptr, node:$idx),
+ [(ext_masked_gather_i32 node:$def, node:$pred, node:$ptr, node:$idx),
+ (zext_masked_gather_i32 node:$def, node:$pred, node:$ptr, node:$idx)]>;
+
+def nontrunc_masked_scatter :
+ PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx),
+ (masked_scatter node:$val, node:$pred, node:$ptr, node:$idx), [{
+ return !cast<MaskedScatterSDNode>(N)->isTruncatingStore();
+}]>;
+
+// Truncating masked scatter fragments.
+def trunc_masked_scatter_i8 :
+ PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx),
+ (masked_scatter node:$val, node:$pred, node:$ptr, node:$idx), [{
+ auto MSN = cast<MaskedScatterSDNode>(N);
+ return MSN->isTruncatingStore() &&
+ MSN->getMemoryVT().getScalarType() == MVT::i8;
+}]>;
+def trunc_masked_scatter_i16 :
+ PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx),
+ (masked_scatter node:$val, node:$pred, node:$ptr, node:$idx), [{
+ auto MSN = cast<MaskedScatterSDNode>(N);
+ return MSN->isTruncatingStore() &&
+ MSN->getMemoryVT().getScalarType() == MVT::i16;
+}]>;
+def trunc_masked_scatter_i32 :
+ PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx),
+ (masked_scatter node:$val, node:$pred, node:$ptr, node:$idx), [{
+ auto MSN = cast<MaskedScatterSDNode>(N);
+ return MSN->isTruncatingStore() &&
+ MSN->getMemoryVT().getScalarType() == MVT::i32;
+}]>;
+
//===----------------------------------------------------------------------===//
// Selection DAG Pattern Support.
//
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 00c8381472183..6243aa9dc2bb5 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4562,29 +4562,6 @@ unsigned getGatherVecOpcode(bool IsScaled, bool IsSigned, bool NeedsExtend) {
return AddrModes.find(Key)->second;
}
-unsigned getScatterVecOpcode(bool IsScaled, bool IsSigned, bool NeedsExtend) {
- std::map<std::tuple<bool, bool, bool>, unsigned> AddrModes = {
- {std::make_tuple(/*Scaled*/ false, /*Signed*/ false, /*Extend*/ false),
- AArch64ISD::SST1_PRED},
- {std::make_tuple(/*Scaled*/ false, /*Signed*/ false, /*Extend*/ true),
- AArch64ISD::SST1_UXTW_PRED},
- {std::make_tuple(/*Scaled*/ false, /*Signed*/ true, /*Extend*/ false),
- AArch64ISD::SST1_PRED},
- {std::make_tuple(/*Scaled*/ false, /*Signed*/ true, /*Extend*/ true),
- AArch64ISD::SST1_SXTW_PRED},
- {std::make_tuple(/*Scaled*/ true, /*Signed*/ false, /*Extend*/ false),
- AArch64ISD::SST1_SCALED_PRED},
- {std::make_tuple(/*Scaled*/ true, /*Signed*/ false, /*Extend*/ true),
- AArch64ISD::SST1_UXTW_SCALED_PRED},
- {std::make_tuple(/*Scaled*/ true, /*Signed*/ true, /*Extend*/ false),
- AArch64ISD::SST1_SCALED_PRED},
- {std::make_tuple(/*Scaled*/ true, /*Signed*/ true, /*Extend*/ true),
- AArch64ISD::SST1_SXTW_SCALED_PRED},
- };
- auto Key = std::make_tuple(IsScaled, IsSigned, NeedsExtend);
- return AddrModes.find(Key)->second;
-}
-
unsigned getSignExtendedGatherOpcode(unsigned Opcode) {
switch (Opcode) {
default:
@@ -4607,53 +4584,6 @@ unsigned getSignExtendedGatherOpcode(unsigned Opcode) {
}
}
-bool getGatherScatterIndexIsExtended(SDValue Index) {
- // Ignore non-pointer sized indices.
- if (Index.getValueType() != MVT::nxv2i64)
- return false;
-
- unsigned Opcode = Index.getOpcode();
- if (Opcode == ISD::SIGN_EXTEND_INREG)
- return cast<VTSDNode>(Index.getOperand(1))->getVT() == MVT::nxv2i32;
-
- if (Opcode == ISD::AND) {
- SDValue Splat = Index.getOperand(1);
- if (Splat.getOpcode() != ISD::SPLAT_VECTOR)
- return false;
- ConstantSDNode *Mask = dyn_cast<ConstantSDNode>(Splat.getOperand(0));
- if (!Mask || Mask->getZExtValue() != 0xFFFFFFFF)
- return false;
- return true;
- }
-
- return false;
-}
-
-// If the base pointer of a masked gather or scatter is constant, we
-// may be able to swap BasePtr & Index and use the vector + immediate addressing
-// mode, e.g.
-// VECTOR + IMMEDIATE:
-// getelementptr nullptr, <vscale x N x T> (splat(#x)) + %indices)
-// -> getelementptr #x, <vscale x N x T> %indices
-void selectGatherScatterAddrMode(SDValue &BasePtr, SDValue &Index,
- bool IsScaled, EVT MemVT, unsigned &Opcode,
- bool IsGather, SelectionDAG &DAG) {
- ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(BasePtr);
- if (!Offset || IsScaled)
- return;
-
- uint64_t OffsetVal = Offset->getZExtValue();
- unsigned ScalarSizeInBytes = MemVT.getScalarSizeInBits() / 8;
-
- if (OffsetVal % ScalarSizeInBytes || OffsetVal / ScalarSizeInBytes > 31)
- return;
-
- // Immediate is in range
- Opcode =
- IsGather ? AArch64ISD::GLD1_IMM_MERGE_ZERO : AArch64ISD::SST1_IMM_PRED;
- std::swap(BasePtr, Index);
-}
-
SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
SelectionDAG &DAG) const {
MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(Op);
@@ -4749,37 +4679,8 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
return DAG.getMergeValues({Result, Load.getValue(1)}, DL);
}
- bool IdxNeedsExtend =
- getGatherScatterIndexIsExtended(Index) ||
- Index.getSimpleValueType().getVectorElementType() == MVT::i32;
-
- EVT IndexVT = Index.getSimpleValueType();
- SDValue InputVT = DAG.getValueType(MemVT);
-
- // Handle FP data by using an integer gather and casting the result.
- if (VT.isFloatingPoint())
- InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
-
- SDVTList VTs = DAG.getVTList(IndexVT, MVT::Other);
-
- if (getGatherScatterIndexIsExtended(Index))
- Index = Index.getOperand(0);
-
- unsigned Opcode = getGatherVecOpcode(IsScaled, IsSigned, IdxNeedsExtend);
- selectGatherScatterAddrMode(BasePtr, Index, IsScaled, MemVT, Opcode,
- /*isGather=*/true, DAG);
-
- if (ExtType == ISD::SEXTLOAD)
- Opcode = getSignExtendedGatherOpcode(Opcode);
-
- SDValue Ops[] = {Chain, Mask, BasePtr, Index, InputVT};
- SDValue Result = DAG.getNode(Opcode, DL, VTs, Ops);
- Chain = Result.getValue(1);
-
- if (VT.isFloatingPoint())
- Result = getSVESafeBitCast(VT, Result, DAG);
-
- return DAG.getMergeValues({Result, Chain}, DL);
+ // Everything else is legal.
+ return Op;
}
SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
@@ -4859,29 +4760,8 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
MSC->getMemOperand(), IndexType, Truncating);
}
- bool NeedsExtend =
- getGatherScatterIndexIsExtended(Index) ||
- Index.getSimpleValueType().getVectorElementType() == MVT::i32;
-
- SDVTList VTs = DAG.getVTList(MVT::Other);
- SDValue InputVT = DAG.getValueType(MemVT);
-
- if (VT.isFloatingPoint()) {
- // Handle FP data by casting the data so an integer scatter can be used.
- EVT StoreValVT = getPackedSVEVectorVT(VT.getVectorElementCount());
- StoreVal = getSVESafeBitCast(StoreValVT, StoreVal, DAG);
- InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
- }
-
- if (getGatherScatterIndexIsExtended(Index))
- Index = Index.getOperand(0);
-
- unsigned Opcode = getScatterVecOpcode(IsScaled, IsSigned, NeedsExtend);
- selectGatherScatterAddrMode(BasePtr, Index, IsScaled, MemVT, Opcode,
- /*isGather=*/false, DAG);
-
- SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, InputVT};
- return DAG.getNode(Opcode, DL, VTs, Ops);
+ // Everything else is legal.
+ return Op;
}
SDValue AArch64TargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const {
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 003e2abf9ce59..3d42ac84a6266 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -443,6 +443,58 @@ def non_temporal_store :
cast<MaskedStoreSDNode>(N)->isNonTemporal();
}]>;
+multiclass masked_gather_scatter<PatFrags GatherScatterOp> {
+ // offsets = (signed)Index << sizeof(elt)
+ def NAME#_signed_scaled :
+ PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx),
+ (GatherScatterOp node:$val, node:$pred, node:$ptr, node:$idx),[{
+ auto MGS = cast<MaskedGatherScatterSDNode>(N);
+ bool Signed = MGS->isIndexSigned() ||
+ MGS->getIndex().getValueType().getVectorElementType() == MVT::i64;
+ return Signed && MGS->isIndexScaled();
+ }]>;
+ // offsets = (signed)Index
+ def NAME#_signed_unscaled :
+ PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx),
+ (GatherScatterOp node:$val, node:$pred, node:$ptr, node:$idx),[{
+ auto MGS = cast<MaskedGatherScatterSDNode>(N);
+ bool Signed = MGS->isIndexSigned() ||
+ MGS->getIndex().getValueType().getVectorElementType() == MVT::i64;
+ return Signed && !MGS->isIndexScaled();
+ }]>;
+ // offsets = (unsigned)Index << sizeof(elt)
+ def NAME#_unsigned_scaled :
+ PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx),
+ (GatherScatterOp node:$val, node:$pred, node:$ptr, node:$idx),[{
+ auto MGS = cast<MaskedGatherScatterSDNode>(N);
+ bool Signed = MGS->isIndexSigned() ||
+ MGS->getIndex().getValueType().getVectorElementType() == MVT::i64;
+ return !Signed && MGS->isIndexScaled();
+ }]>;
+ // offsets = (unsigned)Index
+ def NAME#_unsigned_unscaled :
+ PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx),
+ (GatherScatterOp node:$val, node:$pred, node:$ptr, node:$idx),[{
+ auto MGS = cast<MaskedGatherScatterSDNode>(N);
+ bool Signed = MGS->isIndexSigned() ||
+ MGS->getIndex().getValueType().getVectorElementType() == MVT::i64;
+ return !Signed && !MGS->isIndexScaled();
+ }]>;
+}
+
+defm nonext_masked_gather : masked_gather_scatter<nonext_masked_gather>;
+defm azext_masked_gather_i8 : masked_gather_scatter<azext_masked_gather_i8>;
+defm azext_masked_gather_i16 : masked_gather_scatter<azext_masked_gather_i16>;
+defm azext_masked_gather_i32 : masked_gather_scatter<azext_masked_gather_i32>;
+defm sext_masked_gather_i8 : masked_gather_scatter<sext_masked_gather_i8>;
+defm sext_masked_gather_i16 : masked_gather_scatter<sext_masked_gather_i16>;
+defm sext_masked_gather_i32 : masked_gather_scatter<sext_masked_gather_i32>;
+
+defm nontrunc_masked_scatter : masked_gather_scatter<nontrunc_masked_scatter>;
+defm trunc_masked_scatter_i8 : masked_gather_scatter<trunc_masked_scatter_i8>;
+defm trunc_masked_scatter_i16 : masked_gather_scatter<trunc_masked_scatter_i16>;
+defm trunc_masked_scatter_i32 : masked_gather_scatter<trunc_masked_scatter_i32>;
+
// top16Zero - answer true if the upper 16 bits of $src are 0, false otherwise
def top16Zero: PatLeaf<(i32 GPR32:$src), [{
return SDValue(N,0)->getValueType(0) == MVT::i32 &&
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 5e60e0fe2cdb1..3262b3e20d81e 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -1036,6 +1036,92 @@ let Predicates = [HasSVE] in {
defm GLDFF1W_D : sve_mem_64b_gld_sv_32_scaled<0b1011, "ldff1w", AArch64ldff1_gather_sxtw_scaled_z, AArch64ldff1_gather_uxtw_scaled_z, ZPR64ExtSXTW32, ZPR64ExtUXTW32, nxv2i32>;
defm GLD1D : sve_mem_64b_gld_sv_32_scaled<0b1110, "ld1d", AArch64ld1_gather_sxtw_scaled_z, AArch64ld1_gather_uxtw_scaled_z, ZPR64ExtSXTW64, ZPR64ExtUXTW64, nxv2i64>;
defm GLDFF1D : sve_mem_64b_gld_sv_32_scaled<0b1111, "ldff1d", AArch64ldff1_gather_sxtw_scaled_z, AArch64ldff1_gather_uxtw_scaled_z, ZPR64ExtSXTW64, ZPR64ExtUXTW64, nxv2i64>;
+
+ multiclass sve_masked_gather_x2_scaled<ValueType Ty, SDPatternOperator Load, string Inst> {
+ // base + vector of scaled offsets
+ def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), GPR64:$base, (nxv2i64 ZPR:$offs))),
+ (!cast<Instruction>(Inst # _SCALED) PPR:$gp, GPR64:$base, ZPR:$offs)>;
+ // base + vector of signed 32bit scaled offsets
+ def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), GPR64:$base, (sext_inreg (nxv2i64 ZPR:$offs), nxv2i32))),
+ (!cast<Instruction>(Inst # _SXTW_SCALED) PPR:$gp, GPR64:$base, ZPR:$offs)>;
+ // base + vector of unsigned 32bit scaled offsets
+ def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), GPR64:$base, (and (nxv2i64 ZPR:$offs), (nxv2i64 (AArch64dup (i64 0xFFFFFFFF)))))),
+ (!cast<Instruction>(Inst # _UXTW_SCALED) PPR:$gp, GPR64:$base, ZPR:$offs)>;
+ }
+
+ multiclass sve_masked_gather_x2_unscaled<ValueType Ty, SDPatternOperator Load, string Inst, Operand ImmTy> {
+ // vector of pointers + immediate offset (includes zero)
+ def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), (i64 ImmTy:$imm), (nxv2i64 ZPR:$ptrs))),
+ (!cast<Instruction>(Inst # _IMM) PPR:$gp, ZPR:$ptrs, ImmTy:$imm)>;
+ // base + vector of offsets
+ def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), GPR64:$base, (nxv2i64 ZPR:$offs))),
+ (!cast<Instruction>(Inst) PPR:$gp, GPR64:$base, ZPR:$offs)>;
+ // base + vector of signed 32bit offsets
+ def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), GPR64:$base, (sext_inreg (nxv2i64 ZPR:$offs), nxv2i32))),
+ (!cast<Instruction>(Inst # _SXTW) PPR:$gp, GPR64:$base, ZPR:$offs)>;
+ // base + vector of unsigned 32bit offsets
+ def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), GPR64:$base, (and (nxv2i64 ZPR:$offs), (nxv2i64 (AArch64dup (i64 0xFFFFFFFF)))))),
+ (!cast<Instruction>(Inst # _UXTW) PPR:$gp, GPR64:$base, ZPR:$offs)>;
+ }
+
+ multiclass sve_masked_gather_x4<ValueType Ty, SDPatternOperator Load, Instruction Inst> {
+ def : Pat<(Ty (Load (SVEDup0Undef), (nxv4i1 PPR:$gp), GPR64:$base, (nxv4i32 ZPR:$offs))),
+ (Inst PPR:$gp, GPR64:$base, ZPR:$offs)>;
+ }
+
+ defm : sve_masked_gather_x2_scaled<nxv2i64, azext_masked_gather_i16_signed_scaled, "GLD1H_D">;
+ defm : sve_masked_gather_x2_scaled<nxv2i64, sext_masked_gather_i16_signed_scaled, "GLD1SH_D">;
+ defm : sve_masked_gather_x2_scaled<nxv2i64, azext_masked_gather_i32_signed_scaled, "GLD1W_D">;
+ defm : sve_masked_gather_x2_scaled<nxv2i64, sext_masked_gather_i32_signed_scaled, "GLD1SW_D">;
+ defm : sve_masked_gather_x2_scaled<nxv2i64, nonext_masked_gather_signed_scaled, "GLD1D">;
+ defm : sve_masked_gather_x2_scaled<nxv2f16, nonext_masked_gather_signed_scaled, "GLD1H_D">;
+ defm : sve_masked_gather_x2_scaled<nxv2f32, nonext_masked_gather_signed_scaled, "GLD1W_D">;
+ defm : sve_masked_gather_x2_scaled<nxv2f64, nonext_masked_gather_signed_scaled, "GLD1D">;
+ defm : sve_masked_gather_x2_scaled<nxv2bf16, nonext_masked_gather_signed_scaled, "GLD1H_D">;
+
+ defm : sve_masked_gather_x2_unscaled<nxv2i64, azext_masked_gather_i8_signed_unscaled, "GLD1B_D" , imm0_31>;
+ defm : sve_masked_gather_x2_unscaled<nxv2i64, sext_masked_gather_i8_signed_unscaled, "GLD1SB_D", imm0_31>;
+ defm : sve_masked_gather_x2_unscaled<nxv2i64, azext_masked_gather_i16_signed_unscaled, "GLD1H_D", uimm5s2>;
+ defm : sve_masked_gather_x2_unscaled<nxv2i64, sext_masked_gather_i16_signed_unscaled, "GLD1SH_D", uimm5s2>;
+ defm : sve_masked_gather_x2_unscaled<nxv2i64, azext_masked_gather_i32_signed_unscaled, "GLD1W_D", uimm5s4>;
+ defm : sve_masked_gather_x2_unscaled<nxv2i64, sext_masked_gather_i32_signed_unscaled, "GLD1SW_D", uimm5s4>;
+ defm : sve_masked_gather_x2_unscaled<nxv2i64, nonext_masked_gather_signed_unscaled, "GLD1D", uimm5s8>;
+ defm : sve_masked_gather_x2_unscaled<nxv2f16, nonext_masked_gather_signed_unscaled, "GLD1H_D", uimm5s2>;
+ defm : sve_masked_gather_x2_unscaled<nxv2f32, nonext_masked_gather_signed_unscaled, "GLD1W_D", uimm5s4>;
+ defm : sve_masked_gather_x2_unscaled<nxv2f64, nonext_masked_gather_signed_unscaled, "GLD1D", uimm5s8>;
+ defm : sve_masked_gather_x2_unscaled<nxv2bf16, nonext_masked_gather_signed_unscaled, "GLD1H_D", uimm5s2>;
+
+ defm : sve_masked_gather_x4<nxv4i32, azext_masked_gather_i16_signed_scaled, GLD1H_S_SXTW_SCALED>;
+ defm : sve_masked_gather_x4<nxv4i32, sext_masked_gather_i16_signed_scaled, GLD1SH_S_SXTW_SCALED>;
+ defm : sve_masked_gather_x4<nxv4i32, nonext_masked_gather_signed_scaled, GLD1W_SXTW_SCALED>;
+ defm : sve_masked_gather_x4<nxv4f16, nonext_masked_gather_signed_scaled, GLD1H_S_SXTW_SCALED>;
+ defm : sve_masked_gather_x4<nxv4f32, nonext_masked_gather_signed_scaled, GLD1W_SXTW_SCALED>;
+ defm : sve_masked_gather_x4<nxv4bf16, nonext_masked_gather_signed_scaled, GLD1H_S_SXTW_SCALED>;
+
+ defm : sve_masked_gather_x4<nxv4i32, azext_masked_gather_i8_signed_unscaled, GLD1B_S_SXTW>;
+ defm : sve_masked_gather_x4<nxv4i32, sext_masked_gather_i8_signed_unscaled, GLD1SB_S_SXTW>;
+ defm : sve_masked_gather_x4<nxv4i32, azext_masked_gather_i16_signed_unscaled, GLD1H_S_SXTW>;
+ defm : sve_masked_gather_x4<nxv4i32, sext_masked_gather_i16_signed_unscaled, GLD1SH_S_SXTW>;
+ defm : sve_masked_gather_x4<nxv4i32, nonext_masked_gather_signed_unscaled, GLD1W_SXTW>;
+ defm : sve_masked_gather_x4<nxv4f16, nonext_masked_gather_signed_unscaled, GLD1H_S_SXTW>;
+ defm : sve_masked_gather_x4<nxv4f32, nonext_masked_gather_signed_unscaled, GLD1W_SXTW>;
+ defm : sve_masked_gather_x4<nxv4bf16, nonext_masked_gather_signed_unscaled, GLD1H_S_SXTW>;
+
+ defm : sve_masked_gather_x4<nxv4i32, azext_masked_gather_i16_unsigned_scaled, GLD1H_S_UXTW_SCALED>;
+ defm : sve_masked_gather_x4<nxv4i32, sext_masked_gather_i16_unsigned_scaled, GLD1SH_S_UXTW_SCALED>;
+ defm : sve_masked_gather_x4<nxv4i32, nonext_masked_gather_unsigned_scaled, GLD1W_UXTW_SCALED>;
+ defm : sve_masked_gather_x4<nxv4f16, nonext_masked_gather_unsigned_scaled, GLD1H_S_UXTW_SCALED>;
+ defm : sve_masked_gather_x4<nxv4f32, nonext_masked_gather_unsigned_scaled, GLD1W_UXTW_SCALED>;
+ defm : sve_masked_gather_x4<nxv4bf16, nonext_masked_gather_unsigned_scaled, GLD1H_S_UXTW_SCALED>;
+
+ defm : sve_masked_gather_x4<nxv4i32, azext_masked_gather_i8_unsigned_unscaled, GLD1B_S_UXTW>;
+ defm : sve_masked_gather_x4<nxv4i32, sext_masked_gather_i8_unsigned_unscaled, GLD1SB_S_UXTW>;
+ defm : sve_masked_gather_x4<nxv4i32, azext_masked_gather_i16_unsigned_unscaled, GLD1H_S_UXTW>;
+ defm : sve_masked_gather_x4<nxv4i32, sext_masked_gather_i16_unsigned_unscaled, GLD1SH_S_UXTW>;
+ defm : sve_masked_gather_x4<nxv4i32, nonext_masked_gather_unsigned_unscaled, GLD1W_UXTW>;
+ defm : sve_masked_gather_x4<nxv4f16, nonext_masked_gather_unsigned_unscaled, GLD1H_S_UXTW>;
+ defm : sve_masked_gather_x4<nxv4f32, nonext_masked_gather_unsigned_unscaled, GLD1W_UXTW>;
+ defm : sve_masked_gather_x4<nxv4bf16, nonext_masked_gather_unsigned_unscaled, GLD1H_S_UXTW>;
} // End HasSVE
let Predicates = [HasSVEorSME] in {
@@ -1126,6 +1212,81 @@ let Predicates = [HasSVE] in {
defm SST1H_D : sve_mem_sst_sv_64_scaled<0b01, "st1h", AArch64st1_scatter_scaled, ZPR64ExtLSL16, nxv2i16>;
defm SST1W_D : sve_mem_sst_sv_64_scaled<0b10, "st1w", AArch64st1_scatter_scaled, ZPR64ExtLSL32, nxv2i32>;
defm SST1D : sve_mem_sst_sv_64_scaled<0b11, "st1d", AArch64st1_scatter_scaled, ZPR64ExtLSL64, nxv2i64>;
+
+ multiclass sve_masked_scatter_x2_scaled<ValueType Ty, SDPatternOperator Store, string Inst> {
+ // base + vector of scaled offsets
+ def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), GPR64:$base, (nxv2i64 ZPR:$offs)),
+ (!cast<Instruction>(Inst # _SCALED) ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>;
+ // base + vector of signed 32bit scaled offsets
+ def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), GPR64:$base, (sext_inreg (nxv2i64 ZPR:$offs), nxv2i32)),
+ (!cast<Instruction>(Inst # _SXTW_SCALED) ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>;
+ // base + vector of unsigned 32bit scaled offsets
+ def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), GPR64:$base, (and (nxv2i64 ZPR:$offs), (nxv2i64 (AArch64dup (i64 0xFFFFFFFF))))),
+ (!cast<Instruction>(Inst # _UXTW_SCALED) ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>;
+ }
+
+ multiclass sve_masked_scatter_x2_unscaled<ValueType Ty, SDPatternOperator Store, string Inst, Operand ImmTy> {
+ // vector of pointers + immediate offset (includes zero)
+ def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), (i64 ImmTy:$imm), (nxv2i64 ZPR:$ptrs)),
+ (!cast<Instruction>(Inst # _IMM) ZPR:$data, PPR:$gp, ZPR:$ptrs, ImmTy:$imm)>;
+ // base + vector of offsets
+ def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), GPR64:$base, (nxv2i64 ZPR:$offs)),
+ (!cast<Instruction>(Inst) ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>;
+ // base + vector of signed 32bit offsets
+ def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), GPR64:$base, (sext_inreg (nxv2i64 ZPR:$offs), nxv2i32)),
+ (!cast<Instruction>(Inst # _SXTW) ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>;
+ // base + vector of unsigned 32bit offsets
+ def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), GPR64:$base, (and (nxv2i64 ZPR:$offs), (nxv2i64 (AArch64dup (i64 0xFFFFFFFF))))),
+ (!cast<Instruction>(Inst # _UXTW) ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>;
+ }
+
+ multiclass sve_masked_scatter_x4<ValueType Ty, SDPatternOperator Store, Instruction Inst> {
+ def : Pat<(Store (Ty ZPR:$data), (nxv4i1 PPR:$gp), GPR64:$base, (nxv4i32 ZPR:$offs)),
+ (Inst ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>;
+ }
+
+ defm : sve_masked_scatter_x2_scaled<nxv2i64, trunc_masked_scatter_i16_signed_scaled, "SST1H_D">;
+ defm : sve_masked_scatter_x2_scaled<nxv2i64, trunc_masked_scatter_i32_signed_scaled, "SST1W_D">;
+ defm : sve_masked_scatter_x2_scaled<nxv2i64, nontrunc_masked_scatter_signed_scaled, "SST1D">;
+ defm : sve_masked_scatter_x2_scaled<nxv2f16, nontrunc_masked_scatter_signed_scaled, "SST1H_D">;
+ defm : sve_masked_scatter_x2_scaled<nxv2f32, nontrunc_masked_scatter_signed_scaled, "SST1W_D">;
+ defm : sve_masked_scatter_x2_scaled<nxv2f64, nontrunc_masked_scatter_signed_scaled, "SST1D">;
+ defm : sve_masked_scatter_x2_scaled<nxv2bf16, nontrunc_masked_scatter_signed_scaled, "SST1H_D">;
+
+ defm : sve_masked_scatter_x2_unscaled<nxv2i64, trunc_masked_scatter_i8_signed_unscaled, "SST1B_D" , imm0_31>;
+ defm : sve_masked_scatter_x2_unscaled<nxv2i64, trunc_masked_scatter_i16_signed_unscaled, "SST1H_D", uimm5s2>;
+ defm : sve_masked_scatter_x2_unscaled<nxv2i64, trunc_masked_scatter_i32_signed_unscaled, "SST1W_D", uimm5s4>;
+ defm : sve_masked_scatter_x2_unscaled<nxv2i64, nontrunc_masked_scatter_signed_unscaled, "SST1D", uimm5s8>;
+ defm : sve_masked_scatter_x2_unscaled<nxv2f16, nontrunc_masked_scatter_signed_unscaled, "SST1H_D", uimm5s2>;
+ defm : sve_masked_scatter_x2_unscaled<nxv2f32, nontrunc_masked_scatter_signed_unscaled, "SST1W_D", uimm5s4>;
+ defm : sve_masked_scatter_x2_unscaled<nxv2f64, nontrunc_masked_scatter_signed_unscaled, "SST1D", uimm5s8>;
+ defm : sve_masked_scatter_x2_unscaled<nxv2bf16, nontrunc_masked_scatter_signed_unscaled, "SST1H_D", uimm5s2>;
+
+ defm : sve_masked_scatter_x4<nxv4i32, trunc_masked_scatter_i16_signed_scaled, SST1H_S_SXTW_SCALED>;
+ defm : sve_masked_scatter_x4<nxv4i32, nontrunc_masked_scatter_signed_scaled, SST1W_SXTW_SCALED>;
+ defm : sve_masked_scatter_x4<nxv4f16, nontrunc_masked_scatter_signed_scaled, SST1H_S_SXTW_SCALED>;
+ defm : sve_masked_scatter_x4<nxv4f32, nontrunc_masked_scatter_signed_scaled, SST1W_SXTW_SCALED>;
+ defm : sve_masked_scatter_x4<nxv4bf16, nontrunc_masked_scatter_signed_scaled, SST1H_S_SXTW_SCALED>;
+
+ defm : sve_masked_scatter_x4<nxv4i32, trunc_masked_scatter_i8_signed_unscaled, SST1B_S_SXTW>;
+ defm : sve_masked_scatter_x4<nxv4i32, trunc_masked_scatter_i16_signed_unscaled, SST1H_S_SXTW>;
+ defm : sve_masked_scatter_x4<nxv4i32, nontrunc_masked_scatter_signed_unscaled, SST1W_SXTW>;
+ defm : sve_masked_scatter_x4<nxv4f16, nontrunc_masked_scatter_signed_unscaled, SST1H_S_SXTW>;
+ defm : sve_masked_scatter_x4<nxv4f32, nontrunc_masked_scatter_signed_unscaled, SST1W_SXTW>;
+ defm : sve_masked_scatter_x4<nxv4bf16, nontrunc_masked_scatter_signed_unscaled, SST1H_S_SXTW>;
+
+ defm : sve_masked_scatter_x4<nxv4i32, trunc_masked_scatter_i16_unsigned_scaled, SST1H_S_UXTW_SCALED>;
+ defm : sve_masked_scatter_x4<nxv4i32, nontrunc_masked_scatter_unsigned_scaled, SST1W_UXTW_SCALED>;
+ defm : sve_masked_scatter_x4<nxv4f16, nontrunc_masked_scatter_unsigned_scaled, SST1H_S_UXTW_SCALED>;
+ defm : sve_masked_scatter_x4<nxv4f32, nontrunc_masked_scatter_unsigned_scaled, SST1W_UXTW_SCALED>;
+ defm : sve_masked_scatter_x4<nxv4bf16, nontrunc_masked_scatter_unsigned_scaled, SST1H_S_UXTW_SCALED>;
+
+ defm : sve_masked_scatter_x4<nxv4i32, trunc_masked_scatter_i8_unsigned_unscaled, SST1B_S_UXTW>;
+ defm : sve_masked_scatter_x4<nxv4i32, trunc_masked_scatter_i16_unsigned_unscaled, SST1H_S_UXTW>;
+ defm : sve_masked_scatter_x4<nxv4i32, nontrunc_masked_scatter_unsigned_unscaled, SST1W_UXTW>;
+ defm : sve_masked_scatter_x4<nxv4f16, nontrunc_masked_scatter_unsigned_unscaled, SST1H_S_UXTW>;
+ defm : sve_masked_scatter_x4<nxv4f32, nontrunc_masked_scatter_unsigned_unscaled, SST1W_UXTW>;
+ defm : sve_masked_scatter_x4<nxv4bf16, nontrunc_masked_scatter_unsigned_unscaled, SST1H_S_UXTW>;
} // End HasSVE
let Predicates = [HasSVEorSME] in {
More information about the llvm-commits
mailing list