[llvm] [AArch64] Extend performActiveLaneMaskCombine for more than two extracts (PR #146725)

Kerry McLaughlin via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 8 06:52:20 PDT 2025


================
@@ -18143,53 +18143,63 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
       (!ST->hasSVE2p1() && !(ST->hasSME2() && ST->isStreaming())))
     return SDValue();
 
-  if (!N->hasNUsesOfValue(2, 0))
+  unsigned NumUses = N->use_size();
+  auto MaskEC = N->getValueType(0).getVectorElementCount();
+  if (!MaskEC.isKnownMultipleOf(NumUses))
     return SDValue();
 
-  const uint64_t HalfSize = N->getValueType(0).getVectorMinNumElements() / 2;
-  if (HalfSize < 2)
+  ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumUses);
+  if (ExtMinEC.getKnownMinValue() < 2)
     return SDValue();
 
-  auto It = N->user_begin();
-  SDNode *Lo = *It++;
-  SDNode *Hi = *It;
+  SmallVector<SDNode *> Extracts(NumUses, nullptr);
+  for (SDNode *Use : N->users()) {
+    if (Use->getOpcode() != ISD::EXTRACT_SUBVECTOR)
+      return SDValue();
 
-  if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
-      Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR)
-    return SDValue();
+    // Ensure the extract type is correct (e.g. if NumUses is 4 and
+    // the mask return type is nxv8i1, each extract should be nxv2i1.
+    if (Use->getValueType(0).getVectorElementCount() != ExtMinEC)
+      return SDValue();
 
-  uint64_t OffLo = Lo->getConstantOperandVal(1);
-  uint64_t OffHi = Hi->getConstantOperandVal(1);
+    // There should be exactly one extract for each part of the mask.
+    unsigned Offset = Use->getConstantOperandVal(1);
+    unsigned Part = Offset / ExtMinEC.getKnownMinValue();
+    if (Extracts[Part] != nullptr)
+      return SDValue();
 
-  if (OffLo > OffHi) {
-    std::swap(Lo, Hi);
-    std::swap(OffLo, OffHi);
+    Extracts[Part] = Use;
   }
 
-  if (OffLo != 0 || OffHi != HalfSize)
-    return SDValue();
-
-  EVT HalfVec = Lo->getValueType(0);
-  if (HalfVec != Hi->getValueType(0) ||
-      HalfVec.getVectorElementCount() != ElementCount::getScalable(HalfSize))
-    return SDValue();
-
   SelectionDAG &DAG = DCI.DAG;
   SDLoc DL(N);
   SDValue ID =
       DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
+
   SDValue Idx = N->getOperand(0);
   SDValue TC = N->getOperand(1);
-  if (Idx.getValueType() != MVT::i64) {
-    Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
-    TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
+  EVT OpVT = Idx.getValueType();
+  if (OpVT != MVT::i64) {
+    Idx = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, Idx);
+    TC = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, TC);
   }
-  auto R =
-      DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL,
-                  {Lo->getValueType(0), Hi->getValueType(0)}, {ID, Idx, TC});
 
-  DCI.CombineTo(Lo, R.getValue(0));
-  DCI.CombineTo(Hi, R.getValue(1));
+  // Create the whilelo_x2 intrinsics from each pair of extracts
+  EVT ExtVT = Extracts[0]->getValueType(0);
+  for (unsigned I = 0; I < NumUses; I += 2) {
+    // After the first whilelo_x2, we need to increment the starting value.
+    if (I > 0) {
+      SDValue Elts =
+          DAG.getElementCount(DL, OpVT, ExtVT.getVectorElementCount());
----------------
kmclaughlin-arm wrote:

Yes, the element count should have been multiplied by two. I've fixed this in the latest commit.

https://github.com/llvm/llvm-project/pull/146725


More information about the llvm-commits mailing list