[Mlir-commits] [mlir] eede476 - [mlir][tosa] Robustify Tosa_while_loop op against null dereference and wrong assertion (#159910)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 22 08:40:02 PDT 2025
Author: Jasmine Tang
Date: 2025-09-22T16:39:58+01:00
New Revision: eede47656b0cc9c3cff8e1959a6f3d55402f3283
URL: https://github.com/llvm/llvm-project/commit/eede47656b0cc9c3cff8e1959a6f3d55402f3283
DIFF: https://github.com/llvm/llvm-project/commit/eede47656b0cc9c3cff8e1959a6f3d55402f3283.diff
LOG: [mlir][tosa] Robustify Tosa_while_loop op against null dereference and wrong assertion (#159910)
Follow up to #159756
Added:
Modified:
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/verifier.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index aa58fc21fe26f..26ad641128b3d 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -4073,16 +4073,26 @@ LogicalResult WhileOp::verify() {
.failed())
return failure();
- auto bodyYield = cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
- if (errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
- "'body_graph' results", getInputList(),
- "'input_list'")
- .failed())
- return failure();
+ if (getBodyGraph().front().mightHaveTerminator()) {
+ auto bodyYield =
+ dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
+ if (bodyYield && errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
+ "'body_graph' results",
+ getInputList(), "'input_list'")
+ .failed())
+ return failure();
+ }
// Condition block output must be a single element tensor with a single bool
// value.
- auto condYield = cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
+ if (!getCondGraph().front().mightHaveTerminator())
+ return success();
+
+ auto condYield =
+ dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
+ if (!condYield)
+ return success();
+
if (condYield.getInputs().size() != 1)
return emitOpError() << "require 'cond_graph' only have one result";
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index e5571b6b4412c..0128da729136e 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -686,6 +686,48 @@ func.func @test_cond_if_incorrect_type_simple(%arg0: tensor<f32>, %arg1: tensor<
return %0 : tensor<f32>
}
+// -----
+func.func @test_while_loop_wrong_terminator(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
+ %0 = tosa.while_loop (%arg2 = %arg0) : (tensor<i32>) -> tensor<i32> {
+ // expected-error at +2 {{'func.return' op expects parent op 'func.func'}}
+ %1 = tosa.greater_equal %arg1, %arg2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ "func.return"(%arg2) : (tensor<i32>) -> ()
+ } do {
+ ^bb0(%arg2: tensor<i32>):
+ %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+ %2 = tosa.add %arg2, %1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %2 : tensor<i32>
+ }
+ return %0 : tensor<i32>
+}
+
+// -----
+func.func @test_while_loop_missing_cond_terminator(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
+ %0 = tosa.while_loop (%arg2 = %arg0) : (tensor<i32>) -> tensor<i32> {
+ // expected-error at +1 {{block with no terminator}}
+ %1 = tosa.greater_equal %arg1, %arg2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ } do {
+ ^bb0(%arg2: tensor<i32>):
+ %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+ %2 = tosa.add %arg2, %1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %2 : tensor<i32>
+ }
+ return %0 : tensor<i32>
+}
+
+// -----
+func.func @test_while_loop_missing_body_terminator(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
+ %0 = tosa.while_loop (%arg2 = %arg0) : (tensor<i32>) -> tensor<i32> {
+ %1 = tosa.greater_equal %arg1, %arg2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ tosa.yield %1 : tensor<i1>
+ } do {
+ ^bb0(%arg2: tensor<i32>):
+ // expected-error at +1 {{block with no terminator}}
+ %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+ }
+ return %0 : tensor<i32>
+}
+
// -----
func.func @test_while_loop_input_list_mismatch_body_block_in(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
More information about the Mlir-commits
mailing list