[llvm-branch-commits] [llvm] [AArch64] Split large loop dependence masks (PR #153187)
Benjamin Maxwell via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Aug 20 08:22:23 PDT 2025
================
@@ -5248,49 +5248,94 @@ AArch64TargetLowering::LowerLOOP_DEPENDENCE_MASK(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
uint64_t EltSize = Op.getConstantOperandVal(2);
- EVT VT = Op.getValueType();
+ EVT FullVT = Op.getValueType();
+ unsigned NumElements = FullVT.getVectorMinNumElements();
+ unsigned NumSplits = 0;
+ EVT EltVT;
switch (EltSize) {
case 1:
- if (VT != MVT::v16i8 && VT != MVT::nxv16i1)
- return SDValue();
+ EltVT = MVT::i8;
break;
case 2:
- if (VT != MVT::v8i8 && VT != MVT::nxv8i1)
- return SDValue();
+ if (NumElements >= 16)
+ NumSplits = NumElements / 16;
+ EltVT = MVT::i16;
break;
case 4:
- if (VT != MVT::v4i16 && VT != MVT::nxv4i1)
- return SDValue();
+ if (NumElements >= 8)
+ NumSplits = NumElements / 8;
+ EltVT = MVT::i32;
break;
case 8:
- if (VT != MVT::v2i32 && VT != MVT::nxv2i1)
- return SDValue();
+ if (NumElements >= 4)
+ NumSplits = NumElements / 4;
+ EltVT = MVT::i64;
break;
default:
// Other element sizes are incompatible with whilewr/rw, so expand instead
return SDValue();
}
- SDValue PtrA = Op.getOperand(0);
- SDValue PtrB = Op.getOperand(1);
+ auto LowerToWhile = [&](EVT VT, unsigned AddrScale) {
+ SDValue PtrA = Op.getOperand(0);
+ SDValue PtrB = Op.getOperand(1);
- if (VT.isScalableVT())
- return DAG.getNode(Op.getOpcode(), DL, VT, PtrA, PtrB, Op.getOperand(2));
+ EVT StoreVT = EVT::getVectorVT(*DAG.getContext(), EltVT,
+ VT.getVectorMinNumElements(), false);
+ if (AddrScale > 0) {
+ unsigned Offset = StoreVT.getStoreSizeInBits() / 8 * AddrScale;
+ SDValue Addend;
- // We can use the SVE whilewr/whilerw instruction to lower this
- // intrinsic by creating the appropriate sequence of scalable vector
- // operations and then extracting a fixed-width subvector from the scalable
- // vector. Scalable vector variants are already legal.
- EVT ContainerVT =
- EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
- VT.getVectorNumElements(), true);
- EVT WhileVT = ContainerVT.changeElementType(MVT::i1);
+ if (VT.isScalableVT())
+ Addend = DAG.getVScale(DL, MVT::i64, APInt(64, Offset));
+ else
+ Addend = DAG.getConstant(Offset, DL, MVT::i64);
- SDValue Mask =
- DAG.getNode(Op.getOpcode(), DL, WhileVT, PtrA, PtrB, Op.getOperand(2));
- SDValue MaskAsInt = DAG.getNode(ISD::SIGN_EXTEND, DL, ContainerVT, Mask);
- return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, MaskAsInt,
- DAG.getVectorIdxConstant(0, DL));
+ PtrA = DAG.getNode(ISD::ADD, DL, MVT::i64, PtrA, Addend);
+ PtrB = DAG.getNode(ISD::ADD, DL, MVT::i64, PtrB, Addend);
+ }
+
+ if (VT.isScalableVT())
+ return DAG.getNode(Op.getOpcode(), DL, VT, PtrA, PtrB, Op.getOperand(2));
+
+ // We can use the SVE whilewr/whilerw instruction to lower this
+ // intrinsic by creating the appropriate sequence of scalable vector
+ // operations and then extracting a fixed-width subvector from the scalable
+ // vector. Scalable vector variants are already legal.
+ EVT ContainerVT =
+ EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
+ VT.getVectorNumElements(), true);
+ EVT WhileVT = ContainerVT.changeElementType(MVT::i1);
+
+ SDValue Mask =
+ DAG.getNode(Op.getOpcode(), DL, WhileVT, PtrA, PtrB, Op.getOperand(2));
+ SDValue MaskAsInt = DAG.getNode(ISD::SIGN_EXTEND, DL, ContainerVT, Mask);
+ return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, MaskAsInt,
+ DAG.getVectorIdxConstant(0, DL));
+ };
+
+ if (NumSplits == 0)
+ return LowerToWhile(FullVT, 0);
+
+ SDValue FullVec = DAG.getUNDEF(FullVT);
+
+ unsigned NumElementsPerSplit = NumElements / (2 * NumSplits);
+ EVT PartVT =
+ EVT::getVectorVT(*DAG.getContext(), FullVT.getVectorElementType(),
+ NumElementsPerSplit, FullVT.isScalableVT());
+ for (unsigned Split = 0, InsertIdx = 0; Split < NumSplits;
+ Split++, InsertIdx += 2) {
+ SDValue Low = LowerToWhile(PartVT, InsertIdx);
+ SDValue High = LowerToWhile(PartVT, InsertIdx + 1);
+ unsigned InsertIdxLow = InsertIdx * NumElementsPerSplit;
+ unsigned InsertIdxHigh = (InsertIdx + 1) * NumElementsPerSplit;
+ SDValue Insert =
+ DAG.getNode(ISD::INSERT_SUBVECTOR, DL, FullVT, FullVec, Low,
+ DAG.getVectorIdxConstant(InsertIdxLow, DL));
+ FullVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, FullVT, Insert, High,
+ DAG.getVectorIdxConstant(InsertIdxHigh, DL));
+ }
----------------
MacDue wrote:
Following from my first suggestion (that adds `NumWhiles`) this can be simplified to:
```cpp
if (NumWhiles <= 1)
return LowerToWhile(FullVT, 0);
unsigned NumElementsPerSplit = NumElements / NumWhiles;
EVT PartVT =
EVT::getVectorVT(*DAG.getContext(), FullVT.getVectorElementType(),
NumElementsPerSplit, FullVT.isScalableVT());
SDValue FullVec = DAG.getUNDEF(FullVT);
for (unsigned I = 0; I < NumWhiles; I++) {
SDValue While = LowerToWhile(PartVT, I);
unsigned InsertIdx = I * NumElementsPerSplit;
FullVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, FullVT, FullVec, While,
DAG.getVectorIdxConstant(InsertIdx, DL));
}
```
https://github.com/llvm/llvm-project/pull/153187
More information about the llvm-branch-commits
mailing list