[Mlir-commits] [mlir] [mlir][tosa] Fix crash in TosaInferShapes when while_loop carries sparse tensors (PR #183943)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Feb 28 12:13:11 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
Author: Mehdi Amini (joker-eph)
<details>
<summary>Changes</summary>
TypeModificationState::commit() inserted a tensor.cast immediately after the defining operation of the value whose type changed. For block arguments (e.g., loop-carried variables in tosa.while_loop), there is no defining operation, so getDefiningOp() returns nullptr, causing a segmentation fault when the insertion point was set via setInsertionPointAfter(nullptr).
Fix by checking whether the value is defined by an operation; if not (i.e., it is a block argument), insert the cast at the start of the block that owns the argument.
Fixes #<!-- -->181449
---
Full diff: https://github.com/llvm/llvm-project/pull/183943.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp (+7-1)
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+24)
``````````diff
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index 837ebca572aae..e8f34201b7d18 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -112,7 +112,13 @@ class TypeModificationState {
// are likely to be further up in the code due to the order in which
// they appear in the use list.
OpBuilder builder{value.getContext()};
- builder.setInsertionPointAfter(value.getDefiningOp());
+ if (Operation *defOp = value.getDefiningOp()) {
+ builder.setInsertionPointAfter(defOp);
+ } else {
+ // For block arguments there is no defining op; insert at the start
+ // of the block that owns the argument.
+ builder.setInsertionPointToStart(value.getParentBlock());
+ }
castValue =
tensor::CastOp::create(builder, value.getLoc(), oldType, value);
}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 386184b59c8cf..8069877a0dfd2 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1878,3 +1878,27 @@ func.func @test_avg_pool2d_unranked_input(%input: tensor<*xi32>, %zp: tensor<1xi
%0 = tosa.avg_pool2d %input, %zp, %zp { acc_type = i32, kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1> } : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
return
}
+
+// -----
+
+// Regression test for https://github.com/llvm/llvm-project/issues/181449
+// Ensure tosa-infer-shapes does not crash when a tosa.while_loop carries a
+// value with a sparse tensor encoding (a block argument without a defining op).
+
+// CHECK-LABEL: @while_with_sparse_tensor_encoding
+func.func @while_with_sparse_tensor_encoding() {
+ %0 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+ %1 = sparse_tensor.convert %0 : tensor<1xi32> to tensor<1xi32, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>>
+ // CHECK: tosa.while_loop
+ %2 = tosa.while_loop (%arg0 = %1) : (tensor<1xi32, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>>) -> tensor<1xi32> {
+ %3 = "tosa.const"() <{values = dense<3> : tensor<1xi32>}> : () -> tensor<1xi32>
+ %4 = sparse_tensor.convert %arg0 : tensor<1xi32, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>> to tensor<1xi32>
+ %5 = tosa.greater_equal %3, %4 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+ tosa.yield %5 : tensor<1xi1>
+ } do {
+ ^bb0(%arg0: tensor<1xi32, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>>):
+ %3 = "tosa.const"() <{values = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
+ tosa.yield %3 : tensor<1xi32>
+ }
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/183943
More information about the Mlir-commits
mailing list