[Mlir-commits] [mlir] [mlir][spirv] Add Cast/Rescale ops in TOSA Ext Inst Set (PR #189028)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 27 08:44:18 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Davide Grohmann (davidegrohmann)
<details>
<summary>Changes</summary>
This patch introduces the following operators:
spirv.Tosa.Cast
spirv.Tosa.Rescale
Also dialect and serialization round-trip tests have been added.
---
Patch is 30.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/189028.diff
7 Files Affected:
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+11)
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td (+124)
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td (+14)
- (modified) mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir (+144)
- (modified) mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir (+37)
- (modified) mlir/test/Target/SPIRV/tosa-ops.mlir (+61)
- (modified) mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp (+1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 9f9e2f5f9a677..11a91958d7484 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4985,4 +4985,15 @@ def SPIRV_TosaExtNaNPropagationModeAttr : SPIRV_I32EnumAttr<
I32EnumAttrCase<"Ignore", 2>,
]>;
+// NOTE: This is an attribute in the SPIR-V *dialect* but a constant (<id>) in
+// SPIR-V proper.
+def SPIRV_TosaExtRoundingModeAttr : SPIRV_I32EnumAttr<
+ "TosaExtRoundingModeType", "Tosa Ext Rounding Mode Type",
+ "tosa_ext_rounding_mode_type",
+ [
+ I32EnumAttrCase<"SingleRound", 1>,
+ I32EnumAttrCase<"InexactRound", 2>,
+ I32EnumAttrCase<"DoubleRound", 3>,
+ ]>;
+
#endif // MLIR_DIALECT_SPIRV_IR_BASE
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index 7fc7f86478491..c0061adfeed2c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -2648,4 +2648,128 @@ def SPIRV_TosaResizeOp : SPIRV_TosaOpWithResult<"Resize", 63, [Pure,
}
+def SPIRV_TosaCastOp : SPIRV_TosaOpWithResult<"Cast", 64, [Pure,
+ AllShapesMatch<["input", "output"]>]> {
+ let summary = "Cast operation.";
+
+ let description = [{
+ Casts a tensor from one data type to another.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_cast
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_cast
+
+ #### Example:
+ ```mlir
+ %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<1x65538x1x2xi8> -> !spirv.arm.tensor<1x65538x1x2xi32>
+ %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<11x5x14x4xf32> -> !spirv.arm.tensor<11x5x14x4xf16>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_TosaAny_TensorArm: $input
+ );
+
+ let results = (outs
+ SPIRV_TosaAny_TensorArm: $output
+ );
+
+ let assemblyFormat = [{
+ $input
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ ::mlir::spirv::TensorArmType getInputType() {
+ return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+ }
+ }];
+}
+
+
+def SPIRV_TosaRescaleOp : SPIRV_TosaOpWithResult<"Rescale", 65, [NoMemoryEffect,
+ AllShapesMatch<["input", "output"]>,
+ AllElementTypesMatch<["input", "input_zp"]>,
+ AllElementTypesMatch<["output", "output_zp"]>,
+ ElementTypeMatchesScale32<"multiplier">,
+ TensorLengthMatchesPerChannel<"multiplier">,
+ TensorLengthMatchesPerChannel<"shift">,
+ TypeConstraintImplicationOn<"input", I8, "output", [I8, I16, I32]>,
+ TypeConstraintImplicationOn<"input", I16, "output", [I8, I16, I32]>,
+ TypeConstraintImplicationOn<"input", I32, "output", [I8, I16, I32]>,
+ TypeConstraintImplicationOn<"input", I64, "output", [I8, I16, I32]>]> {
+ let summary = "Rescale operator.";
+
+ let description = [{
+ Rescale is defined using an integer multiply, add, and shift.
+
+ Rescale supports two precisions of multiplier: 16-bit and 32-bit. The
+ 32-bit multiplier version supports two rounding modes to enable simpler
+ lowering of existing frameworks that use two stage rounding. All arithmetic
+ is designed so that it does not overflow a 64-bit accumulator and that the
+ result fits in 32 bits. In particular, a 48-bit value cannot be scaled with
+ the 32-bit multiplier because the accumulator would need to have 80 bits.
+
+ The shift and value range are limited to allow a variety of implementations.
+ The limit of 62 on shift allows the shift to be decomposed as two right
+ shifts of 31.
+
+ Undefined behaviour may occur if the calculated result underflows or overflows
+ their integer ranges.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_rescale
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_rescale
+
+ #### Example:
+ ```mlir
+ %9 = spirv.Tosa.Rescale scale32 = true, rounding_mode = <DoubleRound>, per_channel = false, input_unsigned = false, output_unsigned = true, %arg0, %multiplier, %shift, %input_zp, %output_zp : !spirv.arm.tensor<17x29x19xi16>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<17x29x19xi16>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_BoolConstAttr: $scale32,
+ SPIRV_TosaExtRoundingModeAttr: $rounding_mode,
+ SPIRV_BoolConstAttr: $per_channel,
+ SPIRV_BoolConstAttr: $input_unsigned,
+ SPIRV_BoolConstAttr: $output_unsigned,
+ SPIRV_TosaInteger_TensorArm: $input,
+ SPIRV_Int16OrInt32_TensorArm1D: $multiplier,
+ SPIRV_Int8_TensorArm1D: $shift,
+ SPIRV_TosaInteger_1DTensorArmOfLength1: $input_zp,
+ SPIRV_TosaInteger_1DTensorArmOfLength1: $output_zp
+ );
+
+ let results = (outs
+ SPIRV_TosaInteger_TensorArm: $output
+ );
+
+ let assemblyFormat = [{
+ `scale32` `=` $scale32 `,`
+ `rounding_mode` `=` $rounding_mode `,`
+ `per_channel` `=` $per_channel `,`
+ `input_unsigned` `=` $input_unsigned `,`
+ `output_unsigned` `=` $output_unsigned `,`
+ $input `,`
+ $multiplier `,`
+ $shift `,`
+ $input_zp `,`
+ $output_zp
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ ::mlir::spirv::TensorArmType getInputType() {
+ return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+ }
+ ::mlir::spirv::TensorArmType getMultiplierType() {
+ return cast<::mlir::spirv::TensorArmType>(getMultiplier().getType());
+ }
+ ::mlir::spirv::TensorArmType getShiftType() {
+ return cast<::mlir::spirv::TensorArmType>(getShift().getType());
+ }
+ }];
+}
+
+
#endif // MLIR_DIALECT_SPIRV_IR_TOSA_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
index f116c4dcdd491..0b6de5bfc1ff9 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -38,6 +38,8 @@ class TensorArmRankOf<list<Type> allowedTypes, list<int> ranks>
[HasAnyRankOfPred<ranks>],
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensorArm">;
+def SPIRV_Int8_TensorArm1D : TensorArmRankOf<[SPIRV_Int8], [1]>;
+def SPIRV_Int16OrInt32_TensorArm1D : TensorArmRankOf<[SPIRV_Int16, SPIRV_Int32], [1]>;
def SPIRV_Int32_TensorArm2D : TensorArmRankOf<[SPIRV_Int32], [2]>;
def SPIRV_Float32_TensorArm3D: TensorArmRankOf<[SPIRV_Float32], [3]>;
def SPIRV_TosaInteger_TensorArm1D : TensorArmRankOf<[SPIRV_TosaInteger], [1]>;
@@ -92,6 +94,7 @@ def SPIRV_Int32_1DTensorArmOfLength1To6Attr : ConfinedAttr<
I32ElementsAttr, [SPIRV_DenseElementAttrsWithTensorArmType, Is1DTensorArmAttrOfLength<[1, 2, 3, 4, 5, 6]>]>;
def SPIRV_Int8_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_Int8]>;
+def SPIRV_TosaInteger_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_TosaInteger]>;
def SPIRV_TosaNumerical_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_TosaNumerical]>;
def SPIRV_TosaAny_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_TosaAny]>;
@@ -201,5 +204,16 @@ class VariadicInputAllSameRank<string reference, string input>:
" && ::llvm::cast<::mlir::ShapedType>(t).getRank() == "
# Rank<reference>.result # "; })">>;
+class ElementTypeMatchesScale32<string tensor> :
+ PredOpTrait<tensor # " must have element type i32 when scale32 is true, otherwise i16",
+ CPred<"::llvm::cast<::mlir::ShapedType>($" # tensor # ".getType()).getElementType()."
+ "isInteger(getScale32() ? 32 : 16)">>;
+
+class TensorLengthMatchesPerChannel<string tensor> :
+ PredOpTrait<tensor # " must have length rank(input) - 1 when per_channel is true, otherwise length 1",
+ CPred<"::llvm::cast<::mlir::ShapedType>($" # tensor # ".getType()).getShape()[0] == "
+ "(getPerChannel() ? "
+ "::llvm::cast<::mlir::ShapedType>($input.getType()).getRank() - 1 : 1)">>;
+
#endif // MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
index f95fedba74307..82ec2b9506bdc 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
@@ -1940,3 +1940,147 @@ spirv.ARM.Graph @resize_bf16_input_output_element_type_must_be_bf16(%arg0: !spir
%4 = spirv.Tosa.Resize mode = <Bilinear>, %arg0, %1, %2, %3 : !spirv.arm.tensor<1x48x33x63xbf16>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x753x297x63xf32>
spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x753x297x63xf32>
}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Cast
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @cast_input_output_shapes_not_matching(%arg0: !spirv.arm.tensor<2x3x4xi8>) -> (!spirv.arm.tensor<2x3x5xi32>) {
+ // expected-error @+1 {{op failed to verify that all of {input, output} have same shape}}
+ %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xi8> -> !spirv.arm.tensor<2x3x5xi32>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x5xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Rescale
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @rescale_input_output_shapes_not_matching(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x5xi16>) {
+ %1 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16>
+ %2 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
+ %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+ %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+ // expected-error @+1 {{op failed to verify that all of {input, output} have same shape}}
+ %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x5xi16>
+ spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x5xi16>
+}
+
+spirv.ARM.Graph @rescale_input_and_input_zp_element_types_not_matching(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi16>) {
+ %1 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16>
+ %2 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
+ %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
+ %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+ // expected-error @+1 {{op failed to verify that all of {input, input_zp} have same element type}}
+ %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16>
+ spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16>
+}
+
+spirv.ARM.Graph @rescale_output_and_output_zp_element_types_not_matching(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi16>) {
+ %1 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16>
+ %2 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
+ %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+ %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
+ // expected-error @+1 {{op failed to verify that all of {output, output_zp} have same element type}}
+ %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<2x3x4xi16>
+ spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16>
+}
+
+spirv.ARM.Graph @rescale_scale32_true_requires_i32_multiplier(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi16>) {
+ %1 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16>
+ %2 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
+ %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+ %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+ // expected-error @+1 {{op failed to verify that multiplier must have element type i32 when scale32 is true, otherwise i16}}
+ %5 = spirv.Tosa.Rescale scale32 = true, rounding_mode = <SingleRound>, per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16>
+ spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16>
+}
+
+spirv.ARM.Graph @rescale_scale32_false_requires_i16_multiplier(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi16>) {
+ %1 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi32>
+ %2 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
+ %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+ %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+ // expected-error @+1 {{op failed to verify that multiplier must have element type i32 when scale32 is true, otherwise i16}}
+ %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16>
+ spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16>
+}
+
+spirv.ARM.Graph @rescale_per_channel_true_requires_multiplier_length_rank_minus_one(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi16>) {
+ %1 = spirv.Constant dense<[1]> : !spirv.arm.tensor<1xi16>
+ %2 = spirv.Constant dense<[0, 0]> : !spirv.arm.tensor<2xi8>
+ %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+ %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+ // expected-error @+1 {{op failed to verify that multiplier must have length rank(input) - 1 when per_channel is true, otherwise length 1}}
+ %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = true, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<2xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16>
+ spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16>
+}
+
+spirv.ARM.Graph @rescale_per_channel_true_requires_shift_length_rank_minus_one(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi16>) {
+ %1 = spirv.Constant dense<[1, 1]> : !spirv.arm.tensor<2xi16>
+ %2 = spirv.Constant dense<[0]> : !spirv.arm.tensor<1xi8>
+ %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+ %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+ // expected-error @+1 {{op failed to verify that shift must have length rank(input) - 1 when per_channel is true, otherwise length 1}}
+ %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = true, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<2xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16>
+ spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16>
+}
+
+spirv.ARM.Graph @rescale_per_channel_false_requires_multiplier_length_one(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi16>) {
+ %1 = spirv.Constant dense<[1, 1]> : !spirv.arm.tensor<2xi16>
+ %2 = spirv.Constant dense<[0]> : !spirv.arm.tensor<1xi8>
+ %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+ %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+ // expected-error @+1 {{op failed to verify that multiplier must have length rank(input) - 1 when per_channel is true, otherwise length 1}}
+ %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<2xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16>
+ spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16>
+}
+
+spirv.ARM.Graph @rescale_per_channel_false_requires_shift_length_one(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi16>) {
+ %1 = spirv.Constant dense<[1]> : !spirv.arm.tensor<1xi16>
+ %2 = spirv.Constant dense<[0, 0]> : !spirv.arm.tensor<2xi8>
+ %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+ %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+ // expected-error @+1 {{op failed to verify that shift must have length rank(input) - 1 when per_channel is true, otherwise length 1}}
+ %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<2xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16>
+ spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16>
+}
+
+spirv.ARM.Graph @rescale_i8_input_requires_i8_i16_or_i32_output(%arg0: !spirv.arm.tensor<2x3x4xi8>) -> (!spirv.arm.tensor<2x3x4xi64>) {
+ %1 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16>
+ %2 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
+ %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
+ %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi64>
+ // expected-error @+1 {{op failed to verify that if input has type 8-bit signless integer then output must have a type in [8-bit signless integer,16-bit signless integer,32-bit signless integer]}}
+ %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi64> -> !spirv.arm.tensor<2x3x4xi64>
+ spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi64>
+}
+
+spirv.ARM.Graph @rescale_i16_input_requires_i8_i16_or_i32_output(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi64>) {
+ %1 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16>
+ %2 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
+ %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+ %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi64>
+ // expected-error @+1 {{op failed to verify that if input has type 16-bit signless integer then output must have a type in [8-bit signless integer,16-bit signless integer,32-bit signless integer]}}
+ %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi64> -> !spirv.arm.tensor<2x3x4xi64>
+ spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi64>
+}
+
+spirv.ARM.Graph @rescale_i32_input_requires_i8_i16_or_i32_output(%arg0: !spirv.arm.tensor<2x3x4xi32>) -> (!spirv.arm.tensor<2x3x4xi64>) {
+ %1 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16>
+ %2 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
+ %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi32>
+ %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi64>
+ // expected-error @+1 {{op failed to verify that if input has type 32-bit signless integer then output must have a type in [8-bit signless integer,16-bit signless integer,32-bit signless integer]}}
+ %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = false, input_unsigned = false, output_un...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/189028
More information about the Mlir-commits
mailing list