[llvm] cf50bbf - [AArch64][SVE2p1] Allow more uses of mask in performActiveLaneMaskCombine (#159360)

via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 30 06:53:05 PDT 2025


Author: Kerry McLaughlin
Date: 2025-09-30T14:53:01+01:00
New Revision: cf50bbf983c6ff032c7ad0de27ffaff412947ffb

URL: https://github.com/llvm/llvm-project/commit/cf50bbf983c6ff032c7ad0de27ffaff412947ffb
DIFF: https://github.com/llvm/llvm-project/commit/cf50bbf983c6ff032c7ad0de27ffaff412947ffb.diff

LOG: [AArch64][SVE2p1] Allow more uses of mask in performActiveLaneMaskCombine (#159360)

The combine replaces a get_active_lane_mask used by two extract
subvectors with a single paired whilelo intrinsic. When the instruction
is used for control flow in a vector loop, an additional extract of element
0 may introduce other uses of the intrinsic such as ptest and reinterpret
cast, which is currently not supported.

This patch changes performActiveLaneMaskCombine to count the number
of extract subvectors using the mask instead of the total number of uses,
and returns the concatenated results of get_active_lane_mask.

Added: 
    

Modified: 
    llvm/include/llvm/Support/TypeSize.h
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
    llvm/unittests/Support/TypeSizeTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Support/TypeSize.h b/llvm/include/llvm/Support/TypeSize.h
index 29d1c6894b4b6..0a7ae15edbb33 100644
--- a/llvm/include/llvm/Support/TypeSize.h
+++ b/llvm/include/llvm/Support/TypeSize.h
@@ -179,7 +179,7 @@ template <typename LeafTy, typename ValueTy> class FixedOrScalableQuantity {
   /// This function tells the caller whether the element count is known at
   /// compile time to be a multiple of the scalar value RHS.
   constexpr bool isKnownMultipleOf(ScalarTy RHS) const {
-    return getKnownMinValue() % RHS == 0;
+    return RHS != 0 && getKnownMinValue() % RHS == 0;
   }
 
   /// Returns whether or not the callee is known to be a multiple of RHS.
@@ -191,7 +191,8 @@ template <typename LeafTy, typename ValueTy> class FixedOrScalableQuantity {
     // x % y == 0 !=> x % (vscale * y) == 0
     if (!isScalable() && RHS.isScalable())
       return false;
-    return getKnownMinValue() % RHS.getKnownMinValue() == 0;
+    return RHS.getKnownMinValue() != 0 &&
+           getKnownMinValue() % RHS.getKnownMinValue() == 0;
   }
 
   // Return the minimum value with the assumption that the count is exact.

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9078675da0e95..45f52352d45fd 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -18867,21 +18867,25 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
       (!ST->hasSVE2p1() && !(ST->hasSME2() && ST->isStreaming())))
     return SDValue();
 
-  unsigned NumUses = N->use_size();
+  // Count the number of users which are extract_vectors.
+  unsigned NumExts = count_if(N->users(), [](SDNode *Use) {
+    return Use->getOpcode() == ISD::EXTRACT_SUBVECTOR;
+  });
+
   auto MaskEC = N->getValueType(0).getVectorElementCount();
-  if (!MaskEC.isKnownMultipleOf(NumUses))
+  if (!MaskEC.isKnownMultipleOf(NumExts))
     return SDValue();
 
-  ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumUses);
+  ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumExts);
   if (ExtMinEC.getKnownMinValue() < 2)
     return SDValue();
 
-  SmallVector<SDNode *> Extracts(NumUses, nullptr);
+  SmallVector<SDNode *> Extracts(NumExts, nullptr);
   for (SDNode *Use : N->users()) {
     if (Use->getOpcode() != ISD::EXTRACT_SUBVECTOR)
-      return SDValue();
+      continue;
 
-    // Ensure the extract type is correct (e.g. if NumUses is 4 and
+    // Ensure the extract type is correct (e.g. if NumExts is 4 and
     // the mask return type is nxv8i1, each extract should be nxv2i1.
     if (Use->getValueType(0).getVectorElementCount() != ExtMinEC)
       return SDValue();
@@ -18902,32 +18906,39 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
 
   SDValue Idx = N->getOperand(0);
   SDValue TC = N->getOperand(1);
-  EVT OpVT = Idx.getValueType();
-  if (OpVT != MVT::i64) {
+  if (Idx.getValueType() != MVT::i64) {
     Idx = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, Idx);
     TC = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, TC);
   }
 
   // Create the whilelo_x2 intrinsics from each pair of extracts
   EVT ExtVT = Extracts[0]->getValueType(0);
