[llvm] [MemoryLocation] Size Scalable Masked MemOps (PR #154785)

via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 21 08:18:26 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: Matthew Devereau (MDevereau)

<details>
<summary>Changes</summary>

Scalable masked loads and stores with a get active lane mask whose size is less than or equal to the scalable minimum number of elements can be be proven to have a fixed size. Adding this infomation allows scalable masked loads and stores to benefit from alias analysis optimizations.

---
Full diff: https://github.com/llvm/llvm-project/pull/154785.diff


2 Files Affected:

- (modified) llvm/lib/Analysis/MemoryLocation.cpp (+44-11) 
- (added) llvm/test/Analysis/BasicAA/scalable-dse-aa.ll (+145) 


``````````diff
diff --git a/llvm/lib/Analysis/MemoryLocation.cpp b/llvm/lib/Analysis/MemoryLocation.cpp
index 72b643c56a994..f2c3b843f70f6 100644
--- a/llvm/lib/Analysis/MemoryLocation.cpp
+++ b/llvm/lib/Analysis/MemoryLocation.cpp
@@ -150,6 +150,29 @@ MemoryLocation::getForDest(const CallBase *CB, const TargetLibraryInfo &TLI) {
   return MemoryLocation::getBeforeOrAfter(UsedV, CB->getAAMetadata());
 }
 
+static std::optional<FixedVectorType *>
+getFixedTypeFromScalableMemOp(Value *Mask, Type *Ty) {
+  auto ActiveLaneMask = dyn_cast<IntrinsicInst>(Mask);
+  if (!ActiveLaneMask ||
+      ActiveLaneMask->getIntrinsicID() != Intrinsic::get_active_lane_mask)
+    return std::nullopt;
+
+  auto ScalableTy = dyn_cast<ScalableVectorType>(Ty);
+  if (!ScalableTy)
+    return std::nullopt;
+
+  auto LaneMaskLo = dyn_cast<ConstantInt>(ActiveLaneMask->getOperand(0));
+  auto LaneMaskHi = dyn_cast<ConstantInt>(ActiveLaneMask->getOperand(1));
+  if (!LaneMaskLo || !LaneMaskHi)
+    return std::nullopt;
+
+  uint64_t NumElts = LaneMaskHi->getZExtValue() - LaneMaskLo->getZExtValue();
+  if (NumElts > ScalableTy->getMinNumElements())
+    return std::nullopt;
+
+  return FixedVectorType::get(ScalableTy->getElementType(), NumElts);
+}
+
 MemoryLocation MemoryLocation::getForArgument(const CallBase *Call,
                                               unsigned ArgIdx,
                                               const TargetLibraryInfo *TLI) {
@@ -213,20 +236,30 @@ MemoryLocation MemoryLocation::getForArgument(const CallBase *Call,
               cast<ConstantInt>(II->getArgOperand(0))->getZExtValue()),
           AATags);
 
-    case Intrinsic::masked_load:
+    case Intrinsic::masked_load: {
       assert(ArgIdx == 0 && "Invalid argument index");
-      return MemoryLocation(
-          Arg,
-          LocationSize::upperBound(DL.getTypeStoreSize(II->getType())),
-          AATags);
 
-    case Intrinsic::masked_store:
+      Type *Ty = II->getType();
+      auto KnownScalableSize =
+          getFixedTypeFromScalableMemOp(II->getOperand(2), Ty);
+      if (KnownScalableSize)
+        return MemoryLocation(Arg, DL.getTypeStoreSize(*KnownScalableSize),
+                              AATags);
+
+      return MemoryLocation(Arg, DL.getTypeStoreSize(Ty), AATags);
+    }
+    case Intrinsic::masked_store: {
       assert(ArgIdx == 1 && "Invalid argument index");
-      return MemoryLocation(
-          Arg,
-          LocationSize::upperBound(
-              DL.getTypeStoreSize(II->getArgOperand(0)->getType())),
-          AATags);
+
+      Type *Ty = II->getArgOperand(0)->getType();
+      auto KnownScalableSize =
+          getFixedTypeFromScalableMemOp(II->getOperand(3), Ty);
+      if (KnownScalableSize)
+        return MemoryLocation(Arg, DL.getTypeStoreSize(*KnownScalableSize),
+                              AATags);
+
+      return MemoryLocation(Arg, DL.getTypeStoreSize(Ty), AATags);
+    }
 
     case Intrinsic::invariant_end:
       // The first argument to an invariant.end is a "descriptor" type (e.g. a
diff --git a/llvm/test/Analysis/BasicAA/scalable-dse-aa.ll b/llvm/test/Analysis/BasicAA/scalable-dse-aa.ll
new file mode 100644
index 0000000000000..c12d1c2f25835
--- /dev/null
+++ b/llvm/test/Analysis/BasicAA/scalable-dse-aa.ll
@@ -0,0 +1,145 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -aa-pipeline=basic-aa -passes=dse -S | FileCheck %s
+
+define <vscale x 4 x float> @dead_scalable_store(i32 %0, ptr %1) {
+; CHECK-LABEL: define <vscale x 4 x float> @dead_scalable_store(
+; CHECK: call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.1.16, ptr nonnull %gep.arr.16, i32 1, <vscale x 4 x i1> %mask)
+; CHECK-NOT: call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.1.32, ptr nonnull %gep.arr.32, i32 1, <vscale x 4 x i1> %mask)
+; CHECK: call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.1.48, ptr nonnull %gep.arr.48, i32 1, <vscale x 4 x i1> %mask)
+;
+  %arr = alloca [64 x i32], align 4
+  %mask = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i32(i32 0, i32 4)
+
+  %gep.1.16 = getelementptr inbounds nuw i8, ptr %1, i64 16
+  %gep.1.32 = getelementptr inbounds nuw i8, ptr %1, i64 32
+  %gep.1.48 = getelementptr inbounds nuw i8, ptr %1, i64 48
+  %gep.arr.16 = getelementptr inbounds nuw i8, ptr %arr, i64 16
+  %gep.arr.32 = getelementptr inbounds nuw i8, ptr %arr, i64 32
+  %gep.arr.48 = getelementptr inbounds nuw i8, ptr %arr, i64 48
+
+  %load.1.16 = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.1.16, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+  call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.1.16, ptr nonnull %gep.arr.16, i32 1, <vscale x 4 x i1> %mask)
+
+  %load.1.32 = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.1.32, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+  call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.1.32, ptr nonnull %gep.arr.32, i32 1, <vscale x 4 x i1> %mask)
+
+  %load.1.48 = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.1.48, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+  call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.1.48, ptr nonnull %gep.arr.48, i32 1, <vscale x 4 x i1> %mask)
+
+  %faddop0 = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.arr.16, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+  %faddop1 = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.arr.48, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+  %fadd = fadd <vscale x 4 x float> %faddop0, %faddop1
+
+  ret <vscale x 4 x float> %fadd
+}
+
+define <vscale x 4 x float> @scalable_store_partial_overwrite(ptr %0) {
+; CHECK-LABEL: define <vscale x 4 x float> @scalable_store_partial_overwrite(
+; CHECK: call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0.16, ptr nonnull %gep.arr.16, i32 1, <vscale x 4 x i1> %mask)
+; CHECK: call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0.30, ptr nonnull %gep.arr.30, i32 1, <vscale x 4 x i1> %mask)
+; CHECK: call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0.48, ptr nonnull %gep.arr.48, i32 1, <vscale x 4 x i1> %mask)
+;
+  %arr = alloca [64 x i32], align 4
+  %mask = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i32(i32 0, i32 4)
+
+  %gep.0.16 = getelementptr inbounds nuw i8, ptr %0, i64 16
+  %gep.0.30 = getelementptr inbounds nuw i8, ptr %0, i64 30
+  %gep.0.48 = getelementptr inbounds nuw i8, ptr %0, i64 48
+  %gep.arr.16 = getelementptr inbounds nuw i8, ptr %arr, i64 16
+  %gep.arr.30 = getelementptr inbounds nuw i8, ptr %arr, i64 30
+  %gep.arr.48 = getelementptr inbounds nuw i8, ptr %arr, i64 48
+
+  %load.0.16 = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.0.16, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+  call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0.16, ptr nonnull %gep.arr.16, i32 1, <vscale x 4 x i1> %mask)
+
+  %load.0.30 = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.0.30, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+  call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0.30, ptr nonnull %gep.arr.30, i32 1, <vscale x 4 x i1> %mask)
+
+  %load.0.48 = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.0.48, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+  call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0.48, ptr nonnull %gep.arr.48, i32 1, <vscale x 4 x i1> %mask)
+
+  %faddop0 = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.arr.16, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+  %faddop1 = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.arr.48, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+  %fadd = fadd <vscale x 4 x float> %faddop0, %faddop1
+
+  ret <vscale x 4 x float> %fadd
+}
+
+define <vscale x 4 x float> @dead_scalable_store_small_mask(ptr %0) {
+; CHECK-LABEL: define <vscale x 4 x float> @dead_scalable_store_small_mask(
+; CHECK: call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0.16, ptr nonnull %gep.arr.16, i32 1, <vscale x 4 x i1> %mask)
+; CHECK-NOT: call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0.30, ptr nonnull %gep.arr.30, i32 1, <vscale x 4 x i1> %mask)
+; CHECK: call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0.46, ptr nonnull %gep.arr.46, i32 1, <vscale x 4 x i1> %mask)
+  %arr = alloca [64 x i32], align 4
+  %mask = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i32(i32 0, i32 4)
+
+  %gep.0.16 = getelementptr inbounds nuw i8, ptr %0, i64 16
+  %gep.0.30 = getelementptr inbounds nuw i8, ptr %0, i64 30
+  %gep.0.46 = getelementptr inbounds nuw i8, ptr %0, i64 46
+  %gep.arr.16 = getelementptr inbounds nuw i8, ptr %arr, i64 16
+  %gep.arr.30 = getelementptr inbounds nuw i8, ptr %arr, i64 30
+  %gep.arr.46 = getelementptr inbounds nuw i8, ptr %arr, i64 46
+
+  %load.0.16 = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.0.16, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+  call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0.16, ptr nonnull %gep.arr.16, i32 1, <vscale x 4 x i1> %mask)
+
+  %load.0.30 = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.0.30, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+  call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0.30, ptr nonnull %gep.arr.30, i32 1, <vscale x 4 x i1> %mask)
+
+  %load.0.46 = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.0.46, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+  call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0.46, ptr nonnull %gep.arr.46, i32 1, <vscale x 4 x i1> %mask)
+
+  %smallmask = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.32(i32 0, i32 2)
+  %faddop0 = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.arr.16, i32 1, <vscale x 4 x i1> %smallmask, <vscale x 4 x float> zeroinitializer)
+  %faddop1 = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %gep.arr.46, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+  %fadd = fadd <vscale x 4 x float> %faddop0, %faddop1
+
+  ret <vscale x 4 x float> %fadd
+}
+
+define <vscale x 4 x float> @dead_scalar_store(ptr noalias %0, ptr %1) {
+; CHECK-LABEL: define <vscale x 4 x float> @dead_scalar_store(
+; CHECK-NOT: store i32 20, ptr %gep.1.12
+;
+  %mask = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i32(i32 0, i32 4)
+  %gep.1.12 = getelementptr inbounds nuw i8, ptr %1, i64 12
+  store i32 20, ptr %gep.1.12
+
+  %load.0 = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %0, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+  call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0, ptr nonnull %1, i32 1, <vscale x 4 x i1> %mask)
+  %retval = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %1, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+  ret <vscale x 4 x float> %retval
+}
+
+; We don't know if the scalar store is dead as we can't determine vscale.
+; This get active lane mask may cover 4 or 8 integers
+define <vscale x 4 x float> @mask_gt_minimum_num_elts(ptr noalias %0, ptr %1) {
+; CHECK-LABEL: define <vscale x 4 x float> @mask_gt_minimum_num_elts(
+; CHECK: store i32 20, ptr %gep.1.28
+;
+  %mask = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i32(i32 0, i32 8)
+  %gep.1.28 = getelementptr inbounds nuw i8, ptr %1, i64 28
+  store i32 20, ptr %gep.1.28
+
+  %load.0 = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %0, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+  call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %load.0, ptr nonnull %1, i32 1, <vscale x 4 x i1> %mask)
+  %retval = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr nonnull %1, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+  ret <vscale x 4 x float> %retval
+}
+
+define <vscale x 16 x i8> @scalar_stores_small_mask(ptr noalias %0, ptr %1) {
+; CHECK-LABEL: define <vscale x 16 x i8> @scalar_stores_small_mask(
+; CHECK-NOT: store i8 60, ptr %gep.1.6
+; CHECK: store i8 120, ptr %gep.1.8
+;
+  %mask = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i8.i32(i32 0, i32 7)
+  %gep.1.6 = getelementptr inbounds nuw i8, ptr %1, i64 6
+  store i8 60, ptr %gep.1.6
+  %gep.1.8 = getelementptr inbounds nuw i8, ptr %1, i64 8
+  store i8 120,   ptr %gep.1.8
+
+  %load.0 = tail call <vscale x 16 x i8> @llvm.masked.load.nxv16i8.p0(ptr nonnull %0, i32 1, <vscale x 16 x i1> %mask, <vscale x 16 x i8> zeroinitializer)
+  call void @llvm.masked.store.nxv16i8.p0(<vscale x 16 x i8> %load.0, ptr %1, i32 1, <vscale x 16 x i1> %mask)
+  %retval = call <vscale x 16 x i8> @llvm.masked.load.nxv16i8.p0(ptr %1, i32 1, <vscale x 16 x i1> %mask, <vscale x 16 x i8> zeroinitializer)
+  ret <vscale x 16 x i8> %retval
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/154785


More information about the llvm-commits mailing list