[Mlir-commits] [mlir] d713a00 - [TOSA] Add level checks and remove Tensor1DTo4D

Eric Kunze llvmlistbot at llvm.org
Wed Jul 12 09:58:29 PDT 2023


Author: Tai Ly
Date: 2023-07-12T16:56:44Z
New Revision: d713a00270cc31e6a4968ff08a84fbf5f64ac830

URL: https://github.com/llvm/llvm-project/commit/d713a00270cc31e6a4968ff08a84fbf5f64ac830
DIFF: https://github.com/llvm/llvm-project/commit/d713a00270cc31e6a4968ff08a84fbf5f64ac830.diff

LOG: [TOSA] Add level checks and remove Tensor1DTo4D

Remove Tosa_Tensor1Dto4D and Tosa_TensorUpto4D in the Tosa Dialect
and added level checks to TosaValidation pass to validate per spec.

Signed-off-by: Tai Ly <tai.ly at arm.com>
Change-Id: Icd32137e9f8051f99994cee9f388f20c1a840f4b

Reviewed By: eric-k256

Differential Revision: https://reviews.llvm.org/D154273

Added: 
    mlir/test/Dialect/Tosa/level_check.mlir

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
    mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 9a6370591011d5..e5b4e664202f7d 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -44,12 +44,12 @@ def Tosa_ArgMaxOp : Tosa_Op<"argmax", [
   }];
 
   let arguments = (ins
-    Tosa_Tensor1Dto4D: $input,
+    Tosa_Tensor: $input,
     I64Attr: $axis
   );
 
   let results = (outs
-    Tosa_TensorUpto4D: $output
+    Tosa_Tensor: $output
   );
 }
 
@@ -1222,12 +1222,12 @@ def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [
   }];
 
   let arguments = (ins
-    Tosa_Tensor1Dto4D:$input,
+    Tosa_Tensor:$input,
     I64Attr:$axis
   );
 
   let results = (outs
-    Tosa_Tensor1Dto4D:$output
+    Tosa_Tensor:$output
   );
 
   let hasFolder = 1;
@@ -1251,12 +1251,12 @@ def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [
   }];
 
   let arguments = (ins
-    Tosa_Tensor1Dto4D:$input,
+    Tosa_Tensor:$input,
     I64Attr:$axis
   );
 
   let results = (outs
-    Tosa_Tensor1Dto4D:$output
+    Tosa_Tensor:$output
   );
 
   let hasFolder = 1;
@@ -1280,12 +1280,12 @@ def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [
   }];
 
   let arguments = (ins
-    Tosa_Tensor1Dto4D:$input,
+    Tosa_Tensor:$input,
     I64Attr:$axis
   );
 
   let results = (outs
-    Tosa_Tensor1Dto4D:$output
+    Tosa_Tensor:$output
   );
 
   let hasFolder = 1;
@@ -1309,12 +1309,12 @@ def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [
   }];
 
   let arguments = (ins
-    Tosa_Tensor1Dto4D:$input,
+    Tosa_Tensor:$input,
     I64Attr:$axis
   );
 
   let results = (outs
-    Tosa_Tensor1Dto4D:$output
+    Tosa_Tensor:$output
   );
 
   let hasFolder = 1;
@@ -1338,12 +1338,12 @@ def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [
   }];
 
   let arguments = (ins
-    Tosa_Tensor1Dto4D:$input,
+    Tosa_Tensor:$input,
     I64Attr:$axis
   );
 
   let results = (outs
-    Tosa_Tensor1Dto4D:$output
+    Tosa_Tensor:$output
   );
 
   let hasFolder = 1;
@@ -1367,12 +1367,12 @@ def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [
   }];
 
   let arguments = (ins
-    Tosa_Tensor1Dto4D:$input,
+    Tosa_Tensor:$input,
     I64Attr:$axis
   );
 
   let results = (outs
-    Tosa_Tensor1Dto4D:$output
+    Tosa_Tensor:$output
   );
 
   let hasFolder = 1;
@@ -1515,12 +1515,12 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
   }];
 
   let arguments = (ins
-    Tosa_Tensor1Dto4D:$input,
+    Tosa_Tensor:$input,
     I64Attr:$axis
   );
 
   let results = (outs
-    Tosa_Tensor1Dto4D:$output
+    Tosa_Tensor:$output
   );
 
   let hasFolder = 1;
@@ -1541,13 +1541,13 @@ def Tosa_SliceOp: Tosa_Op<"slice", [
   }];
 
   let arguments = (ins
-    Tosa_Tensor1Dto6D:$input,
+    Tosa_Tensor:$input,
     DenseI64ArrayAttr:$start,
     DenseI64ArrayAttr:$size
   );
 
   let results = (outs
-    Tosa_Tensor1Dto6D:$output
+    Tosa_Tensor:$output
   );
 
   let hasCanonicalizer = 1;
@@ -1568,11 +1568,11 @@ def Tosa_TileOp: Tosa_Op<"tile", [
   }];
 
   let arguments = (ins
-    Tosa_Tensor1Dto4D:$input1,
+    Tosa_Tensor:$input1,
     DenseI64ArrayAttr:$multiples);
 
   let results = (outs
-    Tosa_Tensor1Dto4D:$output
+    Tosa_Tensor:$output
   );
 
   let hasFolder = 1;
