[llvm] df48dfa - [AArch64] Add custom lowering of nxv32i1 get.active.lane.mask nodes (#141969)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 6 05:35:39 PDT 2025


Author: Kerry McLaughlin
Date: 2025-06-06T13:35:34+01:00
New Revision: df48dfa0aece477f8f9990e26e91b43969851559

URL: https://github.com/llvm/llvm-project/commit/df48dfa0aece477f8f9990e26e91b43969851559
DIFF: https://github.com/llvm/llvm-project/commit/df48dfa0aece477f8f9990e26e91b43969851559.diff

LOG: [AArch64] Add custom lowering of nxv32i1 get.active.lane.mask nodes (#141969)

performActiveLaneMaskCombine already tries to combine a single
get.active.lane.mask where the low and high halves of the result are
extracted into a single whilelo which operates on a predicate pair.

If the get.active.lane.mask node requires splitting, multiple nodes are
created with saturating adds to increment the starting index. We cannot
combine these into a single whilelo_x2 at this point unless we know
the add will not overflow.

This patch adds custom lowering for the node if the return type is
nxv32xi1, as this can be replaced with a whilelo_x2 using legal types.
Anything wider than nxv32i1 will still require splitting first.

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e1b22b3eaf7bc..b11e1553e8146 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1511,6 +1511,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::GET_ACTIVE_LANE_MASK, VT, Legal);
     }
 
+    if (Subtarget->hasSVE2p1() ||
+        (Subtarget->hasSME2() && Subtarget->isStreaming()))
+      setOperationAction(ISD::GET_ACTIVE_LANE_MASK, MVT::nxv32i1, Custom);
+
     for (auto VT : {MVT::v16i8, MVT::v8i8, MVT::v4i16, MVT::v2i32})
       setOperationAction(ISD::GET_ACTIVE_LANE_MASK, VT, Custom);
   }
@@ -17981,7 +17985,7 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
                                                 /*IsEqual=*/false))
     return While;
 
-  if (!ST->hasSVE2p1())
+  if (!ST->hasSVE2p1() && !(ST->hasSME2() && ST->isStreaming()))
     return SDValue();
 
   if (!N->hasNUsesOfValue(2, 0))
@@ -27138,6 +27142,37 @@ void AArch64TargetLowering::ReplaceExtractSubVectorResults(
   Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, Half));
 }
 
