[Mlir-commits] [mlir] [mlir][nvgpu] Fix crash in optimize-shared-memory pass with vector element types (PR #179111)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Feb 1 06:28:19 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Jueon Park (JueonPark)
<details>
<summary>Changes</summary>
The --nvgpu-optimize-shared-memory pass crashed when processing memrefs with vector element types (e.g., memref<16x1xvector<16xf16>, 3>). This occurred because getElementTypeBitWidth() calls getIntOrFloatBitWidth(), which asserts the element type must be an integer or float.
Add an early-exit guard to return failure() when the memref's element type is not a scalar int or float.
Fixes #<!-- -->177823
---
Full diff: https://github.com/llvm/llvm-project/pull/179111.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp (+5)
- (modified) mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir (+11)
``````````diff
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
index c4d00b48a49d3..e06b15febf145 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
@@ -156,6 +156,11 @@ mlir::nvgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
if (memRefType.getRank() == 0)
return failure();
+ // Only support memrefs with scalar element types (i.e., int or float).
+ // Memrefs with vector element types are not supported.
+ if (!memRefType.getElementType().isIntOrFloat())
+ return failure();
+
// Abort if the given value has any sub-views; we do not do any alias
// analysis.
bool hasSubView = false;
diff --git a/mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir b/mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir
index 596d24b94811e..128c4c5d918d9 100644
--- a/mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir
+++ b/mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir
@@ -259,3 +259,14 @@ func.func @test_dynamic_and_zero_dim(%arg0 : index) {
%alloc_1 = memref.alloc(%arg0) : memref<?xf32, 3>
return
}
+
+// -----
+
+// Ensure memrefs with vector element types do not crash (issue #177823).
+
+// CHECK-LABEL: func @test_vector_element_type
+// CHECK: memref.alloc() : memref<16x1xvector<16xf16>, 3>
+func.func @test_vector_element_type() {
+ %alloc = memref.alloc() : memref<16x1xvector<16xf16>, 3>
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/179111
More information about the Mlir-commits
mailing list