[llvm] [RISCV] Support llvm.masked.expandload intrinsic (PR #101954)
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Mon Oct 28 17:06:55 PDT 2024
================
@@ -11134,25 +11136,110 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
if (!VL)
VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
- unsigned IntID =
- IsUnmasked ? Intrinsic::riscv_vle : Intrinsic::riscv_vle_mask;
- SmallVector<SDValue, 8> Ops{Chain, DAG.getTargetConstant(IntID, DL, XLenVT)};
- if (IsUnmasked)
- Ops.push_back(DAG.getUNDEF(ContainerVT));
- else
+ SDValue Result;
+ if (!IsUnmasked && IsExpandingLoad &&
+ Subtarget.hasOptimizedIndexedLoadStore()) {
+ MVT IndexVT = ContainerVT;
+ if (ContainerVT.isFloatingPoint())
+ IndexVT = IndexVT.changeVectorElementTypeToInteger();
+
+ MVT IndexEltVT = IndexVT.getVectorElementType();
+ if (Subtarget.isRV32() && IndexEltVT.bitsGT(XLenVT))
+ IndexVT = IndexVT.changeVectorElementType(XLenVT);
+
+ // If index vector is an i8 vector and the element count exceeds 256, we
+ // should change the element type of index vector to i16 to avoid
+ // overflow.
+ if (IndexEltVT == MVT::i8 && VT.getVectorNumElements() > 256) {
+ // FIXME: We need to do vector splitting manually for LMUL=8 cases.
+ if (getLMUL(IndexVT) == RISCVII::LMUL_8)
+ return SDValue();
+ IndexVT = IndexVT.changeVectorElementType(MVT::i16);
+ }
+
+ SDValue Index =
+ DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, IndexVT,
+ DAG.getTargetConstant(Intrinsic::riscv_viota, DL, XLenVT),
+ DAG.getUNDEF(IndexVT), Mask, VL);
+ if (uint64_t EltSize = ContainerVT.getScalarSizeInBits(); EltSize > 8)
+ Index = DAG.getNode(RISCVISD::SHL_VL, DL, IndexVT, Index,
+ DAG.getConstant(Log2_64(EltSize / 8), DL, IndexVT),
+ DAG.getUNDEF(IndexVT), Mask, VL);
+ unsigned IntID = Intrinsic::riscv_vluxei_mask;
+ SmallVector<SDValue, 8> Ops{Chain,
+ DAG.getTargetConstant(IntID, DL, XLenVT)};
Ops.push_back(PassThru);
- Ops.push_back(BasePtr);
- if (!IsUnmasked)
+ Ops.push_back(BasePtr);
+ Ops.push_back(Index);
Ops.push_back(Mask);
- Ops.push_back(VL);
- if (!IsUnmasked)
+ Ops.push_back(VL);
Ops.push_back(DAG.getTargetConstant(RISCVII::TAIL_AGNOSTIC, DL, XLenVT));
- SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
+ SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
- SDValue Result =
- DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops, MemVT, MMO);
- Chain = Result.getValue(1);
+ Result = DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops,
+ MemVT, MMO);
+ Chain = Result.getValue(1);
+ } else {
+ SDValue ExpandingVL;
+ if (!IsUnmasked && IsExpandingLoad &&
+ !Subtarget.hasOptimizedIndexedLoadStore()) {
+ ExpandingVL = VL;
+ VL = DAG.getNode(RISCVISD::VCPOP_VL, DL, XLenVT, Mask,
+ getAllOnesMask(Mask.getSimpleValueType(), VL, DL, DAG),
+ VL);
+ }
+
+ unsigned IntID = IsUnmasked || (IsExpandingLoad &&
+ !Subtarget.hasOptimizedIndexedLoadStore())
+ ? Intrinsic::riscv_vle
+ : Intrinsic::riscv_vle_mask;
+ SmallVector<SDValue, 8> Ops{Chain,
+ DAG.getTargetConstant(IntID, DL, XLenVT)};
+ if (IntID == Intrinsic::riscv_vle)
+ Ops.push_back(DAG.getUNDEF(ContainerVT));
+ else
+ Ops.push_back(PassThru);
+ Ops.push_back(BasePtr);
+ if (IntID == Intrinsic::riscv_vle_mask)
+ Ops.push_back(Mask);
+ Ops.push_back(VL);
+ if (IntID == Intrinsic::riscv_vle_mask)
+ Ops.push_back(DAG.getTargetConstant(RISCVII::TAIL_AGNOSTIC, DL, XLenVT));
+
+ SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
+
+ Result = DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops,
+ MemVT, MMO);
+ Chain = Result.getValue(1);
+ if (ExpandingVL) {
+ MVT IndexVT = ContainerVT;
+ if (ContainerVT.isFloatingPoint())
+ IndexVT = ContainerVT.changeVectorElementTypeToInteger();
+
+ MVT IndexEltVT = IndexVT.getVectorElementType();
+ bool UseVRGATHEREI16 = false;
+ // If index vector is an i8 vector and the element count exceeds 256, we
+ // should change the element type of index vector to i16 to avoid
+ // overflow.
+ if (IndexEltVT == MVT::i8 && VT.getVectorNumElements() > 256) {
+ // FIXME: We need to do vector splitting manually for LMUL=8 cases.
+ if (getLMUL(IndexVT) == RISCVII::LMUL_8)
+ return SDValue();
----------------
topperc wrote:
What does the legalizer do when it receives SDValue() here?
https://github.com/llvm/llvm-project/pull/101954
More information about the llvm-commits
mailing list