[Mlir-commits] [mlir] [mlir][spirv] Add Pooling, Fourier Transform, and MatMul operations t… (PR #177585)

Davide Grohmann llvmlistbot at llvm.org
Fri Jan 23 06:47:23 PST 2026


https://github.com/davidegrohmann updated https://github.com/llvm/llvm-project/pull/177585

>From 3249d9f4f0e2f9790bedb4a023dddde0bdcbe9b3 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Fri, 23 Jan 2026 14:17:46 +0100
Subject: [PATCH] [mlir][spirv] Add Pooling, Fourier Transform, and MatMul
 operations to TOSA Extended Instruction Set (001000.1)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This patch expands support for the TOSA Extended Instruction Set
(001000.1) to the SPIR-V dialect in MLIR. The TOSA extended
instruction set provides a standardized set of machine learning
operations designed to be used within spirv.ARM.Graph operations
(corresponding to OpGraphARM in SPV_ARM_graph) and typed with
!spirv.arm.tensor<...> (corresponding to OpTypeTensorARM in
SPV_ARM_tensor).

The change introduces:

* Extending dialect plumbing for import, serialization, and
  deserialization of the TOSA extended instruction set.
* The spirv.Tosa.*Pool* pooling operations, spirv.Tosa.MatMul, and
  spirv.Tosa.*FFT* (Fourier Transform) operations from TOSA extended
  instruction, each lowering to the corresponding OpExtInst.
* Verification enforcing that new convolution operations appears only
  within spirv.ARM.Graph regions, operates on !spirv.arm.tensor<...>
  types, and is well-formed according to the TOSA 001000.1
  specification.

All these operations from TOSA 001000.1 extended instructions are
introduced: [parser, printer, verifier, and round-trip tests using
MLIR’s SPIR-V serialization/deserialization infrastructure are
included.

This work aligns with Khronos SPIR-V TOSA specifications.

Specification:
https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: I6c8813928c0d52541bb05d03244eaf9b2f6c4128
---
 .../mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td     | 287 +++++++++++++++++-
 .../mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td   |  17 ++
 mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp    |  56 ++++
 .../SPIRV/IR/tosa-ops-verification.mlir       |  57 ++++
 mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir      | 104 +++++++
 mlir/test/Target/SPIRV/tosa-ops.mlir          | 173 +++++++++++
 6 files changed, 693 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index 1b2f6a923d01b..be7196ee83b2a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -33,6 +33,8 @@ class SPIRV_TosaOp<string mnemonic, int opcode, list<Trait> traits = []> :
     Extension<[SPV_ARM_graph]>,
     Capability<[SPIRV_C_GraphARM]>
   ];
+
+  let hasVerifier = 0;
 }
 
 
