[Mlir-commits] [mlir] [mlir][spirv] Introduce a base class for spirv.TOSA convolution ops (PR #183751)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 27 07:26:01 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Davide Grohmann (davidegrohmann)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/183751.diff


1 Files Affected:

- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td (+38-108) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index 406fb43aaa4e8..7d4795ac40e93 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -83,6 +83,40 @@ class SPIRV_TosaElementwiseBinaryOp<string mnemonic, int opcode, list<Trait> tra
     AllElementTypesMatch<["input1", "output"]>])> {
 }
 
+class SPIRV_TosaConvolutionOp<string mnemonic, int opcode, list<Trait> traits = []> :
+  SPIRV_TosaOpWithResult<mnemonic, opcode, !listconcat(traits, [Pure,
+    TypeConstraintImplicationOn<"input", I8, "output", [I32]>,
+    TypeConstraintImplicationOn<"input", I16, "output", [I64]>,
+    TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
+    TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
+    TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
+    TypeConstraintImplicationOn<"input", BF16, "weight", [BF16]>,
+    TypeConstraintImplicationOn<"input", F16, "weight", [F16]>,
+    TypeConstraintImplicationOn<"input", F32, "weight", [F32]>,
+    TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>,
+    TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>,
+    TypeImpliesAccType<"input", I8, ["INT32"]>,
+    TypeImpliesAccType<"input", I16, ["INT48"]>,
+    TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
+    TypeImpliesAccType<"input", BF16, ["FP32"]>,
+    TypeImpliesAccType<"input", F32, ["FP32"]>,
+    AllElementTypesMatch<["bias", "output"]>,
+    AllElementTypesMatch<["input", "input_zp"]>,
+    AllElementTypesMatch<["weight", "weight_zp"]>])> {
+
+  let extraClassDeclaration = extraBaseClassDeclaration#[{
+    ::mlir::spirv::TensorArmType getInputType() {
+      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+    }
+    ::mlir::spirv::TensorArmType getWeightType() {
+      return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
+    }
+    ::mlir::spirv::TensorArmType getBiasType() {
+      return cast<::mlir::spirv::TensorArmType>(getBias().getType());
+    }
+  }];
+}
+
 
 def SPIRV_TosaArgMaxOp : SPIRV_TosaOpWithResult<"ArgMax", 0, [Pure,
   OutputRankIsInputRankMinusOne<"input", "output">,
@@ -190,22 +224,7 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [Pure,
 }
 
 
-def SPIRV_TosaConv2DOp : SPIRV_TosaOpWithResult<"Conv2D", 2, [Pure,
-  TypeConstraintImplicationOn<"input", I8, "output", [I32]>,
-  TypeConstraintImplicationOn<"input", I16, "output", [I64]>,
-  TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
-  TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
-  TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
-  TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>,
-  TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>,
-  TypeImpliesAccType<"input", I8, ["INT32"]>,
-  TypeImpliesAccType<"input", I16, ["INT48"]>,
-  TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
-  TypeImpliesAccType<"input", BF16, ["FP32"]>,
-  TypeImpliesAccType<"input", F32, ["FP32"]>,
-  AllElementTypesMatch<["bias", "output"]>,
-  AllElementTypesMatch<["input", "input_zp"]>,
-  AllElementTypesMatch<["weight", "weight_zp"]>]> {
+def SPIRV_TosaConv2DOp : SPIRV_TosaConvolutionOp<"Conv2D", 2> {
   let summary = "2D Convolution operator.";
 
   let description = [{
@@ -257,36 +276,10 @@ def SPIRV_TosaConv2DOp : SPIRV_TosaOpWithResult<"Conv2D", 2, [Pure,
     attr-dict `:` type(operands) `->` type(results)
   }];
 
-  let extraClassDeclaration = extraBaseClassDeclaration#[{
-    ::mlir::spirv::TensorArmType getInputType() {
-      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
-    }
-    ::mlir::spirv::TensorArmType getWeightType() {
-      return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
-    }
-    ::mlir::spirv::TensorArmType getBiasType() {
-      return cast<::mlir::spirv::TensorArmType>(getBias().getType());
-    }
-  }];
 }
 
 
-def SPIRV_TosaConv3DOp : SPIRV_TosaOpWithResult<"Conv3D", 3, [Pure,
-  TypeConstraintImplicationOn<"input", I8, "output", [I32]>,
-  TypeConstraintImplicationOn<"input", I16, "output", [I64]>,
-  TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
-  TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
-  TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
-  TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>,
-  TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>,
-  TypeImpliesAccType<"input", I8, ["INT32"]>,
-  TypeImpliesAccType<"input", I16, ["INT48"]>,
-  TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
-  TypeImpliesAccType<"input", BF16, ["FP32"]>,
-  TypeImpliesAccType<"input", F32, ["FP32"]>,
-  AllElementTypesMatch<["bias", "output"]>,
-  AllElementTypesMatch<["input", "input_zp"]>,
-  AllElementTypesMatch<["weight", "weight_zp"]>]> {
+def SPIRV_TosaConv3DOp : SPIRV_TosaConvolutionOp<"Conv3D", 3> {
   let summary = "3D Convolution operator.";
 
   let description = [{
@@ -337,36 +330,10 @@ def SPIRV_TosaConv3DOp : SPIRV_TosaOpWithResult<"Conv3D", 3, [Pure,
     attr-dict `:` type(operands) `->` type(results)
   }];
 
-  let extraClassDeclaration = extraBaseClassDeclaration#[{
-    ::mlir::spirv::TensorArmType getInputType() {
-      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
-    }
-    ::mlir::spirv::TensorArmType getWeightType() {
-      return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
-    }
-    ::mlir::spirv::TensorArmType getBiasType() {
-      return cast<::mlir::spirv::TensorArmType>(getBias().getType());
-    }
-  }];
 }
 
 
-def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaOpWithResult<"DepthwiseConv2D", 4, [Pure,
-  TypeConstraintImplicationOn<"input", I8, "output", [I32]>,
-  TypeConstraintImplicationOn<"input", I16, "output", [I64]>,
-  TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
-  TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
-  TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
-  TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>,
-  TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>,
-  TypeImpliesAccType<"input", I8, ["INT32"]>,
-  TypeImpliesAccType<"input", I16, ["INT48"]>,
-  TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
-  TypeImpliesAccType<"input", BF16, ["FP32"]>,
-  TypeImpliesAccType<"input", F32, ["FP32"]>,
-  AllElementTypesMatch<["bias", "output"]>,
-  AllElementTypesMatch<["input", "input_zp"]>,
-  AllElementTypesMatch<["weight", "weight_zp"]>]> {
+def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaConvolutionOp<"DepthwiseConv2D", 4> {
   let summary = "Depthwise 2D Convolution operator.";
 
   let description = [{
@@ -418,17 +385,6 @@ def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaOpWithResult<"DepthwiseConv2D", 4, [
     attr-dict `:` type(operands) `->` type(results)
   }];
 
-  let extraClassDeclaration = extraBaseClassDeclaration#[{
-    ::mlir::spirv::TensorArmType getInputType() {
-      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
-    }
-    ::mlir::spirv::TensorArmType getWeightType() {
-      return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
-    }
-    ::mlir::spirv::TensorArmType getBiasType() {
-      return cast<::mlir::spirv::TensorArmType>(getBias().getType());
-    }
-  }];
 }
 
 
@@ -635,22 +591,7 @@ def SPIRV_TosaRFFT2DOp : SPIRV_TosaOpWithComplexResult<"RFFT2D", 8, [Pure]> {
 }
 
 
-def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaOpWithResult<"TransposeConv2D", 9, [Pure,
-  TypeConstraintImplicationOn<"input", I8, "output", [I32]>,
-  TypeConstraintImplicationOn<"input", I16, "output", [I64]>,
-  TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
-  TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
-  TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
-  TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>,
-  TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>,
-  TypeImpliesAccType<"input", I8, ["INT32"]>,
-  TypeImpliesAccType<"input", I16, ["INT48"]>,
-  TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
-  TypeImpliesAccType<"input", BF16, ["FP32"]>,
-  TypeImpliesAccType<"input", F32, ["FP32"]>,
-  AllElementTypesMatch<["bias", "output"]>,
-  AllElementTypesMatch<["input", "input_zp"]>,
-  AllElementTypesMatch<["weight", "weight_zp"]>]> {
+def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaConvolutionOp<"TransposeConv2D", 9> {
   let summary = "Transpose 2D Convolution operator.";
 
   let description = [{
@@ -700,17 +641,6 @@ def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaOpWithResult<"TransposeConv2D", 9, [
     attr-dict `:` type(operands) `->` type(results)
   }];
 
-  let extraClassDeclaration = extraBaseClassDeclaration#[{
-    ::mlir::spirv::TensorArmType getInputType() {
-      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
-    }
-    ::mlir::spirv::TensorArmType getWeightType() {
-      return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
-    }
-    ::mlir::spirv::TensorArmType getBiasType() {
-      return cast<::mlir::spirv::TensorArmType>(getBias().getType());
-    }
-  }];
 }
 
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/183751


More information about the Mlir-commits mailing list