+void AArch64TargetLowering::ReplaceGetActiveLaneMaskResults(
+    SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
+  assert((Subtarget->hasSVE2p1() ||
+          (Subtarget->hasSME2() && Subtarget->isStreaming())) &&
+         "Custom lower of get.active.lane.mask missing required feature.");
+
+  assert(N->getValueType(0) == MVT::nxv32i1 &&
+         "Unexpected result type for get.active.lane.mask");
+
+  SDLoc DL(N);
+  SDValue Idx = N->getOperand(0);
+  SDValue TC = N->getOperand(1);
+
+  assert(Idx.getValueType().getFixedSizeInBits() <= 64 &&
+         "Unexpected operand type for get.active.lane.mask");
+
+  if (Idx.getValueType() != MVT::i64) {
+    Idx = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, Idx);
+    TC = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, TC);
+  }
+
+  SDValue ID =
+      DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
+  EVT HalfVT = N->getValueType(0).getHalfNumVectorElementsVT(*DAG.getContext());
+  auto WideMask =
+      DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {HalfVT, HalfVT}, {ID, Idx, TC});
+
+  Results.push_back(DAG.getNode(ISD::CONCAT_VECTORS, DL, N->getValueType(0),
+                                {WideMask.getValue(0), WideMask.getValue(1)}));
+}
+
 // Create an even/odd pair of X registers holding integer value V.
 static SDValue createGPRPairNode(SelectionDAG &DAG, SDValue V) {
   SDLoc dl(V.getNode());
@@ -27524,6 +27559,9 @@ void AArch64TargetLowering::ReplaceNodeResults(
     // CONCAT_VECTORS -- but delegate to common code for result type
     // legalisation
     return;
+  case ISD::GET_ACTIVE_LANE_MASK:
+    ReplaceGetActiveLaneMaskResults(N, Results, DAG);
+    return;
   case ISD::INTRINSIC_WO_CHAIN: {
     EVT VT = N->getValueType(0);
 

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index b2174487c2fe8..7b7f020f7c771 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -822,6 +822,9 @@ class AArch64TargetLowering : public TargetLowering {
   void ReplaceExtractSubVectorResults(SDNode *N,
                                       SmallVectorImpl<SDValue> &Results,
                                       SelectionDAG &DAG) const;
+  void ReplaceGetActiveLaneMaskResults(SDNode *N,
+                                       SmallVectorImpl<SDValue> &Results,
+                                       SelectionDAG &DAG) const;
 
   bool shouldNormalizeToSelectSequence(LLVMContext &, EVT) const override;
 

diff  --git a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
index 2d84a69f3144e..c76b50d69b877 100644
--- a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
+++ b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
@@ -1,6 +1,7 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
 ; RUN: llc -mattr=+sve    < %s | FileCheck %s -check-prefix CHECK-SVE
-; RUN: llc -mattr=+sve2p1 < %s | FileCheck %s -check-prefix CHECK-SVE2p1
+; RUN: llc -mattr=+sve2p1 < %s | FileCheck %s -check-prefix CHECK-SVE2p1-SME2 -check-prefix CHECK-SVE2p1
+; RUN: llc -mattr=+sve -mattr=+sme2 -force-streaming < %s | FileCheck %s -check-prefix CHECK-SVE2p1-SME2 -check-prefix CHECK-SME2
 target triple = "aarch64-linux"
 
 ; Test combining of getActiveLaneMask with a pair of extract_vector operations.
@@ -13,12 +14,12 @@ define void @test_2x8bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #0
 ; CHECK-SVE-NEXT:    punpkhi p1.h, p1.b
 ; CHECK-SVE-NEXT:    b use
 ;
-; CHECK-SVE2p1-LABEL: test_2x8bit_mask_with_32bit_index_and_trip_count:
-; CHECK-SVE2p1:       // %bb.0:
-; CHECK-SVE2p1-NEXT:    mov w8, w1
-; CHECK-SVE2p1-NEXT:    mov w9, w0
-; CHECK-SVE2p1-NEXT:    whilelo { p0.h, p1.h }, x9, x8
-; CHECK-SVE2p1-NEXT:    b use
+; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_32bit_index_and_trip_count:
+; CHECK-SVE2p1-SME2:       // %bb.0:
+; CHECK-SVE2p1-SME2-NEXT:    mov w8, w1
+; CHECK-SVE2p1-SME2-NEXT:    mov w9, w0
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.h, p1.h }, x9, x8
+; CHECK-SVE2p1-SME2-NEXT:    b use
     %r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i32 %i, i32 %n)
     %v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
     %v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
@@ -34,10 +35,10 @@ define void @test_2x8bit_mask_with_64bit_index_and_trip_count(i64 %i, i64 %n) #0
 ; CHECK-SVE-NEXT:    punpkhi p1.h, p1.b
 ; CHECK-SVE-NEXT:    b use
 ;
-; CHECK-SVE2p1-LABEL: test_2x8bit_mask_with_64bit_index_and_trip_count:
-; CHECK-SVE2p1:       // %bb.0:
-; CHECK-SVE2p1-NEXT:    whilelo { p0.h, p1.h }, x0, x1
-; CHECK-SVE2p1-NEXT:    b use
+; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_64bit_index_and_trip_count:
+; CHECK-SVE2p1-SME2:       // %bb.0:
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.h, p1.h }, x0, x1
+; CHECK-SVE2p1-SME2-NEXT:    b use
     %r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 %i, i64 %n)
     %v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
     %v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
@@ -53,12 +54,12 @@ define void @test_edge_case_2x1bit_mask(i64 %i, i64 %n) #0 {
 ; CHECK-SVE-NEXT:    punpkhi p1.h, p1.b
 ; CHECK-SVE-NEXT:    b use
 ;
-; CHECK-SVE2p1-LABEL: test_edge_case_2x1bit_mask:
-; CHECK-SVE2p1:       // %bb.0:
-; CHECK-SVE2p1-NEXT:    whilelo p1.d, x0, x1
-; CHECK-SVE2p1-NEXT:    punpklo p0.h, p1.b
-; CHECK-SVE2p1-NEXT:    punpkhi p1.h, p1.b
-; CHECK-SVE2p1-NEXT:    b use
+; CHECK-SVE2p1-SME2-LABEL: test_edge_case_2x1bit_mask:
+; CHECK-SVE2p1-SME2:       // %bb.0:
+; CHECK-SVE2p1-SME2-NEXT:    whilelo p1.d, x0, x1
+; CHECK-SVE2p1-SME2-NEXT:    punpklo p0.h, p1.b
+; CHECK-SVE2p1-SME2-NEXT:    punpkhi p1.h, p1.b
+; CHECK-SVE2p1-SME2-NEXT:    b use
     %r = call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i64(i64 %i, i64 %n)
     %v0 = call <vscale x 1 x i1> @llvm.vector.extract.nxv1i1.nxv2i1.i64(<vscale x 2 x i1> %r, i64 0)
     %v1 = call <vscale x 1 x i1> @llvm.vector.extract.nxv1i1.nxv2i1.i64(<vscale x 2 x i1> %r, i64 1)
