[Mlir-commits] [mlir] [mlir][tosa] Add error if and level checks for COND_IF & WHILE_LOOP (PR #136194)

TatWai Chong llvmlistbot at llvm.org
Fri Apr 25 00:43:45 PDT 2025


https://github.com/tatwaichong updated https://github.com/llvm/llvm-project/pull/136194

>From d8a3f5e6e6824bb9f2317416eeb93c84fcb87582 Mon Sep 17 00:00:00 2001
From: TatWai Chong <tatwai.chong at arm.com>
Date: Fri, 11 Apr 2025 19:54:49 -0700
Subject: [PATCH] [mlir][tosa] Add error if and level checks for COND_IF &
 WHILE_LOOP

Error if checks: verify whether the same length and type between
input list, output list, and control-flow blocks.

Level_checks: verify whether the nested depth exceeds MAX_NESTING.
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  |   2 +
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 129 +++++++
 .../Tosa/Transforms/TosaValidation.cpp        |  42 +++
 mlir/test/Dialect/Tosa/level_check.mlir       | 303 +++++++++++----
 mlir/test/Dialect/Tosa/verifier.mlir          | 348 ++++++++++++++++++
 5 files changed, 745 insertions(+), 79 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index c94edad62cac7..70702647bb50a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2559,6 +2559,7 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
   );
 
   let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -2597,6 +2598,7 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
   );
 
   let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
 }
 
 include "mlir/Dialect/Tosa/IR/TosaUtilOps.td"
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 751ae785bda6f..4bf08cc9e9929 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -572,6 +572,57 @@ static LogicalResult verifyConvOpErrorIf(T op) {
   return success();
 }
 
+// Verify whether same type and shape of the given two types.
+static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
+                                                StringRef name1, Type type2,
+                                                StringRef name2) {
+  auto shapeType1 = dyn_cast<ShapedType>(type1);
+  auto shapeType2 = dyn_cast<ShapedType>(type2);
+  if (!shapeType1 || !shapeType2)
+    return failure();
+
+  auto elemType1 = shapeType1.getElementType();
+  auto elemType2 = shapeType2.getElementType();
+  if (elemType1 != elemType2)
+    return op->emitOpError()
+           << "require same element type for " << name1 << " (" << elemType1
+           << ") and " << name2 << " (" << elemType2 << ")";
+
+  if (failed(verifyCompatibleShape(type1, type2)))
+    return op->emitOpError()
+           << "require same shapes for " << name1 << " (" << type1 << ") and "
+           << name2 << " (" << type2 << ")";
+
+  return success();
+}
+
+// Verify whether same length, type, and shape of the given two tensor lists.
+static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, ValueRange list1,
+                                                StringRef name1,
+                                                ValueRange list2,
+                                                StringRef name2) {
+  if (list1.size() != list2.size())
+    return op->emitOpError()
+           << "require same number of values in " << name1 << " ("
+           << list1.size() << ") and " << name2 << " (" << list2.size() << ")";
+
+  for (auto [type1, type2] :
+       llvm::zip_equal(list1.getTypes(), list2.getTypes())) {
+    if (errorIfTypeOrShapeMismatch(op, type1, name1, type2, name2).failed())
+      return failure();
+  }
+
+  return success();
+}
+
+static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
+  ShapeAdaptor shapeAdaptor(type);
+  if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
+    return success();
+
+  return shapeAdaptor.getNumElements() == 1 ? success() : failure();
+}
+
 // verify that inType and outType have same element types
 template <typename T>
 static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
@@ -3397,6 +3448,84 @@ void IfOp::print(OpAsmPrinter &p) {
   p.printOptionalAttrDict((*this)->getAttrs());
 }
 
