[llvm] 85f3f6b - [RISCV] Lower scalable vector masked loads to intrinsics to match fixed vectors and reduce isel patterns.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 19 10:47:20 PDT 2021


Author: Craig Topper
Date: 2021-03-19T10:39:35-07:00
New Revision: 85f3f6b3cc2969fa0e7b38209dfe02354f7153dd

URL: https://github.com/llvm/llvm-project/commit/85f3f6b3cc2969fa0e7b38209dfe02354f7153dd
DIFF: https://github.com/llvm/llvm-project/commit/85f3f6b3cc2969fa0e7b38209dfe02354f7153dd.diff

LOG: [RISCV] Lower scalable vector masked loads to intrinsics to match fixed vectors and reduce isel patterns.

Reviewed By: frasercrmck

Differential Revision: https://reviews.llvm.org/D98840

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.h
    llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index cd47d65d50a8..6dfc2d46afe1 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -474,6 +474,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
       setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
 
+      setOperationAction(ISD::MLOAD, VT, Custom);
+      setOperationAction(ISD::MSTORE, VT, Custom);
       setOperationAction(ISD::MGATHER, VT, Custom);
       setOperationAction(ISD::MSCATTER, VT, Custom);
 
@@ -517,6 +519,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
       setOperationAction(ISD::FCOPYSIGN, VT, Legal);
 
+      setOperationAction(ISD::MLOAD, VT, Custom);
+      setOperationAction(ISD::MSTORE, VT, Custom);
       setOperationAction(ISD::MGATHER, VT, Custom);
       setOperationAction(ISD::MSCATTER, VT, Custom);
 
@@ -1651,9 +1655,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
   case ISD::STORE:
     return lowerFixedLengthVectorStoreToRVV(Op, DAG);
   case ISD::MLOAD:
-    return lowerFixedLengthVectorMaskedLoadToRVV(Op, DAG);
+    return lowerMLOAD(Op, DAG);
   case ISD::MSTORE:
-    return lowerFixedLengthVectorMaskedStoreToRVV(Op, DAG);
+    return lowerMSTORE(Op, DAG);
   case ISD::SETCC:
     return lowerFixedLengthVectorSetccToRVV(Op, DAG);
   case ISD::ADD:
@@ -3194,50 +3198,63 @@ RISCVTargetLowering::lowerFixedLengthVectorStoreToRVV(SDValue Op,
       Store->getMemoryVT(), Store->getMemOperand());
 }
 
-SDValue RISCVTargetLowering::lowerFixedLengthVectorMaskedLoadToRVV(
-    SDValue Op, SelectionDAG &DAG) const {
+SDValue RISCVTargetLowering::lowerMLOAD(SDValue Op, SelectionDAG &DAG) const {
   auto *Load = cast<MaskedLoadSDNode>(Op);
 
   SDLoc DL(Op);
   MVT VT = Op.getSimpleValueType();
-  MVT ContainerVT = getContainerForFixedLengthVector(VT);
-  MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
   MVT XLenVT = Subtarget.getXLenVT();
 
-  SDValue Mask =
-      convertToScalableVector(MaskVT, Load->getMask(), DAG, Subtarget);
-  SDValue PassThru =
-      convertToScalableVector(ContainerVT, Load->getPassThru(), DAG, Subtarget);
-  SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
+  SDValue Mask = Load->getMask();
+  SDValue PassThru = Load->getPassThru();
+  SDValue VL;
+
+  MVT ContainerVT = VT;
+  if (VT.isFixedLengthVector()) {
+    ContainerVT = getContainerForFixedLengthVector(VT);
+    MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
+
+    Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
+    PassThru = convertToScalableVector(ContainerVT, PassThru, DAG, Subtarget);
+    VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
+  } else
+    VL = DAG.getRegister(RISCV::X0, XLenVT);
 
   SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
   SDValue IntID = DAG.getTargetConstant(Intrinsic::riscv_vle_mask, DL, XLenVT);
   SDValue Ops[] = {Load->getChain(),   IntID, PassThru,
                    Load->getBasePtr(), Mask,  VL};
-  SDValue NewLoad =
+  SDValue Result =
       DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops,
                               Load->getMemoryVT(), Load->getMemOperand());
+  SDValue Chain = Result.getValue(1);
 
-  SDValue Result = convertFromScalableVector(VT, NewLoad, DAG, Subtarget);
-  return DAG.getMergeValues({Result, NewLoad.getValue(1)}, DL);
+  if (VT.isFixedLengthVector())
+    Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
+
+  return DAG.getMergeValues({Result, Chain}, DL);
 }
 
