[Mlir-commits] [mlir] [mlir][tosa] Fix crash in TosaInferShapes when while_loop carries sparse tensors (PR #183943)

Mehdi Amini llvmlistbot at llvm.org
Sat Feb 28 12:12:43 PST 2026


https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/183943

TypeModificationState::commit() inserted a tensor.cast immediately after the defining operation of the value whose type changed. For block arguments (e.g., loop-carried variables in tosa.while_loop), there is no defining operation, so getDefiningOp() returns nullptr, causing a segmentation fault when the insertion point was set via setInsertionPointAfter(nullptr).

Fix by checking whether the value is defined by an operation; if not (i.e., it is a block argument), insert the cast at the start of the block that owns the argument.

Fixes #181449

>From 1fdf25229437935ee8b5a3bc4a69f202ecac62c3 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sat, 28 Feb 2026 09:08:37 -0800
Subject: [PATCH] [mlir][tosa] Fix crash in TosaInferShapes when while_loop
 carries sparse tensors

TypeModificationState::commit() inserted a tensor.cast immediately after
the defining operation of the value whose type changed. For block
arguments (e.g., loop-carried variables in tosa.while_loop), there is no
defining operation, so getDefiningOp() returns nullptr, causing a
segmentation fault when the insertion point was set via
setInsertionPointAfter(nullptr).

Fix by checking whether the value is defined by an operation; if not
(i.e., it is a block argument), insert the cast at the start of the
block that owns the argument.

Fixes #181449
---
 .../Tosa/Transforms/TosaInferShapes.cpp       |  8 ++++++-
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 24 +++++++++++++++++++
 2 files changed, 31 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index 837ebca572aae..e8f34201b7d18 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -112,7 +112,13 @@ class TypeModificationState {
           // are likely to be further up in the code due to the order in which
           // they appear in the use list.
           OpBuilder builder{value.getContext()};
-          builder.setInsertionPointAfter(value.getDefiningOp());
+          if (Operation *defOp = value.getDefiningOp()) {
+            builder.setInsertionPointAfter(defOp);
+          } else {
+            // For block arguments there is no defining op; insert at the start
+            // of the block that owns the argument.
+            builder.setInsertionPointToStart(value.getParentBlock());
+          }
           castValue =
               tensor::CastOp::create(builder, value.getLoc(), oldType, value);
         }
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 386184b59c8cf..8069877a0dfd2 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1878,3 +1878,27 @@ func.func @test_avg_pool2d_unranked_input(%input: tensor<*xi32>, %zp: tensor<1xi
   %0 = tosa.avg_pool2d %input, %zp, %zp { acc_type = i32, kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1> } : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
   return
 }
+
+// -----
+
+// Regression test for https://github.com/llvm/llvm-project/issues/181449
+// Ensure tosa-infer-shapes does not crash when a tosa.while_loop carries a
+// value with a sparse tensor encoding (a block argument without a defining op).
+
+// CHECK-LABEL: @while_with_sparse_tensor_encoding
+func.func @while_with_sparse_tensor_encoding() {
+  %0 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %1 = sparse_tensor.convert %0 : tensor<1xi32> to tensor<1xi32, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>>
+  // CHECK: tosa.while_loop
+  %2 = tosa.while_loop (%arg0 = %1) : (tensor<1xi32, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>>) -> tensor<1xi32> {
+    %3 = "tosa.const"() <{values = dense<3> : tensor<1xi32>}> : () -> tensor<1xi32>
+    %4 = sparse_tensor.convert %arg0 : tensor<1xi32, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>> to tensor<1xi32>
+    %5 = tosa.greater_equal %3, %4 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+    tosa.yield %5 : tensor<1xi1>
+  } do {
+    ^bb0(%arg0: tensor<1xi32, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>>):
+    %3 = "tosa.const"() <{values = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
+    tosa.yield %3 : tensor<1xi32>
+  }
+  return
+}



More information about the Mlir-commits mailing list