[llvm-branch-commits] [mlir] [mlir][SparseTensor] Fix type conversion rule (PR #140350)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri May 16 21:48:59 PDT 2025


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/140350

A type conversion rule cannot make any assumptions about the number of pre-existing types in the `results` vector.

This commit fixes a failed assertion in a SparseTensor type conversion rule. This is only reproducible when type conversion caching is deactivated. There's no way to do this at the moment. This commit is in preparation of adding context-aware type conversions, which will deactivate type caching in such cases.

Depends on #140347.


>From f14a088266fa8df837e38088a0e8b22b93dcbfc5 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 17 May 2025 06:44:59 +0200
Subject: [PATCH] [mlir][SparseTensor] Fix type conversion rule

---
 .../Transforms/Utils/SparseTensorDescriptor.cpp          | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
index 8bbb2cac5efdf..79602a22dc1fe 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
@@ -38,12 +38,13 @@ convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl<Type> &fields) {
   if (!stt.hasEncoding())
     return std::nullopt;
 
+  unsigned numFields = fields.size();
   foreachFieldAndTypeInSparseTensor(
       stt,
-      [&fields](Type fieldType, FieldIndex fieldIdx,
-                SparseTensorFieldKind /*fieldKind*/, Level /*lvl*/,
-                LevelType /*lt*/) -> bool {
-        assert(fieldIdx == fields.size());
+      [&](Type fieldType, FieldIndex fieldIdx,
+          SparseTensorFieldKind /*fieldKind*/, Level /*lvl*/,
+          LevelType /*lt*/) -> bool {
+        assert(numFields + fieldIdx == fields.size());
         fields.push_back(fieldType);
         return true;
       });



More information about the llvm-branch-commits mailing list