[Mlir-commits] [mlir] [mlir][spirv] Add Conv operations for TOSA Extended Instruction Set (001000.1) (PR #176908)

Davide Grohmann llvmlistbot at llvm.org
Thu Jan 22 05:56:21 PST 2026


https://github.com/davidegrohmann updated https://github.com/llvm/llvm-project/pull/176908

>From eba638ea1479d5ae05691a8693511b409920551d Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Tue, 20 Jan 2026 10:44:53 +0100
Subject: [PATCH] [mlir][spirv] Add Conv operations for TOSA Extended
 Instruction Set (001000.1)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This patch expands support for the TOSA Extended Instruction Set
(001000.1) to the SPIR-V dialect in MLIR. The TOSA extended
instruction set provides a standardized set of machine learning
operations designed to be used within `spirv.ARM.Graph` operations
(corresponding to OpGraphARM in SPV_ARM_graph) and typed with
`!spirv.arm.tensor<...>` (corresponding to OpTypeTensorARM in
SPV_ARM_tensor).

The change introduces:
* Extending dialect plumbing for import, serialization, and
  deserialization of the TOSA extended instruction set.
* The `spirv.Tosa.*Conv*` convolution operation from TOSA extended
  instruction, each lowering to the corresponding `OpExtInst`.
* Verification enforcing that new convolution operations appears only
  within `spirv.ARM.Graph` regions, operates on
  `!spirv.arm.tensor<...>` types, and is well-formed according to the
  TOSA 001000.1 specification.

All convolution operations from TOSA 001000.1 extended instructions
are introduced: [parser, printer, verifier, and round-trip tests using
MLIR’s SPIR-V serialization/deserialization infrastructure are
included.

This work aligns with Khronos SPIR-V TOSA specifications.

Specification:
https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: I32ca642362dbad0cfb172f5738f8ff62b6745b85
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        |  11 +
 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h |   1 +
 .../mlir/Dialect/SPIRV/IR/SPIRVTosaOps.h      |  12 +
 .../mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td     | 288 ++++++++++++++-
 .../mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td   |  29 ++
 mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp    | 135 ++++++-
 .../SPIRV/IR/tosa-ops-verification.mlir       | 337 ++++++++++++++++++
 mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir      | 104 ++++++
 mlir/test/Target/SPIRV/tosa-ops.mlir          | 184 ++++++++++
 mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp      |  13 +-
 10 files changed, 1108 insertions(+), 6 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.h

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 21010d91dc47c..4ea6d784dd88f 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4915,6 +4915,17 @@ def SPIRV_FPFastMathModeAttr :
 // SPIR-V TOSA enum definitions.
 //===----------------------------------------------------------------------===//
 
+// NOTE: This is an attribute in the SPIR-V *dialect* but a constant (<id>) in
+// SPIR-V proper.
+def SPIRV_TosaExtAccTypeAttr : SPIRV_I32EnumAttr<
+  "TosaExtAccType", "Tosa Ext Acculumator Type", "tosa_ext_acc_type",
+  [
+      I32EnumAttrCase<"INT32", 1>,
+      I32EnumAttrCase<"FP16", 2>,
+      I32EnumAttrCase<"FP32", 3>,
+      I32EnumAttrCase<"INT48", 4>,
+  ]>;
+
 // NOTE: This is an attribute in the SPIR-V *dialect* but a constant (<id>) in
 // SPIR-V proper.
 def SPIRV_TosaExtNaNPropagationModeAttr : SPIRV_I32EnumAttr<
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h
index 0e1f6e79a3670..4d43c7d7066ed 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h
@@ -16,6 +16,7 @@
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVTosaOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/Interfaces/SPIRVImageInterfaces.h"
 #include "mlir/IR/BuiltinOps.h"
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.h
new file mode 100644
index 0000000000000..30ec869e0bbf0
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.h
@@ -0,0 +1,12 @@
+
+#include "mlir/IR/BuiltinTypes.h"
+
+namespace mlir::spirv {
+
+ParseResult parseSPIRV_I32_1DArmTensor(OpAsmParser &parser,
+                                       DenseIntElementsAttr &attr);
+
+void printSPIRV_I32_1DArmTensor(OpAsmPrinter &printer, Operation *,
+                                       DenseIntElementsAttr attr);
+
+} // namespace mlir::spirv
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index 6c6a318db4827..7efe383918c0a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -51,7 +51,7 @@ def SPIRV_TosaArgMaxOp : SPIRV_TosaOp<"ArgMax", 0, [Pure]> {
     #### Example:
     ```mlir
     %2 = spirv.Tosa.ArgMax axis = 3, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi32>
-    %2 = spirv.Tosa.ArgMax axis = 3, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<2x2x7x14xf32> -> !spirv.arm.tensor<2x2x14xi32>
+    %2 = spirv.Tosa.ArgMax axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<2x2x7x14xf32> -> !spirv.arm.tensor<2x2x14xi32>
     ```
   }];
 
@@ -83,4 +83,290 @@ def SPIRV_TosaArgMaxOp : SPIRV_TosaOp<"ArgMax", 0, [Pure]> {
   }];
 }
 