@@ -1592,12 +1592,12 @@ def Tosa_TransposeOp : Tosa_Op<"transpose", [
   }];
 
   let arguments = (ins
-    Tosa_Tensor1Dto6D:$input1,
+    Tosa_Tensor:$input1,
     Tosa_Int32Or64Tensor:$perms
   );
 
   let results = (
-    outs Tosa_Tensor1Dto6D:$output
+    outs Tosa_Tensor:$output
   );
 
   let extraClassDeclaration = [{

diff  --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 1c3bfbebb1ccc8..11f17dc5e66b75 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -74,6 +74,14 @@ def TosaProfileType : I32EnumAttr<"TosaProfileEnum", "Tosa profile",
   let cppNamespace = "mlir::tosa";
 }
 
+def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level",
+    [
+      I32EnumAttrCase<"None", 0, "none">,
+      I32EnumAttrCase<"EightK", 1, "8k">,
+    ]>{
+  let cppNamespace = "mlir::tosa";
+}
+
 def TosaValidation : Pass<"tosa-validate", "func::FuncOp"> {
   let summary = "Validates TOSA dialect";
   let description = [{
@@ -89,6 +97,9 @@ def TosaValidation : Pass<"tosa-validate", "func::FuncOp"> {
       Option<"StrictOperationSpecAlignment", "strict-op-spec-alignment", "bool",
              /*default=*/"false",
              "Verify if the properties of certain operations align the spec requirement">,
+      Option<"levelName", "level", "std::string",
+             /*default=*/"\"8k\"",
+             "Validate if operator parameters are within specfication for the given level">,
    ];
 }
 

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 5605080384bd7f..05da6ee6cad95d 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -72,6 +72,23 @@ static LogicalResult checkConstantOperandFullyConnected(Operation *op) {
   return success();
 }
 
+struct tosa_level_t {
+  int32_t MAX_RANK = 0;
+  int32_t MAX_KERNEL = 0;
+  int32_t MAX_STRIDE = 0;
+  int32_t MAX_SCALE = 0;
+
+  // @todo: MAX_LOG2_SIZE value and checks
+
+  bool operator==(const tosa_level_t &rhs) {
+    return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
+           MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE;
+  }
+};
+
+static constexpr tosa_level_t TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 64};
+static constexpr tosa_level_t TOSA_LEVEL_NONE = {0, 0, 0, 0};
+
 //===----------------------------------------------------------------------===//
 // TOSA Validation Pass.
 //===----------------------------------------------------------------------===//
@@ -89,6 +106,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     return success();
   }
 
+  LogicalResult applyLevelCheck(Operation *op);
+
 private:
   void populateConstantOperandChecks() {
     const_checkers.emplace_back(checkConstantOperandPad);
@@ -96,13 +115,320 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     const_checkers.emplace_back(checkConstantOperandFullyConnected);
   }
 
+  bool levelCheckKernel(Operation *op, int32_t v,
+                        const std::string &check_desc) {
+    if (v > tosa_level.MAX_KERNEL) {
+      op->emitOpError() << "failed level check: " << check_desc;
+      return false;
+    }
+    return true;
+  }
+
+  bool levelCheckStride(Operation *op, int32_t v,
+                        const std::string &check_desc) {
+    if (v > tosa_level.MAX_STRIDE) {
+      op->emitOpError() << "failed level check: " << check_desc;
+      return false;
+    }
+    return true;
+  }
+
+  bool levelCheckScale(Operation *op, int32_t v,
+                       const std::string &check_desc) {
+    if (v > tosa_level.MAX_SCALE) {
+      op->emitOpError() << "failed level check: " << check_desc;
+      return false;
+    }
+    return true;
+  }
+
+  bool levelCheckRank(Operation *op, const Value &v,
+                      const std::string &check_desc) {
+    if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
+      if (type.getRank() > tosa_level.MAX_RANK) {
+        op->emitOpError() << "failed level check: " << check_desc;
+        return false;
+      }
+    }
+    return true;
+  }
+
+  template <typename T>
+  bool levelCheckRanksFor(Operation *op) {
+    if (dyn_cast<T>(op)) {
+      // level check ranks of all operands and results
+      for (auto v : op->getOperands()) {
+        if (!levelCheckRank(op, v, "operand rank(shape) <= MAX_RANK"))
+          return false;
+      }
+      for (auto v : op->getResults()) {
+        if (!levelCheckRank(op, v, "result rank(shape) <= MAX_RANK"))
+          return false;
+      }
+    }
+    return true;
+  }
+
+  bool levelCheckRanks(Operation *op) {
+#define CHECK_RANKS_FOR(tosa_op)                                               \
+  if (!levelCheckRanksFor<tosa_op##Op>(op))                                    \
+    return false;
+
+    // tensor operators:
+    CHECK_RANKS_FOR(ArgMax);
+    // all activation functions:
+    CHECK_RANKS_FOR(Clamp);
+    CHECK_RANKS_FOR(Sigmoid);
+    CHECK_RANKS_FOR(Tanh);
+    // all elementwise binary operators:
+    CHECK_RANKS_FOR(Add);
+    CHECK_RANKS_FOR(ArithmeticRightShift);
+    CHECK_RANKS_FOR(BitwiseAnd);
+    CHECK_RANKS_FOR(BitwiseOr);
+    CHECK_RANKS_FOR(BitwiseXor);
+    CHECK_RANKS_FOR(Div);
+    CHECK_RANKS_FOR(LogicalAnd);
+    CHECK_RANKS_FOR(LogicalLeftShift);
+    CHECK_RANKS_FOR(LogicalRightShift);
+    CHECK_RANKS_FOR(LogicalOr);
+    CHECK_RANKS_FOR(LogicalXor);
+    CHECK_RANKS_FOR(Maximum);
+    CHECK_RANKS_FOR(Minimum);
+    CHECK_RANKS_FOR(Mul);
+    CHECK_RANKS_FOR(Pow);
+    CHECK_RANKS_FOR(Sub);
+    CHECK_RANKS_FOR(Table);
+    // all elementwise unary operators:
+    CHECK_RANKS_FOR(Abs);
+    CHECK_RANKS_FOR(BitwiseNot);
+    CHECK_RANKS_FOR(Ceil);
+    CHECK_RANKS_FOR(Clz);
+    CHECK_RANKS_FOR(Exp);
+    CHECK_RANKS_FOR(Floor);
+    CHECK_RANKS_FOR(Log);
+    CHECK_RANKS_FOR(LogicalNot);
+    CHECK_RANKS_FOR(Negate);
+    CHECK_RANKS_FOR(Reciprocal);
+    CHECK_RANKS_FOR(Rsqrt);
+    // all elementwise ternary operators:
+    CHECK_RANKS_FOR(Select);
+    // all comparison operators:
+    CHECK_RANKS_FOR(Equal);
+    CHECK_RANKS_FOR(Greater);
+    CHECK_RANKS_FOR(GreaterEqual);
+    // all reduction operators:
+    CHECK_RANKS_FOR(ReduceAll);
+    CHECK_RANKS_FOR(ReduceAny);
+    CHECK_RANKS_FOR(ReduceMax);
+    CHECK_RANKS_FOR(ReduceMin);
+    CHECK_RANKS_FOR(ReduceProd);
+    CHECK_RANKS_FOR(ReduceSum);
+    // all data layout operators:
+    CHECK_RANKS_FOR(Concat);
+    CHECK_RANKS_FOR(Pad);
+    CHECK_RANKS_FOR(Reshape);
+    CHECK_RANKS_FOR(Reverse);
+    CHECK_RANKS_FOR(Slice);
+    CHECK_RANKS_FOR(Tile);
+    CHECK_RANKS_FOR(Transpose);
+    // all type conversion operators:
+    CHECK_RANKS_FOR(Cast);
+    CHECK_RANKS_FOR(Rescale);
+    // all data nodes operators:
+    CHECK_RANKS_FOR(Const);
+    CHECK_RANKS_FOR(Identity);
+
+#undef CHECK_RANKS_FOR
+    return true;
+  }
+
+  // Pool Op: level check kernel/stride/pad values
+  template <typename T>
+  bool levelCheckPool(Operation *op) {
+    if (auto pool_op = dyn_cast<T>(op)) {
+      for (auto k : pool_op.getKernel()) {
+        if (!levelCheckKernel(op, k, "kernel <= MAX_KERNEL")) {
+          return false;
+        }
+      }
+      for (auto s : pool_op.getStride()) {
+        if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
+          return false;
+        }
+      }
+      for (auto p : pool_op.getPad()) {
+        if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
+          return false;
+        }
+      }
+    }
+    return true;
+  }
+
+  // Conv Op: level check dilation/stride/pad values
+  template <typename T>
+  bool levelCheckConv(Operation *op) {
+    if (auto conv_op = dyn_cast<T>(op)) {
+
+      for (auto k : conv_op.getDilation()) {
+        if (!levelCheckKernel(op, k, "dilation <= MAX_KERNEL")) {
+          return false;
+        }
+      }
+      for (auto p : conv_op.getPad()) {
+        if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
+          return false;
+        }
+      }
+      for (auto s : conv_op.getStride()) {
+        if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
+          return false;
+        }
+      }
+      auto dilation = conv_op.getDilation();
+      if (ShapedType weight_type =
+              dyn_cast<ShapedType>(op->getOperand(1).getType())) {
+        auto shape = weight_type.getShape();
+        if (isa<tosa::Conv2DOp>(op)) {
+          assert(shape.size() == 4);
+          assert(dilation.size() == 2);
+          if (!levelCheckKernel(op, dilation[0] * shape[1],
+                                "dilation_y * KH <= MAX_KERNEL)") ||
+              !levelCheckKernel(op, dilation[1] * shape[2],
+                                "dilation_x * KW <= MAX_KERNEL)"))
+            return false;
+        } else if (isa<tosa::Conv3DOp>(op)) {
+          assert(shape.size() == 5);
+          assert(dilation.size() == 3);
+          if (!levelCheckKernel(op, dilation[0] * shape[1],
+                                "dilation_d * KD <= MAX_KERNEL)") ||
+              !levelCheckKernel(op, dilation[1] * shape[2],
+                                "dilation_y * KH <= MAX_KERNEL)") ||
+              !levelCheckKernel(op, dilation[2] * shape[3],
+                                "dilation_x * KW <= MAX_KERNEL)"))
+            return false;
+        } else if (isa<tosa::DepthwiseConv2DOp>(op)) {
+          assert(shape.size() == 4);
+          assert(dilation.size() == 2);
+          if (!levelCheckKernel(op, dilation[0] * shape[0],
+                                "dilation_y * KH <= MAX_KERNEL)") ||
+              !levelCheckKernel(op, dilation[1] * shape[1],
+                                "dilation_x * KW <= MAX_KERNEL)"))
+            return false;
+        }
+      }
+    }
+    return true;
+  }
+
+  // FFT op: level check H, W in input shape [N,H,W]
+  template <typename T>
+  bool levelCheckFFT(Operation *op) {
+    if (isa<T>(op)) {
+      for (auto v : op->getOperands()) {
+        if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
+          auto shape = type.getShape();
+          assert(shape.size() == 3);
+          if (!levelCheckKernel(op, shape[1], "H <= MAX_KERNEL") ||
+              !levelCheckKernel(op, shape[2], "W <= MAX_KERNEL")) {
+            return false;
+          }
+        }
+      }
+    }
+    return true;
+  }
+
+  // TransposeConv2d op: level check kH/kW, outpad, and stride
+  bool levelCheckTransposeConv2d(Operation *op) {
+    if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
+      if (ShapedType filter_type =
+              transpose.getFilter().getType().dyn_cast<ShapedType>()) {
+        auto shape = filter_type.getShape();
+        assert(shape.size() == 4);
+        // level check kernel sizes for kH and KW
+        if (!levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL") ||
+            !levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL")) {
+          return false;
+        }
+      }
+      for (auto p : transpose.getOutPad()) {
+        if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
+          return false;
+        }
+      }
+      for (auto s : transpose.getStride()) {
+        if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
+          return false;
+        }
+      }
+    }
+    return true;
+  }
+
+  // Resize op: level check max scales
+  bool levelCheckResize(Operation *op) {
+    if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
+      auto scale = resize.getScale();
+      int16_t scale_y_n = scale[0];
+      int16_t scale_y_d = scale[1];
+      int16_t scale_x_n = scale[2];
+      int16_t scale_x_d = scale[3];
+      if (!levelCheckScale(op, scale_y_n / scale_y_d,
+                           "scale_y_n/scale_y_d <= MAX_SCALE") ||
+          !levelCheckScale(op, scale_x_n / scale_x_d,
+                           "scale_x_n/scale_x_d <= MAX_SCALE")) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  // configure profile and level values from pass options profileName and
+  // levelName
+  void configLevelAndProfile() {
+    profileType = symbolizeEnum<TosaProfileEnum>(profileName);
+
+    auto levelType = symbolizeEnum<TosaLevelEnum>(levelName);
+
+    tosa_level = TOSA_LEVEL_NONE;
+    if (levelType == TosaLevelEnum::EightK) {
+      tosa_level = TOSA_LEVEL_EIGHTK;
+    }
+  }
+
   SmallVector<std::function<LogicalResult(Operation *)>> const_checkers;
   std::optional<TosaProfileEnum> profileType;
+  tosa_level_t tosa_level;
 };
 
