[llvm] [InstCombine] Increase alignment in masked load / store instrinsics if known (PR #156057)

via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 29 09:37:10 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Joseph Huber (jhuber6)

<details>
<summary>Changes</summary>

Summary:
The masked load / store LLVM intrinsics take an argument for the
alignment. If the user is pessimistic about alignment they can provide a
value of `1` for an unaligned load. This patch updates instcombine to
increase the alignment value of the alignment argument if it is known
greater than the provided one.

Ignoring the gather / scatter versions for now since they contain many
pointers.


---
Full diff: https://github.com/llvm/llvm-project/pull/156057.diff


3 Files Affected:

- (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+34-14) 
- (modified) llvm/test/Transforms/InstCombine/load-store-masked-constant-array.ll (+1-1) 
- (modified) llvm/test/Transforms/InstCombine/masked_intrinsics.ll (+31) 


``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 42b65dde67255..7e50e55ae24c8 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -288,8 +288,11 @@ Instruction *InstCombinerImpl::SimplifyAnyMemSet(AnyMemSetInst *MI) {
 // * Narrow width by halfs excluding zero/undef lanes
 Value *InstCombinerImpl::simplifyMaskedLoad(IntrinsicInst &II) {
   Value *LoadPtr = II.getArgOperand(0);
-  const Align Alignment =
-      cast<ConstantInt>(II.getArgOperand(1))->getAlignValue();
+  Align Alignment = cast<ConstantInt>(II.getArgOperand(1))->getAlignValue();
+
+  Align KnownAlign = getKnownAlignment(LoadPtr, DL, &II, &AC, &DT);
+  if (Alignment < KnownAlign)
+    Alignment = KnownAlign;
 
   // If the mask is all ones or undefs, this is a plain vector load of the 1st
   // argument.
@@ -310,6 +313,15 @@ Value *InstCombinerImpl::simplifyMaskedLoad(IntrinsicInst &II) {
     return Builder.CreateSelect(II.getArgOperand(2), LI, II.getArgOperand(3));
   }
 
+  // Update the alignment if the known value is greater than the provided one.
+  if (cast<ConstantInt>(II.getArgOperand(1))->getAlignValue() < Alignment) {
+    SmallVector<Value *> Args(II.arg_begin(), II.arg_end());
+    Args[1] = Builder.getInt32(Alignment.value());
+    CallInst *CI = Builder.CreateCall(II.getCalledFunction(), Args);
+    CI->copyMetadata(II);
+    return CI;
+  }
+
   return nullptr;
 }
 
@@ -317,33 +329,41 @@ Value *InstCombinerImpl::simplifyMaskedLoad(IntrinsicInst &II) {
 // * Single constant active lane -> store
 // * Narrow width by halfs excluding zero/undef lanes
 Instruction *InstCombinerImpl::simplifyMaskedStore(IntrinsicInst &II) {
-  auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3));
-  if (!ConstMask)
-    return nullptr;
+  Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
+
+  Align KnownAlign = getKnownAlignment(II.getArgOperand(1), DL, &II, &AC, &DT);
+  if (Alignment < KnownAlign)
+    Alignment = KnownAlign;
 
   // If the mask is all zeros, this instruction does nothing.
-  if (ConstMask->isNullValue())
+  auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3));
+  if (ConstMask && ConstMask->isNullValue())
     return eraseInstFromFunction(II);
 
   // If the mask is all ones, this is a plain vector store of the 1st argument.
-  if (ConstMask->isAllOnesValue()) {
+  if (ConstMask && ConstMask->isAllOnesValue()) {
     Value *StorePtr = II.getArgOperand(1);
-    Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
     StoreInst *S =
         new StoreInst(II.getArgOperand(0), StorePtr, false, Alignment);
     S->copyMetadata(II);
     return S;
   }
 
-  if (isa<ScalableVectorType>(ConstMask->getType()))
+  if (ConstMask && isa<ScalableVectorType>(ConstMask->getType()))
     return nullptr;
 
   // Use masked off lanes to simplify operands via SimplifyDemandedVectorElts
-  APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask);
-  APInt PoisonElts(DemandedElts.getBitWidth(), 0);
-  if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0), DemandedElts,
-                                            PoisonElts))
-    return replaceOperand(II, 0, V);
+  if (ConstMask) {
+    APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask);
+    APInt PoisonElts(DemandedElts.getBitWidth(), 0);
+    if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0), DemandedElts,
+                                              PoisonElts))
+      return replaceOperand(II, 0, V);
+  }
+
+  // Update the alignment if the known value is greater than the provided one.
+  if (cast<ConstantInt>(II.getArgOperand(2))->getAlignValue() < Alignment)
+    return replaceOperand(II, 2, Builder.getInt32(Alignment.value()));
 
   return nullptr;
 }
