[llvm] [InstCombine] Increase alignment in masked load / store instrinsics if known (PR #156057)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Aug 29 09:37:10 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Joseph Huber (jhuber6)
<details>
<summary>Changes</summary>
Summary:
The masked load / store LLVM intrinsics take an argument for the
alignment. If the user is pessimistic about alignment they can provide a
value of `1` for an unaligned load. This patch updates instcombine to
increase the alignment value of the alignment argument if it is known
greater than the provided one.
Ignoring the gather / scatter versions for now since they contain many
pointers.
---
Full diff: https://github.com/llvm/llvm-project/pull/156057.diff
3 Files Affected:
- (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+34-14)
- (modified) llvm/test/Transforms/InstCombine/load-store-masked-constant-array.ll (+1-1)
- (modified) llvm/test/Transforms/InstCombine/masked_intrinsics.ll (+31)
``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 42b65dde67255..7e50e55ae24c8 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -288,8 +288,11 @@ Instruction *InstCombinerImpl::SimplifyAnyMemSet(AnyMemSetInst *MI) {
// * Narrow width by halfs excluding zero/undef lanes
Value *InstCombinerImpl::simplifyMaskedLoad(IntrinsicInst &II) {
Value *LoadPtr = II.getArgOperand(0);
- const Align Alignment =
- cast<ConstantInt>(II.getArgOperand(1))->getAlignValue();
+ Align Alignment = cast<ConstantInt>(II.getArgOperand(1))->getAlignValue();
+
+ Align KnownAlign = getKnownAlignment(LoadPtr, DL, &II, &AC, &DT);
+ if (Alignment < KnownAlign)
+ Alignment = KnownAlign;
// If the mask is all ones or undefs, this is a plain vector load of the 1st
// argument.
@@ -310,6 +313,15 @@ Value *InstCombinerImpl::simplifyMaskedLoad(IntrinsicInst &II) {
return Builder.CreateSelect(II.getArgOperand(2), LI, II.getArgOperand(3));
}
+ // Update the alignment if the known value is greater than the provided one.
+ if (cast<ConstantInt>(II.getArgOperand(1))->getAlignValue() < Alignment) {
+ SmallVector<Value *> Args(II.arg_begin(), II.arg_end());
+ Args[1] = Builder.getInt32(Alignment.value());
+ CallInst *CI = Builder.CreateCall(II.getCalledFunction(), Args);
+ CI->copyMetadata(II);
+ return CI;
+ }
+
return nullptr;
}
@@ -317,33 +329,41 @@ Value *InstCombinerImpl::simplifyMaskedLoad(IntrinsicInst &II) {
// * Single constant active lane -> store
// * Narrow width by halfs excluding zero/undef lanes
Instruction *InstCombinerImpl::simplifyMaskedStore(IntrinsicInst &II) {
- auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3));
- if (!ConstMask)
- return nullptr;
+ Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
+
+ Align KnownAlign = getKnownAlignment(II.getArgOperand(1), DL, &II, &AC, &DT);
+ if (Alignment < KnownAlign)
+ Alignment = KnownAlign;
// If the mask is all zeros, this instruction does nothing.
- if (ConstMask->isNullValue())
+ auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3));
+ if (ConstMask && ConstMask->isNullValue())
return eraseInstFromFunction(II);
// If the mask is all ones, this is a plain vector store of the 1st argument.
- if (ConstMask->isAllOnesValue()) {
+ if (ConstMask && ConstMask->isAllOnesValue()) {
Value *StorePtr = II.getArgOperand(1);
- Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
StoreInst *S =
new StoreInst(II.getArgOperand(0), StorePtr, false, Alignment);
S->copyMetadata(II);
return S;
}
- if (isa<ScalableVectorType>(ConstMask->getType()))
+ if (ConstMask && isa<ScalableVectorType>(ConstMask->getType()))
return nullptr;
// Use masked off lanes to simplify operands via SimplifyDemandedVectorElts
- APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask);
- APInt PoisonElts(DemandedElts.getBitWidth(), 0);
- if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0), DemandedElts,
- PoisonElts))
- return replaceOperand(II, 0, V);
+ if (ConstMask) {
+ APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask);
+ APInt PoisonElts(DemandedElts.getBitWidth(), 0);
+ if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0), DemandedElts,
+ PoisonElts))
+ return replaceOperand(II, 0, V);
+ }
+
+ // Update the alignment if the known value is greater than the provided one.
+ if (cast<ConstantInt>(II.getArgOperand(2))->getAlignValue() < Alignment)
+ return replaceOperand(II, 2, Builder.getInt32(Alignment.value()));
return nullptr;
}
diff --git a/llvm/test/Transforms/InstCombine/load-store-masked-constant-array.ll b/llvm/test/Transforms/InstCombine/load-store-masked-constant-array.ll
index 918ea605a10bf..6ba52c178b8d4 100644
--- a/llvm/test/Transforms/InstCombine/load-store-masked-constant-array.ll
+++ b/llvm/test/Transforms/InstCombine/load-store-masked-constant-array.ll
@@ -7,7 +7,7 @@
define void @combine_masked_load_store_from_constant_array(ptr %ptr) {
; CHECK-LABEL: @combine_masked_load_store_from_constant_array(
; CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 2 x i1> @llvm.aarch64.sve.whilelt.nxv2i1.i32(i32 0, i32 10)
-; CHECK-NEXT: [[TMP2:%.*]] = call <vscale x 2 x i64> @llvm.masked.load.nxv2i64.p0(ptr nonnull @contant_int_array, i32 8, <vscale x 2 x i1> [[TMP1]], <vscale x 2 x i64> zeroinitializer)
+; CHECK-NEXT: [[TMP2:%.*]] = call <vscale x 2 x i64> @llvm.masked.load.nxv2i64.p0(ptr nonnull @contant_int_array, i32 16, <vscale x 2 x i1> [[TMP1]], <vscale x 2 x i64> zeroinitializer)
; CHECK-NEXT: call void @llvm.masked.store.nxv2i64.p0(<vscale x 2 x i64> [[TMP2]], ptr [[PTR:%.*]], i32 1, <vscale x 2 x i1> [[TMP1]])
; CHECK-NEXT: ret void
;
diff --git a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
index 8f7683419a82a..6c1168e6c1c70 100644
--- a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
+++ b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
@@ -439,3 +439,34 @@ define <2 x i64> @negative_gather_v2i64_uniform_ptrs_no_all_active_mask(ptr %src
declare <vscale x 2 x i64> @llvm.masked.gather.nxv2i64(<vscale x 2 x ptr>, i32, <vscale x 2 x i1>, <vscale x 2 x i64>)
declare <2 x i64> @llvm.masked.gather.v2i64(<2 x ptr>, i32, <2 x i1>, <2 x i64>)
+; Alignment tests
+
+define <2 x i32> @unaligned_load(<2 x i1> %mask, ptr %ptr) {
+; CHECK-LABEL: @unaligned_load(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: call void @llvm.assume(i1 true) [ "align"(ptr [[PTR:%.*]], i64 64) ]
+; CHECK-NEXT: [[MASKED_LOAD:%.*]] = call <2 x i32> @llvm.masked.load.v2i32.p0(ptr [[PTR]], i32 64, <2 x i1> [[MASK:%.*]], <2 x i32> poison)
+; CHECK-NEXT: ret <2 x i32> [[MASKED_LOAD]]
+;
+entry:
+ call void @llvm.assume(i1 true) [ "align"(ptr %ptr, i64 64) ]
+ %masked_load = call <2 x i32> @llvm.masked.load.v2i32.p0(ptr %ptr, i32 1, <2 x i1> %mask, <2 x i32> poison)
+ ret <2 x i32> %masked_load
+}
+
+define void @unaligned_store(<2 x i1> %mask, <2 x i32> %val, ptr %ptr) {
+; CHECK-LABEL: @unaligned_store(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: call void @llvm.assume(i1 true) [ "align"(ptr [[PTR:%.*]], i64 64) ]
+; CHECK-NEXT: tail call void @llvm.masked.store.v2i32.p0(<2 x i32> [[VAL:%.*]], ptr [[PTR]], i32 64, <2 x i1> [[MASK:%.*]])
+; CHECK-NEXT: ret void
+;
+entry:
+ call void @llvm.assume(i1 true) [ "align"(ptr %ptr, i64 64) ]
+ tail call void @llvm.masked.store.v2i32.p0(<2 x i32> %val, ptr %ptr, i32 1, <2 x i1> %mask)
+ ret void
+}
+
+declare void @llvm.assume(i1)
+declare <2 x i32> @llvm.masked.load.v2i32.p0(ptr, i32, <2 x i1>, <2 x i32>)
+declare void @llvm.masked.store.v2i32.p0(<2 x i32>, ptr, i32, <2 x i1>)
``````````
</details>
https://github.com/llvm/llvm-project/pull/156057
More information about the llvm-commits
mailing list