[Mlir-commits] [mlir] [mlir][spirv] Add Pooling, Fourier Transform, and MatMul operations t… (PR #177585)
Jakub Kuderski
llvmlistbot at llvm.org
Fri Jan 23 05:52:39 PST 2026
================
@@ -299,6 +360,234 @@ def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaOp<"DepthwiseConv2D", 4, [Pure,
}
+def SPIRV_TosaFFT2DOp : SPIRV_TosaOp<"FFT2D", 5, [Pure]> {
+ let summary = "Performs FFT2D operation on the input.";
+
+ let description = [{
+ Performs a batched complex 2D Fast Fourier Transform over the input. The
+ complex input values are constructed from the corresponding values in the
+ input_real and input_imag tensors. The resulting values in the output are
+ split into the output_real and output_imag tensors. No normalization is
+ applied on either the forward or inverse versions of the operation.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_fft2d
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_fft2d
+
+ #### Example:
+ ```mlir
+ %0 = spirv.Tosa.FFT2D inverse = true, local_bound = false, %arg0, %arg1 : !spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32> -> !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+ %1 = spirv.CompositeExtract %0[0 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+ %2 = spirv.CompositeExtract %0[1 : i32] : !spirv.struct<(!spirv.arm.tensor<1x32x32xf32>, !spirv.arm.tensor<1x32x32xf32>)>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_BoolConstAttr: $inverse,
+ SPIRV_BoolConstAttr: $local_bound,
+ SPIRV_Float32_TensorArm3D: $input_real,
+ SPIRV_Float32_TensorArm3D: $input_imag
+ );
+
+ let results = (outs
+ SPIRV_Struct_2_Float32_TensorArm3D: $output
+ );
+
+ let hasVerifier = 0;
+
+ let assemblyFormat = [{
+ `inverse` `=` $inverse `,`
+ `local_bound` `=` $local_bound `,`
+ $input_real `,`
+ $input_imag
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::spirv::TensorArmType getInputRealType() {
+ return cast<::mlir::spirv::TensorArmType>(getInputReal().getType());
+ }
+ ::mlir::spirv::TensorArmType getInputImagType() {
+ return cast<::mlir::spirv::TensorArmType>(getInputImag().getType());
+ }
+ ::mlir::spirv::TensorArmType getResultRealType() {
+ auto resultType = cast<StructType>(getType());
+ return cast<::mlir::spirv::TensorArmType>(resultType.getElementType(0));
+ }
+ ::mlir::spirv::TensorArmType getResultImagType() {
+ auto resultType = cast<StructType>(getType());
+ return cast<::mlir::spirv::TensorArmType>(resultType.getElementType(1));
+ }
+ }];
+}
+
+
+def SPIRV_TosaMatMulOp : SPIRV_TosaOp<"MatMul", 6, [Pure,
+ AllElementTypesMatch<["A", "A_zp", "B", "B_zp"]>]> {
+ let summary = "Matrix Multiplication operator.";
+
+ let description = [{
+ Performs two dimensional matrix multiplications.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_matmul
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_matmul
+
+ #### Example:
+ ```mlir
+ %2 = spirv.Tosa.MatMul %arg0, %arg1, %0, %1 : !spirv.arm.tensor<8x2x3xi8>, !spirv.arm.tensor<8x3x8xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<8x2x8xi32>
+ %2 = spirv.Tosa.MatMul %arg0, %arg1, %0, %1 : !spirv.arm.tensor<15x39x50xf16>, !spirv.arm.tensor<15x50x24xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<15x39x24xf16>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_TosaNumerical_TensorArm3D: $A,
+ SPIRV_TosaNumerical_TensorArm3D: $B,
+ SPIRV_TosaNumerical_1DTensorArmOfLength1: $A_zp,
+ SPIRV_TosaNumerical_1DTensorArmOfLength1: $B_zp
+ );
+
+ let results = (outs
+ SPIRV_TosaNumerical_TensorArm3D: $output
+ );
+
+ let hasVerifier = 1;
+
+ let assemblyFormat = [{
+ $A `,`
+ $B `,`
+ $A_zp `,`
+ $B_zp
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::spirv::TensorArmType getAType() {
+ return cast<::mlir::spirv::TensorArmType>(getA().getType());
+ }
+ ::mlir::spirv::TensorArmType getBType() {
+ return cast<::mlir::spirv::TensorArmType>(getB().getType());
+ }
+ ::mlir::spirv::TensorArmType getResultType() {
+ return cast<::mlir::spirv::TensorArmType>(getType());
+ }
+ }];
+}
+
+
+def SPIRV_TosaMaxPool2DOp : SPIRV_TosaOp<"MaxPool2D", 7, [Pure,
+ AllElementTypesMatch<["input", "output"]>]> {
+ let summary = "Performs max pooling on the input.";
+
+ let description = [{
+ Performs a max pooling over the given input tensor. A sliding window of
+ size given by <kernel size> is passed over the input tensor, with the
+ maximum value being placed in the
+ output tensor.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_max_pool2d
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_max_pool2d
+
+ #### Example:
+ ```mlir
+ %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [1, 2], pad = [1, 0, 0, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x3x65537x1xi8> -> !spirv.arm.tensor<1x2x32769x1xi8>
+ %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [2, 2], pad = [1, 0, 1, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x6x65536x1xf32> -> !spirv.arm.tensor<1x3x32769x1xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_Int32_1DTensorArmOfLength2Attr: $kernel,
+ SPIRV_Int32_1DTensorArmOfLength2Attr: $stride,
+ SPIRV_Int32_1DTensorArmOfLength4Attr: $pad,
+ SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
+ SPIRV_TosaNumerical_TensorArm4D: $input
+ );
+
+ let results = (outs
+ SPIRV_TosaNumerical_TensorArm4D: $output
+ );
+
+ let hasVerifier = 0;
----------------
kuhar wrote:
also here
https://github.com/llvm/llvm-project/pull/177585
More information about the Mlir-commits
mailing list