-void TosaValidation::runOnOperation() {
-  profileType = symbolizeEnum<TosaProfileEnum>(profileName);
+LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
+  if (tosa_level == TOSA_LEVEL_NONE) {
+    // no need to do level checks
+    return success();
+  }
+
+  if (!levelCheckRanks(op)) {
+    return failure();
+  }
+
+  // additional level checks from spec 0.70
+  if (!levelCheckPool<tosa::AvgPool2dOp>(op) ||
+      !levelCheckConv<tosa::Conv2DOp>(op) ||
+      !levelCheckConv<tosa::Conv3DOp>(op) ||
+      !levelCheckConv<tosa::DepthwiseConv2DOp>(op) ||
+      !levelCheckFFT<tosa::FFT2dOp>(op) ||
+      !levelCheckPool<tosa::MaxPool2dOp>(op) ||
+      !levelCheckFFT<tosa::RFFT2dOp>(op) || !levelCheckTransposeConv2d(op) ||
+      !levelCheckResize(op)) {
+    return failure();
+  }
 
+  return success();
+}
+
+void TosaValidation::runOnOperation() {
+  configLevelAndProfile();
   getOperation().walk([&](Operation *op) {
     for (Value operand : op->getOperands()) {
       if ((profileType == TosaProfileEnum::BaseInference) &&
@@ -117,6 +443,10 @@ void TosaValidation::runOnOperation() {
     // Some uses of TOSA rely on the constant operands of particular operations.
     if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op)))
       signalPassFailure();
+
+    // do level checks
+    if (failed(applyLevelCheck(op)))
+      signalPassFailure();
   });
 }
 } // namespace

