[llvm] [InstCombine] Fold selects into masked loads (PR #160522)
Matthew Devereau via llvm-commits
llvm-commits at lists.llvm.org
Wed Sep 24 08:21:34 PDT 2025
https://github.com/MDevereau updated https://github.com/llvm/llvm-project/pull/160522
>From c223655e2b48b50dc246e6d0abbe1eaf14efffa8 Mon Sep 17 00:00:00 2001
From: Matthew Devereau <matthew.devereau at arm.com>
Date: Wed, 24 Sep 2025 12:54:25 +0000
Subject: [PATCH 1/3] [InstCombine] Fold selects into masked loads
Selects can be folded into masked loads if the masks are identical
---
.../InstCombine/InstCombineSelect.cpp | 11 +++++++++
.../InstCombine/select-masked_load.ll | 24 +++++++++++++++++--
2 files changed, 33 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 4ea75409252bd..50dbc965a3b9b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -4611,5 +4611,16 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
return replaceOperand(SI, 2, ConstantInt::get(FalseVal->getType(), 0));
}
+ Value *MaskedLoadPtr;
+ const APInt *MaskedLoadAlignment;
+ if (match(TrueVal,
+ m_MaskedLoad(m_Value(MaskedLoadPtr), m_APInt(MaskedLoadAlignment),
+ m_Specific(CondVal), m_Value())))
+ return replaceInstUsesWith(
+ SI, Builder.CreateMaskedLoad(
+ TrueVal->getType(), MaskedLoadPtr,
+ llvm::Align(MaskedLoadAlignment->getZExtValue()), CondVal,
+ FalseVal));
+
return nullptr;
}
diff --git a/llvm/test/Transforms/InstCombine/select-masked_load.ll b/llvm/test/Transforms/InstCombine/select-masked_load.ll
index b6bac612d6f9b..650b0b79c7cbf 100644
--- a/llvm/test/Transforms/InstCombine/select-masked_load.ll
+++ b/llvm/test/Transforms/InstCombine/select-masked_load.ll
@@ -26,8 +26,7 @@ define <4 x i32> @masked_load_and_zero_inactive_2(ptr %ptr, <4 x i1> %mask) {
; No transform when the load's passthrough cannot be reused or altered.
define <4 x i32> @masked_load_and_zero_inactive_3(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthrough) {
; CHECK-LABEL: @masked_load_and_zero_inactive_3(
-; CHECK-NEXT: [[LOAD:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr [[PTR:%.*]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> [[PASSTHROUGH:%.*]])
-; CHECK-NEXT: [[MASKED:%.*]] = select <4 x i1> [[MASK]], <4 x i32> [[LOAD]], <4 x i32> zeroinitializer
+; CHECK-NEXT: [[MASKED:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr [[PTR:%.*]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> zeroinitializer)
; CHECK-NEXT: ret <4 x i32> [[MASKED]]
;
%load = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr %ptr, i32 4, <4 x i1> %mask, <4 x i32> %passthrough)
@@ -116,6 +115,27 @@ entry:
ret <8 x float> %1
}
+define <vscale x 4 x float> @fold_sel_into_masked_load_scalable(ptr %loc, <vscale x 4 x i1> %mask, <vscale x 4 x float> %passthrough) {
+; CHECK-LABEL: @fold_sel_into_masked_load_scalable(
+; CHECK-NEXT: [[SEL:%.*]] = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr [[LOC:%.*]], i32 1, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x float> [[PASSTHROUGH:%.*]])
+; CHECK-NEXT: ret <vscale x 4 x float> [[SEL]]
+;
+ %load = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr %loc, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+ %sel = select <vscale x 4 x i1> %mask, <vscale x 4 x float> %load, <vscale x 4 x float> %passthrough
+ ret <vscale x 4 x float> %sel
+}
+
+define <vscale x 4 x float> @neg_fold_sel_into_masked_load_mask_mismatch(ptr %loc, <vscale x 4 x i1> %mask, <vscale x 4 x i1> %mask2, <vscale x 4 x float> %passthrough) {
+; CHECK-LABEL: @neg_fold_sel_into_masked_load_mask_mismatch(
+; CHECK-NEXT: [[LOAD:%.*]] = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr [[LOC:%.*]], i32 1, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x float> [[PASSTHROUGH:%.*]])
+; CHECK-NEXT: [[SEL:%.*]] = select <vscale x 4 x i1> [[MASK2:%.*]], <vscale x 4 x float> [[LOAD]], <vscale x 4 x float> [[PASSTHROUGH]]
+; CHECK-NEXT: ret <vscale x 4 x float> [[SEL]]
+;
+ %load = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr %loc, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> %passthrough)
+ %sel = select <vscale x 4 x i1> %mask2, <vscale x 4 x float> %load, <vscale x 4 x float> %passthrough
+ ret <vscale x 4 x float> %sel
+}
+
declare <8 x float> @llvm.masked.load.v8f32.p0(ptr, i32 immarg, <8 x i1>, <8 x float>)
declare <4 x i32> @llvm.masked.load.v4i32.p0(ptr, i32 immarg, <4 x i1>, <4 x i32>)
declare <4 x float> @llvm.masked.load.v4f32.p0(ptr, i32 immarg, <4 x i1>, <4 x float>)
>From c802e73bc1b7abda1419b2ae40934245a00c1e34 Mon Sep 17 00:00:00 2001
From: Matthew Devereau <matthew.devereau at arm.com>
Date: Wed, 24 Sep 2025 14:04:56 +0000
Subject: [PATCH 2/3] Add hasOneUse check and test
---
.../Transforms/InstCombine/InstCombineSelect.cpp | 3 ++-
.../Transforms/InstCombine/select-masked_load.ll | 13 +++++++++++++
2 files changed, 15 insertions(+), 1 deletion(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 50dbc965a3b9b..6ffaebf425394 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -4613,7 +4613,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
Value *MaskedLoadPtr;
const APInt *MaskedLoadAlignment;
- if (match(TrueVal,
+ if (TrueVal->hasOneUse() &&
+ match(TrueVal,
m_MaskedLoad(m_Value(MaskedLoadPtr), m_APInt(MaskedLoadAlignment),
m_Specific(CondVal), m_Value())))
return replaceInstUsesWith(
diff --git a/llvm/test/Transforms/InstCombine/select-masked_load.ll b/llvm/test/Transforms/InstCombine/select-masked_load.ll
index 650b0b79c7cbf..22e30ac019a5d 100644
--- a/llvm/test/Transforms/InstCombine/select-masked_load.ll
+++ b/llvm/test/Transforms/InstCombine/select-masked_load.ll
@@ -136,6 +136,19 @@ define <vscale x 4 x float> @neg_fold_sel_into_masked_load_mask_mismatch(ptr %lo
ret <vscale x 4 x float> %sel
}
+define <vscale x 4 x float> @fold_sel_into_masked_load_scalable_one_use_check(ptr %loc1, <vscale x 4 x i1> %mask, <vscale x 4 x float> %passthrough, ptr %loc2) {
+; CHECK-LABEL: @fold_sel_into_masked_load_scalable_one_use_check(
+; CHECK-NEXT: [[LOAD:%.*]] = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr [[LOC:%.*]], i32 1, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x float> zeroinitializer)
+; CHECK-NEXT: [[SEL:%.*]] = select <vscale x 4 x i1> [[MASK]], <vscale x 4 x float> [[LOAD]], <vscale x 4 x float> [[PASSTHROUGH:%.*]]
+; CHECK-NEXT: call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> [[LOAD]], ptr [[LOC2:%.*]], i32 1, <vscale x 4 x i1> [[MASK]])
+; CHECK-NEXT: ret <vscale x 4 x float> [[SEL]]
+;
+ %load = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr %loc1, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+ %sel = select <vscale x 4 x i1> %mask, <vscale x 4 x float> %load, <vscale x 4 x float> %passthrough
+ call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load, ptr %loc2, i32 1, <vscale x 4 x i1> %mask)
+ ret <vscale x 4 x float> %sel
+}
+
declare <8 x float> @llvm.masked.load.v8f32.p0(ptr, i32 immarg, <8 x i1>, <8 x float>)
declare <4 x i32> @llvm.masked.load.v4i32.p0(ptr, i32 immarg, <4 x i1>, <4 x i32>)
declare <4 x float> @llvm.masked.load.v4f32.p0(ptr, i32 immarg, <4 x i1>, <4 x float>)
>From 438081406d4eaa4db7b9d991148ae7aedf31b4bf Mon Sep 17 00:00:00 2001
From: Matthew Devereau <matthew.devereau at arm.com>
Date: Wed, 24 Sep 2025 15:17:55 +0000
Subject: [PATCH 3/3] Add suggested changes
---
.../Transforms/InstCombine/InstCombineSelect.cpp | 14 ++++++--------
1 file changed, 6 insertions(+), 8 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 6ffaebf425394..b6b3a95f35c76 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -4613,15 +4613,13 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
Value *MaskedLoadPtr;
const APInt *MaskedLoadAlignment;
- if (TrueVal->hasOneUse() &&
- match(TrueVal,
- m_MaskedLoad(m_Value(MaskedLoadPtr), m_APInt(MaskedLoadAlignment),
- m_Specific(CondVal), m_Value())))
+ if (match(TrueVal, m_OneUse(m_MaskedLoad(m_Value(MaskedLoadPtr),
+ m_APInt(MaskedLoadAlignment),
+ m_Specific(CondVal), m_Value()))))
return replaceInstUsesWith(
- SI, Builder.CreateMaskedLoad(
- TrueVal->getType(), MaskedLoadPtr,
- llvm::Align(MaskedLoadAlignment->getZExtValue()), CondVal,
- FalseVal));
+ SI, Builder.CreateMaskedLoad(TrueVal->getType(), MaskedLoadPtr,
+ Align(MaskedLoadAlignment->getZExtValue()),
+ CondVal, FalseVal));
return nullptr;
}
More information about the llvm-commits
mailing list