[Mlir-commits] [mlir] 370ea51 - [mlir][tosa] Robustify Tosa_IfOp against null dereference and wrong assertion (#159756)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 22 01:34:44 PDT 2025
Author: Jasmine Tang
Date: 2025-09-22T09:34:40+01:00
New Revision: 370ea51f706a043d4a660bda8513f51c288772a6
URL: https://github.com/llvm/llvm-project/commit/370ea51f706a043d4a660bda8513f51c288772a6
DIFF: https://github.com/llvm/llvm-project/commit/370ea51f706a043d4a660bda8513f51c288772a6.diff
LOG: [mlir][tosa] Robustify Tosa_IfOp against null dereference and wrong assertion (#159756)
Fixes #159650.
The current implementation ICE out if we access an IfOp's terminator
when it doesn't have it. Instead the PR defers the job of verifying that
a block would have at least a terminator.
The current implementation also crashes with cast<YieldOp> if the
terminator is not a YieldOp, the PR also defers the job of verification
to the op itself.
Added:
Modified:
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/verifier.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1c0a6a618fcd2..aa58fc21fe26f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -4025,19 +4025,27 @@ LogicalResult IfOp::verify() {
.failed())
return failure();
- auto thenYield = cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
- if (errorIfTypeOrShapeMismatch(*this, thenYield.getInputs(),
- "'then_graph' results", getOutputList(),
- "'output_list'")
- .failed())
- return failure();
+ // MLIR will verify the absence of the terminator for us if otherwise.
+ if (getThenGraph().front().mightHaveTerminator()) {
+ auto thenYield =
+ dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
+ if (thenYield && errorIfTypeOrShapeMismatch(
+ *this, thenYield.getInputs(), "'then_graph' results",
+ getOutputList(), "'output_list'")
+ .failed())
+ return failure();
+ }
- auto elseYield = cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
- if (errorIfTypeOrShapeMismatch(*this, elseYield.getInputs(),
- "'else_graph' results", getOutputList(),
- "'output_list'")
- .failed())
- return failure();
+ // MLIR will verify the absence of the terminator for us if otherwise.
+ if (getElseGraph().front().mightHaveTerminator()) {
+ auto elseYield =
+ dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
+ if (elseYield && errorIfTypeOrShapeMismatch(
+ *this, elseYield.getInputs(), "'else_graph' results",
+ getOutputList(), "'output_list'")
+ .failed())
+ return failure();
+ }
auto condType = getCondition().getType();
if (errorIfShapeNotSizeOne(*this, condType).failed())
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index f58ddb180ce4f..e5571b6b4412c 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -438,6 +438,43 @@ func.func @test_pad_invalid_padding_value(%arg0: tensor<10xi8>, %arg1: tensor<1x
return %1 : tensor<10xi8>
}
+// -----
+func.func @test_cond_if_wrong_terminator_op(%arg0: tensor<i1>) -> tensor<i32> {
+ %0 = "tosa.cond_if"(%arg0) ({
+ %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+ "tosa.yield"(%1) : (tensor<i32>) -> ()
+ }, {
+ // expected-error at +2 {{'func.return' op expects parent op 'func.func'}}
+ %2 = "tosa.const"() <{values = dense<2> : tensor<i32>}> : () -> tensor<i32>
+ "func.return"(%2) : (tensor<i32>) -> ()
+ }) : (tensor<i1>) -> tensor<i32>
+ return %0 : tensor<i32>
+}
+
+// -----
+func.func @test_cond_if_missing_then_terminator(%arg0: tensor<i1>) -> tensor<i32> {
+ %0 = "tosa.cond_if"(%arg0) ({
+ // expected-error at +1 {{block with no terminator}}
+ %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+ }, {
+ %2 = "tosa.const"() <{values = dense<2> : tensor<i32>}> : () -> tensor<i32>
+ "tosa.yield"(%2) : (tensor<i32>) -> ()
+ }) : (tensor<i1>) -> tensor<i32>
+ return %0 : tensor<i32>
+}
+
+// -----
+func.func @test_cond_if_missing_else_terminator(%arg0: tensor<i1>) -> tensor<i32> {
+ %0 = "tosa.cond_if"(%arg0) ({
+ %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+ "tosa.yield"(%1) : (tensor<i32>) -> ()
+ }, {
+ // expected-error at +1 {{block with no terminator}}
+ %2 = "tosa.const"() <{values = dense<2> : tensor<i32>}> : () -> tensor<i32>
+ }) : (tensor<i1>) -> tensor<i32>
+ return %0 : tensor<i32>
+}
+
// -----
func.func @test_cond_if_input_list_mismatch_then_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
More information about the Mlir-commits
mailing list