[Mlir-commits] [mlir] [mlir][tosa] Add error if and level checks for COND_IF & WHILE_LOOP (PR #136194)

TatWai Chong llvmlistbot at llvm.org
Thu Apr 17 13:34:41 PDT 2025


https://github.com/tatwaichong created https://github.com/llvm/llvm-project/pull/136194

Error if checks: verify whether the same length and type between input list, output list, and control-flow blocks.

Level_checks: verify whether the nested depth exceeds MAX_NESTING.

>From a8c744b8236e850e6938f66e681b633742765326 Mon Sep 17 00:00:00 2001
From: TatWai Chong <tatwai.chong at arm.com>
Date: Fri, 11 Apr 2025 19:54:49 -0700
Subject: [PATCH] [mlir][tosa] Add error if and level checks for COND_IF &
 WHILE_LOOP

Error if checks: verify whether the same length and type between
input list, output list, and control-flow blocks.

Level_checks: verify whether the nested depth exceeds MAX_NESTING.
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  |   2 +
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 130 +++++++
 .../Tosa/Transforms/TosaValidation.cpp        |  42 +++
 mlir/test/Dialect/Tosa/level_check.mlir       | 303 +++++++++++----
 mlir/test/Dialect/Tosa/verifier.mlir          | 349 ++++++++++++++++++
 5 files changed, 747 insertions(+), 79 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index c94edad62cac7..70702647bb50a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2559,6 +2559,7 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
   );
 
   let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -2597,6 +2598,7 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
   );
 
   let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
 }
 
 include "mlir/Dialect/Tosa/IR/TosaUtilOps.td"
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 8b4f6ef0d0980..365ea458d25d5 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -428,6 +428,51 @@ static LogicalResult verifyConvOpModes(T op) {
   return success();
 }
 
+// Verify whether same length and type of block arguments and tensor list.
+static LogicalResult errorIfTensorListShapeMismatch(Operation *op,
+                                                    ValueRange blocksArgs,
+                                                    StringRef blockName,
+                                                    ValueRange tensorList,
+                                                    StringRef listName) {
+  if (blocksArgs.size() != tensorList.size())
+    return op->emitOpError() << "require same number of values in " << blockName
+                             << " (" << blocksArgs.size() << ") and "
+                             << listName << " (" << tensorList.size() << ")";
+
+  for (auto [bbArgType, opArgType] :
+       llvm::zip_equal(blocksArgs.getTypes(), tensorList.getTypes())) {
+    ShapeAdaptor bbShapeAdaptor(bbArgType);
+    ShapeAdaptor opShapeAdaptor(opArgType);
+
+    if (!bbShapeAdaptor.hasRank() || !opShapeAdaptor.hasRank())
+      continue;
+
+    if (!bbShapeAdaptor.hasStaticShape() || !opShapeAdaptor.hasStaticShape())
+      continue;
+
+    if (bbArgType != opArgType)
+      return op->emitOpError()
+             << "require same shapes for " << blockName << " (" << bbArgType
+             << ") and " << listName << " (" << opArgType << ")";
+  }
+
+  return success();
+}
+
+static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
+  ShapeAdaptor shapeAdaptor(type);
+
+  if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
+    return success();
+
+  SmallVector<int64_t> shape;
+  shapeAdaptor.getDims(shape);
+  if (!llvm::all_of(shape, [](int64_t dim) { return dim == 1; }))
+    return failure();
+
+  return success();
+}
+
 // verify that inType and outType have same element types
 template <typename T>
 static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
@@ -3321,6 +3366,91 @@ void IfOp::print(OpAsmPrinter &p) {
   p.printOptionalAttrDict((*this)->getAttrs());
 }
 
