[llvm] [LLVM][DAGCombiner] Look through freeze when combining extensions of extending-masked-loads. (PR #172484)
Paul Walker via llvm-commits
llvm-commits at lists.llvm.org
Tue Dec 16 06:41:18 PST 2025
https://github.com/paulwalker-arm created https://github.com/llvm/llvm-project/pull/172484
Extensions in this context mean post legalisation extensions (i.e. and, sext-in-reg) because that's the point the freeze blocks the existing combine.
NOTE: The use check for the signed case looks overly restrictive but nonetheless I figured it best for the freeze variant to maintain the same behaviour.
>From c0a84b0b72933fad5b2a6df91cd4e956e17df1e5 Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Mon, 15 Dec 2025 16:35:30 +0000
Subject: [PATCH 1/2] Add tests showing how freeze prevents extload combine.
---
.../CodeGen/AArch64/sve-masked-ldst-sext.ll | 29 +++++++++++++++++++
.../CodeGen/AArch64/sve-masked-ldst-zext.ll | 27 +++++++++++++++++
2 files changed, 56 insertions(+)
diff --git a/llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll b/llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll
index 5277c2efab85d..94b7bbf5b135f 100644
--- a/llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll
+++ b/llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll
@@ -257,6 +257,35 @@ define <vscale x 8 x i64> @masked_sload_x2_8i8_8i64(ptr %a, ptr %b, <vscale x 8
ret <vscale x 8 x i64> %res
}
+define <vscale x 2 x i64> @masked_load_frozen_before_sext(ptr %a, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_load_frozen_before_sext:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ld1b { z0.d }, p0/z, [x0]
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: sxtb z0.d, p0/m, z0.d
+; CHECK-NEXT: ret
+ %load = call <vscale x 2 x i8> @llvm.masked.load.nxv2i8(ptr %a, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i8> poison)
+ %load.frozen = freeze <vscale x 2 x i8> %load
+ %ext = sext <vscale x 2 x i8> %load.frozen to <vscale x 2 x i64>
+ ret <vscale x 2 x i64> %ext
+}
+
+; A multi-use freeze effectively means the load is also multi-use.
+define <vscale x 2 x i64> @masked_load_frozen_before_sext_multiuse(ptr %a, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_load_frozen_before_sext_multiuse:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ld1b { z1.d }, p0/z, [x0]
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: movprfx z0, z1
+; CHECK-NEXT: sxtb z0.d, p0/m, z1.d
+; CHECK-NEXT: // fake_use: $z1
+; CHECK-NEXT: ret
+ %load = call <vscale x 2 x i8> @llvm.masked.load.nxv2i8(ptr %a, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i8> poison)
+ %load.frozen = freeze <vscale x 2 x i8> %load
+ %ext = sext <vscale x 2 x i8> %load.frozen to <vscale x 2 x i64>
+ call void (...) @llvm.fake.use(<vscale x 2 x i8> %load.frozen)
+ ret <vscale x 2 x i64> %ext
+}
declare <vscale x 2 x i8> @llvm.masked.load.nxv2i8(ptr, i32, <vscale x 2 x i1>, <vscale x 2 x i8>)
declare <vscale x 2 x i16> @llvm.masked.load.nxv2i16(ptr, i32, <vscale x 2 x i1>, <vscale x 2 x i16>)
diff --git a/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll b/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll
index f69ab0de06e08..f3b57999c4e22 100644
--- a/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll
+++ b/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll
@@ -252,6 +252,33 @@ define <vscale x 8 x i64> @masked_zload_x2_8i8_8i64(ptr %a, ptr %b, <vscale x 8
ret <vscale x 8 x i64> %res
}
+define <vscale x 2 x i64> @masked_load_frozen_before_zext(ptr %src, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_load_frozen_before_zext:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ld1b { z0.d }, p0/z, [x0]
+; CHECK-NEXT: and z0.d, z0.d, #0xff
+; CHECK-NEXT: ret
+ %load = call <vscale x 2 x i8> @llvm.masked.load.nxv2i8(ptr %src, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i8> poison)
+ %load.frozen = freeze <vscale x 2 x i8> %load
+ %ext = zext <vscale x 2 x i8> %load.frozen to <vscale x 2 x i64>
+ ret <vscale x 2 x i64> %ext
+}
+
+; A multi-use freeze effectively means the load is also multi-use.
+define <vscale x 2 x i64> @masked_load_frozen_before_zext_multiuse(ptr %src, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_load_frozen_before_zext_multiuse:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ld1b { z0.d }, p0/z, [x0]
+; CHECK-NEXT: mov z1.d, z0.d
+; CHECK-NEXT: and z0.d, z0.d, #0xff
+; CHECK-NEXT: // fake_use: $z1
+; CHECK-NEXT: ret
+ %load = call <vscale x 2 x i8> @llvm.masked.load.nxv2i8(ptr %src, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i8> poison)
+ %load.frozen = freeze <vscale x 2 x i8> %load
+ %ext = zext <vscale x 2 x i8> %load.frozen to <vscale x 2 x i64>
+ call void (...) @llvm.fake.use(<vscale x 2 x i8> %load.frozen)
+ ret <vscale x 2 x i64> %ext
+}
declare <vscale x 2 x i8> @llvm.masked.load.nxv2i8(ptr, i32, <vscale x 2 x i1>, <vscale x 2 x i8>)
declare <vscale x 2 x i16> @llvm.masked.load.nxv2i16(ptr, i32, <vscale x 2 x i1>, <vscale x 2 x i16>)
>From e27397754f2180aac48f087a84735a1aa1812b26 Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Mon, 15 Dec 2025 16:58:32 +0000
Subject: [PATCH 2/2] [LLVM][DAGCombiner] Look through freeze when combining
extensions of extending-masked-loads.
---
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 31 +++++++++----------
...rleaving-reductions-predicated-scalable.ll | 1 -
.../CodeGen/AArch64/sve-masked-ldst-sext.ll | 4 +--
.../CodeGen/AArch64/sve-masked-ldst-zext.ll | 5 +--
4 files changed, 16 insertions(+), 25 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 6a99d4e29b64f..879041be473f3 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -7518,27 +7518,23 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
return N0;
// fold (and (masked_load) (splat_vec (x, ...))) to zext_masked_load
- auto *MLoad = dyn_cast<MaskedLoadSDNode>(N0);
+ bool Frozen = N0.getOpcode() == ISD::FREEZE;
+ auto *MLoad = dyn_cast<MaskedLoadSDNode>(Frozen ? N0.getOperand(0) : N0);
ConstantSDNode *Splat = isConstOrConstSplat(N1, true, true);
if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && Splat) {
- EVT LoadVT = MLoad->getMemoryVT();
- EVT ExtVT = VT;
- if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT)) {
+ EVT MemVT = MLoad->getMemoryVT();
+ if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT)) {
// For this AND to be a zero extension of the masked load the elements
// of the BuildVec must mask the bottom bits of the extended element
// type
- uint64_t ElementSize =
- LoadVT.getVectorElementType().getScalarSizeInBits();
- if (Splat->getAPIntValue().isMask(ElementSize)) {
+ if (Splat->getAPIntValue().isMask(MemVT.getScalarSizeInBits())) {
SDValue NewLoad = DAG.getMaskedLoad(
- ExtVT, DL, MLoad->getChain(), MLoad->getBasePtr(),
+ VT, DL, MLoad->getChain(), MLoad->getBasePtr(),
MLoad->getOffset(), MLoad->getMask(), MLoad->getPassThru(),
- LoadVT, MLoad->getMemOperand(), MLoad->getAddressingMode(),
+ MemVT, MLoad->getMemOperand(), MLoad->getAddressingMode(),
ISD::ZEXTLOAD, MLoad->isExpandingLoad());
- bool LoadHasOtherUsers = !N0.hasOneUse();
- CombineTo(N, NewLoad);
- if (LoadHasOtherUsers)
- CombineTo(MLoad, NewLoad.getValue(0), NewLoad.getValue(1));
+ CombineTo(N, Frozen ? N0 : NewLoad);
+ CombineTo(MLoad, NewLoad, NewLoad.getValue(1));
return SDValue(N, 0);
}
}
@@ -15992,16 +15988,17 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
// fold (sext_inreg (masked_load x)) -> (sext_masked_load x)
// ignore it if the masked load is already sign extended
- if (MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(N0)) {
- if (ExtVT == Ld->getMemoryVT() && N0.hasOneUse() &&
+ bool Frozen = N0.getOpcode() == ISD::FREEZE && N0.hasOneUse();
+ if (auto *Ld = dyn_cast<MaskedLoadSDNode>(Frozen ? N0.getOperand(0) : N0)) {
+ if (ExtVT == Ld->getMemoryVT() && Ld->hasNUsesOfValue(1, 0) &&
Ld->getExtensionType() != ISD::LoadExtType::NON_EXTLOAD &&
TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT)) {
SDValue ExtMaskedLoad = DAG.getMaskedLoad(
VT, DL, Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(),
Ld->getMask(), Ld->getPassThru(), ExtVT, Ld->getMemOperand(),
Ld->getAddressingMode(), ISD::SEXTLOAD, Ld->isExpandingLoad());
- CombineTo(N, ExtMaskedLoad);
- CombineTo(N0.getNode(), ExtMaskedLoad, ExtMaskedLoad.getValue(1));
+ CombineTo(N, Frozen ? N0 : ExtMaskedLoad);
+ CombineTo(Ld, ExtMaskedLoad, ExtMaskedLoad.getValue(1));
return SDValue(N, 0); // Return N so it doesn't get rechecked!
}
}
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-predicated-scalable.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-predicated-scalable.ll
index 79f0cd345f95c..fc672dfa84edd 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-predicated-scalable.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-predicated-scalable.ll
@@ -225,7 +225,6 @@ define %"class.std::complex" @complex_mul_predicated_x2_v2f64(ptr %a, ptr %b, pt
; CHECK-NEXT: mov z6.d, z0.d
; CHECK-NEXT: mov z7.d, z1.d
; CHECK-NEXT: add x8, x8, x10
-; CHECK-NEXT: and z2.d, z2.d, #0xffffffff
; CHECK-NEXT: cmpne p1.d, p1/z, z2.d, #0
; CHECK-NEXT: zip2 p2.d, p1.d, p1.d
; CHECK-NEXT: zip1 p1.d, p1.d, p1.d
diff --git a/llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll b/llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll
index 94b7bbf5b135f..5c506ab20e6f4 100644
--- a/llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll
+++ b/llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll
@@ -260,9 +260,7 @@ define <vscale x 8 x i64> @masked_sload_x2_8i8_8i64(ptr %a, ptr %b, <vscale x 8
define <vscale x 2 x i64> @masked_load_frozen_before_sext(ptr %a, <vscale x 2 x i1> %mask) {
; CHECK-LABEL: masked_load_frozen_before_sext:
; CHECK: // %bb.0:
-; CHECK-NEXT: ld1b { z0.d }, p0/z, [x0]
-; CHECK-NEXT: ptrue p0.d
-; CHECK-NEXT: sxtb z0.d, p0/m, z0.d
+; CHECK-NEXT: ld1sb { z0.d }, p0/z, [x0]
; CHECK-NEXT: ret
%load = call <vscale x 2 x i8> @llvm.masked.load.nxv2i8(ptr %a, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i8> poison)
%load.frozen = freeze <vscale x 2 x i8> %load
diff --git a/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll b/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll
index f3b57999c4e22..09e276f1cc433 100644
--- a/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll
+++ b/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll
@@ -256,7 +256,6 @@ define <vscale x 2 x i64> @masked_load_frozen_before_zext(ptr %src, <vscale x 2
; CHECK-LABEL: masked_load_frozen_before_zext:
; CHECK: // %bb.0:
; CHECK-NEXT: ld1b { z0.d }, p0/z, [x0]
-; CHECK-NEXT: and z0.d, z0.d, #0xff
; CHECK-NEXT: ret
%load = call <vscale x 2 x i8> @llvm.masked.load.nxv2i8(ptr %src, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i8> poison)
%load.frozen = freeze <vscale x 2 x i8> %load
@@ -269,9 +268,7 @@ define <vscale x 2 x i64> @masked_load_frozen_before_zext_multiuse(ptr %src, <vs
; CHECK-LABEL: masked_load_frozen_before_zext_multiuse:
; CHECK: // %bb.0:
; CHECK-NEXT: ld1b { z0.d }, p0/z, [x0]
-; CHECK-NEXT: mov z1.d, z0.d
-; CHECK-NEXT: and z0.d, z0.d, #0xff
-; CHECK-NEXT: // fake_use: $z1
+; CHECK-NEXT: // fake_use: $z0
; CHECK-NEXT: ret
%load = call <vscale x 2 x i8> @llvm.masked.load.nxv2i8(ptr %src, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i8> poison)
%load.frozen = freeze <vscale x 2 x i8> %load
More information about the llvm-commits
mailing list