+LogicalResult IfOp::verify() {
+  if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
+                                 "'then_graph' arguments", getInputList(),
+                                 "'input_list'")
+          .failed())
+    return failure();
+
+  if (errorIfTypeOrShapeMismatch(*this, getElseGraph().front().getArguments(),
+                                 "'else_graph' arguments", getInputList(),
+                                 "'input_list'")
+          .failed())
+    return failure();
+
+  auto thenYield = cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
+  if (errorIfTypeOrShapeMismatch(*this, thenYield.getInputs(),
+                                 "'then_graph' results", getOutputList(),
+                                 "'output_list'")
+          .failed())
+    return failure();
+
+  auto elseYield = cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
+  if (errorIfTypeOrShapeMismatch(*this, elseYield.getInputs(),
+                                 "'else_graph' results", getOutputList(),
+                                 "'output_list'")
+          .failed())
+    return failure();
+
+  auto condType = getCondition().getType();
+  if (errorIfShapeNotSizeOne(*this, condType).failed())
+    return emitOpError() << "'condition' must be a size 1 tensor, got "
+                         << condType;
+
+  return success();
+}
+
+LogicalResult WhileOp::verify() {
+  if (errorIfTypeOrShapeMismatch(*this, getInputList(), "'input_list'",
+                                 getOutputList(), "'output_list'")
+          .failed())
+    return failure();
+
+  if (errorIfTypeOrShapeMismatch(*this, getCondGraph().front().getArguments(),
+                                 "'cond_graph' arguments", getInputList(),
+                                 "'input_list'")
+          .failed())
+    return failure();
+
+  if (errorIfTypeOrShapeMismatch(*this, getBodyGraph().front().getArguments(),
+                                 "'body_graph' arguments", getInputList(),
+                                 "'input_list'")
+          .failed())
+    return failure();
+
+  auto bodyYield = cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
+  if (errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
+                                 "'body_graph' results", getInputList(),
+                                 "'input_list'")
+          .failed())
+    return failure();
+
+  // Condition block output must be a single element tensor with a single bool
+  // value.
+  auto condYield = cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
+  if (condYield.getInputs().size() != 1)
+    return emitOpError() << "require 'cond_graph' only have one result";
+
+  auto condOutType = condYield.getInputs()[0].getType();
+  if (errorIfShapeNotSizeOne(*this, condOutType).failed())
+    return emitOpError() << "'cond_graph' result must be a size 1 tensor, got "
+                         << condOutType;
+
+  if (!getElementTypeOrSelf(condOutType).isInteger(1))
+    return emitOpError() << "'cond_graph' result must be a boolean tensor, got "
+                         << condOutType;
+
+  return success();
+}
+
 LogicalResult ReverseOp::verify() {
   if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
                              /* outType = */ getOutput().getType())
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index baa202833e285..e2026c9d3ae55 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -449,6 +449,42 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     return true;
   }
 
+  // Preform depth-first search in Tosa IR structure to find the maximum nesting
+  // depth. Tosa nesting_depth starts at 0 and increase by one each time a new
+  // nested `region` is encountered.
+
+  static int32_t getMaxNestedDepth(Operation *op) {
+    int32_t depth = 0;
+    for (Region &region : op->getRegions())
+      depth = std::max(depth, getMaxNestedDepth(region));
+    return depth;
+  }
+
+  static int32_t getMaxNestedDepth(Block &block) {
+    int32_t depth = 0;
+    for (Operation &op : block.getOperations())
+      depth = std::max(depth, getMaxNestedDepth(&op));
+    return depth;
+  }
+
+  static int32_t getMaxNestedDepth(Region &region) {
+    int32_t depth = 0;
+    for (Block &block : region.getBlocks())
+      depth = std::max(depth, getMaxNestedDepth(block));
+    // Increase the nested depth.
+    return depth + 1;
+  }
+
+  bool levelCheckMaxNesting(Operation *op) {
+    int32_t maxNestedDepth = getMaxNestedDepth(op);
+    if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
+      op->emitOpError() << "failed level check: " << maxNestedDepth
+                        << " >= MAX_NESTING";
+      return false;
+    }
+    return true;
+  }
+
   bool levelCheckListSize(Operation *op) {
     if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
       return levelCheckListSize(op, concat.getInput1().size(), "input1");
@@ -750,6 +786,12 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
     return failure();
   }
 
