[llvm] [LLVM] Add `llvm.experimental.vector.compress` intrinsic (PR #92289)
Luke Lau via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 2 20:03:36 PDT 2024
================
@@ -11336,6 +11336,108 @@ SDValue TargetLowering::expandVectorSplice(SDNode *Node,
MachinePointerInfo::getUnknownStack(MF));
}
+SDValue TargetLowering::expandVECTOR_COMPRESS(SDNode *Node,
+ SelectionDAG &DAG) const {
+ SDLoc DL(Node);
+ SDValue Vec = Node->getOperand(0);
+ SDValue Mask = Node->getOperand(1);
+ SDValue Passthru = Node->getOperand(2);
+
+ EVT VecVT = Vec.getValueType();
+ EVT ScalarVT = VecVT.getScalarType();
+ EVT MaskVT = Mask.getValueType();
+ EVT MaskScalarVT = MaskVT.getScalarType();
+
+ // Needs to be handled by targets that have scalable vector types.
+ if (VecVT.isScalableVector())
+ report_fatal_error("Cannot expand masked_compress for scalable vectors.");
+
+ SDValue StackPtr = DAG.CreateStackTemporary(
+ VecVT.getStoreSize(), DAG.getReducedAlign(VecVT, /*UseABI=*/false));
+ int FI = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
+ MachinePointerInfo PtrInfo =
+ MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI);
+
+ MVT PositionVT = getVectorIdxTy(DAG.getDataLayout());
+ SDValue Chain = DAG.getEntryNode();
+ SDValue OutPos = DAG.getConstant(0, DL, PositionVT);
+
+ bool HasPassthru = !Passthru.isUndef();
+
+ // If we have a passthru vector, store it on the stack, overwrite the matching
+ // positions and then re-write the last element that was potentially
+ // overwritten even though mask[i] = false.
+ if (HasPassthru)
+ Chain = DAG.getStore(Chain, DL, Passthru, StackPtr, PtrInfo);
+
+ SDValue LastWriteVal;
+ APInt PassthruSplatVal;
+ bool IsSplatPassthru =
+ ISD::isConstantSplatVector(Passthru.getNode(), PassthruSplatVal);
+
+ if (IsSplatPassthru) {
+ // As we do not know which position we wrote to last, we cannot simply
+ // access that index from the passthru vector. So we first check if passthru
+ // is a splat vector, to use any element ...
+ LastWriteVal = DAG.getConstant(PassthruSplatVal, DL, ScalarVT);
+ } else if (HasPassthru) {
+ // ... if it is not a splat vector, we need to get the passthru value at
+ // position = popcount(mask) and re-load it from the stack before it is
+ // overwritten in the loop below.
+ SDValue Popcount = DAG.getNode(
+ ISD::TRUNCATE, DL, MaskVT.changeVectorElementType(MVT::i1), Mask);
+ Popcount = DAG.getNode(ISD::ZERO_EXTEND, DL,
+ MaskVT.changeVectorElementType(ScalarVT), Popcount);
+ Popcount = DAG.getNode(ISD::VECREDUCE_ADD, DL, ScalarVT, Popcount);
+ SDValue LastElmtPtr =
+ getVectorElementPointer(DAG, StackPtr, VecVT, Popcount);
+ LastWriteVal = DAG.getLoad(
+ ScalarVT, DL, Chain, LastElmtPtr,
+ MachinePointerInfo::getUnknownStack(DAG.getMachineFunction()));
+ Chain = LastWriteVal.getValue(1);
----------------
lukel97 wrote:
Can we just extract LastWriteVal from the passthru directly with extract_vector_elt?
https://github.com/llvm/llvm-project/pull/92289
More information about the llvm-commits
mailing list