+
+def SPIRV_TosaConv2DOp : SPIRV_TosaOp<"Conv2D", 2, [Pure,
+  AllElementTypesMatch<["bias", "output"]>,
+  AllElementTypesMatch<["input", "input_zp"]>,
+  AllElementTypesMatch<["weight", "weight_zp"]>]> {
+  let summary = "2D Convolution operator.";
+
+  let description = [{
+    Performs a 2D convolution over the given tensor input, using the weight
+    tensor. Implementations may choose to skip calculation of multiplies in
+    the padding area.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_conv2d
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_conv2d
+
+    #### Example:
+    ```mlir
+    %7 = spirv.Tosa.Conv2D pad = dense<[1, 0, 0, 0]> : !spirv.arm.tensor<4xi32>, stride = dense<[1, 2]> : !spirv.arm.tensor<2xi32>, dilation = dense<[7, 1]> : !spirv.arm.tensor<2xi32>, acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32>
+    %7 = spirv.Tosa.Conv2D pad = dense<0> : !spirv.arm.tensor<4xi32>, stride = dense<1> : !spirv.arm.tensor<2xi32>, dilation = dense<1> : !spirv.arm.tensor<2xi32>, acc_type = <FP16>, local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf16>, !spirv.arm.tensor<11x1x1x27xf16>, !spirv.arm.tensor<11xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x34x18x11xf16>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_Int32_1DTensorArmOfLength4Attr: $pad,
+    SPIRV_Int32_1DTensorArmOfLength2Attr: $stride,
+    SPIRV_Int32_1DTensorArmOfLength2Attr: $dilation,
+    SPIRV_TosaExtAccTypeAttr: $acc_type,
+    SPIRV_BoolConstAttr: $local_bound,
+    SPIRV_TosaNumerical_TensorArm4D: $input,
+    SPIRV_TosaNumerical_TensorArm4D: $weight,
+    SPIRV_TosaNumerical_TensorArm1D: $bias,
+    SPIRV_TosaNumerical_1DTensorArmOfLength1: $input_zp,
+    SPIRV_TosaNumerical_1DTensorArmOfLength1: $weight_zp
+  );
+
+  let results = (outs
+    SPIRV_TosaNumerical_TensorArm4D: $output
+  );
+
+  let hasVerifier = 1;
+
+  let assemblyFormat = [{
+    `pad` `=` custom<SPIRV_I32_1DArmTensor>($pad) `,`
+    `stride` `=` custom<SPIRV_I32_1DArmTensor>($stride) `,`
+    `dilation` `=` custom<SPIRV_I32_1DArmTensor>($dilation) `,`
+    `acc_type` `=` $acc_type `,`
+    `local_bound` `=` $local_bound `,`
+    $input `,`
+    $weight `,`
+    $bias `,`
+    $input_zp `,`
+    $weight_zp
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::spirv::TensorArmType getInputType() {
+      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+    }
+    ::mlir::spirv::TensorArmType getWeightType() {
+      return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
+    }
+    ::mlir::spirv::TensorArmType getBiasType() {
+      return cast<::mlir::spirv::TensorArmType>(getBias().getType());
+    }
+    ::mlir::spirv::TensorArmType getResultType() {
+      return cast<::mlir::spirv::TensorArmType>(getType());
+    }
+  }];
+}
+
+
+def SPIRV_TosaConv3DOp : SPIRV_TosaOp<"Conv3D", 3, [Pure,
+  AllElementTypesMatch<["bias", "output"]>,
+  AllElementTypesMatch<["input", "input_zp"]>,
+  AllElementTypesMatch<["weight", "weight_zp"]>]> {
+  let summary = "3D Convolution operator.";
+
+  let description = [{
+    Performs a 3D convolution over the given input tensor. Implementations
+    may choose to skip calculation of multiplies in the padding area.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_conv3d
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_conv3d
+
+    #### Example:
+    ```mlir
+    %7 = spirv.Tosa.Conv3D pad = dense<0> : !spirv.arm.tensor<6xi32>, stride = dense<1> : !spirv.arm.tensor<3xi32>, dilation = dense<1> : !spirv.arm.tensor<3xi32>, acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x9x21x14x1xi8>, !spirv.arm.tensor<2x1x2x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x9x20x14x2xi32>
+    %7 = spirv.Tosa.Conv3D pad = dense<[0, 1, 1, 0, 0, 1]> : !spirv.arm.tensor<6xi32>, stride = dense<1> : !spirv.arm.tensor<3xi32>, dilation = dense<[1, 1, 7]> : !spirv.arm.tensor<3xi32>, acc_type = <FP32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x2x65539x1x2xf32>, !spirv.arm.tensor<1x1x1x1x2xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x3x65540x2x1xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_Int32_1DTensorArmOfLength6Attr: $pad,
+    SPIRV_Int32_1DTensorArmOfLength3Attr: $stride,
+    SPIRV_Int32_1DTensorArmOfLength3Attr: $dilation,
+    SPIRV_TosaExtAccTypeAttr: $acc_type,
+    SPIRV_BoolConstAttr: $local_bound,
+    SPIRV_TosaNumerical_TensorArm5D: $input,
+    SPIRV_TosaNumerical_TensorArm5D: $weight,
+    SPIRV_TosaNumerical_TensorArm1D: $bias,
+    SPIRV_TosaNumerical_1DTensorArmOfLength1: $input_zp,
+    SPIRV_TosaNumerical_1DTensorArmOfLength1: $weight_zp
+  );
+
+  let results = (outs
+    SPIRV_TosaNumerical_TensorArm5D: $output
+  );
+
+  let hasVerifier = 1;
+
+  let assemblyFormat = [{
+    `pad` `=` custom<SPIRV_I32_1DArmTensor>($pad) `,`
+    `stride` `=` custom<SPIRV_I32_1DArmTensor>($stride) `,`
+    `dilation` `=` custom<SPIRV_I32_1DArmTensor>($dilation) `,`
+    `acc_type` `=` $acc_type `,`
+    `local_bound` `=` $local_bound `,`
+    $input `,`
+    $weight `,`
+    $bias `,`
+    $input_zp `,`
+    $weight_zp
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::spirv::TensorArmType getInputType() {
+      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+    }
+    ::mlir::spirv::TensorArmType getWeightType() {
+      return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
+    }
+    ::mlir::spirv::TensorArmType getBiasType() {
+      return cast<::mlir::spirv::TensorArmType>(getBias().getType());
+    }
+    ::mlir::spirv::TensorArmType getResultType() {
+      return cast<::mlir::spirv::TensorArmType>(getType());
+    }
+  }];
+}
+
+
+def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaOp<"DepthwiseConv2D", 4, [Pure,
+  AllElementTypesMatch<["bias", "output"]>,
+  AllElementTypesMatch<["input", "input_zp"]>,
+  AllElementTypesMatch<["weight", "weight_zp"]>]> {
+  let summary = "Depthwise 2D Convolution operator.";
+
+  let description = [{
+    Performs 2D convolutions separately over each channel of the given tensor
+    input, using the weight tensor. Implementations may choose to skip
+    calculation of multiplies in the padding area.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_depthwise_conv2d
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_depthwise_conv2d
+
+    #### Example:
+    ```mlir
+    %7 = spirv.Tosa.DepthwiseConv2D pad = dense<0> : !spirv.arm.tensor<4xi32>, stride = dense<[1, 2]> : !spirv.arm.tensor<2xi32>, dilation = dense<7> : !spirv.arm.tensor<2xi32>, acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x65537x1xi8>, !spirv.arm.tensor<1x3x1x4xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x4x32762x4xi32>
+    %7 = spirv.Tosa.DepthwiseConv2D pad = dense<[0, 1, 1, 1]> : !spirv.arm.tensor<4xi32>, stride = dense<[1, 2]> : !spirv.arm.tensor<2xi32>, dilation = dense<[1, 7]> : !spirv.arm.tensor<2xi32>, acc_type = <FP32>, local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65540x1x3xf32>, !spirv.arm.tensor<1x1x3x1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x65541x2x3xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_Int32_1DTensorArmOfLength4Attr: $pad,
+    SPIRV_Int32_1DTensorArmOfLength2Attr: $stride,
+    SPIRV_Int32_1DTensorArmOfLength2Attr: $dilation,
+    SPIRV_TosaExtAccTypeAttr: $acc_type,
+    SPIRV_BoolConstAttr: $local_bound,
+    SPIRV_TosaNumerical_TensorArm4D: $input,
+    SPIRV_TosaNumerical_TensorArm4D: $weight,
+    SPIRV_TosaNumerical_TensorArm1D: $bias,
+    SPIRV_TosaNumerical_1DTensorArmOfLength1: $input_zp,
+    SPIRV_TosaNumerical_1DTensorArmOfLength1: $weight_zp
+  );
+
+  let results = (outs
+    SPIRV_TosaNumerical_TensorArm4D: $output
+  );
+
+  let hasVerifier = 1;
+
+  let assemblyFormat = [{
+    `pad` `=` custom<SPIRV_I32_1DArmTensor>($pad) `,`
+    `stride` `=` custom<SPIRV_I32_1DArmTensor>($stride) `,`
+    `dilation` `=` custom<SPIRV_I32_1DArmTensor>($dilation) `,`
+    `acc_type` `=` $acc_type `,`
+    `local_bound` `=` $local_bound `,`
+    $input `,`
+    $weight `,`
+    $bias `,`
+    $input_zp `,`
+    $weight_zp
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::spirv::TensorArmType getInputType() {
+      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+    }
+    ::mlir::spirv::TensorArmType getWeightType() {
+      return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
+    }
+    ::mlir::spirv::TensorArmType getBiasType() {
+      return cast<::mlir::spirv::TensorArmType>(getBias().getType());
+    }
+    ::mlir::spirv::TensorArmType getResultType() {
+      return cast<::mlir::spirv::TensorArmType>(getType());
+    }
+  }];
+}
+
+
+def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaOp<"TransposeConv2D", 9, [Pure,
+  AllElementTypesMatch<["bias", "output"]>,
+  AllElementTypesMatch<["input", "input_zp"]>,
+  AllElementTypesMatch<["weight", "weight_zp"]>]> {
+  let summary = "Transpose 2D Convolution operator.";
+
+  let description = [{
+    Performs a 2D transposed convolution over the given tensor input, using the
+    weights tensor. Implementations may choose to skip calculation of multiplies
+    by zero at fractional input positions.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_transpose_conv2d
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_transpose_conv2d
+
+    #### Example:
+    ```mlir
+    %6 = spirv.Tosa.TransposeConv2D out_pad = dense<0> : !spirv.arm.tensor<4xi32>, stride = dense<1> : !spirv.arm.tensor<2xi32>, acc_type = <INT48>, local_bound = false, %arg0, %arg1, %arg2, %4, %5 : !spirv.arm.tensor<1x13x33x3xi16>, !spirv.arm.tensor<11x1x3x3xi8>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x13x35x11xi64>
+    %6 = spirv.Tosa.TransposeConv2D out_pad = dense<[0, 1, 0, 0]> : !spirv.arm.tensor<4xi32>, stride = dense<[1, 8]> : !spirv.arm.tensor<2xi32>, acc_type = <FP16>, local_bound = true, %arg0, %arg1, %arg2, %4, %5 : !spirv.arm.tensor<10x24x9x13xf16>, !spirv.arm.tensor<14x1x1x13xf16>, !spirv.arm.tensor<14xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<10x25x65x14xf16>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_Int32_1DTensorArmOfLength4Attr: $out_pad,
+    SPIRV_Int32_1DTensorArmOfLength2Attr: $stride,
+    SPIRV_TosaExtAccTypeAttr: $acc_type,
+    SPIRV_BoolConstAttr: $local_bound,
+    SPIRV_TosaNumerical_TensorArm4D: $input,
+    SPIRV_TosaNumerical_TensorArm4D: $weight,
+    SPIRV_TosaNumerical_TensorArm1D: $bias,
+    SPIRV_TosaNumerical_1DTensorArmOfLength1: $input_zp,
+    SPIRV_TosaNumerical_1DTensorArmOfLength1: $weight_zp
+  );
+
+  let results = (outs
+    SPIRV_TosaNumerical_TensorArm4D: $output
+  );
+
+  let hasVerifier = 1;
+
+  let assemblyFormat = [{
+    `out_pad` `=` custom<SPIRV_I32_1DArmTensor>($out_pad) `,`
+    `stride` `=` custom<SPIRV_I32_1DArmTensor>($stride) `,`
+    `acc_type` `=` $acc_type `,`
+    `local_bound` `=` $local_bound `,`
+    $input `,`
+    $weight `,`
+    $bias `,`
+    $input_zp `,`
+    $weight_zp
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::spirv::TensorArmType getInputType() {
+      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+    }
+    ::mlir::spirv::TensorArmType getWeightType() {
+      return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
+    }
+    ::mlir::spirv::TensorArmType getBiasType() {
+      return cast<::mlir::spirv::TensorArmType>(getBias().getType());
+    }
+    ::mlir::spirv::TensorArmType getResultType() {
+      return cast<::mlir::spirv::TensorArmType>(getType());
+    }
+  }];
+}
+
+
 #endif // MLIR_DIALECT_SPIRV_IR_TOSA_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
index e731388182eb4..7e2c37f74b437 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -13,6 +13,7 @@
 #ifndef MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
 #define MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
 
+include "mlir/IR/CommonAttrConstraints.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
 
 def SPIRV_TosaInteger : AnyIntOfWidths<[8, 16, 32, 64]>;
@@ -21,6 +22,7 @@ def SPIRV_TosaNumerical : AnyTypeOf<[SPIRV_TosaInteger, SPIRV_TosaFloat]>;
 def SPIRV_TosaAny : AnyTypeOf<[SPIRV_TosaNumerical, SPIRV_Bool]>;
 
 def SPIRV_TensorArmAxisAttr : ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<5>]>;
+def SPIRV_BoolConstAttr : ConfinedAttr<BoolAttr, []>;
 
 // TensorARM Types
 
@@ -35,7 +37,34 @@ class TensorArmRankOf<list<Type> allowedTypes, list<int> ranks>
       [HasAnyRankOfPred<ranks>],
       !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensorArm">;
 
+def SPIRV_TosaNumerical_TensorArm1D : TensorArmRankOf<[SPIRV_TosaNumerical], [1]>;
+def SPIRV_TosaNumerical_TensorArm4D : TensorArmRankOf<[SPIRV_TosaNumerical], [4]>;
+def SPIRV_TosaNumerical_TensorArm5D : TensorArmRankOf<[SPIRV_TosaNumerical], [5]>;
+
 def SPIRV_TosaNumerical_TensorArm : TensorArmRankOf<[SPIRV_TosaNumerical], [1, 2, 3, 4, 5, 6]>;
 def SPIRV_Int32_TensorArmUpTo5D : TensorArmRankOf<[SPIRV_Int32], [1, 2, 3, 4, 5]>;
 
+class Is1DTensorArmOfLength<list<int> allowedLengths> :
+  And<[HasAnyRankOfPred<[1]>,
+       Or<!foreach(allowedlength, allowedLengths,
+                   CPred<[{::llvm::cast<::mlir::spirv::TensorArmType>($_self).getShape()[0] == }]
+                         # allowedlength>)>]>;
+
+class SPIRV_1DTensorArmOfLengthAndType<list<int> allowedLengths, list<Type> allowedTypes> :
+  ContainerType<AnyTypeOf<allowedTypes>, Is1DTensorArmOfLength<allowedLengths>,
+    "::llvm::cast<::mlir::spirv::TensorArmType>($_self).getElementType()",
+    "rank 1 tensorArm of length " # !interleave(allowedLengths, "/"),
+    "::mlir::spirv::TensorArmType">;
+
+def SPIRV_DenseElementAttrsWithTensorArmType : AttrConstraint<
+  CPred<"::llvm::isa<::mlir::spirv::TensorArmType>(::llvm::cast<::mlir::DenseElementsAttr>($_self).getType())">,
+  "Attr with type = spirv::TensorArmType">;
+
+def SPIRV_Int32_1DTensorArmOfLength2Attr : ConfinedAttr<RankedI32ElementsAttr<[2]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
+def SPIRV_Int32_1DTensorArmOfLength3Attr : ConfinedAttr<RankedI32ElementsAttr<[3]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
+def SPIRV_Int32_1DTensorArmOfLength4Attr : ConfinedAttr<RankedI32ElementsAttr<[4]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
+def SPIRV_Int32_1DTensorArmOfLength6Attr : ConfinedAttr<RankedI32ElementsAttr<[6]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
+
+def SPIRV_TosaNumerical_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_TosaNumerical]>;
+
 #endif // MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
index 4f3c91d4a1c12..e0bca91c41f19 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
@@ -10,9 +10,6 @@
 //
 //===----------------------------------------------------------------------===//
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/TypeUtilities.h"
 
 namespace mlir::spirv {
 
@@ -20,6 +17,60 @@ namespace mlir::spirv {
 // TOSA Operator Verifiers.
 //===----------------------------------------------------------------------===//
 
+namespace {
+
+LogicalResult verifyConvOp(Operation *op, Type inputETy, Type resultETy,
+                           TosaExtAccType accType) {
+  if (inputETy.isInteger() && !inputETy.isInteger(8) &&
+      !inputETy.isInteger(16)) {
+    return op->emitOpError(
+        "input element type can only be of width 8 or 16 when integer type");
+  }
+
+  if (inputETy.isInteger(8) && !resultETy.isInteger(32)) {
+    return op->emitOpError("expect result type to be i32, got ") << resultETy;
+  }
+
+  if (inputETy.isInteger(16) && !resultETy.isInteger(64)) {
+    return op->emitOpError("expect result type to be i64, got ") << resultETy;
+  }
+
+  if (inputETy.isF16() && !resultETy.isF16()) {
+    return op->emitOpError("expect result type to be f16, got ") << resultETy;
+  }
+
+  if (inputETy.isF32() && !resultETy.isF32()) {
+    return op->emitOpError("expect result type to be f32, got ") << resultETy;
+  }
+
+  if (inputETy.isInteger(8) && accType != TosaExtAccType::INT32) {
+    return op->emitOpError("accumulator type for i8 tensorARM is not i32");
+  }
+
+  if (inputETy.isInteger(16) && accType != TosaExtAccType::INT48) {
+    return op->emitOpError("accumulator type for i16 tensorARM is not i48");
+  }
+
+  if (inputETy.isF16() &&
+      !llvm::is_contained({TosaExtAccType::FP16, TosaExtAccType::FP32},
+                          accType)) {
+    return op->emitOpError(
+        "accumulator type for f16 tensorARM is not f16 or f32");
+  }
+
+  if (inputETy.isBF16() && accType != TosaExtAccType::FP32) {
+    return op->emitOpError("accumulator type for bf16 tensorARM is not f32");
+  }
+
+  if (inputETy.isF32() && accType != TosaExtAccType::FP32) {
+    return op->emitOpError("accumulator type for f32 tensorARM is not f32");
+  }
+
+  return success();
+}
+
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // spirv.TosaArgmaxOp
 //===----------------------------------------------------------------------===//
@@ -46,4 +97,82 @@ LogicalResult TosaArgMaxOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.TosaConv2DOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TosaConv2DOp::verify() {
+  Type inputETy = getInputType().getElementType();
+  Type resultETy = getResultType().getElementType();
+  TosaExtAccType accType = getAccType();
+  return verifyConvOp(this->getOperation(), inputETy, resultETy, accType);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TosaConv3DOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TosaConv3DOp::verify() {
+  Type inputETy = getInputType().getElementType();
+  Type resultETy = getResultType().getElementType();
+  TosaExtAccType accType = getAccType();
+  return verifyConvOp(this->getOperation(), inputETy, resultETy, accType);
+}
+
+//===----------------------------------------------------------------------===//
+// SPIRV Tosa DepthwiseConv2D Ops:
+//===----------------------------------------------------------------------===//
+
+LogicalResult TosaDepthwiseConv2DOp::verify() {
+  Type inputETy = getInputType().getElementType();
+  Type resultETy = getResultType().getElementType();
+  TosaExtAccType accType = getAccType();
+  return verifyConvOp(this->getOperation(), inputETy, resultETy, accType);
+}
+
+//===----------------------------------------------------------------------===//
+// SPIRV Tosa TransposeConv2D Ops:
+//===----------------------------------------------------------------------===//
+
+LogicalResult TosaTransposeConv2DOp::verify() {
+  Type inputETy = getInputType().getElementType();
+  Type resultETy = getResultType().getElementType();
+  TosaExtAccType accType = getAccType();
+  return verifyConvOp(this->getOperation(), inputETy, resultETy, accType);
+}
+
+//===----------------------------------------------------------------------===//
+// SPIRV Tosa Custom formatters
+//===----------------------------------------------------------------------===//
+
+ParseResult parseSPIRV_I32_1DArmTensor(OpAsmParser &parser,
+                                       DenseIntElementsAttr &attr) {
+  SmallVector<int32_t, 6> elements;
+  auto f = [&]() {
+    int32_t value;
+    ParseResult r = parser.parseInteger(value);
+    elements.push_back(value);
+    return r;
+  };
+  if (parser.parseCommaSeparatedList(
+          OpAsmParser::Delimiter::Square, f,
+          "parsing values in integer list attribute")) {
+    return failure();
+  }
+
+  auto i32Type = IntegerType::get(parser.getContext(), 32);
+  auto type = TensorArmType::get(
+      ArrayRef{static_cast<int64_t>(elements.size())}, i32Type);
+  attr = DenseIntElementsAttr::get(type, elements);
+  return success();
+}
+
+void printSPIRV_I32_1DArmTensor(OpAsmPrinter &printer, Operation *,
+                                DenseIntElementsAttr attr) {
+  printer << '[';
+  llvm::interleaveComma(attr.getValues<APInt>(), printer,
+                        [&](APInt a) { printer << a.getSExtValue(); });
+  printer << ']';
+}
+
 } // namespace mlir::spirv
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
index a6496316f9881..2099630aff0fb 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
@@ -21,3 +21,340 @@ spirv.ARM.Graph @argmax_axis_value_not_in_input_rank_range(%arg0: !spirv.arm.ten
   %2 = spirv.Tosa.ArgMax axis = 4, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi32>
   spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<3x28x17xi32>
 }
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Conv2D
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @conv2d_wrong_input_integer_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi32>, %arg1: !spirv.arm.tensor<7x1x1x1xi32>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7xi64>) {
+  %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi32>
+  %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi32>
+  // expected-error @+1 {{op input element type can only be of width 8 or 16 when integer type}}
+  %7 = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi32>, !spirv.arm.tensor<7x1x1x1xi32>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi32> -> !spirv.arm.tensor<1x65536x2x7xi64>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi64>
+}
+
+spirv.ARM.Graph @conv2d_mismatch_result_element_type_i8_input(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<1x65536x2x7xi16>) {
+  %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8>
+  %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8>
+  // expected-error @+1 {{op expect result type to be i32, got 'i16'}}
+  %7 = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi16>
+}
+
+spirv.ARM.Graph @conv2d_mismatch_result_element_type_i16_input(%arg0: !spirv.arm.tensor<1x65535x3x1xi16>, %arg1: !spirv.arm.tensor<7x1x1x1xi16>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) {
+  %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi16>
+  %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi16>
+  // expected-error @+1 {{op expect result type to be i64, got 'i32'}}
+  %7 = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi16>, !spirv.arm.tensor<7x1x1x1xi16>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<1x65536x2x7xi32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi32>
+}
+
+spirv.ARM.Graph @conv2d_mismatch_result_element_type_f16_input(%arg0: !spirv.arm.tensor<1x34x18x27xf16>, %arg1: !spirv.arm.tensor<11x1x1x27xf16>, %arg2: !spirv.arm.tensor<11xf32>) -> (!spirv.arm.tensor<1x34x18x11xf32>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+  // expected-error @+1 {{op expect result type to be f16, got 'f32'}}
+  %7 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP16>, local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf16>, !spirv.arm.tensor<11x1x1x27xf16>, !spirv.arm.tensor<11xf32>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x34x18x11xf32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf32>
+}
+
+spirv.ARM.Graph @conv2d_mismatch_result_element_type_f32_input(%arg0: !spirv.arm.tensor<1x34x18x27xf32>, %arg1: !spirv.arm.tensor<11x1x1x27xf32>, %arg2: !spirv.arm.tensor<11xf16>) -> (!spirv.arm.tensor<1x34x18x11xf16>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+  // expected-error @+1 {{op expect result type to be f32, got 'f16'}}
+  %7 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP16>, local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf32>, !spirv.arm.tensor<11x1x1x27xf32>, !spirv.arm.tensor<11xf16>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x34x18x11xf16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf16>
+}
+
+spirv.ARM.Graph @conv2d_bias_element_type_must_be_same_as_result_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) {
+  %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8>
+  %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8>
+  // expected-error @+1 {{op failed to verify that all of {bias, output} have same element type}}
+  %7 = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi32>
+}
+
+spirv.ARM.Graph @conv2d_accumulator_must_be_INT32_for_i8_input_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) {
+  %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8>
+  %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8>
+  // expected-error @+1 {{op accumulator type for i8 tensorARM is not i32}}
+  %7 = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT48>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi32>
+}
+
+spirv.ARM.Graph @conv2d_accumulator_must_be_INT48_for_i16_input_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi16>, %arg1: !spirv.arm.tensor<7x1x1x1xi16>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7xi64>) {
+  %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi16>
+  %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi16>
+  // expected-error @+1 {{op accumulator type for i16 tensorARM is not i48}}
+  %7 = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi16>, !spirv.arm.tensor<7x1x1x1xi16>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<1x65536x2x7xi64>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi64>
+}
+
+spirv.ARM.Graph @conv2d_accumulator_must_be_either_FP16_or_FP32_for_f16_input_element_type(%arg0: !spirv.arm.tensor<1x34x18x27xf16>, %arg1: !spirv.arm.tensor<11x1x1x27xf16>, %arg2: !spirv.arm.tensor<11xf16>) -> (!spirv.arm.tensor<1x34x18x11xf16>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+  // expected-error @+1 {{op accumulator type for f16 tensorARM is not f16 or f32}}
+  %7 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <INT32>, local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf16>, !spirv.arm.tensor<11x1x1x27xf16>, !spirv.arm.tensor<11xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x34x18x11xf16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf16>
+}
+
+spirv.ARM.Graph @conv2d_accumulator_must_be_either_FP32_for_f32_input_element_type(%arg0: !spirv.arm.tensor<1x34x18x27xf32>, %arg1: !spirv.arm.tensor<11x1x1x27xf32>, %arg2: !spirv.arm.tensor<11xf32>) -> (!spirv.arm.tensor<1x34x18x11xf32>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+  // expected-error @+1 {{op accumulator type for f32 tensorARM is not f32}}
+  %7 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP16>, local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf32>, !spirv.arm.tensor<11x1x1x27xf32>, !spirv.arm.tensor<11xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x34x18x11xf32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf32>
+}
+
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Conv3D
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @conv3d_wrong_input_integer_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1x1xi32>, %arg1: !spirv.arm.tensor<7x1x1x1x1xi32>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7x1xi64>) {
+  %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi32>
+  %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi32>
+  // expected-error @+1 {{op input element type can only be of width 8 or 16 when integer type}}
+  %7 = spirv.Tosa.Conv3D pad = [1, 0, 0, 0, 0, 0], stride = [1, 2, 3], dilation = [7, 1, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1x1xi32>, !spirv.arm.tensor<7x1x1x1x1xi32>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi32> -> !spirv.arm.tensor<1x65536x2x7x1xi64>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7x1xi64>
+}
+
+spirv.ARM.Graph @conv3d_mismatch_result_element_type_i8_input(%arg0: !spirv.arm.tensor<1x65535x3x1x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<1x65536x2x7x1xi16>) {
+  %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8>
+  %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8>
+  // expected-error @+1 {{op expect result type to be i32, got 'i16'}}
+  %7 = spirv.Tosa.Conv3D pad = [1, 0, 0, 0, 0, 0], stride = [1, 2, 3], dilation = [7, 1, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1x1xi8>, !spirv.arm.tensor<7x1x1x1x1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7x1xi16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7x1xi16>
+}
+
+spirv.ARM.Graph @conv3d_mismatch_result_element_type_i16_input(%arg0: !spirv.arm.tensor<1x65535x3x1x1xi16>, %arg1: !spirv.arm.tensor<7x1x1x1x1xi16>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x65536x2x7x1xi32>) {
+  %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi16>
+  %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi16>
+  // expected-error @+1 {{op expect result type to be i64, got 'i32'}}
+  %7 = spirv.Tosa.Conv3D pad = [1, 0, 0, 0, 0, 0], stride = [1, 2, 3], dilation = [7, 1, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1x1xi16>, !spirv.arm.tensor<7x1x1x1x1xi16>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<1x65536x2x7x1xi32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7x1xi32>
+}
+
+spirv.ARM.Graph @conv3d_mismatch_result_element_type_f16_input(%arg0: !spirv.arm.tensor<1x34x18x27x1xf16>, %arg1: !spirv.arm.tensor<11x1x1x27x1xf16>, %arg2: !spirv.arm.tensor<11xf32>) -> (!spirv.arm.tensor<1x34x18x11x1xf32>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+  // expected-error @+1 {{op expect result type to be f16, got 'f32'}}
+  %7 = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = <FP16>, local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27x1xf16>, !spirv.arm.tensor<11x1x1x27x1xf16>, !spirv.arm.tensor<11xf32>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x34x18x11x1xf32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11x1xf32>
+}
+
+spirv.ARM.Graph @conv3d_mismatch_result_element_type_f32_input(%arg0: !spirv.arm.tensor<1x34x18x27x1xf32>, %arg1: !spirv.arm.tensor<11x1x1x27x1xf32>, %arg2: !spirv.arm.tensor<11xf16>) -> (!spirv.arm.tensor<1x34x18x11x1xf16>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+  // expected-error @+1 {{op expect result type to be f32, got 'f16'}}
+  %7 = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = <FP16>, local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27x1xf32>, !spirv.arm.tensor<11x1x1x27x1xf32>, !spirv.arm.tensor<11xf16>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x34x18x11x1xf16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11x1xf16>
+}
+
+spirv.ARM.Graph @conv3d_bias_element_type_must_be_same_as_result_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<1x65536x2x7x1xi32>) {
+  %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8>
+  %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8>
+  // expected-error @+1 {{op failed to verify that all of {bias, output} have same element type}}
+  %7 = spirv.Tosa.Conv3D pad = [1, 0, 0, 0, 0, 0], stride = [1, 2, 3], dilation = [7, 1, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1x1xi8>, !spirv.arm.tensor<7x1x1x1x1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7x1xi32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7x1xi32>
+}
+
+spirv.ARM.Graph @conv3d_accumulator_must_be_INT32_for_i8_input_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x65536x2x7x1xi32>) {
+  %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8>
+  %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8>
+  // expected-error @+1 {{op accumulator type for i8 tensorARM is not i32}}
+  %7 = spirv.Tosa.Conv3D pad = [1, 0, 0, 0, 0, 0], stride = [1, 2, 3], dilation = [7, 1, 1], acc_type = <INT48>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1x1xi8>, !spirv.arm.tensor<7x1x1x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7x1xi32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7x1xi32>
+}
+
+spirv.ARM.Graph @conv3d_accumulator_must_be_INT48_for_i16_input_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1x1xi16>, %arg1: !spirv.arm.tensor<7x1x1x1x1xi16>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7x1xi64>) {
+  %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi16>
+  %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi16>
+  // expected-error @+1 {{op accumulator type for i16 tensorARM is not i48}}
+  %7 = spirv.Tosa.Conv3D pad = [1, 0, 0, 0, 0, 0], stride = [1, 2, 3], dilation = [7, 1, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1x1xi16>, !spirv.arm.tensor<7x1x1x1x1xi16>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<1x65536x2x7x1xi64>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7x1xi64>
+}
+
+spirv.ARM.Graph @conv3d_accumulator_must_be_either_FP16_or_FP32_for_f16_input_element_type(%arg0: !spirv.arm.tensor<1x34x18x27x1xf16>, %arg1: !spirv.arm.tensor<11x1x1x27x1xf16>, %arg2: !spirv.arm.tensor<11xf16>) -> (!spirv.arm.tensor<1x34x18x11x1xf16>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+  // expected-error @+1 {{op accumulator type for f16 tensorARM is not f16 or f32}}
+  %7 = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = <INT32>, local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27x1xf16>, !spirv.arm.tensor<11x1x1x27x1xf16>, !spirv.arm.tensor<11xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x34x18x11x1xf16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11x1xf16>
+}
+
+spirv.ARM.Graph @conv3d_accumulator_must_be_either_FP32_for_f32_input_element_type(%arg0: !spirv.arm.tensor<1x34x18x27x1xf32>, %arg1: !spirv.arm.tensor<11x1x1x27x1xf32>, %arg2: !spirv.arm.tensor<11xf32>) -> (!spirv.arm.tensor<1x34x18x11x1xf32>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+  // expected-error @+1 {{op accumulator type for f32 tensorARM is not f32}}
+  %7 = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = <FP16>, local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27x1xf32>, !spirv.arm.tensor<11x1x1x27x1xf32>, !spirv.arm.tensor<11xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x34x18x11x1xf32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11x1xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.DepthwiseConv2D
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @depthwise_conv2d_wrong_input_integer_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi32>, %arg1: !spirv.arm.tensor<7x1x1x1xi32>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7xi64>) {
+  %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi32>
+  %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi32>
+  // expected-error @+1 {{op input element type can only be of width 8 or 16 when integer type}}
+  %7 = spirv.Tosa.DepthwiseConv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi32>, !spirv.arm.tensor<7x1x1x1xi32>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi32> -> !spirv.arm.tensor<1x65536x2x7xi64>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi64>
+}
+
+spirv.ARM.Graph @depthwise_conv2d_mismatch_result_element_type_i8_input(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<1x65536x2x7xi16>) {
+  %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8>
+  %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8>
+  // expected-error @+1 {{op expect result type to be i32, got 'i16'}}
+  %7 = spirv.Tosa.DepthwiseConv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi16>
+}
+
+spirv.ARM.Graph @depthwise_conv2d_mismatch_result_element_type_i16_input(%arg0: !spirv.arm.tensor<1x65535x3x1xi16>, %arg1: !spirv.arm.tensor<7x1x1x1xi16>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) {
+  %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi16>
+  %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi16>
+  // expected-error @+1 {{op expect result type to be i64, got 'i32'}}
+  %7 = spirv.Tosa.DepthwiseConv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi16>, !spirv.arm.tensor<7x1x1x1xi16>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<1x65536x2x7xi32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi32>
+}
+
+spirv.ARM.Graph @depthwise_conv2d_mismatch_result_element_type_f16_input(%arg0: !spirv.arm.tensor<1x34x18x27xf16>, %arg1: !spirv.arm.tensor<11x1x1x27xf16>, %arg2: !spirv.arm.tensor<11xf32>) -> (!spirv.arm.tensor<1x34x18x11xf32>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+  // expected-error @+1 {{op expect result type to be f16, got 'f32'}}
+  %7 = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP16>, local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf16>, !spirv.arm.tensor<11x1x1x27xf16>, !spirv.arm.tensor<11xf32>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x34x18x11xf32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf32>
+}
+
+spirv.ARM.Graph @depthwise_conv2d_mismatch_result_element_type_f32_input(%arg0: !spirv.arm.tensor<1x34x18x27xf32>, %arg1: !spirv.arm.tensor<11x1x1x27xf32>, %arg2: !spirv.arm.tensor<11xf16>) -> (!spirv.arm.tensor<1x34x18x11xf16>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+  // expected-error @+1 {{op expect result type to be f32, got 'f16'}}
+  %7 = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP16>, local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf32>, !spirv.arm.tensor<11x1x1x27xf32>, !spirv.arm.tensor<11xf16>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x34x18x11xf16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf16>
+}
+
+spirv.ARM.Graph @depthwise_conv2d_bias_element_type_must_be_same_as_result_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) {
+  %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8>
+  %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8>
+  // expected-error @+1 {{op failed to verify that all of {bias, output} have same element type}}
+  %7 = spirv.Tosa.DepthwiseConv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi32>
+}
+
+spirv.ARM.Graph @depthwise_conv2d_accumulator_must_be_INT32_for_i8_input_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) {
+  %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8>
+  %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8>
+  // expected-error @+1 {{op accumulator type for i8 tensorARM is not i32}}
+  %7 = spirv.Tosa.DepthwiseConv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT48>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi32>
+}
+
+spirv.ARM.Graph @depthwise_conv2d_accumulator_must_be_INT48_for_i16_input_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi16>, %arg1: !spirv.arm.tensor<7x1x1x1xi16>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7xi64>) {
+  %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi16>
+  %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi16>
+  // expected-error @+1 {{op accumulator type for i16 tensorARM is not i48}}
+  %7 = spirv.Tosa.DepthwiseConv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi16>, !spirv.arm.tensor<7x1x1x1xi16>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<1x65536x2x7xi64>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi64>
+}
+
+spirv.ARM.Graph @depthwise_conv2d_accumulator_must_be_either_FP16_or_FP32_for_f16_input_element_type(%arg0: !spirv.arm.tensor<1x34x18x27xf16>, %arg1: !spirv.arm.tensor<11x1x1x27xf16>, %arg2: !spirv.arm.tensor<11xf16>) -> (!spirv.arm.tensor<1x34x18x11xf16>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+  // expected-error @+1 {{op accumulator type for f16 tensorARM is not f16 or f32}}
+  %7 = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <INT32>, local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf16>, !spirv.arm.tensor<11x1x1x27xf16>, !spirv.arm.tensor<11xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x34x18x11xf16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf16>
+}
+
+spirv.ARM.Graph @depthwise_conv2d_accumulator_must_be_either_FP32_for_f32_input_element_type(%arg0: !spirv.arm.tensor<1x34x18x27xf32>, %arg1: !spirv.arm.tensor<11x1x1x27xf32>, %arg2: !spirv.arm.tensor<11xf32>) -> (!spirv.arm.tensor<1x34x18x11xf32>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+  // expected-error @+1 {{op accumulator type for f32 tensorARM is not f32}}
+  %7 = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP16>, local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf32>, !spirv.arm.tensor<11x1x1x27xf32>, !spirv.arm.tensor<11xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x34x18x11xf32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.TransposeConv2D
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @transpose_conv2d_wrong_input_integer_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi32>, %arg1: !spirv.arm.tensor<7x1x1x1xi32>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7xi64>) {
+  %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi32>
+  %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi32>
+  // expected-error @+1 {{op input element type can only be of width 8 or 16 when integer type}}
+  %7 = spirv.Tosa.TransposeConv2D out_pad = [1, 0, 0, 0], stride = [1, 2], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi32>, !spirv.arm.tensor<7x1x1x1xi32>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi32> -> !spirv.arm.tensor<1x65536x2x7xi64>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi64>
+}
+
+spirv.ARM.Graph @transpose_conv2d_mismatch_result_element_type_i8_input(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<1x65536x2x7xi16>) {
+  %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8>
+  %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8>
+  // expected-error @+1 {{op expect result type to be i32, got 'i16'}}
+  %7 = spirv.Tosa.TransposeConv2D out_pad = [1, 0, 0, 0], stride = [1, 2], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi16>
+}
+
+spirv.ARM.Graph @transpose_conv2d_mismatch_result_element_type_i16_input(%arg0: !spirv.arm.tensor<1x65535x3x1xi16>, %arg1: !spirv.arm.tensor<7x1x1x1xi16>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) {
+  %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi16>
+  %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi16>
+  // expected-error @+1 {{op expect result type to be i64, got 'i32'}}
+  %7 = spirv.Tosa.TransposeConv2D out_pad = [1, 0, 0, 0], stride = [1, 2], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi16>, !spirv.arm.tensor<7x1x1x1xi16>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<1x65536x2x7xi32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi32>
+}
+
+spirv.ARM.Graph @transpose_conv2d_mismatch_result_element_type_f16_input(%arg0: !spirv.arm.tensor<1x34x18x27xf16>, %arg1: !spirv.arm.tensor<11x1x1x27xf16>, %arg2: !spirv.arm.tensor<11xf32>) -> (!spirv.arm.tensor<1x34x18x11xf32>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+  // expected-error @+1 {{op expect result type to be f16, got 'f32'}}
+  %7 = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = <FP16>, local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf16>, !spirv.arm.tensor<11x1x1x27xf16>, !spirv.arm.tensor<11xf32>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x34x18x11xf32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf32>
+}
+
+spirv.ARM.Graph @transpose_conv2d_mismatch_result_element_type_f32_input(%arg0: !spirv.arm.tensor<1x34x18x27xf32>, %arg1: !spirv.arm.tensor<11x1x1x27xf32>, %arg2: !spirv.arm.tensor<11xf16>) -> (!spirv.arm.tensor<1x34x18x11xf16>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+  // expected-error @+1 {{op expect result type to be f32, got 'f16'}}
+  %7 = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = <FP16>, local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf32>, !spirv.arm.tensor<11x1x1x27xf32>, !spirv.arm.tensor<11xf16>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x34x18x11xf16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf16>
+}
+
+spirv.ARM.Graph @transpose_conv2d_bias_element_type_must_be_same_as_result_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) {
+  %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8>
+  %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8>
+  // expected-error @+1 {{op failed to verify that all of {bias, output} have same element type}}
+  %7 = spirv.Tosa.TransposeConv2D out_pad = [1, 0, 0, 0], stride = [1, 2], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi32>
+}
+
+spirv.ARM.Graph @transpose_conv2d_accumulator_must_be_INT32_for_i8_input_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) {
+  %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8>
+  %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8>
+  // expected-error @+1 {{op accumulator type for i8 tensorARM is not i32}}
+  %7 = spirv.Tosa.TransposeConv2D out_pad = [1, 0, 0, 0], stride = [1, 2], acc_type = <INT48>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi32>
+}
+
+spirv.ARM.Graph @transpose_conv2d_accumulator_must_be_INT48_for_i16_input_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi16>, %arg1: !spirv.arm.tensor<7x1x1x1xi16>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7xi64>) {
+  %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi16>
+  %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi16>
+  // expected-error @+1 {{op accumulator type for i16 tensorARM is not i48}}
+  %7 = spirv.Tosa.TransposeConv2D out_pad = [1, 0, 0, 0], stride = [1, 2], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi16>, !spirv.arm.tensor<7x1x1x1xi16>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<1x65536x2x7xi64>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi64>
+}
+
+spirv.ARM.Graph @transpose_conv2d_accumulator_must_be_either_FP16_or_FP32_for_f16_input_element_type(%arg0: !spirv.arm.tensor<1x34x18x27xf16>, %arg1: !spirv.arm.tensor<11x1x1x27xf16>, %arg2: !spirv.arm.tensor<11xf16>) -> (!spirv.arm.tensor<1x34x18x11xf16>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+  // expected-error @+1 {{op accumulator type for f16 tensorARM is not f16 or f32}}
+  %7 = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = <INT32>, local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf16>, !spirv.arm.tensor<11x1x1x27xf16>, !spirv.arm.tensor<11xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x34x18x11xf16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf16>
+}
+
+spirv.ARM.Graph @transpose_conv2d_accumulator_must_be_either_FP32_for_f32_input_element_type(%arg0: !spirv.arm.tensor<1x34x18x27xf32>, %arg1: !spirv.arm.tensor<11x1x1x27xf32>, %arg2: !spirv.arm.tensor<11xf32>) -> (!spirv.arm.tensor<1x34x18x11xf32>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+  // expected-error @+1 {{op accumulator type for f32 tensorARM is not f32}}
+  %7 = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = <FP16>, local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf32>, !spirv.arm.tensor<11x1x1x27xf32>, !spirv.arm.tensor<11xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x34x18x11xf32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf32>
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
index c9832b903b79e..45243a7553c56 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
@@ -21,3 +21,107 @@ spirv.ARM.Graph @argmax_fp(%arg0: !spirv.arm.tensor<2x2x7x14xf32>) -> (!spirv.ar
   // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<2x2x14xi32>
   spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<2x2x14xi32>
 }
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Conv2D - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @conv2d_int(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) {
+  %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8>
+  %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8>
+  // CHECK: {{%.*}} = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32>
+  %7 = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x65536x2x7xi32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Conv2D - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @conv2d_fp(%arg0: !spirv.arm.tensor<1x34x18x27xf16>, %arg1: !spirv.arm.tensor<11x1x1x27xf16>, %arg2: !spirv.arm.tensor<11xf16>) -> (!spirv.arm.tensor<1x34x18x11xf16>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+  // CHECK: {{%.*}} = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP16>, local_bound = true, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x34x18x27xf16>, !spirv.arm.tensor<11x1x1x27xf16>, !spirv.arm.tensor<11xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x34x18x11xf16>
+  %7 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP16>, local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf16>, !spirv.arm.tensor<11x1x1x27xf16>, !spirv.arm.tensor<11xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x34x18x11xf16>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x34x18x11xf16>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf16>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Conv3D - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @conv3d_int(%arg0: !spirv.arm.tensor<1x9x21x14x1xi8>, %arg1: !spirv.arm.tensor<2x1x2x1x1xi8>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x9x20x14x2xi32>) {
+  %5 = spirv.Constant dense<123> : !spirv.arm.tensor<1xi8>
+  %6 = spirv.Constant dense<121> : !spirv.arm.tensor<1xi8>
+  // CHECK: {{%.*}} = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x9x21x14x1xi8>, !spirv.arm.tensor<2x1x2x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x9x20x14x2xi32>
+  %7 = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x9x21x14x1xi8>, !spirv.arm.tensor<2x1x2x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x9x20x14x2xi32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x9x20x14x2xi32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x9x20x14x2xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Conv3D - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @conv3d_fp(%arg0: !spirv.arm.tensor<1x2x65539x1x2xf32>, %arg1: !spirv.arm.tensor<1x1x1x1x2xf32>, %arg2: !spirv.arm.tensor<1xf32>) -> (!spirv.arm.tensor<1x3x65540x2x1xf32>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+  // CHECK: {{%.*}} = spirv.Tosa.Conv3D pad = [0, 1, 1, 0, 0, 1], stride = [1, 1, 1], dilation = [1, 1, 7], acc_type = <FP32>, local_bound = false, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x2x65539x1x2xf32>, !spirv.arm.tensor<1x1x1x1x2xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x3x65540x2x1xf32>
+  %7 = spirv.Tosa.Conv3D pad = [0, 1, 1, 0, 0, 1], stride = [1, 1, 1], dilation = [1, 1, 7], acc_type = <FP32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x2x65539x1x2xf32>, !spirv.arm.tensor<1x1x1x1x2xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x3x65540x2x1xf32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x3x65540x2x1xf32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x3x65540x2x1xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.DepthwiseConv2D - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @depthwiseconv2d_int(%arg0: !spirv.arm.tensor<1x4x65537x1xi8>, %arg1: !spirv.arm.tensor<1x3x1x4xi8>, %arg2: !spirv.arm.tensor<4xi32>) -> (!spirv.arm.tensor<1x4x32762x4xi32>) {
+  %5 = spirv.Constant dense<58> : !spirv.arm.tensor<1xi8>
+  %6 = spirv.Constant dense<-106> : !spirv.arm.tensor<1xi8>
+  // CHECK: {{%.*}} = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 2], dilation = [7, 7], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x4x65537x1xi8>, !spirv.arm.tensor<1x3x1x4xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x4x32762x4xi32>
+  %7 = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 2], dilation = [7, 7], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x65537x1xi8>, !spirv.arm.tensor<1x3x1x4xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x4x32762x4xi32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x4x32762x4xi32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x32762x4xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.DepthwiseConv2D - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @depthwiseconv2d_fp(%arg0: !spirv.arm.tensor<1x65540x1x3xf32>, %arg1: !spirv.arm.tensor<1x1x3x1xf32>, %arg2: !spirv.arm.tensor<1xf32>) -> (!spirv.arm.tensor<1x65541x2x3xf32>) {
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+  %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+  // CHECK: {{%.*}} = spirv.Tosa.DepthwiseConv2D pad = [0, 1, 1, 1], stride = [1, 2], dilation = [1, 7], acc_type = <FP32>, local_bound = true, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x65540x1x3xf32>, !spirv.arm.tensor<1x1x3x1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x65541x2x3xf32>
+  %7 = spirv.Tosa.DepthwiseConv2D pad = [0, 1, 1, 1], stride = [1, 2], dilation = [1, 7], acc_type = <FP32>, local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65540x1x3xf32>, !spirv.arm.tensor<1x1x3x1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x65541x2x3xf32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x65541x2x3xf32>
+  spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65541x2x3xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.TransposeConv2D - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @transposeconv2d_int(%arg0: !spirv.arm.tensor<1x13x33x3xi16>, %arg1: !spirv.arm.tensor<11x1x3x3xi8>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x13x35x11xi64>) {
+  %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+  %5 = spirv.Constant dense<88> : !spirv.arm.tensor<1xi8>
+  // CHECK: {{%.*}} = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = <INT48>, local_bound = false, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x13x33x3xi16>, !spirv.arm.tensor<11x1x3x3xi8>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x13x35x11xi64>
+  %6 = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = <INT48>, local_bound = false, %arg0, %arg1, %arg2, %4, %5 : !spirv.arm.tensor<1x13x33x3xi16>, !spirv.arm.tensor<11x1x3x3xi8>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x13x35x11xi64>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x13x35x11xi64>
+  spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x13x35x11xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.TransposeConv2D - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @transposeconv2d_fp(%arg0: !spirv.arm.tensor<10x24x9x13xf16>, %arg1: !spirv.arm.tensor<14x1x1x13xf16>, %arg2: !spirv.arm.tensor<14xf16>) -> (!spirv.arm.tensor<10x25x65x14xf16>) {
+  %4 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+  // CHECK: {{%.*}} = spirv.Tosa.TransposeConv2D out_pad = [0, 1, 0, 0], stride = [1, 8], acc_type = <FP16>, local_bound = true, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<10x24x9x13xf16>, !spirv.arm.tensor<14x1x1x13xf16>, !spirv.arm.tensor<14xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<10x25x65x14xf16>
+  %6 = spirv.Tosa.TransposeConv2D out_pad = [0, 1, 0, 0], stride = [1, 8], acc_type = <FP16>, local_bound = true, %arg0, %arg1, %arg2, %4, %5 : !spirv.arm.tensor<10x24x9x13xf16>, !spirv.arm.tensor<14x1x1x13xf16>, !spirv.arm.tensor<14xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<10x25x65x14xf16>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<10x25x65x14xf16>
+  spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<10x25x65x14xf16>
+}
diff --git a/mlir/test/Target/SPIRV/tosa-ops.mlir b/mlir/test/Target/SPIRV/tosa-ops.mlir
index 8c0429bca68e4..edaa000c183a8 100644
--- a/mlir/test/Target/SPIRV/tosa-ops.mlir
+++ b/mlir/test/Target/SPIRV/tosa-ops.mlir
@@ -39,3 +39,187 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader
     spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<2x2x14xi32>
   }
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Conv2D - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @conv2d_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x65535x3x1xi8>, UniformConstant>
+  spirv.GlobalVariable @conv2d_int_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<7x1x1x1xi8>, UniformConstant>
+  spirv.GlobalVariable @conv2d_int_arg_2 bind(0, 2) : !spirv.ptr<!spirv.arm.tensor<1xi32>, UniformConstant>
+  spirv.GlobalVariable @conv2d_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<1x65536x2x7xi32>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @conv2d_int, @conv2d_int_arg_0, @conv2d_int_arg_1, @conv2d_int_arg_2, @conv2d_int_res_0
+  spirv.ARM.Graph @conv2d_int(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) {
+    %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8>
+    %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8>
+    // CHECK: {{%.*}} = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32>
+    %7 = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x65536x2x7xi32>
+    spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi32>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Conv2D - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @conv2d_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x34x18x27xf16>, UniformConstant>
+  spirv.GlobalVariable @conv2d_fp_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<11x1x1x27xf16>, UniformConstant>
+  spirv.GlobalVariable @conv2d_fp_arg_2 bind(0, 2) : !spirv.ptr<!spirv.arm.tensor<11xf16>, UniformConstant>
+  spirv.GlobalVariable @conv2d_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<1x34x18x11xf16>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @conv2d_fp, @conv2d_fp_arg_0, @conv2d_fp_arg_1, @conv2d_fp_arg_2, @conv2d_fp_res_0
+  spirv.ARM.Graph @conv2d_fp(%arg0: !spirv.arm.tensor<1x34x18x27xf16>, %arg1: !spirv.arm.tensor<11x1x1x27xf16>, %arg2: !spirv.arm.tensor<11xf16>) -> (!spirv.arm.tensor<1x34x18x11xf16>) {
+    %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+    %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+    // CHECK: {{%.*}} = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP16>, local_bound = true, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x34x18x27xf16>, !spirv.arm.tensor<11x1x1x27xf16>, !spirv.arm.tensor<11xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x34x18x11xf16>
+    %7 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP16>, local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf16>, !spirv.arm.tensor<11x1x1x27xf16>, !spirv.arm.tensor<11xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x34x18x11xf16>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x34x18x11xf16>
+    spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf16>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Conv3D - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @conv3d_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x9x21x14x1xi8>, UniformConstant>
+  spirv.GlobalVariable @conv3d_int_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<2x1x2x1x1xi8>, UniformConstant>
+  spirv.GlobalVariable @conv3d_int_arg_2 bind(0, 2) : !spirv.ptr<!spirv.arm.tensor<1xi32>, UniformConstant>
+  spirv.GlobalVariable @conv3d_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<1x9x20x14x2xi32>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @conv3d_int, @conv3d_int_arg_0, @conv3d_int_arg_1, @conv3d_int_arg_2, @conv3d_int_res_0
+  spirv.ARM.Graph @conv3d_int(%arg0: !spirv.arm.tensor<1x9x21x14x1xi8>, %arg1: !spirv.arm.tensor<2x1x2x1x1xi8>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x9x20x14x2xi32>) {
+    %5 = spirv.Constant dense<123> : !spirv.arm.tensor<1xi8>
+    %6 = spirv.Constant dense<121> : !spirv.arm.tensor<1xi8>
+    // CHECK: {{%.*}} = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x9x21x14x1xi8>, !spirv.arm.tensor<2x1x2x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x9x20x14x2xi32>
+    %7 = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x9x21x14x1xi8>, !spirv.arm.tensor<2x1x2x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x9x20x14x2xi32>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x9x20x14x2xi32>
+    spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x9x20x14x2xi32>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Conv3D - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @conv3d_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x2x65539x1x2xf32>, UniformConstant>
+  spirv.GlobalVariable @conv3d_fp_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<1x1x1x1x2xf32>, UniformConstant>
+  spirv.GlobalVariable @conv3d_fp_arg_2 bind(0, 2) : !spirv.ptr<!spirv.arm.tensor<1xf32>, UniformConstant>
+  spirv.GlobalVariable @conv3d_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<1x3x65540x2x1xf32>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @conv3d_fp, @conv3d_fp_arg_0, @conv3d_fp_arg_1, @conv3d_fp_arg_2, @conv3d_fp_res_0
+  spirv.ARM.Graph @conv3d_fp(%arg0: !spirv.arm.tensor<1x2x65539x1x2xf32>, %arg1: !spirv.arm.tensor<1x1x1x1x2xf32>, %arg2: !spirv.arm.tensor<1xf32>) -> (!spirv.arm.tensor<1x3x65540x2x1xf32>) {
+    %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+    %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+    // CHECK: {{%.*}} = spirv.Tosa.Conv3D pad = [0, 1, 1, 0, 0, 1], stride = [1, 1, 1], dilation = [1, 1, 7], acc_type = <FP32>, local_bound = false, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x2x65539x1x2xf32>, !spirv.arm.tensor<1x1x1x1x2xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x3x65540x2x1xf32>
+    %7 = spirv.Tosa.Conv3D pad = [0, 1, 1, 0, 0, 1], stride = [1, 1, 1], dilation = [1, 1, 7], acc_type = <FP32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x2x65539x1x2xf32>, !spirv.arm.tensor<1x1x1x1x2xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x3x65540x2x1xf32>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x3x65540x2x1xf32>
+    spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x3x65540x2x1xf32>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.DepthwiseConv2D - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @depthwiseconv2d_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x4x65537x1xi8>, UniformConstant>
+  spirv.GlobalVariable @depthwiseconv2d_int_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<1x3x1x4xi8>, UniformConstant>
+  spirv.GlobalVariable @depthwiseconv2d_int_arg_2 bind(0, 2) : !spirv.ptr<!spirv.arm.tensor<4xi32>, UniformConstant>
+  spirv.GlobalVariable @depthwiseconv2d_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<1x4x32762x4xi32>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @depthwiseconv2d_int, @depthwiseconv2d_int_arg_0, @depthwiseconv2d_int_arg_1, @depthwiseconv2d_int_arg_2, @depthwiseconv2d_int_res_0
+  spirv.ARM.Graph @depthwiseconv2d_int(%arg0: !spirv.arm.tensor<1x4x65537x1xi8>, %arg1: !spirv.arm.tensor<1x3x1x4xi8>, %arg2: !spirv.arm.tensor<4xi32>) -> (!spirv.arm.tensor<1x4x32762x4xi32>) {
+    %5 = spirv.Constant dense<58> : !spirv.arm.tensor<1xi8>
+    %6 = spirv.Constant dense<-106> : !spirv.arm.tensor<1xi8>
+    // CHECK: {{%.*}} = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 2], dilation = [7, 7], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x4x65537x1xi8>, !spirv.arm.tensor<1x3x1x4xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x4x32762x4xi32>
+    %7 = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 2], dilation = [7, 7], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x65537x1xi8>, !spirv.arm.tensor<1x3x1x4xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x4x32762x4xi32>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x4x32762x4xi32>
+    spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x32762x4xi32>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.DepthwiseConv2D - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @depthwiseconv2d_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x65540x1x3xf32>, UniformConstant>
+  spirv.GlobalVariable @depthwiseconv2d_fp_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<1x1x3x1xf32>, UniformConstant>
+  spirv.GlobalVariable @depthwiseconv2d_fp_arg_2 bind(0, 2) : !spirv.ptr<!spirv.arm.tensor<1xf32>, UniformConstant>
+  spirv.GlobalVariable @depthwiseconv2d_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<1x65541x2x3xf32>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @depthwiseconv2d_fp, @depthwiseconv2d_fp_arg_0, @depthwiseconv2d_fp_arg_1, @depthwiseconv2d_fp_arg_2, @depthwiseconv2d_fp_res_0
+  spirv.ARM.Graph @depthwiseconv2d_fp(%arg0: !spirv.arm.tensor<1x65540x1x3xf32>, %arg1: !spirv.arm.tensor<1x1x3x1xf32>, %arg2: !spirv.arm.tensor<1xf32>) -> (!spirv.arm.tensor<1x65541x2x3xf32>) {
+    %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+    %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+    // CHECK: {{%.*}} = spirv.Tosa.DepthwiseConv2D pad = [0, 1, 1, 1], stride = [1, 2], dilation = [1, 7], acc_type = <FP32>, local_bound = true, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x65540x1x3xf32>, !spirv.arm.tensor<1x1x3x1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x65541x2x3xf32>
+    %7 = spirv.Tosa.DepthwiseConv2D pad = [0, 1, 1, 1], stride = [1, 2], dilation = [1, 7], acc_type = <FP32>, local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65540x1x3xf32>, !spirv.arm.tensor<1x1x3x1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x65541x2x3xf32>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x65541x2x3xf32>
+    spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65541x2x3xf32>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.TransposeConv2D - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @transposeconv2d_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x13x33x3xi16>, UniformConstant>
+  spirv.GlobalVariable @transposeconv2d_int_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<11x1x3x3xi8>, UniformConstant>
+  spirv.GlobalVariable @transposeconv2d_int_arg_2 bind(0, 2) : !spirv.ptr<!spirv.arm.tensor<1xi64>, UniformConstant>
+  spirv.GlobalVariable @transposeconv2d_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<1x13x35x11xi64>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @transposeconv2d_int, @transposeconv2d_int_arg_0, @transposeconv2d_int_arg_1, @transposeconv2d_int_arg_2, @transposeconv2d_int_res_0
+  spirv.ARM.Graph @transposeconv2d_int(%arg0: !spirv.arm.tensor<1x13x33x3xi16>, %arg1: !spirv.arm.tensor<11x1x3x3xi8>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x13x35x11xi64>) {
+    %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+    %5 = spirv.Constant dense<88> : !spirv.arm.tensor<1xi8>
+    // CHECK: {{%.*}} = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = <INT48>, local_bound = false, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x13x33x3xi16>, !spirv.arm.tensor<11x1x3x3xi8>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x13x35x11xi64>
+    %6 = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = <INT48>, local_bound = false, %arg0, %arg1, %arg2, %4, %5 : !spirv.arm.tensor<1x13x33x3xi16>, !spirv.arm.tensor<11x1x3x3xi8>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x13x35x11xi64>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x13x35x11xi64>
+    spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x13x35x11xi64>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.TransposeConv2D - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @transposeconv2d_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<10x24x9x13xf16>, UniformConstant>
+  spirv.GlobalVariable @transposeconv2d_fp_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<14x1x1x13xf16>, UniformConstant>
+  spirv.GlobalVariable @transposeconv2d_fp_arg_2 bind(0, 2) : !spirv.ptr<!spirv.arm.tensor<14xf16>, UniformConstant>
+  spirv.GlobalVariable @transposeconv2d_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<10x25x65x14xf16>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @transposeconv2d_fp, @transposeconv2d_fp_arg_0, @transposeconv2d_fp_arg_1, @transposeconv2d_fp_arg_2, @transposeconv2d_fp_res_0
+  spirv.ARM.Graph @transposeconv2d_fp(%arg0: !spirv.arm.tensor<10x24x9x13xf16>, %arg1: !spirv.arm.tensor<14x1x1x13xf16>, %arg2: !spirv.arm.tensor<14xf16>) -> (!spirv.arm.tensor<10x25x65x14xf16>) {
+    %4 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+    %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+    // CHECK: {{%.*}} = spirv.Tosa.TransposeConv2D out_pad = [0, 1, 0, 0], stride = [1, 8], acc_type = <FP16>, local_bound = true, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<10x24x9x13xf16>, !spirv.arm.tensor<14x1x1x13xf16>, !spirv.arm.tensor<14xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<10x25x65x14xf16>
+    %6 = spirv.Tosa.TransposeConv2D out_pad = [0, 1, 0, 0], stride = [1, 8], acc_type = <FP16>, local_bound = true, %arg0, %arg1, %arg2, %4, %5 : !spirv.arm.tensor<10x24x9x13xf16>, !spirv.arm.tensor<14x1x1x13xf16>, !spirv.arm.tensor<14xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<10x25x65x14xf16>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<10x25x65x14xf16>
+    spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<10x25x65x14xf16>
+  }
+}
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index f3327e31aae04..0b1771ffcee71 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -501,6 +501,7 @@ constexpr llvm::StringLiteral constantIdEnumAttrs[] = {
     "SPIRV_KHR_CooperativeMatrixLayoutAttr",
     "SPIRV_MemorySemanticsAttr",
     "SPIRV_MatrixLayoutAttr",
+    "SPIRV_TosaExtAccTypeAttr",
     "SPIRV_TosaExtNaNPropagationModeAttr",
 };
 
@@ -556,11 +557,18 @@ static void emitAttributeSerialization(const Attribute &attr,
     os << tabs << "    return failure();\n";
     os << tabs << "  }\n";
     os << tabs << formatv("  {0}.push_back(attrTypeID);\n", operandList);
-  } else if (attr.getAttrDefName() == "SPIRV_TensorArmAxisAttr") {
+  } else if (llvm::is_contained(
+                 {"SPIRV_BoolConstAttr", "SPIRV_TensorArmAxisAttr"},
+                 attr.getAttrDefName())) {
     os << tabs
        << formatv(
               "  {0}.push_back(prepareConstantScalar({1}.getLoc(), attr));\n",
               operandList, opVar);
+  } else if (attr.getAttrDefName().contains("TensorArm")) {
+    os << tabs
+       << formatv("  {0}.push_back(prepareConstant({1}.getLoc(), "
+                  "llvm::cast<DenseElementsAttr>(attr).getType(), attr));\n",
+                  operandList, opVar);
   } else {
     PrintFatalError(
         loc,
@@ -855,7 +863,8 @@ static void emitAttributeDeserialization(const Attribute &attr,
        << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
                   "TypeAttr::get(getType({2}[{3}++]))));\n",
                   attrList, attrName, words, wordIndex);
-  } else if (attr.getAttrDefName() == "SPIRV_TensorArmAxisAttr") {
+  } else if (attr.getAttrDefName() == "SPIRV_BoolConstAttr" ||
+             attr.getAttrDefName().contains("TensorArm")) {
     os << tabs
        << formatv("std::optional<std::pair<Attribute, Type>> c = "
                   "getConstant({0}[{1}++]);\n",



More information about the Mlir-commits mailing list