+  EVT DoubleExtVT = ExtVT.getDoubleNumVectorElementsVT(*DAG.getContext());
   auto R =
       DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
   DCI.CombineTo(Extracts[0], R.getValue(0));
   DCI.CombineTo(Extracts[1], R.getValue(1));
+  SmallVector<SDValue> Concats = {DAG.getNode(
+      ISD::CONCAT_VECTORS, DL, DoubleExtVT, R.getValue(0), R.getValue(1))};
 
-  if (NumUses == 2)
-    return SDValue(N, 0);
+  if (NumExts == 2) {
+    assert(N->getValueType(0) == DoubleExtVT);
+    return Concats[0];
+  }
 
-  auto Elts = DAG.getElementCount(DL, OpVT, ExtVT.getVectorElementCount() * 2);
-  for (unsigned I = 2; I < NumUses; I += 2) {
+  auto Elts =
+      DAG.getElementCount(DL, MVT::i64, ExtVT.getVectorElementCount() * 2);
+  for (unsigned I = 2; I < NumExts; I += 2) {
     // After the first whilelo_x2, we need to increment the starting value.
-    Idx = DAG.getNode(ISD::UADDSAT, DL, OpVT, Idx, Elts);
+    Idx = DAG.getNode(ISD::UADDSAT, DL, MVT::i64, Idx, Elts);
     R = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
     DCI.CombineTo(Extracts[I], R.getValue(0));
     DCI.CombineTo(Extracts[I + 1], R.getValue(1));
+    Concats.push_back(DAG.getNode(ISD::CONCAT_VECTORS, DL, DoubleExtVT,
+                                  R.getValue(0), R.getValue(1)));
   }
 
-  return SDValue(N, 0);
+  return DAG.getNode(ISD::CONCAT_VECTORS, DL, N->getValueType(0), Concats);
 }
 
 // Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce

diff  --git a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
index 5e01612e3881a..b89f55188b0f2 100644
--- a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
+++ b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
@@ -310,6 +310,187 @@ define void @test_2x32bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #
   ret void
 }
 
