[llvm] [MemoryLocation] Support strided matrix loads / stores (PR #163368)

Nathan Corbyn via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 22 08:21:09 PDT 2025


https://github.com/cofibrant updated https://github.com/llvm/llvm-project/pull/163368

>From adbfde0f7be8458737f5c2dbc57377ce5537ee92 Mon Sep 17 00:00:00 2001
From: Nathan Corbyn <n_corbyn at apple.com>
Date: Wed, 22 Oct 2025 14:30:33 +0100
Subject: [PATCH 1/2] [MemoryLocation] Support strided matrix loads / stores

---
 llvm/lib/Analysis/MemoryLocation.cpp          | 28 +++++++++++++++++++
 .../DeadStoreElimination/matrix-intrinsics.ll |  8 +-----
 llvm/test/Transforms/GVN/matrix-intrinsics.ll | 12 +++-----
 3 files changed, 33 insertions(+), 15 deletions(-)

diff --git a/llvm/lib/Analysis/MemoryLocation.cpp b/llvm/lib/Analysis/MemoryLocation.cpp
index 1c5f08e13498c..947ef598c8c52 100644
--- a/llvm/lib/Analysis/MemoryLocation.cpp
+++ b/llvm/lib/Analysis/MemoryLocation.cpp
@@ -288,6 +288,34 @@ MemoryLocation MemoryLocation::getForArgument(const CallBase *Call,
                             LocationSize::precise(DL.getTypeStoreSize(
                                 II->getArgOperand(1)->getType())),
                             AATags);
+    case Intrinsic::matrix_column_major_load:
+    case Intrinsic::matrix_column_major_store: {
+      bool IsLoad = II->getIntrinsicID() == Intrinsic::matrix_column_major_load;
+      assert(ArgIdx == (IsLoad ? 0 : 1) && "Invalid argument index");
+
+      auto *Stride = dyn_cast<ConstantInt>(II->getArgOperand(IsLoad ? 1 : 2));
+      uint64_t Rows =
+          cast<ConstantInt>(II->getArgOperand(IsLoad ? 3 : 4))->getZExtValue();
+      uint64_t Cols =
+          cast<ConstantInt>(II->getArgOperand(IsLoad ? 4 : 5))->getZExtValue();
+
+      // The stride is dynamic, so there's nothing we can say.
+      if (!Stride)
+        return MemoryLocation(Arg, LocationSize::afterPointer(), AATags);
+
+      uint64_t ConstStride = Stride->getZExtValue();
+      auto *VT = cast<VectorType>(IsLoad ? II->getType()
+                                         : II->getArgOperand(0)->getType());
+      assert(Cols != 0 && "Matrix cannot have 0 columns");
+      TypeSize Size = DL.getTypeStoreSize(VT->getScalarType()) *
+                      (ConstStride * (Cols - 1) + Rows);
+
+      // In the unstrided case, we have a precise size, ...
+      if (ConstStride == Rows)
+        return MemoryLocation(Arg, LocationSize::precise(Size), AATags);
+      // otherwise we merely obtain an upper bound.
+      return MemoryLocation(Arg, LocationSize::upperBound(Size), AATags);
+    }
     }
 
     assert(
diff --git a/llvm/test/Transforms/DeadStoreElimination/matrix-intrinsics.ll b/llvm/test/Transforms/DeadStoreElimination/matrix-intrinsics.ll
index ae3c7464656df..2eaa275c33211 100644
--- a/llvm/test/Transforms/DeadStoreElimination/matrix-intrinsics.ll
+++ b/llvm/test/Transforms/DeadStoreElimination/matrix-intrinsics.ll
@@ -5,8 +5,8 @@ define void @dead_unstrided_store_non_matrix_load(ptr noalias %src, ptr noalias
 ; CHECK-LABEL: define void @dead_unstrided_store_non_matrix_load(
 ; CHECK-SAME: ptr noalias [[SRC:%.*]], ptr noalias [[DST:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
-; CHECK-NEXT:    call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> zeroinitializer, ptr [[DST]], i32 4, i1 false, i32 4, i32 2)
 ; CHECK-NEXT:    [[L:%.*]] = load double, ptr [[SRC]], align 8
+; CHECK-NEXT:    call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> zeroinitializer, ptr [[DST]], i32 4, i1 false, i32 4, i32 2)
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -173,7 +173,6 @@ define void @dead_unstrided_store(ptr noalias %src, ptr noalias %dst) {
 ; CHECK-LABEL: define void @dead_unstrided_store(
 ; CHECK-SAME: ptr noalias [[SRC:%.*]], ptr noalias [[DST:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
-; CHECK-NEXT:    call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> zeroinitializer, ptr [[DST]], i32 4, i1 false, i32 4, i32 2)
 ; CHECK-NEXT:    [[L:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC]], i32 4, i1 false, i32 4, i32 2)
 ; CHECK-NEXT:    call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> [[L]], ptr [[DST]], i32 4, i1 false, i32 4, i32 2)
 ; CHECK-NEXT:    ret void
@@ -241,7 +240,6 @@ define void @dead_matrix_store_non_matrix_overwrite_unstrided(ptr noalias %src,
 ; CHECK-LABEL: define void @dead_matrix_store_non_matrix_overwrite_unstrided(
 ; CHECK-SAME: ptr noalias [[SRC:%.*]], ptr noalias [[DST:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
-; CHECK-NEXT:    call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> zeroinitializer, ptr [[DST]], i32 4, i1 false, i32 4, i32 2)
 ; CHECK-NEXT:    [[L:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC]], i32 4, i1 false, i32 4, i32 2)
 ; CHECK-NEXT:    store <8 x double> zeroinitializer, ptr [[DST]], align 64
 ; CHECK-NEXT:    ret void
@@ -257,7 +255,6 @@ define void @dead_matrix_store_non_matrix_overwrite_strided(ptr noalias %src, pt
 ; CHECK-LABEL: define void @dead_matrix_store_non_matrix_overwrite_strided(
 ; CHECK-SAME: ptr noalias [[SRC:%.*]], ptr noalias [[DST:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
-; CHECK-NEXT:    call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> zeroinitializer, ptr [[DST]], i32 4, i1 false, i32 4, i32 2)
 ; CHECK-NEXT:    [[L:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC]], i32 8, i1 false, i32 4, i32 2)
 ; CHECK-NEXT:    store <16 x double> zeroinitializer, ptr [[DST]], align 128
 ; CHECK-NEXT:    ret void
@@ -289,7 +286,6 @@ define void @live_matrix_store_non_matrix_overwrite_strided(ptr noalias %src, pt
 ; CHECK-LABEL: define void @live_matrix_store_non_matrix_overwrite_strided(
 ; CHECK-SAME: ptr noalias [[SRC:%.*]], ptr noalias [[DST:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
-; CHECK-NEXT:    call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> zeroinitializer, ptr [[DST]], i32 4, i1 false, i32 4, i32 2)
 ; CHECK-NEXT:    [[L:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC]], i32 8, i1 false, i32 4, i32 2)
 ; CHECK-NEXT:    store <8 x double> zeroinitializer, ptr [[DST]], align 64
 ; CHECK-NEXT:    ret void
@@ -305,8 +301,6 @@ define void @dead_matrix_store_dimension_change(ptr noalias %src, ptr noalias %d
 ; CHECK-LABEL: define void @dead_matrix_store_dimension_change(
 ; CHECK-SAME: ptr noalias [[SRC:%.*]], ptr noalias [[DST:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
-; CHECK-NEXT:    [[L:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC]], i32 8, i1 false, i32 4, i32 2)
-; CHECK-NEXT:    call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> [[L]], ptr [[DST]], i32 4, i1 false, i32 4, i32 2)
 ; CHECK-NEXT:    call void @llvm.matrix.column.major.store.v9f64.i32(<9 x double> zeroinitializer, ptr [[DST]], i32 3, i1 false, i32 3, i32 3)
 ; CHECK-NEXT:    ret void
 ;
diff --git a/llvm/test/Transforms/GVN/matrix-intrinsics.ll b/llvm/test/Transforms/GVN/matrix-intrinsics.ll
index 78dbfe1ef6bd8..03bd45b7fcde7 100644
--- a/llvm/test/Transforms/GVN/matrix-intrinsics.ll
+++ b/llvm/test/Transforms/GVN/matrix-intrinsics.ll
@@ -8,9 +8,8 @@ define void @redundant_unstrided_load(ptr %src) {
 ; CHECK-NEXT:    [[SRC_OFFSET:%.*]] = getelementptr inbounds double, ptr [[SRC]], i32 8
 ; CHECK-NEXT:    [[L:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC_OFFSET]], i32 4, i1 false, i32 4, i32 2)
 ; CHECK-NEXT:    call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> [[L]], ptr [[SRC]], i32 4, i1 false, i32 4, i32 2)
-; CHECK-NEXT:    [[L_2:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC_OFFSET]], i32 4, i1 false, i32 4, i32 2)
 ; CHECK-NEXT:    call void @use(<8 x double> [[L]])
-; CHECK-NEXT:    call void @use(<8 x double> [[L_2]])
+; CHECK-NEXT:    call void @use(<8 x double> [[L]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -30,9 +29,8 @@ define void @redundant_unstrided_load_non_matrix_store(ptr %src) {
 ; CHECK-NEXT:    [[SRC_OFFSET:%.*]] = getelementptr inbounds double, ptr [[SRC]], i32 1
 ; CHECK-NEXT:    [[L:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC_OFFSET]], i32 4, i1 false, i32 4, i32 2)
 ; CHECK-NEXT:    store double 4.200000e+01, ptr [[SRC]], align 8
-; CHECK-NEXT:    [[L_2:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC_OFFSET]], i32 4, i1 false, i32 4, i32 2)
 ; CHECK-NEXT:    call void @use(<8 x double> [[L]])
-; CHECK-NEXT:    call void @use(<8 x double> [[L_2]])
+; CHECK-NEXT:    call void @use(<8 x double> [[L]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -52,9 +50,8 @@ define void @redundant_strided_load(ptr %src) {
 ; CHECK-NEXT:    [[SRC_OFFSET:%.*]] = getelementptr inbounds double, ptr [[SRC]], i32 16
 ; CHECK-NEXT:    [[L:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC_OFFSET]], i32 8, i1 false, i32 4, i32 2)
 ; CHECK-NEXT:    call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> [[L]], ptr [[SRC]], i32 8, i1 false, i32 4, i32 2)
-; CHECK-NEXT:    [[L_2:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC_OFFSET]], i32 8, i1 false, i32 4, i32 2)
 ; CHECK-NEXT:    call void @use(<8 x double> [[L]])
-; CHECK-NEXT:    call void @use(<8 x double> [[L_2]])
+; CHECK-NEXT:    call void @use(<8 x double> [[L]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -75,9 +72,8 @@ define void @redundant_strided_load_non_matrix_store(ptr %src) {
 ; CHECK-NEXT:    [[SRC_OFFSET:%.*]] = getelementptr inbounds double, ptr [[SRC]], i32 16
 ; CHECK-NEXT:    [[L:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC_OFFSET]], i32 8, i1 false, i32 4, i32 2)
 ; CHECK-NEXT:    store double 4.200000e+01, ptr [[SRC]], align 8
-; CHECK-NEXT:    [[L_2:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC_OFFSET]], i32 8, i1 false, i32 4, i32 2)
 ; CHECK-NEXT:    call void @use(<8 x double> [[L]])
-; CHECK-NEXT:    call void @use(<8 x double> [[L_2]])
+; CHECK-NEXT:    call void @use(<8 x double> [[L]])
 ; CHECK-NEXT:    ret void
 ;
 entry:

>From 8a2a5cbe55493a936c8e0bc35a305ac24625bb60 Mon Sep 17 00:00:00 2001
From: Nathan Corbyn <n_corbyn at apple.com>
Date: Wed, 22 Oct 2025 16:20:55 +0100
Subject: [PATCH 2/2] `getTypeStoreSize` ~> `getTypeAllocSize`

---
 llvm/lib/Analysis/MemoryLocation.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Analysis/MemoryLocation.cpp b/llvm/lib/Analysis/MemoryLocation.cpp
index 947ef598c8c52..edca3871ad593 100644
--- a/llvm/lib/Analysis/MemoryLocation.cpp
+++ b/llvm/lib/Analysis/MemoryLocation.cpp
@@ -307,7 +307,7 @@ MemoryLocation MemoryLocation::getForArgument(const CallBase *Call,
       auto *VT = cast<VectorType>(IsLoad ? II->getType()
                                          : II->getArgOperand(0)->getType());
       assert(Cols != 0 && "Matrix cannot have 0 columns");
-      TypeSize Size = DL.getTypeStoreSize(VT->getScalarType()) *
+      TypeSize Size = DL.getTypeAllocSize(VT->getScalarType()) *
                       (ConstStride * (Cols - 1) + Rows);
 
       // In the unstrided case, we have a precise size, ...



More information about the llvm-commits mailing list