[Mlir-commits] [mlir] [mlir][nvgpu] Fix a division by zero crash in OptimizeSharedMemoryPass (PR #174931)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 8 01:17:53 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-nvgpu

Author: Longsheng Mou (CoTinker)

<details>
<summary>Changes</summary>

Fixes #<!-- -->173553.

---
Full diff: https://github.com/llvm/llvm-project/pull/174931.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp (+3) 
- (modified) mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir (+11) 


``````````diff
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
index 957b9632422a6..cb26955cd5fd5 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
@@ -166,6 +166,9 @@ mlir::nvgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
   // Check if this is necessary given the assumption of 128b accesses:
   // If dim[rank-1] is small enough to fit 8 rows in a 128B line.
   const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
+  if (rowSize == ShapedType::kDynamic || rowSize == 0)
+    return failure();
+
   const int64_t rowsPerLine =
       (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
       rowSize;
diff --git a/mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir b/mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir
index 7477e18728677..596d24b94811e 100644
--- a/mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir
+++ b/mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir
@@ -248,3 +248,14 @@ func.func @test_0_d() -> memref<i32, #gpu.address_space<workgroup>> {
   %alloc = memref.alloc() : memref<i32, #gpu.address_space<workgroup>>
   return %alloc : memref<i32, #gpu.address_space<workgroup>>
 }
+
+// -----
+
+// Ensure the case with zero or dynamic dim not crash.
+
+// CHECK-LABEL: func @test_dynamic_and_zero_dim
+func.func @test_dynamic_and_zero_dim(%arg0 : index) {
+  %alloc = memref.alloc() : memref<0xf32, 3>
+  %alloc_1 = memref.alloc(%arg0) : memref<?xf32, 3>
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/174931


More information about the Mlir-commits mailing list