[llvm] [RISCV] Support llvm.masked.expandload intrinsic (PR #101954)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 28 17:06:55 PDT 2024


================
@@ -11134,25 +11136,110 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
   if (!VL)
     VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
 
-  unsigned IntID =
-      IsUnmasked ? Intrinsic::riscv_vle : Intrinsic::riscv_vle_mask;
-  SmallVector<SDValue, 8> Ops{Chain, DAG.getTargetConstant(IntID, DL, XLenVT)};
-  if (IsUnmasked)
-    Ops.push_back(DAG.getUNDEF(ContainerVT));
-  else
+  SDValue Result;
+  if (!IsUnmasked && IsExpandingLoad &&
+      Subtarget.hasOptimizedIndexedLoadStore()) {
+    MVT IndexVT = ContainerVT;
+    if (ContainerVT.isFloatingPoint())
+      IndexVT = IndexVT.changeVectorElementTypeToInteger();
+
+    MVT IndexEltVT = IndexVT.getVectorElementType();
+    if (Subtarget.isRV32() && IndexEltVT.bitsGT(XLenVT))
+      IndexVT = IndexVT.changeVectorElementType(XLenVT);
+
+    // If index vector is an i8 vector and the element count exceeds 256, we
+    // should change the element type of index vector to i16 to avoid
+    // overflow.
+    if (IndexEltVT == MVT::i8 && VT.getVectorNumElements() > 256) {
+      // FIXME: We need to do vector splitting manually for LMUL=8 cases.
+      if (getLMUL(IndexVT) == RISCVII::LMUL_8)
+        return SDValue();
+      IndexVT = IndexVT.changeVectorElementType(MVT::i16);
+    }
+
+    SDValue Index =
+        DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, IndexVT,
+                    DAG.getTargetConstant(Intrinsic::riscv_viota, DL, XLenVT),
+                    DAG.getUNDEF(IndexVT), Mask, VL);
+    if (uint64_t EltSize = ContainerVT.getScalarSizeInBits(); EltSize > 8)
+      Index = DAG.getNode(RISCVISD::SHL_VL, DL, IndexVT, Index,
+                          DAG.getConstant(Log2_64(EltSize / 8), DL, IndexVT),
+                          DAG.getUNDEF(IndexVT), Mask, VL);
+    unsigned IntID = Intrinsic::riscv_vluxei_mask;
+    SmallVector<SDValue, 8> Ops{Chain,
+                                DAG.getTargetConstant(IntID, DL, XLenVT)};
     Ops.push_back(PassThru);
-  Ops.push_back(BasePtr);
-  if (!IsUnmasked)
+    Ops.push_back(BasePtr);
+    Ops.push_back(Index);
     Ops.push_back(Mask);
-  Ops.push_back(VL);
-  if (!IsUnmasked)
+    Ops.push_back(VL);
     Ops.push_back(DAG.getTargetConstant(RISCVII::TAIL_AGNOSTIC, DL, XLenVT));
 
-  SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
+    SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
 
-  SDValue Result =
-      DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops, MemVT, MMO);
-  Chain = Result.getValue(1);
+    Result = DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops,
+                                     MemVT, MMO);
+    Chain = Result.getValue(1);
+  } else {
+    SDValue ExpandingVL;
+    if (!IsUnmasked && IsExpandingLoad &&
+        !Subtarget.hasOptimizedIndexedLoadStore()) {
+      ExpandingVL = VL;
+      VL = DAG.getNode(RISCVISD::VCPOP_VL, DL, XLenVT, Mask,
+                       getAllOnesMask(Mask.getSimpleValueType(), VL, DL, DAG),
+                       VL);
+    }
+
+    unsigned IntID = IsUnmasked || (IsExpandingLoad &&
+                                    !Subtarget.hasOptimizedIndexedLoadStore())
+                         ? Intrinsic::riscv_vle
+                         : Intrinsic::riscv_vle_mask;
+    SmallVector<SDValue, 8> Ops{Chain,
+                                DAG.getTargetConstant(IntID, DL, XLenVT)};
+    if (IntID == Intrinsic::riscv_vle)
+      Ops.push_back(DAG.getUNDEF(ContainerVT));
+    else
+      Ops.push_back(PassThru);
+    Ops.push_back(BasePtr);
+    if (IntID == Intrinsic::riscv_vle_mask)
+      Ops.push_back(Mask);
+    Ops.push_back(VL);
+    if (IntID == Intrinsic::riscv_vle_mask)
+      Ops.push_back(DAG.getTargetConstant(RISCVII::TAIL_AGNOSTIC, DL, XLenVT));
+
+    SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
+
+    Result = DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops,
+                                     MemVT, MMO);
+    Chain = Result.getValue(1);
+    if (ExpandingVL) {
+      MVT IndexVT = ContainerVT;
+      if (ContainerVT.isFloatingPoint())
+        IndexVT = ContainerVT.changeVectorElementTypeToInteger();
+
+      MVT IndexEltVT = IndexVT.getVectorElementType();
+      bool UseVRGATHEREI16 = false;
+      // If index vector is an i8 vector and the element count exceeds 256, we
+      // should change the element type of index vector to i16 to avoid
+      // overflow.
+      if (IndexEltVT == MVT::i8 && VT.getVectorNumElements() > 256) {
+        // FIXME: We need to do vector splitting manually for LMUL=8 cases.
+        if (getLMUL(IndexVT) == RISCVII::LMUL_8)
+          return SDValue();
----------------
topperc wrote:

What does the legalizer do when it receives SDValue() here?

https://github.com/llvm/llvm-project/pull/101954


More information about the llvm-commits mailing list