[llvm] [MemoryLocation] Size Scalable Masked MemOps (PR #154785)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Aug 21 08:18:26 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-analysis
Author: Matthew Devereau (MDevereau)
<details>
<summary>Changes</summary>
Scalable masked loads and stores with a get active lane mask whose size is less than or equal to the scalable minimum number of elements can be be proven to have a fixed size. Adding this infomation allows scalable masked loads and stores to benefit from alias analysis optimizations.
---
Full diff: https://github.com/llvm/llvm-project/pull/154785.diff
2 Files Affected:
- (modified) llvm/lib/Analysis/MemoryLocation.cpp (+44-11)
- (added) llvm/test/Analysis/BasicAA/scalable-dse-aa.ll (+145)
``````````diff
diff --git a/llvm/lib/Analysis/MemoryLocation.cpp b/llvm/lib/Analysis/MemoryLocation.cpp
index 72b643c56a994..f2c3b843f70f6 100644
--- a/llvm/lib/Analysis/MemoryLocation.cpp
+++ b/llvm/lib/Analysis/MemoryLocation.cpp
@@ -150,6 +150,29 @@ MemoryLocation::getForDest(const CallBase *CB, const TargetLibraryInfo &TLI) {
return MemoryLocation::getBeforeOrAfter(UsedV, CB->getAAMetadata());
}
+static std::optional<FixedVectorType *>
+getFixedTypeFromScalableMemOp(Value *Mask, Type *Ty) {
+ auto ActiveLaneMask = dyn_cast<IntrinsicInst>(Mask);
+ if (!ActiveLaneMask ||
+ ActiveLaneMask->getIntrinsicID() != Intrinsic::get_active_lane_mask)
+ return std::nullopt;
+
+ auto ScalableTy = dyn_cast<ScalableVectorType>(Ty);
+ if (!ScalableTy)
+ return std::nullopt;
+
+ auto LaneMaskLo = dyn_cast<ConstantInt>(ActiveLaneMask->getOperand(0));
+ auto LaneMaskHi = dyn_cast<ConstantInt>(ActiveLaneMask->getOperand(1));
+ if (!LaneMaskLo || !LaneMaskHi)
+ return std::nullopt;
+
+ uint64_t NumElts = LaneMaskHi->getZExtValue() - LaneMaskLo->getZExtValue();
+ if (NumElts > ScalableTy->getMinNumElements())
+ return std::nullopt;
+
+ return FixedVectorType::get(ScalableTy->getElementType(), NumElts);
+}
+
MemoryLocation MemoryLocation::getForArgument(const CallBase *Call,
unsigned ArgIdx,
const TargetLibraryInfo *TLI) {
@@ -213,20 +236,30 @@ MemoryLocation MemoryLocation::getForArgument(const CallBase *Call,
cast<ConstantInt>(II->getArgOperand(0))->getZExtValue()),
AATags);
- case Intrinsic::masked_load:
+ case Intrinsic::masked_load: {
assert(ArgIdx == 0 && "Invalid argument index");
- return MemoryLocation(
- Arg,
- LocationSize::upperBound(DL.getTypeStoreSize(II->getType())),
- AATags);
- case Intrinsic::masked_store:
+ Type *Ty = II->getType();
+ auto KnownScalableSize =
+ getFixedTypeFromScalableMemOp(II->getOperand(2), Ty);
+ if (KnownScalableSize)
+ return MemoryLocation(Arg, DL.getTypeStoreSize(*KnownScalableSize),
+ AATags);
+
+ return MemoryLocation(Arg, DL.getTypeStoreSize(Ty), AATags);
+ }
+ case Intrinsic::masked_store: {
assert(ArgIdx == 1 && "Invalid argument index");
- return MemoryLocation(
- Arg,
- LocationSize::upperBound(
- DL.getTypeStoreSize(II->getArgOperand(0)->getType())),
- AATags);
+
+ Type *Ty = II->getArgOperand(0)->getType();
+ auto KnownScalableSize =
+ getFixedTypeFromScalableMemOp(II->getOperand(3), Ty);
+ if (KnownScalableSize)
+ return MemoryLocation(Arg, DL.getTypeStoreSize(*KnownScalableSize),
+ AATags);
+
+ return MemoryLocation(Arg, DL.getTypeStoreSize(Ty), AATags);
+ }
case Intrinsic::invariant_end:
// The first argument to an invariant.end is a "descriptor" type (e.g. a
diff --git a/llvm/test/Analysis/BasicAA/scalable-dse-aa.ll b/llvm/test/Analysis/BasicAA/scalable-dse-aa.ll
new file mode 100644
index 0000000000000..c12d1c2f25835
--- /dev/null
+++ b/llvm/test/Analysis/BasicAA/scalable-dse-aa.ll
@@ -0,0 +1,145 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -aa-pipeline=basic-aa -passes=dse -S | FileCheck %s
+
+define <vscale x 4 x float> @dead_scalable_store(i32 %0, ptr %1) {
+; CHECK-LABEL: define <vscale x 4 x float> @dead_scalable_store(
+; CHECK: call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.1.16, ptr nonnull %gep.arr.16, i32 1, <vscale x 4 x i1> %mask)
+; CHECK-NOT: call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.1.32, ptr nonnull %gep.arr.32, i32 1, <vscale x 4 x i1> %mask)
+; CHECK: call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.1.48, ptr nonnull %gep.arr.48, i32 1, <vscale x 4 x i1> %mask)
+;
+ %arr = alloca [64 x i32], align 4
+ %mask = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i32(i32 0, i32 4)
+
+ %gep.1.16 = getelementptr inbounds nuw i8, ptr %1, i64 16
+ %gep.1.32 = getelementptr inbounds nuw i8, ptr %1, i64 32
+ %gep.1.48 = getelementptr inbounds nuw i8, ptr %1, i64 48
+ %gep.arr.16 = getelementptr inbounds nuw i8, ptr %arr, i64 16
+ %gep.arr.32 = getelementptr inbounds nuw i8, ptr %arr, i64 32
+ %gep.arr.48 = getelementptr inbounds nuw i8, ptr %arr, i64 48
+
+ %load.1.16 = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.1.16, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+ call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.1.16, ptr nonnull %gep.arr.16, i32 1, <vscale x 4 x i1> %mask)
+
+ %load.1.32 = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.1.32, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+ call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.1.32, ptr nonnull %gep.arr.32, i32 1, <vscale x 4 x i1> %mask)
+
+ %load.1.48 = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.1.48, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+ call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.1.48, ptr nonnull %gep.arr.48, i32 1, <vscale x 4 x i1> %mask)
+
+ %faddop0 = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.arr.16, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+ %faddop1 = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.arr.48, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+ %fadd = fadd <vscale x 4 x float> %faddop0, %faddop1
+
+ ret <vscale x 4 x float> %fadd
+}
+
+define <vscale x 4 x float> @scalable_store_partial_overwrite(ptr %0) {
+; CHECK-LABEL: define <vscale x 4 x float> @scalable_store_partial_overwrite(
+; CHECK: call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0.16, ptr nonnull %gep.arr.16, i32 1, <vscale x 4 x i1> %mask)
+; CHECK: call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0.30, ptr nonnull %gep.arr.30, i32 1, <vscale x 4 x i1> %mask)
+; CHECK: call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0.48, ptr nonnull %gep.arr.48, i32 1, <vscale x 4 x i1> %mask)
+;
+ %arr = alloca [64 x i32], align 4
+ %mask = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i32(i32 0, i32 4)
+
+ %gep.0.16 = getelementptr inbounds nuw i8, ptr %0, i64 16
+ %gep.0.30 = getelementptr inbounds nuw i8, ptr %0, i64 30
+ %gep.0.48 = getelementptr inbounds nuw i8, ptr %0, i64 48
+ %gep.arr.16 = getelementptr inbounds nuw i8, ptr %arr, i64 16
+ %gep.arr.30 = getelementptr inbounds nuw i8, ptr %arr, i64 30
+ %gep.arr.48 = getelementptr inbounds nuw i8, ptr %arr, i64 48
+
+ %load.0.16 = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.0.16, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+ call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0.16, ptr nonnull %gep.arr.16, i32 1, <vscale x 4 x i1> %mask)
+
+ %load.0.30 = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.0.30, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+ call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0.30, ptr nonnull %gep.arr.30, i32 1, <vscale x 4 x i1> %mask)
+
+ %load.0.48 = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.0.48, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+ call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0.48, ptr nonnull %gep.arr.48, i32 1, <vscale x 4 x i1> %mask)
+
+ %faddop0 = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.arr.16, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+ %faddop1 = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.arr.48, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+ %fadd = fadd <vscale x 4 x float> %faddop0, %faddop1
+
+ ret <vscale x 4 x float> %fadd
+}
+
+define <vscale x 4 x float> @dead_scalable_store_small_mask(ptr %0) {
+; CHECK-LABEL: define <vscale x 4 x float> @dead_scalable_store_small_mask(
+; CHECK: call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0.16, ptr nonnull %gep.arr.16, i32 1, <vscale x 4 x i1> %mask)
+; CHECK-NOT: call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0.30, ptr nonnull %gep.arr.30, i32 1, <vscale x 4 x i1> %mask)
+; CHECK: call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0.46, ptr nonnull %gep.arr.46, i32 1, <vscale x 4 x i1> %mask)
+ %arr = alloca [64 x i32], align 4
+ %mask = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i32(i32 0, i32 4)
+
+ %gep.0.16 = getelementptr inbounds nuw i8, ptr %0, i64 16
+ %gep.0.30 = getelementptr inbounds nuw i8, ptr %0, i64 30
+ %gep.0.46 = getelementptr inbounds nuw i8, ptr %0, i64 46
+ %gep.arr.16 = getelementptr inbounds nuw i8, ptr %arr, i64 16
+ %gep.arr.30 = getelementptr inbounds nuw i8, ptr %arr, i64 30
+ %gep.arr.46 = getelementptr inbounds nuw i8, ptr %arr, i64 46
+
+ %load.0.16 = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.0.16, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+ call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0.16, ptr nonnull %gep.arr.16, i32 1, <vscale x 4 x i1> %mask)
+
+ %load.0.30 = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.0.30, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+ call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0.30, ptr nonnull %gep.arr.30, i32 1, <vscale x 4 x i1> %mask)
+
+ %load.0.46 = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.0.46, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+ call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0.46, ptr nonnull %gep.arr.46, i32 1, <vscale x 4 x i1> %mask)
+
+ %smallmask = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.32(i32 0, i32 2)
+ %faddop0 = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.arr.16, i32 1, <vscale x 4 x i1> %smallmask, <vscale x 4 x float> zeroinitializer)
+ %faddop1 = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.arr.46, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+ %fadd = fadd <vscale x 4 x float> %faddop0, %faddop1
+
+ ret <vscale x 4 x float> %fadd
+}
+
+define <vscale x 4 x float> @dead_scalar_store(ptr noalias %0, ptr %1) {
+; CHECK-LABEL: define <vscale x 4 x float> @dead_scalar_store(
+; CHECK-NOT: store i32 20, ptr %gep.1.12
+;
+ %mask = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i32(i32 0, i32 4)
+ %gep.1.12 = getelementptr inbounds nuw i8, ptr %1, i64 12
+ store i32 20, ptr %gep.1.12
+
+ %load.0 = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %0, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+ call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0, ptr nonnull %1, i32 1, <vscale x 4 x i1> %mask)
+ %retval = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %1, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+ ret <vscale x 4 x float> %retval
+}
+
+; We don't know if the scalar store is dead as we can't determine vscale.
+; This get active lane mask may cover 4 or 8 integers
+define <vscale x 4 x float> @mask_gt_minimum_num_elts(ptr noalias %0, ptr %1) {
+; CHECK-LABEL: define <vscale x 4 x float> @mask_gt_minimum_num_elts(
+; CHECK: store i32 20, ptr %gep.1.28
+;
+ %mask = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i32(i32 0, i32 8)
+ %gep.1.28 = getelementptr inbounds nuw i8, ptr %1, i64 28
+ store i32 20, ptr %gep.1.28
+
+ %load.0 = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %0, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+ call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0, ptr nonnull %1, i32 1, <vscale x 4 x i1> %mask)
+ %retval = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %1, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+ ret <vscale x 4 x float> %retval
+}
+
+define <vscale x 16 x i8> @scalar_stores_small_mask(ptr noalias %0, ptr %1) {
+; CHECK-LABEL: define <vscale x 16 x i8> @scalar_stores_small_mask(
+; CHECK-NOT: store i8 60, ptr %gep.1.6
+; CHECK: store i8 120, ptr %gep.1.8
+;
+ %mask = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i8.i32(i32 0, i32 7)
+ %gep.1.6 = getelementptr inbounds nuw i8, ptr %1, i64 6
+ store i8 60, ptr %gep.1.6
+ %gep.1.8 = getelementptr inbounds nuw i8, ptr %1, i64 8
+ store i8 120, ptr %gep.1.8
+
+ %load.0 = tail call <vscale x 16 x i8> @llvm.masked.load.nxv16i8.p0(ptr nonnull %0, i32 1, <vscale x 16 x i1> %mask, <vscale x 16 x i8> zeroinitializer)
+ call void @llvm.masked.store.nxv16i8.p0(<vscale x 16 x i8> %load.0, ptr %1, i32 1, <vscale x 16 x i1> %mask)
+ %retval = call <vscale x 16 x i8> @llvm.masked.load.nxv16i8.p0(ptr %1, i32 1, <vscale x 16 x i1> %mask, <vscale x 16 x i8> zeroinitializer)
+ ret <vscale x 16 x i8> %retval
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/154785
More information about the llvm-commits
mailing list