+  if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
+    if (!levelCheckMaxNesting(op)) {
+      return failure();
+    }
+  }
+
   return success();
 }
 
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 5307645324b81..5520f7f2e4691 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1327,33 +1327,42 @@ func.func @test_if_tensor_list_size(%arg0 : tensor<i1>) {
   %0 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
   // expected-error at +1 {{'tosa.cond_if' op failed level check for MAX_TENSOR_LIST_SIZE: inputs}}
   %1 = "tosa.cond_if"(%arg0,   // condition
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0) ({
-  ^bb0(%arg3: tensor<1xi32>):
-    "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0) ({
+  ^bb0(%00: tensor<1xi32>, %01: tensor<1xi32>, %02: tensor<1xi32>, %03: tensor<1xi32>, %04: tensor<1xi32>, %05: tensor<1xi32>, %06: tensor<1xi32>, %07: tensor<1xi32>, %08: tensor<1xi32>, %09: tensor<1xi32>,
+       %10: tensor<1xi32>, %11: tensor<1xi32>, %12: tensor<1xi32>, %13: tensor<1xi32>, %14: tensor<1xi32>, %15: tensor<1xi32>, %16: tensor<1xi32>, %17: tensor<1xi32>, %18: tensor<1xi32>, %19: tensor<1xi32>,
+       %20: tensor<1xi32>, %21: tensor<1xi32>, %22: tensor<1xi32>, %23: tensor<1xi32>, %24: tensor<1xi32>, %25: tensor<1xi32>, %26: tensor<1xi32>, %27: tensor<1xi32>, %28: tensor<1xi32>, %29: tensor<1xi32>,
+       %30: tensor<1xi32>, %31: tensor<1xi32>, %32: tensor<1xi32>, %33: tensor<1xi32>, %34: tensor<1xi32>, %35: tensor<1xi32>, %36: tensor<1xi32>, %37: tensor<1xi32>, %38: tensor<1xi32>, %39: tensor<1xi32>,
+       %40: tensor<1xi32>, %41: tensor<1xi32>, %42: tensor<1xi32>, %43: tensor<1xi32>, %44: tensor<1xi32>, %45: tensor<1xi32>, %46: tensor<1xi32>, %47: tensor<1xi32>, %48: tensor<1xi32>, %49: tensor<1xi32>,
+       %50: tensor<1xi32>, %51: tensor<1xi32>, %52: tensor<1xi32>, %53: tensor<1xi32>, %54: tensor<1xi32>, %55: tensor<1xi32>, %56: tensor<1xi32>, %57: tensor<1xi32>, %58: tensor<1xi32>, %59: tensor<1xi32>,
+       %60: tensor<1xi32>, %61: tensor<1xi32>, %62: tensor<1xi32>, %63: tensor<1xi32>, %64: tensor<1xi32>
+   ):
+    "tosa.yield"(%64) : (tensor<1xi32>) -> ()
   },  {
-  ^bb0(%arg3: tensor<1xi32>):
-    "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
+  ^bb0(%00: tensor<1xi32>, %01: tensor<1xi32>, %02: tensor<1xi32>, %03: tensor<1xi32>, %04: tensor<1xi32>, %05: tensor<1xi32>, %06: tensor<1xi32>, %07: tensor<1xi32>, %08: tensor<1xi32>, %09: tensor<1xi32>,
+       %10: tensor<1xi32>, %11: tensor<1xi32>, %12: tensor<1xi32>, %13: tensor<1xi32>, %14: tensor<1xi32>, %15: tensor<1xi32>, %16: tensor<1xi32>, %17: tensor<1xi32>, %18: tensor<1xi32>, %19: tensor<1xi32>,
+       %20: tensor<1xi32>, %21: tensor<1xi32>, %22: tensor<1xi32>, %23: tensor<1xi32>, %24: tensor<1xi32>, %25: tensor<1xi32>, %26: tensor<1xi32>, %27: tensor<1xi32>, %28: tensor<1xi32>, %29: tensor<1xi32>,
+       %30: tensor<1xi32>, %31: tensor<1xi32>, %32: tensor<1xi32>, %33: tensor<1xi32>, %34: tensor<1xi32>, %35: tensor<1xi32>, %36: tensor<1xi32>, %37: tensor<1xi32>, %38: tensor<1xi32>, %39: tensor<1xi32>,
+       %40: tensor<1xi32>, %41: tensor<1xi32>, %42: tensor<1xi32>, %43: tensor<1xi32>, %44: tensor<1xi32>, %45: tensor<1xi32>, %46: tensor<1xi32>, %47: tensor<1xi32>, %48: tensor<1xi32>, %49: tensor<1xi32>,
+       %50: tensor<1xi32>, %51: tensor<1xi32>, %52: tensor<1xi32>, %53: tensor<1xi32>, %54: tensor<1xi32>, %55: tensor<1xi32>, %56: tensor<1xi32>, %57: tensor<1xi32>, %58: tensor<1xi32>, %59: tensor<1xi32>,
+       %60: tensor<1xi32>, %61: tensor<1xi32>, %62: tensor<1xi32>, %63: tensor<1xi32>, %64: tensor<1xi32>
+   ):
+    "tosa.yield"(%01) : (tensor<1xi32>) -> ()
   }) : (
-                    tensor<i1>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>
-                  ) -> tensor<1xi32>
-
+       tensor<i1>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
+     ) -> tensor<1xi32>
   return
 }
 
