[Mlir-commits] [mlir] [mlir][spirv] Expand support for TOSA Extended Instruction Set (00100… (PR #176908)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 20 04:06:16 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-spirv
Author: Davide Grohmann (davidegrohmann)
<details>
<summary>Changes</summary>
…0.1)
This patch expands support for the TOSA Extended Instruction Set (001000.1) to the SPIR-V dialect in MLIR. The TOSA extended instruction set provides a standardized set of machine learning operations designed to be used within `spirv.ARM.Graph` operations (corresponding to OpGraphARM in SPV_ARM_graph) and typed with `!spirv.arm.tensor<...>` (corresponding to OpTypeTensorARM in SPV_ARM_tensor).
The change introduces:
* Extending dialect plumbing for import, serialization, and deserialization of the TOSA extended instruction set.
* The `spirv.Tosa.*Conv*` convolution operation from TOSA extended instruction, each lowering to the corresponding `OpExtInst`.
* Verification enforcing that new convolution operations appears only within `spirv.ARM.Graph` regions, operates on `!spirv.arm.tensor<...>` types, and is well-formed according to the TOSA 001000.1 specification.
All convolution operations from TOSA 001000.1 extended instructions are introduced: [parser, printer, verifier, and round-trip tests using MLIR’s SPIR-V serialization/deserialization infrastructure are included.
This work aligns with Khronos SPIR-V TOSA specifications.
Specification:
https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html
---
Patch is 92.45 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/176908.diff
8 Files Affected:
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+11)
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td (+250)
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td (+29)
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp (+90)
- (modified) mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir (+337)
- (modified) mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir (+104)
- (modified) mlir/test/Target/SPIRV/tosa-ops.mlir (+184)
- (modified) mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp (+11-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 21010d91dc47c..4ea6d784dd88f 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4915,6 +4915,17 @@ def SPIRV_FPFastMathModeAttr :
// SPIR-V TOSA enum definitions.
//===----------------------------------------------------------------------===//
+// NOTE: This is an attribute in the SPIR-V *dialect* but a constant (<id>) in
+// SPIR-V proper.
+def SPIRV_TosaExtAccTypeAttr : SPIRV_I32EnumAttr<
+ "TosaExtAccType", "Tosa Ext Acculumator Type", "tosa_ext_acc_type",
+ [
+ I32EnumAttrCase<"INT32", 1>,
+ I32EnumAttrCase<"FP16", 2>,
+ I32EnumAttrCase<"FP32", 3>,
+ I32EnumAttrCase<"INT48", 4>,
+ ]>;
+
// NOTE: This is an attribute in the SPIR-V *dialect* but a constant (<id>) in
// SPIR-V proper.
def SPIRV_TosaExtNaNPropagationModeAttr : SPIRV_I32EnumAttr<
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index 6c6a318db4827..6fd368af6ec7e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -83,4 +83,254 @@ def SPIRV_TosaArgMaxOp : SPIRV_TosaOp<"ArgMax", 0, [Pure]> {
}];
}
+
+def SPIRV_TosaConv2DOp : SPIRV_TosaOp<"Conv2D", 2, [Pure]> {
+ let summary = "2D Convolution operator.";
+
+ let description = [{
+ Performs a 2D convolution over the given tensor input, using the weight
+ tensor. Implementations may choose to skip calculation of multiplies in
+ the padding area.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_conv2d
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_conv2d
+
+ #### Example:
+ ```mlir
+ %7 = spirv.Tosa.Conv2D pad = dense<[1, 0, 0, 0]> : !spirv.arm.tensor<4xi32>, stride = dense<[1, 2]> : !spirv.arm.tensor<2xi32>, dilation = dense<[7, 1]> : !spirv.arm.tensor<2xi32>, acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32>
+ %7 = spirv.Tosa.Conv2D pad = dense<0> : !spirv.arm.tensor<4xi32>, stride = dense<1> : !spirv.arm.tensor<2xi32>, dilation = dense<1> : !spirv.arm.tensor<2xi32>, acc_type = <FP16>, local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf16>, !spirv.arm.tensor<11x1x1x27xf16>, !spirv.arm.tensor<11xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x34x18x11xf16>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_Int32_1DTensorArmOfLength4Attr: $pad,
+ SPIRV_Int32_1DTensorArmOfLength2Attr: $stride,
+ SPIRV_Int32_1DTensorArmOfLength2Attr: $dilation,
+ SPIRV_TosaExtAccTypeAttr: $acc_type,
+ SPIRV_BoolConstAttr: $local_bound,
+ SPIRV_TosaNumerical_TensorArm4D: $input,
+ SPIRV_TosaNumerical_TensorArm4D: $weight,
+ SPIRV_TosaNumerical_TensorArm1D: $bias,
+ SPIRV_TosaNumerical_1DTensorArmOfLength1: $input_zp,
+ SPIRV_TosaNumerical_1DTensorArmOfLength1: $weight_zp
+ );
+
+ let results = (outs
+ SPIRV_TosaNumerical_TensorArm4D: $output
+ );
+
+ let hasVerifier = 1;
+
+ let assemblyFormat = [{
+ `pad` `=` $pad `,` `stride` `=` $stride `,`
+ `dilation` `=` $dilation `,` `acc_type` `=` $acc_type `,`
+ `local_bound` `=` $local_bound `,`
+ $input `,` $weight `,` $bias `,` $input_zp `,` $weight_zp
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::spirv::TensorArmType getInputType() {
+ return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+ }
+ ::mlir::spirv::TensorArmType getWeightType() {
+ return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
+ }
+ ::mlir::spirv::TensorArmType getBiasType() {
+ return cast<::mlir::spirv::TensorArmType>(getBias().getType());
+ }
+ ::mlir::spirv::TensorArmType getResultType() {
+ return cast<::mlir::spirv::TensorArmType>(getType());
+ }
+ }];
+}
+
+
+def SPIRV_TosaConv3DOp : SPIRV_TosaOp<"Conv3D", 3, [Pure]> {
+ let summary = "3D Convolution operator.";
+
+ let description = [{
+ Performs a 3D convolution over the given input tensor. Implementations
+ may choose to skip calculation of multiplies in the padding area.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_conv3d
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_conv3d
+
+ #### Example:
+ ```mlir
+ %7 = spirv.Tosa.Conv3D pad = dense<0> : !spirv.arm.tensor<6xi32>, stride = dense<1> : !spirv.arm.tensor<3xi32>, dilation = dense<1> : !spirv.arm.tensor<3xi32>, acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x9x21x14x1xi8>, !spirv.arm.tensor<2x1x2x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x9x20x14x2xi32>
+ %7 = spirv.Tosa.Conv3D pad = dense<[0, 1, 1, 0, 0, 1]> : !spirv.arm.tensor<6xi32>, stride = dense<1> : !spirv.arm.tensor<3xi32>, dilation = dense<[1, 1, 7]> : !spirv.arm.tensor<3xi32>, acc_type = <FP32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x2x65539x1x2xf32>, !spirv.arm.tensor<1x1x1x1x2xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x3x65540x2x1xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_Int32_1DTensorArmOfLength6Attr: $pad,
+ SPIRV_Int32_1DTensorArmOfLength3Attr: $stride,
+ SPIRV_Int32_1DTensorArmOfLength3Attr: $dilation,
+ SPIRV_TosaExtAccTypeAttr: $acc_type,
+ SPIRV_BoolConstAttr: $local_bound,
+ SPIRV_TosaNumerical_TensorArm5D: $input,
+ SPIRV_TosaNumerical_TensorArm5D: $weight,
+ SPIRV_TosaNumerical_TensorArm1D: $bias,
+ SPIRV_TosaNumerical_1DTensorArmOfLength1: $input_zp,
+ SPIRV_TosaNumerical_1DTensorArmOfLength1: $weight_zp
+ );
+
+ let results = (outs
+ SPIRV_TosaNumerical_TensorArm5D: $output
+ );
+
+ let hasVerifier = 1;
+
+ let assemblyFormat = [{
+ `pad` `=` $pad `,` `stride` `=` $stride `,`
+ `dilation` `=` $dilation `,` `acc_type` `=` $acc_type `,`
+ `local_bound` `=` $local_bound `,`
+ $input `,` $weight `,` $bias `,` $input_zp `,` $weight_zp
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::spirv::TensorArmType getInputType() {
+ return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+ }
+ ::mlir::spirv::TensorArmType getWeightType() {
+ return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
+ }
+ ::mlir::spirv::TensorArmType getBiasType() {
+ return cast<::mlir::spirv::TensorArmType>(getBias().getType());
+ }
+ ::mlir::spirv::TensorArmType getResultType() {
+ return cast<::mlir::spirv::TensorArmType>(getType());
+ }
+ }];
+}
+
+
+def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaOp<"DepthwiseConv2D", 4, [Pure]> {
+ let summary = "Depthwise 2D Convolution operator.";
+
+ let description = [{
+ Performs 2D convolutions separately over each channel of the given tensor
+ input, using the weight tensor. Implementations may choose to skip
+ calculation of multiplies in the padding area.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_depthwise_conv2d
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_depthwise_conv2d
+
+ #### Example:
+ ```mlir
+ %7 = spirv.Tosa.DepthwiseConv2D pad = dense<0> : !spirv.arm.tensor<4xi32>, stride = dense<[1, 2]> : !spirv.arm.tensor<2xi32>, dilation = dense<7> : !spirv.arm.tensor<2xi32>, acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x65537x1xi8>, !spirv.arm.tensor<1x3x1x4xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x4x32762x4xi32>
+ %7 = spirv.Tosa.DepthwiseConv2D pad = dense<[0, 1, 1, 1]> : !spirv.arm.tensor<4xi32>, stride = dense<[1, 2]> : !spirv.arm.tensor<2xi32>, dilation = dense<[1, 7]> : !spirv.arm.tensor<2xi32>, acc_type = <FP32>, local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65540x1x3xf32>, !spirv.arm.tensor<1x1x3x1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x65541x2x3xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_Int32_1DTensorArmOfLength4Attr: $pad,
+ SPIRV_Int32_1DTensorArmOfLength2Attr: $stride,
+ SPIRV_Int32_1DTensorArmOfLength2Attr: $dilation,
+ SPIRV_TosaExtAccTypeAttr: $acc_type,
+ SPIRV_BoolConstAttr: $local_bound,
+ SPIRV_TosaNumerical_TensorArm4D: $input,
+ SPIRV_TosaNumerical_TensorArm4D: $weight,
+ SPIRV_TosaNumerical_TensorArm1D: $bias,
+ SPIRV_TosaNumerical_1DTensorArmOfLength1: $input_zp,
+ SPIRV_TosaNumerical_1DTensorArmOfLength1: $weight_zp
+ );
+
+ let results = (outs
+ SPIRV_TosaNumerical_TensorArm4D: $output
+ );
+
+ let hasVerifier = 1;
+
+ let assemblyFormat = [{
+ `pad` `=` $pad `,` `stride` `=` $stride `,`
+ `dilation` `=` $dilation `,` `acc_type` `=` $acc_type `,`
+ `local_bound` `=` $local_bound `,`
+ $input `,` $weight `,` $bias `,` $input_zp `,` $weight_zp
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::spirv::TensorArmType getInputType() {
+ return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+ }
+ ::mlir::spirv::TensorArmType getWeightType() {
+ return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
+ }
+ ::mlir::spirv::TensorArmType getBiasType() {
+ return cast<::mlir::spirv::TensorArmType>(getBias().getType());
+ }
+ ::mlir::spirv::TensorArmType getResultType() {
+ return cast<::mlir::spirv::TensorArmType>(getType());
+ }
+ }];
+}
+
+
+def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaOp<"TransposeConv2D", 9, [Pure]> {
+ let summary = "Transpose 2D Convolution operator.";
+
+ let description = [{
+ Performs a 2D transposed convolution over the given tensor input, using the
+ weights tensor. Implementations may choose to skip calculation of multiplies
+ by zero at fractional input positions.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_transpose_conv2d
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_transpose_conv2d
+
+ #### Example:
+ ```mlir
+ %6 = spirv.Tosa.TransposeConv2D out_pad = dense<0> : !spirv.arm.tensor<4xi32>, stride = dense<1> : !spirv.arm.tensor<2xi32>, acc_type = <INT48>, local_bound = false, %arg0, %arg1, %arg2, %4, %5 : !spirv.arm.tensor<1x13x33x3xi16>, !spirv.arm.tensor<11x1x3x3xi8>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x13x35x11xi64>
+ %6 = spirv.Tosa.TransposeConv2D out_pad = dense<[0, 1, 0, 0]> : !spirv.arm.tensor<4xi32>, stride = dense<[1, 8]> : !spirv.arm.tensor<2xi32>, acc_type = <FP16>, local_bound = true, %arg0, %arg1, %arg2, %4, %5 : !spirv.arm.tensor<10x24x9x13xf16>, !spirv.arm.tensor<14x1x1x13xf16>, !spirv.arm.tensor<14xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<10x25x65x14xf16>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_Int32_1DTensorArmOfLength4Attr: $out_pad,
+ SPIRV_Int32_1DTensorArmOfLength2Attr: $stride,
+ SPIRV_TosaExtAccTypeAttr: $acc_type,
+ SPIRV_BoolConstAttr: $local_bound,
+ SPIRV_TosaNumerical_TensorArm4D: $input,
+ SPIRV_TosaNumerical_TensorArm4D: $weight,
+ SPIRV_TosaNumerical_TensorArm1D: $bias,
+ SPIRV_TosaNumerical_1DTensorArmOfLength1: $input_zp,
+ SPIRV_TosaNumerical_1DTensorArmOfLength1: $weight_zp
+ );
+
+ let results = (outs
+ SPIRV_TosaNumerical_TensorArm4D: $output
+ );
+
+ let hasVerifier = 1;
+
+ let assemblyFormat = [{
+ `out_pad` `=` $out_pad `,` `stride` `=` $stride `,`
+ `acc_type` `=` $acc_type `,` `local_bound` `=` $local_bound `,`
+ $input `,` $weight `,` $bias `,` $input_zp `,` $weight_zp
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::spirv::TensorArmType getInputType() {
+ return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+ }
+ ::mlir::spirv::TensorArmType getWeightType() {
+ return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
+ }
+ ::mlir::spirv::TensorArmType getBiasType() {
+ return cast<::mlir::spirv::TensorArmType>(getBias().getType());
+ }
+ ::mlir::spirv::TensorArmType getResultType() {
+ return cast<::mlir::spirv::TensorArmType>(getType());
+ }
+ }];
+}
+
+
#endif // MLIR_DIALECT_SPIRV_IR_TOSA_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
index e731388182eb4..7e2c37f74b437 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -13,6 +13,7 @@
#ifndef MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
#define MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
+include "mlir/IR/CommonAttrConstraints.td"
include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
def SPIRV_TosaInteger : AnyIntOfWidths<[8, 16, 32, 64]>;
@@ -21,6 +22,7 @@ def SPIRV_TosaNumerical : AnyTypeOf<[SPIRV_TosaInteger, SPIRV_TosaFloat]>;
def SPIRV_TosaAny : AnyTypeOf<[SPIRV_TosaNumerical, SPIRV_Bool]>;
def SPIRV_TensorArmAxisAttr : ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<5>]>;
+def SPIRV_BoolConstAttr : ConfinedAttr<BoolAttr, []>;
// TensorARM Types
@@ -35,7 +37,34 @@ class TensorArmRankOf<list<Type> allowedTypes, list<int> ranks>
[HasAnyRankOfPred<ranks>],
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensorArm">;
+def SPIRV_TosaNumerical_TensorArm1D : TensorArmRankOf<[SPIRV_TosaNumerical], [1]>;
+def SPIRV_TosaNumerical_TensorArm4D : TensorArmRankOf<[SPIRV_TosaNumerical], [4]>;
+def SPIRV_TosaNumerical_TensorArm5D : TensorArmRankOf<[SPIRV_TosaNumerical], [5]>;
+
def SPIRV_TosaNumerical_TensorArm : TensorArmRankOf<[SPIRV_TosaNumerical], [1, 2, 3, 4, 5, 6]>;
def SPIRV_Int32_TensorArmUpTo5D : TensorArmRankOf<[SPIRV_Int32], [1, 2, 3, 4, 5]>;
+class Is1DTensorArmOfLength<list<int> allowedLengths> :
+ And<[HasAnyRankOfPred<[1]>,
+ Or<!foreach(allowedlength, allowedLengths,
+ CPred<[{::llvm::cast<::mlir::spirv::TensorArmType>($_self).getShape()[0] == }]
+ # allowedlength>)>]>;
+
+class SPIRV_1DTensorArmOfLengthAndType<list<int> allowedLengths, list<Type> allowedTypes> :
+ ContainerType<AnyTypeOf<allowedTypes>, Is1DTensorArmOfLength<allowedLengths>,
+ "::llvm::cast<::mlir::spirv::TensorArmType>($_self).getElementType()",
+ "rank 1 tensorArm of length " # !interleave(allowedLengths, "/"),
+ "::mlir::spirv::TensorArmType">;
+
+def SPIRV_DenseElementAttrsWithTensorArmType : AttrConstraint<
+ CPred<"::llvm::isa<::mlir::spirv::TensorArmType>(::llvm::cast<::mlir::DenseElementsAttr>($_self).getType())">,
+ "Attr with type = spirv::TensorArmType">;
+
+def SPIRV_Int32_1DTensorArmOfLength2Attr : ConfinedAttr<RankedI32ElementsAttr<[2]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
+def SPIRV_Int32_1DTensorArmOfLength3Attr : ConfinedAttr<RankedI32ElementsAttr<[3]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
+def SPIRV_Int32_1DTensorArmOfLength4Attr : ConfinedAttr<RankedI32ElementsAttr<[4]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
+def SPIRV_Int32_1DTensorArmOfLength6Attr : ConfinedAttr<RankedI32ElementsAttr<[6]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
+
+def SPIRV_TosaNumerical_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_TosaNumerical]>;
+
#endif // MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
index 4f3c91d4a1c12..0a755bf7c1b00 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
@@ -20,6 +20,72 @@ namespace mlir::spirv {
// TOSA Operator Verifiers.
//===----------------------------------------------------------------------===//
+namespace {
+
+template <typename T>
+static LogicalResult verifyConvOp(T op) {
+ ShapedType inputTy = op.getInputType();
+ ShapedType biasTy = op.getBiasType();
+ ShapedType resultTy = op.getResultType();
+
+ Type inputETy = inputTy.getElementType();
+ Type biasETy = biasTy.getElementType();
+ Type resultETy = resultTy.getElementType();
+
+ if (inputETy.isInteger() && !inputETy.isInteger(8) &&
+ !inputETy.isInteger(16)) {
+ return op.emitOpError(
+ "input element type can only be of width 8 or 16 when integer type");
+ }
+
+ if (inputETy.isInteger(8) && !resultETy.isInteger(32)) {
+ return op.emitOpError("expect result type to be i32, got ") << resultETy;
+ }
+
+ if (inputETy.isInteger(16) && !resultETy.isInteger(64)) {
+ return op.emitOpError("expect result type to be i64, got ") << resultETy;
+ }
+
+ if (inputETy.isF16() && !resultETy.isF16()) {
+ return op.emitOpError("expect result type to be f16, got ") << resultETy;
+ }
+
+ if (inputETy.isF32() && !resultETy.isF32()) {
+ return op.emitOpError("expect result type to be f32, got ") << resultETy;
+ }
+
+ if (biasETy != resultETy) {
+ return op.emitOpError("element types of bias and result must be the same");
+ }
+
+ TosaExtAccType accType = op.getAccType();
+ if (inputETy.isInteger(8) && accType != TosaExtAccType::INT32) {
+ return op.emitOpError("accumulator type for i8 tensorARM is not i32");
+ }
+
+ if (inputETy.isInteger(16) && accType != TosaExtAccType::INT48) {
+ return op.emitOpError("accumulator type for i16 tensorARM is not i48");
+ }
+
+ if (inputETy.isF16() &&
+ !(accType == TosaExtAccType::FP16 || accType == TosaExtAccType::FP32)) {
+ return op.emitOpError(
+ "accumulator type for f16 tensorARM is not f16 or f32");
+ }
+
+ if (inputETy.isBF16() && accType != TosaExtAccType::FP32) {
+ return op.emitOpError("accumulator type for bf16 tensorARM is not f32");
+ }
+
+ if (inputETy.isF32() && accType != TosaExtAccType::FP32) {
+ return op.emitOpError("accumulator type for f32 tensorARM is not f32");
+ }
+
+ return success();
+}
+
+} // namespace
+
//===----------------------------------------------------------------------===//
// spirv.TosaArgmaxOp
//===----------------------------------------------------------------------===//
@@ -46,4 +112,28 @@ LogicalResult TosaArgMaxOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// spirv.TosaConv2DOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TosaConv2DOp::verify() { return verifyConvOp(*this); }
+
+//===----------------------------------------------------------------------===//
+// spirv.TosaConv3DOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TosaConv3DOp::verify() { return verifyConvOp(*this); }
+
+//===----------------------------------------------------------------------===//
+// SPIRV Tosa DepthwiseConv2D Ops:
+//===----------------------------------------------------------------------===//
+
+LogicalResult TosaDepthwiseConv2DOp::verify() { return verifyConvOp(*this); }
+
+//===----------------------------------------------------------------------===//
+// SPIRV Tosa TransposeConv2D Ops:
+//===----------------------------------------------------------------------===//
+
+LogicalResult TosaTransposeConv2DOp::verify() { return verifyConvOp(*this); }
+
} // namespace mlir::spirv
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
index a6496316f9881..cb7863ba9c1a0 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
@@ -21,3 +21,340 @@ spirv.ARM.Graph @argmax_axis_value_not_in_input_rank_range(%arg0: !spirv.arm.ten
%2 = spirv.Tosa.ArgMax axis = 4, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi32>
spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<3x28x17xi32>
}
+
+//===---------------...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/176908
More information about the Mlir-commits
mailing list