[Mlir-commits] [mlir] [mlir][spirv] Add Gather/Scatter/Resize ops in TOSA Ext Inst Set (PR #188497)
Davide Grohmann
llvmlistbot at llvm.org
Fri Mar 27 02:05:30 PDT 2026
https://github.com/davidegrohmann updated https://github.com/llvm/llvm-project/pull/188497
>From 71f0e7a199505410b00f03632cddbe3a4d3a1c70 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Wed, 28 Jan 2026 13:39:57 +0100
Subject: [PATCH] [mlir][spirv] Add Gather/Scatter/Resize ops in TOSA Ext Inst
Set
This patch introduces the following reduction operators:
spirv.Tosa.Gather
spirv.Tosa.Scatter
spirv.Tosa.Resize
Also dialect and serialization round-trip tests have been added.
Change-Id: I873f77d23673e87ba2e303686184021fdd792822
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
---
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 10 +
.../mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td | 186 ++++++++++++++++++
.../mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td | 25 +++
.../SPIRV/IR/tosa-ops-verification.mlir | 87 ++++++++
mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir | 72 +++++++
mlir/test/Target/SPIRV/tosa-ops.mlir | 126 ++++++++++++
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | 1 +
7 files changed, 507 insertions(+)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 8badb84a879fa..9f9e2f5f9a677 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4965,6 +4965,16 @@ def SPIRV_TosaExtAccTypeAttr : SPIRV_I32EnumAttr<
I32EnumAttrCase<"INT48", 4>,
]>;
+// NOTE: This is an attribute in the SPIR-V *dialect* but a constant (<id>) in
+// SPIR-V proper.
+def SPIRV_TosaExtResizeModeAttr : SPIRV_I32EnumAttr<
+ "TosaExtResizeModeType", "Tosa Ext Resize Mode Type",
+ "tosa_ext_resize_mode_type",
+ [
+ I32EnumAttrCase<"NearestNeighbor", 1>,
+ I32EnumAttrCase<"Bilinear", 2>,
+ ]>;
+
// NOTE: This is an attribute in the SPIR-V *dialect* but a constant (<id>) in
// SPIR-V proper.
def SPIRV_TosaExtNaNPropagationModeAttr : SPIRV_I32EnumAttr<
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index f2519784b5058..7fc7f86478491 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -2462,4 +2462,190 @@ def SPIRV_TosaTransposeOp : SPIRV_TosaOpWithResult<"Transpose", 60, [Pure,
}
+def SPIRV_TosaGatherOp : SPIRV_TosaOpWithResult<"Gather", 61, [NoMemoryEffect,
+ AllElementTypesMatch<["values", "output"]>,
+ ValuesIndicesShapesMatch<"values", "indices", "output">]> {
+ let summary = "Gather operation.";
+
+ let description = [{
+ Generate a tensor for which each element in the output is a subtensor of the
+ values tensor based on the indices. Undefined behaviour may occur if the
+ specified indices are out of range.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_gather
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_gather
+
+ #### Example:
+ ```mlir
+ %0 = spirv.Tosa.Gather %values, %indices : !spirv.arm.tensor<31x11x45xi32>, !spirv.arm.tensor<31x15xi32> -> !spirv.arm.tensor<31x15x45xi32>
+ %0 = spirv.Tosa.Gather %values, %indices : !spirv.arm.tensor<59x61x19xf32>, !spirv.arm.tensor<59x65xi32> -> !spirv.arm.tensor<59x65x19xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_TosaNumerical_TensorArm3D: $values,
+ SPIRV_Int32_TensorArm2D: $indices
+ );
+
+ let results = (outs
+ SPIRV_TosaNumerical_TensorArm3D: $output
+ );
+
+ let assemblyFormat = [{
+ $values `,`
+ $indices
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ ::mlir::spirv::TensorArmType getValuesType() {
+ return cast<::mlir::spirv::TensorArmType>(getValues().getType());
+ }
+ ::mlir::spirv::TensorArmType getIndicesType() {
+ return cast<::mlir::spirv::TensorArmType>(getIndices().getType());
+ }
+ }];
+}
+
+
+def SPIRV_TosaScatterOp : SPIRV_TosaOpWithResult<"Scatter", 62, [NoMemoryEffect,
+ AllElementTypesMatch<["values_in", "input", "values_out"]>,
+ AllTypesMatch<["values_in", "values_out"]>,
+ ValuesIndicesShapesMatch<"values_in", "indices", "input">]> {
+ let summary = "Scatter operation.";
+
+ let description = [{
+ The values_out tensor is set to the values_in tensor with data modified as
+ follows: data from the input tensor is inserted at the positions specified
+ by the indices tensor. In use cases that require multiple updates to the
+ same output position, these must be decomposed into multiple scatter
+ operations. Undefined behaviour may occur if the specified indices are
+ out of range or duplicate indices are provided.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_scatter
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_scatter
+
+ #### Example:
+ ```mlir
+ %0 = spirv.Tosa.Scatter %values_in, %indices, %arg0 : !spirv.arm.tensor<34x28x54xi32>, !spirv.arm.tensor<34x18xi32>, !spirv.arm.tensor<34x18x54xi32> -> !spirv.arm.tensor<34x28x54xi32>
+ %0 = spirv.Tosa.Scatter %values_in, %indices, %arg0 : !spirv.arm.tensor<18x34x25xf16>, !spirv.arm.tensor<18x20xi32>, !spirv.arm.tensor<18x20x25xf16> -> !spirv.arm.tensor<18x34x25xf16>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_TosaNumerical_TensorArm3D: $values_in,
+ SPIRV_Int32_TensorArm2D: $indices,
+ SPIRV_TosaNumerical_TensorArm3D: $input
+ );
+
+ let results = (outs
+ SPIRV_TosaNumerical_TensorArm3D: $values_out
+ );
+
+ let assemblyFormat = [{
+ $values_in `,`
+ $indices `,`
+ $input
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ ::mlir::spirv::TensorArmType getValuesInType() {
+ return cast<::mlir::spirv::TensorArmType>(getValuesIn().getType());
+ }
+ ::mlir::spirv::TensorArmType getIndicesType() {
+ return cast<::mlir::spirv::TensorArmType>(getIndices().getType());
+ }
+ ::mlir::spirv::TensorArmType getInputType() {
+ return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+ }
+ }];
+}
+
+
+def SPIRV_TosaResizeOp : SPIRV_TosaOpWithResult<"Resize", 63, [Pure,
+ TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
+ TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
+ TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
+ TypeConstraintImplicationOn<"input", I8, "output", [I8, I32]>,
+ TypeConstraintImplicationOn<"input", I16, "output", [I16, I64]>]> {
+ let summary = "Resize operation, supports various resize/upsample modes.";
+
+ let description = [{
+ Resizes a tensor. Resize is only allowed in the H and W dimensions, given the input
+ shape = [N,H,W,C].
+
+ The height dimension (H) is scaled by factor ($ scale_y_n/scale_y_d $). The width
+ dimension (W) is scaled by factor ($ scale_x_n/scale_x_d $).
+
+ The NearestNeighbor mode returns the value of the input tensor closest to
+ the calculated sample position for both floating-point and integer data
+ formats.
+
+ Floating-point Bilinear mode returns a bilinearly interpolated output value
+ based on the four closest input sample positions.
+
+ For integer Bilinear interpolation mode, the output value must be scaled by
+ $ 1/(scale_y_n * scale_x_n) $ in a following operation to complete the
+ interpolation (for example with a rescale operator).
+
+ The output dimensions can be derived from the input dimensions by inverting
+ the scale. The [border_y, border_x] values adjust the output size to allow
+ fractional sampling beyond integer input position (H - 1,W - 1).
+
+ The limit MAX_SCALE=256 is applied to each scale ratio after reduction of the
+ ratio. Individual scale numerator and denominator values are allowed to be
+ larger than MAX_SCALE.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_resize
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_resize
+
+ #### Example:
+ ```mlir
+ %4 = spirv.Tosa.Resize mode = <NearestNeighbor>, %arg0, %scale, %offset, %border : !spirv.arm.tensor<1x1x31x55xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x1x278x55xi8>
+ %4 = spirv.Tosa.Resize mode = <Bilinear>, %arg0, %scale, %offset, %border : !spirv.arm.tensor<1x48x33x63xf32>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x753x297x63xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_TosaExtResizeModeAttr: $mode,
+ SPIRV_TosaNumerical_TensorArm4D: $input,
+ SPIRV_Int32_1DTensorArmOfLength4: $scale,
+ SPIRV_Int32_1DTensorArmOfLength2: $offset,
+ SPIRV_Int32_1DTensorArmOfLength2: $border
+ );
+
+ let results = (outs
+ SPIRV_TosaNumerical_TensorArm4D: $output
+ );
+
+ let assemblyFormat = [{
+ `mode` `=` $mode `,`
+ $input `,`
+ $scale `,`
+ $offset `,`
+ $border
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ ::mlir::spirv::TensorArmType getInputType() {
+ return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+ }
+ ::mlir::spirv::TensorArmType getScaleType() {
+ return cast<::mlir::spirv::TensorArmType>(getScale().getType());
+ }
+ ::mlir::spirv::TensorArmType getOffsetType() {
+ return cast<::mlir::spirv::TensorArmType>(getOffset().getType());
+ }
+ ::mlir::spirv::TensorArmType getBorderType() {
+ return cast<::mlir::spirv::TensorArmType>(getBorder().getType());
+ }
+ }];
+}
+
+
#endif // MLIR_DIALECT_SPIRV_IR_TOSA_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
index d14eeba324adf..f116c4dcdd491 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -38,6 +38,7 @@ class TensorArmRankOf<list<Type> allowedTypes, list<int> ranks>
[HasAnyRankOfPred<ranks>],
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensorArm">;
+def SPIRV_Int32_TensorArm2D : TensorArmRankOf<[SPIRV_Int32], [2]>;
def SPIRV_Float32_TensorArm3D: TensorArmRankOf<[SPIRV_Float32], [3]>;
def SPIRV_TosaInteger_TensorArm1D : TensorArmRankOf<[SPIRV_TosaInteger], [1]>;
def SPIRV_TosaNumerical_TensorArm1D : TensorArmRankOf<[SPIRV_TosaNumerical], [1]>;
@@ -65,6 +66,9 @@ class SPIRV_1DTensorArmOfLengthAndType<list<int> allowedLengths, list<Type> allo
"rank 1 tensorArm of length " # !interleave(allowedLengths, "/"),
"::mlir::spirv::TensorArmType">;
+def SPIRV_Int32_1DTensorArmOfLength2 : SPIRV_1DTensorArmOfLengthAndType<[2], [SPIRV_Int32]>;
+def SPIRV_Int32_1DTensorArmOfLength4 : SPIRV_1DTensorArmOfLengthAndType<[4], [SPIRV_Int32]>;
+
def SPIRV_Int32_1DTensorArmOfLength1To6 : SPIRV_1DTensorArmOfLengthAndType<[1, 2, 3, 4, 5, 6], [SPIRV_Int32]>;
def SPIRV_Int32_1DTensorArmOfEvenLength2To12 : SPIRV_1DTensorArmOfLengthAndType<[2, 4, 6, 8, 10, 12], [SPIRV_Int32]>;
@@ -147,6 +151,27 @@ class MatchBroadcastableShapes<string input1, string input2, string output>:
"})">]>
>;
+class SameDimsOrDynamicPred<string lhs, int lhsDim, string rhs, int rhsDim> :
+ CPred<"[](::mlir::ShapedType lhsType, ::mlir::ShapedType rhsType) { "
+ " int64_t lhsSize = lhsType.getDimSize(" # lhsDim # "); "
+ " int64_t rhsSize = rhsType.getDimSize(" # rhsDim # "); "
+ " return ::mlir::ShapedType::isDynamic(lhsSize) || "
+ " ::mlir::ShapedType::isDynamic(rhsSize) || lhsSize == rhsSize; "
+ "}("
+ "::llvm::cast<::mlir::ShapedType>($" # lhs # ".getType()), "
+ "::llvm::cast<::mlir::ShapedType>($" # rhs # ".getType()))">;
+
+class ValuesIndicesShapesMatch<string values, string indices, string tensor>:
+ PredOpTrait<"shapes of " # values # ", " # indices # ", and " # tensor #
+ " must satisfy [N,K,C], [N,W], [N,W,C]",
+ And<[
+ SameDimsOrDynamicPred<values, 0, indices, 0>,
+ SameDimsOrDynamicPred<values, 0, tensor, 0>,
+ SameDimsOrDynamicPred<indices, 0, tensor, 0>,
+ SameDimsOrDynamicPred<indices, 1, tensor, 1>,
+ SameDimsOrDynamicPred<values, 2, tensor, 2>
+ ]>>;
+
class TableSizeConstraint<string input, Type type, int size>:
PredOpTrait<"table must have size " # size # " if " # input # " has element type " # type.summary,
Implies<ElementTypeIsPred<input, type>, [CPred<"::llvm::cast<::mlir::ShapedType>(getTable().getType()).getShape()[0] == " # size>]>
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
index 5ff51dca91a3b..f95fedba74307 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
@@ -1853,3 +1853,90 @@ spirv.ARM.Graph @transpose_perms_element_count_not_input_rank(%arg0: !spirv.arm.
%0 = spirv.Tosa.Transpose perms = [1, 0], %arg0 : !spirv.arm.tensor<2x3x4xi8> -> !spirv.arm.tensor<4x2x3xi8>
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x2x3xi8>
}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Gather
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @gather_values_output_element_types_not_matching(%arg0: !spirv.arm.tensor<31x11x45xi32>, %arg1: !spirv.arm.tensor<31x15xi32>) -> (!spirv.arm.tensor<31x15x45xi16>) {
+ // expected-error @+1 {{op failed to verify that all of {values, output} have same element type}}
+ %0 = spirv.Tosa.Gather %arg0, %arg1 : !spirv.arm.tensor<31x11x45xi32>, !spirv.arm.tensor<31x15xi32> -> !spirv.arm.tensor<31x15x45xi16>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<31x15x45xi16>
+}
+
+spirv.ARM.Graph @gather_shapes_not_matching(%arg0: !spirv.arm.tensor<31x11x45xi32>, %arg1: !spirv.arm.tensor<31x15xi32>) -> (!spirv.arm.tensor<30x15x44xi32>) {
+ // expected-error @+1 {{op failed to verify that shapes of values, indices, and output must satisfy [N,K,C], [N,W], [N,W,C]}}
+ %0 = spirv.Tosa.Gather %arg0, %arg1 : !spirv.arm.tensor<31x11x45xi32>, !spirv.arm.tensor<31x15xi32> -> !spirv.arm.tensor<30x15x44xi32>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<30x15x44xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Scatter
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @scatter_values_in_input_values_out_element_types_not_matching(%arg0: !spirv.arm.tensor<34x28x54xi32>, %arg1: !spirv.arm.tensor<34x18xi32>, %arg2: !spirv.arm.tensor<34x18x54xi16>) -> (!spirv.arm.tensor<34x28x54xi32>) {
+ // expected-error @+1 {{op failed to verify that all of {values_in, input, values_out} have same element type}}
+ %0 = spirv.Tosa.Scatter %arg0, %arg1, %arg2 : !spirv.arm.tensor<34x28x54xi32>, !spirv.arm.tensor<34x18xi32>, !spirv.arm.tensor<34x18x54xi16> -> !spirv.arm.tensor<34x28x54xi32>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<34x28x54xi32>
+}
+
+spirv.ARM.Graph @scatter_values_in_values_out_types_not_matching(%arg0: !spirv.arm.tensor<34x28x54xi32>, %arg1: !spirv.arm.tensor<34x18xi32>, %arg2: !spirv.arm.tensor<34x18x54xi32>) -> (!spirv.arm.tensor<35x28x54xi32>) {
+ // expected-error @+1 {{op failed to verify that all of {values_in, values_out} have same type}}
+ %0 = spirv.Tosa.Scatter %arg0, %arg1, %arg2 : !spirv.arm.tensor<34x28x54xi32>, !spirv.arm.tensor<34x18xi32>, !spirv.arm.tensor<34x18x54xi32> -> !spirv.arm.tensor<35x28x54xi32>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<35x28x54xi32>
+}
+
+spirv.ARM.Graph @scatter_shapes_not_matching(%arg0: !spirv.arm.tensor<34x28x54xi32>, %arg1: !spirv.arm.tensor<34x18xi32>, %arg2: !spirv.arm.tensor<34x18x55xi32>) -> (!spirv.arm.tensor<34x28x54xi32>) {
+ // expected-error @+1 {{op failed to verify that shapes of values_in, indices, and input must satisfy [N,K,C], [N,W], [N,W,C]}}
+ %0 = spirv.Tosa.Scatter %arg0, %arg1, %arg2 : !spirv.arm.tensor<34x28x54xi32>, !spirv.arm.tensor<34x18xi32>, !spirv.arm.tensor<34x18x55xi32> -> !spirv.arm.tensor<34x28x54xi32>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<34x28x54xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Resize
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @resize_i8_input_output_element_type_must_be_i8_or_i32(%arg0: !spirv.arm.tensor<1x1x31x55xi8>) -> (!spirv.arm.tensor<1x1x278x55xf16>) {
+ %1 = spirv.Constant dense<[16, 1, 9, 1]> : !spirv.arm.tensor<4xi32>
+ %2 = spirv.Constant dense<0> : !spirv.arm.tensor<2xi32>
+ %3 = spirv.Constant dense<[0, 7]> : !spirv.arm.tensor<2xi32>
+ // expected-error @+1 {{op failed to verify that if input has type 8-bit signless integer then output must have a type in [8-bit signless integer,32-bit signless integer]}}
+ %4 = spirv.Tosa.Resize mode = <NearestNeighbor>, %arg0, %1, %2, %3 : !spirv.arm.tensor<1x1x31x55xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x1x278x55xf16>
+ spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x1x278x55xf16>
+}
+
+spirv.ARM.Graph @resize_i16_input_output_element_type_must_be_i16_or_i64(%arg0: !spirv.arm.tensor<1x1x31x55xi16>) -> (!spirv.arm.tensor<1x1x278x55xi32>) {
+ %1 = spirv.Constant dense<[16, 1, 9, 1]> : !spirv.arm.tensor<4xi32>
+ %2 = spirv.Constant dense<0> : !spirv.arm.tensor<2xi32>
+ %3 = spirv.Constant dense<[0, 7]> : !spirv.arm.tensor<2xi32>
+ // expected-error @+1 {{op failed to verify that if input has type 16-bit signless integer then output must have a type in [16-bit signless integer,64-bit signless integer]}}
+ %4 = spirv.Tosa.Resize mode = <NearestNeighbor>, %arg0, %1, %2, %3 : !spirv.arm.tensor<1x1x31x55xi16>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x1x278x55xi32>
+ spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x1x278x55xi32>
+}
+
+spirv.ARM.Graph @resize_f16_input_output_element_type_must_be_f16(%arg0: !spirv.arm.tensor<1x48x33x63xf16>) -> (!spirv.arm.tensor<1x753x297x63xf32>) {
+ %1 = spirv.Constant dense<[16, 1, 9, 1]> : !spirv.arm.tensor<4xi32>
+ %2 = spirv.Constant dense<0> : !spirv.arm.tensor<2xi32>
+ %3 = spirv.Constant dense<[0, 8]> : !spirv.arm.tensor<2xi32>
+ // expected-error @+1 {{op failed to verify that if input has type 16-bit float then output must have a type in [16-bit float]}}
+ %4 = spirv.Tosa.Resize mode = <Bilinear>, %arg0, %1, %2, %3 : !spirv.arm.tensor<1x48x33x63xf16>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x753x297x63xf32>
+ spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x753x297x63xf32>
+}
+
+spirv.ARM.Graph @resize_f32_input_output_element_type_must_be_f32(%arg0: !spirv.arm.tensor<1x48x33x63xf32>) -> (!spirv.arm.tensor<1x753x297x63xf16>) {
+ %1 = spirv.Constant dense<[16, 1, 9, 1]> : !spirv.arm.tensor<4xi32>
+ %2 = spirv.Constant dense<0> : !spirv.arm.tensor<2xi32>
+ %3 = spirv.Constant dense<[0, 8]> : !spirv.arm.tensor<2xi32>
+ // expected-error @+1 {{op failed to verify that if input has type 32-bit float then output must have a type in [32-bit float]}}
+ %4 = spirv.Tosa.Resize mode = <Bilinear>, %arg0, %1, %2, %3 : !spirv.arm.tensor<1x48x33x63xf32>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x753x297x63xf16>
+ spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x753x297x63xf16>
+}
+
+spirv.ARM.Graph @resize_bf16_input_output_element_type_must_be_bf16(%arg0: !spirv.arm.tensor<1x48x33x63xbf16>) -> (!spirv.arm.tensor<1x753x297x63xf32>) {
+ %1 = spirv.Constant dense<[16, 1, 9, 1]> : !spirv.arm.tensor<4xi32>
+ %2 = spirv.Constant dense<0> : !spirv.arm.tensor<2xi32>
+ %3 = spirv.Constant dense<[0, 8]> : !spirv.arm.tensor<2xi32>
+ // expected-error @+1 {{op failed to verify that if input has type bfloat16 type then output must have a type in [bfloat16 type]}}
+ %4 = spirv.Tosa.Resize mode = <Bilinear>, %arg0, %1, %2, %3 : !spirv.arm.tensor<1x48x33x63xbf16>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x753x297x63xf32>
+ spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x753x297x63xf32>
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
index baa1372685165..b10724b16a84f 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
@@ -1051,3 +1051,75 @@ spirv.ARM.Graph @transpose_fp(%arg0: !spirv.arm.tensor<42x22x49xi1>) -> (!spirv.
// CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<49x42x22xi1>
spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<49x42x22xi1>
}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Gather - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @gather_int(%arg0: !spirv.arm.tensor<31x11x45xi32>, %arg1: !spirv.arm.tensor<31x15xi32>) -> (!spirv.arm.tensor<31x15x45xi32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Gather %arg0, %arg1 : !spirv.arm.tensor<31x11x45xi32>, !spirv.arm.tensor<31x15xi32> -> !spirv.arm.tensor<31x15x45xi32>
+ %0 = spirv.Tosa.Gather %arg0, %arg1 : !spirv.arm.tensor<31x11x45xi32>, !spirv.arm.tensor<31x15xi32> -> !spirv.arm.tensor<31x15x45xi32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<31x15x45xi32>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<31x15x45xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Gather - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @gather_fp(%arg0: !spirv.arm.tensor<59x61x19xf32>, %arg1: !spirv.arm.tensor<59x65xi32>) -> (!spirv.arm.tensor<59x65x19xf32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Gather %arg0, %arg1 : !spirv.arm.tensor<59x61x19xf32>, !spirv.arm.tensor<59x65xi32> -> !spirv.arm.tensor<59x65x19xf32>
+ %0 = spirv.Tosa.Gather %arg0, %arg1 : !spirv.arm.tensor<59x61x19xf32>, !spirv.arm.tensor<59x65xi32> -> !spirv.arm.tensor<59x65x19xf32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<59x65x19xf32>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<59x65x19xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Scatter - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @scatter_int(%arg0: !spirv.arm.tensor<34x28x54xi32>, %arg1: !spirv.arm.tensor<34x18xi32>, %arg2: !spirv.arm.tensor<34x18x54xi32>) -> (!spirv.arm.tensor<34x28x54xi32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Scatter %arg0, %arg1, %arg2 : !spirv.arm.tensor<34x28x54xi32>, !spirv.arm.tensor<34x18xi32>, !spirv.arm.tensor<34x18x54xi32> -> !spirv.arm.tensor<34x28x54xi32>
+ %0 = spirv.Tosa.Scatter %arg0, %arg1, %arg2 : !spirv.arm.tensor<34x28x54xi32>, !spirv.arm.tensor<34x18xi32>, !spirv.arm.tensor<34x18x54xi32> -> !spirv.arm.tensor<34x28x54xi32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<34x28x54xi32>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<34x28x54xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Scatter - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @scatter_fp(%arg0: !spirv.arm.tensor<18x34x25xf16>, %arg1: !spirv.arm.tensor<18x20xi32>, %arg2: !spirv.arm.tensor<18x20x25xf16>) -> (!spirv.arm.tensor<18x34x25xf16>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Scatter %arg0, %arg1, %arg2 : !spirv.arm.tensor<18x34x25xf16>, !spirv.arm.tensor<18x20xi32>, !spirv.arm.tensor<18x20x25xf16> -> !spirv.arm.tensor<18x34x25xf16>
+ %0 = spirv.Tosa.Scatter %arg0, %arg1, %arg2 : !spirv.arm.tensor<18x34x25xf16>, !spirv.arm.tensor<18x20xi32>, !spirv.arm.tensor<18x20x25xf16> -> !spirv.arm.tensor<18x34x25xf16>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<18x34x25xf16>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<18x34x25xf16>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Resize - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @resize_int(%arg0: !spirv.arm.tensor<1x1x31x55xi8>) -> (!spirv.arm.tensor<1x1x278x55xi8>) {
+ %1 = spirv.Constant dense<[16, 1, 9, 1]> : !spirv.arm.tensor<4xi32>
+ %2 = spirv.Constant dense<0> : !spirv.arm.tensor<2xi32>
+ %3 = spirv.Constant dense<[0, 7]> : !spirv.arm.tensor<2xi32>
+ // CHECK: {{%.*}} = spirv.Tosa.Resize mode = <NearestNeighbor>, %arg0, {{%.*}}, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x1x31x55xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x1x278x55xi8>
+ %4 = spirv.Tosa.Resize mode = <NearestNeighbor>, %arg0, %1, %2, %3 : !spirv.arm.tensor<1x1x31x55xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x1x278x55xi8>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x1x278x55xi8>
+ spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x1x278x55xi8>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Resize - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @resize_fp(%arg0: !spirv.arm.tensor<1x48x33x63xf32>) -> (!spirv.arm.tensor<1x753x297x63xf32>) {
+ %1 = spirv.Constant dense<[16, 1, 9, 1]> : !spirv.arm.tensor<4xi32>
+ %2 = spirv.Constant dense<0> : !spirv.arm.tensor<2xi32>
+ %3 = spirv.Constant dense<[0, 8]> : !spirv.arm.tensor<2xi32>
+ // CHECK: {{%.*}} = spirv.Tosa.Resize mode = <Bilinear>, %arg0, {{%.*}}, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x48x33x63xf32>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x753x297x63xf32>
+ %4 = spirv.Tosa.Resize mode = <Bilinear>, %arg0, %1, %2, %3 : !spirv.arm.tensor<1x48x33x63xf32>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x753x297x63xf32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x753x297x63xf32>
+ spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x753x297x63xf32>
+}
diff --git a/mlir/test/Target/SPIRV/tosa-ops.mlir b/mlir/test/Target/SPIRV/tosa-ops.mlir
index 5556ff47785fb..ebc6a290a9dc1 100644
--- a/mlir/test/Target/SPIRV/tosa-ops.mlir
+++ b/mlir/test/Target/SPIRV/tosa-ops.mlir
@@ -1838,3 +1838,129 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader
spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<49x42x22xi1>
}
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Gather - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @gather_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<31x11x45xi32>, UniformConstant>
+ spirv.GlobalVariable @gather_int_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<31x15xi32>, UniformConstant>
+ spirv.GlobalVariable @gather_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<31x15x45xi32>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @gather_int, @gather_int_arg_0, @gather_int_arg_1, @gather_int_res_0
+ spirv.ARM.Graph @gather_int(%arg0: !spirv.arm.tensor<31x11x45xi32>, %arg1: !spirv.arm.tensor<31x15xi32>) -> (!spirv.arm.tensor<31x15x45xi32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Gather %arg0, %arg1 : !spirv.arm.tensor<31x11x45xi32>, !spirv.arm.tensor<31x15xi32> -> !spirv.arm.tensor<31x15x45xi32>
+ %0 = spirv.Tosa.Gather %arg0, %arg1 : !spirv.arm.tensor<31x11x45xi32>, !spirv.arm.tensor<31x15xi32> -> !spirv.arm.tensor<31x15x45xi32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<31x15x45xi32>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<31x15x45xi32>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Gather - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @gather_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<59x61x19xf32>, UniformConstant>
+ spirv.GlobalVariable @gather_fp_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<59x65xi32>, UniformConstant>
+ spirv.GlobalVariable @gather_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<59x65x19xf32>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @gather_fp, @gather_fp_arg_0, @gather_fp_arg_1, @gather_fp_res_0
+ spirv.ARM.Graph @gather_fp(%arg0: !spirv.arm.tensor<59x61x19xf32>, %arg1: !spirv.arm.tensor<59x65xi32>) -> (!spirv.arm.tensor<59x65x19xf32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Gather %arg0, %arg1 : !spirv.arm.tensor<59x61x19xf32>, !spirv.arm.tensor<59x65xi32> -> !spirv.arm.tensor<59x65x19xf32>
+ %0 = spirv.Tosa.Gather %arg0, %arg1 : !spirv.arm.tensor<59x61x19xf32>, !spirv.arm.tensor<59x65xi32> -> !spirv.arm.tensor<59x65x19xf32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<59x65x19xf32>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<59x65x19xf32>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Scatter - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @scatter_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<34x28x54xi32>, UniformConstant>
+ spirv.GlobalVariable @scatter_int_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<34x18xi32>, UniformConstant>
+ spirv.GlobalVariable @scatter_int_arg_2 bind(0, 2) : !spirv.ptr<!spirv.arm.tensor<34x18x54xi32>, UniformConstant>
+ spirv.GlobalVariable @scatter_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<34x28x54xi32>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @scatter_int, @scatter_int_arg_0, @scatter_int_arg_1, @scatter_int_arg_2, @scatter_int_res_0
+ spirv.ARM.Graph @scatter_int(%arg0: !spirv.arm.tensor<34x28x54xi32>, %arg1: !spirv.arm.tensor<34x18xi32>, %arg2: !spirv.arm.tensor<34x18x54xi32>) -> (!spirv.arm.tensor<34x28x54xi32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Scatter %arg0, %arg1, %arg2 : !spirv.arm.tensor<34x28x54xi32>, !spirv.arm.tensor<34x18xi32>, !spirv.arm.tensor<34x18x54xi32> -> !spirv.arm.tensor<34x28x54xi32>
+ %0 = spirv.Tosa.Scatter %arg0, %arg1, %arg2 : !spirv.arm.tensor<34x28x54xi32>, !spirv.arm.tensor<34x18xi32>, !spirv.arm.tensor<34x18x54xi32> -> !spirv.arm.tensor<34x28x54xi32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<34x28x54xi32>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<34x28x54xi32>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Scatter - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @scatter_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<18x34x25xf16>, UniformConstant>
+ spirv.GlobalVariable @scatter_fp_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<18x20xi32>, UniformConstant>
+ spirv.GlobalVariable @scatter_fp_arg_2 bind(0, 2) : !spirv.ptr<!spirv.arm.tensor<18x20x25xf16>, UniformConstant>
+ spirv.GlobalVariable @scatter_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<18x34x25xf16>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @scatter_fp, @scatter_fp_arg_0, @scatter_fp_arg_1, @scatter_fp_arg_2, @scatter_fp_res_0
+ spirv.ARM.Graph @scatter_fp(%arg0: !spirv.arm.tensor<18x34x25xf16>, %arg1: !spirv.arm.tensor<18x20xi32>, %arg2: !spirv.arm.tensor<18x20x25xf16>) -> (!spirv.arm.tensor<18x34x25xf16>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Scatter %arg0, %arg1, %arg2 : !spirv.arm.tensor<18x34x25xf16>, !spirv.arm.tensor<18x20xi32>, !spirv.arm.tensor<18x20x25xf16> -> !spirv.arm.tensor<18x34x25xf16>
+ %0 = spirv.Tosa.Scatter %arg0, %arg1, %arg2 : !spirv.arm.tensor<18x34x25xf16>, !spirv.arm.tensor<18x20xi32>, !spirv.arm.tensor<18x20x25xf16> -> !spirv.arm.tensor<18x34x25xf16>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<18x34x25xf16>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<18x34x25xf16>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Resize - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @resize_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x1x31x55xi8>, UniformConstant>
+ spirv.GlobalVariable @resize_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<1x1x278x55xi8>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @resize_int, @resize_int_arg_0, @resize_int_res_0
+ spirv.ARM.Graph @resize_int(%arg0: !spirv.arm.tensor<1x1x31x55xi8>) -> (!spirv.arm.tensor<1x1x278x55xi8>) {
+ %1 = spirv.Constant dense<[16, 1, 9, 1]> : !spirv.arm.tensor<4xi32>
+ %2 = spirv.Constant dense<0> : !spirv.arm.tensor<2xi32>
+ %3 = spirv.Constant dense<[0, 7]> : !spirv.arm.tensor<2xi32>
+ // CHECK: {{%.*}} = spirv.Tosa.Resize mode = <NearestNeighbor>, %arg0, {{%.*}}, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x1x31x55xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x1x278x55xi8>
+ %4 = spirv.Tosa.Resize mode = <NearestNeighbor>, %arg0, %1, %2, %3 : !spirv.arm.tensor<1x1x31x55xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x1x278x55xi8>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x1x278x55xi8>
+ spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x1x278x55xi8>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Resize - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @resize_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x48x33x63xf32>, UniformConstant>
+ spirv.GlobalVariable @resize_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<1x753x297x63xf32>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @resize_fp, @resize_fp_arg_0, @resize_fp_res_0
+ spirv.ARM.Graph @resize_fp(%arg0: !spirv.arm.tensor<1x48x33x63xf32>) -> (!spirv.arm.tensor<1x753x297x63xf32>) {
+ %1 = spirv.Constant dense<[16, 1, 9, 1]> : !spirv.arm.tensor<4xi32>
+ %2 = spirv.Constant dense<0> : !spirv.arm.tensor<2xi32>
+ %3 = spirv.Constant dense<[0, 8]> : !spirv.arm.tensor<2xi32>
+ // CHECK: {{%.*}} = spirv.Tosa.Resize mode = <Bilinear>, %arg0, {{%.*}}, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x48x33x63xf32>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x753x297x63xf32>
+ %4 = spirv.Tosa.Resize mode = <Bilinear>, %arg0, %1, %2, %3 : !spirv.arm.tensor<1x48x33x63xf32>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x753x297x63xf32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x753x297x63xf32>
+ spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x753x297x63xf32>
+ }
+}
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 34edd4df49d8e..2344465f75214 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -502,6 +502,7 @@ constexpr llvm::StringLiteral constantIdEnumAttrs[] = {
"SPIRV_MemorySemanticsAttr",
"SPIRV_MatrixLayoutAttr",
"SPIRV_TosaExtAccTypeAttr",
+ "SPIRV_TosaExtResizeModeAttr",
"SPIRV_TosaExtNaNPropagationModeAttr",
"SPIRV_QuadSwapDirectionAttr",
};
More information about the Mlir-commits
mailing list