[Mlir-commits] [mlir] [mlir][nvgpu] Fix a division by zero crash in OptimizeSharedMemoryPass (PR #174931)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 8 01:17:53 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-nvgpu
Author: Longsheng Mou (CoTinker)
<details>
<summary>Changes</summary>
Fixes #<!-- -->173553.
---
Full diff: https://github.com/llvm/llvm-project/pull/174931.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp (+3)
- (modified) mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir (+11)
``````````diff
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
index 957b9632422a6..cb26955cd5fd5 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
@@ -166,6 +166,9 @@ mlir::nvgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
// Check if this is necessary given the assumption of 128b accesses:
// If dim[rank-1] is small enough to fit 8 rows in a 128B line.
const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
+ if (rowSize == ShapedType::kDynamic || rowSize == 0)
+ return failure();
+
const int64_t rowsPerLine =
(8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
rowSize;
diff --git a/mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir b/mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir
index 7477e18728677..596d24b94811e 100644
--- a/mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir
+++ b/mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir
@@ -248,3 +248,14 @@ func.func @test_0_d() -> memref<i32, #gpu.address_space<workgroup>> {
%alloc = memref.alloc() : memref<i32, #gpu.address_space<workgroup>>
return %alloc : memref<i32, #gpu.address_space<workgroup>>
}
+
+// -----
+
+// Ensure the case with zero or dynamic dim not crash.
+
+// CHECK-LABEL: func @test_dynamic_and_zero_dim
+func.func @test_dynamic_and_zero_dim(%arg0 : index) {
+ %alloc = memref.alloc() : memref<0xf32, 3>
+ %alloc_1 = memref.alloc(%arg0) : memref<?xf32, 3>
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/174931
More information about the Mlir-commits
mailing list