[llvm] [AArch64][SVE2p1] Allow more uses of mask in performActiveLaneMaskCombine (PR #159360)

Kerry McLaughlin via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 17 06:38:47 PDT 2025


https://github.com/kmclaughlin-arm created https://github.com/llvm/llvm-project/pull/159360

The combine replaces a get_active_lane_mask used by two extract subvectors with
a single paired whilelo intrinsic. When the instruction is used for control
flow in a vector loop, an additional extract of element 0 may introduce
other uses of the intrinsic such as ptest and reinterpret cast, which
is currently not supported.

This patch changes performActiveLaneMaskCombine to count the number of
extract subvectors using the mask instead of the total number of uses,
and allows other uses by these additional operations.

>From 5443c18e6e238aa62af1d4af970ab895e6ecbc22 Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Tue, 16 Sep 2025 16:11:22 +0000
Subject: [PATCH 1/2] - Tests for get_active_lane_mask with more uses than
 extract_vector

---
 .../AArch64/get-active-lane-mask-extract.ll   | 79 +++++++++++++++++++
 1 file changed, 79 insertions(+)

diff --git a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
index 5e01612e3881a..290731355797c 100644
--- a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
+++ b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
@@ -310,6 +310,85 @@ define void @test_2x32bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #
   ret void
 }
 
+; Extra use of the get_active_lane_mask from an extractelement, which is replaced with ptest_first.
+
+define void @test_2x8bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
+; CHECK-SVE-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
+; CHECK-SVE:       // %bb.0: // %entry
+; CHECK-SVE-NEXT:    whilelo p1.b, x0, x1
+; CHECK-SVE-NEXT:    b.pl .LBB11_2
+; CHECK-SVE-NEXT:  // %bb.1: // %if.then
+; CHECK-SVE-NEXT:    punpklo p0.h, p1.b
+; CHECK-SVE-NEXT:    punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT:    b use
+; CHECK-SVE-NEXT:  .LBB11_2: // %if.end
+; CHECK-SVE-NEXT:    ret
+;
+; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
+; CHECK-SVE2p1-SME2:       // %bb.0: // %entry
+; CHECK-SVE2p1-SME2-NEXT:    whilelo p1.b, x0, x1
+; CHECK-SVE2p1-SME2-NEXT:    b.pl .LBB11_2
+; CHECK-SVE2p1-SME2-NEXT:  // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT:    punpklo p0.h, p1.b
+; CHECK-SVE2p1-SME2-NEXT:    punpkhi p1.h, p1.b
+; CHECK-SVE2p1-SME2-NEXT:    b use
+; CHECK-SVE2p1-SME2-NEXT:  .LBB11_2: // %if.end
+; CHECK-SVE2p1-SME2-NEXT:    ret
+entry:
+    %r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i64 %i, i64 %n)
+    %v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
+    %v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
+    %elt0 = extractelement <vscale x 16 x i1> %r, i32 0
+    br i1 %elt0, label %if.then, label %if.end
+
+if.then:
+    tail call void @use(<vscale x 8 x i1> %v0, <vscale x 8 x i1> %v1)
+    br label %if.end
+
+if.end:
+    ret void
+}
+
+; Extra use of the get_active_lane_mask from an extractelement, which is
+; replaced with ptest_first and reinterpret_casts because the extract is not nxv16i1.
+
+define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n) {
+; CHECK-SVE-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
+; CHECK-SVE:       // %bb.0: // %entry
+; CHECK-SVE-NEXT:    whilelo p1.h, x0, x1
+; CHECK-SVE-NEXT:    b.pl .LBB12_2
+; CHECK-SVE-NEXT:  // %bb.1: // %if.then
+; CHECK-SVE-NEXT:    punpklo p0.h, p1.b
+; CHECK-SVE-NEXT:    punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT:    b use
+; CHECK-SVE-NEXT:  .LBB12_2: // %if.end
+; CHECK-SVE-NEXT:    ret
+;
+; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
+; CHECK-SVE2p1-SME2:       // %bb.0: // %entry
+; CHECK-SVE2p1-SME2-NEXT:    whilelo p1.h, x0, x1
+; CHECK-SVE2p1-SME2-NEXT:    b.pl .LBB12_2
+; CHECK-SVE2p1-SME2-NEXT:  // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT:    punpklo p0.h, p1.b
+; CHECK-SVE2p1-SME2-NEXT:    punpkhi p1.h, p1.b
+; CHECK-SVE2p1-SME2-NEXT:    b use
+; CHECK-SVE2p1-SME2-NEXT:  .LBB12_2: // %if.end
+; CHECK-SVE2p1-SME2-NEXT:    ret
+entry:
+    %r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
+    %v0 = tail call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv8i1(<vscale x 8 x i1> %r, i64 0)
+    %v1 = tail call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv8i1(<vscale x 8 x i1> %r, i64 4)
+    %elt0 = extractelement <vscale x 8 x i1> %r, i64 0
+    br i1 %elt0, label %if.then, label %if.end
+
+if.then:
+    tail call void @use(<vscale x 4 x i1> %v0, <vscale x 4 x i1> %v1)
+    br label %if.end
+
+if.end:
+    ret void
+}
+
 declare void @use(...)
 
 attributes #0 = { nounwind }

