[Mlir-commits] [mlir] [mlir][tosa] Check for compile time constants in the validation pass (PR #131123)
Luke Hutton
llvmlistbot at llvm.org
Thu Mar 13 04:18:40 PDT 2025
https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/131123
This commit adds a concept of the 'dynamic' extension in the Dialect and checks that compile time constant (CTC) operands for each operator are constant if the dynamic extension is not loaded.
Operands labeled as CTC in the specification that are of tosa.shape (shape_t in the specification) type are not checked as they are always expected to be constant. This requirement is checked elsewhere in the dialect.
>From 02466af1b865c0218781dd4e08d39a324fb049e6 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Tue, 10 Dec 2024 14:28:44 +0000
Subject: [PATCH] [mlir][tosa] Check for compile time constants in the
validation pass
This commit adds a concept of the 'dynamic' extension in the Dialect
and checks that compile time constant (CTC) operands for each operator
are constant if the dynamic extension is not loaded.
Operands labeled as CTC in the specification that are of tosa.shape
(shape_t in the specification) type are not checked as they are always
expected to be constant. This requirement is checked elsewhere in the
dialect.
Change-Id: I2477e6a6cd1c01b6a66a5d9c4abc38eb064bab43
Signed-off-by: Luke Hutton <luke.hutton at arm.com>
---
.../mlir/Dialect/Tosa/IR/TosaOpBase.td | 4 +-
.../Dialect/Tosa/IR/TosaProfileCompliance.h | 1 +
.../Tosa/Transforms/TosaValidation.cpp | 112 +++++++++++++--
mlir/test/Dialect/Tosa/dynamic_extension.mlir | 87 ++++++++++++
mlir/test/Dialect/Tosa/invalid.mlir | 2 +-
mlir/test/Dialect/Tosa/invalid_extension.mlir | 132 +++++++++++++++++-
6 files changed, 322 insertions(+), 16 deletions(-)
create mode 100644 mlir/test/Dialect/Tosa/dynamic_extension.mlir
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 52130d31c248d..4cb2e8006ca57 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -228,6 +228,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
// CONTROLFLOW : Control Flow operations.
// DOUBLEROUND : Adds double rounding support to the RESCALE operator.
// INEXACTROUND : Adds inexact rounding support to the RESCALE operator.
+// DYNAMIC : Removes all Compile Time Constant state for CTC inputs.
//===----------------------------------------------------------------------===//
def Tosa_NONE : I32EnumAttrCase<"none", 0>;
@@ -245,12 +246,13 @@ def Tosa_EXT_VARIABLE : I32EnumAttrCase<"variable", 7>;
def Tosa_EXT_CONTROLFLOW : I32EnumAttrCase<"controlflow", 8>;
def Tosa_EXT_DOUBLEROUND : I32EnumAttrCase<"doubleround", 9>;
def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>;
+def Tosa_EXT_DYNAMIC : I32EnumAttrCase<"dynamic", 11>;
def Tosa_ExtensionAttr
: Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [
Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16, Tosa_EXT_FP8E4M3,
Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW,
- Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND, Tosa_EXT_NONE
+ Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND, Tosa_EXT_DYNAMIC, Tosa_EXT_NONE
]>;
def Tosa_ExtensionArrayAttr
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 88f454f63e6f9..69b827fe14dee 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -146,6 +146,7 @@ class TosaProfileCompliance {
return {Profile::pro_fp};
case Extension::variable:
case Extension::controlflow:
+ case Extension::dynamic:
return {Profile::pro_fp, Profile::pro_int};
case Extension::none:
return {};
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 6b1b651a90cfb..79c13793d7713 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -41,17 +41,91 @@ using namespace mlir::tosa;
namespace {
-static LogicalResult checkConstantOperandPad(Operation *op) {
+static LogicalResult
+checkConstantOperands(Operation *op, ArrayRef<unsigned int> operandIndices) {
+ for (const auto index : operandIndices) {
+ Attribute attr;
+ if (!matchPattern(op->getOperand(index), m_Constant(&attr))) {
+ return op->emitOpError("expected compile time resolvable constant, but "
+ "got variable value for operand #")
+ << index;
+ }
+ }
+ return success();
+}
+
+static LogicalResult checkConstantOperandMul(Operation *op,
+ const TargetEnv &env) {
+ if (!env.allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
+ // Check 'shift'
+ return checkConstantOperands(op, {2});
+ }
+ return success();
+}
+
+static LogicalResult checkConstantOperandTable(Operation *op,
+ const TargetEnv &env) {
+ if (!env.allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
+ // Check 'table'
+ return checkConstantOperands(op, {1});
+ }
+ return success();
+}
+
+static LogicalResult checkConstantOperandPad(Operation *op,
+ const TargetEnv &env) {
if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
- DenseElementsAttr paddings;
- if (!matchPattern(padOp.getPadding(), m_Constant(&paddings)))
- return op->emitOpError("padding of pad is not constant");
+ // Assume this op is zero-padding if padConst is not presented
+ if (!env.allows(Extension::dynamic) && padOp.getPadConst())
+ // Check 'pad_const'
+ // Note: 'padding' (operand 1) is not checked as it is a tosa.shape type
+ return checkConstantOperands(op, {2});
+ }
+ return success();
+}
+
+static LogicalResult checkConstantOperandRescale(Operation *op,
+ const TargetEnv &env) {
+ if (!env.allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
+ // Check 'multiplier', 'shift', 'input_zp' and 'output_zp'
+ return checkConstantOperands(op, {1, 2, 3, 4});
+ }
+ return success();
+}
+
+template <typename T>
+static LogicalResult checkConstantOperandConvOps(Operation *op,
+ const TargetEnv &env) {
+ if (!env.allows(Extension::dynamic) && isa<T>(op)) {
+ // Check 'input_zp' and 'weight_zp'
+ return checkConstantOperands(op, {3, 4});
+ }
+ return success();
+}
+
+static LogicalResult checkConstantOperandMatMul(Operation *op,
+ const TargetEnv &env) {
+ if (!env.allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
+ // Check 'A_zp' and 'B_zp'
+ return checkConstantOperands(op, {2, 3});
+ }
+ return success();
+}
+
+static LogicalResult checkConstantOperandAvgPool2d(Operation *op,
+ const TargetEnv &env) {
+ if (!env.allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
+ // Check 'input_zp' and 'output_zp'
+ return checkConstantOperands(op, {1, 2});
+ }
+ return success();
+}
- DenseElementsAttr padConst;
- // Assume this op is zero-padding if padConst is not presented.
- if (padOp.getPadConst() &&
- !matchPattern(padOp.getPadConst(), m_Constant(&padConst)))
- return op->emitOpError("pad_const of pad is not constant");
+static LogicalResult checkConstantOperandNegate(Operation *op,
+ const TargetEnv &env) {
+ if (!env.allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
+ // Check 'input1_zp' and 'output_zp'
+ return checkConstantOperands(op, {1, 2});
}
return success();
}
@@ -97,7 +171,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
LogicalResult applyConstantOperandCheck(Operation *op) {
for (auto &checker : constCheckers) {
- if (failed(checker(op)))
+ if (failed(checker(op, targetEnv)))
return failure();
}
return success();
@@ -114,7 +188,19 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
private:
void populateConstantOperandChecks() {
+ constCheckers.emplace_back(checkConstantOperandMul);
+ constCheckers.emplace_back(checkConstantOperandTable);
constCheckers.emplace_back(checkConstantOperandPad);
+ constCheckers.emplace_back(checkConstantOperandRescale);
+ constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
+ constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
+ constCheckers.emplace_back(
+ checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
+ constCheckers.emplace_back(
+ checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
+ constCheckers.emplace_back(checkConstantOperandMatMul);
+ constCheckers.emplace_back(checkConstantOperandAvgPool2d);
+ constCheckers.emplace_back(checkConstantOperandNegate);
}
bool levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) {
@@ -436,7 +522,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
llvm::errs() << "unknown TOSA extension name passed in: " << ext
<< ", supported extension are int16, int4, bf16, "
<< "fp8e4m3, fp8e5m2, fft, variable, controlflow, "
- << "doubleround and inexactround\n";
+ << "doubleround, inexactround and dynamic\n";
return signalPassFailure();
}
}
@@ -447,7 +533,9 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
bool CheckVariableReadOrWrite(Operation *op);
bool isValidElementType(Type type);
- SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
+ SmallVector<
+ std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
+ constCheckers;
TosaLevel tosaLevel;
DenseMap<StringAttr, mlir::Type> variablesMap;
TosaProfileCompliance profileComp;
diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
new file mode 100644
index 0000000000000..fd9b3d5f23483
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
@@ -0,0 +1,87 @@
+//--------------------------------------------------------
+// Check operations when the dynamic extension is enabled.
+//--------------------------------------------------------
+
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic strict-op-spec-alignment"
+
+// -----
+
+func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi8> {
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
+ return %0 : tensor<13x21x3xi8>
+}
+
+// -----
+
+func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
+ %0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<513xi8>) -> tensor<4x5xi8>
+ return
+}
+
+// -----
+
+func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
+ %0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
+ return %1 : tensor<13x21x3xi8>
+}
+
+// -----
+
+func.func @test_rescale_non_const_multiplier(%arg0: tensor<13x21x3xi32>, %multiplier: tensor<1xi32>) -> tensor<13x21x3xi32> {
+ %zps = "tosa.const"() {values = dense<0> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+ return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_rescale_non_const_shift(%arg0: tensor<13x21x3xi32>, %shift: tensor<1xi8>) -> tensor<13x21x3xi32> {
+ %zps = "tosa.const"() {values = dense<0> : tensor<1xi32> } : () -> tensor<1xi32>
+ %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+ return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_rescale_non_const_input_zp(%arg0: tensor<13x21x3xi32>, %input_zp: tensor<1xi32>) -> tensor<13x21x3xi32> {
+ %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %output_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+ return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_rescale_non_const_output_zp(%arg0: tensor<13x21x3xi32>, %output_zp: tensor<1xi32>) -> tensor<13x21x3xi32> {
+ %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+ return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_matmul_non_const_zps(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>, %a_zp: tensor<1xf32>, %b_zp: tensor<1xf32>) -> tensor<1x14x28xf32> {
+ %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32>
+ return %0 : tensor<1x14x28xf32>
+}
+
+// -----
+
+func.func @test_negate_non_const_zps(%arg0: tensor<1xf32>, %input1_zp: tensor<1xf32>, %output_zp: tensor<1xf32>) -> tensor<1xf32> {
+ %0 = tosa.negate %arg0, %input1_zp, %output_zp {} : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+ return %0 : tensor<1xf32>
+}
+
+// -----
+
+func.func @test_avg_pool2d_non_const_zps(%arg0: tensor<1x32x32x8xf32>, %input_zp: tensor<1xf32>, %output_zp: tensor<1xf32>) -> tensor<1x32x32x8xf32> {
+ %0 = "tosa.avg_pool2d"(%arg0, %input_zp, %output_zp) {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, acc_type = f32} :
+ (tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32>
+ return %0 : tensor<1x32x32x8xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index a488c051dcd3b..5b591e3c5f45c 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -242,7 +242,7 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>)
func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
%0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
- // expected-error at +1 {{'tosa.pad' op pad_const of pad is not constant}}
+ // expected-error at +1 {{'tosa.pad' op expected compile time resolvable constant, but got variable value for operand #2}}
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
return %1 : tensor<13x21x3xi8>
}
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 1a8366528b35c..13952716a9611 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -91,7 +91,7 @@ func.func @test_double_round_rescale(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error at +1 {{'tosa.rescale' op failed attribute check: rounding_mode = DOUBLE_ROUND requires extension [doubleround]}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "DOUBLE_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "DOUBLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
return %0 : tensor<13x21x3xi8>
}
@@ -103,6 +103,134 @@ func.func @test_inexact_round_rescale(%arg0: tensor<13x21x3xi8>) -> tensor<13x21
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error at +1 {{'tosa.rescale' op failed attribute check: rounding_mode = INEXACT_ROUND requires extension [inexactround]}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "INEXACT_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "INEXACT_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
return %0 : tensor<13x21x3xi8>
}
+
+// -----
+
+func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
+ %0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ // expected-error at +1 {{'tosa.pad' op expected compile time resolvable constant, but got variable value for operand #2}}
+ %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
+ return %1 : tensor<13x21x3xi8>
+}
+
+// -----
+
+func.func @test_rescale_non_const_multiplier(%arg0: tensor<13x21x3xi32>, %multiplier: tensor<1xi32>) -> tensor<13x21x3xi32> {
+ %zps = "tosa.const"() {values = dense<0> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.rescale' op expected compile time resolvable constant, but got variable value for operand #1}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+ return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_rescale_non_const_shift(%arg0: tensor<13x21x3xi32>, %shift: tensor<1xi8>) -> tensor<13x21x3xi32> {
+ %zps = "tosa.const"() {values = dense<0> : tensor<1xi32> } : () -> tensor<1xi32>
+ %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+ // expected-error at +1 {{'tosa.rescale' op expected compile time resolvable constant, but got variable value for operand #2}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+ return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_conv2d_non_const_input_zp(%arg0: tensor<1x4x4x4xi8>, %arg1: tensor<8x1x1x4xi8>, %arg2: tensor<8xi32>, %arg3: tensor<1xi8>) -> tensor<1x4x4x8xi32> {
+ %weight_zp = "tosa.const"() {values = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.conv2d' op expected compile time resolvable constant, but got variable value for operand #3}}
+ %0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %weight_zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xi8>, tensor<8x1x1x4xi8>, tensor<8xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8xi32>
+ return %0 : tensor<1x4x4x8xi32>
+}
+
+// -----
+
+func.func @test_conv3d_non_const_weight_zp(%arg0: tensor<1x4x8x21x17xi8>, %arg1: tensor<34x1x1x1x17xi8>, %arg2: tensor<34xi32>, %arg3: tensor<1xi8>) -> tensor<1x4x8x21x34xi32> {
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.conv3d' op expected compile time resolvable constant, but got variable value for operand #4}}
+ %0 = tosa.conv3d %arg0, %arg1, %arg2, %input_zp, %arg3 {acc_type = i32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xi8>, tensor<34x1x1x1x17xi8>, tensor<34xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x8x21x34xi32>
+ return %0 : tensor<1x4x8x21x34xi32>
+}
+
+// -----
+
+func.func @test_depthwise_conv2d_non_const_input_zp(%arg0: tensor<1x4x4x4xi8>, %arg1: tensor<1x1x4x2xi8>, %arg2: tensor<8xi32>, %arg3: tensor<1xi8>) -> tensor<1x4x4x8xi32> {
+ %weight_zp = "tosa.const"() {values = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.depthwise_conv2d' op expected compile time resolvable constant, but got variable value for operand #3}}
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %arg3, %weight_zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xi8>, tensor<1x1x4x2xi8>, tensor<8xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8xi32>
+ return %0 : tensor<1x4x4x8xi32>
+}
+
+// -----
+
+func.func @test_transpose_conv2d_non_const_weight_zp(%arg0: tensor<1x4x4x4xi8>, %arg1: tensor<1x1x4x2xi8>, %arg2: tensor<8xi32>, %arg3: tensor<1xi8>) -> tensor<1x4x4x8xi32> {
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.transpose_conv2d' op expected compile time resolvable constant, but got variable value for operand #4}}
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %arg3 {acc_type = i32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xi8>, tensor<1x1x4x2xi8>, tensor<8xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8xi32>
+ return %0 : tensor<1x4x4x8xi32>
+}
+
+// -----
+
+func.func @test_matmul_non_const_a_zp(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>, %a_zp: tensor<1xf32>, %b_zp: tensor<1xf32>) -> tensor<1x14x28xf32> {
+ // expected-error at +1 {{'tosa.matmul' op expected compile time resolvable constant, but got variable value for operand #2}}
+ %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32>
+ return %0 : tensor<1x14x28xf32>
+}
+
+// -----
+
+func.func @test_matmul_non_const_b_zp(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>, %b_zp: tensor<1xf32>) -> tensor<1x14x28xf32> {
+ %a_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32> } : () -> tensor<1xf32>
+ // expected-error at +1 {{'tosa.matmul' op expected compile time resolvable constant, but got variable value for operand #3}}
+ %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32>
+ return %0 : tensor<1x14x28xf32>
+}
+
+// -----
+
+func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi8> {
+ // expected-error at +1 {{'tosa.mul' op expected compile time resolvable constant, but got variable value for operand #2}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
+ return %0 : tensor<13x21x3xi8>
+}
+
+// -----
+
+func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
+ // expected-error at +1 {{'tosa.table' op expected compile time resolvable constant, but got variable value for operand #1}}
+ %0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<513xi8>) -> tensor<4x5xi8>
+ return
+}
+
+// -----
+
+func.func @test_rescale_non_const_input_zp(%arg0: tensor<13x21x3xi32>, %input_zp: tensor<1xi32>) -> tensor<13x21x3xi32> {
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+ %output_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+ // expected-error at +1 {{'tosa.rescale' op expected compile time resolvable constant, but got variable value for operand #3}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+ return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_negate_non_const_input1_zp(%arg0: tensor<1xf32>, %input_zp: tensor<1xf32>) -> tensor<1xf32> {
+ %output_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ // expected-error at +1 {{'tosa.negate' op expected compile time resolvable constant, but got variable value for operand #1}}
+ %0 = tosa.negate %arg0, %input_zp, %output_zp : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+ return %0 : tensor<1xf32>
+}
+
+// -----
+
+func.func @test_avg_pool2d_non_const_output_zp(%arg0: tensor<1x32x32x8xf32>, %output_zp: tensor<1xf32>) -> tensor<1x32x32x8xf32> {
+ %input_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ // expected-error at +1 {{'tosa.avg_pool2d' op expected compile time resolvable constant, but got variable value for operand #2}}
+ %0 = "tosa.avg_pool2d"(%arg0, %input_zp, %output_zp) {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, acc_type = f32} :
+ (tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32>
+ return %0 : tensor<1x32x32x8xf32>
+}
More information about the Mlir-commits
mailing list