[Mlir-commits] [mlir] [mlir][spirv] Tighten SPIR-V TOSA convolution verification (PR #194592)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 28 03:55:22 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-spirv
Author: Davide Grohmann (davidegrohmann)
<details>
<summary>Changes</summary>
Add verifier coverage for SPIR-V TOSA convolution ops against the TOSA shape and type constraints.
This adds shared TableGen shape predicates for Conv2D, Conv3D, DepthwiseConv2D and TransposeConv2D, including batch/channel/bias relationships. It also constrains integer convolution weights so i8 and i16 inputs use i8 weights, matching the SPIR-V TOSA representation.
Add custom verifiers for the convolution output shape formulas, including stride divisibility for regular convolutions and out_pad bounds for TransposeConv2D. Tighten pad, stride and dilation attributes to use non-negative or positive i32 attribute constraints where required.
---
Patch is 80.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/194592.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td (+34-20)
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td (+78-10)
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp (+194)
- (modified) mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir (+228-60)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index 483ce4348971b..e6336bf011e9d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -158,6 +158,8 @@ class SPIRV_TosaConvolutionOp<string mnemonic, int opcode, list<Trait> traits =
TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
TypeConstraintImplicationOn<"input", F8E4M3FN, "output", [F16]>,
TypeConstraintImplicationOn<"input", F8E5M2, "output", [F16]>,
+ TypeConstraintImplicationOn<"input", I8, "weight", [I8]>,
+ TypeConstraintImplicationOn<"input", I16, "weight", [I8]>,
TypeConstraintImplicationOn<"input", BF16, "weight", [BF16]>,
TypeConstraintImplicationOn<"input", F16, "weight", [F16]>,
TypeConstraintImplicationOn<"input", F32, "weight", [F32]>,
@@ -310,9 +312,9 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffe
}];
let arguments = (ins
- SPIRV_PositiveInt32_1DTensorArmOfLength2Attr: $kernel,
- SPIRV_PositiveInt32_1DTensorArmOfLength2Attr: $stride,
- SPIRV_NonNegativeInt32_1DTensorArmOfLength4Attr: $pad,
+ SPIRV_PositiveI32_1DTensorArmOfLength2Attr: $kernel,
+ SPIRV_PositiveI32_1DTensorArmOfLength2Attr: $stride,
+ SPIRV_NonNegativeI32_1DTensorArmOfLength4Attr: $pad,
SPIRV_TosaExtAccTypeAttr: $acc_type,
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input,
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $input_zp,
@@ -344,7 +346,8 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffe
}
-def SPIRV_TosaConv2DOp : SPIRV_TosaConvolutionOp<"Conv2D", 2> {
+def SPIRV_TosaConv2DOp : SPIRV_TosaConvolutionOp<"Conv2D", 2, [
+ Conv2DShapeMatch<"input", "weight", "bias", "output">]> {
let summary = "2D Convolution operator.";
let description = [{
@@ -368,9 +371,9 @@ def SPIRV_TosaConv2DOp : SPIRV_TosaConvolutionOp<"Conv2D", 2> {
}];
let arguments = (ins
- SPIRV_I32_1DTensorArmOfLength4Attr: $pad,
- SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
- SPIRV_I32_1DTensorArmOfLength2Attr: $dilation,
+ SPIRV_NonNegativeI32_1DTensorArmOfLength4Attr: $pad,
+ SPIRV_PositiveI32_1DTensorArmOfLength2Attr: $stride,
+ SPIRV_PositiveI32_1DTensorArmOfLength2Attr: $dilation,
SPIRV_TosaExtAccTypeAttr: $acc_type,
SPIRV_BoolConstAttr: $local_bound,
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input,
@@ -397,10 +400,13 @@ def SPIRV_TosaConv2DOp : SPIRV_TosaConvolutionOp<"Conv2D", 2> {
$weight_zp
attr-dict `:` type(operands) `->` type(results)
}];
+
+ let hasVerifier = 1;
}
-def SPIRV_TosaConv3DOp : SPIRV_TosaConvolutionOp<"Conv3D", 3> {
+def SPIRV_TosaConv3DOp : SPIRV_TosaConvolutionOp<"Conv3D", 3, [
+ Conv3DShapeMatch<"input", "weight", "bias", "output">]> {
let summary = "3D Convolution operator.";
let description = [{
@@ -423,9 +429,9 @@ def SPIRV_TosaConv3DOp : SPIRV_TosaConvolutionOp<"Conv3D", 3> {
}];
let arguments = (ins
- SPIRV_I32_1DTensorArmOfLength6Attr: $pad,
- SPIRV_I32_1DTensorArmOfLength3Attr: $stride,
- SPIRV_I32_1DTensorArmOfLength3Attr: $dilation,
+ SPIRV_NonNegativeI32_1DTensorArmOfLength6Attr: $pad,
+ SPIRV_PositiveI32_1DTensorArmOfLength3Attr: $stride,
+ SPIRV_PositiveI32_1DTensorArmOfLength3Attr: $dilation,
SPIRV_TosaExtAccTypeAttr: $acc_type,
SPIRV_BoolConstAttr: $local_bound,
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm5D: $input,
@@ -452,10 +458,13 @@ def SPIRV_TosaConv3DOp : SPIRV_TosaConvolutionOp<"Conv3D", 3> {
$weight_zp
attr-dict `:` type(operands) `->` type(results)
}];
+
+ let hasVerifier = 1;
}
-def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaConvolutionOp<"DepthwiseConv2D", 4> {
+def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaConvolutionOp<"DepthwiseConv2D", 4, [
+ DepthwiseConv2DShapeMatch<"input", "weight", "bias", "output">]> {
let summary = "Depthwise 2D Convolution operator.";
let description = [{
@@ -479,9 +488,9 @@ def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaConvolutionOp<"DepthwiseConv2D", 4>
}];
let arguments = (ins
- SPIRV_I32_1DTensorArmOfLength4Attr: $pad,
- SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
- SPIRV_I32_1DTensorArmOfLength2Attr: $dilation,
+ SPIRV_NonNegativeI32_1DTensorArmOfLength4Attr: $pad,
+ SPIRV_PositiveI32_1DTensorArmOfLength2Attr: $stride,
+ SPIRV_PositiveI32_1DTensorArmOfLength2Attr: $dilation,
SPIRV_TosaExtAccTypeAttr: $acc_type,
SPIRV_BoolConstAttr: $local_bound,
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input,
@@ -508,6 +517,8 @@ def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaConvolutionOp<"DepthwiseConv2D", 4>
$weight_zp
attr-dict `:` type(operands) `->` type(results)
}];
+
+ let hasVerifier = 1;
}
@@ -646,9 +657,9 @@ def SPIRV_TosaMaxPool2DOp : SPIRV_TosaOpWithResult<"MaxPool2D", 7, [Pure,
}];
let arguments = (ins
- SPIRV_PositiveInt32_1DTensorArmOfLength2Attr: $kernel,
- SPIRV_PositiveInt32_1DTensorArmOfLength2Attr: $stride,
- SPIRV_NonNegativeInt32_1DTensorArmOfLength4Attr: $pad,
+ SPIRV_PositiveI32_1DTensorArmOfLength2Attr: $kernel,
+ SPIRV_PositiveI32_1DTensorArmOfLength2Attr: $stride,
+ SPIRV_NonNegativeI32_1DTensorArmOfLength4Attr: $pad,
SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input
);
@@ -724,7 +735,8 @@ def SPIRV_TosaRFFT2DOp : SPIRV_TosaOpWithComplexResult<"RFFT2D", 8, [Pure]> {
}
-def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaConvolutionOp<"TransposeConv2D", 9> {
+def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaConvolutionOp<"TransposeConv2D", 9, [
+ Conv2DShapeMatch<"input", "weight", "bias", "output">]> {
let summary = "Transpose 2D Convolution operator.";
let description = [{
@@ -749,7 +761,7 @@ def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaConvolutionOp<"TransposeConv2D", 9>
let arguments = (ins
SPIRV_I32_1DTensorArmOfLength4Attr: $out_pad,
- SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
+ SPIRV_PositiveI32_1DTensorArmOfLength2Attr: $stride,
SPIRV_TosaExtAccTypeAttr: $acc_type,
SPIRV_BoolConstAttr: $local_bound,
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input,
@@ -775,6 +787,8 @@ def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaConvolutionOp<"TransposeConv2D", 9>
$weight_zp
attr-dict `:` type(operands) `->` type(results)
}];
+
+ let hasVerifier = 1;
}
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
index 7064930c5864a..bc981d8bf95cf 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -121,8 +121,10 @@ class IntElementsAttrAllValuesAtLeast<int minValue> : AttrConstraint<
minValue # "; })">,
"all values must be >= " # minValue>;
-def SPIRV_PositiveInt32_1DTensorArmOfLength2Attr : ConfinedAttr<RankedI32ElementsAttr<[2]>, [SPIRV_DenseElementAttrsWithTensorArmType, IntElementsAttrAllValuesAtLeast<1>]>;
-def SPIRV_NonNegativeInt32_1DTensorArmOfLength4Attr : ConfinedAttr<RankedI32ElementsAttr<[4]>, [SPIRV_DenseElementAttrsWithTensorArmType, IntElementsAttrAllValuesAtLeast<0>]>;
+def SPIRV_PositiveI32_1DTensorArmOfLength2Attr : ConfinedAttr<RankedI32ElementsAttr<[2]>, [SPIRV_DenseElementAttrsWithTensorArmType, IntElementsAttrAllValuesAtLeast<1>]>;
+def SPIRV_PositiveI32_1DTensorArmOfLength3Attr : ConfinedAttr<RankedI32ElementsAttr<[3]>, [SPIRV_DenseElementAttrsWithTensorArmType, IntElementsAttrAllValuesAtLeast<1>]>;
+def SPIRV_NonNegativeI32_1DTensorArmOfLength4Attr : ConfinedAttr<RankedI32ElementsAttr<[4]>, [SPIRV_DenseElementAttrsWithTensorArmType, IntElementsAttrAllValuesAtLeast<0>]>;
+def SPIRV_NonNegativeI32_1DTensorArmOfLength6Attr : ConfinedAttr<RankedI32ElementsAttr<[6]>, [SPIRV_DenseElementAttrsWithTensorArmType, IntElementsAttrAllValuesAtLeast<0>]>;
class Is1DTensorArmAttrOfLength<list<int> allowedLengths> :
AttrConstraint<And<[CPred<[{::llvm::cast<::mlir::spirv::TensorArmType>(::llvm::cast<::mlir::DenseElementsAttr>($_self).getType()).getShape().size() == 1 }]>,
@@ -204,15 +206,46 @@ class MatchBroadcastableShapes<string input1, string input2, string output>:
"})">]>
>;
+class DimOf<string input, int dim> :
+ StrFunc<"::llvm::cast<::mlir::ShapedType>($" # input #
+ ".getType()).getDimSize(" # dim # ")">;
+
+class DimIsDynamic<string input, int dim> :
+ CPred<"::mlir::ShapedType::isDynamic(" # DimOf<input, dim>.result # ")">;
+
+class DimIsOne<string input, int dim> :
+ CPred<DimOf<input, dim>.result # " == 1">;
+
+class DimsMatch<string lhs, int lhsDim, string rhs, int rhsDim> :
+ CPred<DimOf<lhs, lhsDim>.result # " == " # DimOf<rhs, rhsDim>.result>;
+
+class ProductDimsMatch<string lhs, int lhsDim, string rhs, int rhsDim,
+ string output, int outputDim> :
+ CPred<DimOf<lhs, lhsDim>.result # " * " # DimOf<rhs, rhsDim>.result #
+ " == " # DimOf<output, outputDim>.result>;
+
class SameDimsOrDynamicPred<string lhs, int lhsDim, string rhs, int rhsDim> :
- CPred<"[](::mlir::ShapedType lhsType, ::mlir::ShapedType rhsType) { "
- " int64_t lhsSize = lhsType.getDimSize(" # lhsDim # "); "
- " int64_t rhsSize = rhsType.getDimSize(" # rhsDim # "); "
- " return ::mlir::ShapedType::isDynamic(lhsSize) || "
- " ::mlir::ShapedType::isDynamic(rhsSize) || lhsSize == rhsSize; "
- "}("
- "::llvm::cast<::mlir::ShapedType>($" # lhs # ".getType()), "
- "::llvm::cast<::mlir::ShapedType>($" # rhs # ".getType()))">;
+ Or<[
+ DimIsDynamic<lhs, lhsDim>,
+ DimIsDynamic<rhs, rhsDim>,
+ DimsMatch<lhs, lhsDim, rhs, rhsDim>
+ ]>;
+
+class SameDimsOrOneOrDynamicPred<string lhs, int lhsDim, string rhs, int rhsDim> :
+ Or<[
+ SameDimsOrDynamicPred<lhs, lhsDim, rhs, rhsDim>,
+ DimIsOne<lhs, lhsDim>
+ ]>;
+
+class ProductDimOrOneOrDynamicPred<string lhs, int lhsDim, string rhs,
+ int rhsDim, string output, int outputDim> :
+ Or<[
+ DimIsDynamic<lhs, lhsDim>,
+ DimIsDynamic<rhs, rhsDim>,
+ DimIsDynamic<output, outputDim>,
+ ProductDimsMatch<lhs, lhsDim, rhs, rhsDim, output, outputDim>,
+ DimIsOne<output, outputDim>
+ ]>;
class ValuesIndicesShapesMatch<string values, string indices, string tensor>:
PredOpTrait<"shapes of " # values # ", " # indices # ", and " # tensor #
@@ -234,6 +267,41 @@ class NHWCInputOutputShapeMatch<string input, string output>:
SameDimsOrDynamicPred<input, 3, output, 3>
]>>;
+
+class Conv2DShapeMatch<string input, string weight, string bias, string output>:
+ PredOpTrait<"shapes of " # input # ", " # weight # ", " # bias # ", and " #
+ output # " must satisfy [N,IH,IW,IC], [OC,KH,KW,IC], "
+ "[OC/1], [N,OH,OW,OC]",
+ And<[
+ SameDimsOrDynamicPred<input, 0, output, 0>,
+ SameDimsOrDynamicPred<input, 3, weight, 3>,
+ SameDimsOrDynamicPred<weight, 0, output, 3>,
+ SameDimsOrOneOrDynamicPred<bias, 0, output, 3>
+ ]>>;
+
+class Conv3DShapeMatch<string input, string weight, string bias, string output>:
+ PredOpTrait<"shapes of " # input # ", " # weight # ", " # bias # ", and " #
+ output # " must satisfy [N,ID,IH,IW,IC], [OC,KD,KH,KW,IC], "
+ "[OC/1], [N,OD,OH,OW,OC]",
+ And<[
+ SameDimsOrDynamicPred<input, 0, output, 0>,
+ SameDimsOrDynamicPred<input, 4, weight, 4>,
+ SameDimsOrDynamicPred<weight, 0, output, 4>,
+ SameDimsOrOneOrDynamicPred<bias, 0, output, 4>
+ ]>>;
+
+class DepthwiseConv2DShapeMatch<string input, string weight, string bias,
+ string output>:
+ PredOpTrait<"shapes of " # input # ", " # weight # ", " # bias # ", and " #
+ output # " must satisfy [N,IH,IW,IC], [KH,KW,IC,M], "
+ "[IC*M/1], [N,OH,OW,IC*M]",
+ And<[
+ SameDimsOrDynamicPred<input, 0, output, 0>,
+ SameDimsOrDynamicPred<input, 3, weight, 2>,
+ ProductDimOrOneOrDynamicPred<input, 3, weight, 3, output, 3>,
+ SameDimsOrOneOrDynamicPred<bias, 0, output, 3>
+ ]>>;
+
class FetchNthIntElementsAttr<string attrName, int idx> :
StrFunc<"get" # snakeCaseToCamelCase<attrName>.ret # "().getValues<APInt>()[" # idx # "].getSExtValue()">;
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
index a2cc8be54e4f3..610bc843c8c0c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
@@ -111,6 +111,176 @@ LogicalResult verifyPool2DOp(Operation *op, DenseIntElementsAttr kernel,
return success();
}
+LogicalResult verifyConvolutionOutputDim(
+ Operation *op, int64_t inputSize, int64_t kernelSize, int64_t outputSize,
+ int64_t padBefore, int64_t padAfter, int64_t strideSize,
+ int64_t dilationSize, StringRef dimName, StringRef dimAxis,
+ StringRef padBeforeName, StringRef padAfterName, StringRef errorMessage) {
+ if (ShapedType::isDynamic(inputSize) || ShapedType::isDynamic(kernelSize))
+ return success();
+
+ const int64_t numerator =
+ inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilationSize;
+ if (numerator % strideSize != 0)
+ return op->emitOpError("expected input_")
+ << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
+ << padAfterName << " - (kernel_" << dimName << " - 1) * dilation_"
+ << dimAxis << " to be wholly divisible by stride_" << dimAxis
+ << ", got (" << inputSize << " - 1 + " << padBefore << " + "
+ << padAfter << " - (" << kernelSize << " - 1) * " << dilationSize
+ << ") / " << strideSize;
+
+ const int64_t calculatedOutput = numerator / strideSize + 1;
+ if (!ShapedType::isDynamic(outputSize) && outputSize != calculatedOutput)
+ return op->emitOpError(errorMessage);
+
+ return success();
+}
+
+LogicalResult verifyTransposeConvolutionOutputDim(
+ Operation *op, int64_t inputSize, int64_t kernelSize, int64_t outputSize,
+ int64_t padBefore, int64_t padAfter, int64_t strideSize,
+ StringRef errorMessage) {
+ if (ShapedType::isDynamic(inputSize) || ShapedType::isDynamic(kernelSize))
+ return success();
+
+ const int64_t calculatedOutput =
+ (inputSize - 1) * strideSize + padBefore + padAfter + kernelSize;
+ if (!ShapedType::isDynamic(outputSize) && outputSize != calculatedOutput)
+ return op->emitOpError(errorMessage);
+
+ return success();
+}
+
+LogicalResult verifyConv2DOutputShape(Operation *op, DenseIntElementsAttr pad,
+ DenseIntElementsAttr stride,
+ DenseIntElementsAttr dilation,
+ TensorArmType inputType,
+ TensorArmType weightType,
+ TensorArmType outputType) {
+ constexpr StringLiteral errorMessage =
+ "failed to verify that shapes of input, weight, and output must satisfy "
+ "[N,IH,IW,*], [*,KH,KW,*], [N,OH,OW,*], with OH = ((IH - 1 + pad_top + "
+ "pad_bottom - (KH - 1) * dilation_y) / stride_y) + 1 and OW = ((IW - 1 "
+ "+ pad_left + pad_right - (KW - 1) * dilation_x) / stride_x) + 1";
+ if (!inputType.hasRank() || !weightType.hasRank() || !outputType.hasRank())
+ return success();
+
+ if (failed(verifyConvolutionOutputDim(
+ op, inputType.getDimSize(1), weightType.getDimSize(1),
+ outputType.getDimSize(1), getIntValue(pad, 0), getIntValue(pad, 1),
+ getIntValue(stride, 0), getIntValue(dilation, 0), "height", "y",
+ "top", "bottom", errorMessage)))
+ return failure();
+
+ return verifyConvolutionOutputDim(
+ op, inputType.getDimSize(2), weightType.getDimSize(2),
+ outputType.getDimSize(2), getIntValue(pad, 2), getIntValue(pad, 3),
+ getIntValue(stride, 1), getIntValue(dilation, 1), "width", "x", "left",
+ "right", errorMessage);
+}
+
+LogicalResult verifyConv3DOutputShape(Operation *op, DenseIntElementsAttr pad,
+ DenseIntElementsAttr stride,
+ DenseIntElementsAttr dilation,
+ TensorArmType inputType,
+ TensorArmType weightType,
+ TensorArmType outputType) {
+ constexpr StringLiteral errorMessage =
+ "failed to verify that shapes of input, weight, and output must satisfy "
+ "[N,ID,IH,IW,*], [*,KD,KH,KW,*], [N,OD,OH,OW,*], with OD = ((ID - 1 + "
+ "pad_front + pad_back - (KD - 1) * dilation_d) / stride_d) + 1, OH = "
+ "((IH - 1 + pad_top + pad_bottom - (KH - 1) * dilation_y) / stride_y) "
+ "+ 1 and OW = ((IW - 1 + pad_left + pad_right - (KW - 1) * dilation_x) "
+ "/ stride_x) + 1";
+ if (!inputType.hasRank() || !weightType.hasRank() || !outputType.hasRank())
+ return success();
+
+ if (failed(verifyConvolutionOutputDim(
+ op, inputType.getDimSize(1), weightType.getDimSize(1),
+ outputType.getDimSize(1), getIntValue(pad, 0), getIntValue(pad, 1),
+ getIntValue(stride, 0), getIntValue(dilation, 0), "depth", "d",
+ "front", "back", errorMessage)) ||
+ failed(verifyConvolutionOutputDim(
+ op, inputType.getDimSize(2), weightType.getDimSize(2),
+ outputType.getDimSize(2), getIntValue(pad, 2), getIntValue(pad, 3),
+ getIntValue(stride, 1), getIntValue(dilation, 1), "height", "y",
+ "top", "bottom", errorMessage)))
+ return failure();
+
+ return verifyConvolutionOutputDim(
+ op, inputType.getDimSize(3), weightType.getDimSize(3),
+ outputType.getDimSize(3), getIntValue(pad, 4), getIntValue(pad, 5),
+ getIntValue(stride, 2), getIntValue(dilation, 2), "width", "x", "left",
+ "right", errorMessage);
+}
+
+LogicalResult verifyDepthwiseConv2DOutputShape(
+ Operation *op, DenseIntElementsAttr pad, DenseIntElementsAttr stride,
+ DenseIntElementsAttr dilation, TensorArmType inputType,
+ TensorArmType weightType, TensorArmType outputType) {
+ constexpr StringLiteral errorMessage =
+ "failed to verify that shapes of input, weight, and output must satisfy "
+ "[N,IH,IW,*], [KH,KW,*,*], [N,OH,OW,*], with OH = ((IH - 1 + pad_top + "
+ "pad_bottom - (KH - 1) * dilation_y) / stride_y) + 1 and OW = ((IW - 1 "
+ "+ pad_left + pad_right - (KW - 1) * dilation_x) / stride_x) + 1";
+ if (!inputType.hasRank() || !weightType.hasRank() || !outputType.hasRank())
+ return success();
+
+ if (failed(verifyConvolutionOutputDim(
+ op, inputType.getDimSize(1), weightType.getDimSize(0),
+ outputType.getDimSize(1), getIntValue(pad, 0), getIntValue(pad, 1),
+ getIntValue(stride, 0), getIntValue(dilation, 0), "height", "y",
+ "top", "bottom", errorMessage)))
+ return failure();
+
+ return verifyConvolutionOutputDim(
+ op, inputType.getDimSize(2), weightType.getDimSize(1),
+ outputType.getDimSize(2), getIntValue(pad, 2), getIntValue(pad, 3),
+ getIntValue(stride, 1), getIntValue(dilation, 1), "width", "x", "left",
+ "right", errorMessage);
+}
+
+LogicalResult verifyTransposeConv2DOutputShape(Operation *op,
+ DenseIntElementsAttr outPad,
+ DenseIntElementsAttr stride,
+ TensorArmType inputType,
+ TensorArmType weightType,
+ TensorArmType outputType) {
+ constexpr StringLiteral errorMessage =
+ "failed to verify that shapes of input, weight, and output must satisfy "
+ "[N,IH,IW,*], [*,KH,KW,*], [N,OH,OW,*], with OH = (IH - 1) * stride_y + "
+ "out_pad_top + out_pad_bottom + KH and OW = (IW - 1) * stride_x + "
+ "out_pad_left + out_pad_right + KW";
+ if (!inputType.hasRank() || !weightType.hasRank() || !outputType.hasRank())
+ return success();
+
+ const int64_t kernelHeight = weightType.getDimSize(1);
+ if (ShapedType::isStatic(kernelHeight) &&
+ (getIntValue(outPad, 0) <= -kernelHeight ||
+ getIntValue(outPad, 1) <= -kernelHeight))
+ return op->...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/194592
More information about the Mlir-commits
mailing list