>From ea8a0537b360428533dc0ce4d53a6ae53b76226f Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Tue, 16 Sep 2025 09:04:36 +0000
Subject: [PATCH 2/2] [AArch64][SVE2p1] Allow more uses of mask in
 performActiveLaneMaskCombine

The combine replaces a get_active_lane_mask used by two extract subvectors with
a single paired whilelo intrinsic. When the instruction is used for control
flow in a vector loop, an additional extract of element 0 may introduce
other uses of the intrinsic such as ptest and reinterpret cast, which
is currently not supported.

This patch changes performActiveLaneMaskCombine to count the number of
extract subvectors using the mask instead of the total number of uses,
and allows other uses by these additional operations.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 29 ++++++++++++++-----
 llvm/lib/Target/AArch64/AArch64InstrInfo.cpp  | 20 ++++++++-----
 .../AArch64/get-active-lane-mask-extract.ll   |  8 ++---
 3 files changed, 36 insertions(+), 21 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c9a756da0078d..9c7ecf944e763 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -18693,21 +18693,31 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
       (!ST->hasSVE2p1() && !(ST->hasSME2() && ST->isStreaming())))
     return SDValue();
 
-  unsigned NumUses = N->use_size();
+  // Count the number of users which are extract_vectors
+  // The only other valid users for this combine are ptest_first
+  // and reinterpret_cast.
+  unsigned NumExts = count_if(N->users(), [](SDNode *Use) {
+    return Use->getOpcode() == ISD::EXTRACT_SUBVECTOR;
+  });
+
   auto MaskEC = N->getValueType(0).getVectorElementCount();
-  if (!MaskEC.isKnownMultipleOf(NumUses))
+  if (!MaskEC.isKnownMultipleOf(NumExts))
     return SDValue();
 
-  ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumUses);
+  ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumExts);
   if (ExtMinEC.getKnownMinValue() < 2)
     return SDValue();
 
-  SmallVector<SDNode *> Extracts(NumUses, nullptr);
+  SmallVector<SDNode *> Extracts(NumExts, nullptr);
   for (SDNode *Use : N->users()) {
+    if (Use->getOpcode() == AArch64ISD::PTEST_FIRST ||
+        Use->getOpcode() == AArch64ISD::REINTERPRET_CAST)
+      continue;
+
     if (Use->getOpcode() != ISD::EXTRACT_SUBVECTOR)
       return SDValue();
 
-    // Ensure the extract type is correct (e.g. if NumUses is 4 and
+    // Ensure the extract type is correct (e.g. if NumExts is 4 and
     // the mask return type is nxv8i1, each extract should be nxv2i1.
     if (Use->getValueType(0).getVectorElementCount() != ExtMinEC)
       return SDValue();
@@ -18741,11 +18751,13 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
   DCI.CombineTo(Extracts[0], R.getValue(0));
   DCI.CombineTo(Extracts[1], R.getValue(1));
 
-  if (NumUses == 2)
-    return SDValue(N, 0);
+  if (NumExts == 2) {
+    DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), R.getValue(0));
+    return SDValue(SDValue(N, 0));
+  }
 
   auto Elts = DAG.getElementCount(DL, OpVT, ExtVT.getVectorElementCount() * 2);
