[Mlir-commits] [mlir] [mlir][tosa] Robustify Tosa_IfOp against null dereference and wrong assertion (PR #159756)

Jasmine Tang llvmlistbot at llvm.org
Fri Sep 19 22:48:06 PDT 2025


https://github.com/badumbatish updated https://github.com/llvm/llvm-project/pull/159756

>From 7b6b0cf910132bbdc557f4907ce3b50dd65f878b Mon Sep 17 00:00:00 2001
From: Jasmine Tang <jjasmine at igalia.com>
Date: Fri, 19 Sep 2025 03:33:47 -0700
Subject: [PATCH 1/2] Robustify Tosa_IfOp against null dereference and wrong
 assertion

---
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 32 +++++++++++++++++-----------
 mlir/test/Dialect/Tosa/verifier.mlir | 25 ++++++++++++++++++++++
 2 files changed, 45 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index b4c87a34a0e5a..309ef0bfd4ca3 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -4025,19 +4025,27 @@ LogicalResult IfOp::verify() {
           .failed())
     return failure();
 
-  auto thenYield = cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
-  if (errorIfTypeOrShapeMismatch(*this, thenYield.getInputs(),
-                                 "'then_graph' results", getOutputList(),
-                                 "'output_list'")
-          .failed())
-    return failure();
+  // MLIR will verify the absence of the terminator for us if otherwise.
+  if (getThenGraph().front().mightHaveTerminator()) {
+    auto thenYield =
+        dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
+    if (thenYield && errorIfTypeOrShapeMismatch(
+                         *this, thenYield.getInputs(), "'then_graph' results",
+                         getOutputList(), "'output_list'")
+                         .failed())
+      return failure();
+  }
 
-  auto elseYield = cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
-  if (errorIfTypeOrShapeMismatch(*this, elseYield.getInputs(),
-                                 "'else_graph' results", getOutputList(),
-                                 "'output_list'")
-          .failed())
-    return failure();
+  // MLIR will verify the absence of the terminator for us if otherwise.
+  if (getElseGraph().front().mightHaveTerminator()) {
+    auto elseYield =
+        dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
+    if (elseYield && errorIfTypeOrShapeMismatch(
+                         *this, elseYield.getInputs(), "'else_graph' results",
+                         getOutputList(), "'output_list'")
+                         .failed())
+      return failure();
+  }
 
   auto condType = getCondition().getType();
   if (errorIfShapeNotSizeOne(*this, condType).failed())
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index f58ddb180ce4f..cee2af2ad0c07 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -438,6 +438,31 @@ func.func @test_pad_invalid_padding_value(%arg0: tensor<10xi8>, %arg1: tensor<1x
   return %1 : tensor<10xi8>
 }
 
+// -----
+func.func @test_cond_if_wrong_terminator_op(%arg0: tensor<i1>) -> tensor<i32> {
+  %0 = "tosa.cond_if"(%arg0) ({
+    %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+    "tosa.yield"(%1) : (tensor<i32>) -> ()
+  }, {
+    // expected-error at +2 {{'func.return' op expects parent op 'func.func'}}
+    %2 = "tosa.const"() <{values = dense<2> : tensor<i32>}> : () -> tensor<i32>
+    "func.return"(%2) : (tensor<i32>) -> ()
+  }) : (tensor<i1>) -> tensor<i32>
+  return %0 : tensor<i32>
+}
+
+// -----
+func.func @test_cond_if_missing_terminator(%arg0: tensor<i1>) -> tensor<i32> {
+  %0 = "tosa.cond_if"(%arg0) ({
+    %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+    "tosa.yield"(%1) : (tensor<i32>) -> ()
+  }, {
+    // expected-error at +1 {{block with no terminator}}
+    %2 = "tosa.const"() <{values = dense<2> : tensor<i32>}> : () -> tensor<i32>
+  }) : (tensor<i1>) -> tensor<i32>
+  return %0 : tensor<i32>
+}
+
 // -----
 
 func.func @test_cond_if_input_list_mismatch_then_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {

>From 08be3994b1f7bf200b9855291fd77258dd9cd931 Mon Sep 17 00:00:00 2001
From: Jasmine Tang <jjasmine at igalia.com>
Date: Fri, 19 Sep 2025 22:47:31 -0700
Subject: [PATCH 2/2] Add then test case

---
 mlir/test/Dialect/Tosa/verifier.mlir | 14 +++++++++++++-
 1 file changed, 13 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index cee2af2ad0c07..e5571b6b4412c 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -452,7 +452,19 @@ func.func @test_cond_if_wrong_terminator_op(%arg0: tensor<i1>) -> tensor<i32> {
 }
 
 // -----
-func.func @test_cond_if_missing_terminator(%arg0: tensor<i1>) -> tensor<i32> {
+func.func @test_cond_if_missing_then_terminator(%arg0: tensor<i1>) -> tensor<i32> {
+  %0 = "tosa.cond_if"(%arg0) ({
+    // expected-error at +1 {{block with no terminator}}
+    %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+  }, {
+    %2 = "tosa.const"() <{values = dense<2> : tensor<i32>}> : () -> tensor<i32>
+    "tosa.yield"(%2) : (tensor<i32>) -> ()
+  }) : (tensor<i1>) -> tensor<i32>
+  return %0 : tensor<i32>
+}
+
+// -----
+func.func @test_cond_if_missing_else_terminator(%arg0: tensor<i1>) -> tensor<i32> {
   %0 = "tosa.cond_if"(%arg0) ({
     %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
     "tosa.yield"(%1) : (tensor<i32>) -> ()



More information about the Mlir-commits mailing list