[Mlir-commits] [mlir] 1c33275 - [mlir][spirv] Introduce a base class for spirv.TOSA convolution ops (#183751)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Mar 1 23:55:26 PST 2026
Author: Davide Grohmann
Date: 2026-03-02T08:55:21+01:00
New Revision: 1c3327561977a77aea21956c084d63e0e4fd2860
URL: https://github.com/llvm/llvm-project/commit/1c3327561977a77aea21956c084d63e0e4fd2860
DIFF: https://github.com/llvm/llvm-project/commit/1c3327561977a77aea21956c084d63e0e4fd2860.diff
LOG: [mlir][spirv] Introduce a base class for spirv.TOSA convolution ops (#183751)
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index 406fb43aaa4e8..7d4795ac40e93 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -83,6 +83,40 @@ class SPIRV_TosaElementwiseBinaryOp<string mnemonic, int opcode, list<Trait> tra
AllElementTypesMatch<["input1", "output"]>])> {
}
+class SPIRV_TosaConvolutionOp<string mnemonic, int opcode, list<Trait> traits = []> :
+ SPIRV_TosaOpWithResult<mnemonic, opcode, !listconcat(traits, [Pure,
+ TypeConstraintImplicationOn<"input", I8, "output", [I32]>,
+ TypeConstraintImplicationOn<"input", I16, "output", [I64]>,
+ TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
+ TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
+ TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
+ TypeConstraintImplicationOn<"input", BF16, "weight", [BF16]>,
+ TypeConstraintImplicationOn<"input", F16, "weight", [F16]>,
+ TypeConstraintImplicationOn<"input", F32, "weight", [F32]>,
+ TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>,
+ TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>,
+ TypeImpliesAccType<"input", I8, ["INT32"]>,
+ TypeImpliesAccType<"input", I16, ["INT48"]>,
+ TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
+ TypeImpliesAccType<"input", BF16, ["FP32"]>,
+ TypeImpliesAccType<"input", F32, ["FP32"]>,
+ AllElementTypesMatch<["bias", "output"]>,
+ AllElementTypesMatch<["input", "input_zp"]>,
+ AllElementTypesMatch<["weight", "weight_zp"]>])> {
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ ::mlir::spirv::TensorArmType getInputType() {
+ return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+ }
+ ::mlir::spirv::TensorArmType getWeightType() {
+ return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
+ }
+ ::mlir::spirv::TensorArmType getBiasType() {
+ return cast<::mlir::spirv::TensorArmType>(getBias().getType());
+ }
+ }];
+}
+
def SPIRV_TosaArgMaxOp : SPIRV_TosaOpWithResult<"ArgMax", 0, [Pure,
OutputRankIsInputRankMinusOne<"input", "output">,
@@ -190,22 +224,7 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [Pure,
}
-def SPIRV_TosaConv2DOp : SPIRV_TosaOpWithResult<"Conv2D", 2, [Pure,
- TypeConstraintImplicationOn<"input", I8, "output", [I32]>,
- TypeConstraintImplicationOn<"input", I16, "output", [I64]>,
- TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
- TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
- TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
- TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>,
- TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>,
- TypeImpliesAccType<"input", I8, ["INT32"]>,
- TypeImpliesAccType<"input", I16, ["INT48"]>,
- TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
- TypeImpliesAccType<"input", BF16, ["FP32"]>,
- TypeImpliesAccType<"input", F32, ["FP32"]>,
- AllElementTypesMatch<["bias", "output"]>,
- AllElementTypesMatch<["input", "input_zp"]>,
- AllElementTypesMatch<["weight", "weight_zp"]>]> {
+def SPIRV_TosaConv2DOp : SPIRV_TosaConvolutionOp<"Conv2D", 2> {
let summary = "2D Convolution operator.";
let description = [{
@@ -257,36 +276,10 @@ def SPIRV_TosaConv2DOp : SPIRV_TosaOpWithResult<"Conv2D", 2, [Pure,
attr-dict `:` type(operands) `->` type(results)
}];
- let extraClassDeclaration = extraBaseClassDeclaration#[{
- ::mlir::spirv::TensorArmType getInputType() {
- return cast<::mlir::spirv::TensorArmType>(getInput().getType());
- }
- ::mlir::spirv::TensorArmType getWeightType() {
- return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
- }
- ::mlir::spirv::TensorArmType getBiasType() {
- return cast<::mlir::spirv::TensorArmType>(getBias().getType());
- }
- }];
}
-def SPIRV_TosaConv3DOp : SPIRV_TosaOpWithResult<"Conv3D", 3, [Pure,
- TypeConstraintImplicationOn<"input", I8, "output", [I32]>,
- TypeConstraintImplicationOn<"input", I16, "output", [I64]>,
- TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
- TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
- TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
- TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>,
- TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>,
- TypeImpliesAccType<"input", I8, ["INT32"]>,
- TypeImpliesAccType<"input", I16, ["INT48"]>,
- TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
- TypeImpliesAccType<"input", BF16, ["FP32"]>,
- TypeImpliesAccType<"input", F32, ["FP32"]>,
- AllElementTypesMatch<["bias", "output"]>,
- AllElementTypesMatch<["input", "input_zp"]>,
- AllElementTypesMatch<["weight", "weight_zp"]>]> {
+def SPIRV_TosaConv3DOp : SPIRV_TosaConvolutionOp<"Conv3D", 3> {
let summary = "3D Convolution operator.";
let description = [{
@@ -337,36 +330,10 @@ def SPIRV_TosaConv3DOp : SPIRV_TosaOpWithResult<"Conv3D", 3, [Pure,
attr-dict `:` type(operands) `->` type(results)
}];
- let extraClassDeclaration = extraBaseClassDeclaration#[{
- ::mlir::spirv::TensorArmType getInputType() {
- return cast<::mlir::spirv::TensorArmType>(getInput().getType());
- }
- ::mlir::spirv::TensorArmType getWeightType() {
- return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
- }
- ::mlir::spirv::TensorArmType getBiasType() {
- return cast<::mlir::spirv::TensorArmType>(getBias().getType());
- }
- }];
}
-def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaOpWithResult<"DepthwiseConv2D", 4, [Pure,
- TypeConstraintImplicationOn<"input", I8, "output", [I32]>,
- TypeConstraintImplicationOn<"input", I16, "output", [I64]>,
- TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
- TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
- TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
- TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>,
- TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>,
- TypeImpliesAccType<"input", I8, ["INT32"]>,
- TypeImpliesAccType<"input", I16, ["INT48"]>,
- TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
- TypeImpliesAccType<"input", BF16, ["FP32"]>,
- TypeImpliesAccType<"input", F32, ["FP32"]>,
- AllElementTypesMatch<["bias", "output"]>,
- AllElementTypesMatch<["input", "input_zp"]>,
- AllElementTypesMatch<["weight", "weight_zp"]>]> {
+def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaConvolutionOp<"DepthwiseConv2D", 4> {
let summary = "Depthwise 2D Convolution operator.";
let description = [{
@@ -418,17 +385,6 @@ def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaOpWithResult<"DepthwiseConv2D", 4, [
attr-dict `:` type(operands) `->` type(results)
}];
- let extraClassDeclaration = extraBaseClassDeclaration#[{
- ::mlir::spirv::TensorArmType getInputType() {
- return cast<::mlir::spirv::TensorArmType>(getInput().getType());
- }
- ::mlir::spirv::TensorArmType getWeightType() {
- return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
- }
- ::mlir::spirv::TensorArmType getBiasType() {
- return cast<::mlir::spirv::TensorArmType>(getBias().getType());
- }
- }];
}
@@ -635,22 +591,7 @@ def SPIRV_TosaRFFT2DOp : SPIRV_TosaOpWithComplexResult<"RFFT2D", 8, [Pure]> {
}
-def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaOpWithResult<"TransposeConv2D", 9, [Pure,
- TypeConstraintImplicationOn<"input", I8, "output", [I32]>,
- TypeConstraintImplicationOn<"input", I16, "output", [I64]>,
- TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
- TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
- TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
- TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>,
- TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>,
- TypeImpliesAccType<"input", I8, ["INT32"]>,
- TypeImpliesAccType<"input", I16, ["INT48"]>,
- TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
- TypeImpliesAccType<"input", BF16, ["FP32"]>,
- TypeImpliesAccType<"input", F32, ["FP32"]>,
- AllElementTypesMatch<["bias", "output"]>,
- AllElementTypesMatch<["input", "input_zp"]>,
- AllElementTypesMatch<["weight", "weight_zp"]>]> {
+def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaConvolutionOp<"TransposeConv2D", 9> {
let summary = "Transpose 2D Convolution operator.";
let description = [{
@@ -700,17 +641,6 @@ def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaOpWithResult<"TransposeConv2D", 9, [
attr-dict `:` type(operands) `->` type(results)
}];
- let extraClassDeclaration = extraBaseClassDeclaration#[{
- ::mlir::spirv::TensorArmType getInputType() {
- return cast<::mlir::spirv::TensorArmType>(getInput().getType());
- }
- ::mlir::spirv::TensorArmType getWeightType() {
- return cast<::mlir::spirv::TensorArmType>(getWeight().getType());
- }
- ::mlir::spirv::TensorArmType getBiasType() {
- return cast<::mlir::spirv::TensorArmType>(getBias().getType());
- }
- }];
}
More information about the Mlir-commits
mailing list