[Mlir-commits] [mlir] [mlir][tosa] Robustify Tosa_while_loop op against null dereference and wrong assertion (PR #159910)

Jasmine Tang llvmlistbot at llvm.org
Sat Sep 20 00:05:10 PDT 2025


https://github.com/badumbatish created https://github.com/llvm/llvm-project/pull/159910

Follow up to #159756

>From 180a49c2305a2e9a7ae37ae3e734c43bcff3d5e3 Mon Sep 17 00:00:00 2001
From: Jasmine Tang <jjasmine at igalia.com>
Date: Sat, 20 Sep 2025 00:01:40 -0700
Subject: [PATCH] Robustify while op

---
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 24 +++++++++++-----
 mlir/test/Dialect/Tosa/verifier.mlir | 42 ++++++++++++++++++++++++++++
 2 files changed, 59 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1c0a6a618fcd2..c5133dfa9609e 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -4065,16 +4065,26 @@ LogicalResult WhileOp::verify() {
           .failed())
     return failure();
 
-  auto bodyYield = cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
-  if (errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
-                                 "'body_graph' results", getInputList(),
-                                 "'input_list'")
-          .failed())
-    return failure();
+  if (getBodyGraph().front().mightHaveTerminator()) {
+    auto bodyYield =
+        dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
+    if (bodyYield && errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
+                                                "'body_graph' results",
+                                                getInputList(), "'input_list'")
+                         .failed())
+      return failure();
+  }
 
   // Condition block output must be a single element tensor with a single bool
   // value.
-  auto condYield = cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
+  if (!getCondGraph().front().mightHaveTerminator())
+    return success();
+
+  auto condYield =
+      dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
+  if (!condYield)
+    return success();
+
   if (condYield.getInputs().size() != 1)
     return emitOpError() << "require 'cond_graph' only have one result";
 
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index f58ddb180ce4f..2e18fe46e21c6 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -649,6 +649,48 @@ func.func @test_cond_if_incorrect_type_simple(%arg0: tensor<f32>, %arg1: tensor<
   return %0 : tensor<f32>
 }
 
+// -----
+func.func @test_while_loop_wrong_terminator(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
+    %0 = tosa.while_loop (%arg2 = %arg0) : (tensor<i32>) -> tensor<i32> {
+      // expected-error at +2 {{'func.return' op expects parent op 'func.func'}}
+      %1 = tosa.greater_equal %arg1, %arg2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+      "func.return"(%arg2) : (tensor<i32>) -> ()
+    } do {
+    ^bb0(%arg2: tensor<i32>):
+      %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+      %2 = tosa.add %arg2, %1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+      tosa.yield %2 : tensor<i32>
+    }
+    return %0 : tensor<i32>
+}
+
+// -----
+func.func @test_while_loop_missing_cond_terminator(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
+    %0 = tosa.while_loop (%arg2 = %arg0) : (tensor<i32>) -> tensor<i32> {
+      // expected-error at +1 {{block with no terminator}}
+      %1 = tosa.greater_equal %arg1, %arg2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    } do {
+    ^bb0(%arg2: tensor<i32>):
+      %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+      %2 = tosa.add %arg2, %1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+      tosa.yield %2 : tensor<i32>
+    }
+    return %0 : tensor<i32>
+}
+
+// -----
+func.func @test_while_loop_missing_body_terminator(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
+    %0 = tosa.while_loop (%arg2 = %arg0) : (tensor<i32>) -> tensor<i32> {
+      %1 = tosa.greater_equal %arg1, %arg2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+      tosa.yield %1 : tensor<i1>
+    } do {
+    ^bb0(%arg2: tensor<i32>):
+      // expected-error at +1 {{block with no terminator}}
+      %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+    }
+    return %0 : tensor<i32>
+}
+
 // -----
 
 func.func @test_while_loop_input_list_mismatch_body_block_in(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {



More information about the Mlir-commits mailing list