[llvm] [AArch64] Add lowering for `@llvm.experimental.vector.compress` (PR #101015)

Eli Friedman via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 31 22:01:01 PDT 2024


================
@@ -6615,6 +6633,132 @@ SDValue AArch64TargetLowering::LowerLOAD(SDValue Op,
   return DAG.getMergeValues({Ext, Chain}, DL);
 }
 
+SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
+                                                    SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  SDValue Vec = Op.getOperand(0);
+  SDValue Mask = Op.getOperand(1);
+  SDValue Passthru = Op.getOperand(2);
+  EVT VecVT = Vec.getValueType();
+  EVT MaskVT = Mask.getValueType();
+  EVT ElmtVT = VecVT.getVectorElementType();
+  const bool IsFixedLength = VecVT.isFixedLengthVector();
+  const bool HasPassthru = !Passthru.isUndef();
+  unsigned MinElmts = VecVT.getVectorElementCount().getKnownMinValue();
+  EVT FixedVecVT = MVT::getVectorVT(ElmtVT.getSimpleVT(), MinElmts);
+
+  assert(VecVT.isVector() && "Input to VECTOR_COMPRESS must be vector.");
+
+  if (!Subtarget->hasSVE())
+    return SDValue();
+
+  if (IsFixedLength && VecVT.getSizeInBits().getFixedValue() > 128)
+    return SDValue();
+
+  // We can use the SVE register containing the NEON vector in its lowest bits.
+  if (IsFixedLength) {
+    EVT ScalableVecVT =
+        MVT::getScalableVectorVT(ElmtVT.getSimpleVT(), MinElmts);
+    EVT ScalableMaskVT = MVT::getScalableVectorVT(
+        MaskVT.getVectorElementType().getSimpleVT(), MinElmts);
+
+    Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableVecVT,
+                      DAG.getUNDEF(ScalableVecVT), Vec,
+                      DAG.getConstant(0, DL, MVT::i64));
+    Mask = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableMaskVT,
+                       DAG.getUNDEF(ScalableMaskVT), Mask,
+                       DAG.getConstant(0, DL, MVT::i64));
+    Mask = DAG.getNode(ISD::TRUNCATE, DL,
+                       ScalableMaskVT.changeVectorElementType(MVT::i1), Mask);
+    Passthru = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableVecVT,
+                           DAG.getUNDEF(ScalableVecVT), Passthru,
+                           DAG.getConstant(0, DL, MVT::i64));
+
+    VecVT = Vec.getValueType();
+    MaskVT = Mask.getValueType();
+  }
+
+  // Special case where we can't use svcompact but can do a compressing store
+  // and then reload the vector.
+  if (VecVT == MVT::nxv8i8 || VecVT == MVT::nxv16i8 || VecVT == MVT::nxv8i16) {
+    SDValue StackPtr = DAG.CreateStackTemporary(VecVT);
+    int FI = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
+    MachinePointerInfo PtrInfo =
+        MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI);
+
+    MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
+        PtrInfo, MachineMemOperand::Flags::MOStore,
+        LocationSize::precise(VecVT.getStoreSize()),
+        DAG.getReducedAlign(VecVT, /*UseABI=*/false));
+
+    SDValue Chain = DAG.getEntryNode();
+    if (HasPassthru)
+      Chain = DAG.getStore(Chain, DL, Passthru, StackPtr, PtrInfo);
+
+    Chain = DAG.getMaskedStore(Chain, DL, Vec, StackPtr, DAG.getUNDEF(MVT::i64),
+                               Mask, VecVT, MMO, ISD::UNINDEXED,
+                               /*IsTruncating=*/false, /*IsCompressing=*/true);
+
+    SDValue Compressed = DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);
+
+    if (IsFixedLength)
+      Compressed = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, FixedVecVT,
+                               Compressed, DAG.getConstant(0, DL, MVT::i64));
+
+    return Compressed;
+  }
+
+  // Only <vscale x {2|4} x {i32|i64}> supported for svcompact.
+  if (MinElmts != 2 && MinElmts != 4)
+    return SDValue();
+
+  // Get legal type for svcompact instruction
+  EVT ContainerVT = getSVEContainerType(VecVT);
+  EVT CastVT = VecVT.changeVectorElementTypeToInteger();
+
+  // Convert to i32 or i64 for smaller types, as these are the only supported
+  // sizes for svcompact.
+  if (ContainerVT != VecVT) {
+    Vec = DAG.getBitcast(CastVT, Vec);
+    Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec);
+  }
+
+  SDValue Compressed = DAG.getNode(
+      ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(),
+      DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask, Vec);
+
+  // svcompact fills with 0s, so if our passthru is all 0s, do nothing here.
+  if (HasPassthru && !ISD::isConstantSplatVectorAllZeros(Passthru.getNode())) {
+    SDValue Offset = DAG.getNode(
+        ISD::ZERO_EXTEND, DL, MaskVT.changeVectorElementType(MVT::i32), Mask);
+    Offset = DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, Offset);
+    Compressed =
+        DAG.getNode(ISD::VP_MERGE, DL, VecVT,
----------------
efriedma-quic wrote:

The passthru exists because it's useful for some combinations of target/passthru value.  For SVE in particular, for a non-zero passthru, we need to explicitly construct a mask, but other targets support it directly.  This was discussed in #92289.

https://github.com/llvm/llvm-project/pull/101015


More information about the llvm-commits mailing list