[Mlir-commits] [mlir] [mlir][tosa] Stop support the custom simplified form of COND_IF (PR #139576)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 12 09:32:49 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: TatWai Chong (tatwaichong)

<details>
<summary>Changes</summary>

Since the tensor_list_shape for input_list, output_list, then_graph, and else_graph is required to be equal according to the spec, this information must be explicitly provided during operation construction. The current custom simplified form does not meet this requirement. For example, the input_list and output_list can be empty in the simplified form. A new compatible simplified form will be introduced in the future if necessary.

---
Full diff: https://github.com/llvm/llvm-project/pull/139576.diff


9 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (-1) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (-59) 
- (modified) mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir (+9-14) 
- (modified) mlir/test/Dialect/Tosa/availability.mlir (+7-5) 
- (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+5-3) 
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+73-26) 
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+5-3) 
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+32-24) 
- (modified) mlir/test/Dialect/Tosa/verifier.mlir (+26-17) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 52bb0eb992b69..70aecfcfa3ec7 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2558,7 +2558,6 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
     SizedRegion<1>:$else_graph
   );
 
-  let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
 }
 
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 371c6dc27b428..2d7c80cbf7848 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -3518,65 +3518,6 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
   return std::nullopt;
 }
 
-// parse and print of IfOp refer to the implementation of SCF dialect.
-ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
-  // Create the regions for 'then'.
-  result.regions.reserve(2);
-  Region *thenRegion = result.addRegion();
-  Region *elseRegion = result.addRegion();
-
-  auto &builder = parser.getBuilder();
-  OpAsmParser::UnresolvedOperand cond;
-  // Create a i1 tensor type for the boolean condition.
-  Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
-  if (parser.parseOperand(cond) ||
-      parser.resolveOperand(cond, i1Type, result.operands))
-    return failure();
-  // Parse optional results type list.
-  if (parser.parseOptionalArrowTypeList(result.types))
-    return failure();
-  // Parse the 'then' region.
-  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
-    return failure();
-
-  // If we find an 'else' keyword then parse the 'else' region.
-  if (!parser.parseOptionalKeyword("else")) {
-    if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
-      return failure();
-  }
-
-  // Parse the optional attribute list.
-  if (parser.parseOptionalAttrDict(result.attributes))
-    return failure();
-  return success();
-}
-
-void IfOp::print(OpAsmPrinter &p) {
-  bool printBlockTerminators = false;
-
-  p << " " << getCondition();
-  if (!getResults().empty()) {
-    p << " -> (" << getResultTypes() << ")";
-    // Print yield explicitly if the op defines values.
-    printBlockTerminators = true;
-  }
-  p << ' ';
-  p.printRegion(getThenGraph(),
-                /*printEntryBlockArgs=*/false,
-                /*printBlockTerminators=*/printBlockTerminators);
-
-  // Print the 'else' regions if it exists and has a block.
-  auto &elseRegion = getElseGraph();
-  if (!elseRegion.empty()) {
-    p << " else ";
-    p.printRegion(elseRegion,
-                  /*printEntryBlockArgs=*/false,
-                  /*printBlockTerminators=*/printBlockTerminators);
-  }
-
-  p.printOptionalAttrDict((*this)->getAttrs());
-}
-
 LogicalResult IfOp::verify() {
   if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
                                  "'then_graph' arguments", getInputList(),
diff --git a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
index fa7a91cda0a47..78f5040eab97a 100644
--- a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
+++ b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
@@ -36,20 +36,15 @@ func.func @while_test(%arg0 : tensor<i32>) -> (tensor<i32>) {
 func.func @if_test(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> (tensor<f32>) {
   // CHECK: [[EX:%.+]] = tensor.extract [[ARG2]]
   // CHECK: [[IF:%.+]] = scf.if [[EX]] -> (tensor<f32>) {
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
-
-  // CHECK:   scf.yield [[ARG0]]
-    tosa.yield %arg0 : tensor<f32>
-
-  // CHECK: } else {
-  } else {
-
-  // CHECK:   scf.yield [[ARG1]]
-    tosa.yield %arg1 : tensor<f32>
-
-  // CHECK: }
-  // CHECK: return [[IF]]
-  }
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    // CHECK:   scf.yield [[ARG0]]
+    tosa.yield %arg3 : tensor<f32>
+  },  {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    // CHECK:   scf.yield [[ARG1]]
+    tosa.yield %arg4 : tensor<f32>
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
 
   return %0 : tensor<f32>
 }
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 75126a11ac504..5381d6c533d01 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -645,13 +645,15 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
 func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
   // CHECK: tosa.cond_if profiles: [ ]
   // CHECK: tosa.cond_if extensions: [ [controlflow] ]
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
-    %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
-  } else {
-    %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  },  {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
-  }
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
   return %0 : tensor<f32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 2364985442e43..c688b6592ed9f 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -337,13 +337,15 @@ func.func @test_cast_bf16_i32(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xi32
 // -----
 func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
   // expected-error at +1 {{'tosa.cond_if' op illegal: requires [controlflow]}}
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
-  } else {
+  },  {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
     %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
-  }
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
   return %0 : tensor<f32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index d24c1fa57883d..5b11aa782637a 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1503,40 +1503,87 @@ func.func @test_while_tensor_list_size(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1:
 
 // -----
 
-func.func @test_cond_if_max_nested_depth(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>, %arg3: tensor<i1>) -> tensor<f32> {
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
-    %1 = tosa.cond_if %arg3 -> (tensor<f32>) {
-      %2 = tosa.cond_if %arg2 -> (tensor<f32>) {
-        %3 = tosa.cond_if %arg3 -> (tensor<f32>) {
-          %4 = tosa.cond_if %arg2 -> (tensor<f32>) {
+func.func @test_cond_if_max_nested_depth(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>,  %arg3: tensor<i1>) -> tensor<f32> {
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+
+  // COM: then graph of IF-1
+  ^bb1(%a1: tensor<f32>, %b1: tensor<f32>):
+    %cond1 = tosa.equal %a1, %b1 : (tensor<f32>, tensor<f32>) -> tensor<i1>
+    %1 = "tosa.cond_if"(%cond1, %a1, %b1) ({
+
+    // COM: then graph of IF-2
+    ^bb2(%a2: tensor<f32>, %b2: tensor<f32>):
+      %cond2 = tosa.equal %a2, %b2 : (tensor<f32>, tensor<f32>) -> tensor<i1>
+      %2 = "tosa.cond_if"(%cond2, %a2, %b2) ({
+
+      // COM: then graph of IF-3
+      ^bb3(%a3: tensor<f32>, %b3: tensor<f32>):
+        %cond3 = tosa.equal %a3, %b3 : (tensor<f32>, tensor<f32>) -> tensor<i1>
+        %3 = "tosa.cond_if"(%cond3, %a3, %b3) ({
+
+        // COM: then graph of IF-4
+        ^bb4(%a4: tensor<f32>, %b4: tensor<f32>):
+          %cond4 = tosa.equal %a4, %b4 : (tensor<f32>, tensor<f32>) -> tensor<i1>
+          %4 = "tosa.cond_if"(%cond4, %a4, %b4) ({
+
+          // COM: then graph of IF-5
+          ^bb5(%a5: tensor<f32>, %b5: tensor<f32>):
+            %cond5 = tosa.equal %a5, %b5 : (tensor<f32>, tensor<f32>) -> tensor<i1>
             // expected-error at +1 {{'tosa.cond_if' op failed level check: 6 >= MAX_NESTING}}
-            %5 = tosa.cond_if %arg3 -> (tensor<f32>) {
+            %5 = "tosa.cond_if"(%cond5, %a5, %b5) ({
+
+            // COM: then graph of IF-6
+            ^bb6(%a6: tensor<f32>, %b6: tensor<f32>):
               %res = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
               tosa.yield %res : tensor<f32>
-            } else {
+            },  {
+
+            // COM: else graph of IF-6
+            ^bb6(%a6: tensor<f32>, %b6: tensor<f32>):
               tosa.yield %arg0 : tensor<f32>
-            }
+            }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+
             tosa.yield %5 : tensor<f32>
-          } else {
+          },  {
+
+            // COM: else graph of IF-5
+            ^bb5(%a5: tensor<f32>, %b5: tensor<f32>):
+              tosa.yield %arg0 : tensor<f32>
+          }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+
+            tosa.yield %4 : tensor<f32>
+        },  {
+
+          // COM: else graph of IF-4
+          ^bb4(%a4: tensor<f32>, %b4: tensor<f32>):
             tosa.yield %arg0 : tensor<f32>
-          }
-          tosa.yield %4 : tensor<f32>
-        } else {
-          tosa.yield %arg0 : tensor<f32>
-        }
-        tosa.yield %3 : tensor<f32>
-      } else {
+        }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+
+          tosa.yield %3 : tensor<f32>
+      },  {
+
+      // COM: else graph of IF-3
+      ^bb3(%a3: tensor<f32>, %b3: tensor<f32>):
         tosa.yield %arg0 : tensor<f32>
-      }
-      tosa.yield %2 : tensor<f32>
-    } else {
-      tosa.yield %arg0 : tensor<f32>
-    }
-    tosa.yield %1 : tensor<f32>
-  } else {
-    %res = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+      }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+
+        tosa.yield %2 : tensor<f32>
+    },  {
+
+      // COM: else graph of IF-2
+      ^bb2(%a2: tensor<f32>, %b2: tensor<f32>):
+        tosa.yield %arg0 : tensor<f32>
+    }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+
+      tosa.yield %1 : tensor<f32>
+  },  {
+
+  // COM: else graph of IF-1
+  ^bb1(%a1: tensor<f32>, %b1: tensor<f32>):
+    %res = tosa.sub %a1, %b1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %res : tensor<f32>
-  }
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+
   return %0 : tensor<f32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index e327ed900f45f..e3036cf07171f 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -781,13 +781,15 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
 // -----
 // CHECK-LABEL: cond_if
 func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
-  } else {
+  },  {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
     %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
-  }
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
   return %0 : tensor<f32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 1ad1e6c76c294..981e3cc7fc129 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1121,12 +1121,14 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens
   %b = tosa.log %arg1 : (tensor<f32>) -> tensor<f32>
 
   // CHECK: tosa.cond_if
-  // CHECK: -> (tensor<f32>)
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
-    tosa.yield %a : tensor<f32>
-  } else {
-    tosa.yield %b : tensor<f32>
-  }
+  // CHECK: -> tensor<f32>
+  %0 = "tosa.cond_if"(%arg2, %a, %b) ({
+  ^bb0(%a1: tensor<f32>, %b1: tensor<f32>):
+    tosa.yield %a1 : tensor<f32>
+  },  {
+  ^bb0(%a1: tensor<f32>, %b1: tensor<f32>):
+    tosa.yield %b1 : tensor<f32>
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
   return
 }
 
@@ -1135,12 +1137,14 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens
 // CHECK-LABEL: @if_test_dynamic
 func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
   // CHECK: tosa.cond_if
-  // CHECK: -> (tensor<?xf32>)
-  %0 = tosa.cond_if %arg2 -> (tensor<?xf32>) {
-    tosa.yield %arg0 : tensor<2xf32>
-  } else {
-    tosa.yield %arg1 : tensor<3xf32>
-  }
+  // CHECK: -> tensor<?xf32>
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%a: tensor<2xf32>, %b: tensor<3xf32>):
+    tosa.yield %a : tensor<2xf32>
+  },  {
+  ^bb0(%a: tensor<2xf32>, %b: tensor<3xf32>):
+    tosa.yield %b : tensor<3xf32>
+  }) : (tensor<i1>, tensor<2xf32>, tensor<3xf32>) -> tensor<?xf32>
   return
 }
 
@@ -1149,12 +1153,14 @@ func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 :
 // CHECK-LABEL: @if_test_unranked
 func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
   // CHECK: tosa.cond_if
-  // CHECK: -> (tensor<*xf32>)
-  %0 = tosa.cond_if %arg2 -> (tensor<*xf32>) {
-    tosa.yield %arg0 : tensor<f32>
-  } else {
-    tosa.yield %arg1 : tensor<3xf32>
-  }
+  // CHECK: -> tensor<*xf32>
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%a: tensor<f32>, %b: tensor<3xf32>):
+    tosa.yield %a : tensor<f32>
+  },  {
+  ^bb0(%a: tensor<f32>, %b: tensor<3xf32>):
+    tosa.yield %b : tensor<3xf32>
+  }) : (tensor<i1>, tensor<f32>, tensor<3xf32>) -> tensor<*xf32>
   return
 }
 
