[llvm] [LLVM] Add `llvm.masked.compress` intrinsic (PR #92289)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 18 06:51:56 PDT 2024


================
@@ -11336,6 +11336,105 @@ SDValue TargetLowering::expandVectorSplice(SDNode *Node,
                      MachinePointerInfo::getUnknownStack(MF));
 }
 
+SDValue TargetLowering::expandMASKED_COMPRESS(SDNode *Node,
+                                              SelectionDAG &DAG) const {
+  SDLoc DL(Node);
+  SDValue Vec = Node->getOperand(0);
+  SDValue Mask = Node->getOperand(1);
+  SDValue Passthru = Node->getOperand(2);
+
+  EVT VecVT = Vec.getValueType();
+  EVT ScalarVT = VecVT.getScalarType();
+  EVT MaskVT = Mask.getValueType();
+  EVT MaskScalarVT = MaskVT.getScalarType();
+
+  // Needs to be handled by targets that have scalable vector types.
+  if (VecVT.isScalableVector())
+    report_fatal_error("Cannot expand masked_compress for scalable vectors.");
+
+  SDValue StackPtr = DAG.CreateStackTemporary(
+      VecVT.getStoreSize(), DAG.getReducedAlign(VecVT, /*UseABI=*/false));
+  int FI = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
+  MachinePointerInfo PtrInfo =
+      MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI);
+
+  SDValue Chain = DAG.getEntryNode();
+  SDValue OutPos = DAG.getConstant(0, DL, MVT::i32);
+
+  bool HasPassthru = !Passthru.isUndef();
+
+  // If we have a passthru vector, store it on the stack, overwrite the matching
+  // positions and then re-write the last element that was potentially
+  // overwritten even though mask[i] = false.
+  if (HasPassthru)
+    Chain = DAG.getStore(Chain, DL, Passthru, StackPtr, PtrInfo);
+
+  SDValue LastWriteVal;
+  APInt PassthruSplatVal;
+  bool IsSplatPassthru =
+      ISD::isConstantSplatVector(Passthru.getNode(), PassthruSplatVal);
+
+  if (IsSplatPassthru) {
+    // As we do not know which position we wrote to last, we cannot simply
+    // access that index from the passthru vector. So we first check if passthru
+    // is a splat vector, to use any element ...
+    LastWriteVal = DAG.getConstant(PassthruSplatVal, DL, ScalarVT);
+  } else if (HasPassthru) {
+    // ... if it is not a splat vector, we need to get the passthru value at
+    // position = popcount(mask) and re-load it from the stack before it is
+    // overwritten in the loop below.
+    SDValue Popcount = DAG.getNode(
+        ISD::TRUNCATE, DL, MaskVT.changeVectorElementType(MVT::i1), Mask);
+    Popcount = DAG.getNode(ISD::ZERO_EXTEND, DL,
+                           MaskVT.changeVectorElementType(MVT::i32), Popcount);
+    Popcount = DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, Popcount);
+    SDValue LastElmtPtr =
+        getVectorElementPointer(DAG, StackPtr, VecVT, Popcount);
+    LastWriteVal = DAG.getLoad(
+        ScalarVT, DL, Chain, LastElmtPtr,
+        MachinePointerInfo::getUnknownStack(DAG.getMachineFunction()));
+    Chain = LastWriteVal.getValue(1);
+  }
+
+  unsigned NumElms = VecVT.getVectorNumElements();
+  for (unsigned I = 0; I < NumElms; I++) {
+    SDValue Idx = DAG.getVectorIdxConstant(I, DL);
+
+    SDValue ValI = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, Vec, Idx);
+    SDValue OutPtr = getVectorElementPointer(DAG, StackPtr, VecVT, OutPos);
+    Chain = DAG.getStore(
+        Chain, DL, ValI, OutPtr,
+        MachinePointerInfo::getUnknownStack(DAG.getMachineFunction()));
+
+    // Get the mask value and add it to the current output position. This
+    // either increments by 1 if MaskI is true or adds 0 otherwise.
+    SDValue MaskI =
+        DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MaskScalarVT, Mask, Idx);
+    MaskI = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, MaskI);
+    MaskI = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, MaskI);
+    OutPos = DAG.getNode(ISD::ADD, DL, MVT::i32, OutPos, MaskI);
----------------
RKSimon wrote:

Might not want to hard code to i32 (e.g. RISCV64 doesn't like them) - better off use getVectorIdxConstant/getVectorIdxTy ?

https://github.com/llvm/llvm-project/pull/92289


More information about the llvm-commits mailing list