[Mlir-commits] [mlir] 9afbcde - [mlir][amdgpu] Fix `gather_to_lds` for 0d memrefs (#173421)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Dec 23 15:00:58 PST 2025
Author: Ivan Butygin
Date: 2025-12-24T02:00:55+03:00
New Revision: 9afbcde1d2938c494c225e93837e3a075db218e7
URL: https://github.com/llvm/llvm-project/commit/9afbcde1d2938c494c225e93837e3a075db218e7
DIFF: https://github.com/llvm/llvm-project/commit/9afbcde1d2938c494c225e93837e3a075db218e7.diff
LOG: [mlir][amdgpu] Fix `gather_to_lds` for 0d memrefs (#173421)
`dstType.areTrailingDimsContiguous(1)` asserts for memref of rank 0.
Added:
Modified:
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
mlir/test/Dialect/AMDGPU/ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index e77d131509add..ccaea75f8c4c1 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -833,7 +833,7 @@ LogicalResult GatherToLDSOp::verify() {
MemRefType srcType = cast<MemRefType>(getSrc().getType());
MemRefType dstType = cast<MemRefType>(getDst().getType());
- if (!dstType.areTrailingDimsContiguous(1))
+ if (dstType.getRank() > 0 && !dstType.areTrailingDimsContiguous(1))
return emitOpError("destination type inner most dim must be contiguous");
auto elemType = srcType.getElementType();
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
index 30578517be1ca..a24430c5b86cc 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
@@ -81,6 +81,24 @@ func.func @global_load_to_rocdl_wg_mem(%global : memref<128x72xf32>) {
func.return
}
+// CHECK-LABEL: func @global_load_to_rocdl_0d
+// CHECK-SAME: (%[[ARG0:.*]]: memref<f32>)
+func.func @global_load_to_rocdl_0d(%global : memref<f32>) {
+ %alloc = memref.alloc() : memref<f32, #gpu.address_space<workgroup>>
+ // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<f32> to !llvm.struct<(ptr, ptr, i64)>
+
+ // CHECK: %[[ALLOC:.*]] = memref.alloc()
+ // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] : memref<f32, #gpu.address_space<workgroup>> to !llvm.struct<(ptr<3>, ptr<3>, i64)>
+
+ // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] : !llvm.struct<(ptr, ptr, i64)>
+ // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+
+ // CHECK: rocdl.load.to.lds %[[GLOBAL_BASE]], %[[LDS_BASE]], 4
+ amdgpu.gather_to_lds %global[], %alloc[]
+ : f32, memref<f32>, memref<f32, #gpu.address_space<workgroup>>
+ func.return
+}
+
// CHECK-LABEL: func @global_load_to_rocdl_i8
// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xi8, 1>)
func.func @global_load_to_rocdl_i8(%global : memref<128x72xi8, #gpu_global_addrspace>) {
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 09f739c9d98a2..2b3234ef8510d 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -669,6 +669,13 @@ func.func @gather_to_lds(%idx1 : index, %idx2 : index, %mem1 : memref<32xf16>, %
func.return
}
+// CHECK-LABEL: func @gather_to_lds_0d
+func.func @gather_to_lds_0d(%mem1 : memref<f16>, %smem1 : memref<f16, #gpu.address_space<workgroup>>) {
+ // CHECK: amdgpu.gather_to_lds %{{.*}}[], %{{.*}}[]
+ amdgpu.gather_to_lds %mem1[], %smem1[] : vector<2xf16>, memref<f16>, memref<f16, #gpu.address_space<workgroup>>
+ func.return
+}
+
// CHECK-LABEL: func @memory_counter_wait
func.func @memory_counter_wait() {
// CHECK: amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4) tensor(5)
More information about the Mlir-commits
mailing list