[llvm] [LLVM] Add `llvm.masked.compress` intrinsic (PR #92289)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 18 06:51:55 PDT 2024


================
@@ -12091,6 +12093,57 @@ SDValue DAGCombiner::visitVP_STRIDED_STORE(SDNode *N) {
   return SDValue();
 }
 
+SDValue DAGCombiner::visitMASKED_COMPRESS(SDNode *N) {
+  SDLoc DL(N);
+  SDValue Vec = N->getOperand(0);
+  SDValue Mask = N->getOperand(1);
+  SDValue Passthru = N->getOperand(2);
+  EVT VecVT = Vec.getValueType();
+
+  bool HasPassthru = !Passthru.isUndef();
+
+  APInt SplatVal;
+  if (ISD::isConstantSplatVector(Mask.getNode(), SplatVal))
+    return SplatVal.isAllOnes()
+               ? Vec
+               : (HasPassthru ? Passthru : DAG.getUNDEF(VecVT));
+
+  if (Vec.isUndef() || Mask.isUndef())
+    return DAG.getUNDEF(VecVT);
+
+  // No need for potentially expensive compress if the mask is constant.
+  if (ISD::isBuildVectorOfConstantSDNodes(Mask.getNode())) {
+    SmallVector<SDValue, 16> Ops;
+    EVT ScalarVT = VecVT.getVectorElementType();
+    unsigned NumSelected = 0;
+    unsigned NumElmts = VecVT.getVectorNumElements();
+    for (unsigned I = 0; I < NumElmts; ++I) {
+      SDValue MaskI = Mask.getOperand(I);
+      if (MaskI.isUndef())
+        continue;
+
+      ConstantSDNode *CMaskI = cast<ConstantSDNode>(MaskI);
+      if (CMaskI->isAllOnes()) {
+        SDValue VecI = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, Vec,
+                                   DAG.getVectorIdxConstant(I, DL));
+        Ops.push_back(VecI);
+        NumSelected++;
+      }
+    }
+    for (unsigned Rest = NumSelected; Rest < NumElmts; ++Rest) {
+      SDValue Val =
+          HasPassthru
+              ? DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, Passthru,
+                            DAG.getVectorIdxConstant(Rest - NumSelected, DL))
----------------
RKSimon wrote:

getVectorIdxConstant(Rest) ?

https://github.com/llvm/llvm-project/pull/92289


More information about the llvm-commits mailing list