[llvm] 7ea3714 - [AArch64] Extend performActiveLaneMaskCombine for more than two extracts (#146725)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 9 02:19:19 PDT 2025


Author: Kerry McLaughlin
Date: 2025-07-09T10:19:14+01:00
New Revision: 7ea371443b83f20edadb538c4b6ea19ee4de1094

URL: https://github.com/llvm/llvm-project/commit/7ea371443b83f20edadb538c4b6ea19ee4de1094
DIFF: https://github.com/llvm/llvm-project/commit/7ea371443b83f20edadb538c4b6ea19ee4de1094.diff

LOG: [AArch64] Extend performActiveLaneMaskCombine for more than two extracts (#146725)

The combine was added to find a get.active.lane.mask used by two
extract subvectors and try to replace it with the paired whilelo
instruction. This extends the combine to cover cases where there
are more than two extracts.

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 6b7e9357aab5a..01be10be433fd 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -18170,53 +18170,65 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
       (!ST->hasSVE2p1() && !(ST->hasSME2() && ST->isStreaming())))
     return SDValue();
 
-  if (!N->hasNUsesOfValue(2, 0))
+  unsigned NumUses = N->use_size();
+  auto MaskEC = N->getValueType(0).getVectorElementCount();
+  if (!MaskEC.isKnownMultipleOf(NumUses))
     return SDValue();
 
-  const uint64_t HalfSize = N->getValueType(0).getVectorMinNumElements() / 2;
-  if (HalfSize < 2)
+  ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumUses);
+  if (ExtMinEC.getKnownMinValue() < 2)
     return SDValue();
 
-  auto It = N->user_begin();
-  SDNode *Lo = *It++;
-  SDNode *Hi = *It;
+  SmallVector<SDNode *> Extracts(NumUses, nullptr);
+  for (SDNode *Use : N->users()) {
+    if (Use->getOpcode() != ISD::EXTRACT_SUBVECTOR)
+      return SDValue();
 
-  if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
-      Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR)
-    return SDValue();
+    // Ensure the extract type is correct (e.g. if NumUses is 4 and
+    // the mask return type is nxv8i1, each extract should be nxv2i1.
+    if (Use->getValueType(0).getVectorElementCount() != ExtMinEC)
+      return SDValue();
 
-  uint64_t OffLo = Lo->getConstantOperandVal(1);
-  uint64_t OffHi = Hi->getConstantOperandVal(1);
+    // There should be exactly one extract for each part of the mask.
+    unsigned Offset = Use->getConstantOperandVal(1);
+    unsigned Part = Offset / ExtMinEC.getKnownMinValue();
+    if (Extracts[Part] != nullptr)
+      return SDValue();
 
-  if (OffLo > OffHi) {
-    std::swap(Lo, Hi);
-    std::swap(OffLo, OffHi);
+    Extracts[Part] = Use;
   }
 
-  if (OffLo != 0 || OffHi != HalfSize)
-    return SDValue();
-
-  EVT HalfVec = Lo->getValueType(0);
-  if (HalfVec != Hi->getValueType(0) ||
-      HalfVec.getVectorElementCount() != ElementCount::getScalable(HalfSize))
-    return SDValue();
-
   SelectionDAG &DAG = DCI.DAG;
   SDLoc DL(N);
   SDValue ID =
       DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
+
   SDValue Idx = N->getOperand(0);
   SDValue TC = N->getOperand(1);
-  if (Idx.getValueType() != MVT::i64) {
-    Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
-    TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
+  EVT OpVT = Idx.getValueType();
+  if (OpVT != MVT::i64) {
+    Idx = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, Idx);
+    TC = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, TC);
   }
+
+  // Create the whilelo_x2 intrinsics from each pair of extracts
+  EVT ExtVT = Extracts[0]->getValueType(0);
   auto R =
-      DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL,
-                  {Lo->getValueType(0), Hi->getValueType(0)}, {ID, Idx, TC});
+      DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
+  DCI.CombineTo(Extracts[0], R.getValue(0));
+  DCI.CombineTo(Extracts[1], R.getValue(1));
+
+  if (NumUses == 2)
+    return SDValue(N, 0);
 