+LogicalResult IfOp::verify() {
+  if (getThenGraph().empty() || getElseGraph().empty())
+    return emitOpError("require `then_graph` and `else_graph` not be empty");
+
+  if (errorIfTensorListShapeMismatch(
+          *this, getThenGraph().front().getArguments(),
+          "`then_graph` arguments", getInputList(), "`input_list`")
+          .failed())
+    return failure();
+
+  if (errorIfTensorListShapeMismatch(
+          *this, getElseGraph().front().getArguments(),
+          "`else_graph` arguments", getInputList(), "`input_list`")
+          .failed())
+    return failure();
+
+  auto thenYield = cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
+  if (errorIfTensorListShapeMismatch(*this, thenYield.getInputs(),
+                                     "`then_graph` results", getOutputList(),
+                                     "`output_list`")
+          .failed())
+    return failure();
+
+  auto elseYield = cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
+  if (errorIfTensorListShapeMismatch(*this, elseYield.getInputs(),
+                                     "`else_graph` results", getOutputList(),
+                                     "`output_list`")
+          .failed())
+    return failure();
+
+  auto condType = getCondition().getType();
+  if (errorIfShapeNotSizeOne(*this, condType).failed())
+    return emitOpError() << "`condition` must be a size 1 tensor, got "
+                         << condType;
+
+  return success();
+}
+
+LogicalResult WhileOp::verify() {
+  if (getCondGraph().empty() || getBodyGraph().empty())
+    return emitOpError(
+        "`cond_graph` and `body_graph` regions must not be empty");
+
+  if (errorIfTensorListShapeMismatch(*this, getInputList(), "`input_list`",
+                                     getOutputList(), "`output_list`")
+          .failed())
+    return failure();
+
+  if (errorIfTensorListShapeMismatch(
+          *this, getCondGraph().front().getArguments(),
+          "`cond_graph` arguments", getInputList(), "`input_list`")
+          .failed())
+    return failure();
+
+  if (errorIfTensorListShapeMismatch(
+          *this, getBodyGraph().front().getArguments(),
+          "`body_graph` arguments", getInputList(), "`input_list`")
+          .failed())
+    return failure();
+
+  auto bodyYield = cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
+  if (errorIfTensorListShapeMismatch(*this, bodyYield.getInputs(),
+                                     "`body_graph` results", getInputList(),
+                                     "`input_list`")
+          .failed())
+    return failure();
+
+  // Condition block output must be a single element tensor with a single bool
+  // value.
+  auto condYield = cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
+  if (condYield.getInputs().size() != 1)
+    return emitOpError() << "require `cond_graph` only have one result";
+
+  auto condOutType = condYield.getInputs()[0].getType();
+  if (errorIfShapeNotSizeOne(*this, condOutType).failed())
+    return emitOpError() << "`cond_graph` result must be a size 1 tensor, got "
+                         << condOutType;
+
+  if (!getElementTypeOrSelf(condOutType).isInteger(1))
+    return emitOpError() << "`cond_graph` result must be a boolean tensor, got "
+                         << condOutType;
+
+  return success();
+}
+
 LogicalResult ReverseOp::verify() {
   if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
                              /* outType = */ getOutput().getType())
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index ef9d27f8df0ad..de31198faff2a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -449,6 +449,42 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     return true;
   }
 
+  // Preform depth-first search in Tosa IR structure to find the maximum nesting
+  // depth. Tosa nesting_depth starts at 0 and increase by one each time a new
+  // nested `region` is encountered.
+
+  static int32_t getMaxNestedDepth(Operation *op) {
+    int32_t depth = 0;
+    for (Region &region : op->getRegions())
+      depth = std::max(depth, getMaxNestedDepth(region));
+    return depth;
+  }
+
+  static int32_t getMaxNestedDepth(Block &block) {
+    int32_t depth = 0;
+    for (Operation &op : block.getOperations())
+      depth = std::max(depth, getMaxNestedDepth(&op));
+    return depth;
+  }
+
+  static int32_t getMaxNestedDepth(Region &region) {
+    int32_t depth = 0;
+    for (Block &block : region.getBlocks())
+      depth = std::max(depth, getMaxNestedDepth(block));
+    // Increase the nested depth.
+    return depth + 1;
+  }
+
+  bool levelCheckMaxNesting(Operation *op) {
+    int32_t maxNestedDepth = getMaxNestedDepth(op);
+    if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
+      op->emitOpError() << "failed level check: " << maxNestedDepth
+                        << " >= MAX_NESTING";
+      return false;
+    }
+    return true;
+  }
+
   bool levelCheckListSize(Operation *op) {
     if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
       return levelCheckListSize(op, concat.getInput1().size(), "input1");
@@ -750,6 +786,12 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
     return failure();
   }
 
