[llvm] [AArch64] Support lowering smaller than legal LOOP_DEP_MASKs to whilewr/rw (PR #171982)

Benjamin Maxwell via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 12 03:03:38 PST 2025


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/171982

>From 6bd943eb2d61d2e147384152e403bc5c1fc7e97d Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 12 Dec 2025 10:01:59 +0000
Subject: [PATCH 1/3] Precommit tests

---
 .../CodeGen/AArch64/alias_mask_scalable.ll    | 43 +++++++++++++++++++
 1 file changed, 43 insertions(+)

diff --git a/llvm/test/CodeGen/AArch64/alias_mask_scalable.ll b/llvm/test/CodeGen/AArch64/alias_mask_scalable.ll
index 8a2eff3fde396..559d75715c725 100644
--- a/llvm/test/CodeGen/AArch64/alias_mask_scalable.ll
+++ b/llvm/test/CodeGen/AArch64/alias_mask_scalable.ll
@@ -309,3 +309,46 @@ entry:
   %0 = call <vscale x 16 x i1> @llvm.loop.dependence.war.mask.nxv16i1(ptr %a, ptr %b, i64 3)
   ret <vscale x 16 x i1> %0
 }
+
+define <vscale x 8 x i1> @whilewr_extract_nxv8i1(ptr %a, ptr %b) {
+; CHECK-LABEL: whilewr_extract_nxv8i1:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sub x8, x1, x0
+; CHECK-NEXT:    cmp x8, #1
+; CHECK-NEXT:    csinv x8, x8, xzr, ge
+; CHECK-NEXT:    whilelo p0.h, xzr, x8
+; CHECK-NEXT:    ret
+entry:
+  %0 = call <vscale x 8 x i1> @llvm.loop.dependence.war.mask.nxv8i1(ptr %a, ptr %b, i64 1)
+  ret <vscale x 8 x i1> %0
+}
+
+define <vscale x 4 x i1> @whilewr_extract_nxv4i1(ptr %a, ptr %b) {
+; CHECK-LABEL: whilewr_extract_nxv4i1:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sub x8, x1, x0
+; CHECK-NEXT:    cmp x8, #1
+; CHECK-NEXT:    csinv x8, x8, xzr, ge
+; CHECK-NEXT:    whilelo p0.s, xzr, x8
+; CHECK-NEXT:    ret
+entry:
+  %0 = call <vscale x 4 x i1> @llvm.loop.dependence.war.mask.nxv8i1(ptr %a, ptr %b, i64 1)
+  ret <vscale x 4 x i1> %0
+}
+
+
+define <vscale x 2 x i1> @whilewr_extract_nxv2i1(ptr %a, ptr %b) {
+; CHECK-LABEL: whilewr_extract_nxv2i1:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    subs x8, x1, x0
+; CHECK-NEXT:    add x9, x8, #3
+; CHECK-NEXT:    csel x8, x9, x8, mi
+; CHECK-NEXT:    asr x8, x8, #2
+; CHECK-NEXT:    cmp x8, #1
+; CHECK-NEXT:    csinv x8, x8, xzr, ge
+; CHECK-NEXT:    whilelo p0.d, xzr, x8
+; CHECK-NEXT:    ret
+entry:
+  %0 = call <vscale x 2 x i1> @llvm.loop.dependence.war.mask.nxv8i1(ptr %a, ptr %b, i64 4)
+  ret <vscale x 2 x i1> %0
+}

>From d38d555dbb6c83e14b7c30696d65ec1ea94e2d0c Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 11 Dec 2025 12:06:07 +0000
Subject: [PATCH 2/3] [AArch64] Support lowering smaller than legal
 LOOP_DEP_MASKs to whilewr/rw

This adds support for lowering smaller-than-legal masks such as:

```
<vscale x 8 x i1> @llvm.loop.dependence.war.mask.nxv8i1(ptr %a, ptr %b, i64 1)
```

