[llvm] [InstCombine] Fold `select(load, val) + store` into `llvm.masked.store` (PR #144298)

Abhishek Kaushik via llvm-commits llvm-commits at lists.llvm.org
Sun Jun 15 22:56:25 PDT 2025


https://github.com/abhishek-kaushik22 created https://github.com/llvm/llvm-project/pull/144298

This patch adds a new InstCombine optimization that transforms a pattern of the form:
```
%load = load <8 x i32>, ptr %ptr, align 32
%sel = select <8 x i1> %cmp, <8 x i32> %x, <8 x i32> %load
store <8 x i32> %sel, ptr %ptr, align 32
```
into:
```
@llvm.masked.store.v8i32.p0(<8 x i32> %x, ptr %ptr, i32 32, <8 x i1> %cmp)
```

>From ef5bab701685b224e1c42f2d1c73beb6304cda75 Mon Sep 17 00:00:00 2001
From: Abhishek Kaushik <abhishek.kaushik at intel.com>
Date: Mon, 16 Jun 2025 11:24:52 +0530
Subject: [PATCH] [InstCombine] Fold `select(load, val) + store` into
 `llvm.masked.store`

This patch adds a new InstCombine optimization that transforms a pattern of the form:
```
%load = load <8 x i32>, ptr %ptr, align 32
%sel = select <8 x i1> %cmp, <8 x i32> %x, <8 x i32> %load
store <8 x i32> %sel, ptr %ptr, align 32
```
into:
```
@llvm.masked.store.v8i32.p0(<8 x i32> %x, ptr %ptr, i32 32, <8 x i1> %cmp)
```
---
 .../InstCombineLoadStoreAlloca.cpp            | 45 +++++++++++++
 .../Transforms/InstCombine/masked-store.ll    | 63 +++++++++++++++++++
 .../Transforms/LoopVectorize/if-conversion.ll |  3 +-
 3 files changed, 109 insertions(+), 2 deletions(-)
 create mode 100644 llvm/test/Transforms/InstCombine/masked-store.ll

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
index 1d208de75db3b..ab0228d33db21 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
@@ -1363,6 +1363,48 @@ static bool equivalentAddressValues(Value *A, Value *B) {
   return false;
 }
 
+// Combine
+//   %load = load <8 x i32>, ptr %ptr, align 32
+//   %sel = select <8 x i1> %cmp, <8 x i32> %x, <8 x i32> %load
+//   store <8 x i32> %sel, ptr %ptr, align 32
+// to
+//   @llvm.masked.store.v8i32.p0(<8 x i32> %x, ptr %ptr, i32 32, <8 x i1> %cmp)
+static bool combineToMaskedStore(InstCombinerImpl &IC, StoreInst &Store) {
+  Value *StoredValue = Store.getValueOperand();
+  auto *Select = dyn_cast<SelectInst>(StoredValue);
+  if (!Select || !StoredValue->getType()->isVectorTy())
+    return false;
+
+  Value *Condition = Select->getCondition();
+  Value *TrueValue = Select->getTrueValue();
+  Value *FalseValue = Select->getFalseValue();
+
+  const auto *Load = dyn_cast<LoadInst>(FalseValue);
+  if (!Load || Load->getPointerOperand() != Store.getPointerOperand())
+    return false;
+
+  if (Load->isVolatile() || Store.isVolatile() || Load->isAtomic() ||
+      Store.isAtomic())
+    return false;
+
+  Value *Pointer = Store.getPointerOperand();
+
+  for (const auto *I = Load->getNextNode(); I && I != &Store;
+       I = I->getNextNode()) {
+    if (I->mayHaveSideEffects())
+      return false;
+
+    if (const auto *OtherStore = dyn_cast<StoreInst>(I)) {
+      if (OtherStore->getPointerOperand() == Pointer)
+        return false;
+    }
+  }
+
+  IC.Builder.CreateMaskedStore(TrueValue, Pointer, Store.getAlign(), Condition);
+
+  return true;
+}
+
 Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) {
   Value *Val = SI.getOperand(0);
   Value *Ptr = SI.getOperand(1);
@@ -1375,6 +1417,9 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) {
   if (unpackStoreToAggregate(*this, SI))
     return eraseInstFromFunction(SI);
 
+  if (combineToMaskedStore(*this, SI))
+    return eraseInstFromFunction(SI);
+
   // Replace GEP indices if possible.
   if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Ptr, SI))
     return replaceOperand(SI, 1, NewGEPI);
