[llvm] [LLVM] Add `llvm.masked.compress` intrinsic (PR #92289)
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Tue Jun 18 06:51:56 PDT 2024
================
@@ -11336,6 +11336,105 @@ SDValue TargetLowering::expandVectorSplice(SDNode *Node,
MachinePointerInfo::getUnknownStack(MF));
}
+SDValue TargetLowering::expandMASKED_COMPRESS(SDNode *Node,
+ SelectionDAG &DAG) const {
+ SDLoc DL(Node);
+ SDValue Vec = Node->getOperand(0);
+ SDValue Mask = Node->getOperand(1);
+ SDValue Passthru = Node->getOperand(2);
+
+ EVT VecVT = Vec.getValueType();
+ EVT ScalarVT = VecVT.getScalarType();
+ EVT MaskVT = Mask.getValueType();
+ EVT MaskScalarVT = MaskVT.getScalarType();
+
+ // Needs to be handled by targets that have scalable vector types.
+ if (VecVT.isScalableVector())
+ report_fatal_error("Cannot expand masked_compress for scalable vectors.");
+
+ SDValue StackPtr = DAG.CreateStackTemporary(
+ VecVT.getStoreSize(), DAG.getReducedAlign(VecVT, /*UseABI=*/false));
+ int FI = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
+ MachinePointerInfo PtrInfo =
+ MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI);
+
+ SDValue Chain = DAG.getEntryNode();
+ SDValue OutPos = DAG.getConstant(0, DL, MVT::i32);
+
+ bool HasPassthru = !Passthru.isUndef();
+
+ // If we have a passthru vector, store it on the stack, overwrite the matching
+ // positions and then re-write the last element that was potentially
+ // overwritten even though mask[i] = false.
+ if (HasPassthru)
+ Chain = DAG.getStore(Chain, DL, Passthru, StackPtr, PtrInfo);
+
+ SDValue LastWriteVal;
+ APInt PassthruSplatVal;
+ bool IsSplatPassthru =
+ ISD::isConstantSplatVector(Passthru.getNode(), PassthruSplatVal);
+
+ if (IsSplatPassthru) {
+ // As we do not know which position we wrote to last, we cannot simply
+ // access that index from the passthru vector. So we first check if passthru
+ // is a splat vector, to use any element ...
+ LastWriteVal = DAG.getConstant(PassthruSplatVal, DL, ScalarVT);
+ } else if (HasPassthru) {
+ // ... if it is not a splat vector, we need to get the passthru value at
+ // position = popcount(mask) and re-load it from the stack before it is
+ // overwritten in the loop below.
+ SDValue Popcount = DAG.getNode(
+ ISD::TRUNCATE, DL, MaskVT.changeVectorElementType(MVT::i1), Mask);
+ Popcount = DAG.getNode(ISD::ZERO_EXTEND, DL,
+ MaskVT.changeVectorElementType(MVT::i32), Popcount);
+ Popcount = DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, Popcount);
+ SDValue LastElmtPtr =
+ getVectorElementPointer(DAG, StackPtr, VecVT, Popcount);
+ LastWriteVal = DAG.getLoad(
+ ScalarVT, DL, Chain, LastElmtPtr,
+ MachinePointerInfo::getUnknownStack(DAG.getMachineFunction()));
+ Chain = LastWriteVal.getValue(1);
+ }
+
+ unsigned NumElms = VecVT.getVectorNumElements();
+ for (unsigned I = 0; I < NumElms; I++) {
+ SDValue Idx = DAG.getVectorIdxConstant(I, DL);
+
+ SDValue ValI = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, Vec, Idx);
+ SDValue OutPtr = getVectorElementPointer(DAG, StackPtr, VecVT, OutPos);
+ Chain = DAG.getStore(
+ Chain, DL, ValI, OutPtr,
+ MachinePointerInfo::getUnknownStack(DAG.getMachineFunction()));
+
+ // Get the mask value and add it to the current output position. This
+ // either increments by 1 if MaskI is true or adds 0 otherwise.
+ SDValue MaskI =
+ DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MaskScalarVT, Mask, Idx);
+ MaskI = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, MaskI);
+ MaskI = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, MaskI);
+ OutPos = DAG.getNode(ISD::ADD, DL, MVT::i32, OutPos, MaskI);
----------------
RKSimon wrote:
Might not want to hard code to i32 (e.g. RISCV64 doesn't like them) - better off use getVectorIdxConstant/getVectorIdxTy ?
https://github.com/llvm/llvm-project/pull/92289
More information about the llvm-commits
mailing list