To a whilewr + unpack. It also slightly simplifies the lowering.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 68 ++++++++-----------
 .../CodeGen/AArch64/alias_mask_scalable.ll    | 22 ++----
 2 files changed, 35 insertions(+), 55 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 1ade1df88f010..4d3542c9277e8 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -5439,55 +5439,43 @@ static MVT getSVEContainerType(EVT ContentTy);
 SDValue
 AArch64TargetLowering::LowerLOOP_DEPENDENCE_MASK(SDValue Op,
                                                  SelectionDAG &DAG) const {
+  assert((Subtarget->hasSVE2() ||
+          (Subtarget->hasSME() && Subtarget->isStreaming())) &&
+         "Lowering loop_dependence_raw_mask or loop_dependence_war_mask "
+         "requires SVE or SME");
+
   SDLoc DL(Op);
   EVT VT = Op.getValueType();
-  SDValue EltSize = Op.getOperand(2);
-  switch (EltSize->getAsZExtVal()) {
-  case 1:
-    if (VT != MVT::v16i8 && VT != MVT::nxv16i1)
-      return SDValue();
-    break;
-  case 2:
-    if (VT != MVT::v8i8 && VT != MVT::nxv8i1)
-      return SDValue();
-    break;
-  case 4:
-    if (VT != MVT::v4i16 && VT != MVT::nxv4i1)
-      return SDValue();
-    break;
-  case 8:
-    if (VT != MVT::v2i32 && VT != MVT::nxv2i1)
-      return SDValue();
-    break;
-  default:
-    // Other element sizes are incompatible with whilewr/rw, so expand instead
-    return SDValue();
-  }
+  unsigned LaneOffset = Op.getConstantOperandVal(3);
+  unsigned NumElements = VT.getVectorMinNumElements();
+  uint64_t EltSizeInBytes = Op.getConstantOperandVal(2);
 
-  SDValue LaneOffset = Op.getOperand(3);
-  if (LaneOffset->getAsZExtVal())
+  // Lane offsets and other element sizes are not supported by whilewr/rw.
+  if (LaneOffset != 0 || !is_contained({1u, 2u, 4u, 8u}, EltSizeInBytes))
     return SDValue();
 
-  SDValue PtrA = Op.getOperand(0);
-  SDValue PtrB = Op.getOperand(1);
+  EVT EltVT = MVT::getIntegerVT(EltSizeInBytes * 8);
+  EVT PredVT = getPackedSVEVectorVT(EltVT).changeElementType(MVT::i1);
 
-  if (VT.isScalableVT())
-    return DAG.getNode(Op.getOpcode(), DL, VT, PtrA, PtrB, EltSize, LaneOffset);
+  // Legal whilewr/rw (lowered by tablegen matcher).
+  if (PredVT == VT)
+    return Op;
 
-  // We can use the SVE whilewr/whilerw instruction to lower this
-  // intrinsic by creating the appropriate sequence of scalable vector
-  // operations and then extracting a fixed-width subvector from the scalable
-  // vector. Scalable vector variants are already legal.
-  EVT ContainerVT =
-      EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
-                       VT.getVectorNumElements(), true);
-  EVT WhileVT = ContainerVT.changeElementType(MVT::i1);
+  // Expand if this mask needs splitting (this will produce a whilelo).
+  if (NumElements > PredVT.getVectorMinNumElements())
+    return SDValue();
 
   SDValue Mask =
-      DAG.getNode(Op.getOpcode(), DL, WhileVT, PtrA, PtrB, EltSize, LaneOffset);
-  SDValue MaskAsInt = DAG.getNode(ISD::SIGN_EXTEND, DL, ContainerVT, Mask);
-  return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, MaskAsInt,
-                     DAG.getVectorIdxConstant(0, DL));
+      DAG.getNode(Op.getOpcode(), DL, PredVT, to_vector(Op->op_values()));
+
+  if (VT.isFixedLengthVector()) {
+    EVT WidePredVT = PredVT.changeElementType(VT.getScalarType());
+    SDValue MaskAsInt = DAG.getNode(ISD::SIGN_EXTEND, DL, WidePredVT, Mask);
+    return convertFromScalableVector(DAG, VT, MaskAsInt);
+  }
+
+  return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Mask,
+                     DAG.getConstant(0, DL, MVT::i64));
 }
 
 SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op,
