[llvm] 85f3f6b - [RISCV] Lower scalable vector masked loads to intrinsics to match fixed vectors and reduce isel patterns.
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Fri Mar 19 10:47:20 PDT 2021
Author: Craig Topper
Date: 2021-03-19T10:39:35-07:00
New Revision: 85f3f6b3cc2969fa0e7b38209dfe02354f7153dd
URL: https://github.com/llvm/llvm-project/commit/85f3f6b3cc2969fa0e7b38209dfe02354f7153dd
DIFF: https://github.com/llvm/llvm-project/commit/85f3f6b3cc2969fa0e7b38209dfe02354f7153dd.diff
LOG: [RISCV] Lower scalable vector masked loads to intrinsics to match fixed vectors and reduce isel patterns.
Reviewed By: frasercrmck
Differential Revision: https://reviews.llvm.org/D98840
Added:
Modified:
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/lib/Target/RISCV/RISCVISelLowering.h
llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index cd47d65d50a8..6dfc2d46afe1 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -474,6 +474,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
+ setOperationAction(ISD::MLOAD, VT, Custom);
+ setOperationAction(ISD::MSTORE, VT, Custom);
setOperationAction(ISD::MGATHER, VT, Custom);
setOperationAction(ISD::MSCATTER, VT, Custom);
@@ -517,6 +519,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
setOperationAction(ISD::FCOPYSIGN, VT, Legal);
+ setOperationAction(ISD::MLOAD, VT, Custom);
+ setOperationAction(ISD::MSTORE, VT, Custom);
setOperationAction(ISD::MGATHER, VT, Custom);
setOperationAction(ISD::MSCATTER, VT, Custom);
@@ -1651,9 +1655,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
case ISD::STORE:
return lowerFixedLengthVectorStoreToRVV(Op, DAG);
case ISD::MLOAD:
- return lowerFixedLengthVectorMaskedLoadToRVV(Op, DAG);
+ return lowerMLOAD(Op, DAG);
case ISD::MSTORE:
- return lowerFixedLengthVectorMaskedStoreToRVV(Op, DAG);
+ return lowerMSTORE(Op, DAG);
case ISD::SETCC:
return lowerFixedLengthVectorSetccToRVV(Op, DAG);
case ISD::ADD:
@@ -3194,50 +3198,63 @@ RISCVTargetLowering::lowerFixedLengthVectorStoreToRVV(SDValue Op,
Store->getMemoryVT(), Store->getMemOperand());
}
-SDValue RISCVTargetLowering::lowerFixedLengthVectorMaskedLoadToRVV(
- SDValue Op, SelectionDAG &DAG) const {
+SDValue RISCVTargetLowering::lowerMLOAD(SDValue Op, SelectionDAG &DAG) const {
auto *Load = cast<MaskedLoadSDNode>(Op);
SDLoc DL(Op);
MVT VT = Op.getSimpleValueType();
- MVT ContainerVT = getContainerForFixedLengthVector(VT);
- MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
MVT XLenVT = Subtarget.getXLenVT();
- SDValue Mask =
- convertToScalableVector(MaskVT, Load->getMask(), DAG, Subtarget);
- SDValue PassThru =
- convertToScalableVector(ContainerVT, Load->getPassThru(), DAG, Subtarget);
- SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
+ SDValue Mask = Load->getMask();
+ SDValue PassThru = Load->getPassThru();
+ SDValue VL;
+
+ MVT ContainerVT = VT;
+ if (VT.isFixedLengthVector()) {
+ ContainerVT = getContainerForFixedLengthVector(VT);
+ MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
+
+ Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
+ PassThru = convertToScalableVector(ContainerVT, PassThru, DAG, Subtarget);
+ VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
+ } else
+ VL = DAG.getRegister(RISCV::X0, XLenVT);
SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
SDValue IntID = DAG.getTargetConstant(Intrinsic::riscv_vle_mask, DL, XLenVT);
SDValue Ops[] = {Load->getChain(), IntID, PassThru,
Load->getBasePtr(), Mask, VL};
- SDValue NewLoad =
+ SDValue Result =
DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops,
Load->getMemoryVT(), Load->getMemOperand());
+ SDValue Chain = Result.getValue(1);
- SDValue Result = convertFromScalableVector(VT, NewLoad, DAG, Subtarget);
- return DAG.getMergeValues({Result, NewLoad.getValue(1)}, DL);
+ if (VT.isFixedLengthVector())
+ Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
+
+ return DAG.getMergeValues({Result, Chain}, DL);
}
-SDValue RISCVTargetLowering::lowerFixedLengthVectorMaskedStoreToRVV(
- SDValue Op, SelectionDAG &DAG) const {
+SDValue RISCVTargetLowering::lowerMSTORE(SDValue Op, SelectionDAG &DAG) const {
auto *Store = cast<MaskedStoreSDNode>(Op);
SDLoc DL(Op);
SDValue Val = Store->getValue();
+ SDValue Mask = Store->getMask();
MVT VT = Val.getSimpleValueType();
- MVT ContainerVT = getContainerForFixedLengthVector(VT);
- MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
MVT XLenVT = Subtarget.getXLenVT();
+ SDValue VL;
- Val = convertToScalableVector(ContainerVT, Val, DAG, Subtarget);
- SDValue Mask =
- convertToScalableVector(MaskVT, Store->getMask(), DAG, Subtarget);
+ MVT ContainerVT = VT;
+ if (VT.isFixedLengthVector()) {
+ ContainerVT = getContainerForFixedLengthVector(VT);
+ MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
- SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
+ Val = convertToScalableVector(ContainerVT, Val, DAG, Subtarget);
+ Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
+ VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
+ } else
+ VL = DAG.getRegister(RISCV::X0, XLenVT);
SDValue IntID = DAG.getTargetConstant(Intrinsic::riscv_vse_mask, DL, XLenVT);
return DAG.getMemIntrinsicNode(
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 35fdf2921e22..4546ee4d0f89 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -475,15 +475,13 @@ class RISCVTargetLowering : public TargetLowering {
SDValue lowerEXTRACT_SUBVECTOR(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVECTOR_REVERSE(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerABS(SDValue Op, SelectionDAG &DAG) const;
+ SDValue lowerMLOAD(SDValue Op, SelectionDAG &DAG) const;
+ SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorFCOPYSIGNToRVV(SDValue Op,
SelectionDAG &DAG) const;
SDValue lowerMGATHERMSCATTER(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorLoadToRVV(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorStoreToRVV(SDValue Op, SelectionDAG &DAG) const;
- SDValue lowerFixedLengthVectorMaskedLoadToRVV(SDValue Op,
- SelectionDAG &DAG) const;
- SDValue lowerFixedLengthVectorMaskedStoreToRVV(SDValue Op,
- SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorSetccToRVV(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorLogicOpToRVV(SDValue Op, SelectionDAG &DAG,
unsigned MaskOpc,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
index d847296e7e25..eaa404fa3be8 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
@@ -33,21 +33,6 @@ def SplatPat : ComplexPattern<vAny, 1, "selectVSplat", [splat_vector,
def SplatPat_simm5 : ComplexPattern<vAny, 1, "selectVSplatSimm5", [splat_vector, rv32_splat_i64], [], 2>;
def SplatPat_uimm5 : ComplexPattern<vAny, 1, "selectVSplatUimm5", [splat_vector, rv32_splat_i64], [], 2>;
-def masked_load :
- PatFrag<(ops node:$ptr, node:$mask, node:$maskedoff),
- (masked_ld node:$ptr, undef, node:$mask, node:$maskedoff), [{
- return !cast<MaskedLoadSDNode>(N)->isExpandingLoad() &&
- cast<MaskedLoadSDNode>(N)->getExtensionType() == ISD::NON_EXTLOAD &&
- cast<MaskedLoadSDNode>(N)->isUnindexed();
-}]>;
-def masked_store :
- PatFrag<(ops node:$val, node:$ptr, node:$mask),
- (masked_st node:$val, node:$ptr, undef, node:$mask), [{
- return !cast<MaskedStoreSDNode>(N)->isTruncatingStore() &&
- !cast<MaskedStoreSDNode>(N)->isCompressingStore() &&
- cast<MaskedStoreSDNode>(N)->isUnindexed();
-}]>;
-
class SwapHelper<dag Prefix, dag A, dag B, dag Suffix, bit swap> {
dag Value = !con(Prefix, !if(swap, B, A), !if(swap, A, B), Suffix);
}
@@ -68,25 +53,6 @@ multiclass VPatUSLoadStoreSDNode<ValueType type,
(store_instr reg_class:$rs2, BaseAddr:$rs1, avl, sew)>;
}
-multiclass VPatUSLoadStoreSDNodeMask<ValueType type,
- ValueType mask_type,
- int sew,
- LMULInfo vlmul,
- OutPatFrag avl,
- VReg reg_class>
-{
- defvar load_instr = !cast<Instruction>("PseudoVLE"#sew#"_V_"#vlmul.MX#"_MASK");
- defvar store_instr = !cast<Instruction>("PseudoVSE"#sew#"_V_"#vlmul.MX#"_MASK");
- // Load
- def : Pat<(type (masked_load BaseAddr:$rs1, (mask_type V0), type:$merge)),
- (load_instr reg_class:$merge, BaseAddr:$rs1, (mask_type V0),
- avl, sew)>;
- // Store
- def : Pat<(masked_store type:$rs2, BaseAddr:$rs1, (mask_type V0)),
- (store_instr reg_class:$rs2, BaseAddr:$rs1, (mask_type V0),
- avl, sew)>;
-}
-
multiclass VPatUSLoadStoreWholeVRSDNode<ValueType type,
int sew,
LMULInfo vlmul,
@@ -394,9 +360,6 @@ foreach vti = !listconcat(FractionalGroupIntegerVectors,
FractionalGroupFloatVectors) in
defm "" : VPatUSLoadStoreSDNode<vti.Vector, vti.SEW, vti.LMul,
vti.AVL, vti.RegClass>;
-foreach vti = AllVectors in
- defm "" : VPatUSLoadStoreSDNodeMask<vti.Vector, vti.Mask, vti.SEW, vti.LMul,
- vti.AVL, vti.RegClass>;
foreach vti = [VI8M1, VI16M1, VI32M1, VI64M1, VF16M1, VF32M1, VF64M1] in
defm "" : VPatUSLoadStoreWholeVRSDNode<vti.Vector, vti.SEW, vti.LMul,
vti.RegClass>;
More information about the llvm-commits
mailing list