-  for (unsigned I = 2; I < NumUses; I += 2) {
+  for (unsigned I = 2; I < NumExts; I += 2) {
     // After the first whilelo_x2, we need to increment the starting value.
     Idx = DAG.getNode(ISD::UADDSAT, DL, OpVT, Idx, Elts);
     R = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
@@ -18753,6 +18765,7 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
     DCI.CombineTo(Extracts[I + 1], R.getValue(1));
   }
 
+  DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), R.getValue(0));
   return SDValue(N, 0);
 }
 
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index bf3d47ac43607..069d08663fdea 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1495,13 +1495,19 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
     if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
       return PredOpcode;
 
-    // For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
-    // redundant since WHILE performs an implicit PTEST with an all active
-    // mask.
-    if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 &&
-        getElementSizeForOpcode(MaskOpcode) ==
-            getElementSizeForOpcode(PredOpcode))
-      return PredOpcode;
+    if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31) {
+      auto PTestOp = MRI->getUniqueVRegDef(PTest->getOperand(1).getReg());
+      if (PTest->getOpcode() == AArch64::PTEST_PP_FIRST && PTestOp->isCopy() &&
+          PTestOp->getOperand(1).getSubReg() == AArch64::psub0)
+        return PredOpcode;
+
+      // For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
+      // redundant since WHILE performs an implicit PTEST with an all active
+      // mask.
+      if (getElementSizeForOpcode(MaskOpcode) ==
+          getElementSizeForOpcode(PredOpcode))
+        return PredOpcode;
+    }
 
     return {};
   }
diff --git a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
index 290731355797c..3b18008605413 100644
--- a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
+++ b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
@@ -326,11 +326,9 @@ define void @test_2x8bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
 ;
 ; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
 ; CHECK-SVE2p1-SME2:       // %bb.0: // %entry
-; CHECK-SVE2p1-SME2-NEXT:    whilelo p1.b, x0, x1
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.h, p1.h }, x0, x1
 ; CHECK-SVE2p1-SME2-NEXT:    b.pl .LBB11_2
 ; CHECK-SVE2p1-SME2-NEXT:  // %bb.1: // %if.then
-; CHECK-SVE2p1-SME2-NEXT:    punpklo p0.h, p1.b
-; CHECK-SVE2p1-SME2-NEXT:    punpkhi p1.h, p1.b
 ; CHECK-SVE2p1-SME2-NEXT:    b use
 ; CHECK-SVE2p1-SME2-NEXT:  .LBB11_2: // %if.end
 ; CHECK-SVE2p1-SME2-NEXT:    ret
@@ -366,11 +364,9 @@ define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n
 ;
 ; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
 ; CHECK-SVE2p1-SME2:       // %bb.0: // %entry
-; CHECK-SVE2p1-SME2-NEXT:    whilelo p1.h, x0, x1
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.s, p1.s }, x0, x1
 ; CHECK-SVE2p1-SME2-NEXT:    b.pl .LBB12_2
 ; CHECK-SVE2p1-SME2-NEXT:  // %bb.1: // %if.then
-; CHECK-SVE2p1-SME2-NEXT:    punpklo p0.h, p1.b
-; CHECK-SVE2p1-SME2-NEXT:    punpkhi p1.h, p1.b
 ; CHECK-SVE2p1-SME2-NEXT:    b use
 ; CHECK-SVE2p1-SME2-NEXT:  .LBB12_2: // %if.end
 ; CHECK-SVE2p1-SME2-NEXT:    ret



More information about the llvm-commits mailing list