[Mlir-commits] [mlir] [mlir][nvgpu] Fix crash in optimize-shared-memory pass with vector element types (PR #179111)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Feb 1 06:28:19 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Jueon Park (JueonPark)

<details>
<summary>Changes</summary>

The --nvgpu-optimize-shared-memory pass crashed when processing memrefs with vector element types (e.g., memref<16x1xvector<16xf16>, 3>). This occurred because getElementTypeBitWidth() calls getIntOrFloatBitWidth(), which asserts the element type must be an integer or float.

Add an early-exit guard to return failure() when the memref's element type is not a scalar int or float.

Fixes #<!-- -->177823

---
Full diff: https://github.com/llvm/llvm-project/pull/179111.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp (+5) 
- (modified) mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir (+11) 


``````````diff
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
index c4d00b48a49d3..e06b15febf145 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
@@ -156,6 +156,11 @@ mlir::nvgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
   if (memRefType.getRank() == 0)
     return failure();
 
+  // Only support memrefs with scalar element types (i.e., int or float).
+  // Memrefs with vector element types are not supported.
+  if (!memRefType.getElementType().isIntOrFloat())
+    return failure();
+
   // Abort if the given value has any sub-views; we do not do any alias
   // analysis.
   bool hasSubView = false;
diff --git a/mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir b/mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir
index 596d24b94811e..128c4c5d918d9 100644
--- a/mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir
+++ b/mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir
@@ -259,3 +259,14 @@ func.func @test_dynamic_and_zero_dim(%arg0 : index) {
   %alloc_1 = memref.alloc(%arg0) : memref<?xf32, 3>
   return
 }
+
+// -----
+
+// Ensure memrefs with vector element types do not crash (issue #177823).
+
+// CHECK-LABEL: func @test_vector_element_type
+// CHECK: memref.alloc() : memref<16x1xvector<16xf16>, 3>
+func.func @test_vector_element_type() {
+  %alloc = memref.alloc() : memref<16x1xvector<16xf16>, 3>
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/179111


More information about the Mlir-commits mailing list