@@ -1163,14 +1169,16 @@ func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 :
 // CHECK-LABEL: @if_test_propagate
 func.func @if_test_propagate(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> () {
   // CHECK: tosa.cond_if
-  // CHECK: -> (tensor<f32>)
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
-    %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  // CHECK: -> tensor<f32>
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%a: tensor<f32>, %b: tensor<f32>):
+    %1 = tosa.add %a, %b : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
-  } else {
-    %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  },  {
+  ^bb0(%a: tensor<f32>, %b: tensor<f32>):
+    %1 = tosa.sub %a, %b : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
-  }
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
   return
 }
 
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 990e0d954f54e..e99608dfbeff4 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -502,14 +502,17 @@ func.func @test_cond_if_input_list_mismatch_else_block_2(%arg0: tensor<f32>, %ar
 
 func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
   // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (2) and 'output_list' (1)}}
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     %2 = tosa.add %1, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1, %2 : tensor<f32>, tensor<f32>
-  } else {
+  },  {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
     %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
-  }
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+
   return %0 : tensor<f32>
 }
 
@@ -517,13 +520,15 @@ func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor<f32>, %arg
 
 func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
   // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (1) and 'output_list' (2)}}
-  %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) {
+  %0, %2 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
-  } else {
+  },  {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
     %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
-  }
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
   return %0 : tensor<f32>
 }
 
@@ -531,14 +536,16 @@ func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor<f32>, %a
 
 func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
   // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (2) and 'output_list' (1)}}
-  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
-  } else {
-    %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
-    %2 = tosa.add %1, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  },  {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    %2 = tosa.add %1, %arg3 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1, %2 : tensor<f32>, tensor<f32>
-  }
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
   return %0 : tensor<f32>
 }
 
@@ -546,14 +553,16 @@ func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor<f32>, %arg
 
 func.func @test_cond_if_output_list_mismatch_else_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
   // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (1) and 'output_list' (2)}}
-  %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) {
-    %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
-    %2 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  %0, %2 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    %2 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1, %2 : tensor<f32>, tensor<f32>
-  } else {
-    %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  },  {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
-  }
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
   return %0 : tensor<f32>
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/139576


More information about the Mlir-commits mailing list