[Mlir-commits] [mlir] [mlir][tosa] Check for compile time constants in the validation pass (PR #131123)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 13 04:19:16 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Luke Hutton (lhutton1)

<details>
<summary>Changes</summary>

This commit adds a concept of the 'dynamic' extension in the Dialect and checks that compile time constant (CTC) operands for each operator are constant if the dynamic extension is not loaded.

Operands labeled as CTC in the specification that are of tosa.shape (shape_t in the specification) type are not checked as they are always expected to be constant. This requirement is checked elsewhere in the dialect.

---

Patch is 25.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/131123.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+3-1) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h (+1) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+100-12) 
- (added) mlir/test/Dialect/Tosa/dynamic_extension.mlir (+87) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+1-1) 
- (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+130-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 52130d31c248d..4cb2e8006ca57 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -228,6 +228,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
 // CONTROLFLOW  : Control Flow operations.
 // DOUBLEROUND  : Adds double rounding support to the RESCALE operator.
 // INEXACTROUND : Adds inexact rounding support to the RESCALE operator.
+// DYNAMIC      : Removes all Compile Time Constant state for CTC inputs.
 //===----------------------------------------------------------------------===//
 
 def Tosa_NONE : I32EnumAttrCase<"none", 0>;
@@ -245,12 +246,13 @@ def Tosa_EXT_VARIABLE     : I32EnumAttrCase<"variable", 7>;
 def Tosa_EXT_CONTROLFLOW  : I32EnumAttrCase<"controlflow", 8>;
 def Tosa_EXT_DOUBLEROUND  : I32EnumAttrCase<"doubleround", 9>;
 def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>;
+def Tosa_EXT_DYNAMIC      : I32EnumAttrCase<"dynamic", 11>;
 
 def Tosa_ExtensionAttr
     : Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [
       Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16, Tosa_EXT_FP8E4M3,
       Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW,
-      Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND, Tosa_EXT_NONE
+      Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND, Tosa_EXT_DYNAMIC, Tosa_EXT_NONE
     ]>;
 
 def Tosa_ExtensionArrayAttr
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 88f454f63e6f9..69b827fe14dee 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -146,6 +146,7 @@ class TosaProfileCompliance {
       return {Profile::pro_fp};
     case Extension::variable:
     case Extension::controlflow:
+    case Extension::dynamic:
       return {Profile::pro_fp, Profile::pro_int};
     case Extension::none:
       return {};
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 6b1b651a90cfb..79c13793d7713 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -41,17 +41,91 @@ using namespace mlir::tosa;
 
 namespace {
 
-static LogicalResult checkConstantOperandPad(Operation *op) {
+static LogicalResult
+checkConstantOperands(Operation *op, ArrayRef<unsigned int> operandIndices) {
+  for (const auto index : operandIndices) {
+    Attribute attr;
+    if (!matchPattern(op->getOperand(index), m_Constant(&attr))) {
+      return op->emitOpError("expected compile time resolvable constant, but "
+                             "got variable value for operand #")
+             << index;
+    }
+  }
+  return success();
+}
+
+static LogicalResult checkConstantOperandMul(Operation *op,
+                                             const TargetEnv &env) {
+  if (!env.allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
+    // Check 'shift'
+    return checkConstantOperands(op, {2});
+  }
+  return success();
+}
+
+static LogicalResult checkConstantOperandTable(Operation *op,
+                                               const TargetEnv &env) {
+  if (!env.allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
+    // Check 'table'
+    return checkConstantOperands(op, {1});
+  }
+  return success();
+}
+
+static LogicalResult checkConstantOperandPad(Operation *op,
+                                             const TargetEnv &env) {
   if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
-    DenseElementsAttr paddings;
-    if (!matchPattern(padOp.getPadding(), m_Constant(&paddings)))
-      return op->emitOpError("padding of pad is not constant");
+    // Assume this op is zero-padding if padConst is not presented
+    if (!env.allows(Extension::dynamic) && padOp.getPadConst())
+      // Check 'pad_const'
+      // Note: 'padding' (operand 1) is not checked as it is a tosa.shape type
+      return checkConstantOperands(op, {2});
+  }
+  return success();
+}
+
+static LogicalResult checkConstantOperandRescale(Operation *op,
+                                                 const TargetEnv &env) {
+  if (!env.allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
+    // Check 'multiplier', 'shift', 'input_zp' and 'output_zp'
+    return checkConstantOperands(op, {1, 2, 3, 4});
+  }
+  return success();
+}
+
+template <typename T>
+static LogicalResult checkConstantOperandConvOps(Operation *op,
+                                                 const TargetEnv &env) {
+  if (!env.allows(Extension::dynamic) && isa<T>(op)) {
+    // Check 'input_zp' and 'weight_zp'
+    return checkConstantOperands(op, {3, 4});
+  }
+  return success();
+}
+
+static LogicalResult checkConstantOperandMatMul(Operation *op,
+                                                const TargetEnv &env) {
+  if (!env.allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
+    // Check 'A_zp' and 'B_zp'
+    return checkConstantOperands(op, {2, 3});
+  }
+  return success();
+}
+
+static LogicalResult checkConstantOperandAvgPool2d(Operation *op,
+                                                   const TargetEnv &env) {
+  if (!env.allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
+    // Check 'input_zp' and 'output_zp'
+    return checkConstantOperands(op, {1, 2});
+  }
+  return success();
+}
 
-    DenseElementsAttr padConst;
-    // Assume this op is zero-padding if padConst is not presented.
-    if (padOp.getPadConst() &&
-        !matchPattern(padOp.getPadConst(), m_Constant(&padConst)))
-      return op->emitOpError("pad_const of pad is not constant");
+static LogicalResult checkConstantOperandNegate(Operation *op,
+                                                const TargetEnv &env) {
+  if (!env.allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
+    // Check 'input1_zp' and 'output_zp'
+    return checkConstantOperands(op, {1, 2});
   }
   return success();
 }
@@ -97,7 +171,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
 
   LogicalResult applyConstantOperandCheck(Operation *op) {
     for (auto &checker : constCheckers) {
-      if (failed(checker(op)))
+      if (failed(checker(op, targetEnv)))
         return failure();
     }
     return success();
@@ -114,7 +188,19 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
 
 private:
   void populateConstantOperandChecks() {
+    constCheckers.emplace_back(checkConstantOperandMul);
+    constCheckers.emplace_back(checkConstantOperandTable);
     constCheckers.emplace_back(checkConstantOperandPad);
+    constCheckers.emplace_back(checkConstantOperandRescale);
+    constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
+    constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
+    constCheckers.emplace_back(
+        checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
+    constCheckers.emplace_back(
+        checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
+    constCheckers.emplace_back(checkConstantOperandMatMul);
+    constCheckers.emplace_back(checkConstantOperandAvgPool2d);
+    constCheckers.emplace_back(checkConstantOperandNegate);
   }
 
   bool levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) {
@@ -436,7 +522,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
           llvm::errs() << "unknown TOSA extension name passed in: " << ext
                        << ", supported extension are int16, int4, bf16, "
                        << "fp8e4m3, fp8e5m2, fft, variable, controlflow, "
-                       << "doubleround and inexactround\n";
+                       << "doubleround, inexactround and dynamic\n";
           return signalPassFailure();
         }
       }
@@ -447,7 +533,9 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
   bool CheckVariableReadOrWrite(Operation *op);
   bool isValidElementType(Type type);
 
-  SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
+  SmallVector<
+      std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
+      constCheckers;
   TosaLevel tosaLevel;
   DenseMap<StringAttr, mlir::Type> variablesMap;
   TosaProfileCompliance profileComp;
diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
new file mode 100644
index 0000000000000..fd9b3d5f23483
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
@@ -0,0 +1,87 @@
+//--------------------------------------------------------
+// Check operations when the dynamic extension is enabled.
+//--------------------------------------------------------
+
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic strict-op-spec-alignment"
+
+// -----
+
+func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi8> {
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
+  return %0 : tensor<13x21x3xi8>
+}
+
+// -----
+
+func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
+  %0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<513xi8>) -> tensor<4x5xi8>
+  return
+}
+
+// -----
+
+func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
+  %0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
+  return %1 : tensor<13x21x3xi8>
+}
+
+// -----
+
+func.func @test_rescale_non_const_multiplier(%arg0: tensor<13x21x3xi32>, %multiplier: tensor<1xi32>) -> tensor<13x21x3xi32> {
+  %zps = "tosa.const"() {values = dense<0> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_rescale_non_const_shift(%arg0: tensor<13x21x3xi32>, %shift: tensor<1xi8>) -> tensor<13x21x3xi32> {
+  %zps = "tosa.const"() {values = dense<0> : tensor<1xi32> } : () -> tensor<1xi32>
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_rescale_non_const_input_zp(%arg0: tensor<13x21x3xi32>, %input_zp: tensor<1xi32>) -> tensor<13x21x3xi32> {
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_rescale_non_const_output_zp(%arg0: tensor<13x21x3xi32>, %output_zp: tensor<1xi32>) -> tensor<13x21x3xi32> {
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_matmul_non_const_zps(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>, %a_zp: tensor<1xf32>, %b_zp: tensor<1xf32>) -> tensor<1x14x28xf32> {
+  %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<1x14x28xf32>
+  return %0 : tensor<1x14x28xf32>
+}
+
+// -----
+
+func.func @test_negate_non_const_zps(%arg0: tensor<1xf32>, %input1_zp: tensor<1xf32>, %output_zp: tensor<1xf32>) -> tensor<1xf32> {
+  %0 = tosa.negate %arg0, %input1_zp, %output_zp {} : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  return %0 : tensor<1xf32>
+}
+
+// -----
+
+func.func @test_avg_pool2d_non_const_zps(%arg0: tensor<1x32x32x8xf32>, %input_zp: tensor<1xf32>, %output_zp: tensor<1xf32>) -> tensor<1x32x32x8xf32> {
+  %0 = "tosa.avg_pool2d"(%arg0, %input_zp, %output_zp) {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, acc_type = f32} :
+         (tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index a488c051dcd3b..5b591e3c5f45c 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -242,7 +242,7 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>)
 
 func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
   %0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
-  // expected-error at +1 {{'tosa.pad' op pad_const of pad is not constant}}
+  // expected-error at +1 {{'tosa.pad' op expected compile time resolvable constant, but got variable value for operand #2}}
   %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
   return %1 : tensor<13x21x3xi8>
 }
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 1a8366528b35c..13952716a9611 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -91,7 +91,7 @@ func.func @test_double_round_rescale(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x
   %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
   %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
   // expected-error at +1 {{'tosa.rescale' op failed attribute check: rounding_mode = DOUBLE_ROUND requires extension [doubleround]}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "DOUBLE_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "DOUBLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
   return %0 : tensor<13x21x3xi8>
 }
 
@@ -103,6 +103,134 @@ func.func @test_inexact_round_rescale(%arg0: tensor<13x21x3xi8>) -> tensor<13x21
   %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
   %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
   // expected-error at +1 {{'tosa.rescale' op failed attribute check: rounding_mode = INEXACT_ROUND requires extension [inexactround]}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "INEXACT_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "INEXACT_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
   return %0 : tensor<13x21x3xi8>
 }
+
+// -----
+
+func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
+  %0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  // expected-error at +1 {{'tosa.pad' op expected compile time resolvable constant, but got variable value for operand #2}}
+  %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
+  return %1 : tensor<13x21x3xi8>
+}
+
+// -----
+
+func.func @test_rescale_non_const_multiplier(%arg0: tensor<13x21x3xi32>, %multiplier: tensor<1xi32>) -> tensor<13x21x3xi32> {
+  %zps = "tosa.const"() {values = dense<0> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op expected compile time resolvable constant, but got variable value for operand #1}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_rescale_non_const_shift(%arg0: tensor<13x21x3xi32>, %shift: tensor<1xi8>) -> tensor<13x21x3xi32> {
+  %zps = "tosa.const"() {values = dense<0> : tensor<1xi32> } : () -> tensor<1xi32>
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  // expected-error at +1 {{'tosa.rescale' op expected compile time resolvable constant, but got variable value for operand #2}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_conv2d_non_const_input_zp(%arg0: tensor<1x4x4x4xi8>, %arg1: tensor<8x1x1x4xi8>, %arg2: tensor<8xi32>, %arg3: tensor<1xi8>) -> tensor<1x4x4x8xi32> {
+  %weight_zp = "tosa.const"() {values = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.conv2d' op expected compile time resolvable constant, but got variable value for operand #3}}
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %weight_zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xi8>, tensor<8x1x1x4xi8>, tensor<8xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8xi32>
+  return %0 : tensor<1x4x4x8xi32>
+}
+
+// -----
+
+func.func @test_conv3d_non_const_weight_zp(%arg0: tensor<1x4x8x21x17xi8>, %arg1: tensor<34x1x1x1x17xi8>, %arg2: tensor<34xi32>, %arg3: tensor<1xi8>) -> tensor<1x4x8x21x34xi32> {
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.conv3d' op expected compile time resolvable constant, but got variable value for operand #4}}
+  %0 = tosa.conv3d %arg0, %arg1, %arg2, %input_zp, %arg3 {acc_type = i32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (t...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/131123


More information about the Mlir-commits mailing list