@@ -1361,27 +1370,54 @@ func.func @test_if_tensor_list_size(%arg0 : tensor<i1>) {
 
 // CHECK-LABEL: test_if_tensor_list_size_outputs
 func.func @test_if_tensor_list_size_outputs(%arg0 : tensor<i1>) {
-  %0 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  %cst_0 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
 
   // expected-error at +1 {{'tosa.cond_if' op failed level check for MAX_TENSOR_LIST_SIZE: outputs}}
-  %r:65 = "tosa.cond_if"(%arg0) ({
-  ^bb0(%arg3: tensor<1xi32>):
-    "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
+  %r:65 = "tosa.cond_if"(%arg0, %cst_0) ({
+  ^bb0(%0: tensor<1xi32>):
+    "tosa.yield"(%0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0
+                ) : (
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
+    ) -> ()
   },  {
-  ^bb0(%arg3: tensor<1xi32>):
-    "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
-  }) : (tensor<i1>) -> (
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>
-                  )
-
+  ^bb0(%0: tensor<1xi32>):
+    "tosa.yield"(%0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0
+                ) : (
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
+    ) -> ()
+  }) : (tensor<i1>, tensor<1xi32>) -> (
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
+    )
   return
 }
 
@@ -1391,25 +1427,57 @@ func.func @test_if_tensor_list_size_outputs(%arg0 : tensor<i1>) {
 func.func @test_while_tensor_list_size(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: tensor<1xi32>) {
   %0 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
   // expected-error at +1 {{'tosa.while_loop' op failed level check for MAX_TENSOR_LIST_SIZE: inputs}}
-  %1:2 = "tosa.while_loop"(%0, %arg0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0
+  %1:65 = "tosa.while_loop"(%0, %arg0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0
   ) ({
-  ^bb0(%arg3: tensor<1xi32>, %arg4: tensor<1x1x1x1x1x1x1xf32>):
+  ^bb0(%arg3: tensor<1xi32>, %arg4: tensor<1x1x1x1x1x1x1xf32>,
+       %00: tensor<1xi32>, %01: tensor<1xi32>, %02: tensor<1xi32>, %03: tensor<1xi32>, %04: tensor<1xi32>, %05: tensor<1xi32>, %06: tensor<1xi32>, %07: tensor<1xi32>, %08: tensor<1xi32>, %09: tensor<1xi32>,
+       %10: tensor<1xi32>, %11: tensor<1xi32>, %12: tensor<1xi32>, %13: tensor<1xi32>, %14: tensor<1xi32>, %15: tensor<1xi32>, %16: tensor<1xi32>, %17: tensor<1xi32>, %18: tensor<1xi32>, %19: tensor<1xi32>,
+       %20: tensor<1xi32>, %21: tensor<1xi32>, %22: tensor<1xi32>, %23: tensor<1xi32>, %24: tensor<1xi32>, %25: tensor<1xi32>, %26: tensor<1xi32>, %27: tensor<1xi32>, %28: tensor<1xi32>, %29: tensor<1xi32>,
+       %30: tensor<1xi32>, %31: tensor<1xi32>, %32: tensor<1xi32>, %33: tensor<1xi32>, %34: tensor<1xi32>, %35: tensor<1xi32>, %36: tensor<1xi32>, %37: tensor<1xi32>, %38: tensor<1xi32>, %39: tensor<1xi32>,
+       %40: tensor<1xi32>, %41: tensor<1xi32>, %42: tensor<1xi32>, %43: tensor<1xi32>, %44: tensor<1xi32>, %45: tensor<1xi32>, %46: tensor<1xi32>, %47: tensor<1xi32>, %48: tensor<1xi32>, %49: tensor<1xi32>,
+       %50: tensor<1xi32>, %51: tensor<1xi32>, %52: tensor<1xi32>, %53: tensor<1xi32>, %54: tensor<1xi32>, %55: tensor<1xi32>, %56: tensor<1xi32>, %57: tensor<1xi32>, %58: tensor<1xi32>, %59: tensor<1xi32>,
+       %60: tensor<1xi32>, %61: tensor<1xi32>, %62: tensor<1xi32>
+  ):
     %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
     %3 = "tosa.logical_not"(%2) : (tensor<1xi1>) -> tensor<1xi1>
     "tosa.yield"(%3) : (tensor<1xi1>) -> ()
   },  {
-  ^bb0(%arg3: tensor<i32>, %arg4: tensor<1x1x1x1x1x1x1xf32>):
-    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
-    %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
-    "tosa.yield"(%3, %arg4) : (tensor<i32>, tensor<1x1x1x1x1x1x1xf32>) -> ()
+  ^bb0(%arg3: tensor<1xi32>, %arg4: tensor<1x1x1x1x1x1x1xf32>,
+       %00: tensor<1xi32>, %01: tensor<1xi32>, %02: tensor<1xi32>, %03: tensor<1xi32>, %04: tensor<1xi32>, %05: tensor<1xi32>, %06: tensor<1xi32>, %07: tensor<1xi32>, %08: tensor<1xi32>, %09: tensor<1xi32>,
+       %10: tensor<1xi32>, %11: tensor<1xi32>, %12: tensor<1xi32>, %13: tensor<1xi32>, %14: tensor<1xi32>, %15: tensor<1xi32>, %16: tensor<1xi32>, %17: tensor<1xi32>, %18: tensor<1xi32>, %19: tensor<1xi32>,
+       %20: tensor<1xi32>, %21: tensor<1xi32>, %22: tensor<1xi32>, %23: tensor<1xi32>, %24: tensor<1xi32>, %25: tensor<1xi32>, %26: tensor<1xi32>, %27: tensor<1xi32>, %28: tensor<1xi32>, %29: tensor<1xi32>,
+       %30: tensor<1xi32>, %31: tensor<1xi32>, %32: tensor<1xi32>, %33: tensor<1xi32>, %34: tensor<1xi32>, %35: tensor<1xi32>, %36: tensor<1xi32>, %37: tensor<1xi32>, %38: tensor<1xi32>, %39: tensor<1xi32>,
+       %40: tensor<1xi32>, %41: tensor<1xi32>, %42: tensor<1xi32>, %43: tensor<1xi32>, %44: tensor<1xi32>, %45: tensor<1xi32>, %46: tensor<1xi32>, %47: tensor<1xi32>, %48: tensor<1xi32>, %49: tensor<1xi32>,
+       %50: tensor<1xi32>, %51: tensor<1xi32>, %52: tensor<1xi32>, %53: tensor<1xi32>, %54: tensor<1xi32>, %55: tensor<1xi32>, %56: tensor<1xi32>, %57: tensor<1xi32>, %58: tensor<1xi32>, %59: tensor<1xi32>,
+       %60: tensor<1xi32>, %61: tensor<1xi32>, %62: tensor<1xi32>
+  ):
+    %2 = "tosa.const"() {values = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+    %3 = "tosa.add"(%arg3, %2) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+    "tosa.yield"(%3, %arg4,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0
+                ) : (
+      tensor<1xi32>, tensor<1x1x1x1x1x1x1xf32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
+    ) -> ()
   }) : (tensor<1xi32>, tensor<1x1x1x1x1x1x1xf32>,
                     tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
                     tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
@@ -1419,28 +1487,7 @@ func.func @test_while_tensor_list_size(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1:
                     tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
                     tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
                     tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
-  ) -> (tensor<i32>, tensor<1x1x1x1x1x1x1xf32>)
-
-  return
-}
-
-// -----
-
-// CHECK-LABEL: test_while_tensor_list_size_outputs
-func.func @test_while_tensor_list_size_outputs(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: tensor<1xi32>) {
-  %0 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
-  // expected-error at +1 {{'tosa.while_loop' op failed level check for MAX_TENSOR_LIST_SIZE: outputs}}
-  %1:65 = "tosa.while_loop"(%0, %arg0) ({
-  ^bb0(%arg3: tensor<1xi32>, %arg4: tensor<1x1x1x1x1x1x1xf32>):
-    %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
-    %3 = "tosa.logical_not"(%2) : (tensor<1xi1>) -> tensor<1xi1>
-    "tosa.yield"(%3) : (tensor<1xi1>) -> ()
-  },  {
-  ^bb0(%arg3: tensor<i32>, %arg4: tensor<1x1x1x1x1x1x1xf32>):
-    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
-    %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
-    "tosa.yield"(%3, %arg4) : (tensor<i32>, tensor<1x1x1x1x1x1x1xf32>) -> ()
-  }) : (tensor<1xi32>, tensor<1x1x1x1x1x1x1xf32>) -> ( tensor<i32>, tensor<1x1x1x1x1x1x1xf32>,
+  ) -> (tensor<1xi32>, tensor<1x1x1x1x1x1x1xf32>,
                     tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
                     tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
                     tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
@@ -1453,3 +1500,101 @@ func.func @test_while_tensor_list_size_outputs(%arg0: tensor<1x1x1x1x1x1x1xf32>,
 
   return
 }
+
+// -----
+
+func.func @test_cond_if_max_nested_depth(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>, %arg3: tensor<i1>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.cond_if' op failed level check: 6 >= MAX_NESTING}}
+  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+    %1 = tosa.cond_if %arg3 -> (tensor<f32>) {
+      %2 = tosa.cond_if %arg2 -> (tensor<f32>) {
+        %3 = tosa.cond_if %arg3 -> (tensor<f32>) {
+          %4 = tosa.cond_if %arg2 -> (tensor<f32>) {
+            %5 = tosa.cond_if %arg3 -> (tensor<f32>) {
+              %res = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+              tosa.yield %res : tensor<f32>
+            } else {
+              tosa.yield %arg0 : tensor<f32>
+            }
+            tosa.yield %5 : tensor<f32>
+          } else {
+            tosa.yield %arg0 : tensor<f32>
+          }
+          tosa.yield %4 : tensor<f32>
+        } else {
+          tosa.yield %arg0 : tensor<f32>
+        }
+        tosa.yield %3 : tensor<f32>
+      } else {
+        tosa.yield %arg0 : tensor<f32>
+      }
+      tosa.yield %2 : tensor<f32>
+    } else {
+      tosa.yield %arg0 : tensor<f32>
+    }
+    tosa.yield %1 : tensor<f32>
+  } else {
+    %res = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %res : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_while_loop_max_nested_depth(%arg0: tensor<i32>) {
+  %init_0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  %cst_1 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+
+  // expected-error at +1 {{'tosa.while_loop' op failed level check: 6 >= MAX_NESTING}}
+  %1:2 = tosa.while_loop (%arg2 = %init_0, %arg3 = %arg0) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+    %2 = tosa.greater_equal %arg3, %arg2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    tosa.yield %2 : tensor<i1>
+  } do {
+  ^bb0(%arg2: tensor<i32>, %arg2b: tensor<i32>):
+    %1:2 = tosa.while_loop (%arg4 = %init_0, %arg5 = %arg0) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+      %2 = tosa.greater_equal %arg5, %arg4 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+      tosa.yield %2 : tensor<i1>
+    } do {
+    ^bb0(%arg4: tensor<i32>, %arg4b: tensor<i32>):
+      %1:2 = tosa.while_loop (%arg6 = %init_0, %arg7 = %arg0) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+        %2 = tosa.greater_equal %arg7, %arg6 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+        tosa.yield %2 : tensor<i1>
+      } do {
+      ^bb0(%arg6: tensor<i32>, %arg6b: tensor<i32>):
+        %1:2 = tosa.while_loop (%arg8 = %init_0, %arg9 = %arg0) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+          %2 = tosa.greater_equal %arg9, %arg8 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+          tosa.yield %2 : tensor<i1>
+        } do {
+        ^bb0(%arg8: tensor<i32>, %arg8b: tensor<i32>):
+          %1:2 = tosa.while_loop (%arg10 = %init_0, %arg11 = %arg0) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+            %2 = tosa.greater_equal %arg11, %arg10 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+            tosa.yield %2 : tensor<i1>
+          } do {
+          ^bb0(%arg10: tensor<i32>, %arg10b: tensor<i32>):
+            %1:2 = tosa.while_loop (%arg12 = %init_0, %arg13 = %arg0) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+              %2 = tosa.greater_equal %arg13, %arg12 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+              tosa.yield %2 : tensor<i1>
+            } do {
+            ^bb0(%arg12: tensor<i32>, %arg12b: tensor<i32>):
+              %3 = tosa.add %arg12, %cst_1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+              tosa.yield %arg2, %3: tensor<i32>, tensor<i32>
+            }
+            %3 = tosa.add %arg10, %cst_1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+            tosa.yield %arg2, %3: tensor<i32>, tensor<i32>
+          }
+          %3 = tosa.add %arg8, %cst_1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+          tosa.yield %arg2, %3: tensor<i32>, tensor<i32>
+        }
+        %3 = tosa.add %arg6, %cst_1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+        tosa.yield %arg2, %3: tensor<i32>, tensor<i32>
+      }
+      %3 = tosa.add %arg4, %cst_1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+      tosa.yield %arg2, %3: tensor<i32>, tensor<i32>
+    }
+    %3 = tosa.add %arg2, %cst_1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %arg2, %3: tensor<i32>, tensor<i32>
+  }
+  return
+}
+
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 262e6d4265ea6..dd59a8feb8dc3 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -358,3 +358,351 @@ func.func @test_concat_axis_sum_error(%arg0: tensor<1x2xf32>, %arg1: tensor<2x?x
   %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
   return %0 : tensor<2x?xf32>
 }
+
+// -----
+
+func.func @test_cond_if_input_list_mismatch_then_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'then_graph' arguments (1) and 'input_list' (2)}}
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<f32>):
+    tosa.yield %arg3 : tensor<f32>
+  },  {
+  ^bb0(%arg4: tensor<f32>):
+    tosa.yield %arg4 : tensor<f32>
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+  return %0 : tensor<f32>
+
+}
+
+// -----
+
+func.func @test_cond_if_input_list_mismatch_then_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'then_graph' arguments (2) and 'input_list' (1)}}
+  %0 = "tosa.cond_if"(%arg2, %arg0) ({
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    tosa.yield %arg3 : tensor<f32>
+  },  {
+  ^bb0(%arg4: tensor<f32>):
+    tosa.yield %arg4 : tensor<f32>
+  }) : (tensor<i1>, tensor<f32>) -> tensor<f32>
+  return %0 : tensor<f32>
+
+}
+
+// -----
+
+func.func @test_cond_if_input_list_mismatch_else_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'else_graph' arguments (1) and 'input_list' (2)}}
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    tosa.yield %arg3 : tensor<f32>
+  },  {
+  ^bb0(%arg4: tensor<f32>):
+    tosa.yield %arg4 : tensor<f32>
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+  return %0 : tensor<f32>
+
+}
+
+// -----
+
+func.func @test_cond_if_input_list_mismatch_else_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'else_graph' arguments (2) and 'input_list' (1)}}
+  %0 = "tosa.cond_if"(%arg2, %arg0) ({
+  ^bb0(%arg3: tensor<f32>):
+    tosa.yield %arg3 : tensor<f32>
+  },  {
+  ^bb0(%arg4: tensor<f32>, %arg3: tensor<f32>):
+    tosa.yield %arg4 : tensor<f32>
+  }) : (tensor<i1>, tensor<f32>) -> tensor<f32>
+  return %0 : tensor<f32>
+
+}
+
+// -----
+
+func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (2) and 'output_list' (1)}}
+  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+    %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    %2 = tosa.add %1, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1, %2 : tensor<f32>, tensor<f32>
+  } else {
+    %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (1) and 'output_list' (2)}}
+  %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) {
+    %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  } else {
+    %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (2) and 'output_list' (1)}}
+  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+    %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  } else {
+    %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    %2 = tosa.add %1, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1, %2 : tensor<f32>, tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_output_list_mismatch_else_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (1) and 'output_list' (2)}}
+  %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) {
+    %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    %2 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1, %2 : tensor<f32>, tensor<f32>
+  } else {
+    %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_cond_input_not_size_one(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<2xi1>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.cond_if' op 'condition' must be a size 1 tensor, got 'tensor<2xi1>'}}
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    tosa.yield %arg3 : tensor<f32>
+  },  {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    tosa.yield %arg4 : tensor<f32>
+  }) : (tensor<2xi1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+  return %0 : tensor<f32>
+
+}
+
+// -----
+
+func.func @test_while_loop_input_list_mismatch_body_block_in(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error at +1 {{'tosa.while_loop' op require same number of values in 'body_graph' arguments (3) and 'input_list' (2)}}
+  %1:2 = tosa.while_loop (%arg2 = %0, %arg3 = %arg0) : (tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<10xi32>) {
+    %2 = tosa.greater_equal %arg2, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    tosa.yield %2 : tensor<i1>
+  } do {
+  ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<10xi32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %3, %arg4 : tensor<i32>, tensor<10xi32>
+  }
+  return
+}
+
+// -----
+
+func.func @test_while_loop_input_list_mismatch_body_block_in_2(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error at +1 {{'tosa.while_loop' op require same number of values in 'body_graph' arguments (2) and 'input_list' (3)}}
+  %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %arg0, %arg4 = %arg0)
+    : (tensor<i32>, tensor<10xi32>, tensor<10xi32>) -> (tensor<i32>, tensor<10xi32>, tensor<10xi32>) {
+    %2 = tosa.greater_equal %arg2, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    tosa.yield %2 : tensor<i1>
+  } do {
+  ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %3, %arg3 : tensor<i32>, tensor<i32>
+  }
+  return
+}
+
+// -----
+
+func.func @test_while_loop_input_list_mismatch_output_list(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error at +1 {{'tosa.while_loop' op require same number of values in 'input_list' (3) and 'output_list' (2)}}
+  %1:2 = tosa.while_loop (%arg2 = %0, %arg3 = %arg0, %arg4 = %arg0)
+    : (tensor<i32>, tensor<10xi32>, tensor<10xi32>) -> (tensor<i32>, tensor<10xi32>) {
+    %2 = tosa.greater_equal %arg2, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    tosa.yield %2 : tensor<i1>
+  } do {
+  ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %3, %arg3 : tensor<i32>, tensor<i32>
+  }
+  return
+}
+
+// -----
+
+func.func @test_while_loop_input_list_mismatch_output_list_2(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error at +1 {{'tosa.while_loop' op require same number of values in 'input_list' (2) and 'output_list' (3)}}
+  %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %arg0)
+    : (tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<10xi32>, tensor<10xi32>) {
+    %2 = tosa.greater_equal %arg2, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    tosa.yield %2 : tensor<i1>
+  } do {
+  ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %3, %arg3 : tensor<i32>, tensor<i32>
+  }
+  return
+}
+
+// -----
+
+func.func @test_while_loop_input_list_mismatch_cond_block(%arg0: tensor<2xf32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error at +1 {{'tosa.while_loop' op require same number of values in 'cond_graph' arguments (3) and 'input_list' (2)}}
+  %1:2 = "tosa.while_loop"(%0, %arg0) ({
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<2xf32>, %arg5: tensor<2xf32>):
+    %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    "tosa.yield"(%2) : (tensor<i1>) -> ()
+  },  {
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<2xf32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = "tosa.const"() {values = dense<2> : tensor<1xi8>} : () -> tensor<1xi8>
+    %4 = "tosa.mul"(%arg3, %2, %3) : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+    "tosa.yield"(%4, %arg4) : (tensor<i32>, tensor<2xf32>) -> ()
+  }) : (tensor<i32>, tensor<2xf32>) -> (tensor<i32>, tensor<2xf32>)
+  return
+}
+
+// -----
+
+func.func @test_while_loop_input_list_mismatch_cond_block_2(%arg0: tensor<2xf32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error at +1 {{'tosa.while_loop' op require same number of values in 'cond_graph' arguments (1) and 'input_list' (3)}}
+  %1:3 = "tosa.while_loop"(%0, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<i32>):
+    %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    "tosa.yield"(%2) : (tensor<i1>) -> ()
+  },  {
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<2xf32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = "tosa.const"() {values = dense<2> : tensor<1xi8>} : () -> tensor<1xi8>
+    %4 = "tosa.mul"(%arg3, %2, %3) : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+    "tosa.yield"(%4, %arg4) : (tensor<i32>, tensor<2xf32>) -> ()
+  }) : (tensor<i32>, tensor<2xf32>, tensor<i32>) -> (tensor<i32>, tensor<2xf32>, tensor<i32>)
+  return
+}
+
+// -----
+
+func.func @test_while_loop_input_list_mismatch_body_block_out(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error at +1 {{'tosa.while_loop' op require same number of values in 'body_graph' results (3) and 'input_list' (2)}}
+  %1:2 = tosa.while_loop (%arg2 = %0, %arg3 = %arg0) : (tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<10xi32>) {
+    %2 = tosa.greater_equal %arg2, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    tosa.yield %2 : tensor<i1>
+  } do {
+  ^bb0(%arg2: tensor<i32>, %arg4: tensor<10xi32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %2, %3, %arg4 : tensor<i32>, tensor<i32>, tensor<10xi32>
+  }
+  return
+}
+
+// -----
+
+func.func @test_while_loop_input_list_mismatch_body_block_out_2(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error at +1 {{'tosa.while_loop' op require same number of values in 'body_graph' results (1) and 'input_list' (2)}}
+  %1:2 = tosa.while_loop (%arg2 = %0, %arg3 = %arg0) : (tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<10xi32>) {
+    %2 = tosa.greater_equal %arg2, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    tosa.yield %2 : tensor<i1>
+  } do {
+  ^bb0(%arg2: tensor<i32>, %arg4: tensor<10xi32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %3 : tensor<i32>
+  }
+  return
+}
+
+// -----
+
+func.func @test_while_loop_type_mismatch(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error at +1 {{'tosa.while_loop' op require same element type for 'body_graph' arguments ('f32') and 'input_list' ('i32')}}
+  %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %0, %arg4 = %arg0) : (tensor<i32>, tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<10xi32>) {
+    %2 = tosa.greater_equal %arg3, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    %3 = tosa.logical_not %2 : (tensor<i1>) -> tensor<i1>
+    tosa.yield %3 : tensor<i1>
+  } do {
+  ^bb0(%arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<10xi32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %6 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %6, %2, %arg4 : tensor<i32>, tensor<i32>, tensor<10xi32>
+  }
+  return
+}
+
+// -----
+
+func.func @test_while_loop_type_mismatch_2(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error at +1 {{'tosa.while_loop' op require same shapes for 'body_graph' arguments ('tensor<10xi32>') and 'input_list' ('tensor<i32>')}}
+  %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %0, %arg4 = %arg0) : (tensor<i32>, tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<10xi32>) {
+    %2 = tosa.greater_equal %arg3, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    %3 = tosa.logical_not %2 : (tensor<i1>) -> tensor<i1>
+    tosa.yield %3 : tensor<i1>
+  } do {
+  ^bb0(%arg2: tensor<10xi32>, %arg3: tensor<i32>, %arg4: tensor<10xi32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %6 = tosa.add %arg2, %2 : (tensor<10xi32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %6, %2, %arg4 : tensor<i32>, tensor<i32>, tensor<10xi32>
+  }
+  return
+}
+
+// -----
+
+func.func @test_while_loop_cond_output_not_size_one(%arg0: tensor<10xi32>, %arg1: tensor<2xi32>) {
+  %0 = "tosa.const"() {values = dense<[4, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // expected-error at +1 {{'tosa.while_loop' op 'cond_graph' result must be a size 1 tensor, got 'tensor<2xi1>'}}
+  %1:3 = tosa.while_loop (%arg2 = %arg0, %arg3 = %0, %arg4 = %arg0) : (tensor<10xi32>, tensor<2xi32>, tensor<10xi32>) -> (tensor<10xi32>, tensor<2xi32>, tensor<10xi32>) {
+    %2 = tosa.greater_equal %arg3, %arg1 : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
+    tosa.yield %2 : tensor<2xi1>
+  } do {
+  ^bb0(%arg2: tensor<10xi32>, %arg3: tensor<2xi32>, %arg4: tensor<10xi32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = "tosa.const"() {values = dense<[3, 5]> : tensor<2xi32>} : () -> tensor<2xi32>
+    %4 = tosa.add %arg2, %2 : (tensor<10xi32>, tensor<i32>) -> tensor<10xi32>
+    tosa.yield %4, %3, %arg4 : tensor<10xi32>, tensor<2xi32>, tensor<10xi32>
+  }
+  return
+}
+
+// -----
+
+func.func @test_while_loop_cond_output_not_bool(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {values = dense<9> : tensor<i32>} : () -> tensor<i32>
+  // expected-error at +1 {{'tosa.while_loop' op 'cond_graph' result must be a boolean tensor, got 'tensor<i32>'}}
+  %1:3 = tosa.while_loop (%arg2 = %arg0, %arg3 = %0, %arg4 = %arg0) : (tensor<10xi32>, tensor<i32>, tensor<10xi32>) -> (tensor<10xi32>, tensor<i32>, tensor<10xi32>) {
+    %2 = tosa.add %arg3, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %2 : tensor<i32>
+  } do {
+  ^bb0(%arg2: tensor<10xi32>, %arg3: tensor<i32>, %arg4: tensor<10xi32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %4 = tosa.add %arg2, %2 : (tensor<10xi32>, tensor<i32>) -> tensor<10xi32>
+    tosa.yield %4, %2, %arg4 : tensor<10xi32>, tensor<i32>, tensor<10xi32>
+  }
+  return
+}



More information about the Mlir-commits mailing list