[llvm] [MemoryLocation] Support strided matrix loads / stores (PR #163368)
Nathan Corbyn via llvm-commits
llvm-commits at lists.llvm.org
Wed Oct 22 08:21:09 PDT 2025
https://github.com/cofibrant updated https://github.com/llvm/llvm-project/pull/163368
>From adbfde0f7be8458737f5c2dbc57377ce5537ee92 Mon Sep 17 00:00:00 2001
From: Nathan Corbyn <n_corbyn at apple.com>
Date: Wed, 22 Oct 2025 14:30:33 +0100
Subject: [PATCH 1/2] [MemoryLocation] Support strided matrix loads / stores
---
llvm/lib/Analysis/MemoryLocation.cpp | 28 +++++++++++++++++++
.../DeadStoreElimination/matrix-intrinsics.ll | 8 +-----
llvm/test/Transforms/GVN/matrix-intrinsics.ll | 12 +++-----
3 files changed, 33 insertions(+), 15 deletions(-)
diff --git a/llvm/lib/Analysis/MemoryLocation.cpp b/llvm/lib/Analysis/MemoryLocation.cpp
index 1c5f08e13498c..947ef598c8c52 100644
--- a/llvm/lib/Analysis/MemoryLocation.cpp
+++ b/llvm/lib/Analysis/MemoryLocation.cpp
@@ -288,6 +288,34 @@ MemoryLocation MemoryLocation::getForArgument(const CallBase *Call,
LocationSize::precise(DL.getTypeStoreSize(
II->getArgOperand(1)->getType())),
AATags);
+ case Intrinsic::matrix_column_major_load:
+ case Intrinsic::matrix_column_major_store: {
+ bool IsLoad = II->getIntrinsicID() == Intrinsic::matrix_column_major_load;
+ assert(ArgIdx == (IsLoad ? 0 : 1) && "Invalid argument index");
+
+ auto *Stride = dyn_cast<ConstantInt>(II->getArgOperand(IsLoad ? 1 : 2));
+ uint64_t Rows =
+ cast<ConstantInt>(II->getArgOperand(IsLoad ? 3 : 4))->getZExtValue();
+ uint64_t Cols =
+ cast<ConstantInt>(II->getArgOperand(IsLoad ? 4 : 5))->getZExtValue();
+
+ // The stride is dynamic, so there's nothing we can say.
+ if (!Stride)
+ return MemoryLocation(Arg, LocationSize::afterPointer(), AATags);
+
+ uint64_t ConstStride = Stride->getZExtValue();
+ auto *VT = cast<VectorType>(IsLoad ? II->getType()
+ : II->getArgOperand(0)->getType());
+ assert(Cols != 0 && "Matrix cannot have 0 columns");
+ TypeSize Size = DL.getTypeStoreSize(VT->getScalarType()) *
+ (ConstStride * (Cols - 1) + Rows);
+
+ // In the unstrided case, we have a precise size, ...
+ if (ConstStride == Rows)
+ return MemoryLocation(Arg, LocationSize::precise(Size), AATags);
+ // otherwise we merely obtain an upper bound.
+ return MemoryLocation(Arg, LocationSize::upperBound(Size), AATags);
+ }
}
assert(
diff --git a/llvm/test/Transforms/DeadStoreElimination/matrix-intrinsics.ll b/llvm/test/Transforms/DeadStoreElimination/matrix-intrinsics.ll
index ae3c7464656df..2eaa275c33211 100644
--- a/llvm/test/Transforms/DeadStoreElimination/matrix-intrinsics.ll
+++ b/llvm/test/Transforms/DeadStoreElimination/matrix-intrinsics.ll
@@ -5,8 +5,8 @@ define void @dead_unstrided_store_non_matrix_load(ptr noalias %src, ptr noalias
; CHECK-LABEL: define void @dead_unstrided_store_non_matrix_load(
; CHECK-SAME: ptr noalias [[SRC:%.*]], ptr noalias [[DST:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
-; CHECK-NEXT: call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> zeroinitializer, ptr [[DST]], i32 4, i1 false, i32 4, i32 2)
; CHECK-NEXT: [[L:%.*]] = load double, ptr [[SRC]], align 8
+; CHECK-NEXT: call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> zeroinitializer, ptr [[DST]], i32 4, i1 false, i32 4, i32 2)
; CHECK-NEXT: ret void
;
entry:
@@ -173,7 +173,6 @@ define void @dead_unstrided_store(ptr noalias %src, ptr noalias %dst) {
; CHECK-LABEL: define void @dead_unstrided_store(
; CHECK-SAME: ptr noalias [[SRC:%.*]], ptr noalias [[DST:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
-; CHECK-NEXT: call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> zeroinitializer, ptr [[DST]], i32 4, i1 false, i32 4, i32 2)
; CHECK-NEXT: [[L:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC]], i32 4, i1 false, i32 4, i32 2)
; CHECK-NEXT: call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> [[L]], ptr [[DST]], i32 4, i1 false, i32 4, i32 2)
; CHECK-NEXT: ret void
@@ -241,7 +240,6 @@ define void @dead_matrix_store_non_matrix_overwrite_unstrided(ptr noalias %src,
; CHECK-LABEL: define void @dead_matrix_store_non_matrix_overwrite_unstrided(
; CHECK-SAME: ptr noalias [[SRC:%.*]], ptr noalias [[DST:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
-; CHECK-NEXT: call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> zeroinitializer, ptr [[DST]], i32 4, i1 false, i32 4, i32 2)
; CHECK-NEXT: [[L:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC]], i32 4, i1 false, i32 4, i32 2)
; CHECK-NEXT: store <8 x double> zeroinitializer, ptr [[DST]], align 64
; CHECK-NEXT: ret void
@@ -257,7 +255,6 @@ define void @dead_matrix_store_non_matrix_overwrite_strided(ptr noalias %src, pt
; CHECK-LABEL: define void @dead_matrix_store_non_matrix_overwrite_strided(
; CHECK-SAME: ptr noalias [[SRC:%.*]], ptr noalias [[DST:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
-; CHECK-NEXT: call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> zeroinitializer, ptr [[DST]], i32 4, i1 false, i32 4, i32 2)
; CHECK-NEXT: [[L:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC]], i32 8, i1 false, i32 4, i32 2)
; CHECK-NEXT: store <16 x double> zeroinitializer, ptr [[DST]], align 128
; CHECK-NEXT: ret void
@@ -289,7 +286,6 @@ define void @live_matrix_store_non_matrix_overwrite_strided(ptr noalias %src, pt
; CHECK-LABEL: define void @live_matrix_store_non_matrix_overwrite_strided(
; CHECK-SAME: ptr noalias [[SRC:%.*]], ptr noalias [[DST:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
-; CHECK-NEXT: call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> zeroinitializer, ptr [[DST]], i32 4, i1 false, i32 4, i32 2)
; CHECK-NEXT: [[L:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC]], i32 8, i1 false, i32 4, i32 2)
; CHECK-NEXT: store <8 x double> zeroinitializer, ptr [[DST]], align 64
; CHECK-NEXT: ret void
@@ -305,8 +301,6 @@ define void @dead_matrix_store_dimension_change(ptr noalias %src, ptr noalias %d
; CHECK-LABEL: define void @dead_matrix_store_dimension_change(
; CHECK-SAME: ptr noalias [[SRC:%.*]], ptr noalias [[DST:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
-; CHECK-NEXT: [[L:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC]], i32 8, i1 false, i32 4, i32 2)
-; CHECK-NEXT: call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> [[L]], ptr [[DST]], i32 4, i1 false, i32 4, i32 2)
; CHECK-NEXT: call void @llvm.matrix.column.major.store.v9f64.i32(<9 x double> zeroinitializer, ptr [[DST]], i32 3, i1 false, i32 3, i32 3)
; CHECK-NEXT: ret void
;
diff --git a/llvm/test/Transforms/GVN/matrix-intrinsics.ll b/llvm/test/Transforms/GVN/matrix-intrinsics.ll
index 78dbfe1ef6bd8..03bd45b7fcde7 100644
--- a/llvm/test/Transforms/GVN/matrix-intrinsics.ll
+++ b/llvm/test/Transforms/GVN/matrix-intrinsics.ll
@@ -8,9 +8,8 @@ define void @redundant_unstrided_load(ptr %src) {
; CHECK-NEXT: [[SRC_OFFSET:%.*]] = getelementptr inbounds double, ptr [[SRC]], i32 8
; CHECK-NEXT: [[L:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC_OFFSET]], i32 4, i1 false, i32 4, i32 2)
; CHECK-NEXT: call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> [[L]], ptr [[SRC]], i32 4, i1 false, i32 4, i32 2)
-; CHECK-NEXT: [[L_2:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC_OFFSET]], i32 4, i1 false, i32 4, i32 2)
; CHECK-NEXT: call void @use(<8 x double> [[L]])
-; CHECK-NEXT: call void @use(<8 x double> [[L_2]])
+; CHECK-NEXT: call void @use(<8 x double> [[L]])
; CHECK-NEXT: ret void
;
entry:
@@ -30,9 +29,8 @@ define void @redundant_unstrided_load_non_matrix_store(ptr %src) {
; CHECK-NEXT: [[SRC_OFFSET:%.*]] = getelementptr inbounds double, ptr [[SRC]], i32 1
; CHECK-NEXT: [[L:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC_OFFSET]], i32 4, i1 false, i32 4, i32 2)
; CHECK-NEXT: store double 4.200000e+01, ptr [[SRC]], align 8
-; CHECK-NEXT: [[L_2:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC_OFFSET]], i32 4, i1 false, i32 4, i32 2)
; CHECK-NEXT: call void @use(<8 x double> [[L]])
-; CHECK-NEXT: call void @use(<8 x double> [[L_2]])
+; CHECK-NEXT: call void @use(<8 x double> [[L]])
; CHECK-NEXT: ret void
;
entry:
@@ -52,9 +50,8 @@ define void @redundant_strided_load(ptr %src) {
; CHECK-NEXT: [[SRC_OFFSET:%.*]] = getelementptr inbounds double, ptr [[SRC]], i32 16
; CHECK-NEXT: [[L:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC_OFFSET]], i32 8, i1 false, i32 4, i32 2)
; CHECK-NEXT: call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> [[L]], ptr [[SRC]], i32 8, i1 false, i32 4, i32 2)
-; CHECK-NEXT: [[L_2:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC_OFFSET]], i32 8, i1 false, i32 4, i32 2)
; CHECK-NEXT: call void @use(<8 x double> [[L]])
-; CHECK-NEXT: call void @use(<8 x double> [[L_2]])
+; CHECK-NEXT: call void @use(<8 x double> [[L]])
; CHECK-NEXT: ret void
;
entry:
@@ -75,9 +72,8 @@ define void @redundant_strided_load_non_matrix_store(ptr %src) {
; CHECK-NEXT: [[SRC_OFFSET:%.*]] = getelementptr inbounds double, ptr [[SRC]], i32 16
; CHECK-NEXT: [[L:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC_OFFSET]], i32 8, i1 false, i32 4, i32 2)
; CHECK-NEXT: store double 4.200000e+01, ptr [[SRC]], align 8
-; CHECK-NEXT: [[L_2:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC_OFFSET]], i32 8, i1 false, i32 4, i32 2)
; CHECK-NEXT: call void @use(<8 x double> [[L]])
-; CHECK-NEXT: call void @use(<8 x double> [[L_2]])
+; CHECK-NEXT: call void @use(<8 x double> [[L]])
; CHECK-NEXT: ret void
;
entry:
>From 8a2a5cbe55493a936c8e0bc35a305ac24625bb60 Mon Sep 17 00:00:00 2001
From: Nathan Corbyn <n_corbyn at apple.com>
Date: Wed, 22 Oct 2025 16:20:55 +0100
Subject: [PATCH 2/2] `getTypeStoreSize` ~> `getTypeAllocSize`
---
llvm/lib/Analysis/MemoryLocation.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/lib/Analysis/MemoryLocation.cpp b/llvm/lib/Analysis/MemoryLocation.cpp
index 947ef598c8c52..edca3871ad593 100644
--- a/llvm/lib/Analysis/MemoryLocation.cpp
+++ b/llvm/lib/Analysis/MemoryLocation.cpp
@@ -307,7 +307,7 @@ MemoryLocation MemoryLocation::getForArgument(const CallBase *Call,
auto *VT = cast<VectorType>(IsLoad ? II->getType()
: II->getArgOperand(0)->getType());
assert(Cols != 0 && "Matrix cannot have 0 columns");
- TypeSize Size = DL.getTypeStoreSize(VT->getScalarType()) *
+ TypeSize Size = DL.getTypeAllocSize(VT->getScalarType()) *
(ConstStride * (Cols - 1) + Rows);
// In the unstrided case, we have a precise size, ...
More information about the llvm-commits
mailing list