@@ -68,7 +70,8 @@ def SPIRV_TosaArgMaxOp : SPIRV_TosaOp<"ArgMax", 0, [Pure]> {
   let hasVerifier = 1;
 
   let assemblyFormat = [{
-    `axis` `=` $axis `,` `nan_mode` `=` $nan_mode `,`
+    `axis` `=` $axis `,`
+    `nan_mode` `=` $nan_mode `,`
     $input
     attr-dict `:` type(operands) `->` type(results)
   }];
@@ -84,6 +87,66 @@ def SPIRV_TosaArgMaxOp : SPIRV_TosaOp<"ArgMax", 0, [Pure]> {
 }
 
 
+def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOp<"AvgPool2D", 1, [Pure,
+  AllElementTypesMatch<["input", "input_zp", "output", "output_zp"]>]> {
+  let summary = "Performs average pooling on the input.";
+
+  let description = [{
+    Performs an average pooling over the given input tensor. A sliding
+    window of size given by <kernel size> is passed over the input tensor, with
+    the mean value being placed in the output tensor. When calculating the
+    average, only the number of valid input tensor values, but not padding, are
+    used to calculate the divisor.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_avg_pool2d
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_avg_pool2d
+
+    #### Example:
+    ```mlir
+    %6 = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [1, 2], pad = [0, 1, 0, 0], acc_type = <INT32>, %arg0, %4, %5 : !spirv.arm.tensor<1x3x65537x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x2x32768x1xi8>
+    %6 = spirv.Tosa.AvgPool2D kernel = [2, 2], stride = [1, 1], pad = [1, 0, 0, 0], acc_type = <FP32>, %arg0, %4, %5 : !spirv.arm.tensor<1x2x65533x2xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x2x65532x2xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_Int32_1DTensorArmOfLength2Attr: $kernel,
+    SPIRV_Int32_1DTensorArmOfLength2Attr: $stride,
+    SPIRV_Int32_1DTensorArmOfLength4Attr: $pad,
+    SPIRV_TosaExtAccTypeAttr: $acc_type,
+    SPIRV_TosaNumerical_TensorArm4D: $input,
+    SPIRV_TosaNumerical_1DTensorArmOfLength1: $input_zp,
+    SPIRV_TosaNumerical_1DTensorArmOfLength1: $output_zp
+  );
+
+  let results = (outs
+    SPIRV_TosaNumerical_TensorArm4D: $output
+  );
+
+  let hasVerifier = 1;
+
+  let assemblyFormat = [{
+    `kernel` `=` custom<SPIRV_I32_1DArmTensor>($kernel) `,`
+    `stride` `=` custom<SPIRV_I32_1DArmTensor>($stride) `,`
+    `pad` `=` custom<SPIRV_I32_1DArmTensor>($pad) `,`
+    `acc_type` `=` $acc_type `,`
+    $input `,`
+    $input_zp `,`
+    $output_zp
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::spirv::TensorArmType getInputType() {
+      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+    }
+    ::mlir::spirv::TensorArmType getResultType() {
+      return cast<::mlir::spirv::TensorArmType>(getType());
+    }
+  }];
+}
+
+
 def SPIRV_TosaConv2DOp : SPIRV_TosaOp<"Conv2D", 2, [Pure,
   AllElementTypesMatch<["bias", "output"]>,
   AllElementTypesMatch<["input", "input_zp"]>,
@@ -299,6 +362,228 @@ def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaOp<"DepthwiseConv2D", 4, [Pure,
 }
 
 
+def SPIRV_TosaFFT2DOp : SPIRV_TosaOp<"FFT2D", 5, [Pure]> {
+  let summary = "Performs FFT2D operation on the input.";
+
+  let description = [{
+    Performs a batched complex 2D Fast Fourier Transform over the input. The
+    complex input values are constructed from the corresponding values in the
+    input_real and input_imag tensors. The resulting values in the output are
+    split into the output_real and output_imag tensors. No normalization is
+    applied on either the forward or inverse versions of the operation.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_fft2d
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_fft2d
+
+    #### Example:
+    ```mlir
+    %0 = spirv.Tosa.FFT2D inverse = true, local_bound = false, %arg0, %arg1 : !spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32> -> !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+    %1 = spirv.CompositeExtract %0[0 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+    %2 = spirv.CompositeExtract %0[1 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_BoolConstAttr: $inverse,
+    SPIRV_BoolConstAttr: $local_bound,
+    SPIRV_Float32_TensorArm3D: $input_real,
+    SPIRV_Float32_TensorArm3D: $input_imag
+  );
+
+  let results = (outs
+    SPIRV_Struct_2_Float32_TensorArm3D: $output
+  );
+
+  let assemblyFormat = [{
+    `inverse` `=` $inverse `,`
+    `local_bound` `=` $local_bound `,`
+    $input_real `,`
+    $input_imag
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::spirv::TensorArmType getInputRealType() {
+      return cast<::mlir::spirv::TensorArmType>(getInputReal().getType());
+    }
+    ::mlir::spirv::TensorArmType getInputImagType() {
+      return cast<::mlir::spirv::TensorArmType>(getInputImag().getType());
+    }
+    ::mlir::spirv::TensorArmType getResultRealType() {
+      auto resultType = cast<StructType>(getType());
+      return cast<::mlir::spirv::TensorArmType>(resultType.getElementType(0));
+    }
+    ::mlir::spirv::TensorArmType getResultImagType() {
+      auto resultType = cast<StructType>(getType());
+      return cast<::mlir::spirv::TensorArmType>(resultType.getElementType(1));
+    }
+  }];
+}
+
+
+def SPIRV_TosaMatMulOp : SPIRV_TosaOp<"MatMul", 6, [Pure,
+  AllElementTypesMatch<["A", "A_zp", "B", "B_zp"]>]> {
+  let summary = "Matrix Multiplication operator.";
+
+  let description = [{
+    Performs two dimensional matrix multiplications.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_matmul
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_matmul
+
+    #### Example:
+    ```mlir
+    %2 = spirv.Tosa.MatMul %arg0, %arg1, %0, %1 : !spirv.arm.tensor<8x2x3xi8>, !spirv.arm.tensor<8x3x8xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<8x2x8xi32>
+    %2 = spirv.Tosa.MatMul %arg0, %arg1, %0, %1 : !spirv.arm.tensor<15x39x50xf16>, !spirv.arm.tensor<15x50x24xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<15x39x24xf16>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_TosaNumerical_TensorArm3D: $A,
+    SPIRV_TosaNumerical_TensorArm3D: $B,
+    SPIRV_TosaNumerical_1DTensorArmOfLength1: $A_zp,
+    SPIRV_TosaNumerical_1DTensorArmOfLength1: $B_zp
+  );
+
+  let results = (outs
+    SPIRV_TosaNumerical_TensorArm3D: $output
+  );
+
+  let hasVerifier = 1;
+
+  let assemblyFormat = [{
+    $A `,`
+    $B `,`
+    $A_zp `,`
+    $B_zp
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::spirv::TensorArmType getAType() {
+      return cast<::mlir::spirv::TensorArmType>(getA().getType());
+    }
+    ::mlir::spirv::TensorArmType getBType() {
+      return cast<::mlir::spirv::TensorArmType>(getB().getType());
+    }
+    ::mlir::spirv::TensorArmType getResultType() {
+      return cast<::mlir::spirv::TensorArmType>(getType());
+    }
+  }];
+}
+
+
+def SPIRV_TosaMaxPool2DOp : SPIRV_TosaOp<"MaxPool2D", 7, [Pure,
+  AllElementTypesMatch<["input", "output"]>]> {
+  let summary = "Performs max pooling on the input.";
+
+  let description = [{
+    Performs a max pooling over the given input tensor. A sliding window of
+    size given by <kernel size> is passed over the input tensor, with the
+    maximum value being placed in the
+    output tensor.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_max_pool2d
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_max_pool2d
+
+    #### Example:
+    ```mlir
+    %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [1, 2], pad = [1, 0, 0, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x3x65537x1xi8> -> !spirv.arm.tensor<1x2x32769x1xi8>
+    %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [2, 2], pad = [1, 0, 1, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x6x65536x1xf32> -> !spirv.arm.tensor<1x3x32769x1xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_Int32_1DTensorArmOfLength2Attr: $kernel,
+    SPIRV_Int32_1DTensorArmOfLength2Attr: $stride,
+    SPIRV_Int32_1DTensorArmOfLength4Attr: $pad,
+    SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
+    SPIRV_TosaNumerical_TensorArm4D: $input
+  );
+
+  let results = (outs
+    SPIRV_TosaNumerical_TensorArm4D: $output
+  );
+
+  let assemblyFormat = [{
+    `kernel` `=` custom<SPIRV_I32_1DArmTensor>($kernel) `,`
+    `stride` `=` custom<SPIRV_I32_1DArmTensor>($stride) `,`
+    `pad` `=` custom<SPIRV_I32_1DArmTensor>($pad) `,`
+    `nan_mode` `=` $nan_mode `,`
+    $input
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::spirv::TensorArmType getInputType() {
+      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+    }
+    ::mlir::spirv::TensorArmType getResultType() {
+      return cast<::mlir::spirv::TensorArmType>(getType());
+    }
+  }];
+}
+
+
+def SPIRV_TosaRFFT2DOp : SPIRV_TosaOp<"RFFT2D", 8, [Pure]> {
+  let summary = "Performs RFFT2D operation on the input.";
+
+  let description = [{
+    Performs a batched 2D real-valued Fast Fourier Transform over the input where
+    the input tensor consists of real values producing complex valued output. The
+    complex output values will be split into the output_real and output_imag
+    tensor arguments. RFFT2D takes advantage of Hermitian symmetry to only
+    calculate the first half of the final output axis. Implementations may choose
+    to skip calculation of the imaginary values at (0,0), (0,W/2), (H/2,0), and
+    (H/2, W/2). If the calculation is skipped, the result at that location must be
+    zero.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_rfft2d
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_rfft2d
+
+    #### Example:
+    ```mlir
+    %0 = spirv.Tosa.RFFT2D local_bound = false, %arg0 : !spirv.arm.tensor<1x32x32xf32> -> !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+    %1 = spirv.CompositeExtract %0[0 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+    %2 = spirv.CompositeExtract %0[1 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_BoolConstAttr: $local_bound,
+    SPIRV_Float32_TensorArm3D: $input_real
+  );
+
+  let results = (outs
+    SPIRV_Struct_2_Float32_TensorArm3D: $output
+  );
+
+  let assemblyFormat = [{
+    `local_bound` `=` $local_bound `,`
+    $input_real
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::spirv::TensorArmType getInputRealType() {
+      return cast<::mlir::spirv::TensorArmType>(getInputReal().getType());
+    }
+    ::mlir::spirv::TensorArmType getResultRealType() {
+      auto resultType = cast<StructType>(getType());
+      return cast<::mlir::spirv::TensorArmType>(resultType.getElementType(0));
+    }
+    ::mlir::spirv::TensorArmType getResultImagType() {
+      auto resultType = cast<StructType>(getType());
+      return cast<::mlir::spirv::TensorArmType>(resultType.getElementType(1));
+    }
+  }];
+}
+
+
 def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaOp<"TransposeConv2D", 9, [Pure,
   AllElementTypesMatch<["bias", "output"]>,
   AllElementTypesMatch<["input", "input_zp"]>,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
index 7e2c37f74b437..c8c3068c69739 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -37,7 +37,9 @@ class TensorArmRankOf<list<Type> allowedTypes, list<int> ranks>
       [HasAnyRankOfPred<ranks>],
       !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensorArm">;
 
+def SPIRV_Float32_TensorArm3D: TensorArmRankOf<[SPIRV_Float32], [3]>;
 def SPIRV_TosaNumerical_TensorArm1D : TensorArmRankOf<[SPIRV_TosaNumerical], [1]>;
+def SPIRV_TosaNumerical_TensorArm3D : TensorArmRankOf<[SPIRV_TosaNumerical], [3]>;
 def SPIRV_TosaNumerical_TensorArm4D : TensorArmRankOf<[SPIRV_TosaNumerical], [4]>;
 def SPIRV_TosaNumerical_TensorArm5D : TensorArmRankOf<[SPIRV_TosaNumerical], [5]>;
 
@@ -67,4 +69,19 @@ def SPIRV_Int32_1DTensorArmOfLength6Attr : ConfinedAttr<RankedI32ElementsAttr<[6
 
 def SPIRV_TosaNumerical_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_TosaNumerical]>;
 
+// Struct type
+
+class IsStructOfLengthPred<int numElements> :
+  And<[SPIRV_IsStructType,
+      CPred<[{::llvm::cast<::mlir::spirv::StructType>($_self).getNumElements()
+              == }]
+            # numElements>]>;
+
+class IsStructOfLengthAndType<int numElements, list<Type> allowedTypes>
+    : MixedContainerType<AnyTypeOf<allowedTypes>, IsStructOfLengthPred<numElements>,
+                         "::llvm::cast<::mlir::spirv::StructType>($_self).getElementTypes()",
+                         "Struct">;
+
+def SPIRV_Struct_2_Float32_TensorArm3D : IsStructOfLengthAndType<2, [SPIRV_Float32_TensorArm3D]>;
+
 #endif // MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
index 1d3e1084d5a9e..3036f7b4b3765 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
@@ -99,6 +99,31 @@ LogicalResult TosaArgMaxOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.TosaAvgPool2DOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TosaAvgPool2DOp::verify() {
+  Type inputETy = getInputType().getElementType();
+
+  TosaExtAccType accType = getAccType();
+  if (inputETy.isInteger() && accType != TosaExtAccType::INT32) {
+    return emitOpError("accumulator type for integer tensorARM is not i32");
+  }
+
+  if (inputETy.isF16() &&
+      !llvm::is_contained({TosaExtAccType::FP16, TosaExtAccType::FP32},
+                          accType)) {
+    return emitOpError("accumulator type for f16 tensorARM is not f16/f32");
+  }
+
+  if (inputETy.isF32() && accType != TosaExtAccType::FP32) {
+    return emitOpError("accumulator type for f32 tensorARM is not f32");
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.TosaConv2DOp
 //===----------------------------------------------------------------------===//
@@ -132,6 +157,37 @@ LogicalResult TosaDepthwiseConv2DOp::verify() {
   return verifyConvOp(this->getOperation(), inputETy, resultETy, accType);
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.TosaMatMulOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TosaMatMulOp::verify() {
+  Type aETy = getAType().getElementType();
+  Type resultETy = getResultType().getElementType();
+
+  if (aETy.isInteger(8) && !resultETy.isInteger(32)) {
+    return emitOpError("expect result element type to be i32, got ")
+           << resultETy;
+  }
+
+  if (aETy.isInteger(16) && !resultETy.isInteger(64)) {
+    return emitOpError("expect result element type to be i64, got ")
+           << resultETy;
+  }
+
+  if (aETy.isF16() && !resultETy.isF16() && !resultETy.isF32()) {
+    return emitOpError("expect result element type to be f16 or f32, got ")
+           << resultETy;
+  }
+
+  if (aETy.isF32() && !resultETy.isF32()) {
+    return emitOpError("expect result element type to be f32, got ")
+           << resultETy;
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // SPIRV Tosa TransposeConv2D Ops:
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
index 2099630aff0fb..dec19a95e0ac6 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
@@ -358,3 +358,60 @@ spirv.ARM.Graph @transpose_conv2d_accumulator_must_be_either_FP32_for_f32_input_
   %7 = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = <FP16>, local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf32>, !spirv.arm.tensor<11x1x1x27xf32>, !spirv.arm.tensor<11xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x34x18x11xf32>
   spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf32>
 }
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.AvgPool2D
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @avgpool2d_input_output_different_type(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32768x1xi16>) {
+  %4 = spirv.Constant dense<125> : !spirv.arm.tensor<1xi8>
+  %5 = spirv.Constant dense<-90> : !spirv.arm.tensor<1xi8>
+  // expected-error @+1 {{op failed to verify that all of {input, input_zp, output, output_zp} have same element type}}
+  %6 = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [1, 2], pad = [0, 1, 0, 0], acc_type = <INT32>, %arg0, %4, %5 : !spirv.arm.tensor<1x3x65537x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x2x32768x1xi16>
+  spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x32768x1xi16>
+}
+
+spirv.ARM.Graph @avgpool2d_accumulator_should_be_INT32_for_integer_element_types(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32768x1xi8>) {
+  %4 = spirv.Constant dense<125> : !spirv.arm.tensor<1xi8>
+  %5 = spirv.Constant dense<-90> : !spirv.arm.tensor<1xi8>
+  // expected-error @+1 {{op accumulator type for integer tensorARM is not i32}}
+  %6 = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [1, 2], pad = [0, 1, 0, 0], acc_type = <INT48>, %arg0, %4, %5 : !spirv.arm.tensor<1x3x65537x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x2x32768x1xi8>
+  spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x32768x1xi8>
+}
+
+spirv.ARM.Graph @avgpool2d_accumulator_should_be_either_FP16_or_FP32_for_fp16_element_types(%arg0: !spirv.arm.tensor<1x2x65533x2xf16>) -> (!spirv.arm.tensor<1x2x65532x2xf16>) {
+  %4 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+  // expected-error @+1 {{op accumulator type for f16 tensorARM is not f16/f32}}
+  %6 = spirv.Tosa.AvgPool2D kernel = [2, 2], stride = [1, 1], pad = [1, 0, 0, 0], acc_type = <INT32>, %arg0, %4, %5 : !spirv.arm.tensor<1x2x65533x2xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x2x65532x2xf16>
+  spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x65532x2xf16>
+}
+
+
+spirv.ARM.Graph @avgpool2d_accumulator_should_be_either_FP32_for_fp32_element_types(%arg0: !spirv.arm.tensor<1x2x65533x2xf32>) -> (!spirv.arm.tensor<1x2x65532x2xf32>) {
+  %4 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+  // expected-error @+1 {{op accumulator type for f32 tensorARM is not f32}}
+  %6 = spirv.Tosa.AvgPool2D kernel = [2, 2], stride = [1, 1], pad = [1, 0, 0, 0], acc_type = <FP16>, %arg0, %4, %5 : !spirv.arm.tensor<1x2x65533x2xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x2x65532x2xf32>
+  spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x65532x2xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.MatMul
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @matmul_invalid_input_output_element_type_combination(%arg0: !spirv.arm.tensor<1x4x4xi16>, %arg1: !spirv.arm.tensor<1x4x4xi16>, %arg2: !spirv.arm.tensor<1xi16>, %arg3: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<1x4x4xi32>) {
+  // expected-error @+1 {{op expect result element type to be i64, got 'i32'}}
+  %0 = spirv.Tosa.MatMul %arg0, %arg1, %arg2, %arg3 : !spirv.arm.tensor<1x4x4xi16>, !spirv.arm.tensor<1x4x4xi16>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<1x4x4xi32>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1x4x4xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.MaxPool2D
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @maxpool2d_int(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32769x1xi16>) {
+  // expected-error @+1 {{op failed to verify that all of {input, output} have same element type}}
+  %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [1, 2], pad = [1, 0, 0, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x3x65537x1xi8> -> !spirv.arm.tensor<1x2x32769x1xi16>
+  spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x2x32769x1xi16>
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
index 45243a7553c56..1a43e2c95c530 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
@@ -22,6 +22,32 @@ spirv.ARM.Graph @argmax_fp(%arg0: !spirv.arm.tensor<2x2x7x14xf32>) -> (!spirv.ar
   spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<2x2x14xi32>
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.AvgPool2D - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @avgpool2d_int(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32768x1xi8>) {
+  %4 = spirv.Constant dense<125> : !spirv.arm.tensor<1xi8>
+  %5 = spirv.Constant dense<-90> : !spirv.arm.tensor<1xi8>
+  // CHECK: {{%.*}} = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [1, 2], pad = [0, 1, 0, 0], acc_type = <INT32>, %arg0, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x3x65537x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x2x32768x1xi8>
+  %6 = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [1, 2], pad = [0, 1, 0, 0], acc_type = <INT32>, %arg0, %4, %5 : !spirv.arm.tensor<1x3x65537x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x2x32768x1xi8>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x2x32768x1xi8>
+  spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x32768x1xi8>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.AvgPool2D - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @avgpool2d_fp(%arg0: !spirv.arm.tensor<1x2x65533x2xf32>) -> (!spirv.arm.tensor<1x2x65532x2xf32>) {
+  %4 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+  %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+  // CHECK: {{%.*}} = spirv.Tosa.AvgPool2D kernel = [2, 2], stride = [1, 1], pad = [1, 0, 0, 0], acc_type = <FP32>, %arg0, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x2x65533x2xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x2x65532x2xf32>
+  %6 = spirv.Tosa.AvgPool2D kernel = [2, 2], stride = [1, 1], pad = [1, 0, 0, 0], acc_type = <FP32>, %arg0, %4, %5 : !spirv.arm.tensor<1x2x65533x2xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x2x65532x2xf32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x2x65532x2xf32>
+  spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x65532x2xf32>
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.TOSA.Conv2D - PRO-INT
 //===----------------------------------------------------------------------===//
@@ -100,6 +126,84 @@ spirv.ARM.Graph @depthwiseconv2d_fp(%arg0: !spirv.arm.tensor<1x65540x1x3xf32>, %
   spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65541x2x3xf32>
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.FFT2D - EXT-FFT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @fft2d_fft(%arg0: !spirv.arm.tensor<1x32x32xf32>, %arg1: !spirv.arm.tensor<1x32x32xf32>) -> (!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>) {
+  // CHECK: {{%.*}} = spirv.Tosa.FFT2D inverse = true, local_bound = false, %arg0, %arg1 : !spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32> -> !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+  %out = spirv.Tosa.FFT2D inverse = true, local_bound = false, %arg0, %arg1 : !spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32> -> !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+  // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[0 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+  %out0 = spirv.CompositeExtract %out[0 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+  // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[1 : i32] :  !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+  %out1 = spirv.CompositeExtract %out[1 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>
+  spirv.ARM.GraphOutputs  %out0, %out1 : !spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.MatMul - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @matmul_int(%arg0: !spirv.arm.tensor<8x2x3xi8>, %arg1: !spirv.arm.tensor<8x3x8xi8>) -> (!spirv.arm.tensor<8x2x8xi32>) {
+  %0 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
+  %1 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
+  // CHECK: {{%.*}} = spirv.Tosa.MatMul %arg0, %arg1, {{%.*}}, {{%.*}} : !spirv.arm.tensor<8x2x3xi8>, !spirv.arm.tensor<8x3x8xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<8x2x8xi32>
+  %2 = spirv.Tosa.MatMul %arg0, %arg1, %0, %1 : !spirv.arm.tensor<8x2x3xi8>, !spirv.arm.tensor<8x3x8xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<8x2x8xi32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<8x2x8xi32>
+  spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<8x2x8xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.MatMul - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @matmul_fp(%arg0: !spirv.arm.tensor<15x39x50xf16>, %arg1: !spirv.arm.tensor<15x50x24xf16>) -> (!spirv.arm.tensor<15x39x24xf16>) {
+  %0 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+  %1 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+  // CHECK: {{%.*}} = spirv.Tosa.MatMul %arg0, %arg1, {{%.*}}, {{%.*}} : !spirv.arm.tensor<15x39x50xf16>, !spirv.arm.tensor<15x50x24xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<15x39x24xf16>
+  %2 = spirv.Tosa.MatMul %arg0, %arg1, %0, %1 : !spirv.arm.tensor<15x39x50xf16>, !spirv.arm.tensor<15x50x24xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<15x39x24xf16>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<15x39x24xf16>
+  spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<15x39x24xf16>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.MaxPool2D - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @maxpool2d_int(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32769x1xi8>) {
+  // CHECK: {{%.*}} = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [1, 2], pad = [1, 0, 0, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x3x65537x1xi8> -> !spirv.arm.tensor<1x2x32769x1xi8>
+  %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [1, 2], pad = [1, 0, 0, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x3x65537x1xi8> -> !spirv.arm.tensor<1x2x32769x1xi8>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x2x32769x1xi8>
+  spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x2x32769x1xi8>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.MaxPool2D - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @maxpool2d_fp(%arg0: !spirv.arm.tensor<1x6x65536x1xf32>) -> (!spirv.arm.tensor<1x3x32769x1xf32>) {
+  // CHECK: {{%.*}} = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [2, 2], pad = [1, 0, 1, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x6x65536x1xf32> -> !spirv.arm.tensor<1x3x32769x1xf32>
+  %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [2, 2], pad = [1, 0, 1, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x6x65536x1xf32> -> !spirv.arm.tensor<1x3x32769x1xf32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x3x32769x1xf32>
+  spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x3x32769x1xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.RFFT2D - EXT-FFT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @rfft2d_fft(%arg0: !spirv.arm.tensor<1x32x32xf32>) -> (!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>) {
+  // CHECK: {{%.*}} = spirv.Tosa.RFFT2D local_bound = false, %arg0 : !spirv.arm.tensor<1x32x32xf32> -> !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+  %out = spirv.Tosa.RFFT2D local_bound = false, %arg0 : !spirv.arm.tensor<1x32x32xf32> -> !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+  // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[0 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+  %out0 = spirv.CompositeExtract %out[0 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+  // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[1 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+  %out1 = spirv.CompositeExtract %out[1 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>
+  spirv.ARM.GraphOutputs %out0, %out1 : !spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.TOSA.TransposeConv2D - PRO-INT
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/tosa-ops.mlir b/mlir/test/Target/SPIRV/tosa-ops.mlir
index edaa000c183a8..1d219b855bec1 100644
--- a/mlir/test/Target/SPIRV/tosa-ops.mlir
+++ b/mlir/test/Target/SPIRV/tosa-ops.mlir
@@ -42,6 +42,48 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.AvgPool2D - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @avgpool2d_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x3x65537x1xi8>, UniformConstant>
+  spirv.GlobalVariable @avgpool2d_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<1x2x32768x1xi8>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @avgpool2d_int, @avgpool2d_int_arg_0, @avgpool2d_int_res_0
+  spirv.ARM.Graph @avgpool2d_int(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32768x1xi8>) {
+    %4 = spirv.Constant dense<125> : !spirv.arm.tensor<1xi8>
+    %5 = spirv.Constant dense<-90> : !spirv.arm.tensor<1xi8>
+    // CHECK: {{%.*}} = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [1, 2], pad = [0, 1, 0, 0], acc_type = <INT32>, %arg0, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x3x65537x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x2x32768x1xi8>
+    %6 = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [1, 2], pad = [0, 1, 0, 0], acc_type = <INT32>, %arg0, %4, %5 : !spirv.arm.tensor<1x3x65537x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x2x32768x1xi8>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x2x32768x1xi8>
+    spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x32768x1xi8>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.AvgPool2D - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @avgpool2d_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x2x65533x2xf32>, UniformConstant>
+  spirv.GlobalVariable @avgpool2d_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<1x2x65532x2xf32>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @avgpool2d_fp, @avgpool2d_fp_arg_0, @avgpool2d_fp_res_0
+  spirv.ARM.Graph @avgpool2d_fp(%arg0: !spirv.arm.tensor<1x2x65533x2xf32>) -> (!spirv.arm.tensor<1x2x65532x2xf32>) {
+    %4 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+    %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32>
+    // CHECK: {{%.*}} = spirv.Tosa.AvgPool2D kernel = [2, 2], stride = [1, 1], pad = [1, 0, 0, 0], acc_type = <FP32>, %arg0, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x2x65533x2xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x2x65532x2xf32>
+    %6 = spirv.Tosa.AvgPool2D kernel = [2, 2], stride = [1, 1], pad = [1, 0, 0, 0], acc_type = <FP32>, %arg0, %4, %5 : !spirv.arm.tensor<1x2x65533x2xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x2x65532x2xf32>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x2x65532x2xf32>
+    spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x65532x2xf32>
+  }
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.TOSA.Conv2D - PRO-INT
 //===----------------------------------------------------------------------===//
@@ -180,6 +222,137 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.FFT2D - EXT-FFT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @fft2d_fft_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x32x32xf32>, UniformConstant>
+  spirv.GlobalVariable @fft2d_fft_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<1x32x32xf32>, UniformConstant>
+  spirv.GlobalVariable @fft2d_fft_res_0 bind(0, 2) : !spirv.ptr<!spirv.arm.tensor<1x32x32xf32>, UniformConstant>
+  spirv.GlobalVariable @fft2d_fft_res_1 bind(0, 3) : !spirv.ptr<!spirv.arm.tensor<1x32x32xf32>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @fft2d_fft, @fft2d_fft_arg_0, @fft2d_fft_arg_1, @fft2d_fft_res_0, @fft2d_fft_res_1
+  spirv.ARM.Graph @fft2d_fft(%arg0: !spirv.arm.tensor<1x32x32xf32>, %arg1: !spirv.arm.tensor<1x32x32xf32>) -> (!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>) {
+    // CHECK: {{%.*}} = spirv.Tosa.FFT2D inverse = true, local_bound = false, %arg0, %arg1 : !spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32> -> !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+    %out = spirv.Tosa.FFT2D inverse = true, local_bound = false, %arg0, %arg1 : !spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32> -> !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+    // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[0 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+    %out0 = spirv.CompositeExtract %out[0 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+    // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[1 : i32] :  !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+    %out1 = spirv.CompositeExtract %out[1 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>
+    spirv.ARM.GraphOutputs  %out0, %out1 : !spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.MatMul - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @matmul_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<8x2x3xi8>, UniformConstant>
+  spirv.GlobalVariable @matmul_int_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<8x3x8xi8>, UniformConstant>
+  spirv.GlobalVariable @matmul_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<8x2x8xi32>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @matmul_int, @matmul_int_arg_0, @matmul_int_arg_1, @matmul_int_res_0
+  spirv.ARM.Graph @matmul_int(%arg0: !spirv.arm.tensor<8x2x3xi8>, %arg1: !spirv.arm.tensor<8x3x8xi8>) -> (!spirv.arm.tensor<8x2x8xi32>) {
+    %0 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
+    %1 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
+    // CHECK: {{%.*}} = spirv.Tosa.MatMul %arg0, %arg1, {{%.*}}, {{%.*}} : !spirv.arm.tensor<8x2x3xi8>, !spirv.arm.tensor<8x3x8xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<8x2x8xi32>
+    %2 = spirv.Tosa.MatMul %arg0, %arg1, %0, %1 : !spirv.arm.tensor<8x2x3xi8>, !spirv.arm.tensor<8x3x8xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<8x2x8xi32>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<8x2x8xi32>
+    spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<8x2x8xi32>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.MatMul - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @matmul_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<15x39x50xf16>, UniformConstant>
+  spirv.GlobalVariable @matmul_fp_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<15x50x24xf16>, UniformConstant>
+  spirv.GlobalVariable @matmul_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<15x39x24xf16>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @matmul_fp, @matmul_fp_arg_0, @matmul_fp_arg_1, @matmul_fp_res_0
+  spirv.ARM.Graph @matmul_fp(%arg0: !spirv.arm.tensor<15x39x50xf16>, %arg1: !spirv.arm.tensor<15x50x24xf16>) -> (!spirv.arm.tensor<15x39x24xf16>) {
+    %0 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+    %1 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16>
+    // CHECK: {{%.*}} = spirv.Tosa.MatMul %arg0, %arg1, {{%.*}}, {{%.*}} : !spirv.arm.tensor<15x39x50xf16>, !spirv.arm.tensor<15x50x24xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<15x39x24xf16>
+    %2 = spirv.Tosa.MatMul %arg0, %arg1, %0, %1 : !spirv.arm.tensor<15x39x50xf16>, !spirv.arm.tensor<15x50x24xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<15x39x24xf16>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<15x39x24xf16>
+    spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<15x39x24xf16>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.MaxPool2D - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @maxpool2d_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x3x65537x1xi8>, UniformConstant>
+  spirv.GlobalVariable @maxpool2d_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<1x2x32769x1xi8>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @maxpool2d_int, @maxpool2d_int_arg_0, @maxpool2d_int_res_0
+  spirv.ARM.Graph @maxpool2d_int(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32769x1xi8>) {
+    // CHECK: {{%.*}} = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [1, 2], pad = [1, 0, 0, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x3x65537x1xi8> -> !spirv.arm.tensor<1x2x32769x1xi8>
+    %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [1, 2], pad = [1, 0, 0, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x3x65537x1xi8> -> !spirv.arm.tensor<1x2x32769x1xi8>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x2x32769x1xi8>
+    spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x2x32769x1xi8>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.MaxPool2D - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @maxpool2d_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x6x65536x1xf32>, UniformConstant>
+  spirv.GlobalVariable @maxpool2d_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<1x3x32769x1xf32>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @maxpool2d_fp, @maxpool2d_fp_arg_0, @maxpool2d_fp_res_0
+  spirv.ARM.Graph @maxpool2d_fp(%arg0: !spirv.arm.tensor<1x6x65536x1xf32>) -> (!spirv.arm.tensor<1x3x32769x1xf32>) {
+    // CHECK: {{%.*}} = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [2, 2], pad = [1, 0, 1, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x6x65536x1xf32> -> !spirv.arm.tensor<1x3x32769x1xf32>
+    %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [2, 2], pad = [1, 0, 1, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x6x65536x1xf32> -> !spirv.arm.tensor<1x3x32769x1xf32>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x3x32769x1xf32>
+    spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x3x32769x1xf32>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.RFFT2D - EXT-FFT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @rfft2d_fft_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x32x32xf32>, UniformConstant>
+  spirv.GlobalVariable @rfft2d_fft_res_0 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<1x32x17xf32>, UniformConstant>
+  spirv.GlobalVariable @rfft2d_fft_res_1 bind(0, 2) : !spirv.ptr<!spirv.arm.tensor<1x32x17xf32>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @rfft2d_fft, @rfft2d_fft_arg_0, @rfft2d_fft_res_0, @rfft2d_fft_res_1
+  spirv.ARM.Graph @rfft2d_fft(%arg0: !spirv.arm.tensor<1x32x32xf32>) -> (!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>) {
+    // CHECK: {{%.*}} = spirv.Tosa.RFFT2D local_bound = false, %arg0 : !spirv.arm.tensor<1x32x32xf32> -> !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+    %out = spirv.Tosa.RFFT2D local_bound = false, %arg0 : !spirv.arm.tensor<1x32x32xf32> -> !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+    // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[0 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+    %out0 = spirv.CompositeExtract %out[0 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+    // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[1 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+    %out1 = spirv.CompositeExtract %out[1 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>)>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>
+    spirv.ARM.GraphOutputs %out0, %out1 : !spirv.arm.tensor<1x32x17xf32>, !spirv.arm.tensor<1x32x17xf32>
+  }
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.TOSA.TransposeConv2D - PRO-INT
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list