[llvm] [AArch64] Add lowering for `@llvm.experimental.vector.compress` (PR #101015)
Lawrence Benson via llvm-commits
llvm-commits at lists.llvm.org
Wed Aug 7 05:39:15 PDT 2024
================
@@ -6616,6 +6628,102 @@ SDValue AArch64TargetLowering::LowerLOAD(SDValue Op,
return DAG.getMergeValues({Ext, Chain}, DL);
}
+SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ SDValue Vec = Op.getOperand(0);
+ SDValue Mask = Op.getOperand(1);
+ SDValue Passthru = Op.getOperand(2);
+ EVT VecVT = Vec.getValueType();
+ EVT MaskVT = Mask.getValueType();
+ EVT ElmtVT = VecVT.getVectorElementType();
+ const bool IsFixedLength = VecVT.isFixedLengthVector();
+ const bool HasPassthru = !Passthru.isUndef();
+ unsigned MinElmts = VecVT.getVectorElementCount().getKnownMinValue();
+ EVT FixedVecVT = MVT::getVectorVT(ElmtVT.getSimpleVT(), MinElmts);
+
+ assert(VecVT.isVector() && "Input to VECTOR_COMPRESS must be vector.");
+
+ if (!Subtarget->isSVEAvailable())
+ return SDValue();
+
+ if (IsFixedLength && VecVT.getSizeInBits().getFixedValue() > 128)
+ return SDValue();
+
+ // Only <vscale x {4|2} x {i32|i64}> supported for compact.
+ if (MinElmts != 2 && MinElmts != 4)
+ return SDValue();
+
+ // We can use the SVE register containing the NEON vector in its lowest bits.
+ if (IsFixedLength) {
+ EVT ScalableVecVT =
+ MVT::getScalableVectorVT(ElmtVT.getSimpleVT(), MinElmts);
+ EVT ScalableMaskVT = MVT::getScalableVectorVT(
+ MaskVT.getVectorElementType().getSimpleVT(), MinElmts);
+
+ Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableVecVT,
+ DAG.getUNDEF(ScalableVecVT), Vec,
+ DAG.getConstant(0, DL, MVT::i64));
+ Mask = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableMaskVT,
+ DAG.getUNDEF(ScalableMaskVT), Mask,
+ DAG.getConstant(0, DL, MVT::i64));
+ Mask = DAG.getNode(ISD::TRUNCATE, DL,
+ ScalableMaskVT.changeVectorElementType(MVT::i1), Mask);
+ Passthru = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableVecVT,
+ DAG.getUNDEF(ScalableVecVT), Passthru,
+ DAG.getConstant(0, DL, MVT::i64));
+
+ VecVT = Vec.getValueType();
+ MaskVT = Mask.getValueType();
+ }
+
+ // Get legal type for compact instruction
+ EVT ContainerVT = getSVEContainerType(VecVT);
+ EVT CastVT = VecVT.changeVectorElementTypeToInteger();
+
+ // Convert to i32 or i64 for smaller types, as these are the only supported
+ // sizes for compact.
+ if (ContainerVT != VecVT) {
+ Vec = DAG.getBitcast(CastVT, Vec);
+ Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec);
+ }
+
+ SDValue Compressed = DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(),
+ DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask, Vec);
+
+ // compact fills with 0s, so if our passthru is all 0s, do nothing here.
+ if (HasPassthru && !ISD::isConstantSplatVectorAllZeros(Passthru.getNode())) {
+ SDValue Offset = DAG.getNode(
+ ISD::ZERO_EXTEND, DL, MaskVT.changeVectorElementType(MVT::i32), Mask);
+ Offset = DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, Offset);
+ Compressed =
+ DAG.getNode(ISD::VP_MERGE, DL, VecVT,
----------------
lawben wrote:
done.
https://github.com/llvm/llvm-project/pull/101015
More information about the llvm-commits
mailing list