+  if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
+    if (!levelCheckMaxNesting(op)) {
+      return failure();
+    }
+  }
+
   return success();
 }
 
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index b48f614770fcb..5cad01d85825d 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1327,33 +1327,42 @@ func.func @test_if_tensor_list_size(%arg0 : tensor<i1>) {
   %0 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
   // expected-error at +1 {{'tosa.cond_if' op failed level check for MAX_TENSOR_LIST_SIZE: inputs}}
   %1 = "tosa.cond_if"(%arg0,   // condition
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0) ({
-  ^bb0(%arg3: tensor<1xi32>):
-    "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0) ({
+  ^bb0(%00: tensor<1xi32>, %01: tensor<1xi32>, %02: tensor<1xi32>, %03: tensor<1xi32>, %04: tensor<1xi32>, %05: tensor<1xi32>, %06: tensor<1xi32>, %07: tensor<1xi32>, %08: tensor<1xi32>, %09: tensor<1xi32>,
+       %10: tensor<1xi32>, %11: tensor<1xi32>, %12: tensor<1xi32>, %13: tensor<1xi32>, %14: tensor<1xi32>, %15: tensor<1xi32>, %16: tensor<1xi32>, %17: tensor<1xi32>, %18: tensor<1xi32>, %19: tensor<1xi32>,
+       %20: tensor<1xi32>, %21: tensor<1xi32>, %22: tensor<1xi32>, %23: tensor<1xi32>, %24: tensor<1xi32>, %25: tensor<1xi32>, %26: tensor<1xi32>, %27: tensor<1xi32>, %28: tensor<1xi32>, %29: tensor<1xi32>,
+       %30: tensor<1xi32>, %31: tensor<1xi32>, %32: tensor<1xi32>, %33: tensor<1xi32>, %34: tensor<1xi32>, %35: tensor<1xi32>, %36: tensor<1xi32>, %37: tensor<1xi32>, %38: tensor<1xi32>, %39: tensor<1xi32>,
+       %40: tensor<1xi32>, %41: tensor<1xi32>, %42: tensor<1xi32>, %43: tensor<1xi32>, %44: tensor<1xi32>, %45: tensor<1xi32>, %46: tensor<1xi32>, %47: tensor<1xi32>, %48: tensor<1xi32>, %49: tensor<1xi32>,
+       %50: tensor<1xi32>, %51: tensor<1xi32>, %52: tensor<1xi32>, %53: tensor<1xi32>, %54: tensor<1xi32>, %55: tensor<1xi32>, %56: tensor<1xi32>, %57: tensor<1xi32>, %58: tensor<1xi32>, %59: tensor<1xi32>,
+       %60: tensor<1xi32>, %61: tensor<1xi32>, %62: tensor<1xi32>, %63: tensor<1xi32>, %64: tensor<1xi32>
+   ):
+    "tosa.yield"(%64) : (tensor<1xi32>) -> ()
   },  {
-  ^bb0(%arg3: tensor<1xi32>):
-    "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
+  ^bb0(%00: tensor<1xi32>, %01: tensor<1xi32>, %02: tensor<1xi32>, %03: tensor<1xi32>, %04: tensor<1xi32>, %05: tensor<1xi32>, %06: tensor<1xi32>, %07: tensor<1xi32>, %08: tensor<1xi32>, %09: tensor<1xi32>,
+       %10: tensor<1xi32>, %11: tensor<1xi32>, %12: tensor<1xi32>, %13: tensor<1xi32>, %14: tensor<1xi32>, %15: tensor<1xi32>, %16: tensor<1xi32>, %17: tensor<1xi32>, %18: tensor<1xi32>, %19: tensor<1xi32>,
+       %20: tensor<1xi32>, %21: tensor<1xi32>, %22: tensor<1xi32>, %23: tensor<1xi32>, %24: tensor<1xi32>, %25: tensor<1xi32>, %26: tensor<1xi32>, %27: tensor<1xi32>, %28: tensor<1xi32>, %29: tensor<1xi32>,
+       %30: tensor<1xi32>, %31: tensor<1xi32>, %32: tensor<1xi32>, %33: tensor<1xi32>, %34: tensor<1xi32>, %35: tensor<1xi32>, %36: tensor<1xi32>, %37: tensor<1xi32>, %38: tensor<1xi32>, %39: tensor<1xi32>,
+       %40: tensor<1xi32>, %41: tensor<1xi32>, %42: tensor<1xi32>, %43: tensor<1xi32>, %44: tensor<1xi32>, %45: tensor<1xi32>, %46: tensor<1xi32>, %47: tensor<1xi32>, %48: tensor<1xi32>, %49: tensor<1xi32>,
+       %50: tensor<1xi32>, %51: tensor<1xi32>, %52: tensor<1xi32>, %53: tensor<1xi32>, %54: tensor<1xi32>, %55: tensor<1xi32>, %56: tensor<1xi32>, %57: tensor<1xi32>, %58: tensor<1xi32>, %59: tensor<1xi32>,
+       %60: tensor<1xi32>, %61: tensor<1xi32>, %62: tensor<1xi32>, %63: tensor<1xi32>, %64: tensor<1xi32>
+   ):
+    "tosa.yield"(%01) : (tensor<1xi32>) -> ()
   }) : (
-                    tensor<i1>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>
-                  ) -> tensor<1xi32>
-
+       tensor<i1>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
+     ) -> tensor<1xi32>
   return
 }
 
