[Mlir-commits] [mlir] [mlir][tosa] Add error if and level checks for COND_IF & WHILE_LOOP (PR #136194)
TatWai Chong
llvmlistbot at llvm.org
Mon Apr 28 10:57:05 PDT 2025
https://github.com/tatwaichong updated https://github.com/llvm/llvm-project/pull/136194
>From 5b8f50740a005f5eb373aa0ad81df26ac5ea6c8d Mon Sep 17 00:00:00 2001
From: TatWai Chong <tatwai.chong at arm.com>
Date: Fri, 11 Apr 2025 19:54:49 -0700
Subject: [PATCH] [mlir][tosa] Add error if and level checks for COND_IF &
WHILE_LOOP
Error if checks: verify whether the same length and type between
input list, output list, and control-flow blocks.
Level_checks: verify whether the nested depth exceeds MAX_NESTING.
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 2 +
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 129 +++++++
.../Tosa/Transforms/TosaValidation.cpp | 35 ++
mlir/test/Dialect/Tosa/level_check.mlir | 303 +++++++++++----
mlir/test/Dialect/Tosa/verifier.mlir | 350 +++++++++++++++++-
5 files changed, 739 insertions(+), 80 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index cc78aaed911e6..52bb0eb992b69 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2559,6 +2559,7 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
);
let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -2597,6 +2598,7 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
);
let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
}
include "mlir/Dialect/Tosa/IR/TosaUtilOps.td"
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index d535009f34533..b2e471f2bba93 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -562,6 +562,57 @@ static LogicalResult verifyConvOpErrorIf(T op) {
return success();
}
+// Verify whether same type and shape of the given two types.
+static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
+ StringRef name1, Type type2,
+ StringRef name2) {
+ auto shapeType1 = dyn_cast<ShapedType>(type1);
+ auto shapeType2 = dyn_cast<ShapedType>(type2);
+ if (!shapeType1 || !shapeType2)
+ return failure();
+
+ auto elemType1 = shapeType1.getElementType();
+ auto elemType2 = shapeType2.getElementType();
+ if (elemType1 != elemType2)
+ return op->emitOpError()
+ << "require same element type for " << name1 << " (" << elemType1
+ << ") and " << name2 << " (" << elemType2 << ")";
+
+ if (failed(verifyCompatibleShape(type1, type2)))
+ return op->emitOpError()
+ << "require same shapes for " << name1 << " (" << type1 << ") and "
+ << name2 << " (" << type2 << ")";
+
+ return success();
+}
+
+// Verify whether same length, type, and shape of the given two tensor lists.
+static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, ValueRange list1,
+ StringRef name1,
+ ValueRange list2,
+ StringRef name2) {
+ if (list1.size() != list2.size())
+ return op->emitOpError()
+ << "require same number of values in " << name1 << " ("
+ << list1.size() << ") and " << name2 << " (" << list2.size() << ")";
+
+ for (auto [type1, type2] :
+ llvm::zip_equal(list1.getTypes(), list2.getTypes())) {
+ if (errorIfTypeOrShapeMismatch(op, type1, name1, type2, name2).failed())
+ return failure();
+ }
+
+ return success();
+}
+
+static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
+ ShapeAdaptor shapeAdaptor(type);
+ if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
+ return success();
+
+ return shapeAdaptor.getNumElements() == 1 ? success() : failure();
+}
+
// verify that inType and outType have same element types
template <typename T>
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
@@ -3473,6 +3524,84 @@ void IfOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDict((*this)->getAttrs());
}
+LogicalResult IfOp::verify() {
+ if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
+ "'then_graph' arguments", getInputList(),
+ "'input_list'")
+ .failed())
+ return failure();
+
+ if (errorIfTypeOrShapeMismatch(*this, getElseGraph().front().getArguments(),
+ "'else_graph' arguments", getInputList(),
+ "'input_list'")
+ .failed())
+ return failure();
+
+ auto thenYield = cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
+ if (errorIfTypeOrShapeMismatch(*this, thenYield.getInputs(),
+ "'then_graph' results", getOutputList(),
+ "'output_list'")
+ .failed())
+ return failure();
+
+ auto elseYield = cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
+ if (errorIfTypeOrShapeMismatch(*this, elseYield.getInputs(),
+ "'else_graph' results", getOutputList(),
+ "'output_list'")
+ .failed())
+ return failure();
+
+ auto condType = getCondition().getType();
+ if (errorIfShapeNotSizeOne(*this, condType).failed())
+ return emitOpError() << "'condition' must be a size 1 tensor, got "
+ << condType;
+
+ return success();
+}
+
+LogicalResult WhileOp::verify() {
+ if (errorIfTypeOrShapeMismatch(*this, getInputList(), "'input_list'",
+ getOutputList(), "'output_list'")
+ .failed())
+ return failure();
+
+ if (errorIfTypeOrShapeMismatch(*this, getCondGraph().front().getArguments(),
+ "'cond_graph' arguments", getInputList(),
+ "'input_list'")
+ .failed())
+ return failure();
+
+ if (errorIfTypeOrShapeMismatch(*this, getBodyGraph().front().getArguments(),
+ "'body_graph' arguments", getInputList(),
+ "'input_list'")
+ .failed())
+ return failure();
+
+ auto bodyYield = cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
+ if (errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
+ "'body_graph' results", getInputList(),
+ "'input_list'")
+ .failed())
+ return failure();
+
+ // Condition block output must be a single element tensor with a single bool
+ // value.
+ auto condYield = cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
+ if (condYield.getInputs().size() != 1)
+ return emitOpError() << "require 'cond_graph' only have one result";
+
+ auto condOutType = condYield.getInputs()[0].getType();
+ if (errorIfShapeNotSizeOne(*this, condOutType).failed())
+ return emitOpError() << "'cond_graph' result must be a size 1 tensor, got "
+ << condOutType;
+
+ if (!getElementTypeOrSelf(condOutType).isInteger(1))
+ return emitOpError() << "'cond_graph' result must be a boolean tensor, got "
+ << condOutType;
+
+ return success();
+}
+
LogicalResult ReverseOp::verify() {
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
/* outType = */ getOutput().getType())
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index df950939645b0..e8b52d48347ab 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -449,6 +449,35 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
return true;
}
+ // Recursively perform a bottom-up search to determine the maximum nesting
+ // depth, starting from a specific operation and continuing up to the function
+ // or module scope. Tosa nesting_depth starts at 0 and increments by one each
+ // time a new nested `region` is encountered.
+ static void getMaxNestedDepth(Operation *op, int32_t &depth) {
+ if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op))
+ return;
+
+ op = op->getParentOp();
+ if (!op)
+ return;
+
+ depth++;
+ getMaxNestedDepth(op, depth);
+ return;
+ }
+
+ bool levelCheckMaxNesting(Operation *op) {
+ int32_t maxNestedDepth = 0;
+ getMaxNestedDepth(op, maxNestedDepth);
+
+ if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
+ op->emitOpError() << "failed level check: " << maxNestedDepth
+ << " >= MAX_NESTING";
+ return false;
+ }
+ return true;
+ }
+
bool levelCheckListSize(Operation *op) {
if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
return levelCheckListSize(op, concat.getInput1().size(), "input1");
@@ -750,6 +779,12 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
return failure();
}
+ if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
+ if (!levelCheckMaxNesting(op)) {
+ return failure();
+ }
+ }
+
return success();
}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 5307645324b81..d24c1fa57883d 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1327,33 +1327,42 @@ func.func @test_if_tensor_list_size(%arg0 : tensor<i1>) {
%0 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
// expected-error at +1 {{'tosa.cond_if' op failed level check for MAX_TENSOR_LIST_SIZE: inputs}}
%1 = "tosa.cond_if"(%arg0, // condition
- %0, %0, %0, %0, %0, %0, %0, %0,
- %0, %0, %0, %0, %0, %0, %0, %0,
- %0, %0, %0, %0, %0, %0, %0, %0,
- %0, %0, %0, %0, %0, %0, %0, %0,
- %0, %0, %0, %0, %0, %0, %0, %0,
- %0, %0, %0, %0, %0, %0, %0, %0,
- %0, %0, %0, %0, %0, %0, %0, %0,
- %0, %0, %0, %0, %0, %0, %0, %0,
- %0) ({
- ^bb0(%arg3: tensor<1xi32>):
- "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0) ({
+ ^bb0(%00: tensor<1xi32>, %01: tensor<1xi32>, %02: tensor<1xi32>, %03: tensor<1xi32>, %04: tensor<1xi32>, %05: tensor<1xi32>, %06: tensor<1xi32>, %07: tensor<1xi32>, %08: tensor<1xi32>, %09: tensor<1xi32>,
+ %10: tensor<1xi32>, %11: tensor<1xi32>, %12: tensor<1xi32>, %13: tensor<1xi32>, %14: tensor<1xi32>, %15: tensor<1xi32>, %16: tensor<1xi32>, %17: tensor<1xi32>, %18: tensor<1xi32>, %19: tensor<1xi32>,
+ %20: tensor<1xi32>, %21: tensor<1xi32>, %22: tensor<1xi32>, %23: tensor<1xi32>, %24: tensor<1xi32>, %25: tensor<1xi32>, %26: tensor<1xi32>, %27: tensor<1xi32>, %28: tensor<1xi32>, %29: tensor<1xi32>,
+ %30: tensor<1xi32>, %31: tensor<1xi32>, %32: tensor<1xi32>, %33: tensor<1xi32>, %34: tensor<1xi32>, %35: tensor<1xi32>, %36: tensor<1xi32>, %37: tensor<1xi32>, %38: tensor<1xi32>, %39: tensor<1xi32>,
+ %40: tensor<1xi32>, %41: tensor<1xi32>, %42: tensor<1xi32>, %43: tensor<1xi32>, %44: tensor<1xi32>, %45: tensor<1xi32>, %46: tensor<1xi32>, %47: tensor<1xi32>, %48: tensor<1xi32>, %49: tensor<1xi32>,
+ %50: tensor<1xi32>, %51: tensor<1xi32>, %52: tensor<1xi32>, %53: tensor<1xi32>, %54: tensor<1xi32>, %55: tensor<1xi32>, %56: tensor<1xi32>, %57: tensor<1xi32>, %58: tensor<1xi32>, %59: tensor<1xi32>,
+ %60: tensor<1xi32>, %61: tensor<1xi32>, %62: tensor<1xi32>, %63: tensor<1xi32>, %64: tensor<1xi32>
+ ):
+ "tosa.yield"(%64) : (tensor<1xi32>) -> ()
}, {
- ^bb0(%arg3: tensor<1xi32>):
- "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
+ ^bb0(%00: tensor<1xi32>, %01: tensor<1xi32>, %02: tensor<1xi32>, %03: tensor<1xi32>, %04: tensor<1xi32>, %05: tensor<1xi32>, %06: tensor<1xi32>, %07: tensor<1xi32>, %08: tensor<1xi32>, %09: tensor<1xi32>,
+ %10: tensor<1xi32>, %11: tensor<1xi32>, %12: tensor<1xi32>, %13: tensor<1xi32>, %14: tensor<1xi32>, %15: tensor<1xi32>, %16: tensor<1xi32>, %17: tensor<1xi32>, %18: tensor<1xi32>, %19: tensor<1xi32>,
+ %20: tensor<1xi32>, %21: tensor<1xi32>, %22: tensor<1xi32>, %23: tensor<1xi32>, %24: tensor<1xi32>, %25: tensor<1xi32>, %26: tensor<1xi32>, %27: tensor<1xi32>, %28: tensor<1xi32>, %29: tensor<1xi32>,
+ %30: tensor<1xi32>, %31: tensor<1xi32>, %32: tensor<1xi32>, %33: tensor<1xi32>, %34: tensor<1xi32>, %35: tensor<1xi32>, %36: tensor<1xi32>, %37: tensor<1xi32>, %38: tensor<1xi32>, %39: tensor<1xi32>,
+ %40: tensor<1xi32>, %41: tensor<1xi32>, %42: tensor<1xi32>, %43: tensor<1xi32>, %44: tensor<1xi32>, %45: tensor<1xi32>, %46: tensor<1xi32>, %47: tensor<1xi32>, %48: tensor<1xi32>, %49: tensor<1xi32>,
+ %50: tensor<1xi32>, %51: tensor<1xi32>, %52: tensor<1xi32>, %53: tensor<1xi32>, %54: tensor<1xi32>, %55: tensor<1xi32>, %56: tensor<1xi32>, %57: tensor<1xi32>, %58: tensor<1xi32>, %59: tensor<1xi32>,
+ %60: tensor<1xi32>, %61: tensor<1xi32>, %62: tensor<1xi32>, %63: tensor<1xi32>, %64: tensor<1xi32>
+ ):
+ "tosa.yield"(%01) : (tensor<1xi32>) -> ()
}) : (
- tensor<i1>,
- tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
- tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
- tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
- tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
- tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
- tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
- tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
- tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
- tensor<1xi32>
- ) -> tensor<1xi32>
-
+ tensor<i1>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
+ ) -> tensor<1xi32>
return
}
@@ -1361,27 +1370,54 @@ func.func @test_if_tensor_list_size(%arg0 : tensor<i1>) {
// CHECK-LABEL: test_if_tensor_list_size_outputs
func.func @test_if_tensor_list_size_outputs(%arg0 : tensor<i1>) {
- %0 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ %cst_0 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
// expected-error at +1 {{'tosa.cond_if' op failed level check for MAX_TENSOR_LIST_SIZE: outputs}}
- %r:65 = "tosa.cond_if"(%arg0) ({
- ^bb0(%arg3: tensor<1xi32>):
- "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
+ %r:65 = "tosa.cond_if"(%arg0, %cst_0) ({
+ ^bb0(%0: tensor<1xi32>):
+ "tosa.yield"(%0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0
+ ) : (
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
+ ) -> ()
}, {
- ^bb0(%arg3: tensor<1xi32>):
- "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
- }) : (tensor<i1>) -> (
- tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
- tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
- tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
- tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
- tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
- tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
- tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
- tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
- tensor<1xi32>
- )
-
+ ^bb0(%0: tensor<1xi32>):
+ "tosa.yield"(%0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0
+ ) : (
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
+ ) -> ()
+ }) : (tensor<i1>, tensor<1xi32>) -> (
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
+ )
return
}
@@ -1391,25 +1427,57 @@ func.func @test_if_tensor_list_size_outputs(%arg0 : tensor<i1>) {
func.func @test_while_tensor_list_size(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: tensor<1xi32>) {
%0 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
// expected-error at +1 {{'tosa.while_loop' op failed level check for MAX_TENSOR_LIST_SIZE: inputs}}
- %1:2 = "tosa.while_loop"(%0, %arg0,
- %0, %0, %0, %0, %0, %0, %0, %0,
- %0, %0, %0, %0, %0, %0, %0, %0,
- %0, %0, %0, %0, %0, %0, %0, %0,
- %0, %0, %0, %0, %0, %0, %0, %0,
- %0, %0, %0, %0, %0, %0, %0, %0,
- %0, %0, %0, %0, %0, %0, %0, %0,
- %0, %0, %0, %0, %0, %0, %0, %0,
- %0, %0, %0, %0, %0, %0, %0
+ %1:65 = "tosa.while_loop"(%0, %arg0,
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0
) ({
- ^bb0(%arg3: tensor<1xi32>, %arg4: tensor<1x1x1x1x1x1x1xf32>):
+ ^bb0(%arg3: tensor<1xi32>, %arg4: tensor<1x1x1x1x1x1x1xf32>,
+ %00: tensor<1xi32>, %01: tensor<1xi32>, %02: tensor<1xi32>, %03: tensor<1xi32>, %04: tensor<1xi32>, %05: tensor<1xi32>, %06: tensor<1xi32>, %07: tensor<1xi32>, %08: tensor<1xi32>, %09: tensor<1xi32>,
+ %10: tensor<1xi32>, %11: tensor<1xi32>, %12: tensor<1xi32>, %13: tensor<1xi32>, %14: tensor<1xi32>, %15: tensor<1xi32>, %16: tensor<1xi32>, %17: tensor<1xi32>, %18: tensor<1xi32>, %19: tensor<1xi32>,
+ %20: tensor<1xi32>, %21: tensor<1xi32>, %22: tensor<1xi32>, %23: tensor<1xi32>, %24: tensor<1xi32>, %25: tensor<1xi32>, %26: tensor<1xi32>, %27: tensor<1xi32>, %28: tensor<1xi32>, %29: tensor<1xi32>,
+ %30: tensor<1xi32>, %31: tensor<1xi32>, %32: tensor<1xi32>, %33: tensor<1xi32>, %34: tensor<1xi32>, %35: tensor<1xi32>, %36: tensor<1xi32>, %37: tensor<1xi32>, %38: tensor<1xi32>, %39: tensor<1xi32>,
+ %40: tensor<1xi32>, %41: tensor<1xi32>, %42: tensor<1xi32>, %43: tensor<1xi32>, %44: tensor<1xi32>, %45: tensor<1xi32>, %46: tensor<1xi32>, %47: tensor<1xi32>, %48: tensor<1xi32>, %49: tensor<1xi32>,
+ %50: tensor<1xi32>, %51: tensor<1xi32>, %52: tensor<1xi32>, %53: tensor<1xi32>, %54: tensor<1xi32>, %55: tensor<1xi32>, %56: tensor<1xi32>, %57: tensor<1xi32>, %58: tensor<1xi32>, %59: tensor<1xi32>,
+ %60: tensor<1xi32>, %61: tensor<1xi32>, %62: tensor<1xi32>
+ ):
%2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
%3 = "tosa.logical_not"(%2) : (tensor<1xi1>) -> tensor<1xi1>
"tosa.yield"(%3) : (tensor<1xi1>) -> ()
}, {
- ^bb0(%arg3: tensor<i32>, %arg4: tensor<1x1x1x1x1x1x1xf32>):
- %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
- %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
- "tosa.yield"(%3, %arg4) : (tensor<i32>, tensor<1x1x1x1x1x1x1xf32>) -> ()
+ ^bb0(%arg3: tensor<1xi32>, %arg4: tensor<1x1x1x1x1x1x1xf32>,
+ %00: tensor<1xi32>, %01: tensor<1xi32>, %02: tensor<1xi32>, %03: tensor<1xi32>, %04: tensor<1xi32>, %05: tensor<1xi32>, %06: tensor<1xi32>, %07: tensor<1xi32>, %08: tensor<1xi32>, %09: tensor<1xi32>,
+ %10: tensor<1xi32>, %11: tensor<1xi32>, %12: tensor<1xi32>, %13: tensor<1xi32>, %14: tensor<1xi32>, %15: tensor<1xi32>, %16: tensor<1xi32>, %17: tensor<1xi32>, %18: tensor<1xi32>, %19: tensor<1xi32>,
+ %20: tensor<1xi32>, %21: tensor<1xi32>, %22: tensor<1xi32>, %23: tensor<1xi32>, %24: tensor<1xi32>, %25: tensor<1xi32>, %26: tensor<1xi32>, %27: tensor<1xi32>, %28: tensor<1xi32>, %29: tensor<1xi32>,
+ %30: tensor<1xi32>, %31: tensor<1xi32>, %32: tensor<1xi32>, %33: tensor<1xi32>, %34: tensor<1xi32>, %35: tensor<1xi32>, %36: tensor<1xi32>, %37: tensor<1xi32>, %38: tensor<1xi32>, %39: tensor<1xi32>,
+ %40: tensor<1xi32>, %41: tensor<1xi32>, %42: tensor<1xi32>, %43: tensor<1xi32>, %44: tensor<1xi32>, %45: tensor<1xi32>, %46: tensor<1xi32>, %47: tensor<1xi32>, %48: tensor<1xi32>, %49: tensor<1xi32>,
+ %50: tensor<1xi32>, %51: tensor<1xi32>, %52: tensor<1xi32>, %53: tensor<1xi32>, %54: tensor<1xi32>, %55: tensor<1xi32>, %56: tensor<1xi32>, %57: tensor<1xi32>, %58: tensor<1xi32>, %59: tensor<1xi32>,
+ %60: tensor<1xi32>, %61: tensor<1xi32>, %62: tensor<1xi32>
+ ):
+ %2 = "tosa.const"() {values = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %3 = "tosa.add"(%arg3, %2) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ "tosa.yield"(%3, %arg4,
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0
+ ) : (
+ tensor<1xi32>, tensor<1x1x1x1x1x1x1xf32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+ tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
+ ) -> ()
}) : (tensor<1xi32>, tensor<1x1x1x1x1x1x1xf32>,
tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
@@ -1419,28 +1487,7 @@ func.func @test_while_tensor_list_size(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1:
tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
- ) -> (tensor<i32>, tensor<1x1x1x1x1x1x1xf32>)
-
- return
-}
-
-// -----
-
-// CHECK-LABEL: test_while_tensor_list_size_outputs
-func.func @test_while_tensor_list_size_outputs(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: tensor<1xi32>) {
- %0 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
- // expected-error at +1 {{'tosa.while_loop' op failed level check for MAX_TENSOR_LIST_SIZE: outputs}}
- %1:65 = "tosa.while_loop"(%0, %arg0) ({
- ^bb0(%arg3: tensor<1xi32>, %arg4: tensor<1x1x1x1x1x1x1xf32>):
- %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
- %3 = "tosa.logical_not"(%2) : (tensor<1xi1>) -> tensor<1xi1>
- "tosa.yield"(%3) : (tensor<1xi1>) -> ()
- }, {
- ^bb0(%arg3: tensor<i32>, %arg4: tensor<1x1x1x1x1x1x1xf32>):
- %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
- %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
- "tosa.yield"(%3, %arg4) : (tensor<i32>, tensor<1x1x1x1x1x1x1xf32>) -> ()
- }) : (tensor<1xi32>, tensor<1x1x1x1x1x1x1xf32>) -> ( tensor<i32>, tensor<1x1x1x1x1x1x1xf32>,
+ ) -> (tensor<1xi32>, tensor<1x1x1x1x1x1x1xf32>,
tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
@@ -1453,3 +1500,101 @@ func.func @test_while_tensor_list_size_outputs(%arg0: tensor<1x1x1x1x1x1x1xf32>,
return
}
+
+// -----
+
+func.func @test_cond_if_max_nested_depth(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>, %arg3: tensor<i1>) -> tensor<f32> {
+ %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %1 = tosa.cond_if %arg3 -> (tensor<f32>) {
+ %2 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %3 = tosa.cond_if %arg3 -> (tensor<f32>) {
+ %4 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ // expected-error at +1 {{'tosa.cond_if' op failed level check: 6 >= MAX_NESTING}}
+ %5 = tosa.cond_if %arg3 -> (tensor<f32>) {
+ %res = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %res : tensor<f32>
+ } else {
+ tosa.yield %arg0 : tensor<f32>
+ }
+ tosa.yield %5 : tensor<f32>
+ } else {
+ tosa.yield %arg0 : tensor<f32>
+ }
+ tosa.yield %4 : tensor<f32>
+ } else {
+ tosa.yield %arg0 : tensor<f32>
+ }
+ tosa.yield %3 : tensor<f32>
+ } else {
+ tosa.yield %arg0 : tensor<f32>
+ }
+ tosa.yield %2 : tensor<f32>
+ } else {
+ tosa.yield %arg0 : tensor<f32>
+ }
+ tosa.yield %1 : tensor<f32>
+ } else {
+ %res = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %res : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_while_loop_max_nested_depth(%arg0: tensor<i32>) {
+ %init_0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+ %cst_1 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+
+ %1:2 = tosa.while_loop (%arg2 = %init_0, %arg3 = %arg0) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+ %2 = tosa.greater_equal %arg3, %arg2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ tosa.yield %2 : tensor<i1>
+ } do {
+ ^bb0(%arg2: tensor<i32>, %arg2b: tensor<i32>):
+ %1:2 = tosa.while_loop (%arg4 = %init_0, %arg5 = %arg0) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+ %2 = tosa.greater_equal %arg5, %arg4 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ tosa.yield %2 : tensor<i1>
+ } do {
+ ^bb0(%arg4: tensor<i32>, %arg4b: tensor<i32>):
+ %1:2 = tosa.while_loop (%arg6 = %init_0, %arg7 = %arg0) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+ %2 = tosa.greater_equal %arg7, %arg6 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ tosa.yield %2 : tensor<i1>
+ } do {
+ ^bb0(%arg6: tensor<i32>, %arg6b: tensor<i32>):
+ %1:2 = tosa.while_loop (%arg8 = %init_0, %arg9 = %arg0) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+ %2 = tosa.greater_equal %arg9, %arg8 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ tosa.yield %2 : tensor<i1>
+ } do {
+ ^bb0(%arg8: tensor<i32>, %arg8b: tensor<i32>):
+ %1:2 = tosa.while_loop (%arg10 = %init_0, %arg11 = %arg0) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+ %2 = tosa.greater_equal %arg11, %arg10 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ tosa.yield %2 : tensor<i1>
+ } do {
+ ^bb0(%arg10: tensor<i32>, %arg10b: tensor<i32>):
+ // expected-error at +1 {{'tosa.while_loop' op failed level check: 6 >= MAX_NESTING}}
+ %1:2 = tosa.while_loop (%arg12 = %init_0, %arg13 = %arg0) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+ %2 = tosa.greater_equal %arg13, %arg12 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ tosa.yield %2 : tensor<i1>
+ } do {
+ ^bb0(%arg12: tensor<i32>, %arg12b: tensor<i32>):
+ %3 = tosa.add %arg12, %cst_1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %arg2, %3: tensor<i32>, tensor<i32>
+ }
+ %3 = tosa.add %arg10, %cst_1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %arg2, %3: tensor<i32>, tensor<i32>
+ }
+ %3 = tosa.add %arg8, %cst_1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %arg2, %3: tensor<i32>, tensor<i32>
+ }
+ %3 = tosa.add %arg6, %cst_1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %arg2, %3: tensor<i32>, tensor<i32>
+ }
+ %3 = tosa.add %arg4, %cst_1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %arg2, %3: tensor<i32>, tensor<i32>
+ }
+ %3 = tosa.add %arg2, %cst_1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %arg2, %3: tensor<i32>, tensor<i32>
+ }
+ return
+}
+
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index c669c36e5452f..7ae8ec470c3dd 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -436,4 +436,352 @@ func.func @test_pad_invalid_padding_value(%arg0: tensor<10xi8>, %arg1: tensor<1x
// expected-error at +1 {{invalid padding values at dimension 0: values must be non-negative or -1 for dynamic padding, got [-2, 2]}}
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<10xi8>, !tosa.shape<2>, tensor<1xi8>) -> tensor<10xi8>
return %1 : tensor<10xi8>
-}
\ No newline at end of file
+}
+
+// -----
+
+func.func @test_cond_if_input_list_mismatch_then_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'then_graph' arguments (1) and 'input_list' (2)}}
+ %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+ ^bb0(%arg3: tensor<f32>):
+ tosa.yield %arg3 : tensor<f32>
+ }, {
+ ^bb0(%arg4: tensor<f32>):
+ tosa.yield %arg4 : tensor<f32>
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+ return %0 : tensor<f32>
+
+}
+
+// -----
+
+func.func @test_cond_if_input_list_mismatch_then_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'then_graph' arguments (2) and 'input_list' (1)}}
+ %0 = "tosa.cond_if"(%arg2, %arg0) ({
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ tosa.yield %arg3 : tensor<f32>
+ }, {
+ ^bb0(%arg4: tensor<f32>):
+ tosa.yield %arg4 : tensor<f32>
+ }) : (tensor<i1>, tensor<f32>) -> tensor<f32>
+ return %0 : tensor<f32>
+
+}
+
+// -----
+
+func.func @test_cond_if_input_list_mismatch_else_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'else_graph' arguments (1) and 'input_list' (2)}}
+ %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ tosa.yield %arg3 : tensor<f32>
+ }, {
+ ^bb0(%arg4: tensor<f32>):
+ tosa.yield %arg4 : tensor<f32>
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+ return %0 : tensor<f32>
+
+}
+
+// -----
+
+func.func @test_cond_if_input_list_mismatch_else_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'else_graph' arguments (2) and 'input_list' (1)}}
+ %0 = "tosa.cond_if"(%arg2, %arg0) ({
+ ^bb0(%arg3: tensor<f32>):
+ tosa.yield %arg3 : tensor<f32>
+ }, {
+ ^bb0(%arg4: tensor<f32>, %arg3: tensor<f32>):
+ tosa.yield %arg4 : tensor<f32>
+ }) : (tensor<i1>, tensor<f32>) -> tensor<f32>
+ return %0 : tensor<f32>
+
+}
+
+// -----
+
+func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (2) and 'output_list' (1)}}
+ %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %2 = tosa.add %1, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1, %2 : tensor<f32>, tensor<f32>
+ } else {
+ %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (1) and 'output_list' (2)}}
+ %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) {
+ %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ } else {
+ %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (2) and 'output_list' (1)}}
+ %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ } else {
+ %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %2 = tosa.add %1, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1, %2 : tensor<f32>, tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_output_list_mismatch_else_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error at +1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (1) and 'output_list' (2)}}
+ %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) {
+ %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %2 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1, %2 : tensor<f32>, tensor<f32>
+ } else {
+ %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_cond_input_not_size_one(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<2xi1>) -> tensor<f32> {
+ // expected-error at +1 {{'tosa.cond_if' op 'condition' must be a size 1 tensor, got 'tensor<2xi1>'}}
+ %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ tosa.yield %arg3 : tensor<f32>
+ }, {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ tosa.yield %arg4 : tensor<f32>
+ }) : (tensor<2xi1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+ return %0 : tensor<f32>
+
+}
+
+// -----
+
+func.func @test_while_loop_input_list_mismatch_body_block_in(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+ %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+ // expected-error at +1 {{'tosa.while_loop' op require same number of values in 'body_graph' arguments (3) and 'input_list' (2)}}
+ %1:2 = tosa.while_loop (%arg2 = %0, %arg3 = %arg0) : (tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<10xi32>) {
+ %2 = tosa.greater_equal %arg2, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ tosa.yield %2 : tensor<i1>
+ } do {
+ ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<10xi32>):
+ %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %3 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %3, %arg4 : tensor<i32>, tensor<10xi32>
+ }
+ return
+}
+
+// -----
+
+func.func @test_while_loop_input_list_mismatch_body_block_in_2(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+ %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+ // expected-error at +1 {{'tosa.while_loop' op require same number of values in 'body_graph' arguments (2) and 'input_list' (3)}}
+ %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %arg0, %arg4 = %arg0)
+ : (tensor<i32>, tensor<10xi32>, tensor<10xi32>) -> (tensor<i32>, tensor<10xi32>, tensor<10xi32>) {
+ %2 = tosa.greater_equal %arg2, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ tosa.yield %2 : tensor<i1>
+ } do {
+ ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):
+ %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %3 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %3, %arg3 : tensor<i32>, tensor<i32>
+ }
+ return
+}
+
+// -----
+
+func.func @test_while_loop_input_list_mismatch_output_list(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+ %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+ // expected-error at +1 {{'tosa.while_loop' op require same number of values in 'input_list' (3) and 'output_list' (2)}}
+ %1:2 = tosa.while_loop (%arg2 = %0, %arg3 = %arg0, %arg4 = %arg0)
+ : (tensor<i32>, tensor<10xi32>, tensor<10xi32>) -> (tensor<i32>, tensor<10xi32>) {
+ %2 = tosa.greater_equal %arg2, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ tosa.yield %2 : tensor<i1>
+ } do {
+ ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):
+ %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %3 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %3, %arg3 : tensor<i32>, tensor<i32>
+ }
+ return
+}
+
+// -----
+
+func.func @test_while_loop_input_list_mismatch_output_list_2(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+ %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+ // expected-error at +1 {{'tosa.while_loop' op require same number of values in 'input_list' (2) and 'output_list' (3)}}
+ %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %arg0)
+ : (tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<10xi32>, tensor<10xi32>) {
+ %2 = tosa.greater_equal %arg2, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ tosa.yield %2 : tensor<i1>
+ } do {
+ ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):
+ %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %3 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %3, %arg3 : tensor<i32>, tensor<i32>
+ }
+ return
+}
+
+// -----
+
+func.func @test_while_loop_input_list_mismatch_cond_block(%arg0: tensor<2xf32>, %arg1: tensor<i32>) {
+ %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+ // expected-error at +1 {{'tosa.while_loop' op require same number of values in 'cond_graph' arguments (3) and 'input_list' (2)}}
+ %1:2 = "tosa.while_loop"(%0, %arg0) ({
+ ^bb0(%arg3: tensor<i32>, %arg4: tensor<2xf32>, %arg5: tensor<2xf32>):
+ %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ "tosa.yield"(%2) : (tensor<i1>) -> ()
+ }, {
+ ^bb0(%arg3: tensor<i32>, %arg4: tensor<2xf32>):
+ %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %3 = "tosa.const"() {values = dense<2> : tensor<1xi8>} : () -> tensor<1xi8>
+ %4 = "tosa.mul"(%arg3, %2, %3) : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+ "tosa.yield"(%4, %arg4) : (tensor<i32>, tensor<2xf32>) -> ()
+ }) : (tensor<i32>, tensor<2xf32>) -> (tensor<i32>, tensor<2xf32>)
+ return
+}
+
+// -----
+
+func.func @test_while_loop_input_list_mismatch_cond_block_2(%arg0: tensor<2xf32>, %arg1: tensor<i32>) {
+ %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+ // expected-error at +1 {{'tosa.while_loop' op require same number of values in 'cond_graph' arguments (1) and 'input_list' (3)}}
+ %1:3 = "tosa.while_loop"(%0, %arg0, %arg1) ({
+ ^bb0(%arg3: tensor<i32>):
+ %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ "tosa.yield"(%2) : (tensor<i1>) -> ()
+ }, {
+ ^bb0(%arg3: tensor<i32>, %arg4: tensor<2xf32>):
+ %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %3 = "tosa.const"() {values = dense<2> : tensor<1xi8>} : () -> tensor<1xi8>
+ %4 = "tosa.mul"(%arg3, %2, %3) : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+ "tosa.yield"(%4, %arg4) : (tensor<i32>, tensor<2xf32>) -> ()
+ }) : (tensor<i32>, tensor<2xf32>, tensor<i32>) -> (tensor<i32>, tensor<2xf32>, tensor<i32>)
+ return
+}
+
+// -----
+
+func.func @test_while_loop_input_list_mismatch_body_block_out(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+ %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+ // expected-error at +1 {{'tosa.while_loop' op require same number of values in 'body_graph' results (3) and 'input_list' (2)}}
+ %1:2 = tosa.while_loop (%arg2 = %0, %arg3 = %arg0) : (tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<10xi32>) {
+ %2 = tosa.greater_equal %arg2, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ tosa.yield %2 : tensor<i1>
+ } do {
+ ^bb0(%arg2: tensor<i32>, %arg4: tensor<10xi32>):
+ %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %3 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %2, %3, %arg4 : tensor<i32>, tensor<i32>, tensor<10xi32>
+ }
+ return
+}
+
+// -----
+
+func.func @test_while_loop_input_list_mismatch_body_block_out_2(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+ %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+ // expected-error at +1 {{'tosa.while_loop' op require same number of values in 'body_graph' results (1) and 'input_list' (2)}}
+ %1:2 = tosa.while_loop (%arg2 = %0, %arg3 = %arg0) : (tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<10xi32>) {
+ %2 = tosa.greater_equal %arg2, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ tosa.yield %2 : tensor<i1>
+ } do {
+ ^bb0(%arg2: tensor<i32>, %arg4: tensor<10xi32>):
+ %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %3 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %3 : tensor<i32>
+ }
+ return
+}
+
+// -----
+
+func.func @test_while_loop_type_mismatch(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+ %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+ // expected-error at +1 {{'tosa.while_loop' op require same element type for 'body_graph' arguments ('f32') and 'input_list' ('i32')}}
+ %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %0, %arg4 = %arg0) : (tensor<i32>, tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<10xi32>) {
+ %2 = tosa.greater_equal %arg3, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %3 = tosa.logical_not %2 : (tensor<i1>) -> tensor<i1>
+ tosa.yield %3 : tensor<i1>
+ } do {
+ ^bb0(%arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<10xi32>):
+ %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %6 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %6, %2, %arg4 : tensor<i32>, tensor<i32>, tensor<10xi32>
+ }
+ return
+}
+
+// -----
+
+func.func @test_while_loop_type_mismatch_2(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+ %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+ // expected-error at +1 {{'tosa.while_loop' op require same shapes for 'body_graph' arguments ('tensor<10xi32>') and 'input_list' ('tensor<i32>')}}
+ %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %0, %arg4 = %arg0) : (tensor<i32>, tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<10xi32>) {
+ %2 = tosa.greater_equal %arg3, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %3 = tosa.logical_not %2 : (tensor<i1>) -> tensor<i1>
+ tosa.yield %3 : tensor<i1>
+ } do {
+ ^bb0(%arg2: tensor<10xi32>, %arg3: tensor<i32>, %arg4: tensor<10xi32>):
+ %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %6 = tosa.add %arg2, %2 : (tensor<10xi32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %6, %2, %arg4 : tensor<i32>, tensor<i32>, tensor<10xi32>
+ }
+ return
+}
+
+// -----
+
+func.func @test_while_loop_cond_output_not_size_one(%arg0: tensor<10xi32>, %arg1: tensor<2xi32>) {
+ %0 = "tosa.const"() {values = dense<[4, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
+ // expected-error at +1 {{'tosa.while_loop' op 'cond_graph' result must be a size 1 tensor, got 'tensor<2xi1>'}}
+ %1:3 = tosa.while_loop (%arg2 = %arg0, %arg3 = %0, %arg4 = %arg0) : (tensor<10xi32>, tensor<2xi32>, tensor<10xi32>) -> (tensor<10xi32>, tensor<2xi32>, tensor<10xi32>) {
+ %2 = tosa.greater_equal %arg3, %arg1 : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
+ tosa.yield %2 : tensor<2xi1>
+ } do {
+ ^bb0(%arg2: tensor<10xi32>, %arg3: tensor<2xi32>, %arg4: tensor<10xi32>):
+ %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %3 = "tosa.const"() {values = dense<[3, 5]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %4 = tosa.add %arg2, %2 : (tensor<10xi32>, tensor<i32>) -> tensor<10xi32>
+ tosa.yield %4, %3, %arg4 : tensor<10xi32>, tensor<2xi32>, tensor<10xi32>
+ }
+ return
+}
+
+// -----
+
+func.func @test_while_loop_cond_output_not_bool(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+ %0 = "tosa.const"() {values = dense<9> : tensor<i32>} : () -> tensor<i32>
+ // expected-error at +1 {{'tosa.while_loop' op 'cond_graph' result must be a boolean tensor, got 'tensor<i32>'}}
+ %1:3 = tosa.while_loop (%arg2 = %arg0, %arg3 = %0, %arg4 = %arg0) : (tensor<10xi32>, tensor<i32>, tensor<10xi32>) -> (tensor<10xi32>, tensor<i32>, tensor<10xi32>) {
+ %2 = tosa.add %arg3, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %2 : tensor<i32>
+ } do {
+ ^bb0(%arg2: tensor<10xi32>, %arg3: tensor<i32>, %arg4: tensor<10xi32>):
+ %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %4 = tosa.add %arg2, %2 : (tensor<10xi32>, tensor<i32>) -> tensor<10xi32>
+ tosa.yield %4, %2, %arg4 : tensor<10xi32>, tensor<i32>, tensor<10xi32>
+ }
+ return
+}
More information about the Mlir-commits
mailing list