@@ -74,10 +75,10 @@ define void @test_boring_case_2x2bit_mask(i64 %i, i64 %n) #0 {
 ; CHECK-SVE-NEXT:    punpkhi p1.h, p1.b
 ; CHECK-SVE-NEXT:    b use
 ;
-; CHECK-SVE2p1-LABEL: test_boring_case_2x2bit_mask:
-; CHECK-SVE2p1:       // %bb.0:
-; CHECK-SVE2p1-NEXT:    whilelo { p0.d, p1.d }, x0, x1
-; CHECK-SVE2p1-NEXT:    b use
+; CHECK-SVE2p1-SME2-LABEL: test_boring_case_2x2bit_mask:
+; CHECK-SVE2p1-SME2:       // %bb.0:
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.d, p1.d }, x0, x1
+; CHECK-SVE2p1-SME2-NEXT:    b use
     %r = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64 %i, i64 %n)
     %v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv4i1.i64(<vscale x 4 x i1> %r, i64 0)
     %v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv4i1.i64(<vscale x 4 x i1> %r, i64 2)
@@ -96,14 +97,14 @@ define void @test_partial_extract(i64 %i, i64 %n) #0 {
 ; CHECK-SVE-NEXT:    punpklo p1.h, p2.b
 ; CHECK-SVE-NEXT:    b use
 ;
-; CHECK-SVE2p1-LABEL: test_partial_extract:
-; CHECK-SVE2p1:       // %bb.0:
-; CHECK-SVE2p1-NEXT:    whilelo p0.h, x0, x1
-; CHECK-SVE2p1-NEXT:    punpklo p1.h, p0.b
-; CHECK-SVE2p1-NEXT:    punpkhi p2.h, p0.b
-; CHECK-SVE2p1-NEXT:    punpklo p0.h, p1.b
-; CHECK-SVE2p1-NEXT:    punpklo p1.h, p2.b
-; CHECK-SVE2p1-NEXT:    b use
+; CHECK-SVE2p1-SME2-LABEL: test_partial_extract:
+; CHECK-SVE2p1-SME2:       // %bb.0:
+; CHECK-SVE2p1-SME2-NEXT:    whilelo p0.h, x0, x1
+; CHECK-SVE2p1-SME2-NEXT:    punpklo p1.h, p0.b
+; CHECK-SVE2p1-SME2-NEXT:    punpkhi p2.h, p0.b
+; CHECK-SVE2p1-SME2-NEXT:    punpklo p0.h, p1.b
+; CHECK-SVE2p1-SME2-NEXT:    punpklo p1.h, p2.b
+; CHECK-SVE2p1-SME2-NEXT:    b use
     %r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
     %v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
     %v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 4)
@@ -111,7 +112,7 @@ define void @test_partial_extract(i64 %i, i64 %n) #0 {
     ret void
 }
 
-;; Negative test for when extracting a fixed-length vector.
+; Negative test for when extracting a fixed-length vector.
 define void @test_fixed_extract(i64 %i, i64 %n) #0 {
 ; CHECK-SVE-LABEL: test_fixed_extract:
 ; CHECK-SVE:       // %bb.0:
@@ -144,6 +145,21 @@ define void @test_fixed_extract(i64 %i, i64 %n) #0 {
 ; CHECK-SVE2p1-NEXT:    mov v1.s[1], w11
 ; CHECK-SVE2p1-NEXT:    // kill: def $d1 killed $d1 killed $q1
 ; CHECK-SVE2p1-NEXT:    b use
+;
+; CHECK-SME2-LABEL: test_fixed_extract:
+; CHECK-SME2:       // %bb.0:
+; CHECK-SME2-NEXT:    whilelo p0.h, x0, x1
+; CHECK-SME2-NEXT:    cset w8, mi
+; CHECK-SME2-NEXT:    mov z0.h, p0/z, #1 // =0x1
+; CHECK-SME2-NEXT:    mov z1.h, z0.h[1]
+; CHECK-SME2-NEXT:    mov z2.h, z0.h[5]
+; CHECK-SME2-NEXT:    mov z3.h, z0.h[4]
+; CHECK-SME2-NEXT:    fmov s0, w8
+; CHECK-SME2-NEXT:    zip1 z0.s, z0.s, z1.s
+; CHECK-SME2-NEXT:    zip1 z1.s, z3.s, z2.s
+; CHECK-SME2-NEXT:    // kill: def $d0 killed $d0 killed $z0
+; CHECK-SME2-NEXT:    // kill: def $d1 killed $d1 killed $z1
+; CHECK-SME2-NEXT:    b use
     %r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
     %v0 = call <2 x i1> @llvm.vector.extract.v2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
     %v1 = call <2 x i1> @llvm.vector.extract.v2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 4)