+; Extra use of the get_active_lane_mask from an extractelement, which is replaced with ptest_first.
+
+define void @test_2x8bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
+; CHECK-SVE-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
+; CHECK-SVE:       // %bb.0: // %entry
+; CHECK-SVE-NEXT:    whilelo p1.b, x0, x1
+; CHECK-SVE-NEXT:    b.pl .LBB11_2
+; CHECK-SVE-NEXT:  // %bb.1: // %if.then
+; CHECK-SVE-NEXT:    punpklo p0.h, p1.b
+; CHECK-SVE-NEXT:    punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT:    b use
+; CHECK-SVE-NEXT:  .LBB11_2: // %if.end
+; CHECK-SVE-NEXT:    ret
+;
+; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
+; CHECK-SVE2p1-SME2:       // %bb.0: // %entry
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.h, p1.h }, x0, x1
+; CHECK-SVE2p1-SME2-NEXT:    ptrue p2.b
+; CHECK-SVE2p1-SME2-NEXT:    uzp1 p3.b, p0.b, p1.b
+; CHECK-SVE2p1-SME2-NEXT:    ptest p2, p3.b
+; CHECK-SVE2p1-SME2-NEXT:    b.pl .LBB11_2
+; CHECK-SVE2p1-SME2-NEXT:  // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT:    b use
+; CHECK-SVE2p1-SME2-NEXT:  .LBB11_2: // %if.end
+; CHECK-SVE2p1-SME2-NEXT:    ret
+entry:
+    %r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i64 %i, i64 %n)
+    %v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
+    %v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
+    %elt0 = extractelement <vscale x 16 x i1> %r, i32 0
+    br i1 %elt0, label %if.then, label %if.end
+
+if.then:
+    tail call void @use(<vscale x 8 x i1> %v0, <vscale x 8 x i1> %v1)
+    br label %if.end
+
+if.end:
+    ret void
+}
+
+; Extra use of the get_active_lane_mask from an extractelement, which is
+; replaced with ptest_first and reinterpret_casts because the extract is not nxv16i1.
+
+define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n) {
+; CHECK-SVE-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
+; CHECK-SVE:       // %bb.0: // %entry
+; CHECK-SVE-NEXT:    whilelo p1.h, x0, x1
+; CHECK-SVE-NEXT:    b.pl .LBB12_2
+; CHECK-SVE-NEXT:  // %bb.1: // %if.then
+; CHECK-SVE-NEXT:    punpklo p0.h, p1.b
+; CHECK-SVE-NEXT:    punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT:    b use
+; CHECK-SVE-NEXT:  .LBB12_2: // %if.end
+; CHECK-SVE-NEXT:    ret
+;
+; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
+; CHECK-SVE2p1-SME2:       // %bb.0: // %entry
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.s, p1.s }, x0, x1
+; CHECK-SVE2p1-SME2-NEXT:    ptrue p2.h
+; CHECK-SVE2p1-SME2-NEXT:    uzp1 p3.h, p0.h, p1.h
+; CHECK-SVE2p1-SME2-NEXT:    ptest p2, p3.b
+; CHECK-SVE2p1-SME2-NEXT:    b.pl .LBB12_2
+; CHECK-SVE2p1-SME2-NEXT:  // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT:    b use
+; CHECK-SVE2p1-SME2-NEXT:  .LBB12_2: // %if.end
+; CHECK-SVE2p1-SME2-NEXT:    ret
+entry:
+    %r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
+    %v0 = tail call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv8i1(<vscale x 8 x i1> %r, i64 0)
+    %v1 = tail call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv8i1(<vscale x 8 x i1> %r, i64 4)
+    %elt0 = extractelement <vscale x 8 x i1> %r, i64 0
+    br i1 %elt0, label %if.then, label %if.end
+
+if.then:
+    tail call void @use(<vscale x 4 x i1> %v0, <vscale x 4 x i1> %v1)
+    br label %if.end
+
+if.end:
+    ret void
+}
+
+define void @test_4x4bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
+; CHECK-SVE-LABEL: test_4x4bit_mask_with_extracts_and_ptest:
+; CHECK-SVE:       // %bb.0: // %entry
+; CHECK-SVE-NEXT:    whilelo p0.b, x0, x1
+; CHECK-SVE-NEXT:    b.pl .LBB13_2
+; CHECK-SVE-NEXT:  // %bb.1: // %if.then
+; CHECK-SVE-NEXT:    punpklo p1.h, p0.b
+; CHECK-SVE-NEXT:    punpkhi p3.h, p0.b
+; CHECK-SVE-NEXT:    punpklo p0.h, p1.b
+; CHECK-SVE-NEXT:    punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT:    punpklo p2.h, p3.b
+; CHECK-SVE-NEXT:    punpkhi p3.h, p3.b
+; CHECK-SVE-NEXT:    b use
+; CHECK-SVE-NEXT:  .LBB13_2: // %if.end
+; CHECK-SVE-NEXT:    ret
+;
+; CHECK-SVE2p1-SME2-LABEL: test_4x4bit_mask_with_extracts_and_ptest:
+; CHECK-SVE2p1-SME2:       // %bb.0: // %entry
+; CHECK-SVE2p1-SME2-NEXT:    cnth x8
+; CHECK-SVE2p1-SME2-NEXT:    adds x8, x0, x8
+; CHECK-SVE2p1-SME2-NEXT:    csinv x8, x8, xzr, lo
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.s, p1.s }, x0, x1
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p2.s, p3.s }, x8, x1
+; CHECK-SVE2p1-SME2-NEXT:    uzp1 p4.h, p0.h, p1.h
+; CHECK-SVE2p1-SME2-NEXT:    uzp1 p5.h, p2.h, p3.h
+; CHECK-SVE2p1-SME2-NEXT:    uzp1 p4.b, p4.b, p5.b
+; CHECK-SVE2p1-SME2-NEXT:    ptrue p5.b
+; CHECK-SVE2p1-SME2-NEXT:    ptest p5, p4.b
+; CHECK-SVE2p1-SME2-NEXT:    b.pl .LBB13_2
+; CHECK-SVE2p1-SME2-NEXT:  // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT:    b use
+; CHECK-SVE2p1-SME2-NEXT:  .LBB13_2: // %if.end
+; CHECK-SVE2p1-SME2-NEXT:    ret
+entry:
+    %r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i64 %i, i64 %n)
+    %v0 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
+    %v1 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 4)
+    %v2 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
+    %v3 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 12)
+    %elt0 = extractelement <vscale x 16 x i1> %r, i32 0
+    br i1 %elt0, label %if.then, label %if.end
+
+if.then:
+    tail call void @use(<vscale x 4 x i1> %v0, <vscale x 4 x i1> %v1, <vscale x 4 x i1> %v2, <vscale x 4 x i1> %v3)
+    br label %if.end
+
+if.end:
+    ret void
+}
+
+define void @test_4x2bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n) {
+; CHECK-SVE-LABEL: test_4x2bit_mask_with_extracts_and_reinterpret_casts:
+; CHECK-SVE:       // %bb.0: // %entry
+; CHECK-SVE-NEXT:    whilelo p0.h, x0, x1
+; CHECK-SVE-NEXT:    b.pl .LBB14_2
+; CHECK-SVE-NEXT:  // %bb.1: // %if.then
+; CHECK-SVE-NEXT:    punpklo p1.h, p0.b
+; CHECK-SVE-NEXT:    punpkhi p3.h, p0.b
+; CHECK-SVE-NEXT:    punpklo p0.h, p1.b
+; CHECK-SVE-NEXT:    punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT:    punpklo p2.h, p3.b
+; CHECK-SVE-NEXT:    punpkhi p3.h, p3.b
+; CHECK-SVE-NEXT:    b use
+; CHECK-SVE-NEXT:  .LBB14_2: // %if.end
+; CHECK-SVE-NEXT:    ret
+;
+; CHECK-SVE2p1-SME2-LABEL: test_4x2bit_mask_with_extracts_and_reinterpret_casts:
+; CHECK-SVE2p1-SME2:       // %bb.0: // %entry
+; CHECK-SVE2p1-SME2-NEXT:    cntw x8
+; CHECK-SVE2p1-SME2-NEXT:    adds x8, x0, x8
+; CHECK-SVE2p1-SME2-NEXT:    csinv x8, x8, xzr, lo
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.d, p1.d }, x0, x1
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p2.d, p3.d }, x8, x1
+; CHECK-SVE2p1-SME2-NEXT:    uzp1 p4.s, p0.s, p1.s
+; CHECK-SVE2p1-SME2-NEXT:    uzp1 p5.s, p2.s, p3.s
+; CHECK-SVE2p1-SME2-NEXT:    uzp1 p4.h, p4.h, p5.h
+; CHECK-SVE2p1-SME2-NEXT:    ptrue p5.h
+; CHECK-SVE2p1-SME2-NEXT:    ptest p5, p4.b
+; CHECK-SVE2p1-SME2-NEXT:    b.pl .LBB14_2
+; CHECK-SVE2p1-SME2-NEXT:  // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT:    b use
+; CHECK-SVE2p1-SME2-NEXT:  .LBB14_2: // %if.end
+; CHECK-SVE2p1-SME2-NEXT:    ret
+entry:
+    %r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i32(i64 %i, i64 %n)
+    %v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
+    %v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 2)
+    %v2 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 4)
+    %v3 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 6)
+    %elt0 = extractelement <vscale x 8 x i1> %r, i32 0
+    br i1 %elt0, label %if.then, label %if.end
+
+if.then:
+    tail call void @use(<vscale x 2 x i1> %v0, <vscale x 2 x i1> %v1, <vscale x 2 x i1> %v2, <vscale x 2 x i1> %v3)
+    br label %if.end
+
+if.end:
+    ret void
+}
+
 declare void @use(...)
 
 attributes #0 = { nounwind }

diff  --git a/llvm/unittests/Support/TypeSizeTest.cpp b/llvm/unittests/Support/TypeSizeTest.cpp
index b02b7e6009535..018b2405d4005 100644
--- a/llvm/unittests/Support/TypeSizeTest.cpp
+++ b/llvm/unittests/Support/TypeSizeTest.cpp
@@ -58,6 +58,7 @@ static_assert(ElementCount::getFixed(8).divideCoefficientBy(2) ==
 static_assert(ElementCount::getFixed(8).multiplyCoefficientBy(3) ==
               ElementCount::getFixed(24));
 static_assert(ElementCount::getFixed(8).isKnownMultipleOf(2));
+static_assert(!ElementCount::getFixed(8).isKnownMultipleOf(0));
 
 constexpr TypeSize TSFixed0 = TypeSize::getFixed(0);
 constexpr TypeSize TSFixed1 = TypeSize::getFixed(1);


        


More information about the llvm-commits mailing list