[llvm] [AArch64] Fix active.lane.mask(0, cttz.elts(x)) -> 'brkb' transform (PR #180177)

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 10 09:30:11 PST 2026


https://github.com/sdesmalen-arm updated https://github.com/llvm/llvm-project/pull/180177

>From dc75c215ebb0c84c2c2c704d468513492f852240 Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Fri, 6 Feb 2026 10:22:44 +0000
Subject: [PATCH 1/4] Pre-commit test

---
 .../CodeGen/AArch64/sve-mask-partition.ll     | 52 +++++++++++++++++++
 1 file changed, 52 insertions(+)

diff --git a/llvm/test/CodeGen/AArch64/sve-mask-partition.ll b/llvm/test/CodeGen/AArch64/sve-mask-partition.ll
index 8b712bd7e42a7..e2e7d80545dc8 100644
--- a/llvm/test/CodeGen/AArch64/sve-mask-partition.ll
+++ b/llvm/test/CodeGen/AArch64/sve-mask-partition.ll
@@ -558,3 +558,55 @@ define <vscale x 16 x i1> @mask_exclude_active_nxv16_nonzero_lower_bound(<vscale
   %mask.out = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 1, i64 %tz.elts)
   ret <vscale x 16 x i1> %mask.out
 }
+
+define <vscale x 4 x i1> @mask_exclude_active_narrower_result_type(<vscale x 8 x i1> %mask.in) {
+; CHECK-LABEL: mask_exclude_active_narrower_result_type:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p1.h
+; CHECK-NEXT:    brkb p0.b, p1/z, p0.b
+; CHECK-NEXT:    ret
+  %tz.elts = call i64 @llvm.experimental.cttz.elts.i64.nxv16i1(<vscale x 8 x i1> %mask.in, i1 false)
+  %mask.out = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 0, i64 %tz.elts)
+  ret <vscale x 4 x i1> %mask.out
+}
+
+define <vscale x 16 x i1> @mask_exclude_active_wider_result_type(<vscale x 8 x i1> %mask.in) {
+; CHECK-LABEL: mask_exclude_active_wider_result_type:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p1.h
+; CHECK-NEXT:    brkb p0.b, p1/z, p0.b
+; CHECK-NEXT:    ret
+  %tz.elts = call i64 @llvm.experimental.cttz.elts.i64.nxv16i1(<vscale x 8 x i1> %mask.in, i1 false)
+  %mask.out = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 0, i64 %tz.elts)
+  ret <vscale x 16 x i1> %mask.out
+}
+
+define <4 x i1> @mask_exclude_active_narrower_result_type_fixed(<8 x i1> %mask.in) {
+; CHECK-LABEL: mask_exclude_active_narrower_result_type_fixed:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    shl v0.8b, v0.8b, #7
+; CHECK-NEXT:    ptrue p0.b, vl8
+; CHECK-NEXT:    cmpne p1.b, p0/z, z0.b, #0
+; CHECK-NEXT:    brkb p0.b, p0/z, p1.b
+; CHECK-NEXT:    mov z0.h, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
+; CHECK-NEXT:    ret
+  %tz.elts = call i64 @llvm.experimental.cttz.elts.i64.nxv16i1(<8 x i1> %mask.in, i1 false)
+  %mask.out = call <4 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 0, i64 %tz.elts)
+  ret <4 x i1> %mask.out
+}
+
+define <16 x i1> @mask_exclude_active_wider_result_type_fixed(<8 x i1> %mask.in) {
+; CHECK-LABEL: mask_exclude_active_wider_result_type_fixed:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    shl v0.8b, v0.8b, #7
+; CHECK-NEXT:    ptrue p0.b, vl8
+; CHECK-NEXT:    cmpne p1.b, p0/z, z0.b, #0
+; CHECK-NEXT:    brkb p0.b, p0/z, p1.b
+; CHECK-NEXT:    mov z0.b, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+  %tz.elts = call i64 @llvm.experimental.cttz.elts.i64.nxv16i1(<8 x i1> %mask.in, i1 false)
+  %mask.out = call <16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 0, i64 %tz.elts)
+  ret <16 x i1> %mask.out
+}