diff  --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
new file mode 100644
index 00000000000000..40160a9452f5b0
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -0,0 +1,698 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate
+
+
+func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xf32> {
+  // expected-error at +1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = "tosa.argmax"(%arg0) {axis = 4 : i64} : (tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xf32>
+  return %0 : tensor<1x1x1x1x29x4xf32>
+}
+
+// -----
+
+func.func @test_reduce_all(%arg0: tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x1x21x3xi1> {
+  // expected-error at +1 {{'tosa.reduce_all' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = "tosa.reduce_all"(%arg0) {axis = 4 : i64} : (tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x1x21x3xi1>
+  return %0 : tensor<1x1x1x1x1x21x3xi1>
+}
+
+// -----
+
+func.func @test_reduce_any(%arg0: tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x13x21x3xi1> {
+  // expected-error at +1 {{'tosa.reduce_any' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = "tosa.reduce_any"(%arg0) {axis = 0 : i64} : (tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x13x21x3xi1>
+  return %0 : tensor<1x1x1x1x13x21x3xi1>
+}
+
+// -----
+
+func.func @test_reduce_max(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.reduce_max' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+
+func.func @test_reduce_min(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.reduce_min' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = "tosa.reduce_min"(%arg0) {axis = 0 : i64} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+
+func.func @test_reduce_prod(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.reduce_prod' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = "tosa.reduce_prod"(%arg0) {axis = 0 : i64} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+
+func.func @test_reduce_sum(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.reduce_sum' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+
+func.func @test_concat(%arg0: tensor<1x1x1x13x21x3x8xf32>, %arg1: tensor<1x1x1x13x21x3x8xf32>) -> tensor<1x1x1x26x21x3x8xf32> {
+  // expected-error at +1 {{'tosa.concat' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = "tosa.concat"(%arg0, %arg1) {axis = 3 : i64} : (tensor<1x1x1x13x21x3x8xf32>, tensor<1x1x1x13x21x3x8xf32>) -> tensor<1x1x1x26x21x3x8xf32>
+  return %0 : tensor<1x1x1x26x21x3x8xf32>
+}
+
+// -----
+
+func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x1x1x1x1x1x819xf32> {
+  // expected-error at +1 {{'tosa.reshape' op failed level check: result rank(shape) <= MAX_RANK}}
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 1, 1, 1, 1, 1, 819>} : (tensor<13x21x3xf32>) -> tensor<1x1x1x1x1x1x819xf32>
+  return %0 : tensor<1x1x1x1x1x1x819xf32>
+}
+
+// -----
+
+func.func @test_reverse(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.reverse' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: slice
+func.func @test_slice(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x4x11x1xf32> {
+  // expected-error at +1 {{'tosa.slice' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = "tosa.slice"(%arg0) {start = array<i64: 0, 0, 0, 0, 6, 8, 0>, size = array<i64: 1, 1, 1, 1, 4, 11, 1>} :
+          (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x4x11x1xf32>
+  return %0 : tensor<1x1x1x1x4x11x1xf32>
+}
+
+// -----
+// CHECK-LABEL: tile
+func.func @test_tile(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x39x21x6xf32> {
+  // expected-error at +1 {{'tosa.tile' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = "tosa.tile"(%arg0) {multiples = array<i64: 1, 1, 1, 1, 3, 1, 2>} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x39x21x6xf32>
+  return %0 : tensor<1x1x1x1x39x21x6xf32>
+}
+
+// -----
+
+func.func @test_transpose(%arg0: tensor<13x21x3x1x1x1x1xf32>) -> tensor<3x13x21x1x1x1x1xf32> {
+  %0 = "tosa.const"() {value = dense<[2, 0, 1, 3, 4, 5, 6]> : tensor<7xi32>} : () -> tensor<7xi32>
+  // expected-error at +1 {{'tosa.transpose' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %1 = "tosa.transpose"(%arg0, %0) : (tensor<13x21x3x1x1x1x1xf32>, tensor<7xi32>) -> tensor<3x13x21x1x1x1x1xf32>
+  return %1 : tensor<3x13x21x1x1x1x1xf32>
+}
+
+// -----
+
+func.func @test_const(%arg0 : tensor<1x1xi32>) -> tensor<1x1x1x1x1x1x1xi32> {
+  // expected-error at +1 {{'tosa.const' op failed level check: result rank(shape) <= MAX_RA}}
+  %0 = "tosa.const"() {value = dense<0> : tensor<1x1x1x1x1x1x1xi32>} : () -> tensor<1x1x1x1x1x1x1xi32>
+  return %0: tensor<1x1x1x1x1x1x1xi32>
+}
+
+// -----
+
+func.func @test_avgpool2d_kernel_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+  // expected-error at +1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}}
+  %0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 8193, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :
+         (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+
+func.func @test_avgpool2d_kernel_x(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+  // expected-error at +1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}}
+  %0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 1, 8193>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :
+         (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+
+func.func @test_avgpool2d_stride_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+  // expected-error at +1 {{'tosa.avg_pool2d' op failed level check: stride <= MAX_STRIDE}}
+  %0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 1, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 8193, 1>, acc_type = f32} :
+         (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+
+func.func @test_avgpool2d_stride_x(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+  // expected-error at +1 {{'tosa.avg_pool2d' op failed level check: stride <= MAX_STRIDE}}
+  %0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 1, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 8193>, acc_type = f32} :
+         (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+
+// -----
+
+func.func @test_avgpool2d_pad_top(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+  // expected-error at +1 {{'tosa.avg_pool2d' op failed level check: pad <= MAX_KERNEL}}
+  %0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 1, 1>, pad = array<i64: 8193, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :
+         (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+
+func.func @test_avgpool2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+  // expected-error at +1 {{'tosa.avg_pool2d' op failed level check: pad <= MAX_KERNEL}}
+  %0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 1, 1>, pad = array<i64: 4, 8193, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :
+         (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+
+func.func @test_avgpool2d_pad_left(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+  // expected-error at +1 {{'tosa.avg_pool2d' op failed level check: pad <= MAX_KERNEL}}
+  %0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 1, 1>, pad = array<i64: 4, 4, 8193, 4>, stride = array<i64: 1, 1>, acc_type = f32} :
+         (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+
+func.func @test_avgpool2d_pad_right(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+  // expected-error at +1 {{'tosa.avg_pool2d' op failed level check: pad <= MAX_KERNEL}}
+  %0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 1, 1>, pad = array<i64: 4, 4, 4, 8193>, stride = array<i64: 1, 1>, acc_type = f32} :
+         (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+
+func.func @test_conv2d_dilation_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.conv2d' op failed level check: dilation_y * KH <= MAX_KERNEL}}
+  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 4097, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} :
+            (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+  return %0 : tensor<1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_conv2d_dilation_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.conv2d' op failed level check: dilation_x * KW <= MAX_KERNEL}}
+  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 4097>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} :
+            (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+  return %0 : tensor<1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_conv2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.conv2d' op failed level check: pad <= MAX_KERNEL}}
+  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 8193, 1, 0, 1>, stride = array<i64: 1, 1>} :
+            (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+  return %0 : tensor<1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_conv2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.conv2d' op failed level check: pad <= MAX_KERNEL}}
+  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 8193, 0, 1>, stride = array<i64: 1, 1>} :
+            (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+  return %0 : tensor<1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_conv2d_pad_left(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.conv2d' op failed level check: pad <= MAX_KERNEL}}
+  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 8193, 1>, stride = array<i64: 1, 1>} :
+            (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+  return %0 : tensor<1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_conv2d_pad_right(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.conv2d' op failed level check: pad <= MAX_KERNEL}}
+  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 8193>, stride = array<i64: 1, 1>} :
+            (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+  return %0 : tensor<1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_conv2d_stride_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.conv2d' op failed level check: stride <= MAX_STRIDE}}
+  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 8193, 1>} :
+            (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+  return %0 : tensor<1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_conv2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.conv2d' op failed level check: stride <= MAX_STRIDE}}
+  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 8193>} :
+            (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+  return %0 : tensor<1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_conv3d_dilation_d(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.conv3d' op failed level check: dilation_d * KD <= MAX_KERNEL}}
+  %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 4097, 1, 1>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 1, 1, 1>} :
+            (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32>
+  return %0 : tensor<1x1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_conv3d_dilation_y(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.conv3d' op failed level check: dilation_y * KH <= MAX_KERNEL}}
+  %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 4097, 1>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 1, 1, 1>} :
+            (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32>
+  return %0 : tensor<1x1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_conv3d_dilation_x(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.conv3d' op failed level check: dilation_x * KW <= MAX_KERNEL}}
+  %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 4097>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 1, 1, 1>} :
+            (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32>
+  return %0 : tensor<1x1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_conv3d_pad_d0(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}}
+  %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 1>, pad = array<i64: 8193, 1, 0, 1, 0, 1>, stride = array<i64: 1, 1, 1>} :
+            (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32>
+  return %0 : tensor<1x1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_conv3d_pad_d1(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}}
+  %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 1>, pad = array<i64: 1, 8193, 0, 1, 0, 1>, stride = array<i64: 1, 1, 1>} :
+            (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32>
+  return %0 : tensor<1x1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_conv3d_pad_top(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}}
+  %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 8193, 1, 0, 1>, stride = array<i64: 1, 1, 1>} :
+            (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32>
+  return %0 : tensor<1x1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_conv3d_pad_bottom(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}}
+  %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 8193, 0, 1>, stride = array<i64: 1, 1, 1>} :
+            (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32>
+  return %0 : tensor<1x1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_conv3d_pad_left(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}}
+  %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 1, 8193, 1>, stride = array<i64: 1, 1, 1>} :
+            (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32>
+  return %0 : tensor<1x1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_conv3d_pad_right(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}}
+  %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 1, 0, 8193>, stride = array<i64: 1, 1, 1>} :
+            (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32>
+  return %0 : tensor<1x1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_conv3d_stride_d(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.conv3d' op failed level check: stride <= MAX_STRIDE}}
+  %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 8193, 1, 1>} :
+            (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32>
+  return %0 : tensor<1x1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_conv3d_stride_y(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.conv3d' op failed level check: stride <= MAX_STRIDE}}
+  %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 1, 8193, 1>} :
+            (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32>
+  return %0 : tensor<1x1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_conv3d_stride_x(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.conv3d' op failed level check: stride <= MAX_STRIDE}}
+  %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 1, 1, 8193>} :
+            (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32>
+  return %0 : tensor<1x1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_depthwise_conv2d_dilation_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> {
+  // expected-error at +1 {{'tosa.depthwise_conv2d' op failed level check: dilation_y * KH <= MAX_KERNEL}}
+  %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 4097, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} :
+            (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32>
+  return %0 : tensor<1x32x32x64xf32>
+}
+
+// -----
+
+func.func @test_depthwise_conv2d_dilation_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> {
+  // expected-error at +1 {{'tosa.depthwise_conv2d' op failed level check: dilation_x * KW <= MAX_KERNEL}}
+  %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 4097>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} :
+            (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32>
+  return %0 : tensor<1x32x32x64xf32>
+}
+
+// -----
+
+func.func @test_depthwise_conv2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> {
+  // expected-error at +1 {{'tosa.depthwise_conv2d' op failed level check: pad <= MAX_KERNEL}}
+  %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 8193, 1, 0, 1>, stride = array<i64: 1, 1>} :
+            (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32>
+  return %0 : tensor<1x32x32x64xf32>
+}
+
+// -----
+
+func.func @test_depthwise_conv2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> {
+  // expected-error at +1 {{'tosa.depthwise_conv2d' op failed level check: pad <= MAX_KERNEL}}
+  %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 8193, 0, 1>, stride = array<i64: 1, 1>} :
+            (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32>
+  return %0 : tensor<1x32x32x64xf32>
+}
+
+// -----
+
+func.func @test_depthwise_conv2d_pad_left(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> {
+  // expected-error at +1 {{'tosa.depthwise_conv2d' op failed level check: pad <= MAX_KERNEL}}
+  %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 8193, 1>, stride = array<i64: 1, 1>} :
+            (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32>
+  return %0 : tensor<1x32x32x64xf32>
+}
+
+// -----
+
+func.func @test_depthwise_conv2d_pad_right(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> {
+  // expected-error at +1 {{'tosa.depthwise_conv2d' op failed level check: pad <= MAX_KERNEL}}
+  %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 8193>, stride = array<i64: 1, 1>} :
+            (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32>
+  return %0 : tensor<1x32x32x64xf32>
+}
+
+// -----
+
+func.func @test_depthwise_conv2d_stride_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> {
+  // expected-error at +1 {{'tosa.depthwise_conv2d' op failed level check: stride <= MAX_STRIDE}}
+  %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 8193, 1>} :
+            (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32>
+  return %0 : tensor<1x32x32x64xf32>
+}
+
+// -----
+
+func.func @test_depthwise_conv2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> {
+  // expected-error at +1 {{'tosa.depthwise_conv2d' op failed level check: stride <= MAX_STRIDE}}
+  %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 8193>} :
+            (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32>
+  return %0 : tensor<1x32x32x64xf32>
+}
+
+// -----
+
+func.func @test_fft2d_real_h(%arg0: tensor<32x8193x32xf32>, %arg1: tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+  // expected-error at +1 {{'tosa.fft2d' op failed level check: H <= MAX_KERNEL}}
+  %0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
+            (tensor<32x8193x32xf32>, tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
+  return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+}
+
+// -----
+
+func.func @test_fft2d_real_w(%arg0: tensor<32x32x8193xf32>, %arg1: tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+  // expected-error at +1 {{'tosa.fft2d' op failed level check: W <= MAX_KERNEL}}
+  %0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
+            (tensor<32x32x8193xf32>, tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
+  return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+}
+
+// -----
+
+func.func @test_fft2d_imag_h(%arg0: tensor<32x32x32xf32>, %arg1: tensor<32x8193x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+  // expected-error at +1 {{'tosa.fft2d' op failed level check: H <= MAX_KERNEL}}
+  %0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
+            (tensor<32x32x32xf32>, tensor<32x8193x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
+  return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+}
+
+// -----
+
+func.func @test_fft2d_imag_w(%arg0: tensor<32x32x32xf32>, %arg1: tensor<32x32x8193xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+  // expected-error at +1 {{'tosa.fft2d' op failed level check: W <= MAX_KERNEL}}
+  %0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
+            (tensor<32x32x32xf32>, tensor<32x32x8193xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
+  return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_stride_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+  // expected-error at +1 {{'tosa.max_pool2d' op failed level check: stride <= MAX_STRIDE}}
+  %0 = "tosa.max_pool2d"(%arg0) {kernel = array<i64: 1, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 8193, 1>} :
+         (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_kernel_x(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+  // expected-error at +1 {{'tosa.max_pool2d' op failed level check: kernel <= MAX_KERNEL}}
+  %0 = "tosa.max_pool2d"(%arg0) {kernel = array<i64: 1, 8193>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 1>} :
+         (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_stride_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+  // expected-error at +1 {{'tosa.max_pool2d' op failed level check: stride <= MAX_STRIDE}}
+  %0 = "tosa.max_pool2d"(%arg0) {kernel = array<i64: 1, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 8193, 1>} :
+         (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_stride_x(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+  // expected-error at +1 {{'tosa.max_pool2d' op failed level check: stride <= MAX_STRIDE}}
+  %0 = "tosa.max_pool2d"(%arg0) {kernel = array<i64: 1, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 8193>} :
+         (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+
+// -----
+
+func.func @test_maxpool2d_pad_top(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+  // expected-error at +1 {{'tosa.max_pool2d' op failed level check: pad <= MAX_KERNEL}}
+  %0 = "tosa.max_pool2d"(%arg0) {kernel = array<i64: 1, 1>, pad = array<i64: 8193, 4, 4, 4>, stride = array<i64: 1, 1>} :
+         (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+  // expected-error at +1 {{'tosa.max_pool2d' op failed level check: pad <= MAX_KERNEL}}
+  %0 = "tosa.max_pool2d"(%arg0) {kernel = array<i64: 1, 1>, pad = array<i64: 4, 8193, 4, 4>, stride = array<i64: 1, 1>} :
+         (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_pad_left(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+  // expected-error at +1 {{'tosa.max_pool2d' op failed level check: pad <= MAX_KERNEL}}
+  %0 = "tosa.max_pool2d"(%arg0) {kernel = array<i64: 1, 1>, pad = array<i64: 4, 4, 8193, 4>, stride = array<i64: 1, 1>} :
+         (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_pad_right(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+  // expected-error at +1 {{'tosa.max_pool2d' op failed level check: pad <= MAX_KERNEL}}
+  %0 = "tosa.max_pool2d"(%arg0) {kernel = array<i64: 1, 1>, pad = array<i64: 4, 4, 4, 8193>, stride = array<i64: 1, 1>} :
+         (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+
+func.func @test_rfft2d_input_h(%arg0: tensor<13x8193x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {
+  // expected-error at +1 {{'tosa.rfft2d' op failed level check: H <= MAX_KERNEL}}
+  %0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x8193x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>)
+  return %0, %1 : tensor<13x8x9xf32>, tensor<13x8x9xf32>
+}
+
+// -----
+
+func.func @test_rfft2d_input_w(%arg0: tensor<13x8x8193xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {
+  // expected-error at +1 {{'tosa.rfft2d' op failed level check: W <= MAX_KERNEL}}
+  %0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x8x8193xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>)
+  return %0, %1 : tensor<13x8x9xf32>, tensor<13x8x9xf32>
+}
+
+// -----
+
+func.func @test_transpose_conv2d_weight_h(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x8193x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.transpose_conv2d' op failed level check: KH <= MAX_KERNEL}}
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} :
+              (tensor<1x32x32x8xf32>, tensor<16x8193x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+  return %0 : tensor<1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_transpose_conv2d_weight_w(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x8193x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.transpose_conv2d' op failed level check: KW <= MAX_KERNEL}}
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} :
+              (tensor<1x32x32x8xf32>, tensor<16x1x8193x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+  return %0 : tensor<1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_transpose_conv2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.transpose_conv2d' op failed level check: pad <= MAX_KERNEL}}
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 8193, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} :
+              (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+  return %0 : tensor<1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_transpose_conv2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.transpose_conv2d' op failed level check: pad <= MAX_KERNEL}}
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 0, 8193, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} :
+              (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+  return %0 : tensor<1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_transpose_conv2d_pad_left(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.transpose_conv2d' op failed level check: pad <= MAX_KERNEL}}
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 0, 0, 8193, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} :
+              (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+  return %0 : tensor<1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_transpose_conv2d_pad_right(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.transpose_conv2d' op failed level check: pad <= MAX_KERNEL}}
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 0, 0, 0, 8193>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} :
+              (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+  return %0 : tensor<1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_transpose_conv2d_stride_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.transpose_conv2d' op failed level check: stride <= MAX_STRIDE}}
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 8193, 1>} :
+              (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+  return %0 : tensor<1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_transpose_conv2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.transpose_conv2d' op failed level check: stride <= MAX_STRIDE}}
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 8193>} :
+              (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+  return %0 : tensor<1x32x32x16xf32>
+}
+
+// -----
+
+func.func @test_resize_scale_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> {
+  // expected-error at +1 {{'tosa.resize' op failed level check: scale_y_n/scale_y_d <= MAX_SCALE}}
+  %1 = "tosa.resize"(%arg0) { scale = array<i64: 65, 1, 4, 2>, offset = array<i64: -1, -1>, border = array<i64: 1, 1>, mode = "BILINEAR"} :
+                (tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32>
+  return %1 : tensor<1x64x64x8xf32>
+}
+
+// -----
+
+func.func @test_resize_scale_x(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> {
+  // expected-error at +1 {{'tosa.resize' op failed level check: scale_x_n/scale_x_d <= MAX_SCALE}}
+  %1 = "tosa.resize"(%arg0) { scale = array<i64: 4, 2, 65, 1>, offset = array<i64: -1, -1>, border = array<i64: 1, 1>, mode = "BILINEAR"} :
+                (tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32>
+  return %1 : tensor<1x64x64x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_cond_if
+func.func @test_cond_if(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: tensor<1x1x1x1x1x1x1xf32>, %arg2: tensor<i1>) -> tensor<1x1x1x1x1x1x1xf32> {
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<1x1x1x1x1x1x1xf32>, %arg4: tensor<1x1x1x1x1x1x1xf32>):
+    "tosa.yield"(%arg3) : (tensor<1x1x1x1x1x1x1xf32>) -> ()
+  },  {
+  ^bb0(%arg3: tensor<1x1x1x1x1x1x1xf32>, %arg4: tensor<1x1x1x1x1x1x1xf32>):
+    "tosa.yield"(%arg4) : (tensor<1x1x1x1x1x1x1xf32>) -> ()
+  }) : (tensor<i1>, tensor<1x1x1x1x1x1x1xf32>, tensor<1x1x1x1x1x1x1xf32>) -> tensor<1x1x1x1x1x1x1xf32>
+  return %0 : tensor<1x1x1x1x1x1x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_while_loop
+func.func @test_while_loop(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+  %1:2 = "tosa.while_loop"(%0, %arg0) ({
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<1x1x1x1x1x1x1xf32>):
+    %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    %3 = "tosa.logical_not"(%2) : (tensor<i1>) -> tensor<i1>
+    "tosa.yield"(%3) : (tensor<i1>) -> ()
+  },  {
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<1x1x1x1x1x1x1xf32>):
+    %2 = "tosa.const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    "tosa.yield"(%3, %arg4) : (tensor<i32>, tensor<1x1x1x1x1x1x1xf32>) -> ()
+  }) : (tensor<i32>, tensor<1x1x1x1x1x1x1xf32>) -> (tensor<i32>, tensor<1x1x1x1x1x1x1xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_custom
+func.func @test_custom(%arg0: tensor<1x1x1x1x1x1x10xi32>) -> tensor<1x1x1x1x1x1x10xi32> {
+  %0 = "tosa.custom"(%arg0) {identifier="custom_test", config="tosa_mlir_test", implementation_attrs=""} :
+           (tensor<1x1x1x1x1x1x10xi32>) -> (tensor<1x1x1x1x1x1x10xi32>)
+  return %0 : tensor<1x1x1x1x1x1x10xi32>
+}
+
+


        


More information about the Mlir-commits mailing list