diff --git a/llvm/test/CodeGen/AArch64/alias_mask_scalable.ll b/llvm/test/CodeGen/AArch64/alias_mask_scalable.ll
index 559d75715c725..33f1c194ff3bb 100644
--- a/llvm/test/CodeGen/AArch64/alias_mask_scalable.ll
+++ b/llvm/test/CodeGen/AArch64/alias_mask_scalable.ll
@@ -313,10 +313,8 @@ entry:
 define <vscale x 8 x i1> @whilewr_extract_nxv8i1(ptr %a, ptr %b) {
 ; CHECK-LABEL: whilewr_extract_nxv8i1:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    sub x8, x1, x0
-; CHECK-NEXT:    cmp x8, #1
-; CHECK-NEXT:    csinv x8, x8, xzr, ge
-; CHECK-NEXT:    whilelo p0.h, xzr, x8
+; CHECK-NEXT:    whilewr p0.b, x0, x1
+; CHECK-NEXT:    punpklo p0.h, p0.b
 ; CHECK-NEXT:    ret
 entry:
   %0 = call <vscale x 8 x i1> @llvm.loop.dependence.war.mask.nxv8i1(ptr %a, ptr %b, i64 1)
@@ -326,10 +324,9 @@ entry:
 define <vscale x 4 x i1> @whilewr_extract_nxv4i1(ptr %a, ptr %b) {
 ; CHECK-LABEL: whilewr_extract_nxv4i1:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    sub x8, x1, x0
-; CHECK-NEXT:    cmp x8, #1
-; CHECK-NEXT:    csinv x8, x8, xzr, ge
-; CHECK-NEXT:    whilelo p0.s, xzr, x8
+; CHECK-NEXT:    whilewr p0.b, x0, x1
+; CHECK-NEXT:    punpklo p0.h, p0.b
+; CHECK-NEXT:    punpklo p0.h, p0.b
 ; CHECK-NEXT:    ret
 entry:
   %0 = call <vscale x 4 x i1> @llvm.loop.dependence.war.mask.nxv8i1(ptr %a, ptr %b, i64 1)
@@ -340,13 +337,8 @@ entry:
 define <vscale x 2 x i1> @whilewr_extract_nxv2i1(ptr %a, ptr %b) {
 ; CHECK-LABEL: whilewr_extract_nxv2i1:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    subs x8, x1, x0
-; CHECK-NEXT:    add x9, x8, #3
-; CHECK-NEXT:    csel x8, x9, x8, mi
-; CHECK-NEXT:    asr x8, x8, #2
-; CHECK-NEXT:    cmp x8, #1
-; CHECK-NEXT:    csinv x8, x8, xzr, ge
-; CHECK-NEXT:    whilelo p0.d, xzr, x8
+; CHECK-NEXT:    whilewr p0.s, x0, x1
+; CHECK-NEXT:    punpklo p0.h, p0.b
 ; CHECK-NEXT:    ret
 entry:
   %0 = call <vscale x 2 x i1> @llvm.loop.dependence.war.mask.nxv8i1(ptr %a, ptr %b, i64 4)

>From 7feacc0acd6469fce3606600434ec440e7ca4648 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 12 Dec 2025 11:02:46 +0000
Subject: [PATCH 3/3] Add fixed-length tests

---
 llvm/test/CodeGen/AArch64/alias_mask.ll       | 37 +++++++++++++++++++
 .../CodeGen/AArch64/alias_mask_scalable.ll    |  4 +-
 2 files changed, 39 insertions(+), 2 deletions(-)

diff --git a/llvm/test/CodeGen/AArch64/alias_mask.ll b/llvm/test/CodeGen/AArch64/alias_mask.ll
index bf393b6e87710..7a57fc3be84ac 100644
--- a/llvm/test/CodeGen/AArch64/alias_mask.ll
+++ b/llvm/test/CodeGen/AArch64/alias_mask.ll
@@ -563,3 +563,40 @@ entry:
   %0 = call <1 x i1> @llvm.loop.dependence.raw.mask.v1i1(ptr %a, ptr %b, i64 8)
   ret <1 x i1> %0
 }
+
+define <8 x i1> @whilewr_extract_v8i1(ptr %a, ptr %b) {
+; CHECK-LABEL: whilewr_extract_v8i1:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    whilewr p0.b, x0, x1
+; CHECK-NEXT:    mov z0.b, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
+; CHECK-NEXT:    ret
+entry:
+  %0 = call <8 x i1> @llvm.loop.dependence.war.mask.v8i1(ptr %a, ptr %b, i64 1)
+  ret <8 x i1> %0
+}
+
+define <4 x i1> @whilewr_extract_v4i1(ptr %a, ptr %b) {
+; CHECK-LABEL: whilewr_extract_v4i1:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    whilewr p0.b, x0, x1
+; CHECK-NEXT:    punpklo p0.h, p0.b
+; CHECK-NEXT:    mov z0.h, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
+; CHECK-NEXT:    ret
+entry:
+  %0 = call <4 x i1> @llvm.loop.dependence.war.mask.v4i1(ptr %a, ptr %b, i64 1)
+  ret <4 x i1> %0
+}
+
+define <2 x i1> @whilewr_extract_v2i1(ptr %a, ptr %b) {
+; CHECK-LABEL: whilewr_extract_v2i1:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    whilewr p0.s, x0, x1
+; CHECK-NEXT:    mov z0.s, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
+; CHECK-NEXT:    ret
+entry:
+  %0 = call <2 x i1> @llvm.loop.dependence.war.mask.v2i1(ptr %a, ptr %b, i64 4)
+  ret <2 x i1> %0
+}
diff --git a/llvm/test/CodeGen/AArch64/alias_mask_scalable.ll b/llvm/test/CodeGen/AArch64/alias_mask_scalable.ll
index 33f1c194ff3bb..e9463b5c571b6 100644
--- a/llvm/test/CodeGen/AArch64/alias_mask_scalable.ll
+++ b/llvm/test/CodeGen/AArch64/alias_mask_scalable.ll
@@ -329,7 +329,7 @@ define <vscale x 4 x i1> @whilewr_extract_nxv4i1(ptr %a, ptr %b) {
 ; CHECK-NEXT:    punpklo p0.h, p0.b
 ; CHECK-NEXT:    ret
 entry:
-  %0 = call <vscale x 4 x i1> @llvm.loop.dependence.war.mask.nxv8i1(ptr %a, ptr %b, i64 1)
+  %0 = call <vscale x 4 x i1> @llvm.loop.dependence.war.mask.nxv4i1(ptr %a, ptr %b, i64 1)
   ret <vscale x 4 x i1> %0
 }
 
@@ -341,6 +341,6 @@ define <vscale x 2 x i1> @whilewr_extract_nxv2i1(ptr %a, ptr %b) {
 ; CHECK-NEXT:    punpklo p0.h, p0.b
 ; CHECK-NEXT:    ret
 entry:
-  %0 = call <vscale x 2 x i1> @llvm.loop.dependence.war.mask.nxv8i1(ptr %a, ptr %b, i64 4)
+  %0 = call <vscale x 2 x i1> @llvm.loop.dependence.war.mask.nxv2i1(ptr %a, ptr %b, i64 4)
   ret <vscale x 2 x i1> %0
 }



More information about the llvm-commits mailing list