[llvm] [LLVM] Add `llvm.experimental.vector.compress` intrinsic (PR #92289)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 2 20:03:36 PDT 2024


================
@@ -11336,6 +11336,108 @@ SDValue TargetLowering::expandVectorSplice(SDNode *Node,
                      MachinePointerInfo::getUnknownStack(MF));
 }
 
+SDValue TargetLowering::expandVECTOR_COMPRESS(SDNode *Node,
+                                              SelectionDAG &DAG) const {
+  SDLoc DL(Node);
+  SDValue Vec = Node->getOperand(0);
+  SDValue Mask = Node->getOperand(1);
+  SDValue Passthru = Node->getOperand(2);
+
+  EVT VecVT = Vec.getValueType();
+  EVT ScalarVT = VecVT.getScalarType();
+  EVT MaskVT = Mask.getValueType();
+  EVT MaskScalarVT = MaskVT.getScalarType();
+
+  // Needs to be handled by targets that have scalable vector types.
+  if (VecVT.isScalableVector())
+    report_fatal_error("Cannot expand masked_compress for scalable vectors.");
+
+  SDValue StackPtr = DAG.CreateStackTemporary(
+      VecVT.getStoreSize(), DAG.getReducedAlign(VecVT, /*UseABI=*/false));
+  int FI = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
+  MachinePointerInfo PtrInfo =
+      MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI);
+
+  MVT PositionVT = getVectorIdxTy(DAG.getDataLayout());
+  SDValue Chain = DAG.getEntryNode();
+  SDValue OutPos = DAG.getConstant(0, DL, PositionVT);
+
+  bool HasPassthru = !Passthru.isUndef();
+
+  // If we have a passthru vector, store it on the stack, overwrite the matching
+  // positions and then re-write the last element that was potentially
+  // overwritten even though mask[i] = false.
+  if (HasPassthru)
+    Chain = DAG.getStore(Chain, DL, Passthru, StackPtr, PtrInfo);
+
+  SDValue LastWriteVal;
+  APInt PassthruSplatVal;
+  bool IsSplatPassthru =
+      ISD::isConstantSplatVector(Passthru.getNode(), PassthruSplatVal);
+
+  if (IsSplatPassthru) {
+    // As we do not know which position we wrote to last, we cannot simply
+    // access that index from the passthru vector. So we first check if passthru
+    // is a splat vector, to use any element ...
+    LastWriteVal = DAG.getConstant(PassthruSplatVal, DL, ScalarVT);
+  } else if (HasPassthru) {
+    // ... if it is not a splat vector, we need to get the passthru value at
+    // position = popcount(mask) and re-load it from the stack before it is
+    // overwritten in the loop below.
+    SDValue Popcount = DAG.getNode(
+        ISD::TRUNCATE, DL, MaskVT.changeVectorElementType(MVT::i1), Mask);
+    Popcount = DAG.getNode(ISD::ZERO_EXTEND, DL,
+                           MaskVT.changeVectorElementType(ScalarVT), Popcount);
+    Popcount = DAG.getNode(ISD::VECREDUCE_ADD, DL, ScalarVT, Popcount);
+    SDValue LastElmtPtr =
+        getVectorElementPointer(DAG, StackPtr, VecVT, Popcount);
+    LastWriteVal = DAG.getLoad(
+        ScalarVT, DL, Chain, LastElmtPtr,
+        MachinePointerInfo::getUnknownStack(DAG.getMachineFunction()));
+    Chain = LastWriteVal.getValue(1);
----------------
lukel97 wrote:

Can we just extract LastWriteVal from the passthru directly with extract_vector_elt? 

https://github.com/llvm/llvm-project/pull/92289


More information about the llvm-commits mailing list