[Mlir-commits] [mlir] [mlir][sparse] Fix crash in SparseAssembler when run after SparseTensorCodegen (PR #183896)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Feb 28 02:54:38 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-sparse

Author: Mehdi Amini (joker-eph)

<details>
<summary>Changes</summary>

After --sparse-tensor-codegen, sparse tensor arguments are replaced by memrefs and \!sparse_tensor.storage_specifier types. The subsequent --sparse-assembler pass calls getSparseTensorEncoding() to identify sparse arguments to wrap/unwrap. However, getSparseTensorEncoding() returns non-null for StorageSpecifierType as well as for sparse RankedTensorType. Since StorageSpecifierType is not a RankedTensorType, the subsequent cast<RankedTensorType> in convTypes() and convVals() would crash with an assertion failure.

Fix by also checking isa<RankedTensorType>(type) in the passthrough condition in both convTypes() and convVals(), so that StorageSpecifierType arguments pass through unchanged.

Fixes #<!-- -->183776

---
Full diff: https://github.com/llvm/llvm-project/pull/183896.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp (+8-4) 
- (added) mlir/test/Dialect/SparseTensor/external_after_codegen.mlir (+30) 


``````````diff
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index 40c182f9dbb37..0a8f29cd0b417 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -26,8 +26,10 @@ static void convTypes(bool &hasAnnotation, TypeRange types,
                       SmallVectorImpl<Type> &convTypes,
                       SmallVectorImpl<Type> *extraTypes, bool directOut) {
   for (auto type : types) {
-    // All "dense" data passes through unmodified.
-    if (!getSparseTensorEncoding(type)) {
+    // All "dense" data passes through unmodified. Note: getSparseTensorEncoding
+    // also returns non-null for StorageSpecifierType (which is not a
+    // RankedTensorType), so we must check isa<RankedTensorType> as well.
+    if (!getSparseTensorEncoding(type) || !isa<RankedTensorType>(type)) {
       convTypes.push_back(type);
       continue;
     }
@@ -62,8 +64,10 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
                      bool directOut) {
   unsigned idx = 0;
   for (auto type : types) {
-    // All "dense" data passes through unmodified.
-    if (!getSparseTensorEncoding(type)) {
+    // All "dense" data passes through unmodified. Note: getSparseTensorEncoding
+    // also returns non-null for StorageSpecifierType (which is not a
+    // RankedTensorType), so we must check isa<RankedTensorType> as well.
+    if (!getSparseTensorEncoding(type) || !isa<RankedTensorType>(type)) {
       toVals.push_back(fromVals[idx++]);
       continue;
     }
diff --git a/mlir/test/Dialect/SparseTensor/external_after_codegen.mlir b/mlir/test/Dialect/SparseTensor/external_after_codegen.mlir
new file mode 100644
index 0000000000000..b217900c64498
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/external_after_codegen.mlir
@@ -0,0 +1,30 @@
+// RUN: mlir-opt %s --sparse-tensor-codegen --sparse-assembler | FileCheck %s
+
+// Regression test for https://github.com/llvm/llvm-project/issues/183776:
+// Running --sparse-assembler after --sparse-tensor-codegen must not crash.
+// After codegen, sparse tensor arguments are replaced by memrefs and
+// \!sparse_tensor.storage_specifier types. getSparseTensorEncoding() returns
+// non-null for StorageSpecifierType, but convTypes()/convVals() must not
+// attempt cast<RankedTensorType> on it. Instead, non-RankedTensorType types
+// with a sparse encoding should pass through unchanged.
+
+#CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0: dense, d1: compressed) }>
+
+// Storage_specifier types from codegen must pass through sparse-assembler
+// unchanged (not be treated as sparse tensor arguments to wrap).
+// CHECK-LABEL: func.func @storage_specifier_passthrough(
+// CHECK-SAME:    storage_specifier
+// CHECK-SAME:    storage_specifier
+// CHECK:         return %{{.*}} : tensor<32x32xf32>
+func.func @storage_specifier_passthrough(%arg0: tensor<32x32xf32, #CSR>,
+                                         %arg1: tensor<32x32xf32, #CSR>)
+    -> tensor<32x32xf32> {
+  %cst = arith.constant 0.0 : f32
+  %init = tensor.empty() : tensor<32x32xf32>
+  %out = linalg.fill ins(%cst : f32) outs(%init : tensor<32x32xf32>)
+      -> tensor<32x32xf32>
+  %3 = linalg.add
+      ins(%arg0, %arg1 : tensor<32x32xf32, #CSR>, tensor<32x32xf32, #CSR>)
+      outs(%out : tensor<32x32xf32>) -> tensor<32x32xf32>
+  return %3 : tensor<32x32xf32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/183896


More information about the Mlir-commits mailing list