[Mlir-commits] [mlir] [mlir][spirv] Add Pooling, Fourier Transform, and MatMul operations t… (PR #177585)
Davide Grohmann
llvmlistbot at llvm.org
Fri Jan 23 06:47:23 PST 2026
https://github.com/davidegrohmann updated https://github.com/llvm/llvm-project/pull/177585
>From 3249d9f4f0e2f9790bedb4a023dddde0bdcbe9b3 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Fri, 23 Jan 2026 14:17:46 +0100
Subject: [PATCH] [mlir][spirv] Add Pooling, Fourier Transform, and MatMul
operations to TOSA Extended Instruction Set (001000.1)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
This patch expands support for the TOSA Extended Instruction Set
(001000.1) to the SPIR-V dialect in MLIR. The TOSA extended
instruction set provides a standardized set of machine learning
operations designed to be used within spirv.ARM.Graph operations
(corresponding to OpGraphARM in SPV_ARM_graph) and typed with
!spirv.arm.tensor<...> (corresponding to OpTypeTensorARM in
SPV_ARM_tensor).
The change introduces:
* Extending dialect plumbing for import, serialization, and
deserialization of the TOSA extended instruction set.
* The spirv.Tosa.*Pool* pooling operations, spirv.Tosa.MatMul, and
spirv.Tosa.*FFT* (Fourier Transform) operations from TOSA extended
instruction, each lowering to the corresponding OpExtInst.
* Verification enforcing that new convolution operations appears only
within spirv.ARM.Graph regions, operates on !spirv.arm.tensor<...>
types, and is well-formed according to the TOSA 001000.1
specification.
All these operations from TOSA 001000.1 extended instructions are
introduced: [parser, printer, verifier, and round-trip tests using
MLIR’s SPIR-V serialization/deserialization infrastructure are
included.
This work aligns with Khronos SPIR-V TOSA specifications.
Specification:
https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: I6c8813928c0d52541bb05d03244eaf9b2f6c4128
---
.../mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td | 287 +++++++++++++++++-
.../mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td | 17 ++
mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp | 56 ++++
.../SPIRV/IR/tosa-ops-verification.mlir | 57 ++++
mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir | 104 +++++++
mlir/test/Target/SPIRV/tosa-ops.mlir | 173 +++++++++++
6 files changed, 693 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index 1b2f6a923d01b..be7196ee83b2a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -33,6 +33,8 @@ class SPIRV_TosaOp<string mnemonic, int opcode, list<Trait> traits = []> :
Extension<[SPV_ARM_graph]>,
Capability<[SPIRV_C_GraphARM]>
];
+
+ let hasVerifier = 0;
}
@@ -68,7 +70,8 @@ def SPIRV_TosaArgMaxOp : SPIRV_TosaOp<"ArgMax", 0, [Pure]> {
let hasVerifier = 1;
let assemblyFormat = [{
- `axis` `=` $axis `,` `nan_mode` `=` $nan_mode `,`
+ `axis` `=` $axis `,`
+ `nan_mode` `=` $nan_mode `,`
$input
attr-dict `:` type(operands) `->` type(results)
}];
@@ -84,6 +87,66 @@ def SPIRV_TosaArgMaxOp : SPIRV_TosaOp<"ArgMax", 0, [Pure]> {
}
+def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOp<"AvgPool2D", 1, [Pure,
+ AllElementTypesMatch<["input", "input_zp", "output", "output_zp"]>]> {
+ let summary = "Performs average pooling on the input.";
+
+ let description = [{
+ Performs an average pooling over the given input tensor. A sliding
+ window of size given by <kernel size> is passed over the input tensor, with
+ the mean value being placed in the output tensor. When calculating the
+ average, only the number of valid input tensor values, but not padding, are
+ used to calculate the divisor.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_avg_pool2d
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_avg_pool2d
+
+ #### Example:
+ ```mlir
+ %6 = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [1, 2], pad = [0, 1, 0, 0], acc_type = <INT32>, %arg0, %4, %5 : !spirv.arm.tensor<1x3x65537x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x2x32768x1xi8>
+ %6 = spirv.Tosa.AvgPool2D kernel = [2, 2], stride = [1, 1], pad = [1, 0, 0, 0], acc_type = <FP32>, %arg0, %4, %5 : !spirv.arm.tensor<1x2x65533x2xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x2x65532x2xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_Int32_1DTensorArmOfLength2Attr: $kernel,
+ SPIRV_Int32_1DTensorArmOfLength2Attr: $stride,
+ SPIRV_Int32_1DTensorArmOfLength4Attr: $pad,
+ SPIRV_TosaExtAccTypeAttr: $acc_type,
+ SPIRV_TosaNumerical_TensorArm4D: $input,
+ SPIRV_TosaNumerical_1DTensorArmOfLength1: $input_zp,
+ SPIRV_TosaNumerical_1DTensorArmOfLength1: $output_zp
+ );
+
+ let results = (outs
+ SPIRV_TosaNumerical_TensorArm4D: $output
+ );
+
+ let hasVerifier = 1;
+
+ let assemblyFormat = [{
+ `kernel` `=` custom<SPIRV_I32_1DArmTensor>($kernel) `,`
+ `stride` `=` custom<SPIRV_I32_1DArmTensor>($stride) `,`
+ `pad` `=` custom<SPIRV_I32_1DArmTensor>($pad) `,`
+ `acc_type` `=` $acc_type `,`
+ $input `,`
+ $input_zp `,`
+ $output_zp
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::spirv::TensorArmType getInputType() {
+ return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+ }
+ ::mlir::spirv::TensorArmType getResultType() {
+ return cast<::mlir::spirv::TensorArmType>(getType());
+ }
+ }];
+}
+
+
def SPIRV_TosaConv2DOp : SPIRV_TosaOp<"Conv2D", 2, [Pure,
AllElementTypesMatch<["bias", "output"]>,
AllElementTypesMatch<["input", "input_zp"]>,
@@ -299,6 +362,228 @@ def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaOp<"DepthwiseConv2D", 4, [Pure,
}
+def SPIRV_TosaFFT2DOp : SPIRV_TosaOp<"FFT2D", 5, [Pure]> {
+ let summary = "Performs FFT2D operation on the input.";
+
+ let description = [{
+ Performs a batched complex 2D Fast Fourier Transform over the input. The
+ complex input values are constructed from the corresponding values in the
+ input_real and input_imag tensors. The resulting values in the output are
+ split into the output_real and output_imag tensors. No normalization is
+ applied on either the forward or inverse versions of the operation.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_fft2d
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_fft2d
+
+ #### Example:
+ ```mlir
+ %0 = spirv.Tosa.FFT2D inverse = true, local_bound = false, %arg0, %arg1 : !spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32> -> !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+ %1 = spirv.CompositeExtract %0[0 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+ %2 = spirv.CompositeExtract %0[1 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_BoolConstAttr: $inverse,
+ SPIRV_BoolConstAttr: $local_bound,
+ SPIRV_Float32_TensorArm3D: $input_real,
+ SPIRV_Float32_TensorArm3D: $input_imag
+ );
+
+ let results = (outs
+ SPIRV_Struct_2_Float32_TensorArm3D: $output
+ );
+
+ let assemblyFormat = [{
+ `inverse` `=` $inverse `,`
+ `local_bound` `=` $local_bound `,`
+ $input_real `,`
+ $input_imag
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::spirv::TensorArmType getInputRealType() {
+ return cast<::mlir::spirv::TensorArmType>(getInputReal().getType());
+ }
+ ::mlir::spirv::TensorArmType getInputImagType() {
+ return cast<::mlir::spirv::TensorArmType>(getInputImag().getType());
+ }
+ ::mlir::spirv::TensorArmType getResultRealType() {
+ auto resultType = cast<StructType>(getType());
+ return cast<::mlir::spirv::TensorArmType>(resultType.getElementType(0));
+ }
+ ::mlir::spirv::TensorArmType getResultImagType() {
+ auto resultType = cast<StructType>(getType());
+ return cast<::mlir::spirv::TensorArmType>(resultType.getElementType(1));
+ }
+ }];
+}
+
+
+def SPIRV_TosaMatMulOp : SPIRV_TosaOp<"MatMul", 6, [Pure,
+ AllElementTypesMatch<["A", "A_zp", "B", "B_zp"]>]> {
+ let summary = "Matrix Multiplication operator.";
+
+ let description = [{
+ Performs two dimensional matrix multiplications.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_matmul
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_matmul
+
+ #### Example:
+ ```mlir
+ %2 = spirv.Tosa.MatMul %arg0, %arg1, %0, %1 : !spirv.arm.tensor<8x2x3xi8>, !spirv.arm.tensor<8x3x8xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<8x2x8xi32>
+ %2 = spirv.Tosa.MatMul %arg0, %arg1, %0, %1 : !spirv.arm.tensor<15x39x50xf16>, !spirv.arm.tensor<15x50x24xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<15x39x24xf16>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_TosaNumerical_TensorArm3D: $A,
+ SPIRV_TosaNumerical_TensorArm3D: $B,
+ SPIRV_TosaNumerical_1DTensorArmOfLength1: $A_zp,
+ SPIRV_TosaNumerical_1DTensorArmOfLength1: $B_zp
+ );
+
+ let results = (outs
+ SPIRV_TosaNumerical_TensorArm3D: $output
+ );
+
+ let hasVerifier = 1;
+
+ let assemblyFormat = [{
+ $A `,`
+ $B `,`
+ $A_zp `,`
+ $B_zp
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::spirv::TensorArmType getAType() {
+ return cast<::mlir::spirv::TensorArmType>(getA().getType());
+ }
+ ::mlir::spirv::TensorArmType getBType() {
+ return cast<::mlir::spirv::TensorArmType>(getB().getType());
+ }
+ ::mlir::spirv::TensorArmType getResultType() {
+ return cast<::mlir::spirv::TensorArmType>(getType());
+ }
+ }];
+}
+
+
+def SPIRV_TosaMaxPool2DOp : SPIRV_TosaOp<"MaxPool2D", 7, [Pure,
+ AllElementTypesMatch<["input", "output"]>]> {
+ let summary = "Performs max pooling on the input.";
+
+ let description = [{
+ Performs a max pooling over the given input tensor. A sliding window of
+ size given by <kernel size> is passed over the input tensor, with the
+ maximum value being placed in the
+ output tensor.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_max_pool2d
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_max_pool2d
+
+ #### Example:
+ ```mlir
+ %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [1, 2], pad = [1, 0, 0, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x3x65537x1xi8> -> !spirv.arm.tensor<1x2x32769x1xi8>
+ %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [2, 2], pad = [1, 0, 1, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x6x65536x1xf32> -> !spirv.arm.tensor<1x3x32769x1xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_Int32_1DTensorArmOfLength2Attr: $kernel,
+ SPIRV_Int32_1DTensorArmOfLength2Attr: $stride,
+ SPIRV_Int32_1DTensorArmOfLength4Attr: $pad,
+ SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
+ SPIRV_TosaNumerical_TensorArm4D: $input
+ );
+
+ let results = (outs
+ SPIRV_TosaNumerical_TensorArm4D: $output
+ );
+
+ let assemblyFormat = [{
+ `kernel` `=` custom<SPIRV_I32_1DArmTensor>($kernel) `,`
+ `stride` `=` custom<SPIRV_I32_1DArmTensor>($stride) `,`
+ `pad` `=` custom<SPIRV_I32_1DArmTensor>($pad) `,`
+ `nan_mode` `=` $nan_mode `,`
+ $input
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::spirv::TensorArmType getInputType() {
+ return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+ }
+ ::mlir::spirv::TensorArmType getResultType() {
+ return cast<::mlir::spirv::TensorArmType>(getType());
+ }
+ }];
+}
+
+
+def SPIRV_TosaRFFT2DOp : SPIRV_TosaOp<"RFFT2D", 8, [Pure]> {
+ let summary = "Performs RFFT2D operation on the input.";
+
+ let description = [{
+ Performs a batched 2D real-valued Fast Fourier Transform over the input where
+ the input tensor consists of real values producing complex valued output. The
+ complex output values will be split into the output_real and output_imag
+ tensor arguments. RFFT2D takes advantage of Hermitian symmetry to only
+ calculate the first half of the final output axis. Implementations may choose
+ to skip calculation of the imaginary values at (0,0), (0,W/2), (H/2,0), and
+ (H/2, W/2). If the calculation is skipped, the result at that location must be
+ zero.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_rfft2d
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_rfft2d
+
+ #### Example:
+ ```mlir
+ %0 = spirv.Tosa.RFFT2D local_bound = false, %arg0 : !spirv.arm.tensor<1x32x32xf32> -> !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+ %1 = spirv.CompositeExtract %0[0 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+ %2 = spirv.CompositeExtract %0[1 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_BoolConstAttr: $local_bound,
+ SPIRV_Float32_TensorArm3D: $input_real
+ );
+
+ let results = (outs
+ SPIRV_Struct_2_Float32_TensorArm3D: $output
+ );
+
+ let assemblyFormat = [{
+ `local_bound` `=` $local_bound `,`
+ $input_real
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::spirv::TensorArmType getInputRealType() {
+ return cast<::mlir::spirv::TensorArmType>(getInputReal().getType());
+ }
+ ::mlir::spirv::TensorArmType getResultRealType() {
+ auto resultType = cast<StructType>(getType());
+ return cast<::mlir::spirv::TensorArmType>(resultType.getElementType(0));
+ }
+ ::mlir::spirv::TensorArmType getResultImagType() {
+ auto resultType = cast<StructType>(getType());
+ return cast<::mlir::spirv::TensorArmType>(resultType.getElementType(1));
+ }
+ }];
+}
+
+
def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaOp<"TransposeConv2D", 9, [Pure,
AllElementTypesMatch<["bias", "output"]>,
AllElementTypesMatch<["input", "input_zp"]>,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
index 7e2c37f74b437..c8c3068c69739 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -37,7 +37,9 @@ class TensorArmRankOf<list<Type> allowedTypes, list<int> ranks>
[HasAnyRankOfPred<ranks>],
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensorArm">;
+def SPIRV_Float32_TensorArm3D: TensorArmRankOf<[SPIRV_Float32], [3]>;
def SPIRV_TosaNumerical_TensorArm1D : TensorArmRankOf<[SPIRV_TosaNumerical], [1]>;
+def SPIRV_TosaNumerical_TensorArm3D : TensorArmRankOf<[SPIRV_TosaNumerical], [3]>;
def SPIRV_TosaNumerical_TensorArm4D : TensorArmRankOf<[SPIRV_TosaNumerical], [4]>;
def SPIRV_TosaNumerical_TensorArm5D : TensorArmRankOf<[SPIRV_TosaNumerical], [5]>;
@@ -67,4 +69,19 @@ def SPIRV_Int32_1DTensorArmOfLength6Attr : ConfinedAttr<RankedI32ElementsAttr<[6
def SPIRV_TosaNumerical_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_TosaNumerical]>;
+// Struct type
+
+class IsStructOfLengthPred<int numElements> :
+ And<[SPIRV_IsStructType,
+ CPred<[{::llvm::cast<::mlir::spirv::StructType>($_self).getNumElements()
+ == }]
+ # numElements>]>;
+
+class IsStructOfLengthAndType<int numElements, list<Type> allowedTypes>
+ : MixedContainerType<AnyTypeOf<allowedTypes>, IsStructOfLengthPred<numElements>,
+ "::llvm::cast<::mlir::spirv::StructType>($_self).getElementTypes()",
+ "Struct">;
+
+def SPIRV_Struct_2_Float32_TensorArm3D : IsStructOfLengthAndType<2, [SPIRV_Float32_TensorArm3D]>;
+
#endif // MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
index 1d3e1084d5a9e..3036f7b4b3765 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
@@ -99,6 +99,31 @@ LogicalResult TosaArgMaxOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// spirv.TosaAvgPool2DOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TosaAvgPool2DOp::verify() {
+ Type inputETy = getInputType().getElementType();
+
+ TosaExtAccType accType = getAccType();
+ if (inputETy.isInteger() && accType != TosaExtAccType::INT32) {
+ return emitOpError("accumulator type for integer tensorARM is not i32");
+ }
+
+ if (inputETy.isF16() &&
+ !llvm::is_contained({TosaExtAccType::FP16, TosaExtAccType::FP32},
+ accType)) {
+ return emitOpError("accumulator type for f16 tensorARM is not f16/f32");
+ }
+
+ if (inputETy.isF32() && accType != TosaExtAccType::FP32) {
+ return emitOpError("accumulator type for f32 tensorARM is not f32");
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// spirv.TosaConv2DOp
//===----------------------------------------------------------------------===//
@@ -132,6 +157,37 @@ LogicalResult TosaDepthwiseConv2DOp::verify() {
return verifyConvOp(this->getOperation(), inputETy, resultETy, accType);
}
+//===----------------------------------------------------------------------===//
+// spirv.TosaMatMulOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TosaMatMulOp::verify() {
+ Type aETy = getAType().getElementType();
+ Type resultETy = getResultType().getElementType();
+
+ if (aETy.isInteger(8) && !resultETy.isInteger(32)) {
+ return emitOpError("expect result element type to be i32, got ")
+ << resultETy;
+ }
+
+ if (aETy.isInteger(16) && !resultETy.isInteger(64)) {
+ return emitOpError("expect result element type to be i64, got ")
+ << resultETy;
+ }
+
+ if (aETy.isF16() && !resultETy.isF16() && !resultETy.isF32()) {
+ return emitOpError("expect result element type to be f16 or f32, got ")
+ << resultETy;
+ }
+
+ if (aETy.isF32() && !resultETy.isF32()) {
+ return emitOpError("expect result element type to be f32, got ")
+ << resultETy;
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// SPIRV Tosa TransposeConv2D Ops:
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
index 2099630aff0fb..dec19a95e0ac6 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
@@ -358,3 +358,60 @@ spirv.ARM.Graph @transpose_conv2d_accumulator_must_be_either_FP32_for_f32_input_
%7 = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = <FP16>, local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf32>, !spirv.arm.tensor<11x1x1x27xf32>, !spirv.arm.tensor<11xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x34x18x11xf32>
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf32>
}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.AvgPool2D
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @avgpool2d_input_output_different_type(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32768x1xi16>) {
+ %4 = spirv.Constant dense<125> : !spirv.arm.tensor<1xi8>
+ %5 = spirv.Constant dense<-90> : !spirv.arm.tensor<1xi8>
+ // expected-error @+1 {{op failed to verify that all of {input, input_zp, output, output_zp} have same element type}}
+ %6 = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [1, 2], pad = [0, 1, 0, 0], acc_type = <INT32>, %arg0, %4, %5 : !spirv.arm.tensor<1x3x65537x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x2x32768x1xi16>
+ spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x32768x1xi16>
+}
+
+spirv.ARM.Graph @avgpool2d_accumulator_should_be_INT32_for_integer_element_types(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32768x1xi8>) {
+ %4 = spirv.Constant dense<125> : !spirv.arm.tensor<1xi8>
+ %5 = spirv.Constant dense<-90> : !spirv.arm.tensor<1xi8>
+ // expected-error @+1 {{op accumulator type for integer tensorARM is not i32}}
+ %6 = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [1, 2], pad = [0, 1, 0, 0], acc_type = <INT48>, %arg0, %4, %5 : !spirv.arm.tensor<1x3x65537x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x2x32768x1xi8>
+ spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x32768x1xi8>
+}
+
+spirv.ARM.Graph @avgpool2d_accumulator_should_be_either_FP16_or_FP32_for_fp16_element_types(%arg0: !spirv.arm.tensor<1x2x65533x2xf16>) -> (!spirv.arm.tensor<1x2x65532x2xf16>) {
+ %4 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+ %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+ // expected-error @+1 {{op accumulator type for f16 tensorARM is not f16/f32}}
+ %6 = spirv.Tosa.AvgPool2D kernel = [2, 2], stride = [1, 1], pad = [1, 0, 0, 0], acc_type = <INT32>, %arg0, %4, %5 : !spirv.arm.tensor<1x2x65533x2xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x2x65532x2xf16>
+ spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x65532x2xf16>
+}
+
+
+spirv.ARM.Graph @avgpool2d_accumulator_should_be_either_FP32_for_fp32_element_types(%arg0: !spirv.arm.tensor<1x2x65533x2xf32>) -> (!spirv.arm.tensor<1x2x65532x2xf32>) {
+ %4 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+ %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+ // expected-error @+1 {{op accumulator type for f32 tensorARM is not f32}}
+ %6 = spirv.Tosa.AvgPool2D kernel = [2, 2], stride = [1, 1], pad = [1, 0, 0, 0], acc_type = <FP16>, %arg0, %4, %5 : !spirv.arm.tensor<1x2x65533x2xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x2x65532x2xf32>
+ spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x65532x2xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.MatMul
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @matmul_invalid_input_output_element_type_combination(%arg0: !spirv.arm.tensor<1x4x4xi16>, %arg1: !spirv.arm.tensor<1x4x4xi16>, %arg2: !spirv.arm.tensor<1xi16>, %arg3: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<1x4x4xi32>) {
+ // expected-error @+1 {{op expect result element type to be i64, got 'i32'}}
+ %0 = spirv.Tosa.MatMul %arg0, %arg1, %arg2, %arg3 : !spirv.arm.tensor<1x4x4xi16>, !spirv.arm.tensor<1x4x4xi16>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<1x4x4xi32>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1x4x4xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.MaxPool2D
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @maxpool2d_int(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32769x1xi16>) {
+ // expected-error @+1 {{op failed to verify that all of {input, output} have same element type}}
+ %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [1, 2], pad = [1, 0, 0, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x3x65537x1xi8> -> !spirv.arm.tensor<1x2x32769x1xi16>
+ spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x2x32769x1xi16>
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
index 45243a7553c56..1a43e2c95c530 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
@@ -22,6 +22,32 @@ spirv.ARM.Graph @argmax_fp(%arg0: !spirv.arm.tensor<2x2x7x14xf32>) -> (!spirv.ar
spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<2x2x14xi32>
}
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.AvgPool2D - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @avgpool2d_int(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32768x1xi8>) {
+ %4 = spirv.Constant dense<125> : !spirv.arm.tensor<1xi8>
+ %5 = spirv.Constant dense<-90> : !spirv.arm.tensor<1xi8>
+ // CHECK: {{%.*}} = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [1, 2], pad = [0, 1, 0, 0], acc_type = <INT32>, %arg0, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x3x65537x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x2x32768x1xi8>
+ %6 = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [1, 2], pad = [0, 1, 0, 0], acc_type = <INT32>, %arg0, %4, %5 : !spirv.arm.tensor<1x3x65537x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x2x32768x1xi8>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x2x32768x1xi8>
+ spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x32768x1xi8>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.AvgPool2D - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @avgpool2d_fp(%arg0: !spirv.arm.tensor<1x2x65533x2xf32>) -> (!spirv.arm.tensor<1x2x65532x2xf32>) {
+ %4 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+ %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+ // CHECK: {{%.*}} = spirv.Tosa.AvgPool2D kernel = [2, 2], stride = [1, 1], pad = [1, 0, 0, 0], acc_type = <FP32>, %arg0, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x2x65533x2xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x2x65532x2xf32>
+ %6 = spirv.Tosa.AvgPool2D kernel = [2, 2], stride = [1, 1], pad = [1, 0, 0, 0], acc_type = <FP32>, %arg0, %4, %5 : !spirv.arm.tensor<1x2x65533x2xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x2x65532x2xf32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x2x65532x2xf32>
+ spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x65532x2xf32>
+}
+
//===----------------------------------------------------------------------===//
// spirv.TOSA.Conv2D - PRO-INT
//===----------------------------------------------------------------------===//
@@ -100,6 +126,84 @@ spirv.ARM.Graph @depthwiseconv2d_fp(%arg0: !spirv.arm.tensor<1x65540x1x3xf32>, %
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65541x2x3xf32>
}
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.FFT2D - EXT-FFT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @fft2d_fft(%arg0: !spirv.arm.tensor<1x32x32xf32>, %arg1: !spirv.arm.tensor<1x32x32xf32>) -> (!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.FFT2D inverse = true, local_bound = false, %arg0, %arg1 : !spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32> -> !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+ %out = spirv.Tosa.FFT2D inverse = true, local_bound = false, %arg0, %arg1 : !spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32> -> !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+ // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[0 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+ %out0 = spirv.CompositeExtract %out[0 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+ // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[1 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+ %out1 = spirv.CompositeExtract %out[1 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>
+ spirv.ARM.GraphOutputs %out0, %out1 : !spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.MatMul - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @matmul_int(%arg0: !spirv.arm.tensor<8x2x3xi8>, %arg1: !spirv.arm.tensor<8x3x8xi8>) -> (!spirv.arm.tensor<8x2x8xi32>) {
+ %0 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
+ %1 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
+ // CHECK: {{%.*}} = spirv.Tosa.MatMul %arg0, %arg1, {{%.*}}, {{%.*}} : !spirv.arm.tensor<8x2x3xi8>, !spirv.arm.tensor<8x3x8xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<8x2x8xi32>
+ %2 = spirv.Tosa.MatMul %arg0, %arg1, %0, %1 : !spirv.arm.tensor<8x2x3xi8>, !spirv.arm.tensor<8x3x8xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<8x2x8xi32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<8x2x8xi32>
+ spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<8x2x8xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.MatMul - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @matmul_fp(%arg0: !spirv.arm.tensor<15x39x50xf16>, %arg1: !spirv.arm.tensor<15x50x24xf16>) -> (!spirv.arm.tensor<15x39x24xf16>) {
+ %0 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+ %1 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+ // CHECK: {{%.*}} = spirv.Tosa.MatMul %arg0, %arg1, {{%.*}}, {{%.*}} : !spirv.arm.tensor<15x39x50xf16>, !spirv.arm.tensor<15x50x24xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<15x39x24xf16>
+ %2 = spirv.Tosa.MatMul %arg0, %arg1, %0, %1 : !spirv.arm.tensor<15x39x50xf16>, !spirv.arm.tensor<15x50x24xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<15x39x24xf16>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<15x39x24xf16>
+ spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<15x39x24xf16>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.MaxPool2D - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @maxpool2d_int(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32769x1xi8>) {
+ // CHECK: {{%.*}} = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [1, 2], pad = [1, 0, 0, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x3x65537x1xi8> -> !spirv.arm.tensor<1x2x32769x1xi8>
+ %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [1, 2], pad = [1, 0, 0, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x3x65537x1xi8> -> !spirv.arm.tensor<1x2x32769x1xi8>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x2x32769x1xi8>
+ spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x2x32769x1xi8>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.MaxPool2D - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @maxpool2d_fp(%arg0: !spirv.arm.tensor<1x6x65536x1xf32>) -> (!spirv.arm.tensor<1x3x32769x1xf32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [2, 2], pad = [1, 0, 1, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x6x65536x1xf32> -> !spirv.arm.tensor<1x3x32769x1xf32>
+ %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [2, 2], pad = [1, 0, 1, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x6x65536x1xf32> -> !spirv.arm.tensor<1x3x32769x1xf32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x3x32769x1xf32>
+ spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x3x32769x1xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.RFFT2D - EXT-FFT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @rfft2d_fft(%arg0: !spirv.arm.tensor<1x32x32xf32>) -> (!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.RFFT2D local_bound = false, %arg0 : !spirv.arm.tensor<1x32x32xf32> -> !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+ %out = spirv.Tosa.RFFT2D local_bound = false, %arg0 : !spirv.arm.tensor<1x32x32xf32> -> !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+ // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[0 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+ %out0 = spirv.CompositeExtract %out[0 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+ // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[1 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+ %out1 = spirv.CompositeExtract %out[1 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>
+ spirv.ARM.GraphOutputs %out0, %out1 : !spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>
+}
+
//===----------------------------------------------------------------------===//
// spirv.TOSA.TransposeConv2D - PRO-INT
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/tosa-ops.mlir b/mlir/test/Target/SPIRV/tosa-ops.mlir
index edaa000c183a8..1d219b855bec1 100644
--- a/mlir/test/Target/SPIRV/tosa-ops.mlir
+++ b/mlir/test/Target/SPIRV/tosa-ops.mlir
@@ -42,6 +42,48 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader
// -----
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.AvgPool2D - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @avgpool2d_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x3x65537x1xi8>, UniformConstant>
+ spirv.GlobalVariable @avgpool2d_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<1x2x32768x1xi8>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @avgpool2d_int, @avgpool2d_int_arg_0, @avgpool2d_int_res_0
+ spirv.ARM.Graph @avgpool2d_int(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32768x1xi8>) {
+ %4 = spirv.Constant dense<125> : !spirv.arm.tensor<1xi8>
+ %5 = spirv.Constant dense<-90> : !spirv.arm.tensor<1xi8>
+ // CHECK: {{%.*}} = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [1, 2], pad = [0, 1, 0, 0], acc_type = <INT32>, %arg0, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x3x65537x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x2x32768x1xi8>
+ %6 = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [1, 2], pad = [0, 1, 0, 0], acc_type = <INT32>, %arg0, %4, %5 : !spirv.arm.tensor<1x3x65537x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x2x32768x1xi8>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x2x32768x1xi8>
+ spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x32768x1xi8>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.AvgPool2D - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @avgpool2d_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x2x65533x2xf32>, UniformConstant>
+ spirv.GlobalVariable @avgpool2d_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<1x2x65532x2xf32>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @avgpool2d_fp, @avgpool2d_fp_arg_0, @avgpool2d_fp_res_0
+ spirv.ARM.Graph @avgpool2d_fp(%arg0: !spirv.arm.tensor<1x2x65533x2xf32>) -> (!spirv.arm.tensor<1x2x65532x2xf32>) {
+ %4 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+ %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+ // CHECK: {{%.*}} = spirv.Tosa.AvgPool2D kernel = [2, 2], stride = [1, 1], pad = [1, 0, 0, 0], acc_type = <FP32>, %arg0, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x2x65533x2xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x2x65532x2xf32>
+ %6 = spirv.Tosa.AvgPool2D kernel = [2, 2], stride = [1, 1], pad = [1, 0, 0, 0], acc_type = <FP32>, %arg0, %4, %5 : !spirv.arm.tensor<1x2x65533x2xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x2x65532x2xf32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x2x65532x2xf32>
+ spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x65532x2xf32>
+ }
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.TOSA.Conv2D - PRO-INT
//===----------------------------------------------------------------------===//
@@ -180,6 +222,137 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader
// -----
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.FFT2D - EXT-FFT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @fft2d_fft_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x32x32xf32>, UniformConstant>
+ spirv.GlobalVariable @fft2d_fft_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<1x32x32xf32>, UniformConstant>
+ spirv.GlobalVariable @fft2d_fft_res_0 bind(0, 2) : !spirv.ptr<!spirv.arm.tensor<1x32x32xf32>, UniformConstant>
+ spirv.GlobalVariable @fft2d_fft_res_1 bind(0, 3) : !spirv.ptr<!spirv.arm.tensor<1x32x32xf32>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @fft2d_fft, @fft2d_fft_arg_0, @fft2d_fft_arg_1, @fft2d_fft_res_0, @fft2d_fft_res_1
+ spirv.ARM.Graph @fft2d_fft(%arg0: !spirv.arm.tensor<1x32x32xf32>, %arg1: !spirv.arm.tensor<1x32x32xf32>) -> (!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.FFT2D inverse = true, local_bound = false, %arg0, %arg1 : !spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32> -> !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+ %out = spirv.Tosa.FFT2D inverse = true, local_bound = false, %arg0, %arg1 : !spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32> -> !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+ // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[0 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+ %out0 = spirv.CompositeExtract %out[0 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+ // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[1 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+ %out1 = spirv.CompositeExtract %out[1 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>
+ spirv.ARM.GraphOutputs %out0, %out1 : !spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.MatMul - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @matmul_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<8x2x3xi8>, UniformConstant>
+ spirv.GlobalVariable @matmul_int_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<8x3x8xi8>, UniformConstant>
+ spirv.GlobalVariable @matmul_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<8x2x8xi32>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @matmul_int, @matmul_int_arg_0, @matmul_int_arg_1, @matmul_int_res_0
+ spirv.ARM.Graph @matmul_int(%arg0: !spirv.arm.tensor<8x2x3xi8>, %arg1: !spirv.arm.tensor<8x3x8xi8>) -> (!spirv.arm.tensor<8x2x8xi32>) {
+ %0 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
+ %1 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
+ // CHECK: {{%.*}} = spirv.Tosa.MatMul %arg0, %arg1, {{%.*}}, {{%.*}} : !spirv.arm.tensor<8x2x3xi8>, !spirv.arm.tensor<8x3x8xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<8x2x8xi32>
+ %2 = spirv.Tosa.MatMul %arg0, %arg1, %0, %1 : !spirv.arm.tensor<8x2x3xi8>, !spirv.arm.tensor<8x3x8xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<8x2x8xi32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<8x2x8xi32>
+ spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<8x2x8xi32>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.MatMul - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @matmul_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<15x39x50xf16>, UniformConstant>
+ spirv.GlobalVariable @matmul_fp_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<15x50x24xf16>, UniformConstant>
+ spirv.GlobalVariable @matmul_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<15x39x24xf16>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @matmul_fp, @matmul_fp_arg_0, @matmul_fp_arg_1, @matmul_fp_res_0
+ spirv.ARM.Graph @matmul_fp(%arg0: !spirv.arm.tensor<15x39x50xf16>, %arg1: !spirv.arm.tensor<15x50x24xf16>) -> (!spirv.arm.tensor<15x39x24xf16>) {
+ %0 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+ %1 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+ // CHECK: {{%.*}} = spirv.Tosa.MatMul %arg0, %arg1, {{%.*}}, {{%.*}} : !spirv.arm.tensor<15x39x50xf16>, !spirv.arm.tensor<15x50x24xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<15x39x24xf16>
+ %2 = spirv.Tosa.MatMul %arg0, %arg1, %0, %1 : !spirv.arm.tensor<15x39x50xf16>, !spirv.arm.tensor<15x50x24xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<15x39x24xf16>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<15x39x24xf16>
+ spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<15x39x24xf16>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.MaxPool2D - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @maxpool2d_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x3x65537x1xi8>, UniformConstant>
+ spirv.GlobalVariable @maxpool2d_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<1x2x32769x1xi8>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @maxpool2d_int, @maxpool2d_int_arg_0, @maxpool2d_int_res_0
+ spirv.ARM.Graph @maxpool2d_int(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32769x1xi8>) {
+ // CHECK: {{%.*}} = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [1, 2], pad = [1, 0, 0, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x3x65537x1xi8> -> !spirv.arm.tensor<1x2x32769x1xi8>
+ %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [1, 2], pad = [1, 0, 0, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x3x65537x1xi8> -> !spirv.arm.tensor<1x2x32769x1xi8>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x2x32769x1xi8>
+ spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x2x32769x1xi8>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.MaxPool2D - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @maxpool2d_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x6x65536x1xf32>, UniformConstant>
+ spirv.GlobalVariable @maxpool2d_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<1x3x32769x1xf32>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @maxpool2d_fp, @maxpool2d_fp_arg_0, @maxpool2d_fp_res_0
+ spirv.ARM.Graph @maxpool2d_fp(%arg0: !spirv.arm.tensor<1x6x65536x1xf32>) -> (!spirv.arm.tensor<1x3x32769x1xf32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [2, 2], pad = [1, 0, 1, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x6x65536x1xf32> -> !spirv.arm.tensor<1x3x32769x1xf32>
+ %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [2, 2], pad = [1, 0, 1, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x6x65536x1xf32> -> !spirv.arm.tensor<1x3x32769x1xf32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x3x32769x1xf32>
+ spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x3x32769x1xf32>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.RFFT2D - EXT-FFT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @rfft2d_fft_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x32x32xf32>, UniformConstant>
+ spirv.GlobalVariable @rfft2d_fft_res_0 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<1x32x17xf32>, UniformConstant>
+ spirv.GlobalVariable @rfft2d_fft_res_1 bind(0, 2) : !spirv.ptr<!spirv.arm.tensor<1x32x17xf32>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @rfft2d_fft, @rfft2d_fft_arg_0, @rfft2d_fft_res_0, @rfft2d_fft_res_1
+ spirv.ARM.Graph @rfft2d_fft(%arg0: !spirv.arm.tensor<1x32x32xf32>) -> (!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.RFFT2D local_bound = false, %arg0 : !spirv.arm.tensor<1x32x32xf32> -> !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+ %out = spirv.Tosa.RFFT2D local_bound = false, %arg0 : !spirv.arm.tensor<1x32x32xf32> -> !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+ // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[0 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+ %out0 = spirv.CompositeExtract %out[0 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+ // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[1 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+ %out1 = spirv.CompositeExtract %out[1 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>
+ spirv.ARM.GraphOutputs %out0, %out1 : !spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>
+ }
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.TOSA.TransposeConv2D - PRO-INT
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list