[Mlir-commits] [mlir] [mlir][spirv] Add Gather/Scatter/Resize ops in TOSA Ext Inst Set (PR #188497)

Davide Grohmann llvmlistbot at llvm.org
Fri Mar 27 02:05:30 PDT 2026


https://github.com/davidegrohmann updated https://github.com/llvm/llvm-project/pull/188497

>From 71f0e7a199505410b00f03632cddbe3a4d3a1c70 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Wed, 28 Jan 2026 13:39:57 +0100
Subject: [PATCH] [mlir][spirv] Add Gather/Scatter/Resize ops in TOSA Ext Inst
 Set

This patch introduces the following reduction operators:

spirv.Tosa.Gather
spirv.Tosa.Scatter
spirv.Tosa.Resize

Also dialect and serialization round-trip tests have been added.

Change-Id: I873f77d23673e87ba2e303686184021fdd792822
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        |  10 +
 .../mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td     | 186 ++++++++++++++++++
 .../mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td   |  25 +++
 .../SPIRV/IR/tosa-ops-verification.mlir       |  87 ++++++++
 mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir      |  72 +++++++
 mlir/test/Target/SPIRV/tosa-ops.mlir          | 126 ++++++++++++
 mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp      |   1 +
 7 files changed, 507 insertions(+)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 8badb84a879fa..9f9e2f5f9a677 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4965,6 +4965,16 @@ def SPIRV_TosaExtAccTypeAttr : SPIRV_I32EnumAttr<
       I32EnumAttrCase<"INT48", 4>,
   ]>;
 
+// NOTE: This is an attribute in the SPIR-V *dialect* but a constant (<id>) in
+// SPIR-V proper.
+def SPIRV_TosaExtResizeModeAttr : SPIRV_I32EnumAttr<
+  "TosaExtResizeModeType", "Tosa Ext Resize Mode Type",
+  "tosa_ext_resize_mode_type",
+  [
+      I32EnumAttrCase<"NearestNeighbor", 1>,
+      I32EnumAttrCase<"Bilinear", 2>,
+  ]>;
+
 // NOTE: This is an attribute in the SPIR-V *dialect* but a constant (<id>) in
 // SPIR-V proper.
 def SPIRV_TosaExtNaNPropagationModeAttr : SPIRV_I32EnumAttr<
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index f2519784b5058..7fc7f86478491 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -2462,4 +2462,190 @@ def SPIRV_TosaTransposeOp : SPIRV_TosaOpWithResult<"Transpose", 60, [Pure,
 }
 
 
