[llvm] 17f29c2 - [AArch64] Support lowering smaller than legal LOOP_DEP_MASKs to whilewr/rw (#171982)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Dec 15 01:13:03 PST 2025
Author: Benjamin Maxwell
Date: 2025-12-15T09:12:58Z
New Revision: 17f29c22abc5426153847766bfffbb5e9e7c0241
URL: https://github.com/llvm/llvm-project/commit/17f29c22abc5426153847766bfffbb5e9e7c0241
DIFF: https://github.com/llvm/llvm-project/commit/17f29c22abc5426153847766bfffbb5e9e7c0241.diff
LOG: [AArch64] Support lowering smaller than legal LOOP_DEP_MASKs to whilewr/rw (#171982)
This adds support for lowering smaller-than-legal masks such as:
```
<vscale x 8 x i1> @llvm.loop.dependence.war.mask.nxv8i1(ptr %a, ptr %b, i64 1)
```
To a whilewr + unpack. It also slightly simplifies the lowering.
Added:
Modified:
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/test/CodeGen/AArch64/alias_mask.ll
llvm/test/CodeGen/AArch64/alias_mask_scalable.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 614395c2ed06e..88dd2800fcc24 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -5439,55 +5439,43 @@ static MVT getSVEContainerType(EVT ContentTy);
SDValue
AArch64TargetLowering::LowerLOOP_DEPENDENCE_MASK(SDValue Op,
SelectionDAG &DAG) const {
+ assert((Subtarget->hasSVE2() ||
+ (Subtarget->hasSME() && Subtarget->isStreaming())) &&
+ "Lowering loop_dependence_raw_mask or loop_dependence_war_mask "
+ "requires SVE or SME");
+
SDLoc DL(Op);
EVT VT = Op.getValueType();
- SDValue EltSize = Op.getOperand(2);
- switch (EltSize->getAsZExtVal()) {
- case 1:
- if (VT != MVT::v16i8 && VT != MVT::nxv16i1)
- return SDValue();
- break;
- case 2:
- if (VT != MVT::v8i8 && VT != MVT::nxv8i1)
- return SDValue();
- break;
- case 4:
- if (VT != MVT::v4i16 && VT != MVT::nxv4i1)
- return SDValue();
- break;
- case 8:
- if (VT != MVT::v2i32 && VT != MVT::nxv2i1)
- return SDValue();
- break;
- default:
- // Other element sizes are incompatible with whilewr/rw, so expand instead
- return SDValue();
- }
+ unsigned LaneOffset = Op.getConstantOperandVal(3);
+ unsigned NumElements = VT.getVectorMinNumElements();
+ uint64_t EltSizeInBytes = Op.getConstantOperandVal(2);
- SDValue LaneOffset = Op.getOperand(3);
- if (LaneOffset->getAsZExtVal())
+ // Lane offsets and other element sizes are not supported by whilewr/rw.
+ if (LaneOffset != 0 || !is_contained({1u, 2u, 4u, 8u}, EltSizeInBytes))
return SDValue();
- SDValue PtrA = Op.getOperand(0);
- SDValue PtrB = Op.getOperand(1);
+ EVT EltVT = MVT::getIntegerVT(EltSizeInBytes * 8);
+ EVT PredVT = getPackedSVEVectorVT(EltVT).changeElementType(MVT::i1);
- if (VT.isScalableVT())
- return DAG.getNode(Op.getOpcode(), DL, VT, PtrA, PtrB, EltSize, LaneOffset);
+ // Legal whilewr/rw (lowered by tablegen matcher).
+ if (PredVT == VT)
+ return Op;
- // We can use the SVE whilewr/whilerw instruction to lower this
- // intrinsic by creating the appropriate sequence of scalable vector
- // operations and then extracting a fixed-width subvector from the scalable
- // vector. Scalable vector variants are already legal.
- EVT ContainerVT =
- EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
- VT.getVectorNumElements(), true);
- EVT WhileVT = ContainerVT.changeElementType(MVT::i1);
+ // Expand if this mask needs splitting (this will produce a whilelo).
+ if (NumElements > PredVT.getVectorMinNumElements())
+ return SDValue();
SDValue Mask =
- DAG.getNode(Op.getOpcode(), DL, WhileVT, PtrA, PtrB, EltSize, LaneOffset);
- SDValue MaskAsInt = DAG.getNode(ISD::SIGN_EXTEND, DL, ContainerVT, Mask);
- return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, MaskAsInt,
- DAG.getVectorIdxConstant(0, DL));
+ DAG.getNode(Op.getOpcode(), DL, PredVT, to_vector(Op->op_values()));
+
+ if (VT.isFixedLengthVector()) {
+ EVT WidePredVT = PredVT.changeElementType(VT.getScalarType());
+ SDValue MaskAsInt = DAG.getNode(ISD::SIGN_EXTEND, DL, WidePredVT, Mask);
+ return convertFromScalableVector(DAG, VT, MaskAsInt);
+ }
+
+ return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Mask,
+ DAG.getConstant(0, DL, MVT::i64));
}
SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op,
diff --git a/llvm/test/CodeGen/AArch64/alias_mask.ll b/llvm/test/CodeGen/AArch64/alias_mask.ll
index bf393b6e87710..7a57fc3be84ac 100644
--- a/llvm/test/CodeGen/AArch64/alias_mask.ll
+++ b/llvm/test/CodeGen/AArch64/alias_mask.ll
@@ -563,3 +563,40 @@ entry:
%0 = call <1 x i1> @llvm.loop.dependence.raw.mask.v1i1(ptr %a, ptr %b, i64 8)
ret <1 x i1> %0
}
+
+define <8 x i1> @whilewr_extract_v8i1(ptr %a, ptr %b) {
+; CHECK-LABEL: whilewr_extract_v8i1:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: whilewr p0.b, x0, x1
+; CHECK-NEXT: mov z0.b, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %0 = call <8 x i1> @llvm.loop.dependence.war.mask.v8i1(ptr %a, ptr %b, i64 1)
+ ret <8 x i1> %0
+}
+
+define <4 x i1> @whilewr_extract_v4i1(ptr %a, ptr %b) {
+; CHECK-LABEL: whilewr_extract_v4i1:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: whilewr p0.b, x0, x1
+; CHECK-NEXT: punpklo p0.h, p0.b
+; CHECK-NEXT: mov z0.h, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %0 = call <4 x i1> @llvm.loop.dependence.war.mask.v4i1(ptr %a, ptr %b, i64 1)
+ ret <4 x i1> %0
+}
+
+define <2 x i1> @whilewr_extract_v2i1(ptr %a, ptr %b) {
+; CHECK-LABEL: whilewr_extract_v2i1:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: whilewr p0.s, x0, x1
+; CHECK-NEXT: mov z0.s, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %0 = call <2 x i1> @llvm.loop.dependence.war.mask.v2i1(ptr %a, ptr %b, i64 4)
+ ret <2 x i1> %0
+}
diff --git a/llvm/test/CodeGen/AArch64/alias_mask_scalable.ll b/llvm/test/CodeGen/AArch64/alias_mask_scalable.ll
index 8a2eff3fde396..e9463b5c571b6 100644
--- a/llvm/test/CodeGen/AArch64/alias_mask_scalable.ll
+++ b/llvm/test/CodeGen/AArch64/alias_mask_scalable.ll
@@ -309,3 +309,38 @@ entry:
%0 = call <vscale x 16 x i1> @llvm.loop.dependence.war.mask.nxv16i1(ptr %a, ptr %b, i64 3)
ret <vscale x 16 x i1> %0
}
+
+define <vscale x 8 x i1> @whilewr_extract_nxv8i1(ptr %a, ptr %b) {
+; CHECK-LABEL: whilewr_extract_nxv8i1:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: whilewr p0.b, x0, x1
+; CHECK-NEXT: punpklo p0.h, p0.b
+; CHECK-NEXT: ret
+entry:
+ %0 = call <vscale x 8 x i1> @llvm.loop.dependence.war.mask.nxv8i1(ptr %a, ptr %b, i64 1)
+ ret <vscale x 8 x i1> %0
+}
+
+define <vscale x 4 x i1> @whilewr_extract_nxv4i1(ptr %a, ptr %b) {
+; CHECK-LABEL: whilewr_extract_nxv4i1:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: whilewr p0.b, x0, x1
+; CHECK-NEXT: punpklo p0.h, p0.b
+; CHECK-NEXT: punpklo p0.h, p0.b
+; CHECK-NEXT: ret
+entry:
+ %0 = call <vscale x 4 x i1> @llvm.loop.dependence.war.mask.nxv4i1(ptr %a, ptr %b, i64 1)
+ ret <vscale x 4 x i1> %0
+}
+
+
+define <vscale x 2 x i1> @whilewr_extract_nxv2i1(ptr %a, ptr %b) {
+; CHECK-LABEL: whilewr_extract_nxv2i1:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: whilewr p0.s, x0, x1
+; CHECK-NEXT: punpklo p0.h, p0.b
+; CHECK-NEXT: ret
+entry:
+ %0 = call <vscale x 2 x i1> @llvm.loop.dependence.war.mask.nxv2i1(ptr %a, ptr %b, i64 4)
+ ret <vscale x 2 x i1> %0
+}
More information about the llvm-commits
mailing list