[llvm] [AArch64] Extend performActiveLaneMaskCombine for more than two extracts (PR #146725)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 2 08:28:58 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Kerry McLaughlin (kmclaughlin-arm)

<details>
<summary>Changes</summary>

The combine was added to find a get.active.lane.mask used by two extract
subvectors and try to replace it with the paired whilelo instruction. This
extends the combine to cover cases where there are more than two extracts.

---
Full diff: https://github.com/llvm/llvm-project/pull/146725.diff


2 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+40-30) 
- (modified) llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll (+87) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index fb8bd81c033af..7e4ba0a776382 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -18143,53 +18143,63 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
       (!ST->hasSVE2p1() && !(ST->hasSME2() && ST->isStreaming())))
     return SDValue();
 
-  if (!N->hasNUsesOfValue(2, 0))
+  unsigned NumUses = N->use_size();
+  unsigned MaskMinElts = N->getValueType(0).getVectorMinNumElements();
+  if (MaskMinElts % NumUses != 0)
     return SDValue();
 
-  const uint64_t HalfSize = N->getValueType(0).getVectorMinNumElements() / 2;
-  if (HalfSize < 2)
+  unsigned ExtMinElts = MaskMinElts / NumUses;
+  if (ExtMinElts < 2)
     return SDValue();
 
-  auto It = N->user_begin();
-  SDNode *Lo = *It++;
-  SDNode *Hi = *It;
+  SmallVector<SDNode *> Extracts(NumUses, nullptr);
+  for (SDNode *Use : N->users()) {
+    if (Use->getOpcode() != ISD::EXTRACT_SUBVECTOR)
+      return SDValue();
 
-  if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
-      Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR)
-    return SDValue();
+    // Ensure the extract type is correct (e.g. if NumUses is 4 and
+    // the mask return type is nxv8i1, each extract should be nxv2i1.
+    if (Use->getValueType(0).getVectorMinNumElements() != ExtMinElts)
+      return SDValue();
 
-  uint64_t OffLo = Lo->getConstantOperandVal(1);
-  uint64_t OffHi = Hi->getConstantOperandVal(1);
+    // There should be exactly one extract for each part of the mask.
+    unsigned Offset = Use->getConstantOperandVal(1);
+    unsigned Part = Offset / ExtMinElts;
+    if (Extracts[Part] != nullptr)
+      return SDValue();
 
-  if (OffLo > OffHi) {
-    std::swap(Lo, Hi);
-    std::swap(OffLo, OffHi);
+    Extracts[Part] = Use;
   }
 
-  if (OffLo != 0 || OffHi != HalfSize)
-    return SDValue();
-
-  EVT HalfVec = Lo->getValueType(0);
-  if (HalfVec != Hi->getValueType(0) ||
-      HalfVec.getVectorElementCount() != ElementCount::getScalable(HalfSize))
-    return SDValue();
-
   SelectionDAG &DAG = DCI.DAG;
   SDLoc DL(N);
   SDValue ID =
       DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
+
   SDValue Idx = N->getOperand(0);
   SDValue TC = N->getOperand(1);
-  if (Idx.getValueType() != MVT::i64) {
-    Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
-    TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
+  EVT OpVT = Idx.getValueType();
+  if (OpVT != MVT::i64) {
+    Idx = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, Idx);
+    TC = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, TC);
   }
-  auto R =
-      DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL,
-                  {Lo->getValueType(0), Hi->getValueType(0)}, {ID, Idx, TC});
 
-  DCI.CombineTo(Lo, R.getValue(0));
-  DCI.CombineTo(Hi, R.getValue(1));
+  // Create the whilelo_x2 intrinsics from each pair of extracts
+  EVT ExtVT = Extracts[0]->getValueType(0);
+  for (unsigned I = 0; I < NumUses; I += 2) {
+    // After the first whilelo_x2, we need to increment the starting value.
+    if (I > 0) {
+      SDValue Elts =
+          DAG.getElementCount(DL, OpVT, ExtVT.getVectorElementCount());
+      Idx = DAG.getNode(ISD::UADDSAT, DL, OpVT, Idx, Elts);
+    }
+
+    auto R =
+        DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
+
+    DCI.CombineTo(Extracts[I], R.getValue(0));
+    DCI.CombineTo(Extracts[I + 1], R.getValue(1));
+  }
 
   return SDValue(N, 0);
 }
diff --git a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
index c76b50d69b877..95788f4e6e83b 100644
--- a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
+++ b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
@@ -86,6 +86,64 @@ define void @test_boring_case_2x2bit_mask(i64 %i, i64 %n) #0 {
     ret void
 }
 
