[llvm] [AArch64] Support lowering smaller than legal LOOP_DEP_MASKs to whilewr/rw (PR #171982)
Benjamin Maxwell via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 12 02:11:02 PST 2025
https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/171982
This adds support for lowering smaller-than-legal masks such as:
```
<vscale x 8 x i1> @llvm.loop.dependence.war.mask.nxv8i1(ptr %a, ptr %b, i64 1)
```
To a whilewr + unpack. It also slightly simplifies the lowering.
>From 6bd943eb2d61d2e147384152e403bc5c1fc7e97d Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 12 Dec 2025 10:01:59 +0000
Subject: [PATCH 1/2] Precommit tests
---
.../CodeGen/AArch64/alias_mask_scalable.ll | 43 +++++++++++++++++++
1 file changed, 43 insertions(+)
diff --git a/llvm/test/CodeGen/AArch64/alias_mask_scalable.ll b/llvm/test/CodeGen/AArch64/alias_mask_scalable.ll
index 8a2eff3fde396..559d75715c725 100644
--- a/llvm/test/CodeGen/AArch64/alias_mask_scalable.ll
+++ b/llvm/test/CodeGen/AArch64/alias_mask_scalable.ll
@@ -309,3 +309,46 @@ entry:
%0 = call <vscale x 16 x i1> @llvm.loop.dependence.war.mask.nxv16i1(ptr %a, ptr %b, i64 3)
ret <vscale x 16 x i1> %0
}
+
+define <vscale x 8 x i1> @whilewr_extract_nxv8i1(ptr %a, ptr %b) {
+; CHECK-LABEL: whilewr_extract_nxv8i1:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: sub x8, x1, x0
+; CHECK-NEXT: cmp x8, #1
+; CHECK-NEXT: csinv x8, x8, xzr, ge
+; CHECK-NEXT: whilelo p0.h, xzr, x8
+; CHECK-NEXT: ret
+entry:
+ %0 = call <vscale x 8 x i1> @llvm.loop.dependence.war.mask.nxv8i1(ptr %a, ptr %b, i64 1)
+ ret <vscale x 8 x i1> %0
+}
+
+define <vscale x 4 x i1> @whilewr_extract_nxv4i1(ptr %a, ptr %b) {
+; CHECK-LABEL: whilewr_extract_nxv4i1:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: sub x8, x1, x0
+; CHECK-NEXT: cmp x8, #1
+; CHECK-NEXT: csinv x8, x8, xzr, ge
+; CHECK-NEXT: whilelo p0.s, xzr, x8
+; CHECK-NEXT: ret
+entry:
+ %0 = call <vscale x 4 x i1> @llvm.loop.dependence.war.mask.nxv8i1(ptr %a, ptr %b, i64 1)
+ ret <vscale x 4 x i1> %0
+}
+
+
+define <vscale x 2 x i1> @whilewr_extract_nxv2i1(ptr %a, ptr %b) {
+; CHECK-LABEL: whilewr_extract_nxv2i1:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: subs x8, x1, x0
+; CHECK-NEXT: add x9, x8, #3
+; CHECK-NEXT: csel x8, x9, x8, mi
+; CHECK-NEXT: asr x8, x8, #2
+; CHECK-NEXT: cmp x8, #1
+; CHECK-NEXT: csinv x8, x8, xzr, ge
+; CHECK-NEXT: whilelo p0.d, xzr, x8
+; CHECK-NEXT: ret
+entry:
+ %0 = call <vscale x 2 x i1> @llvm.loop.dependence.war.mask.nxv8i1(ptr %a, ptr %b, i64 4)
+ ret <vscale x 2 x i1> %0
+}
>From d38d555dbb6c83e14b7c30696d65ec1ea94e2d0c Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 11 Dec 2025 12:06:07 +0000
Subject: [PATCH 2/2] [AArch64] Support lowering smaller than legal
LOOP_DEP_MASKs to whilewr/rw
This adds support for lowering smaller-than-legal masks such as:
```
<vscale x 8 x i1> @llvm.loop.dependence.war.mask.nxv8i1(ptr %a, ptr %b, i64 1)
```
To a whilewr + unpack. It also slightly simplifies the lowering.
---
.../Target/AArch64/AArch64ISelLowering.cpp | 68 ++++++++-----------
.../CodeGen/AArch64/alias_mask_scalable.ll | 22 ++----
2 files changed, 35 insertions(+), 55 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 1ade1df88f010..4d3542c9277e8 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -5439,55 +5439,43 @@ static MVT getSVEContainerType(EVT ContentTy);
SDValue
AArch64TargetLowering::LowerLOOP_DEPENDENCE_MASK(SDValue Op,
SelectionDAG &DAG) const {
+ assert((Subtarget->hasSVE2() ||
+ (Subtarget->hasSME() && Subtarget->isStreaming())) &&
+ "Lowering loop_dependence_raw_mask or loop_dependence_war_mask "
+ "requires SVE or SME");
+
SDLoc DL(Op);
EVT VT = Op.getValueType();
- SDValue EltSize = Op.getOperand(2);
- switch (EltSize->getAsZExtVal()) {
- case 1:
- if (VT != MVT::v16i8 && VT != MVT::nxv16i1)
- return SDValue();
- break;
- case 2:
- if (VT != MVT::v8i8 && VT != MVT::nxv8i1)
- return SDValue();
- break;
- case 4:
- if (VT != MVT::v4i16 && VT != MVT::nxv4i1)
- return SDValue();
- break;
- case 8:
- if (VT != MVT::v2i32 && VT != MVT::nxv2i1)
- return SDValue();
- break;
- default:
- // Other element sizes are incompatible with whilewr/rw, so expand instead
- return SDValue();
- }
+ unsigned LaneOffset = Op.getConstantOperandVal(3);
+ unsigned NumElements = VT.getVectorMinNumElements();
+ uint64_t EltSizeInBytes = Op.getConstantOperandVal(2);
- SDValue LaneOffset = Op.getOperand(3);
- if (LaneOffset->getAsZExtVal())
+ // Lane offsets and other element sizes are not supported by whilewr/rw.
+ if (LaneOffset != 0 || !is_contained({1u, 2u, 4u, 8u}, EltSizeInBytes))
return SDValue();
- SDValue PtrA = Op.getOperand(0);
- SDValue PtrB = Op.getOperand(1);
+ EVT EltVT = MVT::getIntegerVT(EltSizeInBytes * 8);
+ EVT PredVT = getPackedSVEVectorVT(EltVT).changeElementType(MVT::i1);
- if (VT.isScalableVT())
- return DAG.getNode(Op.getOpcode(), DL, VT, PtrA, PtrB, EltSize, LaneOffset);
+ // Legal whilewr/rw (lowered by tablegen matcher).
+ if (PredVT == VT)
+ return Op;
- // We can use the SVE whilewr/whilerw instruction to lower this
- // intrinsic by creating the appropriate sequence of scalable vector
- // operations and then extracting a fixed-width subvector from the scalable
- // vector. Scalable vector variants are already legal.
- EVT ContainerVT =
- EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
- VT.getVectorNumElements(), true);
- EVT WhileVT = ContainerVT.changeElementType(MVT::i1);
+ // Expand if this mask needs splitting (this will produce a whilelo).
+ if (NumElements > PredVT.getVectorMinNumElements())
+ return SDValue();
SDValue Mask =
- DAG.getNode(Op.getOpcode(), DL, WhileVT, PtrA, PtrB, EltSize, LaneOffset);
- SDValue MaskAsInt = DAG.getNode(ISD::SIGN_EXTEND, DL, ContainerVT, Mask);
- return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, MaskAsInt,
- DAG.getVectorIdxConstant(0, DL));
+ DAG.getNode(Op.getOpcode(), DL, PredVT, to_vector(Op->op_values()));
+
+ if (VT.isFixedLengthVector()) {
+ EVT WidePredVT = PredVT.changeElementType(VT.getScalarType());
+ SDValue MaskAsInt = DAG.getNode(ISD::SIGN_EXTEND, DL, WidePredVT, Mask);
+ return convertFromScalableVector(DAG, VT, MaskAsInt);
+ }
+
+ return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Mask,
+ DAG.getConstant(0, DL, MVT::i64));
}
SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op,
diff --git a/llvm/test/CodeGen/AArch64/alias_mask_scalable.ll b/llvm/test/CodeGen/AArch64/alias_mask_scalable.ll
index 559d75715c725..33f1c194ff3bb 100644
--- a/llvm/test/CodeGen/AArch64/alias_mask_scalable.ll
+++ b/llvm/test/CodeGen/AArch64/alias_mask_scalable.ll
@@ -313,10 +313,8 @@ entry:
define <vscale x 8 x i1> @whilewr_extract_nxv8i1(ptr %a, ptr %b) {
; CHECK-LABEL: whilewr_extract_nxv8i1:
; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: sub x8, x1, x0
-; CHECK-NEXT: cmp x8, #1
-; CHECK-NEXT: csinv x8, x8, xzr, ge
-; CHECK-NEXT: whilelo p0.h, xzr, x8
+; CHECK-NEXT: whilewr p0.b, x0, x1
+; CHECK-NEXT: punpklo p0.h, p0.b
; CHECK-NEXT: ret
entry:
%0 = call <vscale x 8 x i1> @llvm.loop.dependence.war.mask.nxv8i1(ptr %a, ptr %b, i64 1)
@@ -326,10 +324,9 @@ entry:
define <vscale x 4 x i1> @whilewr_extract_nxv4i1(ptr %a, ptr %b) {
; CHECK-LABEL: whilewr_extract_nxv4i1:
; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: sub x8, x1, x0
-; CHECK-NEXT: cmp x8, #1
-; CHECK-NEXT: csinv x8, x8, xzr, ge
-; CHECK-NEXT: whilelo p0.s, xzr, x8
+; CHECK-NEXT: whilewr p0.b, x0, x1
+; CHECK-NEXT: punpklo p0.h, p0.b
+; CHECK-NEXT: punpklo p0.h, p0.b
; CHECK-NEXT: ret
entry:
%0 = call <vscale x 4 x i1> @llvm.loop.dependence.war.mask.nxv8i1(ptr %a, ptr %b, i64 1)
@@ -340,13 +337,8 @@ entry:
define <vscale x 2 x i1> @whilewr_extract_nxv2i1(ptr %a, ptr %b) {
; CHECK-LABEL: whilewr_extract_nxv2i1:
; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: subs x8, x1, x0
-; CHECK-NEXT: add x9, x8, #3
-; CHECK-NEXT: csel x8, x9, x8, mi
-; CHECK-NEXT: asr x8, x8, #2
-; CHECK-NEXT: cmp x8, #1
-; CHECK-NEXT: csinv x8, x8, xzr, ge
-; CHECK-NEXT: whilelo p0.d, xzr, x8
+; CHECK-NEXT: whilewr p0.s, x0, x1
+; CHECK-NEXT: punpklo p0.h, p0.b
; CHECK-NEXT: ret
entry:
%0 = call <vscale x 2 x i1> @llvm.loop.dependence.war.mask.nxv8i1(ptr %a, ptr %b, i64 4)
More information about the llvm-commits
mailing list