[Mlir-commits] [mlir] [mlir][sparse] Fix crash in SparseAssembler when run after SparseTensorCodegen (PR #183896)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Feb 28 02:54:38 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sparse
Author: Mehdi Amini (joker-eph)
<details>
<summary>Changes</summary>
After --sparse-tensor-codegen, sparse tensor arguments are replaced by memrefs and \!sparse_tensor.storage_specifier types. The subsequent --sparse-assembler pass calls getSparseTensorEncoding() to identify sparse arguments to wrap/unwrap. However, getSparseTensorEncoding() returns non-null for StorageSpecifierType as well as for sparse RankedTensorType. Since StorageSpecifierType is not a RankedTensorType, the subsequent cast<RankedTensorType> in convTypes() and convVals() would crash with an assertion failure.
Fix by also checking isa<RankedTensorType>(type) in the passthrough condition in both convTypes() and convVals(), so that StorageSpecifierType arguments pass through unchanged.
Fixes #<!-- -->183776
---
Full diff: https://github.com/llvm/llvm-project/pull/183896.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp (+8-4)
- (added) mlir/test/Dialect/SparseTensor/external_after_codegen.mlir (+30)
``````````diff
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index 40c182f9dbb37..0a8f29cd0b417 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -26,8 +26,10 @@ static void convTypes(bool &hasAnnotation, TypeRange types,
SmallVectorImpl<Type> &convTypes,
SmallVectorImpl<Type> *extraTypes, bool directOut) {
for (auto type : types) {
- // All "dense" data passes through unmodified.
- if (!getSparseTensorEncoding(type)) {
+ // All "dense" data passes through unmodified. Note: getSparseTensorEncoding
+ // also returns non-null for StorageSpecifierType (which is not a
+ // RankedTensorType), so we must check isa<RankedTensorType> as well.
+ if (!getSparseTensorEncoding(type) || !isa<RankedTensorType>(type)) {
convTypes.push_back(type);
continue;
}
@@ -62,8 +64,10 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
bool directOut) {
unsigned idx = 0;
for (auto type : types) {
- // All "dense" data passes through unmodified.
- if (!getSparseTensorEncoding(type)) {
+ // All "dense" data passes through unmodified. Note: getSparseTensorEncoding
+ // also returns non-null for StorageSpecifierType (which is not a
+ // RankedTensorType), so we must check isa<RankedTensorType> as well.
+ if (!getSparseTensorEncoding(type) || !isa<RankedTensorType>(type)) {
toVals.push_back(fromVals[idx++]);
continue;
}
diff --git a/mlir/test/Dialect/SparseTensor/external_after_codegen.mlir b/mlir/test/Dialect/SparseTensor/external_after_codegen.mlir
new file mode 100644
index 0000000000000..b217900c64498
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/external_after_codegen.mlir
@@ -0,0 +1,30 @@
+// RUN: mlir-opt %s --sparse-tensor-codegen --sparse-assembler | FileCheck %s
+
+// Regression test for https://github.com/llvm/llvm-project/issues/183776:
+// Running --sparse-assembler after --sparse-tensor-codegen must not crash.
+// After codegen, sparse tensor arguments are replaced by memrefs and
+// \!sparse_tensor.storage_specifier types. getSparseTensorEncoding() returns
+// non-null for StorageSpecifierType, but convTypes()/convVals() must not
+// attempt cast<RankedTensorType> on it. Instead, non-RankedTensorType types
+// with a sparse encoding should pass through unchanged.
+
+#CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0: dense, d1: compressed) }>
+
+// Storage_specifier types from codegen must pass through sparse-assembler
+// unchanged (not be treated as sparse tensor arguments to wrap).
+// CHECK-LABEL: func.func @storage_specifier_passthrough(
+// CHECK-SAME: storage_specifier
+// CHECK-SAME: storage_specifier
+// CHECK: return %{{.*}} : tensor<32x32xf32>
+func.func @storage_specifier_passthrough(%arg0: tensor<32x32xf32, #CSR>,
+ %arg1: tensor<32x32xf32, #CSR>)
+ -> tensor<32x32xf32> {
+ %cst = arith.constant 0.0 : f32
+ %init = tensor.empty() : tensor<32x32xf32>
+ %out = linalg.fill ins(%cst : f32) outs(%init : tensor<32x32xf32>)
+ -> tensor<32x32xf32>
+ %3 = linalg.add
+ ins(%arg0, %arg1 : tensor<32x32xf32, #CSR>, tensor<32x32xf32, #CSR>)
+ outs(%out : tensor<32x32xf32>) -> tensor<32x32xf32>
+ return %3 : tensor<32x32xf32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/183896
More information about the Mlir-commits
mailing list