+define void @test_legal_4x2bit_mask(i64 %i, i64 %n) #0 {
+; CHECK-SVE-LABEL: test_legal_4x2bit_mask:
+; CHECK-SVE:       // %bb.0:
+; CHECK-SVE-NEXT:    whilelo p0.h, x0, x1
+; CHECK-SVE-NEXT:    punpkhi p1.h, p0.b
+; CHECK-SVE-NEXT:    punpklo p4.h, p0.b
+; CHECK-SVE-NEXT:    punpkhi p3.h, p1.b
+; CHECK-SVE-NEXT:    punpklo p2.h, p1.b
+; CHECK-SVE-NEXT:    punpklo p0.h, p4.b
+; CHECK-SVE-NEXT:    punpkhi p1.h, p4.b
+; CHECK-SVE-NEXT:    b use
+;
+; CHECK-SVE2p1-SME2-LABEL: test_legal_4x2bit_mask:
+; CHECK-SVE2p1-SME2:       // %bb.0:
+; CHECK-SVE2p1-SME2-NEXT:    cntd x8
+; CHECK-SVE2p1-SME2-NEXT:    adds x8, x0, x8
+; CHECK-SVE2p1-SME2-NEXT:    csinv x8, x8, xzr, lo
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.d, p1.d }, x0, x1
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p2.d, p3.d }, x8, x1
+; CHECK-SVE2p1-SME2-NEXT:    b use
+  %r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
+  %v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 6)
+  %v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 4)
+  %v2 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 2)
+  %v3 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
+  tail call void @use(<vscale x 2 x i1> %v3, <vscale x 2 x i1> %v2, <vscale x 2 x i1> %v1, <vscale x 2 x i1> %v0)
+  ret void
+}
+
+; Negative test where the extract types are correct but we are not extracting all parts of the mask
+define void @test_partial_extract_correct_types(i64 %i, i64 %n) #0 {
+; CHECK-SVE-LABEL: test_partial_extract_correct_types:
+; CHECK-SVE:       // %bb.0:
+; CHECK-SVE-NEXT:    whilelo p0.h, x0, x1
+; CHECK-SVE-NEXT:    punpklo p1.h, p0.b
+; CHECK-SVE-NEXT:    punpkhi p2.h, p0.b
+; CHECK-SVE-NEXT:    punpklo p0.h, p1.b
+; CHECK-SVE-NEXT:    punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT:    punpkhi p2.h, p2.b
+; CHECK-SVE-NEXT:    b use
+;
+; CHECK-SVE2p1-SME2-LABEL: test_partial_extract_correct_types:
+; CHECK-SVE2p1-SME2:       // %bb.0:
+; CHECK-SVE2p1-SME2-NEXT:    whilelo p0.h, x0, x1
+; CHECK-SVE2p1-SME2-NEXT:    punpklo p1.h, p0.b
+; CHECK-SVE2p1-SME2-NEXT:    punpkhi p2.h, p0.b
+; CHECK-SVE2p1-SME2-NEXT:    punpklo p0.h, p1.b
+; CHECK-SVE2p1-SME2-NEXT:    punpkhi p1.h, p1.b
+; CHECK-SVE2p1-SME2-NEXT:    punpkhi p2.h, p2.b
+; CHECK-SVE2p1-SME2-NEXT:    b use
+  %r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
+  %v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
+  %v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 2)
+  %v2 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 6)
+  tail call void @use(<vscale x 2 x i1> %v0, <vscale x 2 x i1> %v1, <vscale x 2 x i1> %v2)
+  ret void
+}
+
 ; Negative test for when not extracting exactly two halves of the source vector
 define void @test_partial_extract(i64 %i, i64 %n) #0 {
 ; CHECK-SVE-LABEL: test_partial_extract:
@@ -167,6 +225,35 @@ define void @test_fixed_extract(i64 %i, i64 %n) #0 {
     ret void
 }
 
+; Negative test where the number of extracts is right, but they cannot be combined because
+; there is not an extract for each part
+define void @test_2x2bit_2x4bit_mask(i64 %i, i64 %n) #0 {
+; CHECK-SVE-LABEL: test_2x2bit_2x4bit_mask:
+; CHECK-SVE:       // %bb.0:
+; CHECK-SVE-NEXT:    whilelo p0.h, x0, x1
+; CHECK-SVE-NEXT:    punpklo p2.h, p0.b
+; CHECK-SVE-NEXT:    punpkhi p3.h, p0.b
+; CHECK-SVE-NEXT:    punpklo p0.h, p2.b
+; CHECK-SVE-NEXT:    punpkhi p1.h, p2.b
+; CHECK-SVE-NEXT:    b use
+;
+; CHECK-SVE2p1-SME2-LABEL: test_2x2bit_2x4bit_mask:
+; CHECK-SVE2p1-SME2:       // %bb.0:
+; CHECK-SVE2p1-SME2-NEXT:    whilelo p0.h, x0, x1
+; CHECK-SVE2p1-SME2-NEXT:    punpklo p2.h, p0.b
+; CHECK-SVE2p1-SME2-NEXT:    punpkhi p3.h, p0.b
+; CHECK-SVE2p1-SME2-NEXT:    punpklo p0.h, p2.b
+; CHECK-SVE2p1-SME2-NEXT:    punpkhi p1.h, p2.b
+; CHECK-SVE2p1-SME2-NEXT:    b use
+  %r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
+  %v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
+  %v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 2)
+  %v2 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
+  %v3 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 4)
+  tail call void @use(<vscale x 2 x i1> %v0, <vscale x 2 x i1> %v1, <vscale x 4 x i1> %v2, <vscale x 4 x i1> %v3)
+  ret void
+}
+
 ; Illegal Types
 
 define void @test_2x16bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #0 {

``````````

</details>


https://github.com/llvm/llvm-project/pull/146725


More information about the llvm-commits mailing list