diff --git a/llvm/test/Transforms/InstCombine/load-store-masked-constant-array.ll b/llvm/test/Transforms/InstCombine/load-store-masked-constant-array.ll
index 918ea605a10bf..6ba52c178b8d4 100644
--- a/llvm/test/Transforms/InstCombine/load-store-masked-constant-array.ll
+++ b/llvm/test/Transforms/InstCombine/load-store-masked-constant-array.ll
@@ -7,7 +7,7 @@
 define void @combine_masked_load_store_from_constant_array(ptr %ptr) {
 ; CHECK-LABEL: @combine_masked_load_store_from_constant_array(
 ; CHECK-NEXT:    [[TMP1:%.*]] = tail call <vscale x 2 x i1> @llvm.aarch64.sve.whilelt.nxv2i1.i32(i32 0, i32 10)
-; CHECK-NEXT:    [[TMP2:%.*]] = call <vscale x 2 x i64> @llvm.masked.load.nxv2i64.p0(ptr nonnull @contant_int_array, i32 8, <vscale x 2 x i1> [[TMP1]], <vscale x 2 x i64> zeroinitializer)
+; CHECK-NEXT:    [[TMP2:%.*]] = call <vscale x 2 x i64> @llvm.masked.load.nxv2i64.p0(ptr nonnull @contant_int_array, i32 16, <vscale x 2 x i1> [[TMP1]], <vscale x 2 x i64> zeroinitializer)
 ; CHECK-NEXT:    call void @llvm.masked.store.nxv2i64.p0(<vscale x 2 x i64> [[TMP2]], ptr [[PTR:%.*]], i32 1, <vscale x 2 x i1> [[TMP1]])
 ; CHECK-NEXT:    ret void
 ;
diff --git a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
index 8f7683419a82a..6c1168e6c1c70 100644
--- a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
+++ b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
@@ -439,3 +439,34 @@ define <2 x i64> @negative_gather_v2i64_uniform_ptrs_no_all_active_mask(ptr %src
 declare <vscale x 2 x i64> @llvm.masked.gather.nxv2i64(<vscale x 2 x ptr>, i32, <vscale x 2 x i1>, <vscale x 2 x i64>)
 declare <2 x i64> @llvm.masked.gather.v2i64(<2 x ptr>, i32, <2 x i1>, <2 x i64>)
 
+; Alignment tests
+
+define <2 x i32> @unaligned_load(<2 x i1> %mask, ptr %ptr) {
+; CHECK-LABEL: @unaligned_load(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    call void @llvm.assume(i1 true) [ "align"(ptr [[PTR:%.*]], i64 64) ]
+; CHECK-NEXT:    [[MASKED_LOAD:%.*]] = call <2 x i32> @llvm.masked.load.v2i32.p0(ptr [[PTR]], i32 64, <2 x i1> [[MASK:%.*]], <2 x i32> poison)
+; CHECK-NEXT:    ret <2 x i32> [[MASKED_LOAD]]
+;
+entry:
+  call void @llvm.assume(i1 true) [ "align"(ptr %ptr, i64 64) ]
+  %masked_load = call <2 x i32> @llvm.masked.load.v2i32.p0(ptr %ptr, i32 1, <2 x i1> %mask, <2 x i32> poison)
+  ret <2 x i32> %masked_load
+}
+
+define void @unaligned_store(<2 x i1> %mask, <2 x i32> %val, ptr %ptr) {
+; CHECK-LABEL: @unaligned_store(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    call void @llvm.assume(i1 true) [ "align"(ptr [[PTR:%.*]], i64 64) ]
+; CHECK-NEXT:    tail call void @llvm.masked.store.v2i32.p0(<2 x i32> [[VAL:%.*]], ptr [[PTR]], i32 64, <2 x i1> [[MASK:%.*]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  call void @llvm.assume(i1 true) [ "align"(ptr %ptr, i64 64) ]
+  tail call void @llvm.masked.store.v2i32.p0(<2 x i32> %val, ptr %ptr, i32 1, <2 x i1> %mask)
+  ret void
+}
+
+declare void @llvm.assume(i1)
+declare <2 x i32> @llvm.masked.load.v2i32.p0(ptr, i32, <2 x i1>, <2 x i32>)
+declare void @llvm.masked.store.v2i32.p0(<2 x i32>, ptr, i32, <2 x i1>)

``````````

</details>


https://github.com/llvm/llvm-project/pull/156057


More information about the llvm-commits mailing list