-  DCI.CombineTo(Lo, R.getValue(0));
-  DCI.CombineTo(Hi, R.getValue(1));
+  auto Elts = DAG.getElementCount(DL, OpVT, ExtVT.getVectorElementCount() * 2);
+  for (unsigned I = 2; I < NumUses; I += 2) {
+    // After the first whilelo_x2, we need to increment the starting value.
+    Idx = DAG.getNode(ISD::UADDSAT, DL, OpVT, Idx, Elts);
+    R = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
+    DCI.CombineTo(Extracts[I], R.getValue(0));
+    DCI.CombineTo(Extracts[I + 1], R.getValue(1));
+  }
 
   return SDValue(N, 0);
 }

diff  --git a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
index c76b50d69b877..5e01612e3881a 100644
--- a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
+++ b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
@@ -86,6 +86,65 @@ define void @test_boring_case_2x2bit_mask(i64 %i, i64 %n) #0 {
     ret void
 }
 
+define void @test_legal_4x2bit_mask(i64 %i, i64 %n) #0 {
+; CHECK-SVE-LABEL: test_legal_4x2bit_mask:
+; CHECK-SVE:       // %bb.0:
+; CHECK-SVE-NEXT:    whilelo p0.h, x0, x1
+; CHECK-SVE-NEXT:    punpkhi p1.h, p0.b
+; CHECK-SVE-NEXT:    punpklo p4.h, p0.b
+; CHECK-SVE-NEXT:    punpkhi p3.h, p1.b
+; CHECK-SVE-NEXT:    punpklo p2.h, p1.b
+; CHECK-SVE-NEXT:    punpklo p0.h, p4.b
+; CHECK-SVE-NEXT:    punpkhi p1.h, p4.b
+; CHECK-SVE-NEXT:    b use
+;
+; CHECK-SVE2p1-SME2-LABEL: test_legal_4x2bit_mask:
+; CHECK-SVE2p1-SME2:       // %bb.0:
+; CHECK-SVE2p1-SME2-NEXT:    cntw x8
+; CHECK-SVE2p1-SME2-NEXT:    adds x8, x0, x8
+; CHECK-SVE2p1-SME2-NEXT:    csinv x8, x8, xzr, lo
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.d, p1.d }, x0, x1
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p2.d, p3.d }, x8, x1
+; CHECK-SVE2p1-SME2-NEXT:    b use
+  %r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
+  %v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 6)
+  %v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 4)
+  %v2 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 2)
+  %v3 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
+  tail call void @use(<vscale x 2 x i1> %v3, <vscale x 2 x i1> %v2, <vscale x 2 x i1> %v1, <vscale x 2 x i1> %v0)
+  ret void
+}
+
+; Negative test where the extract types are correct but we are not extracting all parts of the mask
+; Note: We could still create a whilelo_x2 for the first two extracts, but we don't expect this case often yet.
+define void @test_partial_extract_correct_types(i64 %i, i64 %n) #0 {
+; CHECK-SVE-LABEL: test_partial_extract_correct_types:
+; CHECK-SVE:       // %bb.0:
+; CHECK-SVE-NEXT:    whilelo p0.h, x0, x1
+; CHECK-SVE-NEXT:    punpklo p1.h, p0.b
+; CHECK-SVE-NEXT:    punpkhi p2.h, p0.b
+; CHECK-SVE-NEXT:    punpklo p0.h, p1.b
+; CHECK-SVE-NEXT:    punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT:    punpkhi p2.h, p2.b
+; CHECK-SVE-NEXT:    b use
+;
+; CHECK-SVE2p1-SME2-LABEL: test_partial_extract_correct_types:
+; CHECK-SVE2p1-SME2:       // %bb.0:
+; CHECK-SVE2p1-SME2-NEXT:    whilelo p0.h, x0, x1
+; CHECK-SVE2p1-SME2-NEXT:    punpklo p1.h, p0.b
+; CHECK-SVE2p1-SME2-NEXT:    punpkhi p2.h, p0.b
+; CHECK-SVE2p1-SME2-NEXT:    punpklo p0.h, p1.b
+; CHECK-SVE2p1-SME2-NEXT:    punpkhi p1.h, p1.b
+; CHECK-SVE2p1-SME2-NEXT:    punpkhi p2.h, p2.b
+; CHECK-SVE2p1-SME2-NEXT:    b use
+  %r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
+  %v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
+  %v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 2)
+  %v2 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 6)
+  tail call void @use(<vscale x 2 x i1> %v0, <vscale x 2 x i1> %v1, <vscale x 2 x i1> %v2)
+  ret void
+}
+
 ; Negative test for when not extracting exactly two halves of the source vector
 define void @test_partial_extract(i64 %i, i64 %n) #0 {
 ; CHECK-SVE-LABEL: test_partial_extract:
@@ -116,57 +175,80 @@ define void @test_partial_extract(i64 %i, i64 %n) #0 {
 define void @test_fixed_extract(i64 %i, i64 %n) #0 {
 ; CHECK-SVE-LABEL: test_fixed_extract:
 ; CHECK-SVE:       // %bb.0:
-; CHECK-SVE-NEXT:    whilelo p0.h, x0, x1
+; CHECK-SVE-NEXT:    whilelo p0.s, x0, x1
 ; CHECK-SVE-NEXT:    cset w8, mi
-; CHECK-SVE-NEXT:    mov z0.h, p0/z, #1 // =0x1
-; CHECK-SVE-NEXT:    umov w9, v0.h[4]
-; CHECK-SVE-NEXT:    umov w10, v0.h[1]
-; CHECK-SVE-NEXT:    umov w11, v0.h[5]
+; CHECK-SVE-NEXT:    mov z1.s, p0/z, #1 // =0x1
 ; CHECK-SVE-NEXT:    fmov s0, w8
-; CHECK-SVE-NEXT:    fmov s1, w9
-; CHECK-SVE-NEXT:    mov v0.s[1], w10
+; CHECK-SVE-NEXT:    mov v0.s[1], v1.s[1]
+; CHECK-SVE-NEXT:    ext z1.b, z1.b, z1.b, #8
 ; CHECK-SVE-NEXT:    // kill: def $d0 killed $d0 killed $q0
-; CHECK-SVE-NEXT:    mov v1.s[1], w11
-; CHECK-SVE-NEXT:    // kill: def $d1 killed $d1 killed $q1
+; CHECK-SVE-NEXT:    // kill: def $d1 killed $d1 killed $z1
 ; CHECK-SVE-NEXT:    b use
 ;
 ; CHECK-SVE2p1-LABEL: test_fixed_extract:
 ; CHECK-SVE2p1:       // %bb.0:
-; CHECK-SVE2p1-NEXT:    whilelo p0.h, x0, x1
+; CHECK-SVE2p1-NEXT:    whilelo p0.s, x0, x1
 ; CHECK-SVE2p1-NEXT:    cset w8, mi
-; CHECK-SVE2p1-NEXT:    mov z0.h, p0/z, #1 // =0x1
-; CHECK-SVE2p1-NEXT:    umov w9, v0.h[4]
-; CHECK-SVE2p1-NEXT:    umov w10, v0.h[1]
-; CHECK-SVE2p1-NEXT:    umov w11, v0.h[5]
+; CHECK-SVE2p1-NEXT:    mov z1.s, p0/z, #1 // =0x1
 ; CHECK-SVE2p1-NEXT:    fmov s0, w8
-; CHECK-SVE2p1-NEXT:    fmov s1, w9
-; CHECK-SVE2p1-NEXT:    mov v0.s[1], w10
+; CHECK-SVE2p1-NEXT:    mov v0.s[1], v1.s[1]
+; CHECK-SVE2p1-NEXT:    ext z1.b, z1.b, z1.b, #8
 ; CHECK-SVE2p1-NEXT:    // kill: def $d0 killed $d0 killed $q0
-; CHECK-SVE2p1-NEXT:    mov v1.s[1], w11
-; CHECK-SVE2p1-NEXT:    // kill: def $d1 killed $d1 killed $q1
+; CHECK-SVE2p1-NEXT:    // kill: def $d1 killed $d1 killed $z1
 ; CHECK-SVE2p1-NEXT:    b use
 ;
 ; CHECK-SME2-LABEL: test_fixed_extract:
 ; CHECK-SME2:       // %bb.0:
-; CHECK-SME2-NEXT:    whilelo p0.h, x0, x1
+; CHECK-SME2-NEXT:    whilelo p0.s, x0, x1
 ; CHECK-SME2-NEXT:    cset w8, mi