-SDValue RISCVTargetLowering::lowerFixedLengthVectorMaskedStoreToRVV(
-    SDValue Op, SelectionDAG &DAG) const {
+SDValue RISCVTargetLowering::lowerMSTORE(SDValue Op, SelectionDAG &DAG) const {
   auto *Store = cast<MaskedStoreSDNode>(Op);
 
   SDLoc DL(Op);
   SDValue Val = Store->getValue();
+  SDValue Mask = Store->getMask();
   MVT VT = Val.getSimpleValueType();
-  MVT ContainerVT = getContainerForFixedLengthVector(VT);
-  MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
   MVT XLenVT = Subtarget.getXLenVT();
+  SDValue VL;
 
-  Val = convertToScalableVector(ContainerVT, Val, DAG, Subtarget);
-  SDValue Mask =
-      convertToScalableVector(MaskVT, Store->getMask(), DAG, Subtarget);
+  MVT ContainerVT = VT;
+  if (VT.isFixedLengthVector()) {
+    ContainerVT = getContainerForFixedLengthVector(VT);
+    MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
 
-  SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
+    Val = convertToScalableVector(ContainerVT, Val, DAG, Subtarget);
+    Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
+    VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
+  } else
+    VL = DAG.getRegister(RISCV::X0, XLenVT);
 
   SDValue IntID = DAG.getTargetConstant(Intrinsic::riscv_vse_mask, DL, XLenVT);
   return DAG.getMemIntrinsicNode(

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 35fdf2921e22..4546ee4d0f89 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -475,15 +475,13 @@ class RISCVTargetLowering : public TargetLowering {
   SDValue lowerEXTRACT_SUBVECTOR(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerVECTOR_REVERSE(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerABS(SDValue Op, SelectionDAG &DAG) const;
+  SDValue lowerMLOAD(SDValue Op, SelectionDAG &DAG) const;
+  SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerFixedLengthVectorFCOPYSIGNToRVV(SDValue Op,
                                                SelectionDAG &DAG) const;
   SDValue lowerMGATHERMSCATTER(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerFixedLengthVectorLoadToRVV(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerFixedLengthVectorStoreToRVV(SDValue Op, SelectionDAG &DAG) const;
-  SDValue lowerFixedLengthVectorMaskedLoadToRVV(SDValue Op,
-                                                SelectionDAG &DAG) const;
-  SDValue lowerFixedLengthVectorMaskedStoreToRVV(SDValue Op,
-                                                 SelectionDAG &DAG) const;
   SDValue lowerFixedLengthVectorSetccToRVV(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerFixedLengthVectorLogicOpToRVV(SDValue Op, SelectionDAG &DAG,
                                              unsigned MaskOpc,

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
index d847296e7e25..eaa404fa3be8 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
@@ -33,21 +33,6 @@ def SplatPat       : ComplexPattern<vAny, 1, "selectVSplat",      [splat_vector,
 def SplatPat_simm5 : ComplexPattern<vAny, 1, "selectVSplatSimm5", [splat_vector, rv32_splat_i64], [], 2>;
 def SplatPat_uimm5 : ComplexPattern<vAny, 1, "selectVSplatUimm5", [splat_vector, rv32_splat_i64], [], 2>;
 
-def masked_load :
-  PatFrag<(ops node:$ptr, node:$mask, node:$maskedoff),
-          (masked_ld node:$ptr, undef, node:$mask, node:$maskedoff), [{
-  return !cast<MaskedLoadSDNode>(N)->isExpandingLoad() &&
-    cast<MaskedLoadSDNode>(N)->getExtensionType() == ISD::NON_EXTLOAD &&
-    cast<MaskedLoadSDNode>(N)->isUnindexed();
-}]>;
-def masked_store :
-  PatFrag<(ops node:$val, node:$ptr, node:$mask),
-          (masked_st node:$val, node:$ptr, undef, node:$mask), [{
-  return !cast<MaskedStoreSDNode>(N)->isTruncatingStore() &&
-         !cast<MaskedStoreSDNode>(N)->isCompressingStore() &&
-         cast<MaskedStoreSDNode>(N)->isUnindexed();
-}]>;
-
 class SwapHelper<dag Prefix, dag A, dag B, dag Suffix, bit swap> {
    dag Value = !con(Prefix, !if(swap, B, A), !if(swap, A, B), Suffix);
 }
@@ -68,25 +53,6 @@ multiclass VPatUSLoadStoreSDNode<ValueType type,
             (store_instr reg_class:$rs2, BaseAddr:$rs1, avl, sew)>;
 }
 
-multiclass VPatUSLoadStoreSDNodeMask<ValueType type,
-                                     ValueType mask_type,
-                                     int sew,
-                                     LMULInfo vlmul,
-                                     OutPatFrag avl,
-                                     VReg reg_class>
-{
-  defvar load_instr = !cast<Instruction>("PseudoVLE"#sew#"_V_"#vlmul.MX#"_MASK");
-  defvar store_instr = !cast<Instruction>("PseudoVSE"#sew#"_V_"#vlmul.MX#"_MASK");
-  // Load
-  def : Pat<(type (masked_load BaseAddr:$rs1, (mask_type V0), type:$merge)),
-            (load_instr reg_class:$merge, BaseAddr:$rs1, (mask_type V0),
-                        avl, sew)>;
-  // Store
-  def : Pat<(masked_store type:$rs2, BaseAddr:$rs1, (mask_type V0)),
-            (store_instr reg_class:$rs2, BaseAddr:$rs1, (mask_type V0),
-                         avl, sew)>;
-}
-
 multiclass VPatUSLoadStoreWholeVRSDNode<ValueType type,
                                         int sew,
                                         LMULInfo vlmul,
@@ -394,9 +360,6 @@ foreach vti = !listconcat(FractionalGroupIntegerVectors,
                           FractionalGroupFloatVectors) in
   defm "" : VPatUSLoadStoreSDNode<vti.Vector, vti.SEW, vti.LMul,
                                   vti.AVL, vti.RegClass>;
-foreach vti = AllVectors in
-  defm "" : VPatUSLoadStoreSDNodeMask<vti.Vector, vti.Mask, vti.SEW, vti.LMul,
-                                      vti.AVL, vti.RegClass>;
 foreach vti = [VI8M1, VI16M1, VI32M1, VI64M1, VF16M1, VF32M1, VF64M1] in
   defm "" : VPatUSLoadStoreWholeVRSDNode<vti.Vector, vti.SEW, vti.LMul,
                                          vti.RegClass>;


        


More information about the llvm-commits mailing list