[llvm] [LLVM] Add `llvm.masked.compress` intrinsic (PR #92289)

Lawrence Benson via llvm-commits llvm-commits at lists.llvm.org
Fri May 17 07:40:47 PDT 2024


https://github.com/lawben updated https://github.com/llvm/llvm-project/pull/92289

>From 3a7b06453eec84b5fd7c3178339fd230f21b5b35 Mon Sep 17 00:00:00 2001
From: Lawrence Benson <github at lawben.com>
Date: Wed, 15 May 2024 14:08:37 +0200
Subject: [PATCH 01/14] Add initial code for @llvm.masked.compress intrinsics

---
 llvm/include/llvm/CodeGen/ISDOpcodes.h        |  5 ++
 llvm/include/llvm/IR/Intrinsics.td            |  5 ++
 .../include/llvm/Target/TargetSelectionDAG.td |  6 ++
 .../SelectionDAG/LegalizeIntegerTypes.cpp     | 20 +++++++
 llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h |  3 +
 .../SelectionDAG/LegalizeVectorOps.cpp        | 50 ++++++++++++++++
 .../SelectionDAG/LegalizeVectorTypes.cpp      | 60 +++++++++++++++++++
 .../SelectionDAG/SelectionDAGBuilder.cpp      |  7 +++
 .../SelectionDAG/SelectionDAGDumper.cpp       |  1 +
 llvm/lib/CodeGen/TargetLoweringBase.cpp       |  3 +
 10 files changed, 160 insertions(+)

diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index d8af97957e48e..71dfd8b43b710 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1294,6 +1294,11 @@ enum NodeType {
   MLOAD,
   MSTORE,
 
+  // Masked compress - consecutively place vector elements based on mask
+  // e.g., vec = {A, B, C, D} and mask = 1010
+  //         --> {A, C, ?, ?} where ? is undefined
+  MCOMPRESS,
+
   // Masked gather and scatter - load and store operations for a vector of
   // random addresses with additional mask operand that prevents memory
   // accesses to the masked-off lanes.
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index f1c7d950f9275..e924d28956b0a 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -2362,6 +2362,11 @@ def int_masked_compressstore:
             [IntrWriteMem, IntrArgMemOnly, IntrWillReturn,
              NoCapture<ArgIndex<1>>]>;
 
+def int_masked_compress:
+    DefaultAttrsIntrinsic<[llvm_anyvector_ty],
+              [llvm_anyvector_ty, LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>],
+              [IntrNoMem, IntrWillReturn]>;
+
 // Test whether a pointer is associated with a type metadata identifier.
 def int_type_test : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_ptr_ty, llvm_metadata_ty],
                               [IntrNoMem, IntrWillReturn, IntrSpeculatable]>;
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 1684b424e3b44..061330fb4e08f 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -266,6 +266,10 @@ def SDTMaskedScatter : SDTypeProfile<0, 4, [
   SDTCisSameNumEltsAs<0, 1>, SDTCisSameNumEltsAs<0, 3>
 ]>;
 
+def SDTMaskedCompress : SDTypeProfile<1, 2, [
+  SDTCisVec<0>, SDTCisVec<1>, SDTCisSameNumEltsAs<0, 1>,
+]>;
+
 def SDTVecShuffle : SDTypeProfile<1, 2, [
   SDTCisSameAs<0, 1>, SDTCisSameAs<1, 2>
 ]>;
@@ -731,6 +735,8 @@ def masked_gather : SDNode<"ISD::MGATHER", SDTMaskedGather,
 def masked_scatter : SDNode<"ISD::MSCATTER", SDTMaskedScatter,
                             [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>;
 
+def masked_compress : SDNode<"ISD::MCOMPRESS", SDTMaskedCompress>;
+
 // Do not use ld, st directly. Use load, extload, sextload, zextload, store,
 // and truncst (see below).
 def ld         : SDNode<"ISD::LOAD"       , SDTLoad,
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 0aa36deda79dc..80f645b433cbe 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -87,6 +87,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
     break;
   case ISD::MGATHER:     Res = PromoteIntRes_MGATHER(cast<MaskedGatherSDNode>(N));
     break;
+  case ISD::MCOMPRESS:   Res = PromoteIntRes_MCOMPRESS(N); break;
   case ISD::SELECT:
   case ISD::VSELECT:
   case ISD::VP_SELECT:
@@ -948,6 +949,11 @@ SDValue DAGTypeLegalizer::PromoteIntRes_MGATHER(MaskedGatherSDNode *N) {
   return Res;
 }
 
+SDValue DAGTypeLegalizer::PromoteIntRes_MCOMPRESS(SDNode *N) {
+  SDValue Vec = GetPromotedInteger(N->getOperand(0));
+  return DAG.getNode(ISD::MCOMPRESS, SDLoc(N), Vec.getValueType(), Vec, N->getOperand(1));
+}
+
 /// Promote the overflow flag of an overflowing arithmetic node.
 SDValue DAGTypeLegalizer::PromoteIntRes_Overflow(SDNode *N) {
   // Change the return type of the boolean result while obeying
@@ -1855,6 +1861,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
                                                  OpNo); break;
   case ISD::MSCATTER: Res = PromoteIntOp_MSCATTER(cast<MaskedScatterSDNode>(N),
                                                   OpNo); break;
+  case ISD::MCOMPRESS: Res = PromoteIntOp_MCOMPRESS(N, OpNo); break;
   case ISD::VP_TRUNCATE:
   case ISD::TRUNCATE:     Res = PromoteIntOp_TRUNCATE(N); break;
   case ISD::BF16_TO_FP:
@@ -2335,6 +2342,19 @@ SDValue DAGTypeLegalizer::PromoteIntOp_MSCATTER(MaskedScatterSDNode *N,
                               N->getIndexType(), TruncateStore);
 }
 
+SDValue DAGTypeLegalizer::PromoteIntOp_MCOMPRESS(SDNode *N, unsigned OpNo) {
+  SDValue Vec = N->getOperand(0);
+  SDValue Mask = N->getOperand(1);
+  EVT VT = Vec.getValueType();
+
+  if (OpNo == 0)
+    Vec = GetPromotedInteger(Vec);
+  else
+    Mask = PromoteTargetBoolean(Mask, VT);
+
+  return DAG.getNode(ISD::MCOMPRESS, SDLoc(N), VT, Vec, Mask);
+}
+
 SDValue DAGTypeLegalizer::PromoteIntOp_TRUNCATE(SDNode *N) {
   SDValue Op = GetPromotedInteger(N->getOperand(0));
   if (N->getOpcode() == ISD::VP_TRUNCATE)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index d925089d5689f..5fb14757f8991 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -321,6 +321,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteIntRes_LOAD(LoadSDNode *N);
   SDValue PromoteIntRes_MLOAD(MaskedLoadSDNode *N);
   SDValue PromoteIntRes_MGATHER(MaskedGatherSDNode *N);
+  SDValue PromoteIntRes_MCOMPRESS(SDNode *N);
   SDValue PromoteIntRes_Overflow(SDNode *N);
   SDValue PromoteIntRes_FFREXP(SDNode *N);
   SDValue PromoteIntRes_SADDSUBO(SDNode *N, unsigned ResNo);
@@ -390,6 +391,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteIntOp_MLOAD(MaskedLoadSDNode *N, unsigned OpNo);
   SDValue PromoteIntOp_MSCATTER(MaskedScatterSDNode *N, unsigned OpNo);
   SDValue PromoteIntOp_MGATHER(MaskedGatherSDNode *N, unsigned OpNo);
+  SDValue PromoteIntOp_MCOMPRESS(SDNode *N, unsigned OpNo);
   SDValue PromoteIntOp_FRAMERETURNADDR(SDNode *N);
   SDValue PromoteIntOp_FIX(SDNode *N);
   SDValue PromoteIntOp_ExpOp(SDNode *N);
@@ -882,6 +884,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   void SplitVecRes_MLOAD(MaskedLoadSDNode *MLD, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_Gather(MemSDNode *VPGT, SDValue &Lo, SDValue &Hi,
                           bool SplitSETCC = false);
+  void SplitVecRes_MCOMPRESS(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_ScalarOp(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_STEP_VECTOR(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_SETCC(SDNode *N, SDValue &Lo, SDValue &Hi);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 423df9ae6b2a5..759de775ba011 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -134,6 +134,7 @@ class VectorLegalizer {
   SDValue ExpandVSELECT(SDNode *Node);
   SDValue ExpandVP_SELECT(SDNode *Node);
   SDValue ExpandVP_MERGE(SDNode *Node);
+  SDValue ExpandMCOMPRESS(SDNode *Node);
   SDValue ExpandVP_REM(SDNode *Node);
   SDValue ExpandSELECT(SDNode *Node);
   std::pair<SDValue, SDValue> ExpandLoad(SDNode *N);
@@ -442,6 +443,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
   case ISD::FP_TO_SINT_SAT:
   case ISD::FP_TO_UINT_SAT:
   case ISD::MGATHER:
+  case ISD::MCOMPRESS:
     Action = TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0));
     break;
   case ISD::SMULFIX:
@@ -1101,6 +1103,9 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
       return;
 
     break;
+  case ISD::MCOMPRESS:
+    Results.push_back(ExpandMCOMPRESS(Node));
+    return;
   }
 
   SDValue Unrolled = DAG.UnrollVectorOp(Node);
@@ -1505,6 +1510,51 @@ SDValue VectorLegalizer::ExpandVP_MERGE(SDNode *Node) {
   return DAG.getSelect(DL, Node->getValueType(0), FullMask, Op1, Op2);
 }
 
+SDValue VectorLegalizer::ExpandMCOMPRESS(SDNode *Node) {
+  SDLoc DL(Node);
+  SDValue Vec = Node->getOperand(0);
+  SDValue Mask = Node->getOperand(1);
+
+  EVT VecVT = Vec.getValueType();
+  EVT ScalarVT = VecVT.getScalarType();
+  EVT MaskScalarVT = Mask.getValueType().getScalarType();
+
+  assert(TLI.isTypeLegal(VecVT) &&  TLI.isTypeLegal(ScalarVT) && TLI.isTypeLegal(MaskScalarVT) &&
+         "Need legal vector/mask element types to scalarize masked compress.");
+
+  SDValue StackPtr = DAG.CreateStackTemporary(VecVT);
+  SDValue Chain = DAG.getEntryNode();
+  SDValue OutPos = DAG.getConstant(0, DL, MVT::i32);
+
+  unsigned NumElms = VecVT.getVectorNumElements();
+  // Skip element zero, as we always copy this to the output vector.
+  for (unsigned I = 0; I < NumElms; I++) {
+    SDValue Idx = DAG.getVectorIdxConstant(I, DL);
+
+    SDValue ValI =
+        DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, Vec, Idx);
+    SDValue OutPtr =
+        TLI.getVectorElementPointer(DAG, StackPtr, VecVT, OutPos);
+    Chain = DAG.getStore(Chain, DL, ValI, OutPtr, MachinePointerInfo::getUnknownStack(DAG.getMachineFunction()));
+
+    // Skip this for last element.
+    if (I < NumElms - 1) {
+      // Get the mask value and add it to the current output position. This
+      // either increments by 1 if MaskI is true or adds 0 otherwise.
+      SDValue MaskI =
+          DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MaskScalarVT, Mask, Idx);
+      MaskI = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, MaskI);
+      MaskI = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, MaskI);
+      OutPos = DAG.getNode(ISD::ADD, DL, MVT::i32, OutPos, MaskI);
+    }
+  }
+
+  int FI = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
+  MachinePointerInfo PtrInfo =
+      MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI);
+  return DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);
+}
+
 SDValue VectorLegalizer::ExpandVP_REM(SDNode *Node) {
   // Implement VP_SREM/UREM in terms of VP_SDIV/VP_UDIV, VP_MUL, VP_SUB.
   EVT VT = Node->getValueType(0);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index cd858003cf03b..62e7febed6568 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1058,6 +1058,9 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::VP_GATHER:
     SplitVecRes_Gather(cast<MemSDNode>(N), Lo, Hi, /*SplitSETCC*/ true);
     break;
+  case ISD::MCOMPRESS:
+    SplitVecRes_MCOMPRESS(N, Lo, Hi);
+    break;
   case ISD::SETCC:
   case ISD::VP_SETCC:
     SplitVecRes_SETCC(N, Lo, Hi);
@@ -2304,6 +2307,63 @@ void DAGTypeLegalizer::SplitVecRes_Gather(MemSDNode *N, SDValue &Lo,
   ReplaceValueWith(SDValue(N, 1), Ch);
 }
 
+void DAGTypeLegalizer::SplitVecRes_MCOMPRESS(SDNode *N, SDValue &Lo,
+                                             SDValue &Hi) {
+  // This is not "trivial", as there is a dependency between the two subvectors.
+  // Depending on the number of 1s in the mask, the elements from the Hi vector
+  // need to be moved to the Lo vector. So we just perform this as one "big"
+  // operation (analogously to the default MCOMPRESS expand implementation), by
+  // writing to memory and then loading the Lo and Hi vectors from that. This
+  // gets rid of MCOMPRESS and all other operands can be legalized later.
+  SDLoc DL(N);
+  SDValue Vec = N->getOperand(0);
+  SDValue Mask = N->getOperand(1);
+
+  EVT VecVT = Vec.getValueType();
+  EVT SubVecVT = VecVT.getHalfNumVectorElementsVT(*DAG.getContext());
+  EVT ScalarVT = VecVT.getScalarType();
+  EVT MaskScalarVT = Mask.getValueType().getScalarType();
+
+  // TODO: This code is duplicated here and in LegalizeVectorOps.
+  SDValue StackPtr = DAG.CreateStackTemporary(VecVT);
+  SDValue Chain = DAG.getEntryNode();
+  SDValue OutPos = DAG.getConstant(0, DL, MVT::i32);
+
+  unsigned NumElms = VecVT.getVectorNumElements();
+  // Skip element zero, as we always copy this to the output vector.
+  for (unsigned I = 0; I < NumElms; I++) {
+    SDValue Idx = DAG.getVectorIdxConstant(I, DL);
+
+    SDValue ValI = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, Vec, Idx);
+    SDValue OutPtr = TLI.getVectorElementPointer(DAG, StackPtr, VecVT, OutPos);
+    Chain = DAG.getStore(
+        Chain, DL, ValI, OutPtr,
+        MachinePointerInfo::getUnknownStack(DAG.getMachineFunction()));
+
+    // Skip this for last element.
+    if (I < NumElms - 1) {
+      // Get the mask value and add it to the current output position. This
+      // either increments by 1 if MaskI is true or adds 0 otherwise.
+      SDValue MaskI =
+          DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MaskScalarVT, Mask, Idx);
+      MaskI = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, MaskI);
+      MaskI = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, MaskI);
+      OutPos = DAG.getNode(ISD::ADD, DL, MVT::i32, OutPos, MaskI);
+    }
+  }
+
+  int FI = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
+  MachinePointerInfo PtrInfo =
+      MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI);
+  SDValue HiPtr = TLI.getVectorElementPointer(
+      DAG, StackPtr, VecVT, DAG.getConstant(NumElms / 2, DL, MVT::i32));
+
+  Lo = DAG.getLoad(SubVecVT, DL, Chain, StackPtr, PtrInfo);
+  Hi = DAG.getLoad(
+      SubVecVT, DL, Chain, HiPtr,
+      MachinePointerInfo::getUnknownStack(DAG.getMachineFunction()));
+}
+
 void DAGTypeLegalizer::SplitVecRes_SETCC(SDNode *N, SDValue &Lo, SDValue &Hi) {
   assert(N->getValueType(0).isVector() &&
          N->getOperand(0).getValueType().isVector() &&
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index ca352da5d36eb..665bab6121837 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -6718,6 +6718,13 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
   case Intrinsic::masked_compressstore:
     visitMaskedStore(I, true /* IsCompressing */);
     return;
+  case Intrinsic::masked_compress:
+    setValue(&I, DAG.getNode(ISD::MCOMPRESS, sdl,
+                             getValue(I.getArgOperand(0)).getValueType(),
+                             getValue(I.getArgOperand(0)),
+                             getValue(I.getArgOperand(1)),
+                             Flags));
+    return;
   case Intrinsic::powi:
     setValue(&I, ExpandPowI(sdl, getValue(I.getArgOperand(0)),
                             getValue(I.getArgOperand(1)), DAG));
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 59742e90c6791..37288054b0e7b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -416,6 +416,7 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
   case ISD::MSTORE:                     return "masked_store";
   case ISD::MGATHER:                    return "masked_gather";
   case ISD::MSCATTER:                   return "masked_scatter";
+  case ISD::MCOMPRESS:                  return "masked_compress";
   case ISD::VAARG:                      return "vaarg";
   case ISD::VACOPY:                     return "vacopy";
   case ISD::VAEND:                      return "vaend";
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index 09b70cfb72278..5ee12be752b27 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -956,6 +956,9 @@ void TargetLoweringBase::initActions() {
     // Named vector shuffles default to expand.
     setOperationAction(ISD::VECTOR_SPLICE, VT, Expand);
 
+    // Only some target support this vector operation. Most need to expand it.
+    setOperationAction(ISD::MCOMPRESS, VT, Expand);
+
     // VP operations default to expand.
 #define BEGIN_REGISTER_VP_SDNODE(SDOPC, ...)                                   \
     setOperationAction(ISD::SDOPC, VT, Expand);

>From 75abf0b013f732335ced35002055f0da48f724e0 Mon Sep 17 00:00:00 2001
From: Lawrence Benson <github at lawben.com>
Date: Wed, 15 May 2024 15:32:47 +0200
Subject: [PATCH 02/14] Remove requirements for legal types

---
 llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 759de775ba011..ca32db26e511c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -1519,9 +1519,6 @@ SDValue VectorLegalizer::ExpandMCOMPRESS(SDNode *Node) {
   EVT ScalarVT = VecVT.getScalarType();
   EVT MaskScalarVT = Mask.getValueType().getScalarType();
 
-  assert(TLI.isTypeLegal(VecVT) &&  TLI.isTypeLegal(ScalarVT) && TLI.isTypeLegal(MaskScalarVT) &&
-         "Need legal vector/mask element types to scalarize masked compress.");
-
   SDValue StackPtr = DAG.CreateStackTemporary(VecVT);
   SDValue Chain = DAG.getEntryNode();
   SDValue OutPos = DAG.getConstant(0, DL, MVT::i32);

>From 0329bc9652f9a6c633924d951d6694399d9f6af7 Mon Sep 17 00:00:00 2001
From: Lawrence Benson <github at lawben.com>
Date: Wed, 15 May 2024 16:33:36 +0200
Subject: [PATCH 03/14] Add tests for AArch64

---
 llvm/test/CodeGen/AArch64/masked-compress.ll | 280 +++++++++++++++++++
 1 file changed, 280 insertions(+)
 create mode 100644 llvm/test/CodeGen/AArch64/masked-compress.ll

diff --git a/llvm/test/CodeGen/AArch64/masked-compress.ll b/llvm/test/CodeGen/AArch64/masked-compress.ll
new file mode 100644
index 0000000000000..54c3beab82f76
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/masked-compress.ll
@@ -0,0 +1,280 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=aarch64-apple-darwin -verify-machineinstrs < %s | FileCheck %s
+
+define <4 x i32> @test_compress_v4i32(<4 x i32> %vec, <4 x i1> %mask) {
+; CHECK-LABEL: test_compress_v4i32:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    ushll.4s v1, v1, #0
+; CHECK-NEXT:    mov x8, sp
+; CHECK-NEXT:    str s0, [sp]
+; CHECK-NEXT:    shl.4s v1, v1, #31
+; CHECK-NEXT:    cmlt.4s v1, v1, #0
+; CHECK-NEXT:    mov.s w9, v1[1]
+; CHECK-NEXT:    mov.s w10, v1[2]
+; CHECK-NEXT:    fmov w11, s1
+; CHECK-NEXT:    bfi x8, x11, #2, #1
+; CHECK-NEXT:    and w11, w11, #0x1
+; CHECK-NEXT:    and w9, w9, #0x1
+; CHECK-NEXT:    and w10, w10, #0x1
+; CHECK-NEXT:    add w9, w11, w9
+; CHECK-NEXT:    mov x11, sp
+; CHECK-NEXT:    st1.s { v0 }[1], [x8]
+; CHECK-NEXT:    add w10, w9, w10
+; CHECK-NEXT:    orr x9, x11, x9, lsl #2
+; CHECK-NEXT:    bfi x11, x10, #2, #2
+; CHECK-NEXT:    st1.s { v0 }[2], [x9]
+; CHECK-NEXT:    st1.s { v0 }[3], [x11]
+; CHECK-NEXT:    ldr q0, [sp], #16
+; CHECK-NEXT:    ret
+    %out = call <4 x i32> @llvm.masked.compress.v4i32(<4 x i32> %vec, <4 x i1> %mask)
+    ret <4 x i32> %out
+}
+
+define <16 x i8> @test_compress_v16i8(<16 x i8> %vec, <16 x i1> %mask) {
+; CHECK-LABEL: test_compress_v16i8:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    shl.16b v1, v1, #7
+; CHECK-NEXT:    mov x12, sp
+; CHECK-NEXT:    mov x8, sp
+; CHECK-NEXT:    st1.b { v0 }[0], [x8]
+; CHECK-NEXT:    mov x13, sp
+; CHECK-NEXT:    cmlt.16b v1, v1, #0
+; CHECK-NEXT:    umov.b w9, v1[0]
+; CHECK-NEXT:    umov.b w10, v1[1]
+; CHECK-NEXT:    umov.b w11, v1[2]
+; CHECK-NEXT:    umov.b w14, v1[3]
+; CHECK-NEXT:    bfxil x12, x9, #0, #1
+; CHECK-NEXT:    and w10, w10, #0x1
+; CHECK-NEXT:    and w9, w9, #0x1
+; CHECK-NEXT:    add w9, w9, w10
+; CHECK-NEXT:    umov.b w10, v1[4]
+; CHECK-NEXT:    and w11, w11, #0x1
+; CHECK-NEXT:    st1.b { v0 }[1], [x12]
+; CHECK-NEXT:    orr x12, x8, x9
+; CHECK-NEXT:    add w9, w9, w11
+; CHECK-NEXT:    umov.b w11, v1[5]
+; CHECK-NEXT:    and w14, w14, #0x1
+; CHECK-NEXT:    st1.b { v0 }[2], [x12]
+; CHECK-NEXT:    add w14, w9, w14
+; CHECK-NEXT:    umov.b w12, v1[6]
+; CHECK-NEXT:    orr x9, x8, x9
+; CHECK-NEXT:    and w10, w10, #0x1
+; CHECK-NEXT:    st1.b { v0 }[3], [x9]
+; CHECK-NEXT:    orr x9, x8, x14
+; CHECK-NEXT:    add w10, w14, w10
+; CHECK-NEXT:    umov.b w14, v1[7]
+; CHECK-NEXT:    st1.b { v0 }[4], [x9]
+; CHECK-NEXT:    and w11, w11, #0x1
+; CHECK-NEXT:    bfxil x13, x10, #0, #4
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    add w10, w10, w11
+; CHECK-NEXT:    umov.b w11, v1[8]
+; CHECK-NEXT:    and w12, w12, #0x1
+; CHECK-NEXT:    bfxil x9, x10, #0, #4
+; CHECK-NEXT:    st1.b { v0 }[5], [x13]
+; CHECK-NEXT:    umov.b w13, v1[9]
+; CHECK-NEXT:    add w10, w10, w12
+; CHECK-NEXT:    mov x12, sp
+; CHECK-NEXT:    and w14, w14, #0x1
+; CHECK-NEXT:    st1.b { v0 }[6], [x9]
+; CHECK-NEXT:    umov.b w9, v1[10]
+; CHECK-NEXT:    bfxil x12, x10, #0, #4
+; CHECK-NEXT:    add w10, w10, w14
+; CHECK-NEXT:    mov x14, sp
+; CHECK-NEXT:    and w11, w11, #0x1
+; CHECK-NEXT:    bfxil x14, x10, #0, #4
+; CHECK-NEXT:    add w10, w10, w11
+; CHECK-NEXT:    mov x11, sp
+; CHECK-NEXT:    and w13, w13, #0x1
+; CHECK-NEXT:    st1.b { v0 }[7], [x12]
+; CHECK-NEXT:    mov x12, sp
+; CHECK-NEXT:    bfxil x11, x10, #0, #4
+; CHECK-NEXT:    add w10, w10, w13
+; CHECK-NEXT:    umov.b w13, v1[11]
+; CHECK-NEXT:    st1.b { v0 }[8], [x14]
+; CHECK-NEXT:    umov.b w14, v1[12]
+; CHECK-NEXT:    and w9, w9, #0x1
+; CHECK-NEXT:    bfxil x12, x10, #0, #4
+; CHECK-NEXT:    add w9, w10, w9
+; CHECK-NEXT:    mov x10, sp
+; CHECK-NEXT:    st1.b { v0 }[9], [x11]
+; CHECK-NEXT:    umov.b w11, v1[13]
+; CHECK-NEXT:    bfxil x10, x9, #0, #4
+; CHECK-NEXT:    st1.b { v0 }[10], [x12]
+; CHECK-NEXT:    umov.b w12, v1[14]
+; CHECK-NEXT:    and w13, w13, #0x1
+; CHECK-NEXT:    and w14, w14, #0x1
+; CHECK-NEXT:    add w9, w9, w13
+; CHECK-NEXT:    st1.b { v0 }[11], [x10]
+; CHECK-NEXT:    mov x10, sp
+; CHECK-NEXT:    add w13, w9, w14
+; CHECK-NEXT:    mov x14, sp
+; CHECK-NEXT:    bfxil x10, x9, #0, #4
+; CHECK-NEXT:    and w9, w11, #0x1
+; CHECK-NEXT:    mov x11, sp
+; CHECK-NEXT:    add w9, w13, w9
+; CHECK-NEXT:    and w12, w12, #0x1
+; CHECK-NEXT:    bfxil x14, x13, #0, #4
+; CHECK-NEXT:    bfxil x11, x9, #0, #4
+; CHECK-NEXT:    add w9, w9, w12
+; CHECK-NEXT:    st1.b { v0 }[12], [x10]
+; CHECK-NEXT:    bfxil x8, x9, #0, #4
+; CHECK-NEXT:    st1.b { v0 }[13], [x14]
+; CHECK-NEXT:    st1.b { v0 }[14], [x11]
+; CHECK-NEXT:    st1.b { v0 }[15], [x8]
+; CHECK-NEXT:    ldr q0, [sp], #16
+; CHECK-NEXT:    ret
+    %out = call <16 x i8> @llvm.masked.compress.v16i8(<16 x i8> %vec, <16 x i1> %mask)
+    ret <16 x i8> %out
+}
+
+define <8 x i32> @test_compress_large(<8 x i32> %vec, <8 x i1> %mask) {
+; CHECK-LABEL: test_compress_large:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-16]! ; 16-byte Folded Spill
+; CHECK-NEXT:    sub x9, sp, #48
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    and sp, x9, #0xffffffffffffffe0
+; CHECK-NEXT:    .cfi_def_cfa w29, 16
+; CHECK-NEXT:    .cfi_offset w30, -8
+; CHECK-NEXT:    .cfi_offset w29, -16
+; CHECK-NEXT:    ; kill: def $d2 killed $d2 def $q2
+; CHECK-NEXT:    umov.b w9, v2[0]
+; CHECK-NEXT:    umov.b w10, v2[1]
+; CHECK-NEXT:    mov x12, sp
+; CHECK-NEXT:    umov.b w11, v2[2]
+; CHECK-NEXT:    umov.b w13, v2[3]
+; CHECK-NEXT:    mov x8, sp
+; CHECK-NEXT:    umov.b w14, v2[4]
+; CHECK-NEXT:    str s0, [sp]
+; CHECK-NEXT:    bfi x12, x9, #2, #1
+; CHECK-NEXT:    and w10, w10, #0x1
+; CHECK-NEXT:    and w9, w9, #0x1
+; CHECK-NEXT:    add w9, w9, w10
+; CHECK-NEXT:    and w10, w11, #0x1
+; CHECK-NEXT:    and w13, w13, #0x1
+; CHECK-NEXT:    orr x11, x8, x9, lsl #2
+; CHECK-NEXT:    add w9, w9, w10
+; CHECK-NEXT:    umov.b w10, v2[5]
+; CHECK-NEXT:    st1.s { v0 }[1], [x12]
+; CHECK-NEXT:    add w13, w9, w13
+; CHECK-NEXT:    orr x9, x8, x9, lsl #2
+; CHECK-NEXT:    st1.s { v0 }[2], [x11]
+; CHECK-NEXT:    umov.b w11, v2[6]
+; CHECK-NEXT:    mov x12, sp
+; CHECK-NEXT:    and w14, w14, #0x1
+; CHECK-NEXT:    bfi x12, x13, #2, #3
+; CHECK-NEXT:    st1.s { v0 }[3], [x9]
+; CHECK-NEXT:    add w13, w13, w14
+; CHECK-NEXT:    and w9, w10, #0x1
+; CHECK-NEXT:    mov x10, sp
+; CHECK-NEXT:    add w9, w13, w9
+; CHECK-NEXT:    mov x14, sp
+; CHECK-NEXT:    str s1, [x12]
+; CHECK-NEXT:    and w11, w11, #0x1
+; CHECK-NEXT:    bfi x10, x9, #2, #3
+; CHECK-NEXT:    bfi x14, x13, #2, #3
+; CHECK-NEXT:    add w9, w9, w11
+; CHECK-NEXT:    bfi x8, x9, #2, #3
+; CHECK-NEXT:    st1.s { v1 }[1], [x14]
+; CHECK-NEXT:    st1.s { v1 }[2], [x10]
+; CHECK-NEXT:    st1.s { v1 }[3], [x8]
+; CHECK-NEXT:    ldp q0, q1, [sp]
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldp x29, x30, [sp], #16 ; 16-byte Folded Reload
+; CHECK-NEXT:    ret
+    %out = call <8 x i32> @llvm.masked.compress.v8i32(<8 x i32> %vec, <8 x i1> %mask)
+    ret <8 x i32> %out
+}
+
+define <4 x i32> @test_compress_const(<4 x i32> %vec, <4 x i1> %mask) {
+; CHECK-LABEL: test_compress_const:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    mov x8, #3 ; =0x3
+; CHECK-NEXT:    mov w9, #9 ; =0x9
+; CHECK-NEXT:    movk x8, #7, lsl #32
+; CHECK-NEXT:    str x8, [sp, #-16]!
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    mov w8, #5 ; =0x5
+; CHECK-NEXT:    str w9, [sp, #8]
+; CHECK-NEXT:    str w8, [sp]
+; CHECK-NEXT:    ldr q0, [sp], #16
+; CHECK-NEXT:    ret
+    %out = call <4 x i32> @llvm.masked.compress.v4i32(<4 x i32> <i32 3, i32 5, i32 7, i32 9>,
+                                                      <4 x i1>   <i1 0,  i1 1,  i1 1,  i1 0>)
+    ret <4 x i32> %out
+}
+
+define <4 x i8> @test_compress_small(<4 x i8> %vec, <4 x i1> %mask) {
+; CHECK-LABEL: test_compress_small:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    shl.4h v1, v1, #15
+; CHECK-NEXT:    add x8, sp, #8
+; CHECK-NEXT:    ; kill: def $d0 killed $d0 def $q0
+; CHECK-NEXT:    str h0, [sp, #8]
+; CHECK-NEXT:    cmlt.4h v1, v1, #0
+; CHECK-NEXT:    umov.h w9, v1[0]
+; CHECK-NEXT:    umov.h w10, v1[1]
+; CHECK-NEXT:    umov.h w11, v1[2]
+; CHECK-NEXT:    bfi x8, x9, #1, #1
+; CHECK-NEXT:    and w10, w10, #0x1
+; CHECK-NEXT:    and w9, w9, #0x1
+; CHECK-NEXT:    add w9, w9, w10
+; CHECK-NEXT:    and w11, w11, #0x1
+; CHECK-NEXT:    add x10, sp, #8
+; CHECK-NEXT:    add w11, w9, w11
+; CHECK-NEXT:    orr x9, x10, x9, lsl #1
+; CHECK-NEXT:    st1.h { v0 }[1], [x8]
+; CHECK-NEXT:    bfi x10, x11, #1, #2
+; CHECK-NEXT:    st1.h { v0 }[2], [x9]
+; CHECK-NEXT:    st1.h { v0 }[3], [x10]
+; CHECK-NEXT:    ldr d0, [sp, #8]
+; CHECK-NEXT:    add sp, sp, #16
+; CHECK-NEXT:    ret
+    %out = call <4 x i8> @llvm.masked.compress.v4i8(<4 x i8> %vec, <4 x i1> %mask)
+    ret <4 x i8> %out
+}
+
+define <4 x i4> @test_compress_illegal_element_type(<4 x i4> %vec, <4 x i1> %mask) {
+; CHECK-LABEL: test_compress_illegal_element_type:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    shl.4h v1, v1, #15
+; CHECK-NEXT:    add x8, sp, #8
+; CHECK-NEXT:    ; kill: def $d0 killed $d0 def $q0
+; CHECK-NEXT:    str h0, [sp, #8]
+; CHECK-NEXT:    cmlt.4h v1, v1, #0
+; CHECK-NEXT:    umov.h w9, v1[0]
+; CHECK-NEXT:    umov.h w10, v1[1]
+; CHECK-NEXT:    umov.h w11, v1[2]
+; CHECK-NEXT:    bfi x8, x9, #1, #1
+; CHECK-NEXT:    and w10, w10, #0x1
+; CHECK-NEXT:    and w9, w9, #0x1
+; CHECK-NEXT:    add w9, w9, w10
+; CHECK-NEXT:    and w11, w11, #0x1
+; CHECK-NEXT:    add x10, sp, #8
+; CHECK-NEXT:    add w11, w9, w11
+; CHECK-NEXT:    orr x9, x10, x9, lsl #1
+; CHECK-NEXT:    st1.h { v0 }[1], [x8]
+; CHECK-NEXT:    bfi x10, x11, #1, #2
+; CHECK-NEXT:    st1.h { v0 }[2], [x9]
+; CHECK-NEXT:    st1.h { v0 }[3], [x10]
+; CHECK-NEXT:    ldr d0, [sp, #8]
+; CHECK-NEXT:    add sp, sp, #16
+; CHECK-NEXT:    ret
+    %out = call <4 x i4> @llvm.masked.compress.v4i4(<4 x i4> %vec, <4 x i1> %mask)
+    ret <4 x i4> %out
+}
+
+declare <4 x i32> @llvm.masked.compress.v4i32(<4 x i32>, <4 x i1>)
+declare <16 x i8> @llvm.masked.compress.v16i8(<16 x i8>, <16 x i1>)
+declare <4 x i4> @llvm.masked.compress.v4i4(<4 x i4>, <4 x i1>)
+declare <4 x i8> @llvm.masked.compress.v4i8(<4 x i8>, <4 x i1>)
+declare <8 x i32> @llvm.masked.compress.v8i32(<8 x i32>, <8 x i1>)

>From 73bfebbdb3d7a1c96c8521e5789ad7e59f665da7 Mon Sep 17 00:00:00 2001
From: Lawrence Benson <github at lawben.com>
Date: Wed, 15 May 2024 16:41:13 +0200
Subject: [PATCH 04/14] Add floating point test

---
 llvm/test/CodeGen/AArch64/masked-compress.ll | 19 +++++++++++++++++++
 1 file changed, 19 insertions(+)

diff --git a/llvm/test/CodeGen/AArch64/masked-compress.ll b/llvm/test/CodeGen/AArch64/masked-compress.ll
index 54c3beab82f76..a2f39b9620c95 100644
--- a/llvm/test/CodeGen/AArch64/masked-compress.ll
+++ b/llvm/test/CodeGen/AArch64/masked-compress.ll
@@ -32,6 +32,25 @@ define <4 x i32> @test_compress_v4i32(<4 x i32> %vec, <4 x i1> %mask) {
     ret <4 x i32> %out
 }
 
+define <2 x double> @test_compress_v2f64(<2 x double> %vec, <2 x i1> %mask) {
+; CHECK-LABEL: test_compress_v2f64:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    ushll.2d v1, v1, #0
+; CHECK-NEXT:    mov x8, sp
+; CHECK-NEXT:    str d0, [sp]
+; CHECK-NEXT:    shl.2d v1, v1, #63
+; CHECK-NEXT:    cmlt.2d v1, v1, #0
+; CHECK-NEXT:    fmov x9, d1
+; CHECK-NEXT:    bfi x8, x9, #3, #1
+; CHECK-NEXT:    st1.d { v0 }[1], [x8]
+; CHECK-NEXT:    ldr q0, [sp], #16
+; CHECK-NEXT:    ret
+    %out = call <2 x double> @llvm.masked.compress.v2f64(<2 x double> %vec, <2 x i1> %mask)
+    ret <2 x double> %out
+}
+
 define <16 x i8> @test_compress_v16i8(<16 x i8> %vec, <16 x i1> %mask) {
 ; CHECK-LABEL: test_compress_v16i8:
 ; CHECK:       ; %bb.0:

>From e4423a1b434c086a7787594efbba96aa29e392c4 Mon Sep 17 00:00:00 2001
From: Lawrence Benson <github at lawben.com>
Date: Wed, 15 May 2024 17:47:42 +0200
Subject: [PATCH 05/14] Add documentation

---
 llvm/docs/LangRef.rst | 79 +++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 79 insertions(+)

diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 06809f8bf445d..773893b83a5d7 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -24975,6 +24975,85 @@ The '``llvm.masked.compressstore``' intrinsic is designed for compressing data i
 
 Other targets may support this intrinsic differently, for example, by lowering it into a sequence of branches that guard scalar store operations.
 
+Masked Vector Compress Intrinsic
+--------------------------------
+
+LLVM provides an intrinsic for compressing data within a vector based on a selection mask.
+Semantically, this is similar to :ref:``@llvm.masked.compressstore <_int_compressstore>`` but with weaker assumptions
+and without storing the results to memory, i.e., the data remains in the vector.
+
+.. _int_masked_compress:
+
+'``llvm.masked.compress.*``' Intrinsics
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic. A number of scalar values of integer, floating point or pointer data type are collected
+from an input vector and placed adjacently within the result vector. A mask defines which elements to collect from the vector.
+
+:: code-block:: llvm
+
+      declare <8 x i32> @llvm.masked.compress.v8i32(<8 x i32> <value>, <8 x i1> <mask>)
+      declare <16 x float> @llvm.masked.compress.v16f32(<16 x float> <value>, <16 x i1> <mask>)
+
+Overview:
+"""""""""
+
+Selects elements from input vector '``value``' according to the '``mask``'.
+All selected elements are written into adjacent lanes in the result vector, from lower to higher.
+The mask holds a bit for each vector lane, and is used to select elements to be kept.
+The number of valid lanes is equal to the number of active bits in the mask.
+The main difference to :ref:`llvm.masked.compressstore <_int_compressstore>` is that the remainder of the vector may
+contain undefined values.
+This allows for branchless code and better optimization for all targets that do not support the explicit semantics of
+:ref:`llvm.masked.compressstore <_int_compressstore>`.
+The result vector can be written with a similar effect, as all the selected values are at the lower positions of the
+vector, but without requiring branches to avoid writes where the mask is 0.
+
+
+Arguments:
+""""""""""
+
+The first operand is the input vector, from which elements are selected.
+The second operand is the mask, a vector of boolean values.
+The mask and the input vector must have the same number of vector elements.
+
+Semantics:
+""""""""""
+
+The '``llvm.masked.compress``' intrinsic is designed for compressing data within a vector, i.e., ideally within a register.
+It allows to collect elements from possibly non-adjacent lanes of a vector and place them contiguously in the result vector in one IR operation.
+It is useful for targets all that support compress operations (e.g., AVX-512, ARM SVE, RISCV V), which more instruction
+sets do than explicit compressstore, i.e., ``llvm.masked.compress`` may yield better performance on more targets than
+``llvm.masked.compressstore`` due to weaker constraints.
+This intrinsic allows vectorizing loops with cross-iteration dependencies like in the following example:
+
+.. code-block:: c
+
+    // Consecutively store selected values with branchless code.
+    int *in, *out; bool *mask; int pos = 0;
+    for (int i = 0; i < size; ++i) {
+      out[pos] = in[i];
+      // if mask[i] == 0, the current value is overwritten in the next iteration.
+      pos += mask[i];
+    }
+
+
+.. code-block:: llvm
+
+    ; Load elements from `in`.
+    %vec = load <4 x i32>, ptr %inPtr
+    %mask = load <4 x i1>, ptr %maskPtr
+    %compressed = call <4 x i32> @llvm.masked.compress.v4i32(<4 x i32> %vec, <4 x i1> %mask)
+    store <4 x i32> %compressed, ptr %outPtr
+
+    ; %outPtr should be increased in each iteration by the number of '1's in the mask.
+    %iMask = bitcast <4 x i1> %mask to i4
+    %popcnt = call i4 @llvm.ctpop.i4(i4 %iMask)
+    %zextPopcnt = zext i4 %popcnt to i64
+    %nextOut = add i64 %outPos, %zextPopcnt
+
 
 Memory Use Markers
 ------------------

>From 3e9967803d110116fc90823ae39a91cfe9d03d2c Mon Sep 17 00:00:00 2001
From: Lawrence Benson <github at lawben.com>
Date: Wed, 15 May 2024 18:06:44 +0200
Subject: [PATCH 06/14] Fix formatting

---
 .../lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp | 11 ++++++++---
 llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp   | 10 +++++-----
 llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp |  3 +--
 3 files changed, 14 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 80f645b433cbe..4063144f47393 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -87,7 +87,9 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
     break;
   case ISD::MGATHER:     Res = PromoteIntRes_MGATHER(cast<MaskedGatherSDNode>(N));
     break;
-  case ISD::MCOMPRESS:   Res = PromoteIntRes_MCOMPRESS(N); break;
+  case ISD::MCOMPRESS:
+    Res = PromoteIntRes_MCOMPRESS(N);
+    break;
   case ISD::SELECT:
   case ISD::VSELECT:
   case ISD::VP_SELECT:
@@ -951,7 +953,8 @@ SDValue DAGTypeLegalizer::PromoteIntRes_MGATHER(MaskedGatherSDNode *N) {
 
 SDValue DAGTypeLegalizer::PromoteIntRes_MCOMPRESS(SDNode *N) {
   SDValue Vec = GetPromotedInteger(N->getOperand(0));
-  return DAG.getNode(ISD::MCOMPRESS, SDLoc(N), Vec.getValueType(), Vec, N->getOperand(1));
+  return DAG.getNode(ISD::MCOMPRESS, SDLoc(N), Vec.getValueType(), Vec,
+                     N->getOperand(1));
 }
 
 /// Promote the overflow flag of an overflowing arithmetic node.
@@ -1861,7 +1864,9 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
                                                  OpNo); break;
   case ISD::MSCATTER: Res = PromoteIntOp_MSCATTER(cast<MaskedScatterSDNode>(N),
                                                   OpNo); break;
-  case ISD::MCOMPRESS: Res = PromoteIntOp_MCOMPRESS(N, OpNo); break;
+  case ISD::MCOMPRESS:
+    Res = PromoteIntOp_MCOMPRESS(N, OpNo);
+    break;
   case ISD::VP_TRUNCATE:
   case ISD::TRUNCATE:     Res = PromoteIntOp_TRUNCATE(N); break;
   case ISD::BF16_TO_FP:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index ca32db26e511c..ebf0f63775d44 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -1528,11 +1528,11 @@ SDValue VectorLegalizer::ExpandMCOMPRESS(SDNode *Node) {
   for (unsigned I = 0; I < NumElms; I++) {
     SDValue Idx = DAG.getVectorIdxConstant(I, DL);
 
-    SDValue ValI =
-        DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, Vec, Idx);
-    SDValue OutPtr =
-        TLI.getVectorElementPointer(DAG, StackPtr, VecVT, OutPos);
-    Chain = DAG.getStore(Chain, DL, ValI, OutPtr, MachinePointerInfo::getUnknownStack(DAG.getMachineFunction()));
+    SDValue ValI = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, Vec, Idx);
+    SDValue OutPtr = TLI.getVectorElementPointer(DAG, StackPtr, VecVT, OutPos);
+    Chain = DAG.getStore(
+        Chain, DL, ValI, OutPtr,
+        MachinePointerInfo::getUnknownStack(DAG.getMachineFunction()));
 
     // Skip this for last element.
     if (I < NumElms - 1) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 665bab6121837..20461511ac92f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -6722,8 +6722,7 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
     setValue(&I, DAG.getNode(ISD::MCOMPRESS, sdl,
                              getValue(I.getArgOperand(0)).getValueType(),
                              getValue(I.getArgOperand(0)),
-                             getValue(I.getArgOperand(1)),
-                             Flags));
+                             getValue(I.getArgOperand(1)), Flags));
     return;
   case Intrinsic::powi:
     setValue(&I, ExpandPowI(sdl, getValue(I.getArgOperand(0)),

>From b686f83ef295b505e8f3316efc66ac8bea163ebc Mon Sep 17 00:00:00 2001
From: Lawrence Benson <github at lawben.com>
Date: Thu, 16 May 2024 11:15:41 +0200
Subject: [PATCH 07/14] Fix references in docs

---
 llvm/docs/LangRef.rst | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 773893b83a5d7..e2d3a986ddedf 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -24979,7 +24979,7 @@ Masked Vector Compress Intrinsic
 --------------------------------
 
 LLVM provides an intrinsic for compressing data within a vector based on a selection mask.
-Semantically, this is similar to :ref:``@llvm.masked.compressstore <_int_compressstore>`` but with weaker assumptions
+Semantically, this is similar to :ref:`llvm.masked.compressstore <int_compressstore>` but with weaker assumptions
 and without storing the results to memory, i.e., the data remains in the vector.
 
 .. _int_masked_compress:
@@ -25004,10 +25004,10 @@ Selects elements from input vector '``value``' according to the '``mask``'.
 All selected elements are written into adjacent lanes in the result vector, from lower to higher.
 The mask holds a bit for each vector lane, and is used to select elements to be kept.
 The number of valid lanes is equal to the number of active bits in the mask.
-The main difference to :ref:`llvm.masked.compressstore <_int_compressstore>` is that the remainder of the vector may
+The main difference to :ref:`llvm.masked.compressstore <int_compressstore>` is that the remainder of the vector may
 contain undefined values.
 This allows for branchless code and better optimization for all targets that do not support the explicit semantics of
-:ref:`llvm.masked.compressstore <_int_compressstore>`.
+:ref:`llvm.masked.compressstore <int_compressstore>`.
 The result vector can be written with a similar effect, as all the selected values are at the lower positions of the
 vector, but without requiring branches to avoid writes where the mask is 0.
 

>From 73cc28f42e5e18daf3b57997b4cc9084a78e6327 Mon Sep 17 00:00:00 2001
From: Lawrence Benson <github at lawben.com>
Date: Thu, 16 May 2024 12:44:59 +0200
Subject: [PATCH 08/14] Add widen for vector type legalization

---
 llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h |  1 +
 .../SelectionDAG/LegalizeVectorTypes.cpp      | 16 ++++
 llvm/test/CodeGen/AArch64/masked-compress.ll  | 82 +++++++++++++++++--
 3 files changed, 94 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 5fb14757f8991..83343ef5c173f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -975,6 +975,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue WidenVecRes_LOAD(SDNode* N);
   SDValue WidenVecRes_VP_LOAD(VPLoadSDNode *N);
   SDValue WidenVecRes_VP_STRIDED_LOAD(VPStridedLoadSDNode *N);
+  SDValue WidenVecRes_MCOMPRESS(SDNode* N);
   SDValue WidenVecRes_MLOAD(MaskedLoadSDNode* N);
   SDValue WidenVecRes_MGATHER(MaskedGatherSDNode* N);
   SDValue WidenVecRes_VP_GATHER(VPGatherSDNode* N);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 62e7febed6568..2a9c8adfc3fc6 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -4273,6 +4273,9 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::EXPERIMENTAL_VP_STRIDED_LOAD:
     Res = WidenVecRes_VP_STRIDED_LOAD(cast<VPStridedLoadSDNode>(N));
     break;
+  case ISD::MCOMPRESS:
+    Res = WidenVecRes_MCOMPRESS(N);
+    break;
   case ISD::MLOAD:
     Res = WidenVecRes_MLOAD(cast<MaskedLoadSDNode>(N));
     break;
@@ -5655,6 +5658,19 @@ SDValue DAGTypeLegalizer::WidenVecRes_VP_STRIDED_LOAD(VPStridedLoadSDNode *N) {
   return Res;
 }
 
+SDValue DAGTypeLegalizer::WidenVecRes_MCOMPRESS(SDNode *N) {
+  SDValue Vec = N->getOperand(0);
+  SDValue Mask = N->getOperand(1);
+  EVT WideVecVT = TLI.getTypeToTransformTo(*DAG.getContext(), Vec.getValueType());
+  EVT WideMaskVT = TLI.getTypeToTransformTo(*DAG.getContext(), Mask.getValueType());
+
+  // In the default expanded case, adding UNDEF values for the new widened lanes
+  // allows us to remove their access later, which reduces the number os stores.
+  SDValue WideVec = ModifyToType(Vec, WideVecVT, /*FillWithZeroes=*/true);
+  SDValue WideMask = ModifyToType(Mask, WideMaskVT, /*FillWithZeroes=*/true);
+  return DAG.getNode(ISD::MCOMPRESS, SDLoc(N), WideVecVT, WideVec, WideMask);
+}
+
 SDValue DAGTypeLegalizer::WidenVecRes_MLOAD(MaskedLoadSDNode *N) {
 
   EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(),N->getValueType(0));
diff --git a/llvm/test/CodeGen/AArch64/masked-compress.ll b/llvm/test/CodeGen/AArch64/masked-compress.ll
index a2f39b9620c95..92aea4108a7b7 100644
--- a/llvm/test/CodeGen/AArch64/masked-compress.ll
+++ b/llvm/test/CodeGen/AArch64/masked-compress.ll
@@ -292,8 +292,80 @@ define <4 x i4> @test_compress_illegal_element_type(<4 x i4> %vec, <4 x i1> %mas
     ret <4 x i4> %out
 }
 
-declare <4 x i32> @llvm.masked.compress.v4i32(<4 x i32>, <4 x i1>)
-declare <16 x i8> @llvm.masked.compress.v16i8(<16 x i8>, <16 x i1>)
-declare <4 x i4> @llvm.masked.compress.v4i4(<4 x i4>, <4 x i1>)
-declare <4 x i8> @llvm.masked.compress.v4i8(<4 x i8>, <4 x i1>)
-declare <8 x i32> @llvm.masked.compress.v8i32(<8 x i32>, <8 x i1>)
+define <3 x i32> @test_compress_narrow(<3 x i32> %vec, <3 x i1> %mask) {
+; CHECK-LABEL: test_compress_narrow:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    movi.2d v1, #0000000000000000
+; CHECK-NEXT:    mov x11, sp
+; CHECK-NEXT:    str s0, [sp]
+; CHECK-NEXT:    mov.h v1[0], w0
+; CHECK-NEXT:    mov.h v1[1], w1
+; CHECK-NEXT:    mov.h v1[2], w2
+; CHECK-NEXT:    ushll.4s v1, v1, #0
+; CHECK-NEXT:    shl.4s v1, v1, #31
+; CHECK-NEXT:    cmlt.4s v1, v1, #0
+; CHECK-NEXT:    mov.s w8, v1[1]
+; CHECK-NEXT:    mov.s w9, v1[2]
+; CHECK-NEXT:    fmov w10, s1
+; CHECK-NEXT:    bfi x11, x10, #2, #1
+; CHECK-NEXT:    and w10, w10, #0x1
+; CHECK-NEXT:    and w8, w8, #0x1
+; CHECK-NEXT:    and w9, w9, #0x1
+; CHECK-NEXT:    add w8, w10, w8
+; CHECK-NEXT:    mov x10, sp
+; CHECK-NEXT:    st1.s { v0 }[1], [x11]
+; CHECK-NEXT:    add w9, w8, w9
+; CHECK-NEXT:    orr x8, x10, x8, lsl #2
+; CHECK-NEXT:    bfi x10, x9, #2, #2
+; CHECK-NEXT:    st1.s { v0 }[2], [x8]
+; CHECK-NEXT:    str wzr, [x10]
+; CHECK-NEXT:    ldr q0, [sp], #16
+; CHECK-NEXT:    ret
+    %out = call <3 x i32> @llvm.masked.compress.v3i32(<3 x i32> %vec, <3 x i1> %mask)
+    ret <3 x i32> %out
+}
+
+define <3 x i3> @test_compress_narrow_illegal_element_type(<3 x i3> %vec, <3 x i1> %mask) {
+; CHECK-LABEL: test_compress_narrow_illegal_element_type:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    movi.2d v0, #0000000000000000
+; CHECK-NEXT:    fmov s1, w0
+; CHECK-NEXT:    add x12, sp, #8
+; CHECK-NEXT:    movi.4h v2, #7
+; CHECK-NEXT:    add x11, sp, #8
+; CHECK-NEXT:    mov.h v1[1], w1
+; CHECK-NEXT:    mov.h v0[0], w3
+; CHECK-NEXT:    mov.h v1[2], w2
+; CHECK-NEXT:    mov.h v0[1], w4
+; CHECK-NEXT:    mov.h v0[2], w5
+; CHECK-NEXT:    shl.4h v0, v0, #15
+; CHECK-NEXT:    cmlt.4h v0, v0, #0
+; CHECK-NEXT:    umov.h w8, v0[0]
+; CHECK-NEXT:    umov.h w9, v0[1]
+; CHECK-NEXT:    umov.h w10, v0[2]
+; CHECK-NEXT:    and.8b v0, v1, v2
+; CHECK-NEXT:    and w9, w9, #0x1
+; CHECK-NEXT:    and w13, w8, #0x1
+; CHECK-NEXT:    bfi x12, x8, #1, #1
+; CHECK-NEXT:    add w8, w13, w9
+; CHECK-NEXT:    and w9, w10, #0x1
+; CHECK-NEXT:    str h0, [sp, #8]
+; CHECK-NEXT:    orr x10, x11, x8, lsl #1
+; CHECK-NEXT:    add w8, w8, w9
+; CHECK-NEXT:    st1.h { v0 }[1], [x12]
+; CHECK-NEXT:    bfi x11, x8, #1, #2
+; CHECK-NEXT:    st1.h { v0 }[2], [x10]
+; CHECK-NEXT:    strh wzr, [x11]
+; CHECK-NEXT:    ldr d0, [sp, #8]
+; CHECK-NEXT:    umov.h w0, v0[0]
+; CHECK-NEXT:    umov.h w1, v0[1]
+; CHECK-NEXT:    umov.h w2, v0[2]
+; CHECK-NEXT:    add sp, sp, #16
+; CHECK-NEXT:    ret
+    %out = call <3 x i3> @llvm.masked.compress.v3i3(<3 x i3> %vec, <3 x i1> %mask)
+    ret <3 x i3> %out
+}

>From 8a613f37780c83e9874c0f5454ffec27eb9060f2 Mon Sep 17 00:00:00 2001
From: Lawrence Benson <github at lawben.com>
Date: Thu, 16 May 2024 14:48:43 +0200
Subject: [PATCH 09/14] Put expand logic in TargerLowering to avoid code
 duplication.

---
 llvm/include/llvm/CodeGen/TargetLowering.h    |  4 +
 .../SelectionDAG/LegalizeVectorOps.cpp        | 45 +----------
 .../SelectionDAG/LegalizeVectorTypes.cpp      | 56 ++-----------
 .../CodeGen/SelectionDAG/TargetLowering.cpp   | 47 +++++++++++
 llvm/test/CodeGen/AArch64/masked-compress.ll  | 79 +++++++++----------
 5 files changed, 95 insertions(+), 136 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 50a8c7eb75af5..ad79c22aa3e1b 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5472,6 +5472,10 @@ class TargetLowering : public TargetLoweringBase {
   /// method accepts vectors as its arguments.
   SDValue expandVectorSplice(SDNode *Node, SelectionDAG &DAG) const;
 
+  /// Expand a vector MCOMPRESS into a sequence of extract element, store
+  /// temporarily, advance store position, before re-loading the final vector.
+  SDValue expandMCOMPRESS(SDNode *Node, SelectionDAG &DAG) const;
+
   /// Legalize a SETCC or VP_SETCC with given LHS and RHS and condition code CC
   /// on the current target. A VP_SETCC will additionally be given a Mask
   /// and/or EVL not equal to SDValue().
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index ebf0f63775d44..bb2d3a8a3a0d1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -134,7 +134,6 @@ class VectorLegalizer {
   SDValue ExpandVSELECT(SDNode *Node);
   SDValue ExpandVP_SELECT(SDNode *Node);
   SDValue ExpandVP_MERGE(SDNode *Node);
-  SDValue ExpandMCOMPRESS(SDNode *Node);
   SDValue ExpandVP_REM(SDNode *Node);
   SDValue ExpandSELECT(SDNode *Node);
   std::pair<SDValue, SDValue> ExpandLoad(SDNode *N);
@@ -1104,7 +1103,7 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
 
     break;
   case ISD::MCOMPRESS:
-    Results.push_back(ExpandMCOMPRESS(Node));
+    Results.push_back(TLI.expandMCOMPRESS(Node, DAG));
     return;
   }
 
@@ -1510,48 +1509,6 @@ SDValue VectorLegalizer::ExpandVP_MERGE(SDNode *Node) {
   return DAG.getSelect(DL, Node->getValueType(0), FullMask, Op1, Op2);
 }
 
-SDValue VectorLegalizer::ExpandMCOMPRESS(SDNode *Node) {
-  SDLoc DL(Node);
-  SDValue Vec = Node->getOperand(0);
-  SDValue Mask = Node->getOperand(1);
-
-  EVT VecVT = Vec.getValueType();
-  EVT ScalarVT = VecVT.getScalarType();
-  EVT MaskScalarVT = Mask.getValueType().getScalarType();
-
-  SDValue StackPtr = DAG.CreateStackTemporary(VecVT);
-  SDValue Chain = DAG.getEntryNode();
-  SDValue OutPos = DAG.getConstant(0, DL, MVT::i32);
-
-  unsigned NumElms = VecVT.getVectorNumElements();
-  // Skip element zero, as we always copy this to the output vector.
-  for (unsigned I = 0; I < NumElms; I++) {
-    SDValue Idx = DAG.getVectorIdxConstant(I, DL);
-
-    SDValue ValI = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, Vec, Idx);
-    SDValue OutPtr = TLI.getVectorElementPointer(DAG, StackPtr, VecVT, OutPos);
-    Chain = DAG.getStore(
-        Chain, DL, ValI, OutPtr,
-        MachinePointerInfo::getUnknownStack(DAG.getMachineFunction()));
-
-    // Skip this for last element.
-    if (I < NumElms - 1) {
-      // Get the mask value and add it to the current output position. This
-      // either increments by 1 if MaskI is true or adds 0 otherwise.
-      SDValue MaskI =
-          DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MaskScalarVT, Mask, Idx);
-      MaskI = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, MaskI);
-      MaskI = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, MaskI);
-      OutPos = DAG.getNode(ISD::ADD, DL, MVT::i32, OutPos, MaskI);
-    }
-  }
-
-  int FI = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
-  MachinePointerInfo PtrInfo =
-      MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI);
-  return DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);
-}
-
 SDValue VectorLegalizer::ExpandVP_REM(SDNode *Node) {
   // Implement VP_SREM/UREM in terms of VP_SDIV/VP_UDIV, VP_MUL, VP_SUB.
   EVT VT = Node->getValueType(0);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 2a9c8adfc3fc6..34146e9a12308 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -2312,56 +2312,14 @@ void DAGTypeLegalizer::SplitVecRes_MCOMPRESS(SDNode *N, SDValue &Lo,
   // This is not "trivial", as there is a dependency between the two subvectors.
   // Depending on the number of 1s in the mask, the elements from the Hi vector
   // need to be moved to the Lo vector. So we just perform this as one "big"
-  // operation (analogously to the default MCOMPRESS expand implementation), by
-  // writing to memory and then loading the Lo and Hi vectors from that. This
-  // gets rid of MCOMPRESS and all other operands can be legalized later.
-  SDLoc DL(N);
-  SDValue Vec = N->getOperand(0);
-  SDValue Mask = N->getOperand(1);
-
-  EVT VecVT = Vec.getValueType();
-  EVT SubVecVT = VecVT.getHalfNumVectorElementsVT(*DAG.getContext());
-  EVT ScalarVT = VecVT.getScalarType();
-  EVT MaskScalarVT = Mask.getValueType().getScalarType();
-
-  // TODO: This code is duplicated here and in LegalizeVectorOps.
-  SDValue StackPtr = DAG.CreateStackTemporary(VecVT);
-  SDValue Chain = DAG.getEntryNode();
-  SDValue OutPos = DAG.getConstant(0, DL, MVT::i32);
-
-  unsigned NumElms = VecVT.getVectorNumElements();
-  // Skip element zero, as we always copy this to the output vector.
-  for (unsigned I = 0; I < NumElms; I++) {
-    SDValue Idx = DAG.getVectorIdxConstant(I, DL);
-
-    SDValue ValI = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, Vec, Idx);
-    SDValue OutPtr = TLI.getVectorElementPointer(DAG, StackPtr, VecVT, OutPos);
-    Chain = DAG.getStore(
-        Chain, DL, ValI, OutPtr,
-        MachinePointerInfo::getUnknownStack(DAG.getMachineFunction()));
-
-    // Skip this for last element.
-    if (I < NumElms - 1) {
-      // Get the mask value and add it to the current output position. This
-      // either increments by 1 if MaskI is true or adds 0 otherwise.
-      SDValue MaskI =
-          DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MaskScalarVT, Mask, Idx);
-      MaskI = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, MaskI);
-      MaskI = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, MaskI);
-      OutPos = DAG.getNode(ISD::ADD, DL, MVT::i32, OutPos, MaskI);
-    }
-  }
-
-  int FI = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
-  MachinePointerInfo PtrInfo =
-      MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI);
-  SDValue HiPtr = TLI.getVectorElementPointer(
-      DAG, StackPtr, VecVT, DAG.getConstant(NumElms / 2, DL, MVT::i32));
+  // operation and then extract the Lo and Hi vectors from that. This gets rid
+  // of MCOMPRESS and all other operands can be legalized later.
+  SDValue Compressed = TLI.expandMCOMPRESS(N, DAG);
 
-  Lo = DAG.getLoad(SubVecVT, DL, Chain, StackPtr, PtrInfo);
-  Hi = DAG.getLoad(
-      SubVecVT, DL, Chain, HiPtr,
-      MachinePointerInfo::getUnknownStack(DAG.getMachineFunction()));
+  SDLoc DL(N);
+  EVT SubVecVT = Compressed.getValueType().getHalfNumVectorElementsVT(*DAG.getContext());
+  Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVecVT, Compressed, DAG.getVectorIdxConstant(0, DL));
+  Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVecVT, Compressed, DAG.getVectorIdxConstant(SubVecVT.getVectorNumElements(), DL));
 }
 
 void DAGTypeLegalizer::SplitVecRes_SETCC(SDNode *N, SDValue &Lo, SDValue &Hi) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 7beaeb9b7a171..f3d48e60c4b58 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -11226,6 +11226,53 @@ SDValue TargetLowering::expandVectorSplice(SDNode *Node,
                      MachinePointerInfo::getUnknownStack(MF));
 }
 
+SDValue TargetLowering::expandMCOMPRESS(SDNode *Node, SelectionDAG &DAG) const {
+  SDLoc DL(Node);
+  SDValue Vec = Node->getOperand(0);
+  SDValue Mask = Node->getOperand(1);
+
+  EVT VecVT = Vec.getValueType();
+  EVT ScalarVT = VecVT.getScalarType();
+  EVT MaskScalarVT = Mask.getValueType().getScalarType();
+
+  if (VecVT.isScalableVector())
+    report_fatal_error(
+        "Expanding masked_compress for scalable vectors is undefined.");
+
+  SDValue StackPtr = DAG.CreateStackTemporary(VecVT.getStoreSize(), DAG.getReducedAlign(VecVT, /*UseABI=*/false));
+  SDValue Chain = DAG.getEntryNode();
+  SDValue OutPos = DAG.getConstant(0, DL, MVT::i32);
+
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  unsigned NumElms = VecVT.getVectorNumElements();
+  // Skip element zero, as we always copy this to the output vector.
+  for (unsigned I = 0; I < NumElms; I++) {
+    SDValue Idx = DAG.getVectorIdxConstant(I, DL);
+
+    SDValue ValI = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, Vec, Idx);
+    SDValue OutPtr = TLI.getVectorElementPointer(DAG, StackPtr, VecVT, OutPos);
+    Chain = DAG.getStore(
+        Chain, DL, ValI, OutPtr,
+        MachinePointerInfo::getUnknownStack(DAG.getMachineFunction()));
+
+    // Skip this for last element.
+    if (I < NumElms - 1) {
+      // Get the mask value and add it to the current output position. This
+      // either increments by 1 if MaskI is true or adds 0 otherwise.
+      SDValue MaskI =
+          DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MaskScalarVT, Mask, Idx);
+      MaskI = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, MaskI);
+      MaskI = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, MaskI);
+      OutPos = DAG.getNode(ISD::ADD, DL, MVT::i32, OutPos, MaskI);
+    }
+  }
+
+  int FI = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
+  MachinePointerInfo PtrInfo =
+      MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI);
+  return DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);
+}
+
 bool TargetLowering::LegalizeSetCCCondCode(SelectionDAG &DAG, EVT VT,
                                            SDValue &LHS, SDValue &RHS,
                                            SDValue &CC, SDValue Mask,
diff --git a/llvm/test/CodeGen/AArch64/masked-compress.ll b/llvm/test/CodeGen/AArch64/masked-compress.ll
index 92aea4108a7b7..9057e6f8967fa 100644
--- a/llvm/test/CodeGen/AArch64/masked-compress.ll
+++ b/llvm/test/CodeGen/AArch64/masked-compress.ll
@@ -154,57 +154,50 @@ define <16 x i8> @test_compress_v16i8(<16 x i8> %vec, <16 x i1> %mask) {
 define <8 x i32> @test_compress_large(<8 x i32> %vec, <8 x i1> %mask) {
 ; CHECK-LABEL: test_compress_large:
 ; CHECK:       ; %bb.0:
-; CHECK-NEXT:    stp x29, x30, [sp, #-16]! ; 16-byte Folded Spill
-; CHECK-NEXT:    sub x9, sp, #48
-; CHECK-NEXT:    mov x29, sp
-; CHECK-NEXT:    and sp, x9, #0xffffffffffffffe0
-; CHECK-NEXT:    .cfi_def_cfa w29, 16
-; CHECK-NEXT:    .cfi_offset w30, -8
-; CHECK-NEXT:    .cfi_offset w29, -16
+; CHECK-NEXT:    sub sp, sp, #32
+; CHECK-NEXT:    .cfi_def_cfa_offset 32
 ; CHECK-NEXT:    ; kill: def $d2 killed $d2 def $q2
-; CHECK-NEXT:    umov.b w9, v2[0]
-; CHECK-NEXT:    umov.b w10, v2[1]
+; CHECK-NEXT:    umov.b w8, v2[0]
+; CHECK-NEXT:    umov.b w9, v2[1]
 ; CHECK-NEXT:    mov x12, sp
-; CHECK-NEXT:    umov.b w11, v2[2]
+; CHECK-NEXT:    umov.b w10, v2[2]
 ; CHECK-NEXT:    umov.b w13, v2[3]
-; CHECK-NEXT:    mov x8, sp
+; CHECK-NEXT:    mov x11, sp
 ; CHECK-NEXT:    umov.b w14, v2[4]
 ; CHECK-NEXT:    str s0, [sp]
-; CHECK-NEXT:    bfi x12, x9, #2, #1
-; CHECK-NEXT:    and w10, w10, #0x1
 ; CHECK-NEXT:    and w9, w9, #0x1
-; CHECK-NEXT:    add w9, w9, w10
-; CHECK-NEXT:    and w10, w11, #0x1
-; CHECK-NEXT:    and w13, w13, #0x1
-; CHECK-NEXT:    orr x11, x8, x9, lsl #2
-; CHECK-NEXT:    add w9, w9, w10
+; CHECK-NEXT:    and w15, w8, #0x1
+; CHECK-NEXT:    bfi x12, x8, #2, #1
+; CHECK-NEXT:    and w8, w10, #0x1
+; CHECK-NEXT:    add w9, w15, w9
 ; CHECK-NEXT:    umov.b w10, v2[5]
+; CHECK-NEXT:    add w8, w9, w8
+; CHECK-NEXT:    orr x15, x11, x9, lsl #2
+; CHECK-NEXT:    umov.b w9, v2[6]
 ; CHECK-NEXT:    st1.s { v0 }[1], [x12]
-; CHECK-NEXT:    add w13, w9, w13
-; CHECK-NEXT:    orr x9, x8, x9, lsl #2
-; CHECK-NEXT:    st1.s { v0 }[2], [x11]
-; CHECK-NEXT:    umov.b w11, v2[6]
-; CHECK-NEXT:    mov x12, sp
-; CHECK-NEXT:    and w14, w14, #0x1
-; CHECK-NEXT:    bfi x12, x13, #2, #3
-; CHECK-NEXT:    st1.s { v0 }[3], [x9]
-; CHECK-NEXT:    add w13, w13, w14
-; CHECK-NEXT:    and w9, w10, #0x1
-; CHECK-NEXT:    mov x10, sp
-; CHECK-NEXT:    add w9, w13, w9
-; CHECK-NEXT:    mov x14, sp
-; CHECK-NEXT:    str s1, [x12]
-; CHECK-NEXT:    and w11, w11, #0x1
-; CHECK-NEXT:    bfi x10, x9, #2, #3
-; CHECK-NEXT:    bfi x14, x13, #2, #3
-; CHECK-NEXT:    add w9, w9, w11
-; CHECK-NEXT:    bfi x8, x9, #2, #3
-; CHECK-NEXT:    st1.s { v1 }[1], [x14]
-; CHECK-NEXT:    st1.s { v1 }[2], [x10]
-; CHECK-NEXT:    st1.s { v1 }[3], [x8]
-; CHECK-NEXT:    ldp q0, q1, [sp]
-; CHECK-NEXT:    mov sp, x29
-; CHECK-NEXT:    ldp x29, x30, [sp], #16 ; 16-byte Folded Reload
+; CHECK-NEXT:    add x12, x11, w8, uxtw #2
+; CHECK-NEXT:    and w13, w13, #0x1
+; CHECK-NEXT:    st1.s { v0 }[2], [x15]
+; CHECK-NEXT:    add w8, w8, w13
+; CHECK-NEXT:    st1.s { v0 }[3], [x12]
+; CHECK-NEXT:    and w12, w14, #0x1
+; CHECK-NEXT:    and w10, w10, #0x1
+; CHECK-NEXT:    add w12, w8, w12
+; CHECK-NEXT:    and w9, w9, #0x1
+; CHECK-NEXT:    and x8, x8, #0x7
+; CHECK-NEXT:    add w10, w12, w10
+; CHECK-NEXT:    and x12, x12, #0x7
+; CHECK-NEXT:    str s1, [x11, x8, lsl #2]
+; CHECK-NEXT:    add w9, w10, w9
+; CHECK-NEXT:    and x10, x10, #0x7
+; CHECK-NEXT:    add x12, x11, x12, lsl #2
+; CHECK-NEXT:    and x9, x9, #0x7
+; CHECK-NEXT:    add x8, x11, x10, lsl #2
+; CHECK-NEXT:    add x9, x11, x9, lsl #2
+; CHECK-NEXT:    st1.s { v1 }[1], [x12]
+; CHECK-NEXT:    st1.s { v1 }[2], [x8]
+; CHECK-NEXT:    st1.s { v1 }[3], [x9]
+; CHECK-NEXT:    ldp q0, q1, [sp], #32
 ; CHECK-NEXT:    ret
     %out = call <8 x i32> @llvm.masked.compress.v8i32(<8 x i32> %vec, <8 x i1> %mask)
     ret <8 x i32> %out

>From a4df959d8047012661d868e48722711a77e60a68 Mon Sep 17 00:00:00 2001
From: Lawrence Benson <github at lawben.com>
Date: Thu, 16 May 2024 14:49:21 +0200
Subject: [PATCH 10/14] Fix formatting

---
 llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h    |  2 +-
 .../CodeGen/SelectionDAG/LegalizeVectorTypes.cpp | 16 +++++++++++-----
 llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp |  3 ++-
 3 files changed, 14 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 83343ef5c173f..26ce361f2d580 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -975,7 +975,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue WidenVecRes_LOAD(SDNode* N);
   SDValue WidenVecRes_VP_LOAD(VPLoadSDNode *N);
   SDValue WidenVecRes_VP_STRIDED_LOAD(VPStridedLoadSDNode *N);
-  SDValue WidenVecRes_MCOMPRESS(SDNode* N);
+  SDValue WidenVecRes_MCOMPRESS(SDNode *N);
   SDValue WidenVecRes_MLOAD(MaskedLoadSDNode* N);
   SDValue WidenVecRes_MGATHER(MaskedGatherSDNode* N);
   SDValue WidenVecRes_VP_GATHER(VPGatherSDNode* N);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 34146e9a12308..e94316003c35c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -2317,9 +2317,13 @@ void DAGTypeLegalizer::SplitVecRes_MCOMPRESS(SDNode *N, SDValue &Lo,
   SDValue Compressed = TLI.expandMCOMPRESS(N, DAG);
 
   SDLoc DL(N);
-  EVT SubVecVT = Compressed.getValueType().getHalfNumVectorElementsVT(*DAG.getContext());
-  Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVecVT, Compressed, DAG.getVectorIdxConstant(0, DL));
-  Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVecVT, Compressed, DAG.getVectorIdxConstant(SubVecVT.getVectorNumElements(), DL));
+  EVT SubVecVT =
+      Compressed.getValueType().getHalfNumVectorElementsVT(*DAG.getContext());
+  Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVecVT, Compressed,
+                   DAG.getVectorIdxConstant(0, DL));
+  Hi = DAG.getNode(
+      ISD::EXTRACT_SUBVECTOR, DL, SubVecVT, Compressed,
+      DAG.getVectorIdxConstant(SubVecVT.getVectorNumElements(), DL));
 }
 
 void DAGTypeLegalizer::SplitVecRes_SETCC(SDNode *N, SDValue &Lo, SDValue &Hi) {
@@ -5619,8 +5623,10 @@ SDValue DAGTypeLegalizer::WidenVecRes_VP_STRIDED_LOAD(VPStridedLoadSDNode *N) {
 SDValue DAGTypeLegalizer::WidenVecRes_MCOMPRESS(SDNode *N) {
   SDValue Vec = N->getOperand(0);
   SDValue Mask = N->getOperand(1);
-  EVT WideVecVT = TLI.getTypeToTransformTo(*DAG.getContext(), Vec.getValueType());
-  EVT WideMaskVT = TLI.getTypeToTransformTo(*DAG.getContext(), Mask.getValueType());
+  EVT WideVecVT =
+      TLI.getTypeToTransformTo(*DAG.getContext(), Vec.getValueType());
+  EVT WideMaskVT =
+      TLI.getTypeToTransformTo(*DAG.getContext(), Mask.getValueType());
 
   // In the default expanded case, adding UNDEF values for the new widened lanes
   // allows us to remove their access later, which reduces the number os stores.
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index f3d48e60c4b58..6509f4441d266 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -11239,7 +11239,8 @@ SDValue TargetLowering::expandMCOMPRESS(SDNode *Node, SelectionDAG &DAG) const {
     report_fatal_error(
         "Expanding masked_compress for scalable vectors is undefined.");
 
-  SDValue StackPtr = DAG.CreateStackTemporary(VecVT.getStoreSize(), DAG.getReducedAlign(VecVT, /*UseABI=*/false));
+  SDValue StackPtr = DAG.CreateStackTemporary(
+      VecVT.getStoreSize(), DAG.getReducedAlign(VecVT, /*UseABI=*/false));
   SDValue Chain = DAG.getEntryNode();
   SDValue OutPos = DAG.getConstant(0, DL, MVT::i32);
 

>From 17004b99ba19b1886f6dc38d331e78bd5e6ebe44 Mon Sep 17 00:00:00 2001
From: Lawrence Benson <github at lawben.com>
Date: Fri, 17 May 2024 12:07:23 +0200
Subject: [PATCH 11/14] Add basic lowering of MCOMPRESS in GlobalISel

---
 .../llvm/CodeGen/GlobalISel/LegalizerHelper.h |  1 +
 llvm/include/llvm/Support/TargetOpcodes.def   |  3 ++
 llvm/include/llvm/Target/GenericOpcodes.td    |  7 +++
 .../Target/GlobalISel/SelectionDAGCompat.td   |  1 +
 llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp  |  2 +
 .../CodeGen/GlobalISel/LegalizerHelper.cpp    | 49 +++++++++++++++++++
 .../CodeGen/SelectionDAG/TargetLowering.cpp   |  1 -
 .../AArch64/GISel/AArch64LegalizerInfo.cpp    |  4 ++
 8 files changed, 67 insertions(+), 1 deletion(-)

diff --git a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h
index 284f434fbb9b0..132ed3fab57fc 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h
@@ -410,6 +410,7 @@ class LegalizerHelper {
   LegalizeResult lowerUnmergeValues(MachineInstr &MI);
   LegalizeResult lowerExtractInsertVectorElt(MachineInstr &MI);
   LegalizeResult lowerShuffleVector(MachineInstr &MI);
+  LegalizeResult lowerMCOMPRESS(MachineInstr &MI);
   Register getDynStackAllocTargetPtr(Register SPReg, Register AllocSize,
                                      Align Alignment, LLT PtrTy);
   LegalizeResult lowerDynStackAlloc(MachineInstr &MI);
diff --git a/llvm/include/llvm/Support/TargetOpcodes.def b/llvm/include/llvm/Support/TargetOpcodes.def
index 559a588c25148..f85aca25fb945 100644
--- a/llvm/include/llvm/Support/TargetOpcodes.def
+++ b/llvm/include/llvm/Support/TargetOpcodes.def
@@ -751,6 +751,9 @@ HANDLE_TARGET_OPCODE(G_SHUFFLE_VECTOR)
 /// Generic splatvector.
 HANDLE_TARGET_OPCODE(G_SPLAT_VECTOR)
 
+/// Generic masked compress.
+HANDLE_TARGET_OPCODE(G_MCOMPRESS)
+
 /// Generic count trailing zeroes.
 HANDLE_TARGET_OPCODE(G_CTTZ)
 
diff --git a/llvm/include/llvm/Target/GenericOpcodes.td b/llvm/include/llvm/Target/GenericOpcodes.td
index c40498e554215..7cbb6cc1bb737 100644
--- a/llvm/include/llvm/Target/GenericOpcodes.td
+++ b/llvm/include/llvm/Target/GenericOpcodes.td
@@ -1500,6 +1500,13 @@ def G_SPLAT_VECTOR: GenericInstruction {
   let hasSideEffects = false;
 }
 
+// Generic masked compress.
+def G_MCOMPRESS: GenericInstruction {
+  let OutOperandList = (outs type0:$dst);
+  let InOperandList = (ins type1:$vec, type2:$mask);
+  let hasSideEffects = false;
+}
+
 //------------------------------------------------------------------------------
 // Vector reductions
 //------------------------------------------------------------------------------
diff --git a/llvm/include/llvm/Target/GlobalISel/SelectionDAGCompat.td b/llvm/include/llvm/Target/GlobalISel/SelectionDAGCompat.td
index 8fa0e4b86d6dc..69fb235dfc2b5 100644
--- a/llvm/include/llvm/Target/GlobalISel/SelectionDAGCompat.td
+++ b/llvm/include/llvm/Target/GlobalISel/SelectionDAGCompat.td
@@ -186,6 +186,7 @@ def : GINodeEquiv<G_VECREDUCE_UMAX, vecreduce_umax>;
 def : GINodeEquiv<G_VECREDUCE_SMIN, vecreduce_smin>;
 def : GINodeEquiv<G_VECREDUCE_SMAX, vecreduce_smax>;
 def : GINodeEquiv<G_VECREDUCE_ADD, vecreduce_add>;
+def : GINodeEquiv<G_MCOMPRESS, masked_compress>;
 
 def : GINodeEquiv<G_STRICT_FADD, strict_fadd>;
 def : GINodeEquiv<G_STRICT_FSUB, strict_fsub>;
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 5289b993476db..858ba547ac22a 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -1986,6 +1986,8 @@ unsigned IRTranslator::getSimpleIntrinsicOpcode(Intrinsic::ID ID) {
       return TargetOpcode::G_VECREDUCE_UMAX;
     case Intrinsic::vector_reduce_umin:
       return TargetOpcode::G_VECREDUCE_UMIN;
+    case Intrinsic::masked_compress:
+      return TargetOpcode::G_MCOMPRESS;
     case Intrinsic::lround:
       return TargetOpcode::G_LROUND;
     case Intrinsic::llround:
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index 6a76ad7f5db74..5bab609cdefca 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -3946,6 +3946,8 @@ LegalizerHelper::lower(MachineInstr &MI, unsigned TypeIdx, LLT LowerHintTy) {
     return lowerExtractInsertVectorElt(MI);
   case G_SHUFFLE_VECTOR:
     return lowerShuffleVector(MI);
+  case G_MCOMPRESS:
+    return lowerMCOMPRESS(MI);
   case G_DYN_STACKALLOC:
     return lowerDynStackAlloc(MI);
   case G_STACKSAVE:
@@ -7502,6 +7504,53 @@ LegalizerHelper::lowerShuffleVector(MachineInstr &MI) {
   return Legalized;
 }
 
+LegalizerHelper::LegalizeResult
+LegalizerHelper::lowerMCOMPRESS(llvm::MachineInstr &MI) {
+  auto [Dst, DstTy, Vec, VecTy, Mask, MaskTy] = MI.getFirst3RegLLTs();
+
+  if (VecTy.isScalableVector())
+    report_fatal_error(
+        "Lowering masked_compress for scalable vectors is undefined.");
+
+  MachinePointerInfo PtrInfo;
+  Register StackPtr =
+      createStackTemporary(TypeSize::getFixed(VecTy.getSizeInBytes()),
+                           getStackTemporaryAlignment(VecTy), PtrInfo)
+          .getReg(0);
+
+  LLT IdxTy = LLT::scalar(32);
+  LLT ValTy = VecTy.getElementType();
+  Align ValAlign = getStackTemporaryAlignment(ValTy);
+
+  Register OutPos = MIRBuilder.buildConstant(IdxTy, 0).getReg(0);
+
+  unsigned NumElmts = VecTy.getNumElements();
+  for (unsigned I = 0; I < NumElmts; ++I) {
+    auto Idx = MIRBuilder.buildConstant(IdxTy, I);
+    auto Val = MIRBuilder.buildExtractVectorElement(ValTy, Vec, Idx);
+    Register ElmtPtr = getVectorElementPointer(StackPtr, VecTy, OutPos);
+    MIRBuilder.buildStore(Val, ElmtPtr, PtrInfo, ValAlign);
+
+    if (I < NumElmts - 1) {
+      LLT MaskITy = MaskTy.getElementType();
+      auto MaskI = MIRBuilder.buildExtractVectorElement(MaskITy, Mask, Idx);
+      if (MaskITy.getSizeInBits() > 1)
+        MaskI = MIRBuilder.buildTrunc(LLT::scalar(1), MaskI);
+
+      MaskI = MIRBuilder.buildZExt(IdxTy, MaskI);
+      OutPos = MIRBuilder.buildAdd(IdxTy, OutPos, MaskI).getReg(0);
+    }
+  }
+
+  Align VecAlign = getStackTemporaryAlignment(VecTy);
+  MIRBuilder.buildLoad(Dst, StackPtr, PtrInfo, VecAlign);
+
+  MI.eraseFromParent();
+  // TODO: This is not true! We don't know if the input vector type is legal.
+  //       Find out how to assert this!
+  return Legalized;
+}
+
 Register LegalizerHelper::getDynStackAllocTargetPtr(Register SPReg,
                                                     Register AllocSize,
                                                     Align Alignment,
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 6509f4441d266..a8275012ef0d6 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -11246,7 +11246,6 @@ SDValue TargetLowering::expandMCOMPRESS(SDNode *Node, SelectionDAG &DAG) const {
 
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
   unsigned NumElms = VecVT.getVectorNumElements();
-  // Skip element zero, as we always copy this to the output vector.
   for (unsigned I = 0; I < NumElms; I++) {
     SDValue Idx = DAG.getVectorIdxConstant(I, DL);
 
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index d4aac94d24f12..98dacdd3c6eb1 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -1132,6 +1132,10 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       .scalarize(1)
       .lower();
 
+  // TOOD: FIX ThiS!
+  // DO NOT COMMIT
+  getActionDefinitionsBuilder(G_MCOMPRESS).lower();
+
   getActionDefinitionsBuilder({G_FSHL, G_FSHR})
       .customFor({{s32, s32}, {s32, s64}, {s64, s64}})
       .lower();

>From 984cad1ee82cdce9c866efaa87bf600b79455150 Mon Sep 17 00:00:00 2001
From: Lawrence Benson <github at lawben.com>
Date: Fri, 17 May 2024 15:52:35 +0200
Subject: [PATCH 12/14] Add basic AArch64 MIR test

---
 .../GlobalISel/legalize-masked-compress.mir   | 67 +++++++++++++++++++
 1 file changed, 67 insertions(+)
 create mode 100644 llvm/test/CodeGen/AArch64/GlobalISel/legalize-masked-compress.mir

diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/legalize-masked-compress.mir b/llvm/test/CodeGen/AArch64/GlobalISel/legalize-masked-compress.mir
new file mode 100644
index 0000000000000..2d28b4a597ed3
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/legalize-masked-compress.mir
@@ -0,0 +1,67 @@
+# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py
+# RUN: llc -mtriple=aarch64 -run-pass=legalizer %s -o - | FileCheck %s
+---
+name:            test_mcompress_v4s32
+body:             |
+  bb.0:
+    liveins: $q0, $d1
+
+    ; CHECK-LABEL: name: test_mcompress_v4s32
+    ; CHECK: liveins: $q0, $d1
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(<4 x s32>) = COPY $q0
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:_(<4 x s16>) = COPY $d1
+    ; CHECK-NEXT: [[FRAME_INDEX:%[0-9]+]]:_(p0) = G_FRAME_INDEX %stack.0
+    ; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 0
+    ; CHECK-NEXT: [[C1:%[0-9]+]]:_(s64) = G_CONSTANT i64 0
+    ; CHECK-NEXT: [[EVEC:%[0-9]+]]:_(s32) = G_EXTRACT_VECTOR_ELT [[COPY]](<4 x s32>), [[C1]](s64)
+    ; CHECK-NEXT: [[C2:%[0-9]+]]:_(s64) = G_CONSTANT i64 4
+    ; CHECK-NEXT: [[MUL:%[0-9]+]]:_(s64) = G_MUL [[C1]], [[C2]]
+    ; CHECK-NEXT: [[PTR_ADD:%[0-9]+]]:_(p0) = G_PTR_ADD [[FRAME_INDEX]], [[MUL]](s64)
+    ; CHECK-NEXT: G_STORE [[EVEC]](s32), [[PTR_ADD]](p0) :: (store (s32) into %stack.0)
+    ; CHECK-NEXT: [[EVEC1:%[0-9]+]]:_(s16) = G_EXTRACT_VECTOR_ELT [[COPY1]](<4 x s16>), [[C1]](s64)
+    ; CHECK-NEXT: [[ANYEXT:%[0-9]+]]:_(s32) = G_ANYEXT [[EVEC1]](s16)
+    ; CHECK-NEXT: [[C3:%[0-9]+]]:_(s32) = G_CONSTANT i32 1
+    ; CHECK-NEXT: [[AND:%[0-9]+]]:_(s32) = G_AND [[ANYEXT]], [[C3]]
+    ; CHECK-NEXT: [[ADD:%[0-9]+]]:_(s32) = G_ADD [[C]], [[AND]]
+    ; CHECK-NEXT: [[C4:%[0-9]+]]:_(s64) = G_CONSTANT i64 1
+    ; CHECK-NEXT: [[EVEC2:%[0-9]+]]:_(s32) = G_EXTRACT_VECTOR_ELT [[COPY]](<4 x s32>), [[C4]](s64)
+    ; CHECK-NEXT: [[C5:%[0-9]+]]:_(s32) = G_CONSTANT i32 3
+    ; CHECK-NEXT: [[AND1:%[0-9]+]]:_(s32) = G_AND [[ADD]], [[C5]]
+    ; CHECK-NEXT: [[SEXT:%[0-9]+]]:_(s64) = G_SEXT [[AND1]](s32)
+    ; CHECK-NEXT: [[MUL1:%[0-9]+]]:_(s64) = G_MUL [[SEXT]], [[C2]]
+    ; CHECK-NEXT: [[PTR_ADD1:%[0-9]+]]:_(p0) = G_PTR_ADD [[FRAME_INDEX]], [[MUL1]](s64)
+    ; CHECK-NEXT: G_STORE [[EVEC2]](s32), [[PTR_ADD1]](p0) :: (store (s32) into %stack.0)
+    ; CHECK-NEXT: [[EVEC3:%[0-9]+]]:_(s16) = G_EXTRACT_VECTOR_ELT [[COPY1]](<4 x s16>), [[C4]](s64)
+    ; CHECK-NEXT: [[ANYEXT1:%[0-9]+]]:_(s32) = G_ANYEXT [[EVEC3]](s16)
+    ; CHECK-NEXT: [[AND2:%[0-9]+]]:_(s32) = G_AND [[ANYEXT1]], [[C3]]
+    ; CHECK-NEXT: [[ADD1:%[0-9]+]]:_(s32) = G_ADD [[ADD]], [[AND2]]
+    ; CHECK-NEXT: [[C6:%[0-9]+]]:_(s64) = G_CONSTANT i64 2
+    ; CHECK-NEXT: [[EVEC4:%[0-9]+]]:_(s32) = G_EXTRACT_VECTOR_ELT [[COPY]](<4 x s32>), [[C6]](s64)
+    ; CHECK-NEXT: [[AND3:%[0-9]+]]:_(s32) = G_AND [[ADD1]], [[C5]]
+    ; CHECK-NEXT: [[SEXT1:%[0-9]+]]:_(s64) = G_SEXT [[AND3]](s32)
+    ; CHECK-NEXT: [[MUL2:%[0-9]+]]:_(s64) = G_MUL [[SEXT1]], [[C2]]
+    ; CHECK-NEXT: [[PTR_ADD2:%[0-9]+]]:_(p0) = G_PTR_ADD [[FRAME_INDEX]], [[MUL2]](s64)
+    ; CHECK-NEXT: G_STORE [[EVEC4]](s32), [[PTR_ADD2]](p0) :: (store (s32) into %stack.0)
+    ; CHECK-NEXT: [[EVEC5:%[0-9]+]]:_(s16) = G_EXTRACT_VECTOR_ELT [[COPY1]](<4 x s16>), [[C6]](s64)
+    ; CHECK-NEXT: [[ANYEXT2:%[0-9]+]]:_(s32) = G_ANYEXT [[EVEC5]](s16)
+    ; CHECK-NEXT: [[AND4:%[0-9]+]]:_(s32) = G_AND [[ANYEXT2]], [[C3]]
+    ; CHECK-NEXT: [[ADD2:%[0-9]+]]:_(s32) = G_ADD [[ADD1]], [[AND4]]
+    ; CHECK-NEXT: [[C7:%[0-9]+]]:_(s64) = G_CONSTANT i64 3
+    ; CHECK-NEXT: [[EVEC6:%[0-9]+]]:_(s32) = G_EXTRACT_VECTOR_ELT [[COPY]](<4 x s32>), [[C7]](s64)
+    ; CHECK-NEXT: [[AND5:%[0-9]+]]:_(s32) = G_AND [[ADD2]], [[C5]]
+    ; CHECK-NEXT: [[SEXT2:%[0-9]+]]:_(s64) = G_SEXT [[AND5]](s32)
+    ; CHECK-NEXT: [[MUL3:%[0-9]+]]:_(s64) = G_MUL [[SEXT2]], [[C2]]
+    ; CHECK-NEXT: [[PTR_ADD3:%[0-9]+]]:_(p0) = G_PTR_ADD [[FRAME_INDEX]], [[MUL3]](s64)
+    ; CHECK-NEXT: G_STORE [[EVEC6]](s32), [[PTR_ADD3]](p0) :: (store (s32) into %stack.0)
+    ; CHECK-NEXT: [[LOAD:%[0-9]+]]:_(<4 x s32>) = G_LOAD [[FRAME_INDEX]](p0) :: (load (<4 x s32>) from %stack.0)
+    ; CHECK-NEXT: $q0 = COPY [[LOAD]](<4 x s32>)
+    ; CHECK-NEXT: RET_ReallyLR implicit $q0
+    %0:_(<4 x s32>) = COPY $q0
+    %1:_(<4 x s16>) = COPY $d1
+    %2:_(<4 x s32>) = G_MCOMPRESS %0(<4 x s32>), %1(<4 x s16>)
+    $q0 = COPY %2(<4 x s32>)
+    RET_ReallyLR implicit $q0
+...
+
+

>From 0ea24159782792f4164737b53cd4a077aca91b52 Mon Sep 17 00:00:00 2001
From: Lawrence Benson <github at lawben.com>
Date: Fri, 17 May 2024 16:02:41 +0200
Subject: [PATCH 13/14] Address PR comments

---
 llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp | 6 ------
 1 file changed, 6 deletions(-)

diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index 5bab609cdefca..dc29a080882ca 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -7508,10 +7508,6 @@ LegalizerHelper::LegalizeResult
 LegalizerHelper::lowerMCOMPRESS(llvm::MachineInstr &MI) {
   auto [Dst, DstTy, Vec, VecTy, Mask, MaskTy] = MI.getFirst3RegLLTs();
 
-  if (VecTy.isScalableVector())
-    report_fatal_error(
-        "Lowering masked_compress for scalable vectors is undefined.");
-
   MachinePointerInfo PtrInfo;
   Register StackPtr =
       createStackTemporary(TypeSize::getFixed(VecTy.getSizeInBytes()),
@@ -7546,8 +7542,6 @@ LegalizerHelper::lowerMCOMPRESS(llvm::MachineInstr &MI) {
   MIRBuilder.buildLoad(Dst, StackPtr, PtrInfo, VecAlign);
 
   MI.eraseFromParent();
-  // TODO: This is not true! We don't know if the input vector type is legal.
-  //       Find out how to assert this!
   return Legalized;
 }
 

>From 1dc79b4124feecec746905a1dc4b903f4dac93f2 Mon Sep 17 00:00:00 2001
From: Lawrence Benson <github at lawben.com>
Date: Fri, 17 May 2024 16:39:58 +0200
Subject: [PATCH 14/14] Update docs according to PR comments

---
 llvm/docs/LangRef.rst | 53 ++++++++++++++++---------------------------
 1 file changed, 20 insertions(+), 33 deletions(-)

diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index e2d3a986ddedf..ad7e20ec59e34 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -25002,15 +25002,14 @@ Overview:
 
 Selects elements from input vector '``value``' according to the '``mask``'.
 All selected elements are written into adjacent lanes in the result vector, from lower to higher.
-The mask holds a bit for each vector lane, and is used to select elements to be kept.
-The number of valid lanes is equal to the number of active bits in the mask.
+The mask holds an entry for each vector lane, and is used to select elements to be kept.
+The number of valid lanes is equal to the number of ``true`` entries in the mask, i.e., all lanes >= number-of-selected-values are undefined.
 The main difference to :ref:`llvm.masked.compressstore <int_compressstore>` is that the remainder of the vector may
 contain undefined values.
 This allows for branchless code and better optimization for all targets that do not support the explicit semantics of
-:ref:`llvm.masked.compressstore <int_compressstore>`.
+:ref:`llvm.masked.compressstore <int_compressstore>` but still have some form of compress operations (e.g., ARM SVE and RISCV V)
 The result vector can be written with a similar effect, as all the selected values are at the lower positions of the
-vector, but without requiring branches to avoid writes where the mask is 0.
-
+vector, but without requiring branches to avoid writes where the mask is ``false``.
 
 Arguments:
 """"""""""
@@ -25022,39 +25021,27 @@ The mask and the input vector must have the same number of vector elements.
 Semantics:
 """"""""""
 
-The '``llvm.masked.compress``' intrinsic is designed for compressing data within a vector, i.e., ideally within a register.
-It allows to collect elements from possibly non-adjacent lanes of a vector and place them contiguously in the result vector in one IR operation.
-It is useful for targets all that support compress operations (e.g., AVX-512, ARM SVE, RISCV V), which more instruction
-sets do than explicit compressstore, i.e., ``llvm.masked.compress`` may yield better performance on more targets than
-``llvm.masked.compressstore`` due to weaker constraints.
-This intrinsic allows vectorizing loops with cross-iteration dependencies like in the following example:
+The ``llvm.masked.compress`` intrinsic compresses data within a vector.
+It collects elements from possibly non-adjacent lanes of a vector and place them contiguously in the result vector based on a selection mask.
+This intrinsic performs the logic of the following C++ example.
+All values in ``out`` after the last selected one are undefined.
+If all entries in the ``mask`` are 0, the entire ``out`` vector is undefined.
 
-.. code-block:: c
+.. code-block:: cpp
 
-    // Consecutively store selected values with branchless code.
-    int *in, *out; bool *mask; int pos = 0;
-    for (int i = 0; i < size; ++i) {
-      out[pos] = in[i];
-      // if mask[i] == 0, the current value is overwritten in the next iteration.
-      pos += mask[i];
+    // Consecutively place selected values in a vector.
+    using VecT __attribute__((vector_size(N))) = int;
+    VecT compress(VecT vec, VecT mask) {
+      VecT out;
+      int idx = 0;
+      for (int i = 0; i < N / sizeof(int); ++i) {
+        out[idx] = vec[i];
+        idx += static_cast<bool>(mask[i]);
+      }
+      return out;
     }
 
 
-.. code-block:: llvm
-
-    ; Load elements from `in`.
-    %vec = load <4 x i32>, ptr %inPtr
-    %mask = load <4 x i1>, ptr %maskPtr
-    %compressed = call <4 x i32> @llvm.masked.compress.v4i32(<4 x i32> %vec, <4 x i1> %mask)
-    store <4 x i32> %compressed, ptr %outPtr
-
-    ; %outPtr should be increased in each iteration by the number of '1's in the mask.
-    %iMask = bitcast <4 x i1> %mask to i4
-    %popcnt = call i4 @llvm.ctpop.i4(i4 %iMask)
-    %zextPopcnt = zext i4 %popcnt to i64
-    %nextOut = add i64 %outPos, %zextPopcnt
-
-
 Memory Use Markers
 ------------------
 



More information about the llvm-commits mailing list