[Mlir-commits] [mlir] [mlir][spirv] Add Pooling, Fourier Transform, and MatMul operations t… (PR #177585)
Davide Grohmann
llvmlistbot at llvm.org
Wed Jan 28 05:33:28 PST 2026
================
@@ -65,26 +94,94 @@ def SPIRV_TosaArgMaxOp : SPIRV_TosaOp<"ArgMax", 0, [Pure]> {
SPIRV_Int32_TensorArmUpTo5D: $output
);
- let hasVerifier = 1;
-
let assemblyFormat = [{
- `axis` `=` $axis `,` `nan_mode` `=` $nan_mode `,`
+ `axis` `=` $axis `,`
+ `nan_mode` `=` $nan_mode `,`
$input
attr-dict `:` type(operands) `->` type(results)
}];
- let extraClassDeclaration = [{
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
::mlir::spirv::TensorArmType getInputType() {
return cast<::mlir::spirv::TensorArmType>(getInput().getType());
}
- ::mlir::spirv::TensorArmType getResultType() {
- return cast<::mlir::spirv::TensorArmType>(getType());
+ }];
+}
+
+
+def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [Pure,
+ TypeImpliesAccType<"input", I8, ["INT32"]>,
+ TypeImpliesAccType<"input", I16, ["INT32"]>,
+ TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
+ TypeImpliesAccType<"input", BF16, ["FP32"]>,
+ TypeImpliesAccType<"input", F32, ["FP32"]>,
+ AllElementTypesMatch<["input", "input_zp", "output", "output_zp"]>]> {
+ let summary = "Performs average pooling on the input.";
+
+ let description = [{
+ Performs an average pooling over the given input tensor. A sliding
+ window of size given by <kernel size> is passed over the input tensor, with
+ the mean value being placed in the output tensor. When calculating the
+ average, only the number of valid input tensor values, but not padding, are
+ used to calculate the divisor.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_avg_pool2d
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_avg_pool2d
+
+ #### Example:
+ ```mlir
+ %6 = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [1, 2], pad = [0, 1, 0, 0], acc_type = <INT32>, %arg0, %4, %5 : !spirv.arm.tensor<1x3x65537x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x2x32768x1xi8>
+ %6 = spirv.Tosa.AvgPool2D kernel = [2, 2], stride = [1, 1], pad = [1, 0, 0, 0], acc_type = <FP32>, %arg0, %4, %5 : !spirv.arm.tensor<1x2x65533x2xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x2x65532x2xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_Int32_1DTensorArmOfLength2Attr: $kernel,
+ SPIRV_Int32_1DTensorArmOfLength2Attr: $stride,
+ SPIRV_Int32_1DTensorArmOfLength4Attr: $pad,
+ SPIRV_TosaExtAccTypeAttr: $acc_type,
+ SPIRV_TosaNumerical_TensorArm4D: $input,
+ SPIRV_TosaNumerical_1DTensorArmOfLength1: $input_zp,
+ SPIRV_TosaNumerical_1DTensorArmOfLength1: $output_zp
+ );
+
+ let results = (outs
+ SPIRV_TosaNumerical_TensorArm4D: $output
+ );
+
+ let assemblyFormat = [{
+ `kernel` `=` custom<SPIRV_I32_1DArmTensor>($kernel) `,`
+ `stride` `=` custom<SPIRV_I32_1DArmTensor>($stride) `,`
+ `pad` `=` custom<SPIRV_I32_1DArmTensor>($pad) `,`
+ `acc_type` `=` $acc_type `,`
+ $input `,`
+ $input_zp `,`
+ $output_zp
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ ::mlir::spirv::TensorArmType getInputType() {
+ return cast<::mlir::spirv::TensorArmType>(getInput().getType());
}
}];
}
-def SPIRV_TosaConv2DOp : SPIRV_TosaOp<"Conv2D", 2, [Pure,
+def SPIRV_TosaConv2DOp : SPIRV_TosaOpWithResult<"Conv2D", 2, [Pure,
+ TypeConstraintImplicationOn<"input", I8, "output", [I32]>,
+ TypeConstraintImplicationOn<"input", I16, "output", [I64]>,
+ TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
+ TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
+ TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
+ TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>,
+ TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>,
----------------
davidegrohmann wrote:
This should be `TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>`
https://github.com/llvm/llvm-project/pull/177585
More information about the Mlir-commits
mailing list