[Mlir-commits] [mlir] [MLIR][TOSA] Update IfOp print/parse to support ranked condition tens… (PR #149791)

Yuvaraj Venkatesh llvmlistbot at llvm.org
Fri Jul 25 09:25:50 PDT 2025


https://github.com/Yuvaraj-Venkatesh updated https://github.com/llvm/llvm-project/pull/149791

>From 878781071d70a31df9bca2f3ed02954adf7a39aa Mon Sep 17 00:00:00 2001
From: Yuvaraj Venkatesh <yuvaraj.venkatesh at arm.com>
Date: Mon, 14 Jul 2025 10:23:47 +0000
Subject: [PATCH] [MLIR][TOSA] Update IfOp print/parse to support ranked
 condition tensor and optional block arguments

This change extends the TOSA `cond_if` operation's print and parse logic to handle the following:

- The condition operand may now have any rank, as long as the total number of elements sums to 1.

  %1 = tosa.cond_if %0 : tensor<1x1x1xi1> -> tensor<4xf32>

- The `then` and `else` regions can now include optional block arguments. The updated IR syntax reflects this:

  %1 = tosa.cond_if %0 (%arg2 = %arg0, %arg3 = %arg1) : tensor<i1> (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>

- Removed parentheses around single result types in the printed representation, aligning with the `AsmPrinter` conventions.
---
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 122 ++++++++++++------
 .../Tosa/Transforms/TosaValidation.cpp        |  14 +-
 .../Conversion/TosaToSCF/tosa-to-scf.mlir     |   2 +-
 mlir/test/Dialect/Tosa/availability.mlir      |   2 +-
 mlir/test/Dialect/Tosa/controlflow.mlir       |  35 +++++
 mlir/test/Dialect/Tosa/error_if_check.mlir    |   2 +-
 mlir/test/Dialect/Tosa/invalid_extension.mlir |   2 +-
 mlir/test/Dialect/Tosa/level_check.mlir       |  12 +-
 mlir/test/Dialect/Tosa/ops.mlir               |   2 +-
 ...tosa-convert-integer-type-to-signless.mlir |   2 +-
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir |  16 +--
 mlir/test/Dialect/Tosa/verifier.mlir          |  85 +++++++++++-
 12 files changed, 228 insertions(+), 68 deletions(-)
 create mode 100644 mlir/test/Dialect/Tosa/controlflow.mlir

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index ecd93ff4c6e7b..3cafb199d2db3 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -3647,6 +3647,22 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
   return std::nullopt;
 }
 
+static void printInitializationList(OpAsmPrinter &parser,
+                                    Block::BlockArgListType blocksArgs,
+                                    ValueRange initializers,
+                                    StringRef prefix = "") {
+  assert(blocksArgs.size() == initializers.size() &&
+         "expected same length of arguments and initializers");
+  if (initializers.empty())
+    return;
+
+  parser << prefix << '(';
+  llvm::interleaveComma(
+      llvm::zip(blocksArgs, initializers), parser,
+      [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
+  parser << ")";
+}
+
 // parse and print of IfOp refer to the implementation of SCF dialect.
 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
   // Create the regions for 'then'.
@@ -3654,16 +3670,64 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
   Region *thenRegion = result.addRegion();
   Region *elseRegion = result.addRegion();
 
-  auto &builder = parser.getBuilder();
   OpAsmParser::UnresolvedOperand cond;
-  // Create a i1 tensor type for the boolean condition.
-  Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
-  if (parser.parseOperand(cond) ||
-      parser.resolveOperand(cond, i1Type, result.operands))
+
+  if (parser.parseOperand(cond))
     return failure();
-  // Parse optional results type list.
-  if (parser.parseOptionalArrowTypeList(result.types))
+
+  SmallVector<OpAsmParser::Argument, 4> regionArgs;
+  SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
+
+  // Parse the optional block arguments
+  OptionalParseResult listResult =
+      parser.parseOptionalAssignmentList(regionArgs, operands);
+  if (listResult.has_value() && failed(listResult.value()))
     return failure();
+
+  // Parse a colon.
+  if (failed(parser.parseColon()))
+    return parser.emitError(parser.getCurrentLocation(),
+                            "expected type for condition operand");
+
+  // Parse the type of the condition operand
+  Type condType;
+  if (failed(parser.parseType(condType)))
+    return parser.emitError(parser.getCurrentLocation(),
+                            "expected type for condition operand");
+
+  // Resolve operand with provided type
+  if (failed(parser.resolveOperand(cond, condType, result.operands)))
+    return failure();
+
+  // Parse optional block arg types
+  if (listResult.has_value()) {
+    FunctionType functionType;
+
+    if (failed(parser.parseType(functionType)))
+      return parser.emitError(parser.getCurrentLocation())
+             << "expected list of types for block arguments "
+             << "followed by arrow type and list of return types";
+
+    result.addTypes(functionType.getResults());
+
+    if (functionType.getNumInputs() != operands.size()) {
+      return parser.emitError(parser.getCurrentLocation())
+             << "expected as many input types as operands "
+             << "(expected " << operands.size() << " got "
+             << functionType.getNumInputs() << ")";
+    }
+
+    // Resolve input operands.
+    if (failed(parser.resolveOperands(operands, functionType.getInputs(),
+                                      parser.getCurrentLocation(),
+                                      result.operands)))
+      return failure();
+  } else {
+    // Parse optional results type list.
+    if (parser.parseOptionalArrowTypeList(result.types))
+      return failure();
+  }
+
   // Parse the 'then' region.
   if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
     return failure();
@@ -3681,26 +3745,28 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
 }
 
 void IfOp::print(OpAsmPrinter &p) {
-  bool printBlockTerminators = false;
-
   p << " " << getCondition();
-  if (!getResults().empty()) {
-    p << " -> (" << getResultTypes() << ")";
-    // Print yield explicitly if the op defines values.
-    printBlockTerminators = true;
+
+  printInitializationList(p, getThenGraph().front().getArguments(),
+                          getInputList(), " ");
+  p << " : ";
+  p << getCondition().getType();
+
+  if (!getInputList().empty()) {
+    p << " (";
+    llvm::interleaveComma(getInputList().getTypes(), p);
+    p << ")";
   }
-  p << ' ';
-  p.printRegion(getThenGraph(),
-                /*printEntryBlockArgs=*/false,
-                /*printBlockTerminators=*/printBlockTerminators);
+  p.printArrowTypeList(getResultTypes());
+  p << " ";
+
+  p.printRegion(getThenGraph());
 
   // Print the 'else' regions if it exists and has a block.
   auto &elseRegion = getElseGraph();
   if (!elseRegion.empty()) {
     p << " else ";
-    p.printRegion(elseRegion,
-                  /*printEntryBlockArgs=*/false,
-                  /*printBlockTerminators=*/printBlockTerminators);
+    p.printRegion(elseRegion);
   }
 
   p.printOptionalAttrDict((*this)->getAttrs());
@@ -3909,22 +3975,6 @@ ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
                  parser.parseOptionalAttrDictWithKeyword(result.attributes));
 }
 
-static void printInitializationList(OpAsmPrinter &parser,
-                                    Block::BlockArgListType blocksArgs,
-                                    ValueRange initializers,
-                                    StringRef prefix = "") {
-  assert(blocksArgs.size() == initializers.size() &&
-         "expected same length of arguments and initializers");
-  if (initializers.empty())
-    return;
-
-  parser << prefix << '(';
-  llvm::interleaveComma(
-      llvm::zip(blocksArgs, initializers), parser,
-      [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
-  parser << ")";
-}
-
 void WhileOp::print(OpAsmPrinter &parser) {
   printInitializationList(parser, getCondGraph().front().getArguments(),
                           getInputList(), " ");
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 32b5fb63a6ece..8ec77654fb896 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1248,16 +1248,14 @@ bool checkErrorIfCondIf(Operation *op) {
   // })
   //
   // Simplified:
-  // %0 = tosa.cond_if %arg2 {
-  //   tosa.yield %arg0
+  // %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) {
+  //   ^bb0(%arg3, %arg4):
+  //   tosa.yield %arg3
   // } else {
-  //   tosa.yield %arg1
+  //   ^bb0(%arg3, %arg4):
+  //   tosa.yield %arg4
   // }
-  //
-  // Unfortunately, the simplified syntax does not encapsulate values
-  // used in then/else regions (see 'simplified' example above), so it
-  // must be rewritten to use the generic syntax in order to be conformant
-  // to the specification.
+
   return failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) ||
          failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else"));
 }
diff --git a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
index fa7a91cda0a47..b6f2383ac81fc 100644
--- a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
+++ b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
@@ -36,7 +36,7 @@ func.func @while_test(%arg0 : tensor<i32>) -> (tensor<i32>) {
 func.func @if_test(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> (tensor<f32>) {
   // CHECK: [[EX:%.+]] = tensor.extract [[ARG2]]
   // CHECK: [[IF:%.+]] = scf.if [[EX]] -> (tensor<f32>) {
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
 
   // CHECK:   scf.yield [[ARG0]]
     tosa.yield %arg0 : tensor<f32>
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 0176fc2883518..6398161126e80 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -645,7 +645,7 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
 func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
   // CHECK: tosa.cond_if profiles: [ ]
   // CHECK: tosa.cond_if extensions: [ [controlflow] ]
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
   } else {
diff --git a/mlir/test/Dialect/Tosa/controlflow.mlir b/mlir/test/Dialect/Tosa/controlflow.mlir
new file mode 100644
index 0000000000000..06312c7ba2df6
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/controlflow.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt -split-input-file %s | FileCheck %s
+
+// -----
+
+func.func @condif_cond_type_check(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // CHECK: tosa.cond_if %[[ARG2:.*]] : tensor<i1> -> tensor<f32> {
+  %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
+    %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  // CHECK:     } else {
+  } else {
+    %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+ 
+// -----
+
+func.func @condif_block_args_check(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // CHECK: tosa.cond_if %[[ARG2:.*]] (%[[ARG3:.*]] = %[[ARG0:.*]], %[[ARG4:.*]] = %[[ARG1:.*]]) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> {
+  // CHECK-NEXT: ^bb0(%[[ARG3]]: tensor<f32>, %[[ARG4]]: tensor<f32>):
+  %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  // CHECK:     } else {
+  // CHECK-NEXT: ^bb0(%[[ARG3]]: tensor<f32>, %[[ARG4]]: tensor<f32>):
+  } else {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+} 
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index eb25011ff3a9d..fad1bec0e3ecc 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -259,7 +259,7 @@ func.func @test_cond_if_else_not_isolated_from_above(%arg0: tensor<f32>, %arg1:
 
 func.func @test_cond_if_simplified_form_not_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
   // expected-error at +1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the 'then' region is isolated from above.}}
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  %0 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>) {
     tosa.yield %arg0 : tensor<f32>
   } else {
     tosa.yield %arg1 : tensor<f32>
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 5630c33639d86..3154f541e0519 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -337,7 +337,7 @@ func.func @test_cast_bf16_i32(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xi32
 // -----
 func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
   // expected-error at +1 {{'tosa.cond_if' op illegal: requires [controlflow]}}
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
   } else {
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 0dddf26fb1f85..cbe0056bafe22 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1506,13 +1506,13 @@ func.func @test_while_tensor_list_size(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1:
 // -----
 
 func.func @test_cond_if_max_nested_depth(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>, %arg3: tensor<i1>) -> tensor<f32> {
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
-    %1 = tosa.cond_if %arg3 -> (tensor<f32>) {
-      %2 = tosa.cond_if %arg2 -> (tensor<f32>) {
-        %3 = tosa.cond_if %arg3 -> (tensor<f32>) {
-          %4 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
+    %1 = tosa.cond_if %arg3 : tensor<i1>-> tensor<f32> {
+      %2 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
+        %3 = tosa.cond_if %arg3 : tensor<i1> -> tensor<f32> {
+          %4 = tosa.cond_if %arg2 : tensor<i1>  -> tensor<f32> {
             // expected-error at +1 {{'tosa.cond_if' op failed level check: 6 >= MAX_NESTING}}
-            %5 = tosa.cond_if %arg3 -> (tensor<f32>) {
+            %5 = tosa.cond_if %arg3 : tensor<i1> -> tensor<f32> {
               %res = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
               tosa.yield %res : tensor<f32>
             } else {
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index ef51197e86d56..30361a882afe5 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -839,7 +839,7 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
 // -----
 // CHECK-LABEL: cond_if
 func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
   } else {
diff --git a/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
index 38ac8d8fb66d9..e957bdd15e1ec 100644
--- a/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
@@ -54,7 +54,7 @@ func.func @test_no_change(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
 // CHECK-LABEL: test_regions
 // CHECK: %arg0: tensor<i8>, %arg1: tensor<i8>
 func.func @test_regions(%arg0: tensor<ui8>, %arg1: tensor<ui8>, %arg2: tensor<i1>) -> tensor<ui8> {
-  // CHECK: tosa.cond_if %arg2 -> (tensor<i8>)
+  // CHECK: tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<i8>, tensor<i8>) -> tensor<i8> 
   %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
   ^bb0(%arg3: tensor<ui8>, %arg4: tensor<ui8>):
     // CHECK: %1 = tosa.add %arg0, %arg1 : (tensor<i8>, tensor<i8>) -> tensor<i8>
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 9d43f8998da93..ece4bf8017abd 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1166,8 +1166,8 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens
   %b = tosa.log %arg1 : (tensor<f32>) -> tensor<f32>
 
   // CHECK: tosa.cond_if
-  // CHECK: -> (tensor<f32>)
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  // CHECK: -> tensor<f32>
+  %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
     tosa.yield %a : tensor<f32>
   } else {
     tosa.yield %b : tensor<f32>
@@ -1180,8 +1180,8 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens
 // CHECK-LABEL: @if_test_dynamic
 func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
   // CHECK: tosa.cond_if
-  // CHECK: -> (tensor<?xf32>)
-  %0 = tosa.cond_if %arg2 -> (tensor<?xf32>) {
+  // CHECK: -> tensor<?xf32>
+  %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<?xf32> {
     tosa.yield %arg0 : tensor<2xf32>
   } else {
     tosa.yield %arg1 : tensor<3xf32>
@@ -1194,8 +1194,8 @@ func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 :
 // CHECK-LABEL: @if_test_unranked
 func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
   // CHECK: tosa.cond_if
-  // CHECK: -> (tensor<*xf32>)
-  %0 = tosa.cond_if %arg2 -> (tensor<*xf32>) {
+  // CHECK: -> tensor<*xf32>
+  %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<*xf32> {
     tosa.yield %arg0 : tensor<f32>
   } else {
     tosa.yield %arg1 : tensor<3xf32>
@@ -1208,8 +1208,8 @@ func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 :
 // CHECK-LABEL: @if_test_propagate
 func.func @if_test_propagate(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> () {
   // CHECK: tosa.cond_if
-  // CHECK: -> (tensor<f32>)
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  // CHECK: -> tensor<f32>
+  %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
   } else {
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index b3052369b055e..2a937b0a88f28 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -500,9 +500,39 @@ func.func @test_cond_if_input_list_mismatch_else_block_2(%arg0: tensor<f32>, %ar
 
 // -----
 
+func.func @test_cond_if_input_list_mismatch_else_block_simple(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'else_graph' arguments (1) and 'input_list' (2)}}
+  %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  } else {
+  ^bb0(%arg3: tensor<f32>):
+    tosa.yield %arg3 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_input_list_mismatch_else_block_simple_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'else_graph' arguments (2) and 'input_list' (1)}}
+  %0 = tosa.cond_if %arg2 (%arg3 = %arg0) : tensor<i1> (tensor<f32>) -> tensor<f32> {
+  ^bb0(%arg3: tensor<f32>):
+    tosa.yield %arg3 : tensor<f32>
+  } else {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+
+// -----
+
 func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
   // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (2) and 'output_list' (1)}}
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     %2 = tosa.add %1, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1, %2 : tensor<f32>, tensor<f32>
@@ -517,7 +547,7 @@ func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor<f32>, %arg
 
 func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
   // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (1) and 'output_list' (2)}}
-  %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) {
+  %0, %2 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>, tensor<f32>) {
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
   } else {
@@ -531,7 +561,7 @@ func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor<f32>, %a
 
 func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
   // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (2) and 'output_list' (1)}}
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
   } else {
@@ -546,7 +576,7 @@ func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor<f32>, %arg
 
 func.func @test_cond_if_output_list_mismatch_else_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
   // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (1) and 'output_list' (2)}}
-  %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) {
+  %0, %2 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>, tensor<f32>) {
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     %2 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1, %2 : tensor<f32>, tensor<f32>
@@ -574,6 +604,53 @@ func.func @test_cond_if_cond_input_not_size_one(%arg0: tensor<f32>, %arg1: tenso
 
 // -----
 
+// CHECK-LABEL: cond_if_cond_type
+func.func @test_cond_if_cond_type(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error at +2 {{expected ':'}}
+  // expected-error at +1 {{custom op 'tosa.cond_if' expected type for condition operand}}
+  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+    tosa.yield %arg0 : tensor<f32>
+  } else {
+    tosa.yield %arg1 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_input_list_type_mismatch_simple(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error at +1 {{custom op 'tosa.cond_if' expected as many input types as operands (expected 2 got 0)}}
+  %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> () -> tensor<f32> {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  } else {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_incorrect_type_simple(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error at +2 {{expected non-function type}}
+  // expected-error at +1 {{custom op 'tosa.cond_if' expected list of types for block arguments followed by arrow type and list of return types}}
+  %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (%arg3) -> tensor<f32> {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  } else {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+
+// -----
+
 func.func @test_while_loop_input_list_mismatch_body_block_in(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
   %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
   // expected-error at +1 {{'tosa.while_loop' op require same number of values in 'body_graph' arguments (3) and 'input_list' (2)}}



More information about the Mlir-commits mailing list