[Mlir-commits] [mlir] [mlir][tosa] Robustify Tosa_while_loop op against null dereference and wrong assertion (PR #159910)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Sep 20 00:05:39 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Jasmine Tang (badumbatish)

<details>
<summary>Changes</summary>

Follow up to #<!-- -->159756

---
Full diff: https://github.com/llvm/llvm-project/pull/159910.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+17-7) 
- (modified) mlir/test/Dialect/Tosa/verifier.mlir (+42) 


``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1c0a6a618fcd2..c5133dfa9609e 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -4065,16 +4065,26 @@ LogicalResult WhileOp::verify() {
           .failed())
     return failure();
 
-  auto bodyYield = cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
-  if (errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
-                                 "'body_graph' results", getInputList(),
-                                 "'input_list'")
-          .failed())
-    return failure();
+  if (getBodyGraph().front().mightHaveTerminator()) {
+    auto bodyYield =
+        dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
+    if (bodyYield && errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
+                                                "'body_graph' results",
+                                                getInputList(), "'input_list'")
+                         .failed())
+      return failure();
+  }
 
   // Condition block output must be a single element tensor with a single bool
   // value.
-  auto condYield = cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
+  if (!getCondGraph().front().mightHaveTerminator())
+    return success();
+
+  auto condYield =
+      dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
+  if (!condYield)
+    return success();
+
   if (condYield.getInputs().size() != 1)
     return emitOpError() << "require 'cond_graph' only have one result";
 
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index f58ddb180ce4f..2e18fe46e21c6 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -649,6 +649,48 @@ func.func @test_cond_if_incorrect_type_simple(%arg0: tensor<f32>, %arg1: tensor<
   return %0 : tensor<f32>
 }
 
+// -----
+func.func @test_while_loop_wrong_terminator(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
+    %0 = tosa.while_loop (%arg2 = %arg0) : (tensor<i32>) -> tensor<i32> {
+      // expected-error at +2 {{'func.return' op expects parent op 'func.func'}}
+      %1 = tosa.greater_equal %arg1, %arg2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+      "func.return"(%arg2) : (tensor<i32>) -> ()
+    } do {
+    ^bb0(%arg2: tensor<i32>):
+      %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+      %2 = tosa.add %arg2, %1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+      tosa.yield %2 : tensor<i32>
+    }
+    return %0 : tensor<i32>
+}
+
+// -----
+func.func @test_while_loop_missing_cond_terminator(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
+    %0 = tosa.while_loop (%arg2 = %arg0) : (tensor<i32>) -> tensor<i32> {
+      // expected-error at +1 {{block with no terminator}}
+      %1 = tosa.greater_equal %arg1, %arg2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    } do {
+    ^bb0(%arg2: tensor<i32>):
+      %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+      %2 = tosa.add %arg2, %1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+      tosa.yield %2 : tensor<i32>
+    }
+    return %0 : tensor<i32>
+}
+
+// -----
+func.func @test_while_loop_missing_body_terminator(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
+    %0 = tosa.while_loop (%arg2 = %arg0) : (tensor<i32>) -> tensor<i32> {
+      %1 = tosa.greater_equal %arg1, %arg2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+      tosa.yield %1 : tensor<i1>
+    } do {
+    ^bb0(%arg2: tensor<i32>):
+      // expected-error at +1 {{block with no terminator}}
+      %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+    }
+    return %0 : tensor<i32>
+}
+
 // -----
 
 func.func @test_while_loop_input_list_mismatch_body_block_in(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/159910


More information about the Mlir-commits mailing list