[Mlir-commits] [mlir] [mlir][spirv] Add Pooling, Fourier Transform, and MatMul operations t… (PR #177585)

Davide Grohmann llvmlistbot at llvm.org
Wed Jan 28 05:33:28 PST 2026


================
@@ -65,26 +94,94 @@ def SPIRV_TosaArgMaxOp : SPIRV_TosaOp<"ArgMax", 0, [Pure]> {
     SPIRV_Int32_TensorArmUpTo5D: $output
   );
 
-  let hasVerifier = 1;
-
   let assemblyFormat = [{
-    `axis` `=` $axis `,` `nan_mode` `=` $nan_mode `,`
+    `axis` `=` $axis `,`
+    `nan_mode` `=` $nan_mode `,`
     $input
     attr-dict `:` type(operands) `->` type(results)
   }];
 
-  let extraClassDeclaration = [{
+  let extraClassDeclaration = extraBaseClassDeclaration#[{
     ::mlir::spirv::TensorArmType getInputType() {
       return cast<::mlir::spirv::TensorArmType>(getInput().getType());
     }
-    ::mlir::spirv::TensorArmType getResultType() {
-      return cast<::mlir::spirv::TensorArmType>(getType());
+  }];
+}
+
+
+def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [Pure,
+  TypeImpliesAccType<"input", I8, ["INT32"]>,
+  TypeImpliesAccType<"input", I16, ["INT32"]>,
+  TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
+  TypeImpliesAccType<"input", BF16, ["FP32"]>,
+  TypeImpliesAccType<"input", F32, ["FP32"]>,
+  AllElementTypesMatch<["input", "input_zp", "output", "output_zp"]>]> {
+  let summary = "Performs average pooling on the input.";
+
+  let description = [{
+    Performs an average pooling over the given input tensor. A sliding
+    window of size given by <kernel size> is passed over the input tensor, with
+    the mean value being placed in the output tensor. When calculating the
+    average, only the number of valid input tensor values, but not padding, are
+    used to calculate the divisor.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_avg_pool2d
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_avg_pool2d
+
+    #### Example:
+    ```mlir
+    %6 = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [1, 2], pad = [0, 1, 0, 0], acc_type = <INT32>, %arg0, %4, %5 : !spirv.arm.tensor<1x3x65537x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x2x32768x1xi8>
+    %6 = spirv.Tosa.AvgPool2D kernel = [2, 2], stride = [1, 1], pad = [1, 0, 0, 0], acc_type = <FP32>, %arg0, %4, %5 : !spirv.arm.tensor<1x2x65533x2xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x2x65532x2xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_Int32_1DTensorArmOfLength2Attr: $kernel,
+    SPIRV_Int32_1DTensorArmOfLength2Attr: $stride,
+    SPIRV_Int32_1DTensorArmOfLength4Attr: $pad,
+    SPIRV_TosaExtAccTypeAttr: $acc_type,
+    SPIRV_TosaNumerical_TensorArm4D: $input,
+    SPIRV_TosaNumerical_1DTensorArmOfLength1: $input_zp,
+    SPIRV_TosaNumerical_1DTensorArmOfLength1: $output_zp
+  );
+
+  let results = (outs
+    SPIRV_TosaNumerical_TensorArm4D: $output
+  );
+
+  let assemblyFormat = [{
+    `kernel` `=` custom<SPIRV_I32_1DArmTensor>($kernel) `,`
+    `stride` `=` custom<SPIRV_I32_1DArmTensor>($stride) `,`
+    `pad` `=` custom<SPIRV_I32_1DArmTensor>($pad) `,`
+    `acc_type` `=` $acc_type `,`
+    $input `,`
+    $input_zp `,`
+    $output_zp
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+
+  let extraClassDeclaration = extraBaseClassDeclaration#[{
+    ::mlir::spirv::TensorArmType getInputType() {
+      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
     }
   }];
 }
 
 
-def SPIRV_TosaConv2DOp : SPIRV_TosaOp<"Conv2D", 2, [Pure,
+def SPIRV_TosaConv2DOp : SPIRV_TosaOpWithResult<"Conv2D", 2, [Pure,
+  TypeConstraintImplicationOn<"input", I8, "output", [I32]>,
+  TypeConstraintImplicationOn<"input", I16, "output", [I64]>,
+  TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
+  TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
+  TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
+  TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>,
+  TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>,
----------------
davidegrohmann wrote:

This should be `TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>`

https://github.com/llvm/llvm-project/pull/177585


More information about the Mlir-commits mailing list