@@ -1361,27 +1370,54 @@ func.func @test_if_tensor_list_size(%arg0 : tensor<i1>) {
 
 // CHECK-LABEL: test_if_tensor_list_size_outputs
 func.func @test_if_tensor_list_size_outputs(%arg0 : tensor<i1>) {
-  %0 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  %cst_0 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
 
   // expected-error at +1 {{'tosa.cond_if' op failed level check for MAX_TENSOR_LIST_SIZE: outputs}}
-  %r:65 = "tosa.cond_if"(%arg0) ({
-  ^bb0(%arg3: tensor<1xi32>):
-    "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
+  %r:65 = "tosa.cond_if"(%arg0, %cst_0) ({
+  ^bb0(%0: tensor<1xi32>):
+    "tosa.yield"(%0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0
+                ) : (
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
+    ) -> ()
   },  {
-  ^bb0(%arg3: tensor<1xi32>):
-    "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
-  }) : (tensor<i1>) -> (
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>
-                  )
-
+  ^bb0(%0: tensor<1xi32>):
+    "tosa.yield"(%0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0
+                ) : (
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
+    ) -> ()
+  }) : (tensor<i1>, tensor<1xi32>) -> (
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
+    )
   return
 }
 
@@ -1391,25 +1427,57 @@ func.func @test_if_tensor_list_size_outputs(%arg0 : tensor<i1>) {
 func.func @test_while_tensor_list_size(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: tensor<1xi32>) {
   %0 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
   // expected-error at +1 {{'tosa.while_loop' op failed level check for MAX_TENSOR_LIST_SIZE: inputs}}
-  %1:2 = "tosa.while_loop"(%0, %arg0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0
+  %1:65 = "tosa.while_loop"(%0, %arg0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0
   ) ({
-  ^bb0(%arg3: tensor<1xi32>, %arg4: tensor<1x1x1x1x1x1x1xf32>):
+  ^bb0(%arg3: tensor<1xi32>, %arg4: tensor<1x1x1x1x1x1x1xf32>,
+       %00: tensor<1xi32>, %01: tensor<1xi32>, %02: tensor<1xi32>, %03: tensor<1xi32>, %04: tensor<1xi32>, %05: tensor<1xi32>, %06: tensor<1xi32>, %07: tensor<1xi32>, %08: tensor<1xi32>, %09: tensor<1xi32>,
+       %10: tensor<1xi32>, %11: tensor<1xi32>, %12: tensor<1xi32>, %13: tensor<1xi32>, %14: tensor<1xi32>, %15: tensor<1xi32>, %16: tensor<1xi32>, %17: tensor<1xi32>, %18: tensor<1xi32>, %19: tensor<1xi32>,
+       %20: tensor<1xi32>, %21: tensor<1xi32>, %22: tensor<1xi32>, %23: tensor<1xi32>, %24: tensor<1xi32>, %25: tensor<1xi32>, %26: tensor<1xi32>, %27: tensor<1xi32>, %28: tensor<1xi32>, %29: tensor<1xi32>,
+       %30: tensor<1xi32>, %31: tensor<1xi32>, %32: tensor<1xi32>, %33: tensor<1xi32>, %34: tensor<1xi32>, %35: tensor<1xi32>, %36: tensor<1xi32>, %37: tensor<1xi32>, %38: tensor<1xi32>, %39: tensor<1xi32>,
+       %40: tensor<1xi32>, %41: tensor<1xi32>, %42: tensor<1xi32>, %43: tensor<1xi32>, %44: tensor<1xi32>, %45: tensor<1xi32>, %46: tensor<1xi32>, %47: tensor<1xi32>, %48: tensor<1xi32>, %49: tensor<1xi32>,
+       %50: tensor<1xi32>, %51: tensor<1xi32>, %52: tensor<1xi32>, %53: tensor<1xi32>, %54: tensor<1xi32>, %55: tensor<1xi32>, %56: tensor<1xi32>, %57: tensor<1xi32>, %58: tensor<1xi32>, %59: tensor<1xi32>,
+       %60: tensor<1xi32>, %61: tensor<1xi32>, %62: tensor<1xi32>
+  ):
     %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
     %3 = "tosa.logical_not"(%2) : (tensor<1xi1>) -> tensor<1xi1>
     "tosa.yield"(%3) : (tensor<1xi1>) -> ()
   },  {
-  ^bb0(%arg3: tensor<i32>, %arg4: tensor<1x1x1x1x1x1x1xf32>):
-    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
-    %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
-    "tosa.yield"(%3, %arg4) : (tensor<i32>, tensor<1x1x1x1x1x1x1xf32>) -> ()
+  ^bb0(%arg3: tensor<1xi32>, %arg4: tensor<1x1x1x1x1x1x1xf32>,
+       %00: tensor<1xi32>, %01: tensor<1xi32>, %02: tensor<1xi32>, %03: tensor<1xi32>, %04: tensor<1xi32>, %05: tensor<1xi32>, %06: tensor<1xi32>, %07: tensor<1xi32>, %08: tensor<1xi32>, %09: tensor<1xi32>,
+       %10: tensor<1xi32>, %11: tensor<1xi32>, %12: tensor<1xi32>, %13: tensor<1xi32>, %14: tensor<1xi32>, %15: tensor<1xi32>, %16: tensor<1xi32>, %17: tensor<1xi32>, %18: tensor<1xi32>, %19: tensor<1xi32>,
+       %20: tensor<1xi32>, %21: tensor<1xi32>, %22: tensor<1xi32>, %23: tensor<1xi32>, %24: tensor<1xi32>, %25: tensor<1xi32>, %26: tensor<1xi32>, %27: tensor<1xi32>, %28: tensor<1xi32>, %29: tensor<1xi32>,
+       %30: tensor<1xi32>, %31: tensor<1xi32>, %32: tensor<1xi32>, %33: tensor<1xi32>, %34: tensor<1xi32>, %35: tensor<1xi32>, %36: tensor<1xi32>, %37: tensor<1xi32>, %38: tensor<1xi32>, %39: tensor<1xi32>,
+       %40: tensor<1xi32>, %41: tensor<1xi32>, %42: tensor<1xi32>, %43: tensor<1xi32>, %44: tensor<1xi32>, %45: tensor<1xi32>, %46: tensor<1xi32>, %47: tensor<1xi32>, %48: tensor<1xi32>, %49: tensor<1xi32>,
+       %50: tensor<1xi32>, %51: tensor<1xi32>, %52: tensor<1xi32>, %53: tensor<1xi32>, %54: tensor<1xi32>, %55: tensor<1xi32>, %56: tensor<1xi32>, %57: tensor<1xi32>, %58: tensor<1xi32>, %59: tensor<1xi32>,
+       %60: tensor<1xi32>, %61: tensor<1xi32>, %62: tensor<1xi32>
+  ):
+    %2 = "tosa.const"() {values = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+    %3 = "tosa.add"(%arg3, %2) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+    "tosa.yield"(%3, %arg4,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0
+                ) : (
+      tensor<1xi32>, tensor<1x1x1x1x1x1x1xf32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
+    ) -> ()
   }) : (tensor<1xi32>, tensor<1x1x1x1x1x1x1xf32>,
                     tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
                     tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
@@ -1419,28 +1487,7 @@ func.func @test_while_tensor_list_size(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1:
                     tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
                     tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
                     tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
-  ) -> (tensor<i32>, tensor<1x1x1x1x1x1x1xf32>)
-
-  return
-}
-
-// -----
-
-// CHECK-LABEL: test_while_tensor_list_size_outputs
-func.func @test_while_tensor_list_size_outputs(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: tensor<1xi32>) {
-  %0 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
-  // expected-error at +1 {{'tosa.while_loop' op failed level check for MAX_TENSOR_LIST_SIZE: outputs}}
-  %1:65 = "tosa.while_loop"(%0, %arg0) ({
-  ^bb0(%arg3: tensor<1xi32>, %arg4: tensor<1x1x1x1x1x1x1xf32>):
-    %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
-    %3 = "tosa.logical_not"(%2) : (tensor<1xi1>) -> tensor<1xi1>
-    "tosa.yield"(%3) : (tensor<1xi1>) -> ()
-  },  {
-  ^bb0(%arg3: tensor<i32>, %arg4: tensor<1x1x1x1x1x1x1xf32>):
-    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
-    %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
-    "tosa.yield"(%3, %arg4) : (tensor<i32>, tensor<1x1x1x1x1x1x1xf32>) -> ()
-  }) : (tensor<1xi32>, tensor<1x1x1x1x1x1x1xf32>) -> ( tensor<i32>, tensor<1x1x1x1x1x1x1xf32>,
+  ) -> (tensor<1xi32>, tensor<1x1x1x1x1x1x1xf32>,
                     tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
                     tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
                     tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
@@ -1453,3 +1500,101 @@ func.func @test_while_tensor_list_size_outputs(%arg0: tensor<1x1x1x1x1x1x1xf32>,
 
   return
 }
+
+// -----
+
+func.func @test_cond_if_max_nested_depth(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>, %arg3: tensor<i1>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.cond_if' op failed level check: 6 >= MAX_NESTING}}
+  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+    %1 = tosa.cond_if %arg3 -> (tensor<f32>) {
+      %2 = tosa.cond_if %arg2 -> (tensor<f32>) {
+        %3 = tosa.cond_if %arg3 -> (tensor<f32>) {
+          %4 = tosa.cond_if %arg2 -> (tensor<f32>) {
+            %5 = tosa.cond_if %arg3 -> (tensor<f32>) {
+              %res = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+              tosa.yield %res : tensor<f32>
+            } else {
+              tosa.yield %arg0 : tensor<f32>
+            }
+            tosa.yield %5 : tensor<f32>
+          } else {
+            tosa.yield %arg0 : tensor<f32>
+          }
+          tosa.yield %4 : tensor<f32>
+        } else {
+          tosa.yield %arg0 : tensor<f32>
+        }
+        tosa.yield %3 : tensor<f32>
+      } else {
+        tosa.yield %arg0 : tensor<f32>
+      }
+      tosa.yield %2 : tensor<f32>
+    } else {
+      tosa.yield %arg0 : tensor<f32>
+    }
+    tosa.yield %1 : tensor<f32>
+  } else {
+    %res = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %res : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_while_loop_max_nested_depth(%arg0: tensor<i32>) {
+  %init_0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  %cst_1 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+
+  // expected-error at +1 {{'tosa.while_loop' op failed level check: 6 >= MAX_NESTING}}
+  %1:2 = tosa.while_loop (%arg2 = %init_0, %arg3 = %arg0) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+    %2 = tosa.greater_equal %arg3, %arg2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    tosa.yield %2 : tensor<i1>
+  } do {
+  ^bb0(%arg2: tensor<i32>, %arg2b: tensor<i32>):
+    %1:2 = tosa.while_loop (%arg4 = %init_0, %arg5 = %arg0) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+      %2 = tosa.greater_equal %arg5, %arg4 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+      tosa.yield %2 : tensor<i1>
+    } do {
+    ^bb0(%arg4: tensor<i32>, %arg4b: tensor<i32>):
+      %1:2 = tosa.while_loop (%arg6 = %init_0, %arg7 = %arg0) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+        %2 = tosa.greater_equal %arg7, %arg6 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+        tosa.yield %2 : tensor<i1>
+      } do {
+      ^bb0(%arg6: tensor<i32>, %arg6b: tensor<i32>):
+        %1:2 = tosa.while_loop (%arg8 = %init_0, %arg9 = %arg0) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+          %2 = tosa.greater_equal %arg9, %arg8 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+          tosa.yield %2 : tensor<i1>
+        } do {
+        ^bb0(%arg8: tensor<i32>, %arg8b: tensor<i32>):
+          %1:2 = tosa.while_loop (%arg10 = %init_0, %arg11 = %arg0) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+            %2 = tosa.greater_equal %arg11, %arg10 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+            tosa.yield %2 : tensor<i1>
+          } do {
+          ^bb0(%arg10: tensor<i32>, %arg10b: tensor<i32>):
+            %1:2 = tosa.while_loop (%arg12 = %init_0, %arg13 = %arg0) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+              %2 = tosa.greater_equal %arg13, %arg12 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+              tosa.yield %2 : tensor<i1>
+            } do {
+            ^bb0(%arg12: tensor<i32>, %arg12b: tensor<i32>):
+              %3 = tosa.add %arg12, %cst_1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+              tosa.yield %arg2, %3: tensor<i32>, tensor<i32>
+            }
+            %3 = tosa.add %arg10, %cst_1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+            tosa.yield %arg2, %3: tensor<i32>, tensor<i32>
+          }
+          %3 = tosa.add %arg8, %cst_1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+          tosa.yield %arg2, %3: tensor<i32>, tensor<i32>
+        }
+        %3 = tosa.add %arg6, %cst_1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+        tosa.yield %arg2, %3: tensor<i32>, tensor<i32>
+      }
+      %3 = tosa.add %arg4, %cst_1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+      tosa.yield %arg2, %3: tensor<i32>, tensor<i32>
+    }
+    %3 = tosa.add %arg2, %cst_1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %arg2, %3: tensor<i32>, tensor<i32>
+  }
+  return
+}
+
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index efdd26a9346fb..df37afa3bed49 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -167,3 +167,352 @@ func.func @test_scalar_slice(%arg0: tensor<f32>) -> tensor<f32> {
   %2 = tosa.slice %arg0, %0, %1 : (tensor<f32>, !tosa.shape<0>, !tosa.shape<0>) -> tensor<f32>
   return %2 : tensor<f32>
 }
+
+
+// -----
+
+func.func @test_cond_if_input_list_mismatch_then_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.cond_if' op require same number of values in `then_graph` arguments (1) and `input_list` (2)}}
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<f32>):
+    tosa.yield %arg3 : tensor<f32>
+  },  {
+  ^bb0(%arg4: tensor<f32>):
+    tosa.yield %arg4 : tensor<f32>
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+  return %0 : tensor<f32>
+
+}
+
+// -----
+
+func.func @test_cond_if_input_list_mismatch_then_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.cond_if' op require same number of values in `then_graph` arguments (2) and `input_list` (1)}}
+  %0 = "tosa.cond_if"(%arg2, %arg0) ({
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    tosa.yield %arg3 : tensor<f32>
+  },  {
+  ^bb0(%arg4: tensor<f32>):
+    tosa.yield %arg4 : tensor<f32>
+  }) : (tensor<i1>, tensor<f32>) -> tensor<f32>
+  return %0 : tensor<f32>
+
+}
+
+// -----
+
+func.func @test_cond_if_input_list_mismatch_else_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.cond_if' op require same number of values in `else_graph` arguments (1) and `input_list` (2)}}
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    tosa.yield %arg3 : tensor<f32>
+  },  {
+  ^bb0(%arg4: tensor<f32>):
+    tosa.yield %arg4 : tensor<f32>
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+  return %0 : tensor<f32>
+
+}
+
+// -----
+
+func.func @test_cond_if_input_list_mismatch_else_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.cond_if' op require same number of values in `else_graph` arguments (2) and `input_list` (1)}}
+  %0 = "tosa.cond_if"(%arg2, %arg0) ({
+  ^bb0(%arg3: tensor<f32>):
+    tosa.yield %arg3 : tensor<f32>
+  },  {
+  ^bb0(%arg4: tensor<f32>, %arg3: tensor<f32>):
+    tosa.yield %arg4 : tensor<f32>
+  }) : (tensor<i1>, tensor<f32>) -> tensor<f32>
+  return %0 : tensor<f32>
+
+}
+
+// -----
+
+func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.cond_if' op require same number of values in `then_graph` results (2) and `output_list` (1)}}
+  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+    %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    %2 = tosa.add %1, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1, %2 : tensor<f32>, tensor<f32>
+  } else {
+    %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.cond_if' op require same number of values in `then_graph` results (1) and `output_list` (2)}}
+  %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) {
+    %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  } else {
+    %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.cond_if' op require same number of values in `else_graph` results (2) and `output_list` (1)}}
+  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+    %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  } else {
+    %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    %2 = tosa.add %1, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1, %2 : tensor<f32>, tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_output_list_mismatch_else_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.cond_if' op require same number of values in `else_graph` results (1) and `output_list` (2)}}
+  %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) {
+    %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    %2 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1, %2 : tensor<f32>, tensor<f32>
+  } else {
+    %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_cond_input_not_size_one(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<2xi1>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.cond_if' op `condition` must be a size 1 tensor, got 'tensor<2xi1>'}}
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    tosa.yield %arg3 : tensor<f32>
+  },  {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    tosa.yield %arg4 : tensor<f32>
+  }) : (tensor<2xi1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+  return %0 : tensor<f32>
+
+}
+
+// -----
+
+func.func @test_while_loop_input_list_mismatch_body_block_in(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error at +1 {{'tosa.while_loop' op require same number of values in `body_graph` arguments (3) and `input_list` (2)}}
+  %1:2 = tosa.while_loop (%arg2 = %0, %arg3 = %arg0) : (tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<10xi32>) {
+    %2 = tosa.greater_equal %arg2, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    tosa.yield %2 : tensor<i1>
+  } do {
+  ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<10xi32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %3, %arg4 : tensor<i32>, tensor<10xi32>
+  }
+  return
+}
+
+// -----
+
+func.func @test_while_loop_input_list_mismatch_body_block_in_2(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error at +1 {{'tosa.while_loop' op require same number of values in `body_graph` arguments (2) and `input_list` (3)}}
+  %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %arg0, %arg4 = %arg0)
+    : (tensor<i32>, tensor<10xi32>, tensor<10xi32>) -> (tensor<i32>, tensor<10xi32>, tensor<10xi32>) {
+    %2 = tosa.greater_equal %arg2, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    tosa.yield %2 : tensor<i1>
+  } do {
+  ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %3, %arg3 : tensor<i32>, tensor<i32>
+  }
+  return
+}
+
+// -----
+
+func.func @test_while_loop_input_list_mismatch_output_list(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error at +1 {{'tosa.while_loop' op require same number of values in `input_list` (3) and `output_list` (2)}}
+  %1:2 = tosa.while_loop (%arg2 = %0, %arg3 = %arg0, %arg4 = %arg0)
+    : (tensor<i32>, tensor<10xi32>, tensor<10xi32>) -> (tensor<i32>, tensor<10xi32>) {
+    %2 = tosa.greater_equal %arg2, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    tosa.yield %2 : tensor<i1>
+  } do {
+  ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %3, %arg3 : tensor<i32>, tensor<i32>
+  }
+  return
+}
+
+// -----
+
+func.func @test_while_loop_input_list_mismatch_output_list_2(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error at +1 {{'tosa.while_loop' op require same number of values in `input_list` (2) and `output_list` (3)}}
+  %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %arg0)
+    : (tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<10xi32>, tensor<10xi32>) {
+    %2 = tosa.greater_equal %arg2, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    tosa.yield %2 : tensor<i1>
+  } do {
+  ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %3, %arg3 : tensor<i32>, tensor<i32>
+  }
+  return
+}
+
+// -----
+
+func.func @test_while_loop_input_list_mismatch_cond_block(%arg0: tensor<2xf32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error at +1 {{'tosa.while_loop' op require same number of values in `cond_graph` arguments (3) and `input_list` (2)}}
+  %1:2 = "tosa.while_loop"(%0, %arg0) ({
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<2xf32>, %arg5: tensor<2xf32>):
+    %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    "tosa.yield"(%2) : (tensor<i1>) -> ()
+  },  {
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<2xf32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = "tosa.const"() {values = dense<2> : tensor<1xi8>} : () -> tensor<1xi8>
+    %4 = "tosa.mul"(%arg3, %2, %3) : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+    "tosa.yield"(%4, %arg4) : (tensor<i32>, tensor<2xf32>) -> ()
+  }) : (tensor<i32>, tensor<2xf32>) -> (tensor<i32>, tensor<2xf32>)
+  return
+}
+
+// -----
+
+func.func @test_while_loop_input_list_mismatch_cond_block_2(%arg0: tensor<2xf32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error at +1 {{'tosa.while_loop' op require same number of values in `cond_graph` arguments (1) and `input_list` (3)}}
+  %1:3 = "tosa.while_loop"(%0, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<i32>):
+    %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    "tosa.yield"(%2) : (tensor<i1>) -> ()
+  },  {
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<2xf32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = "tosa.const"() {values = dense<2> : tensor<1xi8>} : () -> tensor<1xi8>
+    %4 = "tosa.mul"(%arg3, %2, %3) : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+    "tosa.yield"(%4, %arg4) : (tensor<i32>, tensor<2xf32>) -> ()
+  }) : (tensor<i32>, tensor<2xf32>, tensor<i32>) -> (tensor<i32>, tensor<2xf32>, tensor<i32>)
+  return
+}
+
+// -----
+
+func.func @test_while_loop_input_list_mismatch_body_block_out(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error at +1 {{'tosa.while_loop' op require same number of values in `body_graph` results (3) and `input_list` (2)}}
+  %1:2 = tosa.while_loop (%arg2 = %0, %arg3 = %arg0) : (tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<10xi32>) {
+    %2 = tosa.greater_equal %arg2, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    tosa.yield %2 : tensor<i1>
+  } do {
+  ^bb0(%arg2: tensor<i32>, %arg4: tensor<10xi32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %2, %3, %arg4 : tensor<i32>, tensor<i32>, tensor<10xi32>
+  }
+  return
+}
+
+// -----
+
+func.func @test_while_loop_input_list_mismatch_body_block_out_2(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error at +1 {{'tosa.while_loop' op require same number of values in `body_graph` results (1) and `input_list` (2)}}
+  %1:2 = tosa.while_loop (%arg2 = %0, %arg3 = %arg0) : (tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<10xi32>) {
+    %2 = tosa.greater_equal %arg2, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    tosa.yield %2 : tensor<i1>
+  } do {
+  ^bb0(%arg2: tensor<i32>, %arg4: tensor<10xi32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %3 : tensor<i32>
+  }
+  return
+}
+
+// -----
+
+func.func @test_while_loop_type_mismatch(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error at +1 {{'tosa.while_loop' op require same shapes for `body_graph` arguments ('tensor<f32>') and `input_list` ('tensor<i32>')}}
+  %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %0, %arg4 = %arg0) : (tensor<i32>, tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<10xi32>) {
+    %2 = tosa.greater_equal %arg3, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    %3 = tosa.logical_not %2 : (tensor<i1>) -> tensor<i1>
+    tosa.yield %3 : tensor<i1>
+  } do {
+  ^bb0(%arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<10xi32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %6 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %6, %2, %arg4 : tensor<i32>, tensor<i32>, tensor<10xi32>
+  }
+  return
+}
+
+// -----
+
+func.func @test_while_loop_type_mismatch_2(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error at +1 {{'tosa.while_loop' op require same shapes for `body_graph` arguments ('tensor<10xi32>') and `input_list` ('tensor<i32>')}}
+  %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %0, %arg4 = %arg0) : (tensor<i32>, tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<10xi32>) {
+    %2 = tosa.greater_equal %arg3, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    %3 = tosa.logical_not %2 : (tensor<i1>) -> tensor<i1>
+    tosa.yield %3 : tensor<i1>
+  } do {
+  ^bb0(%arg2: tensor<10xi32>, %arg3: tensor<i32>, %arg4: tensor<10xi32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %6 = tosa.add %arg2, %2 : (tensor<10xi32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %6, %2, %arg4 : tensor<i32>, tensor<i32>, tensor<10xi32>
+  }
+  return
+}
+
+// -----
+
+func.func @test_while_loop_cond_output_not_size_one(%arg0: tensor<10xi32>, %arg1: tensor<2xi32>) {
+  %0 = "tosa.const"() {values = dense<[4, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // expected-error at +1 {{'tosa.while_loop' op `cond_graph` result must be a size 1 tensor, got 'tensor<2xi1>'}}
+  %1:3 = tosa.while_loop (%arg2 = %arg0, %arg3 = %0, %arg4 = %arg0) : (tensor<10xi32>, tensor<2xi32>, tensor<10xi32>) -> (tensor<10xi32>, tensor<2xi32>, tensor<10xi32>) {
+    %2 = tosa.greater_equal %arg3, %arg1 : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
+    tosa.yield %2 : tensor<2xi1>
+  } do {
+  ^bb0(%arg2: tensor<10xi32>, %arg3: tensor<2xi32>, %arg4: tensor<10xi32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = "tosa.const"() {values = dense<[3, 5]> : tensor<2xi32>} : () -> tensor<2xi32>
+    %4 = tosa.add %arg2, %2 : (tensor<10xi32>, tensor<i32>) -> tensor<10xi32>
+    tosa.yield %4, %3, %arg4 : tensor<10xi32>, tensor<2xi32>, tensor<10xi32>
+  }
+  return
+}
+
+// -----
+
+func.func @test_while_loop_cond_output_not_bool(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {values = dense<9> : tensor<i32>} : () -> tensor<i32>
+  // expected-error at +1 {{'tosa.while_loop' op `cond_graph` result must be a boolean tensor, got 'tensor<i32>'}}
+  %1:3 = tosa.while_loop (%arg2 = %arg0, %arg3 = %0, %arg4 = %arg0) : (tensor<10xi32>, tensor<i32>, tensor<10xi32>) -> (tensor<10xi32>, tensor<i32>, tensor<10xi32>) {
+    %2 = tosa.add %arg3, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %2 : tensor<i32>
+  } do {
+  ^bb0(%arg2: tensor<10xi32>, %arg3: tensor<i32>, %arg4: tensor<10xi32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %4 = tosa.add %arg2, %2 : (tensor<10xi32>, tensor<i32>) -> tensor<10xi32>
+    tosa.yield %4, %2, %arg4 : tensor<10xi32>, tensor<i32>, tensor<10xi32>
+  }
+  return
+}



More information about the Mlir-commits mailing list