+def SPIRV_TosaGatherOp : SPIRV_TosaOpWithResult<"Gather", 61, [NoMemoryEffect,
+  AllElementTypesMatch<["values", "output"]>,
+  ValuesIndicesShapesMatch<"values", "indices", "output">]> {
+  let summary = "Gather operation.";
+
+  let description = [{
+    Generate a tensor for which each element in the output is a subtensor of the
+    values tensor based on the indices. Undefined behaviour may occur if the
+    specified indices are out of range.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_gather
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_gather
+
+    #### Example:
+    ```mlir
+    %0 = spirv.Tosa.Gather %values, %indices : !spirv.arm.tensor<31x11x45xi32>, !spirv.arm.tensor<31x15xi32> -> !spirv.arm.tensor<31x15x45xi32>
+    %0 = spirv.Tosa.Gather %values, %indices : !spirv.arm.tensor<59x61x19xf32>, !spirv.arm.tensor<59x65xi32> -> !spirv.arm.tensor<59x65x19xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_TosaNumerical_TensorArm3D: $values,
+    SPIRV_Int32_TensorArm2D: $indices
+  );
+
+  let results = (outs
+    SPIRV_TosaNumerical_TensorArm3D: $output
+  );
+
+  let assemblyFormat = [{
+    $values `,`
+    $indices
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+
+  let extraClassDeclaration = extraBaseClassDeclaration#[{
+    ::mlir::spirv::TensorArmType getValuesType() {
+      return cast<::mlir::spirv::TensorArmType>(getValues().getType());
+    }
+    ::mlir::spirv::TensorArmType getIndicesType() {
+      return cast<::mlir::spirv::TensorArmType>(getIndices().getType());
+    }
+  }];
+}
+
+
+def SPIRV_TosaScatterOp : SPIRV_TosaOpWithResult<"Scatter", 62, [NoMemoryEffect,
+  AllElementTypesMatch<["values_in", "input", "values_out"]>,
+  AllTypesMatch<["values_in", "values_out"]>,
+  ValuesIndicesShapesMatch<"values_in", "indices", "input">]> {
+  let summary = "Scatter operation.";
+
+  let description = [{
+    The values_out tensor is set to the values_in tensor with data modified as
+    follows: data from the input tensor is inserted at the positions specified
+    by the indices tensor. In use cases that require multiple updates to the
+    same output position, these must be decomposed into multiple scatter
+    operations.  Undefined behaviour may occur if the specified indices are
+    out of range or duplicate indices are provided.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_scatter
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_scatter
+
+    #### Example:
+    ```mlir
+    %0 = spirv.Tosa.Scatter %values_in, %indices, %arg0 : !spirv.arm.tensor<34x28x54xi32>, !spirv.arm.tensor<34x18xi32>, !spirv.arm.tensor<34x18x54xi32> -> !spirv.arm.tensor<34x28x54xi32>
+    %0 = spirv.Tosa.Scatter %values_in, %indices, %arg0 : !spirv.arm.tensor<18x34x25xf16>, !spirv.arm.tensor<18x20xi32>, !spirv.arm.tensor<18x20x25xf16> -> !spirv.arm.tensor<18x34x25xf16>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_TosaNumerical_TensorArm3D: $values_in,
+    SPIRV_Int32_TensorArm2D: $indices,
+    SPIRV_TosaNumerical_TensorArm3D: $input
+  );
+
+  let results = (outs
+    SPIRV_TosaNumerical_TensorArm3D: $values_out
+  );
+
+  let assemblyFormat = [{
+    $values_in `,`
+    $indices `,`
+    $input
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+
+  let extraClassDeclaration = extraBaseClassDeclaration#[{
+    ::mlir::spirv::TensorArmType getValuesInType() {
+      return cast<::mlir::spirv::TensorArmType>(getValuesIn().getType());
+    }
+    ::mlir::spirv::TensorArmType getIndicesType() {
+      return cast<::mlir::spirv::TensorArmType>(getIndices().getType());
+    }
+    ::mlir::spirv::TensorArmType getInputType() {
+      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+    }
+  }];
+}
+
+
+def SPIRV_TosaResizeOp : SPIRV_TosaOpWithResult<"Resize", 63, [Pure,
+  TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
+  TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
+  TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
+  TypeConstraintImplicationOn<"input", I8, "output", [I8, I32]>,
+  TypeConstraintImplicationOn<"input", I16, "output", [I16, I64]>]> {
+  let summary = "Resize operation, supports various resize/upsample modes.";
+
+  let description = [{
+    Resizes a tensor. Resize is only allowed in the H and W dimensions, given the input
+    shape = [N,H,W,C].
+
+    The height dimension (H) is scaled by factor ($ scale_y_n/scale_y_d $). The width
+    dimension (W) is scaled by factor ($ scale_x_n/scale_x_d $).
+
+    The NearestNeighbor mode returns the value of the input tensor closest to
+    the calculated sample position for both floating-point and integer data
+    formats.
+
+    Floating-point Bilinear mode returns a bilinearly interpolated output value
+    based on the four closest input sample positions.
+
+    For integer Bilinear interpolation mode, the output value must be scaled by
+    $ 1/(scale_y_n * scale_x_n) $ in a following operation to complete the
+    interpolation (for example with a rescale operator).
+
+    The output dimensions can be derived from the input dimensions by inverting
+    the scale. The [border_y, border_x] values adjust the output size to allow
+    fractional sampling beyond integer input position (H - 1,W - 1).
+
+    The limit MAX_SCALE=256 is applied to each scale ratio after reduction of the
+    ratio. Individual scale numerator and denominator values are allowed to be
+    larger than MAX_SCALE.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_resize
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_resize
+
+    #### Example:
+    ```mlir
+    %4 = spirv.Tosa.Resize mode = <NearestNeighbor>, %arg0, %scale, %offset, %border : !spirv.arm.tensor<1x1x31x55xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x1x278x55xi8>
+    %4 = spirv.Tosa.Resize mode = <Bilinear>, %arg0, %scale, %offset, %border : !spirv.arm.tensor<1x48x33x63xf32>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x753x297x63xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_TosaExtResizeModeAttr: $mode,
+    SPIRV_TosaNumerical_TensorArm4D: $input,
+    SPIRV_Int32_1DTensorArmOfLength4: $scale,
+    SPIRV_Int32_1DTensorArmOfLength2: $offset,
+    SPIRV_Int32_1DTensorArmOfLength2: $border
+  );
+
+  let results = (outs
+    SPIRV_TosaNumerical_TensorArm4D: $output
+  );
+
+  let assemblyFormat = [{
+    `mode` `=` $mode `,`
+    $input `,`
+    $scale `,`
+    $offset `,`
+    $border
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+
+  let extraClassDeclaration = extraBaseClassDeclaration#[{
+    ::mlir::spirv::TensorArmType getInputType() {
+      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+    }
+    ::mlir::spirv::TensorArmType getScaleType() {
+      return cast<::mlir::spirv::TensorArmType>(getScale().getType());
+    }
+    ::mlir::spirv::TensorArmType getOffsetType() {
+      return cast<::mlir::spirv::TensorArmType>(getOffset().getType());
+    }
+    ::mlir::spirv::TensorArmType getBorderType() {
+      return cast<::mlir::spirv::TensorArmType>(getBorder().getType());
+    }
+  }];
+}
+
+
 #endif // MLIR_DIALECT_SPIRV_IR_TOSA_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
index d14eeba324adf..f116c4dcdd491 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -38,6 +38,7 @@ class TensorArmRankOf<list<Type> allowedTypes, list<int> ranks>
       [HasAnyRankOfPred<ranks>],
       !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensorArm">;
 
+def SPIRV_Int32_TensorArm2D : TensorArmRankOf<[SPIRV_Int32], [2]>;
 def SPIRV_Float32_TensorArm3D: TensorArmRankOf<[SPIRV_Float32], [3]>;
 def SPIRV_TosaInteger_TensorArm1D : TensorArmRankOf<[SPIRV_TosaInteger], [1]>;
 def SPIRV_TosaNumerical_TensorArm1D : TensorArmRankOf<[SPIRV_TosaNumerical], [1]>;
@@ -65,6 +66,9 @@ class SPIRV_1DTensorArmOfLengthAndType<list<int> allowedLengths, list<Type> allo
     "rank 1 tensorArm of length " # !interleave(allowedLengths, "/"),
     "::mlir::spirv::TensorArmType">;
 
+def SPIRV_Int32_1DTensorArmOfLength2 : SPIRV_1DTensorArmOfLengthAndType<[2], [SPIRV_Int32]>;
+def SPIRV_Int32_1DTensorArmOfLength4 : SPIRV_1DTensorArmOfLengthAndType<[4], [SPIRV_Int32]>;
+
 def SPIRV_Int32_1DTensorArmOfLength1To6 : SPIRV_1DTensorArmOfLengthAndType<[1, 2, 3, 4, 5, 6], [SPIRV_Int32]>;
 def SPIRV_Int32_1DTensorArmOfEvenLength2To12 : SPIRV_1DTensorArmOfLengthAndType<[2, 4, 6, 8, 10, 12], [SPIRV_Int32]>;
 
@@ -147,6 +151,27 @@ class MatchBroadcastableShapes<string input1, string input2, string output>:
     "})">]>
   >;
 
+class SameDimsOrDynamicPred<string lhs, int lhsDim, string rhs, int rhsDim> :
+  CPred<"[](::mlir::ShapedType lhsType, ::mlir::ShapedType rhsType) { "
+        "  int64_t lhsSize = lhsType.getDimSize(" # lhsDim # "); "
+        "  int64_t rhsSize = rhsType.getDimSize(" # rhsDim # "); "
+        "  return ::mlir::ShapedType::isDynamic(lhsSize) || "
+        "         ::mlir::ShapedType::isDynamic(rhsSize) || lhsSize == rhsSize; "
+        "}("
+        "::llvm::cast<::mlir::ShapedType>($" # lhs # ".getType()), "
+        "::llvm::cast<::mlir::ShapedType>($" # rhs # ".getType()))">;
+
+class ValuesIndicesShapesMatch<string values, string indices, string tensor>:
+  PredOpTrait<"shapes of " # values # ", " # indices # ", and " # tensor #
+                  " must satisfy [N,K,C], [N,W], [N,W,C]",
+    And<[
+      SameDimsOrDynamicPred<values, 0, indices, 0>,
+      SameDimsOrDynamicPred<values, 0, tensor, 0>,
+      SameDimsOrDynamicPred<indices, 0, tensor, 0>,
+      SameDimsOrDynamicPred<indices, 1, tensor, 1>,
+      SameDimsOrDynamicPred<values, 2, tensor, 2>
+    ]>>;
+
 class TableSizeConstraint<string input, Type type, int size>:
   PredOpTrait<"table must have size " # size # " if " # input # " has element type " # type.summary,
       Implies<ElementTypeIsPred<input, type>, [CPred<"::llvm::cast<::mlir::ShapedType>(getTable().getType()).getShape()[0] == " # size>]>
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
index 5ff51dca91a3b..f95fedba74307 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
@@ -1853,3 +1853,90 @@ spirv.ARM.Graph @transpose_perms_element_count_not_input_rank(%arg0: !spirv.arm.
   %0 = spirv.Tosa.Transpose perms = [1, 0], %arg0 : !spirv.arm.tensor<2x3x4xi8> -> !spirv.arm.tensor<4x2x3xi8>
   spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x2x3xi8>
 }
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Gather
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @gather_values_output_element_types_not_matching(%arg0: !spirv.arm.tensor<31x11x45xi32>, %arg1: !spirv.arm.tensor<31x15xi32>) -> (!spirv.arm.tensor<31x15x45xi16>) {
+  // expected-error @+1 {{op failed to verify that all of {values, output} have same element type}}
+  %0 = spirv.Tosa.Gather %arg0, %arg1 : !spirv.arm.tensor<31x11x45xi32>, !spirv.arm.tensor<31x15xi32> -> !spirv.arm.tensor<31x15x45xi16>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<31x15x45xi16>
+}
+
+spirv.ARM.Graph @gather_shapes_not_matching(%arg0: !spirv.arm.tensor<31x11x45xi32>, %arg1: !spirv.arm.tensor<31x15xi32>) -> (!spirv.arm.tensor<30x15x44xi32>) {
+  // expected-error @+1 {{op failed to verify that shapes of values, indices, and output must satisfy [N,K,C], [N,W], [N,W,C]}}
+  %0 = spirv.Tosa.Gather %arg0, %arg1 : !spirv.arm.tensor<31x11x45xi32>, !spirv.arm.tensor<31x15xi32> -> !spirv.arm.tensor<30x15x44xi32>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<30x15x44xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Scatter
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @scatter_values_in_input_values_out_element_types_not_matching(%arg0: !spirv.arm.tensor<34x28x54xi32>, %arg1: !spirv.arm.tensor<34x18xi32>, %arg2: !spirv.arm.tensor<34x18x54xi16>) -> (!spirv.arm.tensor<34x28x54xi32>) {
+  // expected-error @+1 {{op failed to verify that all of {values_in, input, values_out} have same element type}}
+  %0 = spirv.Tosa.Scatter %arg0, %arg1, %arg2 : !spirv.arm.tensor<34x28x54xi32>, !spirv.arm.tensor<34x18xi32>, !spirv.arm.tensor<34x18x54xi16> -> !spirv.arm.tensor<34x28x54xi32>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<34x28x54xi32>
+}
+
+spirv.ARM.Graph @scatter_values_in_values_out_types_not_matching(%arg0: !spirv.arm.tensor<34x28x54xi32>, %arg1: !spirv.arm.tensor<34x18xi32>, %arg2: !spirv.arm.tensor<34x18x54xi32>) -> (!spirv.arm.tensor<35x28x54xi32>) {
+  // expected-error @+1 {{op failed to verify that all of {values_in, values_out} have same type}}
+  %0 = spirv.Tosa.Scatter %arg0, %arg1, %arg2 : !spirv.arm.tensor<34x28x54xi32>, !spirv.arm.tensor<34x18xi32>, !spirv.arm.tensor<34x18x54xi32> -> !spirv.arm.tensor<35x28x54xi32>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<35x28x54xi32>
+}
+
+spirv.ARM.Graph @scatter_shapes_not_matching(%arg0: !spirv.arm.tensor<34x28x54xi32>, %arg1: !spirv.arm.tensor<34x18xi32>, %arg2: !spirv.arm.tensor<34x18x55xi32>) -> (!spirv.arm.tensor<34x28x54xi32>) {
+  // expected-error @+1 {{op failed to verify that shapes of values_in, indices, and input must satisfy [N,K,C], [N,W], [N,W,C]}}
+  %0 = spirv.Tosa.Scatter %arg0, %arg1, %arg2 : !spirv.arm.tensor<34x28x54xi32>, !spirv.arm.tensor<34x18xi32>, !spirv.arm.tensor<34x18x55xi32> -> !spirv.arm.tensor<34x28x54xi32>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<34x28x54xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Resize
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @resize_i8_input_output_element_type_must_be_i8_or_i32(%arg0: !spirv.arm.tensor<1x1x31x55xi8>) -> (!spirv.arm.tensor<1x1x278x55xf16>) {
+  %1 = spirv.Constant dense<[16, 1, 9, 1]> : !spirv.arm.tensor<4xi32>
+  %2 = spirv.Constant dense<0> : !spirv.arm.tensor<2xi32>
+  %3 = spirv.Constant dense<[0, 7]> : !spirv.arm.tensor<2xi32>
+  // expected-error @+1 {{op failed to verify that if input has type 8-bit signless integer then output must have a type in [8-bit signless integer,32-bit signless integer]}}
+  %4 = spirv.Tosa.Resize mode = <NearestNeighbor>, %arg0, %1, %2, %3 : !spirv.arm.tensor<1x1x31x55xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x1x278x55xf16>
+  spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x1x278x55xf16>
+}
+
+spirv.ARM.Graph @resize_i16_input_output_element_type_must_be_i16_or_i64(%arg0: !spirv.arm.tensor<1x1x31x55xi16>) -> (!spirv.arm.tensor<1x1x278x55xi32>) {
+  %1 = spirv.Constant dense<[16, 1, 9, 1]> : !spirv.arm.tensor<4xi32>
+  %2 = spirv.Constant dense<0> : !spirv.arm.tensor<2xi32>
+  %3 = spirv.Constant dense<[0, 7]> : !spirv.arm.tensor<2xi32>
+  // expected-error @+1 {{op failed to verify that if input has type 16-bit signless integer then output must have a type in [16-bit signless integer,64-bit signless integer]}}
+  %4 = spirv.Tosa.Resize mode = <NearestNeighbor>, %arg0, %1, %2, %3 : !spirv.arm.tensor<1x1x31x55xi16>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x1x278x55xi32>
+  spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x1x278x55xi32>
+}
+
+spirv.ARM.Graph @resize_f16_input_output_element_type_must_be_f16(%arg0: !spirv.arm.tensor<1x48x33x63xf16>) -> (!spirv.arm.tensor<1x753x297x63xf32>) {
+  %1 = spirv.Constant dense<[16, 1, 9, 1]> : !spirv.arm.tensor<4xi32>
+  %2 = spirv.Constant dense<0> : !spirv.arm.tensor<2xi32>
+  %3 = spirv.Constant dense<[0, 8]> : !spirv.arm.tensor<2xi32>
+  // expected-error @+1 {{op failed to verify that if input has type 16-bit float then output must have a type in [16-bit float]}}
+  %4 = spirv.Tosa.Resize mode = <Bilinear>, %arg0, %1, %2, %3 : !spirv.arm.tensor<1x48x33x63xf16>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x753x297x63xf32>
+  spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x753x297x63xf32>
+}
+
+spirv.ARM.Graph @resize_f32_input_output_element_type_must_be_f32(%arg0: !spirv.arm.tensor<1x48x33x63xf32>) -> (!spirv.arm.tensor<1x753x297x63xf16>) {
+  %1 = spirv.Constant dense<[16, 1, 9, 1]> : !spirv.arm.tensor<4xi32>
+  %2 = spirv.Constant dense<0> : !spirv.arm.tensor<2xi32>
+  %3 = spirv.Constant dense<[0, 8]> : !spirv.arm.tensor<2xi32>
+  // expected-error @+1 {{op failed to verify that if input has type 32-bit float then output must have a type in [32-bit float]}}
+  %4 = spirv.Tosa.Resize mode = <Bilinear>, %arg0, %1, %2, %3 : !spirv.arm.tensor<1x48x33x63xf32>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x753x297x63xf16>
+  spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x753x297x63xf16>
+}
+
+spirv.ARM.Graph @resize_bf16_input_output_element_type_must_be_bf16(%arg0: !spirv.arm.tensor<1x48x33x63xbf16>) -> (!spirv.arm.tensor<1x753x297x63xf32>) {
+  %1 = spirv.Constant dense<[16, 1, 9, 1]> : !spirv.arm.tensor<4xi32>
+  %2 = spirv.Constant dense<0> : !spirv.arm.tensor<2xi32>
+  %3 = spirv.Constant dense<[0, 8]> : !spirv.arm.tensor<2xi32>
+  // expected-error @+1 {{op failed to verify that if input has type bfloat16 type then output must have a type in [bfloat16 type]}}
+  %4 = spirv.Tosa.Resize mode = <Bilinear>, %arg0, %1, %2, %3 : !spirv.arm.tensor<1x48x33x63xbf16>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x753x297x63xf32>
+  spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x753x297x63xf32>
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
index baa1372685165..b10724b16a84f 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
@@ -1051,3 +1051,75 @@ spirv.ARM.Graph @transpose_fp(%arg0: !spirv.arm.tensor<42x22x49xi1>) -> (!spirv.
   // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<49x42x22xi1>
   spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<49x42x22xi1>
 }
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Gather - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @gather_int(%arg0: !spirv.arm.tensor<31x11x45xi32>, %arg1: !spirv.arm.tensor<31x15xi32>) -> (!spirv.arm.tensor<31x15x45xi32>) {
+  // CHECK: {{%.*}} = spirv.Tosa.Gather %arg0, %arg1 : !spirv.arm.tensor<31x11x45xi32>, !spirv.arm.tensor<31x15xi32> -> !spirv.arm.tensor<31x15x45xi32>
+  %0 = spirv.Tosa.Gather %arg0, %arg1 : !spirv.arm.tensor<31x11x45xi32>, !spirv.arm.tensor<31x15xi32> -> !spirv.arm.tensor<31x15x45xi32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<31x15x45xi32>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<31x15x45xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Gather - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @gather_fp(%arg0: !spirv.arm.tensor<59x61x19xf32>, %arg1: !spirv.arm.tensor<59x65xi32>) -> (!spirv.arm.tensor<59x65x19xf32>) {
+  // CHECK: {{%.*}} = spirv.Tosa.Gather %arg0, %arg1 : !spirv.arm.tensor<59x61x19xf32>, !spirv.arm.tensor<59x65xi32> -> !spirv.arm.tensor<59x65x19xf32>
+  %0 = spirv.Tosa.Gather %arg0, %arg1 : !spirv.arm.tensor<59x61x19xf32>, !spirv.arm.tensor<59x65xi32> -> !spirv.arm.tensor<59x65x19xf32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<59x65x19xf32>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<59x65x19xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Scatter - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @scatter_int(%arg0: !spirv.arm.tensor<34x28x54xi32>, %arg1: !spirv.arm.tensor<34x18xi32>, %arg2: !spirv.arm.tensor<34x18x54xi32>) -> (!spirv.arm.tensor<34x28x54xi32>) {
+  // CHECK: {{%.*}} = spirv.Tosa.Scatter %arg0, %arg1, %arg2 : !spirv.arm.tensor<34x28x54xi32>, !spirv.arm.tensor<34x18xi32>, !spirv.arm.tensor<34x18x54xi32> -> !spirv.arm.tensor<34x28x54xi32>
+  %0 = spirv.Tosa.Scatter %arg0, %arg1, %arg2 : !spirv.arm.tensor<34x28x54xi32>, !spirv.arm.tensor<34x18xi32>, !spirv.arm.tensor<34x18x54xi32> -> !spirv.arm.tensor<34x28x54xi32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<34x28x54xi32>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<34x28x54xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Scatter - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @scatter_fp(%arg0: !spirv.arm.tensor<18x34x25xf16>, %arg1: !spirv.arm.tensor<18x20xi32>, %arg2: !spirv.arm.tensor<18x20x25xf16>) -> (!spirv.arm.tensor<18x34x25xf16>) {
+  // CHECK: {{%.*}} = spirv.Tosa.Scatter %arg0, %arg1, %arg2 : !spirv.arm.tensor<18x34x25xf16>, !spirv.arm.tensor<18x20xi32>, !spirv.arm.tensor<18x20x25xf16> -> !spirv.arm.tensor<18x34x25xf16>
+  %0 = spirv.Tosa.Scatter %arg0, %arg1, %arg2 : !spirv.arm.tensor<18x34x25xf16>, !spirv.arm.tensor<18x20xi32>, !spirv.arm.tensor<18x20x25xf16> -> !spirv.arm.tensor<18x34x25xf16>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<18x34x25xf16>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<18x34x25xf16>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Resize - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @resize_int(%arg0: !spirv.arm.tensor<1x1x31x55xi8>) -> (!spirv.arm.tensor<1x1x278x55xi8>) {
+  %1 = spirv.Constant dense<[16, 1, 9, 1]> : !spirv.arm.tensor<4xi32>
+  %2 = spirv.Constant dense<0> : !spirv.arm.tensor<2xi32>
+  %3 = spirv.Constant dense<[0, 7]> : !spirv.arm.tensor<2xi32>
+  // CHECK: {{%.*}} = spirv.Tosa.Resize mode = <NearestNeighbor>, %arg0, {{%.*}}, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x1x31x55xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x1x278x55xi8>
+  %4 = spirv.Tosa.Resize mode = <NearestNeighbor>, %arg0, %1, %2, %3 : !spirv.arm.tensor<1x1x31x55xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x1x278x55xi8>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x1x278x55xi8>
+  spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x1x278x55xi8>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Resize - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @resize_fp(%arg0: !spirv.arm.tensor<1x48x33x63xf32>) -> (!spirv.arm.tensor<1x753x297x63xf32>) {
+  %1 = spirv.Constant dense<[16, 1, 9, 1]> : !spirv.arm.tensor<4xi32>
+  %2 = spirv.Constant dense<0> : !spirv.arm.tensor<2xi32>
+  %3 = spirv.Constant dense<[0, 8]> : !spirv.arm.tensor<2xi32>
+  // CHECK: {{%.*}} = spirv.Tosa.Resize mode = <Bilinear>, %arg0, {{%.*}}, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x48x33x63xf32>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x753x297x63xf32>
+  %4 = spirv.Tosa.Resize mode = <Bilinear>, %arg0, %1, %2, %3 : !spirv.arm.tensor<1x48x33x63xf32>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x753x297x63xf32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x753x297x63xf32>
+  spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x753x297x63xf32>
+}
diff --git a/mlir/test/Target/SPIRV/tosa-ops.mlir b/mlir/test/Target/SPIRV/tosa-ops.mlir
index 5556ff47785fb..ebc6a290a9dc1 100644
--- a/mlir/test/Target/SPIRV/tosa-ops.mlir
+++ b/mlir/test/Target/SPIRV/tosa-ops.mlir
@@ -1838,3 +1838,129 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader
     spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<49x42x22xi1>
   }
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Gather - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @gather_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<31x11x45xi32>, UniformConstant>
+  spirv.GlobalVariable @gather_int_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<31x15xi32>, UniformConstant>
+  spirv.GlobalVariable @gather_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<31x15x45xi32>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @gather_int, @gather_int_arg_0, @gather_int_arg_1, @gather_int_res_0
+  spirv.ARM.Graph @gather_int(%arg0: !spirv.arm.tensor<31x11x45xi32>, %arg1: !spirv.arm.tensor<31x15xi32>) -> (!spirv.arm.tensor<31x15x45xi32>) {
+    // CHECK: {{%.*}} = spirv.Tosa.Gather %arg0, %arg1 : !spirv.arm.tensor<31x11x45xi32>, !spirv.arm.tensor<31x15xi32> -> !spirv.arm.tensor<31x15x45xi32>
+    %0 = spirv.Tosa.Gather %arg0, %arg1 : !spirv.arm.tensor<31x11x45xi32>, !spirv.arm.tensor<31x15xi32> -> !spirv.arm.tensor<31x15x45xi32>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<31x15x45xi32>
+    spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<31x15x45xi32>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Gather - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @gather_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<59x61x19xf32>, UniformConstant>
+  spirv.GlobalVariable @gather_fp_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<59x65xi32>, UniformConstant>
+  spirv.GlobalVariable @gather_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<59x65x19xf32>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @gather_fp, @gather_fp_arg_0, @gather_fp_arg_1, @gather_fp_res_0
+  spirv.ARM.Graph @gather_fp(%arg0: !spirv.arm.tensor<59x61x19xf32>, %arg1: !spirv.arm.tensor<59x65xi32>) -> (!spirv.arm.tensor<59x65x19xf32>) {
+    // CHECK: {{%.*}} = spirv.Tosa.Gather %arg0, %arg1 : !spirv.arm.tensor<59x61x19xf32>, !spirv.arm.tensor<59x65xi32> -> !spirv.arm.tensor<59x65x19xf32>
+    %0 = spirv.Tosa.Gather %arg0, %arg1 : !spirv.arm.tensor<59x61x19xf32>, !spirv.arm.tensor<59x65xi32> -> !spirv.arm.tensor<59x65x19xf32>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<59x65x19xf32>
+    spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<59x65x19xf32>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Scatter - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @scatter_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<34x28x54xi32>, UniformConstant>
+  spirv.GlobalVariable @scatter_int_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<34x18xi32>, UniformConstant>
+  spirv.GlobalVariable @scatter_int_arg_2 bind(0, 2) : !spirv.ptr<!spirv.arm.tensor<34x18x54xi32>, UniformConstant>
+  spirv.GlobalVariable @scatter_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<34x28x54xi32>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @scatter_int, @scatter_int_arg_0, @scatter_int_arg_1, @scatter_int_arg_2, @scatter_int_res_0
+  spirv.ARM.Graph @scatter_int(%arg0: !spirv.arm.tensor<34x28x54xi32>, %arg1: !spirv.arm.tensor<34x18xi32>, %arg2: !spirv.arm.tensor<34x18x54xi32>) -> (!spirv.arm.tensor<34x28x54xi32>) {
+    // CHECK: {{%.*}} = spirv.Tosa.Scatter %arg0, %arg1, %arg2 : !spirv.arm.tensor<34x28x54xi32>, !spirv.arm.tensor<34x18xi32>, !spirv.arm.tensor<34x18x54xi32> -> !spirv.arm.tensor<34x28x54xi32>
+    %0 = spirv.Tosa.Scatter %arg0, %arg1, %arg2 : !spirv.arm.tensor<34x28x54xi32>, !spirv.arm.tensor<34x18xi32>, !spirv.arm.tensor<34x18x54xi32> -> !spirv.arm.tensor<34x28x54xi32>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<34x28x54xi32>
+    spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<34x28x54xi32>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Scatter - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @scatter_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<18x34x25xf16>, UniformConstant>
+  spirv.GlobalVariable @scatter_fp_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<18x20xi32>, UniformConstant>
+  spirv.GlobalVariable @scatter_fp_arg_2 bind(0, 2) : !spirv.ptr<!spirv.arm.tensor<18x20x25xf16>, UniformConstant>
+  spirv.GlobalVariable @scatter_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<18x34x25xf16>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @scatter_fp, @scatter_fp_arg_0, @scatter_fp_arg_1, @scatter_fp_arg_2, @scatter_fp_res_0
+  spirv.ARM.Graph @scatter_fp(%arg0: !spirv.arm.tensor<18x34x25xf16>, %arg1: !spirv.arm.tensor<18x20xi32>, %arg2: !spirv.arm.tensor<18x20x25xf16>) -> (!spirv.arm.tensor<18x34x25xf16>) {
+    // CHECK: {{%.*}} = spirv.Tosa.Scatter %arg0, %arg1, %arg2 : !spirv.arm.tensor<18x34x25xf16>, !spirv.arm.tensor<18x20xi32>, !spirv.arm.tensor<18x20x25xf16> -> !spirv.arm.tensor<18x34x25xf16>
+    %0 = spirv.Tosa.Scatter %arg0, %arg1, %arg2 : !spirv.arm.tensor<18x34x25xf16>, !spirv.arm.tensor<18x20xi32>, !spirv.arm.tensor<18x20x25xf16> -> !spirv.arm.tensor<18x34x25xf16>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<18x34x25xf16>
+    spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<18x34x25xf16>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Resize - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @resize_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x1x31x55xi8>, UniformConstant>
+  spirv.GlobalVariable @resize_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<1x1x278x55xi8>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @resize_int, @resize_int_arg_0, @resize_int_res_0
+  spirv.ARM.Graph @resize_int(%arg0: !spirv.arm.tensor<1x1x31x55xi8>) -> (!spirv.arm.tensor<1x1x278x55xi8>) {
+    %1 = spirv.Constant dense<[16, 1, 9, 1]> : !spirv.arm.tensor<4xi32>
+    %2 = spirv.Constant dense<0> : !spirv.arm.tensor<2xi32>
+    %3 = spirv.Constant dense<[0, 7]> : !spirv.arm.tensor<2xi32>
+    // CHECK: {{%.*}} = spirv.Tosa.Resize mode = <NearestNeighbor>, %arg0, {{%.*}}, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x1x31x55xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x1x278x55xi8>
+    %4 = spirv.Tosa.Resize mode = <NearestNeighbor>, %arg0, %1, %2, %3 : !spirv.arm.tensor<1x1x31x55xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x1x278x55xi8>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x1x278x55xi8>
+    spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x1x278x55xi8>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Resize - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @resize_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x48x33x63xf32>, UniformConstant>
+  spirv.GlobalVariable @resize_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<1x753x297x63xf32>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @resize_fp, @resize_fp_arg_0, @resize_fp_res_0
+  spirv.ARM.Graph @resize_fp(%arg0: !spirv.arm.tensor<1x48x33x63xf32>) -> (!spirv.arm.tensor<1x753x297x63xf32>) {
+    %1 = spirv.Constant dense<[16, 1, 9, 1]> : !spirv.arm.tensor<4xi32>
+    %2 = spirv.Constant dense<0> : !spirv.arm.tensor<2xi32>
+    %3 = spirv.Constant dense<[0, 8]> : !spirv.arm.tensor<2xi32>
+    // CHECK: {{%.*}} = spirv.Tosa.Resize mode = <Bilinear>, %arg0, {{%.*}}, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x48x33x63xf32>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x753x297x63xf32>
+    %4 = spirv.Tosa.Resize mode = <Bilinear>, %arg0, %1, %2, %3 : !spirv.arm.tensor<1x48x33x63xf32>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<2xi32>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<1x753x297x63xf32>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x753x297x63xf32>
+    spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x753x297x63xf32>
+  }
+}
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 34edd4df49d8e..2344465f75214 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -502,6 +502,7 @@ constexpr llvm::StringLiteral constantIdEnumAttrs[] = {
     "SPIRV_MemorySemanticsAttr",
     "SPIRV_MatrixLayoutAttr",
     "SPIRV_TosaExtAccTypeAttr",
+    "SPIRV_TosaExtResizeModeAttr",
     "SPIRV_TosaExtNaNPropagationModeAttr",
     "SPIRV_QuadSwapDirectionAttr",
 };



More information about the Mlir-commits mailing list