-; CHECK-SME2-NEXT:    mov z0.h, p0/z, #1 // =0x1
-; CHECK-SME2-NEXT:    mov z1.h, z0.h[1]
-; CHECK-SME2-NEXT:    mov z2.h, z0.h[5]
-; CHECK-SME2-NEXT:    mov z3.h, z0.h[4]
-; CHECK-SME2-NEXT:    fmov s0, w8
-; CHECK-SME2-NEXT:    zip1 z0.s, z0.s, z1.s
-; CHECK-SME2-NEXT:    zip1 z1.s, z3.s, z2.s
-; CHECK-SME2-NEXT:    // kill: def $d0 killed $d0 killed $z0
+; CHECK-SME2-NEXT:    mov z1.s, p0/z, #1 // =0x1
+; CHECK-SME2-NEXT:    fmov s2, w8
+; CHECK-SME2-NEXT:    mov z0.s, z1.s[1]
+; CHECK-SME2-NEXT:    ext z1.b, z1.b, z1.b, #8
 ; CHECK-SME2-NEXT:    // kill: def $d1 killed $d1 killed $z1
+; CHECK-SME2-NEXT:    zip1 z0.s, z2.s, z0.s
+; CHECK-SME2-NEXT:    // kill: def $d0 killed $d0 killed $z0
 ; CHECK-SME2-NEXT:    b use
-    %r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
-    %v0 = call <2 x i1> @llvm.vector.extract.v2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
-    %v1 = call <2 x i1> @llvm.vector.extract.v2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 4)
+    %r = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64 %i, i64 %n)
+    %v0 = call <2 x i1> @llvm.vector.extract.v2i1.nxv4i1.i64(<vscale x 4 x i1> %r, i64 0)
+    %v1 = call <2 x i1> @llvm.vector.extract.v2i1.nxv4i1.i64(<vscale x 4 x i1> %r, i64 2)
     tail call void @use(<2 x i1> %v0, <2 x i1> %v1)
     ret void
 }
 
+; Negative test where the number of extracts is right, but they cannot be combined because
+; there is not an extract for each part
+define void @test_4x2bit_duplicate_mask(i64 %i, i64 %n) #0 {
+; CHECK-SVE-LABEL: test_4x2bit_duplicate_mask:
+; CHECK-SVE:       // %bb.0:
+; CHECK-SVE-NEXT:    whilelo p0.h, x0, x1
+; CHECK-SVE-NEXT:    punpklo p1.h, p0.b
+; CHECK-SVE-NEXT:    punpkhi p3.h, p0.b
+; CHECK-SVE-NEXT:    punpkhi p0.h, p1.b
+; CHECK-SVE-NEXT:    punpklo p2.h, p3.b
+; CHECK-SVE-NEXT:    punpkhi p3.h, p3.b
+; CHECK-SVE-NEXT:    mov p1.b, p0.b
+; CHECK-SVE-NEXT:    b use
+;
+; CHECK-SVE2p1-SME2-LABEL: test_4x2bit_duplicate_mask:
+; CHECK-SVE2p1-SME2:       // %bb.0:
+; CHECK-SVE2p1-SME2-NEXT:    whilelo p0.h, x0, x1
+; CHECK-SVE2p1-SME2-NEXT:    punpklo p1.h, p0.b
+; CHECK-SVE2p1-SME2-NEXT:    punpkhi p3.h, p0.b
+; CHECK-SVE2p1-SME2-NEXT:    punpkhi p0.h, p1.b
+; CHECK-SVE2p1-SME2-NEXT:    punpklo p2.h, p3.b
+; CHECK-SVE2p1-SME2-NEXT:    punpkhi p3.h, p3.b
+; CHECK-SVE2p1-SME2-NEXT:    mov p1.b, p0.b
+; CHECK-SVE2p1-SME2-NEXT:    b use
+  %r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
+  %v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 2)
+  %v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 2)
+  %v2 = call <vscale x 2 x i1> @llvm.vector.extract.nxv4i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 4)
+  %v3 = call <vscale x 2 x i1> @llvm.vector.extract.nxv4i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 6)
+  tail call void @use(<vscale x 2 x i1> %v0, <vscale x 2 x i1> %v1, <vscale x 2 x i1> %v2, <vscale x 2 x i1> %v3)
+  ret void
+}
+
 ; Illegal Types
 
 define void @test_2x16bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #0 {


        


More information about the llvm-commits mailing list