[llvm] [AArch64] Extend performActiveLaneMaskCombine for more than two extracts (PR #146725)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 2 08:28:58 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: Kerry McLaughlin (kmclaughlin-arm)
<details>
<summary>Changes</summary>
The combine was added to find a get.active.lane.mask used by two extract
subvectors and try to replace it with the paired whilelo instruction. This
extends the combine to cover cases where there are more than two extracts.
---
Full diff: https://github.com/llvm/llvm-project/pull/146725.diff
2 Files Affected:
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+40-30)
- (modified) llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll (+87)
``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index fb8bd81c033af..7e4ba0a776382 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -18143,53 +18143,63 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
(!ST->hasSVE2p1() && !(ST->hasSME2() && ST->isStreaming())))
return SDValue();
- if (!N->hasNUsesOfValue(2, 0))
+ unsigned NumUses = N->use_size();
+ unsigned MaskMinElts = N->getValueType(0).getVectorMinNumElements();
+ if (MaskMinElts % NumUses != 0)
return SDValue();
- const uint64_t HalfSize = N->getValueType(0).getVectorMinNumElements() / 2;
- if (HalfSize < 2)
+ unsigned ExtMinElts = MaskMinElts / NumUses;
+ if (ExtMinElts < 2)
return SDValue();
- auto It = N->user_begin();
- SDNode *Lo = *It++;
- SDNode *Hi = *It;
+ SmallVector<SDNode *> Extracts(NumUses, nullptr);
+ for (SDNode *Use : N->users()) {
+ if (Use->getOpcode() != ISD::EXTRACT_SUBVECTOR)
+ return SDValue();
- if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
- Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR)
- return SDValue();
+ // Ensure the extract type is correct (e.g. if NumUses is 4 and
+ // the mask return type is nxv8i1, each extract should be nxv2i1.
+ if (Use->getValueType(0).getVectorMinNumElements() != ExtMinElts)
+ return SDValue();
- uint64_t OffLo = Lo->getConstantOperandVal(1);
- uint64_t OffHi = Hi->getConstantOperandVal(1);
+ // There should be exactly one extract for each part of the mask.
+ unsigned Offset = Use->getConstantOperandVal(1);
+ unsigned Part = Offset / ExtMinElts;
+ if (Extracts[Part] != nullptr)
+ return SDValue();
- if (OffLo > OffHi) {
- std::swap(Lo, Hi);
- std::swap(OffLo, OffHi);
+ Extracts[Part] = Use;
}
- if (OffLo != 0 || OffHi != HalfSize)
- return SDValue();
-
- EVT HalfVec = Lo->getValueType(0);
- if (HalfVec != Hi->getValueType(0) ||
- HalfVec.getVectorElementCount() != ElementCount::getScalable(HalfSize))
- return SDValue();
-
SelectionDAG &DAG = DCI.DAG;
SDLoc DL(N);
SDValue ID =
DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
+
SDValue Idx = N->getOperand(0);
SDValue TC = N->getOperand(1);
- if (Idx.getValueType() != MVT::i64) {
- Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
- TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
+ EVT OpVT = Idx.getValueType();
+ if (OpVT != MVT::i64) {
+ Idx = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, Idx);
+ TC = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, TC);
}
- auto R =
- DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL,
- {Lo->getValueType(0), Hi->getValueType(0)}, {ID, Idx, TC});
- DCI.CombineTo(Lo, R.getValue(0));
- DCI.CombineTo(Hi, R.getValue(1));
+ // Create the whilelo_x2 intrinsics from each pair of extracts
+ EVT ExtVT = Extracts[0]->getValueType(0);
+ for (unsigned I = 0; I < NumUses; I += 2) {
+ // After the first whilelo_x2, we need to increment the starting value.
+ if (I > 0) {
+ SDValue Elts =
+ DAG.getElementCount(DL, OpVT, ExtVT.getVectorElementCount());
+ Idx = DAG.getNode(ISD::UADDSAT, DL, OpVT, Idx, Elts);
+ }
+
+ auto R =
+ DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
+
+ DCI.CombineTo(Extracts[I], R.getValue(0));
+ DCI.CombineTo(Extracts[I + 1], R.getValue(1));
+ }
return SDValue(N, 0);
}
diff --git a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
index c76b50d69b877..95788f4e6e83b 100644
--- a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
+++ b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
@@ -86,6 +86,64 @@ define void @test_boring_case_2x2bit_mask(i64 %i, i64 %n) #0 {
ret void
}
+define void @test_legal_4x2bit_mask(i64 %i, i64 %n) #0 {
+; CHECK-SVE-LABEL: test_legal_4x2bit_mask:
+; CHECK-SVE: // %bb.0:
+; CHECK-SVE-NEXT: whilelo p0.h, x0, x1
+; CHECK-SVE-NEXT: punpkhi p1.h, p0.b
+; CHECK-SVE-NEXT: punpklo p4.h, p0.b
+; CHECK-SVE-NEXT: punpkhi p3.h, p1.b
+; CHECK-SVE-NEXT: punpklo p2.h, p1.b
+; CHECK-SVE-NEXT: punpklo p0.h, p4.b
+; CHECK-SVE-NEXT: punpkhi p1.h, p4.b
+; CHECK-SVE-NEXT: b use
+;
+; CHECK-SVE2p1-SME2-LABEL: test_legal_4x2bit_mask:
+; CHECK-SVE2p1-SME2: // %bb.0:
+; CHECK-SVE2p1-SME2-NEXT: cntd x8
+; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8
+; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo
+; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.d, p1.d }, x0, x1
+; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.d, p3.d }, x8, x1
+; CHECK-SVE2p1-SME2-NEXT: b use
+ %r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
+ %v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 6)
+ %v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 4)
+ %v2 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 2)
+ %v3 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
+ tail call void @use(<vscale x 2 x i1> %v3, <vscale x 2 x i1> %v2, <vscale x 2 x i1> %v1, <vscale x 2 x i1> %v0)
+ ret void
+}
+
+; Negative test where the extract types are correct but we are not extracting all parts of the mask
+define void @test_partial_extract_correct_types(i64 %i, i64 %n) #0 {
+; CHECK-SVE-LABEL: test_partial_extract_correct_types:
+; CHECK-SVE: // %bb.0:
+; CHECK-SVE-NEXT: whilelo p0.h, x0, x1
+; CHECK-SVE-NEXT: punpklo p1.h, p0.b
+; CHECK-SVE-NEXT: punpkhi p2.h, p0.b
+; CHECK-SVE-NEXT: punpklo p0.h, p1.b
+; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT: punpkhi p2.h, p2.b
+; CHECK-SVE-NEXT: b use
+;
+; CHECK-SVE2p1-SME2-LABEL: test_partial_extract_correct_types:
+; CHECK-SVE2p1-SME2: // %bb.0:
+; CHECK-SVE2p1-SME2-NEXT: whilelo p0.h, x0, x1
+; CHECK-SVE2p1-SME2-NEXT: punpklo p1.h, p0.b
+; CHECK-SVE2p1-SME2-NEXT: punpkhi p2.h, p0.b
+; CHECK-SVE2p1-SME2-NEXT: punpklo p0.h, p1.b
+; CHECK-SVE2p1-SME2-NEXT: punpkhi p1.h, p1.b
+; CHECK-SVE2p1-SME2-NEXT: punpkhi p2.h, p2.b
+; CHECK-SVE2p1-SME2-NEXT: b use
+ %r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
+ %v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
+ %v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 2)
+ %v2 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 6)
+ tail call void @use(<vscale x 2 x i1> %v0, <vscale x 2 x i1> %v1, <vscale x 2 x i1> %v2)
+ ret void
+}
+
; Negative test for when not extracting exactly two halves of the source vector
define void @test_partial_extract(i64 %i, i64 %n) #0 {
; CHECK-SVE-LABEL: test_partial_extract:
@@ -167,6 +225,35 @@ define void @test_fixed_extract(i64 %i, i64 %n) #0 {
ret void
}
+; Negative test where the number of extracts is right, but they cannot be combined because
+; there is not an extract for each part
+define void @test_2x2bit_2x4bit_mask(i64 %i, i64 %n) #0 {
+; CHECK-SVE-LABEL: test_2x2bit_2x4bit_mask:
+; CHECK-SVE: // %bb.0:
+; CHECK-SVE-NEXT: whilelo p0.h, x0, x1
+; CHECK-SVE-NEXT: punpklo p2.h, p0.b
+; CHECK-SVE-NEXT: punpkhi p3.h, p0.b
+; CHECK-SVE-NEXT: punpklo p0.h, p2.b
+; CHECK-SVE-NEXT: punpkhi p1.h, p2.b
+; CHECK-SVE-NEXT: b use
+;
+; CHECK-SVE2p1-SME2-LABEL: test_2x2bit_2x4bit_mask:
+; CHECK-SVE2p1-SME2: // %bb.0:
+; CHECK-SVE2p1-SME2-NEXT: whilelo p0.h, x0, x1
+; CHECK-SVE2p1-SME2-NEXT: punpklo p2.h, p0.b
+; CHECK-SVE2p1-SME2-NEXT: punpkhi p3.h, p0.b
+; CHECK-SVE2p1-SME2-NEXT: punpklo p0.h, p2.b
+; CHECK-SVE2p1-SME2-NEXT: punpkhi p1.h, p2.b
+; CHECK-SVE2p1-SME2-NEXT: b use
+ %r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
+ %v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
+ %v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 2)
+ %v2 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
+ %v3 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 4)
+ tail call void @use(<vscale x 2 x i1> %v0, <vscale x 2 x i1> %v1, <vscale x 4 x i1> %v2, <vscale x 4 x i1> %v3)
+ ret void
+}
+
; Illegal Types
define void @test_2x16bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #0 {
``````````
</details>
https://github.com/llvm/llvm-project/pull/146725
More information about the llvm-commits
mailing list