[Mlir-commits] [mlir] 19be8d6 - [mlir][tosa] Fix crash in TosaInferShapes when while_loop carries sparse tensors (#183943)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 2 05:59:24 PST 2026


Author: Mehdi Amini
Date: 2026-03-02T14:59:19+01:00
New Revision: 19be8d60662be4972cf52ca58a3c35e0983c1d73

URL: https://github.com/llvm/llvm-project/commit/19be8d60662be4972cf52ca58a3c35e0983c1d73
DIFF: https://github.com/llvm/llvm-project/commit/19be8d60662be4972cf52ca58a3c35e0983c1d73.diff

LOG: [mlir][tosa] Fix crash in TosaInferShapes when while_loop carries sparse tensors (#183943)

TypeModificationState::commit() inserted a tensor.cast immediately after
the defining operation of the value whose type changed. For block
arguments (e.g., loop-carried variables in tosa.while_loop), there is no
defining operation, so getDefiningOp() returns nullptr, causing a
segmentation fault when the insertion point was set via
setInsertionPointAfter(nullptr).

Fix by checking whether the value is defined by an operation; if not
(i.e., it is a block argument), insert the cast at the start of the
block that owns the argument.

Fixes #181449

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
    mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index 837ebca572aae..e8f34201b7d18 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -112,7 +112,13 @@ class TypeModificationState {
           // are likely to be further up in the code due to the order in which
           // they appear in the use list.
           OpBuilder builder{value.getContext()};
-          builder.setInsertionPointAfter(value.getDefiningOp());
+          if (Operation *defOp = value.getDefiningOp()) {
+            builder.setInsertionPointAfter(defOp);
+          } else {
+            // For block arguments there is no defining op; insert at the start
+            // of the block that owns the argument.
+            builder.setInsertionPointToStart(value.getParentBlock());
+          }
           castValue =
               tensor::CastOp::create(builder, value.getLoc(), oldType, value);
         }

diff  --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 386184b59c8cf..8069877a0dfd2 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1878,3 +1878,27 @@ func.func @test_avg_pool2d_unranked_input(%input: tensor<*xi32>, %zp: tensor<1xi
   %0 = tosa.avg_pool2d %input, %zp, %zp { acc_type = i32, kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1> } : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
   return
 }
+
+// -----
+
+// Regression test for https://github.com/llvm/llvm-project/issues/181449
+// Ensure tosa-infer-shapes does not crash when a tosa.while_loop carries a
+// value with a sparse tensor encoding (a block argument without a defining op).
+
+// CHECK-LABEL: @while_with_sparse_tensor_encoding
+func.func @while_with_sparse_tensor_encoding() {
+  %0 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %1 = sparse_tensor.convert %0 : tensor<1xi32> to tensor<1xi32, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>>
+  // CHECK: tosa.while_loop
+  %2 = tosa.while_loop (%arg0 = %1) : (tensor<1xi32, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>>) -> tensor<1xi32> {
+    %3 = "tosa.const"() <{values = dense<3> : tensor<1xi32>}> : () -> tensor<1xi32>
+    %4 = sparse_tensor.convert %arg0 : tensor<1xi32, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>> to tensor<1xi32>
+    %5 = tosa.greater_equal %3, %4 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+    tosa.yield %5 : tensor<1xi1>
+  } do {
+    ^bb0(%arg0: tensor<1xi32, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>>):
+    %3 = "tosa.const"() <{values = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
+    tosa.yield %3 : tensor<1xi32>
+  }
+  return
+}


        


More information about the Mlir-commits mailing list