[Mlir-commits] [mlir] 9afbcde - [mlir][amdgpu] Fix `gather_to_lds` for 0d memrefs (#173421)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 23 15:00:58 PST 2025


Author: Ivan Butygin
Date: 2025-12-24T02:00:55+03:00
New Revision: 9afbcde1d2938c494c225e93837e3a075db218e7

URL: https://github.com/llvm/llvm-project/commit/9afbcde1d2938c494c225e93837e3a075db218e7
DIFF: https://github.com/llvm/llvm-project/commit/9afbcde1d2938c494c225e93837e3a075db218e7.diff

LOG: [mlir][amdgpu] Fix `gather_to_lds` for 0d memrefs (#173421)

`dstType.areTrailingDimsContiguous(1)` asserts for memref of rank 0.

Added: 
    

Modified: 
    mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
    mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
    mlir/test/Dialect/AMDGPU/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index e77d131509add..ccaea75f8c4c1 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -833,7 +833,7 @@ LogicalResult GatherToLDSOp::verify() {
   MemRefType srcType = cast<MemRefType>(getSrc().getType());
   MemRefType dstType = cast<MemRefType>(getDst().getType());
 
-  if (!dstType.areTrailingDimsContiguous(1))
+  if (dstType.getRank() > 0 && !dstType.areTrailingDimsContiguous(1))
     return emitOpError("destination type inner most dim must be contiguous");
 
   auto elemType = srcType.getElementType();

diff  --git a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
index 30578517be1ca..a24430c5b86cc 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
@@ -81,6 +81,24 @@ func.func @global_load_to_rocdl_wg_mem(%global : memref<128x72xf32>) {
   func.return
 }
 
+// CHECK-LABEL: func @global_load_to_rocdl_0d
+// CHECK-SAME: (%[[ARG0:.*]]: memref<f32>)
+func.func @global_load_to_rocdl_0d(%global : memref<f32>) {
+  %alloc = memref.alloc() : memref<f32, #gpu.address_space<workgroup>>
+  // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<f32> to !llvm.struct<(ptr, ptr, i64)>
+
+  // CHECK: %[[ALLOC:.*]] = memref.alloc()
+  // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] : memref<f32, #gpu.address_space<workgroup>> to !llvm.struct<(ptr<3>, ptr<3>, i64)>
+
+  // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] : !llvm.struct<(ptr, ptr, i64)>
+  // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+
+  // CHECK: rocdl.load.to.lds %[[GLOBAL_BASE]], %[[LDS_BASE]], 4
+  amdgpu.gather_to_lds %global[], %alloc[]
+    : f32, memref<f32>, memref<f32, #gpu.address_space<workgroup>>
+  func.return
+}
+
 // CHECK-LABEL: func @global_load_to_rocdl_i8
 // CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xi8, 1>)
 func.func @global_load_to_rocdl_i8(%global : memref<128x72xi8, #gpu_global_addrspace>) {

diff  --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 09f739c9d98a2..2b3234ef8510d 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -669,6 +669,13 @@ func.func @gather_to_lds(%idx1 : index, %idx2 : index, %mem1 : memref<32xf16>, %
   func.return
 }
 
+// CHECK-LABEL: func @gather_to_lds_0d
+func.func @gather_to_lds_0d(%mem1 : memref<f16>, %smem1 : memref<f16, #gpu.address_space<workgroup>>) {
+  // CHECK: amdgpu.gather_to_lds %{{.*}}[], %{{.*}}[]
+  amdgpu.gather_to_lds %mem1[], %smem1[] : vector<2xf16>, memref<f16>, memref<f16, #gpu.address_space<workgroup>>
+  func.return
+}
+
 // CHECK-LABEL: func @memory_counter_wait
 func.func @memory_counter_wait() {
   // CHECK: amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4) tensor(5)


        


More information about the Mlir-commits mailing list