[llvm] [AArch64] Extend performActiveLaneMaskCombine for more than two extracts (PR #146725)
Paul Walker via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 8 05:14:40 PDT 2025
================
@@ -18143,53 +18143,63 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
(!ST->hasSVE2p1() && !(ST->hasSME2() && ST->isStreaming())))
return SDValue();
- if (!N->hasNUsesOfValue(2, 0))
+ unsigned NumUses = N->use_size();
+ auto MaskEC = N->getValueType(0).getVectorElementCount();
+ if (!MaskEC.isKnownMultipleOf(NumUses))
return SDValue();
- const uint64_t HalfSize = N->getValueType(0).getVectorMinNumElements() / 2;
- if (HalfSize < 2)
+ ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumUses);
+ if (ExtMinEC.getKnownMinValue() < 2)
return SDValue();
- auto It = N->user_begin();
- SDNode *Lo = *It++;
- SDNode *Hi = *It;
+ SmallVector<SDNode *> Extracts(NumUses, nullptr);
+ for (SDNode *Use : N->users()) {
+ if (Use->getOpcode() != ISD::EXTRACT_SUBVECTOR)
+ return SDValue();
- if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
- Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR)
- return SDValue();
+ // Ensure the extract type is correct (e.g. if NumUses is 4 and
+ // the mask return type is nxv8i1, each extract should be nxv2i1.
+ if (Use->getValueType(0).getVectorElementCount() != ExtMinEC)
+ return SDValue();
- uint64_t OffLo = Lo->getConstantOperandVal(1);
- uint64_t OffHi = Hi->getConstantOperandVal(1);
+ // There should be exactly one extract for each part of the mask.
+ unsigned Offset = Use->getConstantOperandVal(1);
+ unsigned Part = Offset / ExtMinEC.getKnownMinValue();
+ if (Extracts[Part] != nullptr)
+ return SDValue();
- if (OffLo > OffHi) {
- std::swap(Lo, Hi);
- std::swap(OffLo, OffHi);
+ Extracts[Part] = Use;
}
- if (OffLo != 0 || OffHi != HalfSize)
- return SDValue();
-
- EVT HalfVec = Lo->getValueType(0);
- if (HalfVec != Hi->getValueType(0) ||
- HalfVec.getVectorElementCount() != ElementCount::getScalable(HalfSize))
- return SDValue();
-
SelectionDAG &DAG = DCI.DAG;
SDLoc DL(N);
SDValue ID =
DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
+
SDValue Idx = N->getOperand(0);
SDValue TC = N->getOperand(1);
- if (Idx.getValueType() != MVT::i64) {
- Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
- TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
+ EVT OpVT = Idx.getValueType();
+ if (OpVT != MVT::i64) {
+ Idx = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, Idx);
+ TC = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, TC);
}
- auto R =
- DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL,
- {Lo->getValueType(0), Hi->getValueType(0)}, {ID, Idx, TC});
- DCI.CombineTo(Lo, R.getValue(0));
- DCI.CombineTo(Hi, R.getValue(1));
+ // Create the whilelo_x2 intrinsics from each pair of extracts
+ EVT ExtVT = Extracts[0]->getValueType(0);
+ for (unsigned I = 0; I < NumUses; I += 2) {
+ // After the first whilelo_x2, we need to increment the starting value.
+ if (I > 0) {
+ SDValue Elts =
+ DAG.getElementCount(DL, OpVT, ExtVT.getVectorElementCount());
+ Idx = DAG.getNode(ISD::UADDSAT, DL, OpVT, Idx, Elts);
+ }
+
+ auto R =
+ DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
+
+ DCI.CombineTo(Extracts[I], R.getValue(0));
+ DCI.CombineTo(Extracts[I + 1], R.getValue(1));
+ }
----------------
paulwalker-arm wrote:
What do you think about peeling the first loop iteration to streamline things a little? Something like:
```
auto R = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
DCI.CombineTo(Extracts[0], R.getValue(0));
DCI.CombineTo(Extracts[1], R.getValue(1));
if (NumUses == 2)
return SDValue(N, 0);
SDValue Elts = DAG.getElementCount(DL, OpVT, ExtVT.getVectorElementCount()*2);
for (unsigned I = 2; I < NumUses; I += 2) {
// After the first whilelo_x2, we need to increment the starting value.
Idx = DAG.getNode(ISD::UADDSAT, DL, OpVT, Idx, Elts);
auto R = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
DCI.CombineTo(Extracts[I], R.getValue(0));
DCI.CombineTo(Extracts[I + 1], R.getValue(1));
}
return SDValue(N, 0);
```
No requirement here, just a suggestion to consider.
https://github.com/llvm/llvm-project/pull/146725
More information about the llvm-commits
mailing list