[Mlir-commits] [mlir] [mlir][tosa] Robustify Tosa_while_loop op against null dereference and wrong assertion (PR #159910)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Sep 20 00:05:39 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
Author: Jasmine Tang (badumbatish)
<details>
<summary>Changes</summary>
Follow up to #<!-- -->159756
---
Full diff: https://github.com/llvm/llvm-project/pull/159910.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+17-7)
- (modified) mlir/test/Dialect/Tosa/verifier.mlir (+42)
``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1c0a6a618fcd2..c5133dfa9609e 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -4065,16 +4065,26 @@ LogicalResult WhileOp::verify() {
.failed())
return failure();
- auto bodyYield = cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
- if (errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
- "'body_graph' results", getInputList(),
- "'input_list'")
- .failed())
- return failure();
+ if (getBodyGraph().front().mightHaveTerminator()) {
+ auto bodyYield =
+ dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
+ if (bodyYield && errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
+ "'body_graph' results",
+ getInputList(), "'input_list'")
+ .failed())
+ return failure();
+ }
// Condition block output must be a single element tensor with a single bool
// value.
- auto condYield = cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
+ if (!getCondGraph().front().mightHaveTerminator())
+ return success();
+
+ auto condYield =
+ dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
+ if (!condYield)
+ return success();
+
if (condYield.getInputs().size() != 1)
return emitOpError() << "require 'cond_graph' only have one result";
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index f58ddb180ce4f..2e18fe46e21c6 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -649,6 +649,48 @@ func.func @test_cond_if_incorrect_type_simple(%arg0: tensor<f32>, %arg1: tensor<
return %0 : tensor<f32>
}
+// -----
+func.func @test_while_loop_wrong_terminator(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
+ %0 = tosa.while_loop (%arg2 = %arg0) : (tensor<i32>) -> tensor<i32> {
+ // expected-error at +2 {{'func.return' op expects parent op 'func.func'}}
+ %1 = tosa.greater_equal %arg1, %arg2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ "func.return"(%arg2) : (tensor<i32>) -> ()
+ } do {
+ ^bb0(%arg2: tensor<i32>):
+ %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+ %2 = tosa.add %arg2, %1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %2 : tensor<i32>
+ }
+ return %0 : tensor<i32>
+}
+
+// -----
+func.func @test_while_loop_missing_cond_terminator(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
+ %0 = tosa.while_loop (%arg2 = %arg0) : (tensor<i32>) -> tensor<i32> {
+ // expected-error at +1 {{block with no terminator}}
+ %1 = tosa.greater_equal %arg1, %arg2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ } do {
+ ^bb0(%arg2: tensor<i32>):
+ %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+ %2 = tosa.add %arg2, %1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %2 : tensor<i32>
+ }
+ return %0 : tensor<i32>
+}
+
+// -----
+func.func @test_while_loop_missing_body_terminator(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
+ %0 = tosa.while_loop (%arg2 = %arg0) : (tensor<i32>) -> tensor<i32> {
+ %1 = tosa.greater_equal %arg1, %arg2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ tosa.yield %1 : tensor<i1>
+ } do {
+ ^bb0(%arg2: tensor<i32>):
+ // expected-error at +1 {{block with no terminator}}
+ %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+ }
+ return %0 : tensor<i32>
+}
+
// -----
func.func @test_while_loop_input_list_mismatch_body_block_in(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/159910
More information about the Mlir-commits
mailing list