[Mlir-commits] [mlir] [mlir][amdgpu] Fix `gather_to_lds` for 0d memrefs (PR #173421)
Ivan Butygin
llvmlistbot at llvm.org
Tue Dec 23 13:52:47 PST 2025
https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/173421
None
>From 306b44307070b9024a688e24d9acd7269f88563d Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 23 Dec 2025 22:44:54 +0100
Subject: [PATCH] [mlir][amdgpu] Fix `gather_to_lds` for 0d memrefs
---
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 2 +-
.../Conversion/AMDGPUToROCDL/load_lds.mlir | 18 ++++++++++++++++++
mlir/test/Dialect/AMDGPU/ops.mlir | 7 +++++++
3 files changed, 26 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index e77d131509add..ccaea75f8c4c1 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -833,7 +833,7 @@ LogicalResult GatherToLDSOp::verify() {
MemRefType srcType = cast<MemRefType>(getSrc().getType());
MemRefType dstType = cast<MemRefType>(getDst().getType());
- if (!dstType.areTrailingDimsContiguous(1))
+ if (dstType.getRank() > 0 && !dstType.areTrailingDimsContiguous(1))
return emitOpError("destination type inner most dim must be contiguous");
auto elemType = srcType.getElementType();
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
index 30578517be1ca..a24430c5b86cc 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
@@ -81,6 +81,24 @@ func.func @global_load_to_rocdl_wg_mem(%global : memref<128x72xf32>) {
func.return
}
+// CHECK-LABEL: func @global_load_to_rocdl_0d
+// CHECK-SAME: (%[[ARG0:.*]]: memref<f32>)
+func.func @global_load_to_rocdl_0d(%global : memref<f32>) {
+ %alloc = memref.alloc() : memref<f32, #gpu.address_space<workgroup>>
+ // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<f32> to !llvm.struct<(ptr, ptr, i64)>
+
+ // CHECK: %[[ALLOC:.*]] = memref.alloc()
+ // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] : memref<f32, #gpu.address_space<workgroup>> to !llvm.struct<(ptr<3>, ptr<3>, i64)>
+
+ // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] : !llvm.struct<(ptr, ptr, i64)>
+ // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+
+ // CHECK: rocdl.load.to.lds %[[GLOBAL_BASE]], %[[LDS_BASE]], 4
+ amdgpu.gather_to_lds %global[], %alloc[]
+ : f32, memref<f32>, memref<f32, #gpu.address_space<workgroup>>
+ func.return
+}
+
// CHECK-LABEL: func @global_load_to_rocdl_i8
// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xi8, 1>)
func.func @global_load_to_rocdl_i8(%global : memref<128x72xi8, #gpu_global_addrspace>) {
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 09f739c9d98a2..2b3234ef8510d 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -669,6 +669,13 @@ func.func @gather_to_lds(%idx1 : index, %idx2 : index, %mem1 : memref<32xf16>, %
func.return
}
+// CHECK-LABEL: func @gather_to_lds_0d
+func.func @gather_to_lds_0d(%mem1 : memref<f16>, %smem1 : memref<f16, #gpu.address_space<workgroup>>) {
+ // CHECK: amdgpu.gather_to_lds %{{.*}}[], %{{.*}}[]
+ amdgpu.gather_to_lds %mem1[], %smem1[] : vector<2xf16>, memref<f16>, memref<f16, #gpu.address_space<workgroup>>
+ func.return
+}
+
// CHECK-LABEL: func @memory_counter_wait
func.func @memory_counter_wait() {
// CHECK: amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4) tensor(5)
More information about the Mlir-commits
mailing list