[Mlir-commits] [mlir] [mlir][spirv] Introduce a base class for spirv.TOSA convolution ops (PR #183751)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 27 07:26:01 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Davide Grohmann (davidegrohmann)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/183751.diff
1 Files Affected:
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td (+38-108)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index 406fb43aaa4e8..7d4795ac40e93 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -83,6 +83,40 @@ class SPIRV_TosaElementwiseBinaryOp<string mnemonic, int opcode, list<Trait> tra
AllElementTypesMatch<["input1", "output"]>])> {
}
+class SPIRV_TosaConvolutionOp<string mnemonic, int opcode, list<Trait> traits = []> :
+ SPIRV_TosaOpWithResult<mnemonic, opcode, !listconcat(traits, [Pure,
+ TypeConstraintImplicationOn<"input", I8, "output", [I32]>,
+ TypeConstraintImplicationOn<"input", I16, "output", [I64]>,
+ TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
+ TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
+ TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
+ TypeConstraintImplicationOn<"input", BF16, "weight", [BF16]>,
+ TypeConstraintImplicationOn<"input", F16, "weight", [F16]>,
+ TypeConstraintImplicationOn<"input", F32, "weight", [F32]>,
+ TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>,
+ TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>,
+ TypeImpliesAccType<"input", I8, ["INT32"]>,
+ TypeImpliesAccType<"input", I16, ["INT48"]>,
+ TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
+ TypeImpliesAccType<"input", BF16, ["FP32"]>,
+ TypeImpliesAccType<"input", F32, ["FP32"]>,
+ AllElementTypesMatch<["bias", "output"]>,
+ AllElementTypesMatch<["input", "input_zp"]>,
+ AllElementTypesMatch<["weight", "weight_zp"]>])> {
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ ::mlir::spirv::TensorArmType getInputType() {
+ return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+ }
+ ::mlir::spirv::TensorArmType getWeightType() {
+ return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
+ }
+ ::mlir::spirv::TensorArmType getBiasType() {
+ return cast<::mlir::spirv::TensorArmType>(getBias().getType());
+ }
+ }];
+}
+
def SPIRV_TosaArgMaxOp : SPIRV_TosaOpWithResult<"ArgMax", 0, [Pure,
OutputRankIsInputRankMinusOne<"input", "output">,
@@ -190,22 +224,7 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [Pure,
}
-def SPIRV_TosaConv2DOp : SPIRV_TosaOpWithResult<"Conv2D", 2, [Pure,
- TypeConstraintImplicationOn<"input", I8, "output", [I32]>,
- TypeConstraintImplicationOn<"input", I16, "output", [I64]>,
- TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
- TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
- TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
- TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>,
- TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>,
- TypeImpliesAccType<"input", I8, ["INT32"]>,
- TypeImpliesAccType<"input", I16, ["INT48"]>,
- TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
- TypeImpliesAccType<"input", BF16, ["FP32"]>,
- TypeImpliesAccType<"input", F32, ["FP32"]>,
- AllElementTypesMatch<["bias", "output"]>,
- AllElementTypesMatch<["input", "input_zp"]>,
- AllElementTypesMatch<["weight", "weight_zp"]>]> {
+def SPIRV_TosaConv2DOp : SPIRV_TosaConvolutionOp<"Conv2D", 2> {
let summary = "2D Convolution operator.";
let description = [{
@@ -257,36 +276,10 @@ def SPIRV_TosaConv2DOp : SPIRV_TosaOpWithResult<"Conv2D", 2, [Pure,
attr-dict `:` type(operands) `->` type(results)
}];
- let extraClassDeclaration = extraBaseClassDeclaration#[{
- ::mlir::spirv::TensorArmType getInputType() {
- return cast<::mlir::spirv::TensorArmType>(getInput().getType());
- }
- ::mlir::spirv::TensorArmType getWeightType() {
- return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
- }
- ::mlir::spirv::TensorArmType getBiasType() {
- return cast<::mlir::spirv::TensorArmType>(getBias().getType());
- }
- }];
}
-def SPIRV_TosaConv3DOp : SPIRV_TosaOpWithResult<"Conv3D", 3, [Pure,
- TypeConstraintImplicationOn<"input", I8, "output", [I32]>,
- TypeConstraintImplicationOn<"input", I16, "output", [I64]>,
- TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
- TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
- TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
- TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>,
- TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>,
- TypeImpliesAccType<"input", I8, ["INT32"]>,
- TypeImpliesAccType<"input", I16, ["INT48"]>,
- TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
- TypeImpliesAccType<"input", BF16, ["FP32"]>,
- TypeImpliesAccType<"input", F32, ["FP32"]>,
- AllElementTypesMatch<["bias", "output"]>,
- AllElementTypesMatch<["input", "input_zp"]>,
- AllElementTypesMatch<["weight", "weight_zp"]>]> {
+def SPIRV_TosaConv3DOp : SPIRV_TosaConvolutionOp<"Conv3D", 3> {
let summary = "3D Convolution operator.";
let description = [{
@@ -337,36 +330,10 @@ def SPIRV_TosaConv3DOp : SPIRV_TosaOpWithResult<"Conv3D", 3, [Pure,
attr-dict `:` type(operands) `->` type(results)
}];
- let extraClassDeclaration = extraBaseClassDeclaration#[{
- ::mlir::spirv::TensorArmType getInputType() {
- return cast<::mlir::spirv::TensorArmType>(getInput().getType());
- }
- ::mlir::spirv::TensorArmType getWeightType() {
- return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
- }
- ::mlir::spirv::TensorArmType getBiasType() {
- return cast<::mlir::spirv::TensorArmType>(getBias().getType());
- }
- }];
}
-def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaOpWithResult<"DepthwiseConv2D", 4, [Pure,
- TypeConstraintImplicationOn<"input", I8, "output", [I32]>,
- TypeConstraintImplicationOn<"input", I16, "output", [I64]>,
- TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
- TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
- TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
- TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>,
- TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>,
- TypeImpliesAccType<"input", I8, ["INT32"]>,
- TypeImpliesAccType<"input", I16, ["INT48"]>,
- TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
- TypeImpliesAccType<"input", BF16, ["FP32"]>,
- TypeImpliesAccType<"input", F32, ["FP32"]>,
- AllElementTypesMatch<["bias", "output"]>,
- AllElementTypesMatch<["input", "input_zp"]>,
- AllElementTypesMatch<["weight", "weight_zp"]>]> {
+def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaConvolutionOp<"DepthwiseConv2D", 4> {
let summary = "Depthwise 2D Convolution operator.";
let description = [{
@@ -418,17 +385,6 @@ def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaOpWithResult<"DepthwiseConv2D", 4, [
attr-dict `:` type(operands) `->` type(results)
}];
- let extraClassDeclaration = extraBaseClassDeclaration#[{
- ::mlir::spirv::TensorArmType getInputType() {
- return cast<::mlir::spirv::TensorArmType>(getInput().getType());
- }
- ::mlir::spirv::TensorArmType getWeightType() {
- return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
- }
- ::mlir::spirv::TensorArmType getBiasType() {
- return cast<::mlir::spirv::TensorArmType>(getBias().getType());
- }
- }];
}
@@ -635,22 +591,7 @@ def SPIRV_TosaRFFT2DOp : SPIRV_TosaOpWithComplexResult<"RFFT2D", 8, [Pure]> {
}
-def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaOpWithResult<"TransposeConv2D", 9, [Pure,
- TypeConstraintImplicationOn<"input", I8, "output", [I32]>,
- TypeConstraintImplicationOn<"input", I16, "output", [I64]>,
- TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
- TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
- TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
- TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>,
- TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>,
- TypeImpliesAccType<"input", I8, ["INT32"]>,
- TypeImpliesAccType<"input", I16, ["INT48"]>,
- TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
- TypeImpliesAccType<"input", BF16, ["FP32"]>,
- TypeImpliesAccType<"input", F32, ["FP32"]>,
- AllElementTypesMatch<["bias", "output"]>,
- AllElementTypesMatch<["input", "input_zp"]>,
- AllElementTypesMatch<["weight", "weight_zp"]>]> {
+def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaConvolutionOp<"TransposeConv2D", 9> {
let summary = "Transpose 2D Convolution operator.";
let description = [{
@@ -700,17 +641,6 @@ def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaOpWithResult<"TransposeConv2D", 9, [
attr-dict `:` type(operands) `->` type(results)
}];
- let extraClassDeclaration = extraBaseClassDeclaration#[{
- ::mlir::spirv::TensorArmType getInputType() {
- return cast<::mlir::spirv::TensorArmType>(getInput().getType());
- }
- ::mlir::spirv::TensorArmType getWeightType() {
- return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
- }
- ::mlir::spirv::TensorArmType getBiasType() {
- return cast<::mlir::spirv::TensorArmType>(getBias().getType());
- }
- }];
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/183751
More information about the Mlir-commits
mailing list