[llvm] [AArch64][SVE2p1] Allow more uses of mask in performActiveLaneMaskCombine (PR #159360)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Sep 17 06:39:16 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: Kerry McLaughlin (kmclaughlin-arm)
<details>
<summary>Changes</summary>
The combine replaces a get_active_lane_mask used by two extract subvectors with
a single paired whilelo intrinsic. When the instruction is used for control
flow in a vector loop, an additional extract of element 0 may introduce
other uses of the intrinsic such as ptest and reinterpret cast, which
is currently not supported.
This patch changes performActiveLaneMaskCombine to count the number of
extract subvectors using the mask instead of the total number of uses,
and allows other uses by these additional operations.
---
Full diff: https://github.com/llvm/llvm-project/pull/159360.diff
3 Files Affected:
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+21-8)
- (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.cpp (+13-7)
- (modified) llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll (+75)
``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c9a756da0078d..9c7ecf944e763 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -18693,21 +18693,31 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
(!ST->hasSVE2p1() && !(ST->hasSME2() && ST->isStreaming())))
return SDValue();
- unsigned NumUses = N->use_size();
+ // Count the number of users which are extract_vectors
+ // The only other valid users for this combine are ptest_first
+ // and reinterpret_cast.
+ unsigned NumExts = count_if(N->users(), [](SDNode *Use) {
+ return Use->getOpcode() == ISD::EXTRACT_SUBVECTOR;
+ });
+
auto MaskEC = N->getValueType(0).getVectorElementCount();
- if (!MaskEC.isKnownMultipleOf(NumUses))
+ if (!MaskEC.isKnownMultipleOf(NumExts))
return SDValue();
- ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumUses);
+ ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumExts);
if (ExtMinEC.getKnownMinValue() < 2)
return SDValue();
- SmallVector<SDNode *> Extracts(NumUses, nullptr);
+ SmallVector<SDNode *> Extracts(NumExts, nullptr);
for (SDNode *Use : N->users()) {
+ if (Use->getOpcode() == AArch64ISD::PTEST_FIRST ||
+ Use->getOpcode() == AArch64ISD::REINTERPRET_CAST)
+ continue;
+
if (Use->getOpcode() != ISD::EXTRACT_SUBVECTOR)
return SDValue();
- // Ensure the extract type is correct (e.g. if NumUses is 4 and
+ // Ensure the extract type is correct (e.g. if NumExts is 4 and
// the mask return type is nxv8i1, each extract should be nxv2i1.
if (Use->getValueType(0).getVectorElementCount() != ExtMinEC)
return SDValue();
@@ -18741,11 +18751,13 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
DCI.CombineTo(Extracts[0], R.getValue(0));
DCI.CombineTo(Extracts[1], R.getValue(1));
- if (NumUses == 2)
- return SDValue(N, 0);
+ if (NumExts == 2) {
+ DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), R.getValue(0));
+ return SDValue(SDValue(N, 0));
+ }
auto Elts = DAG.getElementCount(DL, OpVT, ExtVT.getVectorElementCount() * 2);
- for (unsigned I = 2; I < NumUses; I += 2) {
+ for (unsigned I = 2; I < NumExts; I += 2) {
// After the first whilelo_x2, we need to increment the starting value.
Idx = DAG.getNode(ISD::UADDSAT, DL, OpVT, Idx, Elts);
R = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
@@ -18753,6 +18765,7 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
DCI.CombineTo(Extracts[I + 1], R.getValue(1));
}
+ DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), R.getValue(0));
return SDValue(N, 0);
}
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index bf3d47ac43607..069d08663fdea 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1495,13 +1495,19 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
return PredOpcode;
- // For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
- // redundant since WHILE performs an implicit PTEST with an all active
- // mask.
- if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 &&
- getElementSizeForOpcode(MaskOpcode) ==
- getElementSizeForOpcode(PredOpcode))
- return PredOpcode;
+ if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31) {
+ auto PTestOp = MRI->getUniqueVRegDef(PTest->getOperand(1).getReg());
+ if (PTest->getOpcode() == AArch64::PTEST_PP_FIRST && PTestOp->isCopy() &&
+ PTestOp->getOperand(1).getSubReg() == AArch64::psub0)
+ return PredOpcode;
+
+ // For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
+ // redundant since WHILE performs an implicit PTEST with an all active
+ // mask.
+ if (getElementSizeForOpcode(MaskOpcode) ==
+ getElementSizeForOpcode(PredOpcode))
+ return PredOpcode;
+ }
return {};
}
diff --git a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
index 5e01612e3881a..3b18008605413 100644
--- a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
+++ b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
@@ -310,6 +310,81 @@ define void @test_2x32bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #
ret void
}
+; Extra use of the get_active_lane_mask from an extractelement, which is replaced with ptest_first.
+
+define void @test_2x8bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
+; CHECK-SVE-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
+; CHECK-SVE: // %bb.0: // %entry
+; CHECK-SVE-NEXT: whilelo p1.b, x0, x1
+; CHECK-SVE-NEXT: b.pl .LBB11_2
+; CHECK-SVE-NEXT: // %bb.1: // %if.then
+; CHECK-SVE-NEXT: punpklo p0.h, p1.b
+; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT: b use
+; CHECK-SVE-NEXT: .LBB11_2: // %if.end
+; CHECK-SVE-NEXT: ret
+;
+; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
+; CHECK-SVE2p1-SME2: // %bb.0: // %entry
+; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.h, p1.h }, x0, x1
+; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB11_2
+; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT: b use
+; CHECK-SVE2p1-SME2-NEXT: .LBB11_2: // %if.end
+; CHECK-SVE2p1-SME2-NEXT: ret
+entry:
+ %r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i64 %i, i64 %n)
+ %v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
+ %v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
+ %elt0 = extractelement <vscale x 16 x i1> %r, i32 0
+ br i1 %elt0, label %if.then, label %if.end
+
+if.then:
+ tail call void @use(<vscale x 8 x i1> %v0, <vscale x 8 x i1> %v1)
+ br label %if.end
+
+if.end:
+ ret void
+}
+
+; Extra use of the get_active_lane_mask from an extractelement, which is
+; replaced with ptest_first and reinterpret_casts because the extract is not nxv16i1.
+
+define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n) {
+; CHECK-SVE-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
+; CHECK-SVE: // %bb.0: // %entry
+; CHECK-SVE-NEXT: whilelo p1.h, x0, x1
+; CHECK-SVE-NEXT: b.pl .LBB12_2
+; CHECK-SVE-NEXT: // %bb.1: // %if.then
+; CHECK-SVE-NEXT: punpklo p0.h, p1.b
+; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT: b use
+; CHECK-SVE-NEXT: .LBB12_2: // %if.end
+; CHECK-SVE-NEXT: ret
+;
+; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
+; CHECK-SVE2p1-SME2: // %bb.0: // %entry
+; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1
+; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB12_2
+; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT: b use
+; CHECK-SVE2p1-SME2-NEXT: .LBB12_2: // %if.end
+; CHECK-SVE2p1-SME2-NEXT: ret
+entry:
+ %r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
+ %v0 = tail call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv8i1(<vscale x 8 x i1> %r, i64 0)
+ %v1 = tail call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv8i1(<vscale x 8 x i1> %r, i64 4)
+ %elt0 = extractelement <vscale x 8 x i1> %r, i64 0
+ br i1 %elt0, label %if.then, label %if.end
+
+if.then:
+ tail call void @use(<vscale x 4 x i1> %v0, <vscale x 4 x i1> %v1)
+ br label %if.end
+
+if.end:
+ ret void
+}
+
declare void @use(...)
attributes #0 = { nounwind }
``````````
</details>
https://github.com/llvm/llvm-project/pull/159360
More information about the llvm-commits
mailing list