>From fb96582be5ac6268522bc211a1d6d7ff0bb7180c Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Fri, 6 Feb 2026 10:58:01 +0000
Subject: [PATCH 2/4] Don't optimize to 'brkb' if the result type and mask
 don't match.

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 5 ++++-
 llvm/test/CodeGen/AArch64/sve-mask-partition.ll | 6 ++++++
 2 files changed, 10 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 94216f6572f0a..20513411558e7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -6049,11 +6049,14 @@ static SDValue optimizeBrk(SDNode *N, SelectionDAG &DAG) {
   // We're looking for an upper bound based on CTTZ_ELTS; this would be selected
   // as a cntp(brk(Pg, Mask)), but if we're just going to make a whilelo based
   // on that then we just need the brk.
-  if (Upper.getOpcode() != AArch64ISD::CTTZ_ELTS || !VT.isScalableVector())
+  if (Upper.getOpcode() != AArch64ISD::CTTZ_ELTS || !VT.isScalableVector() ||
+      Upper.getOperand(0).getValueType() != VT)
     return SDValue();
 
   SDValue Pg = Upper->getOperand(0);
   SDValue Mask = Upper->getOperand(1);
+  assert(Pg.getValueType() == Mask.getValueType() &&
+         "predicate types must match");
 
   // brk{a,b} only support .b forms, so cast to make sure all our p regs match.
   Pg = getSVEPredicateBitCast(MVT::nxv16i1, Pg, DAG);
diff --git a/llvm/test/CodeGen/AArch64/sve-mask-partition.ll b/llvm/test/CodeGen/AArch64/sve-mask-partition.ll
index e2e7d80545dc8..30c5220decda7 100644
--- a/llvm/test/CodeGen/AArch64/sve-mask-partition.ll
+++ b/llvm/test/CodeGen/AArch64/sve-mask-partition.ll
@@ -564,6 +564,8 @@ define <vscale x 4 x i1> @mask_exclude_active_narrower_result_type(<vscale x 8 x
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    ptrue p1.h
 ; CHECK-NEXT:    brkb p0.b, p1/z, p0.b
+; CHECK-NEXT:    cntp x8, p0, p0.h
+; CHECK-NEXT:    whilelo p0.s, xzr, x8
 ; CHECK-NEXT:    ret
   %tz.elts = call i64 @llvm.experimental.cttz.elts.i64.nxv16i1(<vscale x 8 x i1> %mask.in, i1 false)
   %mask.out = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 0, i64 %tz.elts)
@@ -575,6 +577,8 @@ define <vscale x 16 x i1> @mask_exclude_active_wider_result_type(<vscale x 8 x i
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    ptrue p1.h
 ; CHECK-NEXT:    brkb p0.b, p1/z, p0.b
+; CHECK-NEXT:    cntp x8, p0, p0.h
+; CHECK-NEXT:    whilelo p0.b, xzr, x8
 ; CHECK-NEXT:    ret
   %tz.elts = call i64 @llvm.experimental.cttz.elts.i64.nxv16i1(<vscale x 8 x i1> %mask.in, i1 false)
   %mask.out = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 0, i64 %tz.elts)
@@ -588,6 +592,8 @@ define <4 x i1> @mask_exclude_active_narrower_result_type_fixed(<8 x i1> %mask.i
 ; CHECK-NEXT:    ptrue p0.b, vl8
 ; CHECK-NEXT:    cmpne p1.b, p0/z, z0.b, #0
 ; CHECK-NEXT:    brkb p0.b, p0/z, p1.b
+; CHECK-NEXT:    cntp x8, p0, p0.b
+; CHECK-NEXT:    whilelo p0.h, xzr, x8
 ; CHECK-NEXT:    mov z0.h, p0/z, #-1 // =0xffffffffffffffff
 ; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
 ; CHECK-NEXT:    ret

>From 95ead3493f931ae6568812b7a4c81f6648f0de34 Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Fri, 6 Feb 2026 13:48:04 +0000
Subject: [PATCH 3/4] Move assert to verifyTargetNode

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp     | 2 --
 llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp | 4 ++++
 2 files changed, 4 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 20513411558e7..0929a37af6850 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -6055,8 +6055,6 @@ static SDValue optimizeBrk(SDNode *N, SelectionDAG &DAG) {
 
   SDValue Pg = Upper->getOperand(0);
   SDValue Mask = Upper->getOperand(1);
-  assert(Pg.getValueType() == Mask.getValueType() &&
-         "predicate types must match");
 
   // brk{a,b} only support .b forms, so cast to make sure all our p regs match.
   Pg = getSVEPredicateBitCast(MVT::nxv16i1, Pg, DAG);
diff --git a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
index dbbb871bcec33..8653c72a5c96a 100644
--- a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
@@ -48,6 +48,10 @@ void AArch64SelectionDAGInfo::verifyTargetNode(const SelectionDAG &DAG,
 #ifndef NDEBUG
   // Some additional checks not yet implemented by verifyTargetNode.
   switch (N->getOpcode()) {
+  case AArch64ISD::CTTZ_ELTS:
+    assert(N->getOperand(0).getValueType() == N->getOperand(1).getValueType() &&
+           "Expected the mask and general-predicate have matching types");
+    break;
   case AArch64ISD::SADDWT:
   case AArch64ISD::SADDWB:
   case AArch64ISD::UADDWT:

>From 5a3f41067778dbe343965ad076bc3495f5c3d84e Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Tue, 10 Feb 2026 17:28:26 +0000
Subject: [PATCH 4/4] Address comments

---
 .../Target/AArch64/AArch64SelectionDAGInfo.cpp   |  2 +-
 llvm/test/CodeGen/AArch64/sve-mask-partition.ll  | 16 ++++++++--------
 2 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
index 8653c72a5c96a..3c208e01d7f63 100644
--- a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
@@ -50,7 +50,7 @@ void AArch64SelectionDAGInfo::verifyTargetNode(const SelectionDAG &DAG,
   switch (N->getOpcode()) {
   case AArch64ISD::CTTZ_ELTS:
     assert(N->getOperand(0).getValueType() == N->getOperand(1).getValueType() &&
-           "Expected the mask and general-predicate have matching types");
+           "Expected the general-predicate and mask to have matching types");
     break;
   case AArch64ISD::SADDWT:
   case AArch64ISD::SADDWB:
diff --git a/llvm/test/CodeGen/AArch64/sve-mask-partition.ll b/llvm/test/CodeGen/AArch64/sve-mask-partition.ll
index 30c5220decda7..e4bad94f08b45 100644
--- a/llvm/test/CodeGen/AArch64/sve-mask-partition.ll
+++ b/llvm/test/CodeGen/AArch64/sve-mask-partition.ll
@@ -567,8 +567,8 @@ define <vscale x 4 x i1> @mask_exclude_active_narrower_result_type(<vscale x 8 x
 ; CHECK-NEXT:    cntp x8, p0, p0.h
 ; CHECK-NEXT:    whilelo p0.s, xzr, x8
 ; CHECK-NEXT:    ret
-  %tz.elts = call i64 @llvm.experimental.cttz.elts.i64.nxv16i1(<vscale x 8 x i1> %mask.in, i1 false)
-  %mask.out = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 0, i64 %tz.elts)
+  %tz.elts = call i64 @llvm.experimental.cttz.elts(<vscale x 8 x i1> %mask.in, i1 false)
+  %mask.out = call <vscale x 4 x i1> @llvm.get.active.lane.mask(i64 0, i64 %tz.elts)
   ret <vscale x 4 x i1> %mask.out
 }
 
@@ -580,8 +580,8 @@ define <vscale x 16 x i1> @mask_exclude_active_wider_result_type(<vscale x 8 x i
 ; CHECK-NEXT:    cntp x8, p0, p0.h
 ; CHECK-NEXT:    whilelo p0.b, xzr, x8
 ; CHECK-NEXT:    ret
-  %tz.elts = call i64 @llvm.experimental.cttz.elts.i64.nxv16i1(<vscale x 8 x i1> %mask.in, i1 false)
-  %mask.out = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 0, i64 %tz.elts)
+  %tz.elts = call i64 @llvm.experimental.cttz.elts(<vscale x 8 x i1> %mask.in, i1 false)
+  %mask.out = call <vscale x 16 x i1> @llvm.get.active.lane.mask(i64 0, i64 %tz.elts)
   ret <vscale x 16 x i1> %mask.out
 }
 
@@ -597,8 +597,8 @@ define <4 x i1> @mask_exclude_active_narrower_result_type_fixed(<8 x i1> %mask.i
 ; CHECK-NEXT:    mov z0.h, p0/z, #-1 // =0xffffffffffffffff
 ; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
 ; CHECK-NEXT:    ret
-  %tz.elts = call i64 @llvm.experimental.cttz.elts.i64.nxv16i1(<8 x i1> %mask.in, i1 false)
-  %mask.out = call <4 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 0, i64 %tz.elts)
+  %tz.elts = call i64 @llvm.experimental.cttz.elts(<8 x i1> %mask.in, i1 false)
+  %mask.out = call <4 x i1> @llvm.get.active.lane.mask(i64 0, i64 %tz.elts)
   ret <4 x i1> %mask.out
 }
 
@@ -612,7 +612,7 @@ define <16 x i1> @mask_exclude_active_wider_result_type_fixed(<8 x i1> %mask.in)
 ; CHECK-NEXT:    mov z0.b, p0/z, #-1 // =0xffffffffffffffff
 ; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
 ; CHECK-NEXT:    ret
-  %tz.elts = call i64 @llvm.experimental.cttz.elts.i64.nxv16i1(<8 x i1> %mask.in, i1 false)
-  %mask.out = call <16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 0, i64 %tz.elts)
+  %tz.elts = call i64 @llvm.experimental.cttz.elts(<8 x i1> %mask.in, i1 false)
+  %mask.out = call <16 x i1> @llvm.get.active.lane.mask(i64 0, i64 %tz.elts)
   ret <16 x i1> %mask.out
 }



More information about the llvm-commits mailing list