[Mlir-commits] [mlir] 1c33275 - [mlir][spirv] Introduce a base class for spirv.TOSA convolution ops (#183751)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Mar 1 23:55:26 PST 2026


Author: Davide Grohmann
Date: 2026-03-02T08:55:21+01:00
New Revision: 1c3327561977a77aea21956c084d63e0e4fd2860

URL: https://github.com/llvm/llvm-project/commit/1c3327561977a77aea21956c084d63e0e4fd2860
DIFF: https://github.com/llvm/llvm-project/commit/1c3327561977a77aea21956c084d63e0e4fd2860.diff

LOG: [mlir][spirv] Introduce a base class for spirv.TOSA convolution ops (#183751)

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index 406fb43aaa4e8..7d4795ac40e93 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -83,6 +83,40 @@ class SPIRV_TosaElementwiseBinaryOp<string mnemonic, int opcode, list<Trait> tra
     AllElementTypesMatch<["input1", "output"]>])> {
 }
 
+class SPIRV_TosaConvolutionOp<string mnemonic, int opcode, list<Trait> traits = []> :
+  SPIRV_TosaOpWithResult<mnemonic, opcode, !listconcat(traits, [Pure,
+    TypeConstraintImplicationOn<"input", I8, "output", [I32]>,
+    TypeConstraintImplicationOn<"input", I16, "output", [I64]>,
+    TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
+    TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
+    TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
+    TypeConstraintImplicationOn<"input", BF16, "weight", [BF16]>,
+    TypeConstraintImplicationOn<"input", F16, "weight", [F16]>,
+    TypeConstraintImplicationOn<"input", F32, "weight", [F32]>,
+    TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>,
+    TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>,
+    TypeImpliesAccType<"input", I8, ["INT32"]>,
+    TypeImpliesAccType<"input", I16, ["INT48"]>,
+    TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
+    TypeImpliesAccType<"input", BF16, ["FP32"]>,
+    TypeImpliesAccType<"input", F32, ["FP32"]>,
+    AllElementTypesMatch<["bias", "output"]>,
+    AllElementTypesMatch<["input", "input_zp"]>,
+    AllElementTypesMatch<["weight", "weight_zp"]>])> {
+
+  let extraClassDeclaration = extraBaseClassDeclaration#[{
+    ::mlir::spirv::TensorArmType getInputType() {
+      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+    }
+    ::mlir::spirv::TensorArmType getWeightType() {
+      return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
+    }
+    ::mlir::spirv::TensorArmType getBiasType() {
+      return cast<::mlir::spirv::TensorArmType>(getBias().getType());
+    }
+  }];
+}
+
 
 def SPIRV_TosaArgMaxOp : SPIRV_TosaOpWithResult<"ArgMax", 0, [Pure,
   OutputRankIsInputRankMinusOne<"input", "output">,
@@ -190,22 +224,7 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [Pure,
 }
 
 
-def SPIRV_TosaConv2DOp : SPIRV_TosaOpWithResult<"Conv2D", 2, [Pure,
-  TypeConstraintImplicationOn<"input", I8, "output", [I32]>,
-  TypeConstraintImplicationOn<"input", I16, "output", [I64]>,
-  TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
-  TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
-  TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
-  TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>,
-  TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>,
-  TypeImpliesAccType<"input", I8, ["INT32"]>,
-  TypeImpliesAccType<"input", I16, ["INT48"]>,
-  TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
-  TypeImpliesAccType<"input", BF16, ["FP32"]>,
-  TypeImpliesAccType<"input", F32, ["FP32"]>,
-  AllElementTypesMatch<["bias", "output"]>,
-  AllElementTypesMatch<["input", "input_zp"]>,
-  AllElementTypesMatch<["weight", "weight_zp"]>]> {
+def SPIRV_TosaConv2DOp : SPIRV_TosaConvolutionOp<"Conv2D", 2> {
   let summary = "2D Convolution operator.";
 
   let description = [{
@@ -257,36 +276,10 @@ def SPIRV_TosaConv2DOp : SPIRV_TosaOpWithResult<"Conv2D", 2, [Pure,
     attr-dict `:` type(operands) `->` type(results)
   }];
 
-  let extraClassDeclaration = extraBaseClassDeclaration#[{
-    ::mlir::spirv::TensorArmType getInputType() {
-      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
-    }
-    ::mlir::spirv::TensorArmType getWeightType() {
-      return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
-    }
-    ::mlir::spirv::TensorArmType getBiasType() {
-      return cast<::mlir::spirv::TensorArmType>(getBias().getType());
-    }
-  }];
 }
 
 
-def SPIRV_TosaConv3DOp : SPIRV_TosaOpWithResult<"Conv3D", 3, [Pure,
-  TypeConstraintImplicationOn<"input", I8, "output", [I32]>,
-  TypeConstraintImplicationOn<"input", I16, "output", [I64]>,
-  TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
-  TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
-  TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
-  TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>,
-  TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>,
-  TypeImpliesAccType<"input", I8, ["INT32"]>,
-  TypeImpliesAccType<"input", I16, ["INT48"]>,
-  TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
-  TypeImpliesAccType<"input", BF16, ["FP32"]>,
-  TypeImpliesAccType<"input", F32, ["FP32"]>,
-  AllElementTypesMatch<["bias", "output"]>,
-  AllElementTypesMatch<["input", "input_zp"]>,
-  AllElementTypesMatch<["weight", "weight_zp"]>]> {
+def SPIRV_TosaConv3DOp : SPIRV_TosaConvolutionOp<"Conv3D", 3> {
   let summary = "3D Convolution operator.";
 
   let description = [{
@@ -337,36 +330,10 @@ def SPIRV_TosaConv3DOp : SPIRV_TosaOpWithResult<"Conv3D", 3, [Pure,
     attr-dict `:` type(operands) `->` type(results)
   }];
 
-  let extraClassDeclaration = extraBaseClassDeclaration#[{
-    ::mlir::spirv::TensorArmType getInputType() {
-      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
-    }
-    ::mlir::spirv::TensorArmType getWeightType() {
-      return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
-    }
-    ::mlir::spirv::TensorArmType getBiasType() {
-      return cast<::mlir::spirv::TensorArmType>(getBias().getType());
-    }
-  }];
 }
 
 
-def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaOpWithResult<"DepthwiseConv2D", 4, [Pure,
-  TypeConstraintImplicationOn<"input", I8, "output", [I32]>,
-  TypeConstraintImplicationOn<"input", I16, "output", [I64]>,
-  TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
-  TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
-  TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
-  TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>,
-  TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>,
-  TypeImpliesAccType<"input", I8, ["INT32"]>,
-  TypeImpliesAccType<"input", I16, ["INT48"]>,
-  TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
-  TypeImpliesAccType<"input", BF16, ["FP32"]>,
-  TypeImpliesAccType<"input", F32, ["FP32"]>,
-  AllElementTypesMatch<["bias", "output"]>,
-  AllElementTypesMatch<["input", "input_zp"]>,
-  AllElementTypesMatch<["weight", "weight_zp"]>]> {
+def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaConvolutionOp<"DepthwiseConv2D", 4> {
   let summary = "Depthwise 2D Convolution operator.";
 
   let description = [{
@@ -418,17 +385,6 @@ def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaOpWithResult<"DepthwiseConv2D", 4, [
     attr-dict `:` type(operands) `->` type(results)
   }];
 
-  let extraClassDeclaration = extraBaseClassDeclaration#[{
-    ::mlir::spirv::TensorArmType getInputType() {
-      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
-    }
-    ::mlir::spirv::TensorArmType getWeightType() {
-      return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
-    }
-    ::mlir::spirv::TensorArmType getBiasType() {
-      return cast<::mlir::spirv::TensorArmType>(getBias().getType());
-    }
-  }];
 }
 
 
@@ -635,22 +591,7 @@ def SPIRV_TosaRFFT2DOp : SPIRV_TosaOpWithComplexResult<"RFFT2D", 8, [Pure]> {
 }
 
 
-def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaOpWithResult<"TransposeConv2D", 9, [Pure,
-  TypeConstraintImplicationOn<"input", I8, "output", [I32]>,
-  TypeConstraintImplicationOn<"input", I16, "output", [I64]>,
-  TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
-  TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
-  TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
-  TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>,
-  TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>,
-  TypeImpliesAccType<"input", I8, ["INT32"]>,
-  TypeImpliesAccType<"input", I16, ["INT48"]>,
-  TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
-  TypeImpliesAccType<"input", BF16, ["FP32"]>,
-  TypeImpliesAccType<"input", F32, ["FP32"]>,
-  AllElementTypesMatch<["bias", "output"]>,
-  AllElementTypesMatch<["input", "input_zp"]>,
-  AllElementTypesMatch<["weight", "weight_zp"]>]> {
+def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaConvolutionOp<"TransposeConv2D", 9> {
   let summary = "Transpose 2D Convolution operator.";
 
   let description = [{
@@ -700,17 +641,6 @@ def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaOpWithResult<"TransposeConv2D", 9, [
     attr-dict `:` type(operands) `->` type(results)
   }];
 
-  let extraClassDeclaration = extraBaseClassDeclaration#[{
-    ::mlir::spirv::TensorArmType getInputType() {
-      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
-    }
-    ::mlir::spirv::TensorArmType getWeightType() {
-      return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
-    }
-    ::mlir::spirv::TensorArmType getBiasType() {
-      return cast<::mlir::spirv::TensorArmType>(getBias().getType());
-    }
-  }];
 }
 
 


        


More information about the Mlir-commits mailing list