diff --git a/llvm/test/Transforms/InstCombine/masked-store.ll b/llvm/test/Transforms/InstCombine/masked-store.ll
new file mode 100644
index 0000000000000..bbbf2587a35ef
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/masked-store.ll
@@ -0,0 +1,63 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -passes=instcombine -S < %s | FileCheck %s
+
+define void @test_masked_store_success(<8 x i32> %x, ptr %ptr, <8 x i1> %cmp) {
+; CHECK-LABEL: define void @test_masked_store_success(
+; CHECK-SAME: <8 x i32> [[X:%.*]], ptr [[PTR:%.*]], <8 x i1> [[CMP:%.*]]) {
+; CHECK-NEXT:    call void @llvm.masked.store.v8i32.p0(<8 x i32> [[X]], ptr [[PTR]], i32 32, <8 x i1> [[CMP]])
+; CHECK-NEXT:    ret void
+;
+  %load = load <8 x i32>, ptr %ptr, align 32
+  %sel = select <8 x i1> %cmp, <8 x i32> %x, <8 x i32> %load
+  store <8 x i32> %sel, ptr %ptr, align 32
+  ret void
+}
+
+define void @test_masked_store_volatile_load(<8 x i32> %x, ptr %ptr, <8 x i1> %cmp) {
+; CHECK-LABEL: define void @test_masked_store_volatile_load(
+; CHECK-SAME: <8 x i32> [[X:%.*]], ptr [[PTR:%.*]], <8 x i1> [[CMP:%.*]]) {
+; CHECK-NEXT:    [[LOAD:%.*]] = load volatile <8 x i32>, ptr [[PTR]], align 32
+; CHECK-NEXT:    [[SEL:%.*]] = select <8 x i1> [[CMP]], <8 x i32> [[X]], <8 x i32> [[LOAD]]
+; CHECK-NEXT:    store <8 x i32> [[SEL]], ptr [[PTR]], align 32
+; CHECK-NEXT:    ret void
+;
+  %load = load volatile <8 x i32>, ptr %ptr, align 32
+  %sel = select <8 x i1> %cmp, <8 x i32> %x, <8 x i32> %load
+  store <8 x i32> %sel, ptr %ptr, align 32
+  ret void
+}
+
+define void @test_masked_store_volatile_store(<8 x i32> %x, ptr %ptr, <8 x i1> %cmp) {
+; CHECK-LABEL: define void @test_masked_store_volatile_store(
+; CHECK-SAME: <8 x i32> [[X:%.*]], ptr [[PTR:%.*]], <8 x i1> [[CMP:%.*]]) {
+; CHECK-NEXT:    [[LOAD:%.*]] = load <8 x i32>, ptr [[PTR]], align 32
+; CHECK-NEXT:    [[SEL:%.*]] = select <8 x i1> [[CMP]], <8 x i32> [[X]], <8 x i32> [[LOAD]]
+; CHECK-NEXT:    store volatile <8 x i32> [[SEL]], ptr [[PTR]], align 32
+; CHECK-NEXT:    ret void
+;
+  %load = load <8 x i32>, ptr %ptr, align 32
+  %sel = select <8 x i1> %cmp, <8 x i32> %x, <8 x i32> %load
+  store volatile <8 x i32> %sel, ptr %ptr, align 32
+  ret void
+}
+
+declare void @use_vec(<8 x i32>)
+
+define void @test_masked_store_intervening(<8 x i32> %x, ptr %ptr, <8 x i1> %cmp) {
+; CHECK-LABEL: define void @test_masked_store_intervening(
+; CHECK-SAME: <8 x i32> [[X:%.*]], ptr [[PTR:%.*]], <8 x i1> [[CMP:%.*]]) {
+; CHECK-NEXT:    [[LOAD:%.*]] = load <8 x i32>, ptr [[PTR]], align 32
+; CHECK-NEXT:    store <8 x i32> zeroinitializer, ptr [[PTR]], align 32
+; CHECK-NEXT:    call void @use_vec(<8 x i32> zeroinitializer)
+; CHECK-NEXT:    [[SEL:%.*]] = select <8 x i1> [[CMP]], <8 x i32> [[X]], <8 x i32> [[LOAD]]
+; CHECK-NEXT:    store <8 x i32> [[SEL]], ptr [[PTR]], align 32
+; CHECK-NEXT:    ret void
+;
+  %load = load <8 x i32>, ptr %ptr, align 32
+  store <8 x i32> zeroinitializer, ptr %ptr, align 32
+  %tmp = load <8 x i32>, ptr %ptr
+  call void @use_vec(<8 x i32> %tmp)
+  %sel = select <8 x i1> %cmp, <8 x i32> %x, <8 x i32> %load
+  store <8 x i32> %sel, ptr %ptr, align 32
+  ret void
+}
diff --git a/llvm/test/Transforms/LoopVectorize/if-conversion.ll b/llvm/test/Transforms/LoopVectorize/if-conversion.ll
index 8a7f4a386fda1..622726f6d1fe7 100644
--- a/llvm/test/Transforms/LoopVectorize/if-conversion.ll
+++ b/llvm/test/Transforms/LoopVectorize/if-conversion.ll
@@ -61,8 +61,7 @@ define i32 @function0(ptr nocapture %a, ptr nocapture %b, i32 %start, i32 %end)
 ; CHECK-NEXT:    [[DOTNOT:%.*]] = icmp sgt <4 x i32> [[WIDE_LOAD]], [[WIDE_LOAD4]]
 ; CHECK-NEXT:    [[TMP15:%.*]] = mul <4 x i32> [[WIDE_LOAD]], splat (i32 5)
 ; CHECK-NEXT:    [[TMP16:%.*]] = add <4 x i32> [[TMP15]], splat (i32 3)
-; CHECK-NEXT:    [[PREDPHI:%.*]] = select <4 x i1> [[DOTNOT]], <4 x i32> [[TMP16]], <4 x i32> [[WIDE_LOAD]]
-; CHECK-NEXT:    store <4 x i32> [[PREDPHI]], ptr [[TMP13]], align 4, !alias.scope [[META0]], !noalias [[META3]]
+; CHECK-NEXT:    call void @llvm.masked.store.v4i32.p0(<4 x i32> [[TMP16]], ptr [[TMP13]], i32 4, <4 x i1> [[DOTNOT]])
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
 ; CHECK-NEXT:    [[TMP17:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[TMP17]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP5:![0-9]+]]



More information about the llvm-commits mailing list