[Mlir-commits] [mlir] [MLIR][TOSA] Update IfOp print/parse to support ranked condition tens… (PR #149791)
Yuvaraj Venkatesh
llvmlistbot at llvm.org
Fri Jul 25 09:25:50 PDT 2025
https://github.com/Yuvaraj-Venkatesh updated https://github.com/llvm/llvm-project/pull/149791
>From 878781071d70a31df9bca2f3ed02954adf7a39aa Mon Sep 17 00:00:00 2001
From: Yuvaraj Venkatesh <yuvaraj.venkatesh at arm.com>
Date: Mon, 14 Jul 2025 10:23:47 +0000
Subject: [PATCH] [MLIR][TOSA] Update IfOp print/parse to support ranked
condition tensor and optional block arguments
This change extends the TOSA `cond_if` operation's print and parse logic to handle the following:
- The condition operand may now have any rank, as long as the total number of elements sums to 1.
%1 = tosa.cond_if %0 : tensor<1x1x1xi1> -> tensor<4xf32>
- The `then` and `else` regions can now include optional block arguments. The updated IR syntax reflects this:
%1 = tosa.cond_if %0 (%arg2 = %arg0, %arg3 = %arg1) : tensor<i1> (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
- Removed parentheses around single result types in the printed representation, aligning with the `AsmPrinter` conventions.
---
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 122 ++++++++++++------
.../Tosa/Transforms/TosaValidation.cpp | 14 +-
.../Conversion/TosaToSCF/tosa-to-scf.mlir | 2 +-
mlir/test/Dialect/Tosa/availability.mlir | 2 +-
mlir/test/Dialect/Tosa/controlflow.mlir | 35 +++++
mlir/test/Dialect/Tosa/error_if_check.mlir | 2 +-
mlir/test/Dialect/Tosa/invalid_extension.mlir | 2 +-
mlir/test/Dialect/Tosa/level_check.mlir | 12 +-
mlir/test/Dialect/Tosa/ops.mlir | 2 +-
...tosa-convert-integer-type-to-signless.mlir | 2 +-
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 16 +--
mlir/test/Dialect/Tosa/verifier.mlir | 85 +++++++++++-
12 files changed, 228 insertions(+), 68 deletions(-)
create mode 100644 mlir/test/Dialect/Tosa/controlflow.mlir
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index ecd93ff4c6e7b..3cafb199d2db3 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -3647,6 +3647,22 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
return std::nullopt;
}
+static void printInitializationList(OpAsmPrinter &parser,
+ Block::BlockArgListType blocksArgs,
+ ValueRange initializers,
+ StringRef prefix = "") {
+ assert(blocksArgs.size() == initializers.size() &&
+ "expected same length of arguments and initializers");
+ if (initializers.empty())
+ return;
+
+ parser << prefix << '(';
+ llvm::interleaveComma(
+ llvm::zip(blocksArgs, initializers), parser,
+ [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
+ parser << ")";
+}
+
// parse and print of IfOp refer to the implementation of SCF dialect.
ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
// Create the regions for 'then'.
@@ -3654,16 +3670,64 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
Region *thenRegion = result.addRegion();
Region *elseRegion = result.addRegion();
- auto &builder = parser.getBuilder();
OpAsmParser::UnresolvedOperand cond;
- // Create a i1 tensor type for the boolean condition.
- Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
- if (parser.parseOperand(cond) ||
- parser.resolveOperand(cond, i1Type, result.operands))
+
+ if (parser.parseOperand(cond))
return failure();
- // Parse optional results type list.
- if (parser.parseOptionalArrowTypeList(result.types))
+
+ SmallVector<OpAsmParser::Argument, 4> regionArgs;
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
+
+ // Parse the optional block arguments
+ OptionalParseResult listResult =
+ parser.parseOptionalAssignmentList(regionArgs, operands);
+ if (listResult.has_value() && failed(listResult.value()))
return failure();
+
+ // Parse a colon.
+ if (failed(parser.parseColon()))
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected type for condition operand");
+
+ // Parse the type of the condition operand
+ Type condType;
+ if (failed(parser.parseType(condType)))
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected type for condition operand");
+
+ // Resolve operand with provided type
+ if (failed(parser.resolveOperand(cond, condType, result.operands)))
+ return failure();
+
+ // Parse optional block arg types
+ if (listResult.has_value()) {
+ FunctionType functionType;
+
+ if (failed(parser.parseType(functionType)))
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected list of types for block arguments "
+ << "followed by arrow type and list of return types";
+
+ result.addTypes(functionType.getResults());
+
+ if (functionType.getNumInputs() != operands.size()) {
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected as many input types as operands "
+ << "(expected " << operands.size() << " got "
+ << functionType.getNumInputs() << ")";
+ }
+
+ // Resolve input operands.
+ if (failed(parser.resolveOperands(operands, functionType.getInputs(),
+ parser.getCurrentLocation(),
+ result.operands)))
+ return failure();
+ } else {
+ // Parse optional results type list.
+ if (parser.parseOptionalArrowTypeList(result.types))
+ return failure();
+ }
+
// Parse the 'then' region.
if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
return failure();
@@ -3681,26 +3745,28 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
}
void IfOp::print(OpAsmPrinter &p) {
- bool printBlockTerminators = false;
-
p << " " << getCondition();
- if (!getResults().empty()) {
- p << " -> (" << getResultTypes() << ")";
- // Print yield explicitly if the op defines values.
- printBlockTerminators = true;
+
+ printInitializationList(p, getThenGraph().front().getArguments(),
+ getInputList(), " ");
+ p << " : ";
+ p << getCondition().getType();
+
+ if (!getInputList().empty()) {
+ p << " (";
+ llvm::interleaveComma(getInputList().getTypes(), p);
+ p << ")";
}
- p << ' ';
- p.printRegion(getThenGraph(),
- /*printEntryBlockArgs=*/false,
- /*printBlockTerminators=*/printBlockTerminators);
+ p.printArrowTypeList(getResultTypes());
+ p << " ";
+
+ p.printRegion(getThenGraph());
// Print the 'else' regions if it exists and has a block.
auto &elseRegion = getElseGraph();
if (!elseRegion.empty()) {
p << " else ";
- p.printRegion(elseRegion,
- /*printEntryBlockArgs=*/false,
- /*printBlockTerminators=*/printBlockTerminators);
+ p.printRegion(elseRegion);
}
p.printOptionalAttrDict((*this)->getAttrs());
@@ -3909,22 +3975,6 @@ ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
parser.parseOptionalAttrDictWithKeyword(result.attributes));
}
-static void printInitializationList(OpAsmPrinter &parser,
- Block::BlockArgListType blocksArgs,
- ValueRange initializers,
- StringRef prefix = "") {
- assert(blocksArgs.size() == initializers.size() &&
- "expected same length of arguments and initializers");
- if (initializers.empty())
- return;
-
- parser << prefix << '(';
- llvm::interleaveComma(
- llvm::zip(blocksArgs, initializers), parser,
- [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
- parser << ")";
-}
-
void WhileOp::print(OpAsmPrinter &parser) {
printInitializationList(parser, getCondGraph().front().getArguments(),
getInputList(), " ");
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 32b5fb63a6ece..8ec77654fb896 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1248,16 +1248,14 @@ bool checkErrorIfCondIf(Operation *op) {
// })
//
// Simplified:
- // %0 = tosa.cond_if %arg2 {
- // tosa.yield %arg0
+ // %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) {
+ // ^bb0(%arg3, %arg4):
+ // tosa.yield %arg3
// } else {
- // tosa.yield %arg1
+ // ^bb0(%arg3, %arg4):
+ // tosa.yield %arg4
// }
- //
- // Unfortunately, the simplified syntax does not encapsulate values
- // used in then/else regions (see 'simplified' example above), so it
- // must be rewritten to use the generic syntax in order to be conformant
- // to the specification.
+
return failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) ||
failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else"));
}
diff --git a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
index fa7a91cda0a47..b6f2383ac81fc 100644
--- a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
+++ b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
@@ -36,7 +36,7 @@ func.func @while_test(%arg0 : tensor<i32>) -> (tensor<i32>) {
func.func @if_test(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> (tensor<f32>) {
// CHECK: [[EX:%.+]] = tensor.extract [[ARG2]]
// CHECK: [[IF:%.+]] = scf.if [[EX]] -> (tensor<f32>) {
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
// CHECK: scf.yield [[ARG0]]
tosa.yield %arg0 : tensor<f32>
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 0176fc2883518..6398161126e80 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -645,7 +645,7 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// CHECK: tosa.cond_if profiles: [ ]
// CHECK: tosa.cond_if extensions: [ [controlflow] ]
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
} else {
diff --git a/mlir/test/Dialect/Tosa/controlflow.mlir b/mlir/test/Dialect/Tosa/controlflow.mlir
new file mode 100644
index 0000000000000..06312c7ba2df6
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/controlflow.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt -split-input-file %s | FileCheck %s
+
+// -----
+
+func.func @condif_cond_type_check(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // CHECK: tosa.cond_if %[[ARG2:.*]] : tensor<i1> -> tensor<f32> {
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
+ %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ // CHECK: } else {
+ } else {
+ %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @condif_block_args_check(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // CHECK: tosa.cond_if %[[ARG2:.*]] (%[[ARG3:.*]] = %[[ARG0:.*]], %[[ARG4:.*]] = %[[ARG1:.*]]) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> {
+ // CHECK-NEXT: ^bb0(%[[ARG3]]: tensor<f32>, %[[ARG4]]: tensor<f32>):
+ %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ // CHECK: } else {
+ // CHECK-NEXT: ^bb0(%[[ARG3]]: tensor<f32>, %[[ARG4]]: tensor<f32>):
+ } else {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index eb25011ff3a9d..fad1bec0e3ecc 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -259,7 +259,7 @@ func.func @test_cond_if_else_not_isolated_from_above(%arg0: tensor<f32>, %arg1:
func.func @test_cond_if_simplified_form_not_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// expected-error at +1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the 'then' region is isolated from above.}}
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>) {
tosa.yield %arg0 : tensor<f32>
} else {
tosa.yield %arg1 : tensor<f32>
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 5630c33639d86..3154f541e0519 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -337,7 +337,7 @@ func.func @test_cast_bf16_i32(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xi32
// -----
func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// expected-error at +1 {{'tosa.cond_if' op illegal: requires [controlflow]}}
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
} else {
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 0dddf26fb1f85..cbe0056bafe22 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1506,13 +1506,13 @@ func.func @test_while_tensor_list_size(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1:
// -----
func.func @test_cond_if_max_nested_depth(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>, %arg3: tensor<i1>) -> tensor<f32> {
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
- %1 = tosa.cond_if %arg3 -> (tensor<f32>) {
- %2 = tosa.cond_if %arg2 -> (tensor<f32>) {
- %3 = tosa.cond_if %arg3 -> (tensor<f32>) {
- %4 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
+ %1 = tosa.cond_if %arg3 : tensor<i1>-> tensor<f32> {
+ %2 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
+ %3 = tosa.cond_if %arg3 : tensor<i1> -> tensor<f32> {
+ %4 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
// expected-error at +1 {{'tosa.cond_if' op failed level check: 6 >= MAX_NESTING}}
- %5 = tosa.cond_if %arg3 -> (tensor<f32>) {
+ %5 = tosa.cond_if %arg3 : tensor<i1> -> tensor<f32> {
%res = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %res : tensor<f32>
} else {
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index ef51197e86d56..30361a882afe5 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -839,7 +839,7 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
// -----
// CHECK-LABEL: cond_if
func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
} else {
diff --git a/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
index 38ac8d8fb66d9..e957bdd15e1ec 100644
--- a/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
@@ -54,7 +54,7 @@ func.func @test_no_change(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
// CHECK-LABEL: test_regions
// CHECK: %arg0: tensor<i8>, %arg1: tensor<i8>
func.func @test_regions(%arg0: tensor<ui8>, %arg1: tensor<ui8>, %arg2: tensor<i1>) -> tensor<ui8> {
- // CHECK: tosa.cond_if %arg2 -> (tensor<i8>)
+ // CHECK: tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<i8>, tensor<i8>) -> tensor<i8>
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
^bb0(%arg3: tensor<ui8>, %arg4: tensor<ui8>):
// CHECK: %1 = tosa.add %arg0, %arg1 : (tensor<i8>, tensor<i8>) -> tensor<i8>
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 9d43f8998da93..ece4bf8017abd 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1166,8 +1166,8 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens
%b = tosa.log %arg1 : (tensor<f32>) -> tensor<f32>
// CHECK: tosa.cond_if
- // CHECK: -> (tensor<f32>)
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ // CHECK: -> tensor<f32>
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
tosa.yield %a : tensor<f32>
} else {
tosa.yield %b : tensor<f32>
@@ -1180,8 +1180,8 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens
// CHECK-LABEL: @if_test_dynamic
func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
// CHECK: tosa.cond_if
- // CHECK: -> (tensor<?xf32>)
- %0 = tosa.cond_if %arg2 -> (tensor<?xf32>) {
+ // CHECK: -> tensor<?xf32>
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<?xf32> {
tosa.yield %arg0 : tensor<2xf32>
} else {
tosa.yield %arg1 : tensor<3xf32>
@@ -1194,8 +1194,8 @@ func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 :
// CHECK-LABEL: @if_test_unranked
func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
// CHECK: tosa.cond_if
- // CHECK: -> (tensor<*xf32>)
- %0 = tosa.cond_if %arg2 -> (tensor<*xf32>) {
+ // CHECK: -> tensor<*xf32>
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<*xf32> {
tosa.yield %arg0 : tensor<f32>
} else {
tosa.yield %arg1 : tensor<3xf32>
@@ -1208,8 +1208,8 @@ func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 :
// CHECK-LABEL: @if_test_propagate
func.func @if_test_propagate(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> () {
// CHECK: tosa.cond_if
- // CHECK: -> (tensor<f32>)
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ // CHECK: -> tensor<f32>
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
} else {
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index b3052369b055e..2a937b0a88f28 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -500,9 +500,39 @@ func.func @test_cond_if_input_list_mismatch_else_block_2(%arg0: tensor<f32>, %ar
// -----
+func.func @test_cond_if_input_list_mismatch_else_block_simple(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'else_graph' arguments (1) and 'input_list' (2)}}
+ %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ } else {
+ ^bb0(%arg3: tensor<f32>):
+ tosa.yield %arg3 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_input_list_mismatch_else_block_simple_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'else_graph' arguments (2) and 'input_list' (1)}}
+ %0 = tosa.cond_if %arg2 (%arg3 = %arg0) : tensor<i1> (tensor<f32>) -> tensor<f32> {
+ ^bb0(%arg3: tensor<f32>):
+ tosa.yield %arg3 : tensor<f32>
+ } else {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+
func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// expected-error at +1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (2) and 'output_list' (1)}}
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
%2 = tosa.add %1, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1, %2 : tensor<f32>, tensor<f32>
@@ -517,7 +547,7 @@ func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor<f32>, %arg
func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// expected-error at +1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (1) and 'output_list' (2)}}
- %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) {
+ %0, %2 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>, tensor<f32>) {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
} else {
@@ -531,7 +561,7 @@ func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor<f32>, %a
func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// expected-error at +1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (2) and 'output_list' (1)}}
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
} else {
@@ -546,7 +576,7 @@ func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor<f32>, %arg
func.func @test_cond_if_output_list_mismatch_else_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// expected-error at +1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (1) and 'output_list' (2)}}
- %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) {
+ %0, %2 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>, tensor<f32>) {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
%2 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1, %2 : tensor<f32>, tensor<f32>
@@ -574,6 +604,53 @@ func.func @test_cond_if_cond_input_not_size_one(%arg0: tensor<f32>, %arg1: tenso
// -----
+// CHECK-LABEL: cond_if_cond_type
+func.func @test_cond_if_cond_type(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error at +2 {{expected ':'}}
+ // expected-error at +1 {{custom op 'tosa.cond_if' expected type for condition operand}}
+ %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ tosa.yield %arg0 : tensor<f32>
+ } else {
+ tosa.yield %arg1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_input_list_type_mismatch_simple(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error at +1 {{custom op 'tosa.cond_if' expected as many input types as operands (expected 2 got 0)}}
+ %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> () -> tensor<f32> {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ } else {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_incorrect_type_simple(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error at +2 {{expected non-function type}}
+ // expected-error at +1 {{custom op 'tosa.cond_if' expected list of types for block arguments followed by arrow type and list of return types}}
+ %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (%arg3) -> tensor<f32> {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ } else {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+
func.func @test_while_loop_input_list_mismatch_body_block_in(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
%0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
// expected-error at +1 {{'tosa.while_loop' op require same number of values in 'body_graph' arguments (3) and 'input_list' (2)}}
More information about the Mlir-commits
mailing list