@@ -151,6 +167,67 @@ define void @test_fixed_extract(i64 %i, i64 %n) #0 {
     ret void
 }
 
+; Illegal Types
+
+define void @test_2x16bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #0 {
+; CHECK-SVE-LABEL: test_2x16bit_mask_with_32bit_index_and_trip_count:
+; CHECK-SVE:       // %bb.0:
+; CHECK-SVE-NEXT:    rdvl x8, #1
+; CHECK-SVE-NEXT:    adds w8, w0, w8
+; CHECK-SVE-NEXT:    csinv w8, w8, wzr, lo
+; CHECK-SVE-NEXT:    whilelo p0.b, w0, w1
+; CHECK-SVE-NEXT:    whilelo p1.b, w8, w1
+; CHECK-SVE-NEXT:    b use
+;
+; CHECK-SVE2p1-SME2-LABEL: test_2x16bit_mask_with_32bit_index_and_trip_count:
+; CHECK-SVE2p1-SME2:       // %bb.0:
+; CHECK-SVE2p1-SME2-NEXT:    mov w8, w1
+; CHECK-SVE2p1-SME2-NEXT:    mov w9, w0
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.b, p1.b }, x9, x8
+; CHECK-SVE2p1-SME2-NEXT:    b use
+  %r = call <vscale x 32 x i1> @llvm.get.active.lane.mask.nxv32i1.i32(i32 %i, i32 %n)
+  %v0 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 0)
+  %v1 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 16)
+  tail call void @use(<vscale x 16 x i1> %v0, <vscale x 16 x i1> %v1)
+  ret void
+}
+
+define void @test_2x32bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #0 {
+; CHECK-SVE-LABEL: test_2x32bit_mask_with_32bit_index_and_trip_count:
+; CHECK-SVE:       // %bb.0:
+; CHECK-SVE-NEXT:    rdvl x8, #2
+; CHECK-SVE-NEXT:    rdvl x9, #1
+; CHECK-SVE-NEXT:    adds w8, w0, w8
+; CHECK-SVE-NEXT:    csinv w8, w8, wzr, lo
+; CHECK-SVE-NEXT:    adds w10, w8, w9
+; CHECK-SVE-NEXT:    csinv w10, w10, wzr, lo
+; CHECK-SVE-NEXT:    whilelo p3.b, w10, w1
+; CHECK-SVE-NEXT:    adds w9, w0, w9
+; CHECK-SVE-NEXT:    csinv w9, w9, wzr, lo
+; CHECK-SVE-NEXT:    whilelo p0.b, w0, w1
+; CHECK-SVE-NEXT:    whilelo p1.b, w9, w1
+; CHECK-SVE-NEXT:    whilelo p2.b, w8, w1
+; CHECK-SVE-NEXT:    b use
+;
+; CHECK-SVE2p1-SME2-LABEL: test_2x32bit_mask_with_32bit_index_and_trip_count:
+; CHECK-SVE2p1-SME2:       // %bb.0:
+; CHECK-SVE2p1-SME2-NEXT:    rdvl x8, #2
+; CHECK-SVE2p1-SME2-NEXT:    mov w9, w1
+; CHECK-SVE2p1-SME2-NEXT:    mov w10, w0
+; CHECK-SVE2p1-SME2-NEXT:    adds w8, w0, w8
+; CHECK-SVE2p1-SME2-NEXT:    csinv w8, w8, wzr, lo
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.b, p1.b }, x10, x9
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p2.b, p3.b }, x8, x9
+; CHECK-SVE2p1-SME2-NEXT:    b use
+  %r = call <vscale x 64 x i1> @llvm.get.active.lane.mask.nxv64i1.i32(i32 %i, i32 %n)
+  %v0 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv64i1.i64(<vscale x 64 x i1> %r, i64 0)
+  %v1 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv64i1.i64(<vscale x 64 x i1> %r, i64 16)
+  %v2 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv64i1.i64(<vscale x 64 x i1> %r, i64 32)
+  %v3 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv64i1.i64(<vscale x 64 x i1> %r, i64 48)
+  tail call void @use(<vscale x 16 x i1> %v0, <vscale x 16 x i1> %v1, <vscale x 16 x i1> %v2, <vscale x 16 x i1> %v3)
+  ret void
+}
+
 declare void @use(...)
 
 attributes #0 = { nounwind }


        


More information about the llvm-commits mailing list