[llvm] [AArch64][SVE2p1] Allow more uses of mask in performActiveLaneMaskCombine (PR #159360)
Kerry McLaughlin via llvm-commits
llvm-commits at lists.llvm.org
Thu Sep 18 08:36:01 PDT 2025
https://github.com/kmclaughlin-arm updated https://github.com/llvm/llvm-project/pull/159360
>From 7f042ecbbe0c0afd85304c4742d8d2fa392fa66e Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Tue, 16 Sep 2025 16:11:22 +0000
Subject: [PATCH 1/3] - Tests for get_active_lane_mask with more uses than
extract_vector
---
.../AArch64/get-active-lane-mask-extract.ll | 173 ++++++++++++++++++
1 file changed, 173 insertions(+)
diff --git a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
index 5e01612e3881a..f6dc912ce9f07 100644
--- a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
+++ b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
@@ -310,6 +310,179 @@ define void @test_2x32bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #
ret void
}
+; Extra use of the get_active_lane_mask from an extractelement, which is replaced with ptest_first.
+
+define void @test_2x8bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
+; CHECK-SVE-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
+; CHECK-SVE: // %bb.0: // %entry
+; CHECK-SVE-NEXT: whilelo p1.b, x0, x1
+; CHECK-SVE-NEXT: b.pl .LBB11_2
+; CHECK-SVE-NEXT: // %bb.1: // %if.then
+; CHECK-SVE-NEXT: punpklo p0.h, p1.b
+; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT: b use
+; CHECK-SVE-NEXT: .LBB11_2: // %if.end
+; CHECK-SVE-NEXT: ret
+;
+; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
+; CHECK-SVE2p1-SME2: // %bb.0: // %entry
+; CHECK-SVE2p1-SME2-NEXT: whilelo p1.b, x0, x1
+; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB11_2
+; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT: punpklo p0.h, p1.b
+; CHECK-SVE2p1-SME2-NEXT: punpkhi p1.h, p1.b
+; CHECK-SVE2p1-SME2-NEXT: b use
+; CHECK-SVE2p1-SME2-NEXT: .LBB11_2: // %if.end
+; CHECK-SVE2p1-SME2-NEXT: ret
+entry:
+ %r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i64 %i, i64 %n)
+ %v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
+ %v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
+ %elt0 = extractelement <vscale x 16 x i1> %r, i32 0
+ br i1 %elt0, label %if.then, label %if.end
+
+if.then:
+ tail call void @use(<vscale x 8 x i1> %v0, <vscale x 8 x i1> %v1)
+ br label %if.end
+
+if.end:
+ ret void
+}
+
+; Extra use of the get_active_lane_mask from an extractelement, which is
+; replaced with ptest_first and reinterpret_casts because the extract is not nxv16i1.
+
+define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n) {
+; CHECK-SVE-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
+; CHECK-SVE: // %bb.0: // %entry
+; CHECK-SVE-NEXT: whilelo p1.h, x0, x1
+; CHECK-SVE-NEXT: b.pl .LBB12_2
+; CHECK-SVE-NEXT: // %bb.1: // %if.then
+; CHECK-SVE-NEXT: punpklo p0.h, p1.b
+; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT: b use
+; CHECK-SVE-NEXT: .LBB12_2: // %if.end
+; CHECK-SVE-NEXT: ret
+;
+; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
+; CHECK-SVE2p1-SME2: // %bb.0: // %entry
+; CHECK-SVE2p1-SME2-NEXT: whilelo p1.h, x0, x1
+; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB12_2
+; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT: punpklo p0.h, p1.b
+; CHECK-SVE2p1-SME2-NEXT: punpkhi p1.h, p1.b
+; CHECK-SVE2p1-SME2-NEXT: b use
+; CHECK-SVE2p1-SME2-NEXT: .LBB12_2: // %if.end
+; CHECK-SVE2p1-SME2-NEXT: ret
+entry:
+ %r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
+ %v0 = tail call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv8i1(<vscale x 8 x i1> %r, i64 0)
+ %v1 = tail call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv8i1(<vscale x 8 x i1> %r, i64 4)
+ %elt0 = extractelement <vscale x 8 x i1> %r, i64 0
+ br i1 %elt0, label %if.then, label %if.end
+
+if.then:
+ tail call void @use(<vscale x 4 x i1> %v0, <vscale x 4 x i1> %v1)
+ br label %if.end
+
+if.end:
+ ret void
+}
+
+define void @test_4x4bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
+; CHECK-SVE-LABEL: test_4x4bit_mask_with_extracts_and_ptest:
+; CHECK-SVE: // %bb.0: // %entry
+; CHECK-SVE-NEXT: whilelo p0.b, x0, x1
+; CHECK-SVE-NEXT: b.pl .LBB13_2
+; CHECK-SVE-NEXT: // %bb.1: // %if.then
+; CHECK-SVE-NEXT: punpklo p1.h, p0.b
+; CHECK-SVE-NEXT: punpkhi p3.h, p0.b
+; CHECK-SVE-NEXT: punpklo p0.h, p1.b
+; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT: punpklo p2.h, p3.b
+; CHECK-SVE-NEXT: punpkhi p3.h, p3.b
+; CHECK-SVE-NEXT: b use
+; CHECK-SVE-NEXT: .LBB13_2: // %if.end
+; CHECK-SVE-NEXT: ret
+;
+; CHECK-SVE2p1-SME2-LABEL: test_4x4bit_mask_with_extracts_and_ptest:
+; CHECK-SVE2p1-SME2: // %bb.0: // %entry
+; CHECK-SVE2p1-SME2-NEXT: whilelo p0.b, x0, x1
+; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB13_2
+; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT: punpklo p1.h, p0.b
+; CHECK-SVE2p1-SME2-NEXT: punpkhi p3.h, p0.b
+; CHECK-SVE2p1-SME2-NEXT: punpklo p0.h, p1.b
+; CHECK-SVE2p1-SME2-NEXT: punpkhi p1.h, p1.b
+; CHECK-SVE2p1-SME2-NEXT: punpklo p2.h, p3.b
+; CHECK-SVE2p1-SME2-NEXT: punpkhi p3.h, p3.b
+; CHECK-SVE2p1-SME2-NEXT: b use
+; CHECK-SVE2p1-SME2-NEXT: .LBB13_2: // %if.end
+; CHECK-SVE2p1-SME2-NEXT: ret
+entry:
+ %r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i64 %i, i64 %n)
+ %v0 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
+ %v1 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 4)
+ %v2 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
+ %v3 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 12)
+ %elt0 = extractelement <vscale x 16 x i1> %r, i32 0
+ br i1 %elt0, label %if.then, label %if.end
+
+if.then:
+ tail call void @use(<vscale x 4 x i1> %v0, <vscale x 4 x i1> %v1, <vscale x 4 x i1> %v2, <vscale x 4 x i1> %v3)
+ br label %if.end
+
+if.end:
+ ret void
+}
+
+define void @test_4x2bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n) {
+; CHECK-SVE-LABEL: test_4x2bit_mask_with_extracts_and_reinterpret_casts:
+; CHECK-SVE: // %bb.0: // %entry
+; CHECK-SVE-NEXT: whilelo p0.h, x0, x1
+; CHECK-SVE-NEXT: b.pl .LBB14_2
+; CHECK-SVE-NEXT: // %bb.1: // %if.then
+; CHECK-SVE-NEXT: punpklo p1.h, p0.b
+; CHECK-SVE-NEXT: punpkhi p3.h, p0.b
+; CHECK-SVE-NEXT: punpklo p0.h, p1.b
+; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT: punpklo p2.h, p3.b
+; CHECK-SVE-NEXT: punpkhi p3.h, p3.b
+; CHECK-SVE-NEXT: b use
+; CHECK-SVE-NEXT: .LBB14_2: // %if.end
+; CHECK-SVE-NEXT: ret
+;
+; CHECK-SVE2p1-SME2-LABEL: test_4x2bit_mask_with_extracts_and_reinterpret_casts:
+; CHECK-SVE2p1-SME2: // %bb.0: // %entry
+; CHECK-SVE2p1-SME2-NEXT: whilelo p0.h, x0, x1
+; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB14_2
+; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT: punpklo p1.h, p0.b
+; CHECK-SVE2p1-SME2-NEXT: punpkhi p3.h, p0.b
+; CHECK-SVE2p1-SME2-NEXT: punpklo p0.h, p1.b
+; CHECK-SVE2p1-SME2-NEXT: punpkhi p1.h, p1.b
+; CHECK-SVE2p1-SME2-NEXT: punpklo p2.h, p3.b
+; CHECK-SVE2p1-SME2-NEXT: punpkhi p3.h, p3.b
+; CHECK-SVE2p1-SME2-NEXT: b use
+; CHECK-SVE2p1-SME2-NEXT: .LBB14_2: // %if.end
+; CHECK-SVE2p1-SME2-NEXT: ret
+entry:
+ %r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i32(i64 %i, i64 %n)
+ %v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
+ %v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 2)
+ %v2 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 4)
+ %v3 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 6)
+ %elt0 = extractelement <vscale x 8 x i1> %r, i32 0
+ br i1 %elt0, label %if.then, label %if.end
+
+if.then:
+ tail call void @use(<vscale x 2 x i1> %v0, <vscale x 2 x i1> %v1, <vscale x 2 x i1> %v2, <vscale x 2 x i1> %v3)
+ br label %if.end
+
+if.end:
+ ret void
+}
+
declare void @use(...)
attributes #0 = { nounwind }
>From 64b6d612f29581f4ffe3a7c838b0a94f5f329a1b Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Tue, 16 Sep 2025 09:04:36 +0000
Subject: [PATCH 2/3] [AArch64][SVE2p1] Allow more uses of mask in
performActiveLaneMaskCombine
The combine replaces a get_active_lane_mask used by two extract subvectors with
a single paired whilelo intrinsic. When the instruction is used for control
flow in a vector loop, an additional extract of element 0 may introduce
other uses of the intrinsic such as ptest and reinterpret cast, which
is currently not supported.
This patch changes performActiveLaneMaskCombine to count the number of
extract subvectors using the mask instead of the total number of uses,
and allows other uses by these additional operations.
---
.../Target/AArch64/AArch64ISelLowering.cpp | 29 ++++++++++++++-----
llvm/lib/Target/AArch64/AArch64InstrInfo.cpp | 20 ++++++++-----
.../AArch64/get-active-lane-mask-extract.ll | 8 ++---
3 files changed, 36 insertions(+), 21 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index fc3efb072d57b..0b2bf4554271d 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -18785,21 +18785,31 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
(!ST->hasSVE2p1() && !(ST->hasSME2() && ST->isStreaming())))
return SDValue();
- unsigned NumUses = N->use_size();
+ // Count the number of users which are extract_vectors
+ // The only other valid users for this combine are ptest_first
+ // and reinterpret_cast.
+ unsigned NumExts = count_if(N->users(), [](SDNode *Use) {
+ return Use->getOpcode() == ISD::EXTRACT_SUBVECTOR;
+ });
+
auto MaskEC = N->getValueType(0).getVectorElementCount();
- if (!MaskEC.isKnownMultipleOf(NumUses))
+ if (!MaskEC.isKnownMultipleOf(NumExts))
return SDValue();
- ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumUses);
+ ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumExts);
if (ExtMinEC.getKnownMinValue() < 2)
return SDValue();
- SmallVector<SDNode *> Extracts(NumUses, nullptr);
+ SmallVector<SDNode *> Extracts(NumExts, nullptr);
for (SDNode *Use : N->users()) {
+ if (Use->getOpcode() == AArch64ISD::PTEST_FIRST ||
+ Use->getOpcode() == AArch64ISD::REINTERPRET_CAST)
+ continue;
+
if (Use->getOpcode() != ISD::EXTRACT_SUBVECTOR)
return SDValue();
- // Ensure the extract type is correct (e.g. if NumUses is 4 and
+ // Ensure the extract type is correct (e.g. if NumExts is 4 and
// the mask return type is nxv8i1, each extract should be nxv2i1.
if (Use->getValueType(0).getVectorElementCount() != ExtMinEC)
return SDValue();
@@ -18833,11 +18843,13 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
DCI.CombineTo(Extracts[0], R.getValue(0));
DCI.CombineTo(Extracts[1], R.getValue(1));
- if (NumUses == 2)
- return SDValue(N, 0);
+ if (NumExts == 2) {
+ DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), R.getValue(0));
+ return SDValue(SDValue(N, 0));
+ }
auto Elts = DAG.getElementCount(DL, OpVT, ExtVT.getVectorElementCount() * 2);
- for (unsigned I = 2; I < NumUses; I += 2) {
+ for (unsigned I = 2; I < NumExts; I += 2) {
// After the first whilelo_x2, we need to increment the starting value.
Idx = DAG.getNode(ISD::UADDSAT, DL, OpVT, Idx, Elts);
R = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
@@ -18845,6 +18857,7 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
DCI.CombineTo(Extracts[I + 1], R.getValue(1));
}
+ DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), R.getValue(0));
return SDValue(N, 0);
}
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 5a51c812732e6..35084841eef7e 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1495,13 +1495,19 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
return PredOpcode;
- // For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
- // redundant since WHILE performs an implicit PTEST with an all active
- // mask.
- if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 &&
- getElementSizeForOpcode(MaskOpcode) ==
- getElementSizeForOpcode(PredOpcode))
- return PredOpcode;
+ if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31) {
+ auto PTestOp = MRI->getUniqueVRegDef(PTest->getOperand(1).getReg());
+ if (PTest->getOpcode() == AArch64::PTEST_PP_FIRST && PTestOp->isCopy() &&
+ PTestOp->getOperand(1).getSubReg() == AArch64::psub0)
+ return PredOpcode;
+
+ // For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
+ // redundant since WHILE performs an implicit PTEST with an all active
+ // mask.
+ if (getElementSizeForOpcode(MaskOpcode) ==
+ getElementSizeForOpcode(PredOpcode))
+ return PredOpcode;
+ }
return {};
}
diff --git a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
index f6dc912ce9f07..5fd05478fba57 100644
--- a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
+++ b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
@@ -326,11 +326,9 @@ define void @test_2x8bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
;
; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
-; CHECK-SVE2p1-SME2-NEXT: whilelo p1.b, x0, x1
+; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.h, p1.h }, x0, x1
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB11_2
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
-; CHECK-SVE2p1-SME2-NEXT: punpklo p0.h, p1.b
-; CHECK-SVE2p1-SME2-NEXT: punpkhi p1.h, p1.b
; CHECK-SVE2p1-SME2-NEXT: b use
; CHECK-SVE2p1-SME2-NEXT: .LBB11_2: // %if.end
; CHECK-SVE2p1-SME2-NEXT: ret
@@ -366,11 +364,9 @@ define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n
;
; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
-; CHECK-SVE2p1-SME2-NEXT: whilelo p1.h, x0, x1
+; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB12_2
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
-; CHECK-SVE2p1-SME2-NEXT: punpklo p0.h, p1.b
-; CHECK-SVE2p1-SME2-NEXT: punpkhi p1.h, p1.b
; CHECK-SVE2p1-SME2-NEXT: b use
; CHECK-SVE2p1-SME2-NEXT: .LBB12_2: // %if.end
; CHECK-SVE2p1-SME2-NEXT: ret
>From 70adaf7fc0097fbed3839aa87e0a3068ae57f8bf Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Thu, 18 Sep 2025 09:27:49 +0000
Subject: [PATCH 3/3] - Return a concat_vector created from the results of
whilelo_x2 from performActiveLaneMaskCombine - Add tests for the 4 extracts
case which will use ptest & reinterpret_cast - Remove changes to
canRemovePTestInstr
---
.../Target/AArch64/AArch64ISelLowering.cpp | 22 +++++-----
llvm/lib/Target/AArch64/AArch64InstrInfo.cpp | 20 ++++------
.../AArch64/get-active-lane-mask-extract.ll | 40 ++++++++++++-------
3 files changed, 42 insertions(+), 40 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 0b2bf4554271d..45d2081c0dd2a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -18774,7 +18774,7 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
static SDValue
performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *ST) {
- if (DCI.isBeforeLegalize())
+ if (DCI.isBeforeLegalize() && !!DCI.isBeforeLegalizeOps())
return SDValue();
if (SDValue While = optimizeIncrementingWhile(N, DCI.DAG, /*IsSigned=*/false,
@@ -18793,7 +18793,7 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
});
auto MaskEC = N->getValueType(0).getVectorElementCount();
- if (!MaskEC.isKnownMultipleOf(NumExts))
+ if (NumExts == 0 || !MaskEC.isKnownMultipleOf(NumExts))
return SDValue();
ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumExts);
@@ -18802,12 +18802,8 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
SmallVector<SDNode *> Extracts(NumExts, nullptr);
for (SDNode *Use : N->users()) {
- if (Use->getOpcode() == AArch64ISD::PTEST_FIRST ||
- Use->getOpcode() == AArch64ISD::REINTERPRET_CAST)
- continue;
-
if (Use->getOpcode() != ISD::EXTRACT_SUBVECTOR)
- return SDValue();
+ continue;
// Ensure the extract type is correct (e.g. if NumExts is 4 and
// the mask return type is nxv8i1, each extract should be nxv2i1.
@@ -18842,11 +18838,10 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
DCI.CombineTo(Extracts[0], R.getValue(0));
DCI.CombineTo(Extracts[1], R.getValue(1));
+ SmallVector<SDValue> Results = {R.getValue(0), R.getValue(1)};
- if (NumExts == 2) {
- DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), R.getValue(0));
- return SDValue(SDValue(N, 0));
- }
+ if (NumExts == 2)
+ return DAG.getNode(ISD::CONCAT_VECTORS, DL, N->getValueType(0), Results);
auto Elts = DAG.getElementCount(DL, OpVT, ExtVT.getVectorElementCount() * 2);
for (unsigned I = 2; I < NumExts; I += 2) {
@@ -18855,10 +18850,11 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
R = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
DCI.CombineTo(Extracts[I], R.getValue(0));
DCI.CombineTo(Extracts[I + 1], R.getValue(1));
+ Results.push_back(R.getValue(0));
+ Results.push_back(R.getValue(1));
}
- DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), R.getValue(0));
- return SDValue(N, 0);
+ return DAG.getNode(ISD::CONCAT_VECTORS, DL, N->getValueType(0), Results);
}
// Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 35084841eef7e..5a51c812732e6 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1495,19 +1495,13 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
return PredOpcode;
- if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31) {
- auto PTestOp = MRI->getUniqueVRegDef(PTest->getOperand(1).getReg());
- if (PTest->getOpcode() == AArch64::PTEST_PP_FIRST && PTestOp->isCopy() &&
- PTestOp->getOperand(1).getSubReg() == AArch64::psub0)
- return PredOpcode;
-
- // For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
- // redundant since WHILE performs an implicit PTEST with an all active
- // mask.
- if (getElementSizeForOpcode(MaskOpcode) ==
- getElementSizeForOpcode(PredOpcode))
- return PredOpcode;
- }
+ // For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
+ // redundant since WHILE performs an implicit PTEST with an all active
+ // mask.
+ if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 &&
+ getElementSizeForOpcode(MaskOpcode) ==
+ getElementSizeForOpcode(PredOpcode))
+ return PredOpcode;
return {};
}
diff --git a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
index 5fd05478fba57..b89f55188b0f2 100644
--- a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
+++ b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
@@ -327,6 +327,9 @@ define void @test_2x8bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.h, p1.h }, x0, x1
+; CHECK-SVE2p1-SME2-NEXT: ptrue p2.b
+; CHECK-SVE2p1-SME2-NEXT: uzp1 p3.b, p0.b, p1.b
+; CHECK-SVE2p1-SME2-NEXT: ptest p2, p3.b
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB11_2
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
; CHECK-SVE2p1-SME2-NEXT: b use
@@ -365,6 +368,9 @@ define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n
; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1
+; CHECK-SVE2p1-SME2-NEXT: ptrue p2.h
+; CHECK-SVE2p1-SME2-NEXT: uzp1 p3.h, p0.h, p1.h
+; CHECK-SVE2p1-SME2-NEXT: ptest p2, p3.b
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB12_2
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
; CHECK-SVE2p1-SME2-NEXT: b use
@@ -403,15 +409,18 @@ define void @test_4x4bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
;
; CHECK-SVE2p1-SME2-LABEL: test_4x4bit_mask_with_extracts_and_ptest:
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
-; CHECK-SVE2p1-SME2-NEXT: whilelo p0.b, x0, x1
+; CHECK-SVE2p1-SME2-NEXT: cnth x8
+; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8
+; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo
+; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1
+; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.s, p3.s }, x8, x1
+; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.h, p0.h, p1.h
+; CHECK-SVE2p1-SME2-NEXT: uzp1 p5.h, p2.h, p3.h
+; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.b, p4.b, p5.b
+; CHECK-SVE2p1-SME2-NEXT: ptrue p5.b
+; CHECK-SVE2p1-SME2-NEXT: ptest p5, p4.b
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB13_2
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
-; CHECK-SVE2p1-SME2-NEXT: punpklo p1.h, p0.b
-; CHECK-SVE2p1-SME2-NEXT: punpkhi p3.h, p0.b
-; CHECK-SVE2p1-SME2-NEXT: punpklo p0.h, p1.b
-; CHECK-SVE2p1-SME2-NEXT: punpkhi p1.h, p1.b
-; CHECK-SVE2p1-SME2-NEXT: punpklo p2.h, p3.b
-; CHECK-SVE2p1-SME2-NEXT: punpkhi p3.h, p3.b
; CHECK-SVE2p1-SME2-NEXT: b use
; CHECK-SVE2p1-SME2-NEXT: .LBB13_2: // %if.end
; CHECK-SVE2p1-SME2-NEXT: ret
@@ -450,15 +459,18 @@ define void @test_4x2bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n
;
; CHECK-SVE2p1-SME2-LABEL: test_4x2bit_mask_with_extracts_and_reinterpret_casts:
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
-; CHECK-SVE2p1-SME2-NEXT: whilelo p0.h, x0, x1
+; CHECK-SVE2p1-SME2-NEXT: cntw x8
+; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8
+; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo
+; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.d, p1.d }, x0, x1
+; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.d, p3.d }, x8, x1
+; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.s, p0.s, p1.s
+; CHECK-SVE2p1-SME2-NEXT: uzp1 p5.s, p2.s, p3.s
+; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.h, p4.h, p5.h
+; CHECK-SVE2p1-SME2-NEXT: ptrue p5.h
+; CHECK-SVE2p1-SME2-NEXT: ptest p5, p4.b
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB14_2
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
-; CHECK-SVE2p1-SME2-NEXT: punpklo p1.h, p0.b
-; CHECK-SVE2p1-SME2-NEXT: punpkhi p3.h, p0.b
-; CHECK-SVE2p1-SME2-NEXT: punpklo p0.h, p1.b
-; CHECK-SVE2p1-SME2-NEXT: punpkhi p1.h, p1.b
-; CHECK-SVE2p1-SME2-NEXT: punpklo p2.h, p3.b
-; CHECK-SVE2p1-SME2-NEXT: punpkhi p3.h, p3.b
; CHECK-SVE2p1-SME2-NEXT: b use
; CHECK-SVE2p1-SME2-NEXT: .LBB14_2: // %if.end
; CHECK-SVE2p1-SME2-NEXT: ret
More information about the llvm-commits
mailing list