[Mlir-commits] [mlir] [mlir][spirv] Tighten SPIR-V TOSA convolution verification (PR #194592)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 28 03:55:22 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-spirv

Author: Davide Grohmann (davidegrohmann)

<details>
<summary>Changes</summary>

Add verifier coverage for SPIR-V TOSA convolution ops against the TOSA shape and type constraints.

This adds shared TableGen shape predicates for Conv2D, Conv3D, DepthwiseConv2D and TransposeConv2D, including batch/channel/bias relationships. It also constrains integer convolution weights so i8 and i16 inputs use i8 weights, matching the SPIR-V TOSA representation.

Add custom verifiers for the convolution output shape formulas, including stride divisibility for regular convolutions and out_pad bounds for TransposeConv2D. Tighten pad, stride and dilation attributes to use non-negative or positive i32 attribute constraints where required.

---

Patch is 80.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/194592.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td (+34-20) 
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td (+78-10) 
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp (+194) 
- (modified) mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir (+228-60) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index 483ce4348971b..e6336bf011e9d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -158,6 +158,8 @@ class SPIRV_TosaConvolutionOp<string mnemonic, int opcode, list<Trait> traits =
     TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
     TypeConstraintImplicationOn<"input", F8E4M3FN, "output", [F16]>,
     TypeConstraintImplicationOn<"input", F8E5M2, "output", [F16]>,
+    TypeConstraintImplicationOn<"input", I8, "weight", [I8]>,
+    TypeConstraintImplicationOn<"input", I16, "weight", [I8]>,
     TypeConstraintImplicationOn<"input", BF16, "weight", [BF16]>,
     TypeConstraintImplicationOn<"input", F16, "weight", [F16]>,
     TypeConstraintImplicationOn<"input", F32, "weight", [F32]>,
@@ -310,9 +312,9 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffe
   }];
 
   let arguments = (ins
-    SPIRV_PositiveInt32_1DTensorArmOfLength2Attr: $kernel,
-    SPIRV_PositiveInt32_1DTensorArmOfLength2Attr: $stride,
-    SPIRV_NonNegativeInt32_1DTensorArmOfLength4Attr: $pad,
+    SPIRV_PositiveI32_1DTensorArmOfLength2Attr: $kernel,
+    SPIRV_PositiveI32_1DTensorArmOfLength2Attr: $stride,
+    SPIRV_NonNegativeI32_1DTensorArmOfLength4Attr: $pad,
     SPIRV_TosaExtAccTypeAttr: $acc_type,
     SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input,
     SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $input_zp,
@@ -344,7 +346,8 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffe
 }
 
 
-def SPIRV_TosaConv2DOp : SPIRV_TosaConvolutionOp<"Conv2D", 2> {
+def SPIRV_TosaConv2DOp : SPIRV_TosaConvolutionOp<"Conv2D", 2, [
+  Conv2DShapeMatch<"input", "weight", "bias", "output">]> {
   let summary = "2D Convolution operator.";
 
   let description = [{
@@ -368,9 +371,9 @@ def SPIRV_TosaConv2DOp : SPIRV_TosaConvolutionOp<"Conv2D", 2> {
   }];
 
   let arguments = (ins
-    SPIRV_I32_1DTensorArmOfLength4Attr: $pad,
-    SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
-    SPIRV_I32_1DTensorArmOfLength2Attr: $dilation,
+    SPIRV_NonNegativeI32_1DTensorArmOfLength4Attr: $pad,
+    SPIRV_PositiveI32_1DTensorArmOfLength2Attr: $stride,
+    SPIRV_PositiveI32_1DTensorArmOfLength2Attr: $dilation,
     SPIRV_TosaExtAccTypeAttr: $acc_type,
     SPIRV_BoolConstAttr: $local_bound,
     SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input,
@@ -397,10 +400,13 @@ def SPIRV_TosaConv2DOp : SPIRV_TosaConvolutionOp<"Conv2D", 2> {
     $weight_zp
     attr-dict `:` type(operands) `->` type(results)
   }];
+
+  let hasVerifier = 1;
 }
 
 
-def SPIRV_TosaConv3DOp : SPIRV_TosaConvolutionOp<"Conv3D", 3> {
+def SPIRV_TosaConv3DOp : SPIRV_TosaConvolutionOp<"Conv3D", 3, [
+  Conv3DShapeMatch<"input", "weight", "bias", "output">]> {
   let summary = "3D Convolution operator.";
 
   let description = [{
@@ -423,9 +429,9 @@ def SPIRV_TosaConv3DOp : SPIRV_TosaConvolutionOp<"Conv3D", 3> {
   }];
 
   let arguments = (ins
-    SPIRV_I32_1DTensorArmOfLength6Attr: $pad,
-    SPIRV_I32_1DTensorArmOfLength3Attr: $stride,
-    SPIRV_I32_1DTensorArmOfLength3Attr: $dilation,
+    SPIRV_NonNegativeI32_1DTensorArmOfLength6Attr: $pad,
+    SPIRV_PositiveI32_1DTensorArmOfLength3Attr: $stride,
+    SPIRV_PositiveI32_1DTensorArmOfLength3Attr: $dilation,
     SPIRV_TosaExtAccTypeAttr: $acc_type,
     SPIRV_BoolConstAttr: $local_bound,
     SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm5D: $input,
@@ -452,10 +458,13 @@ def SPIRV_TosaConv3DOp : SPIRV_TosaConvolutionOp<"Conv3D", 3> {
     $weight_zp
     attr-dict `:` type(operands) `->` type(results)
   }];
+
+  let hasVerifier = 1;
 }
 
 
-def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaConvolutionOp<"DepthwiseConv2D", 4> {
+def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaConvolutionOp<"DepthwiseConv2D", 4, [
+  DepthwiseConv2DShapeMatch<"input", "weight", "bias", "output">]> {
   let summary = "Depthwise 2D Convolution operator.";
 
   let description = [{
@@ -479,9 +488,9 @@ def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaConvolutionOp<"DepthwiseConv2D", 4>
   }];
 
   let arguments = (ins
-    SPIRV_I32_1DTensorArmOfLength4Attr: $pad,
-    SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
-    SPIRV_I32_1DTensorArmOfLength2Attr: $dilation,
+    SPIRV_NonNegativeI32_1DTensorArmOfLength4Attr: $pad,
+    SPIRV_PositiveI32_1DTensorArmOfLength2Attr: $stride,
+    SPIRV_PositiveI32_1DTensorArmOfLength2Attr: $dilation,
     SPIRV_TosaExtAccTypeAttr: $acc_type,
     SPIRV_BoolConstAttr: $local_bound,
     SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input,
@@ -508,6 +517,8 @@ def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaConvolutionOp<"DepthwiseConv2D", 4>
     $weight_zp
     attr-dict `:` type(operands) `->` type(results)
   }];
+
+  let hasVerifier = 1;
 }
 
 
@@ -646,9 +657,9 @@ def SPIRV_TosaMaxPool2DOp : SPIRV_TosaOpWithResult<"MaxPool2D", 7, [Pure,
   }];
 
   let arguments = (ins
-    SPIRV_PositiveInt32_1DTensorArmOfLength2Attr: $kernel,
-    SPIRV_PositiveInt32_1DTensorArmOfLength2Attr: $stride,
-    SPIRV_NonNegativeInt32_1DTensorArmOfLength4Attr: $pad,
+    SPIRV_PositiveI32_1DTensorArmOfLength2Attr: $kernel,
+    SPIRV_PositiveI32_1DTensorArmOfLength2Attr: $stride,
+    SPIRV_NonNegativeI32_1DTensorArmOfLength4Attr: $pad,
     SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
     SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input
   );
@@ -724,7 +735,8 @@ def SPIRV_TosaRFFT2DOp : SPIRV_TosaOpWithComplexResult<"RFFT2D", 8, [Pure]> {
 }
 
 
-def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaConvolutionOp<"TransposeConv2D", 9> {
+def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaConvolutionOp<"TransposeConv2D", 9, [
+  Conv2DShapeMatch<"input", "weight", "bias", "output">]> {
   let summary = "Transpose 2D Convolution operator.";
 
   let description = [{
@@ -749,7 +761,7 @@ def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaConvolutionOp<"TransposeConv2D", 9>
 
   let arguments = (ins
     SPIRV_I32_1DTensorArmOfLength4Attr: $out_pad,
-    SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
+    SPIRV_PositiveI32_1DTensorArmOfLength2Attr: $stride,
     SPIRV_TosaExtAccTypeAttr: $acc_type,
     SPIRV_BoolConstAttr: $local_bound,
     SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input,
@@ -775,6 +787,8 @@ def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaConvolutionOp<"TransposeConv2D", 9>
     $weight_zp
     attr-dict `:` type(operands) `->` type(results)
   }];
+
+  let hasVerifier = 1;
 }
 
 
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
index 7064930c5864a..bc981d8bf95cf 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -121,8 +121,10 @@ class IntElementsAttrAllValuesAtLeast<int minValue> : AttrConstraint<
         minValue # "; })">,
   "all values must be >= " # minValue>;
 
-def SPIRV_PositiveInt32_1DTensorArmOfLength2Attr : ConfinedAttr<RankedI32ElementsAttr<[2]>, [SPIRV_DenseElementAttrsWithTensorArmType, IntElementsAttrAllValuesAtLeast<1>]>;
-def SPIRV_NonNegativeInt32_1DTensorArmOfLength4Attr : ConfinedAttr<RankedI32ElementsAttr<[4]>, [SPIRV_DenseElementAttrsWithTensorArmType, IntElementsAttrAllValuesAtLeast<0>]>;
+def SPIRV_PositiveI32_1DTensorArmOfLength2Attr : ConfinedAttr<RankedI32ElementsAttr<[2]>, [SPIRV_DenseElementAttrsWithTensorArmType, IntElementsAttrAllValuesAtLeast<1>]>;
+def SPIRV_PositiveI32_1DTensorArmOfLength3Attr : ConfinedAttr<RankedI32ElementsAttr<[3]>, [SPIRV_DenseElementAttrsWithTensorArmType, IntElementsAttrAllValuesAtLeast<1>]>;
+def SPIRV_NonNegativeI32_1DTensorArmOfLength4Attr : ConfinedAttr<RankedI32ElementsAttr<[4]>, [SPIRV_DenseElementAttrsWithTensorArmType, IntElementsAttrAllValuesAtLeast<0>]>;
+def SPIRV_NonNegativeI32_1DTensorArmOfLength6Attr : ConfinedAttr<RankedI32ElementsAttr<[6]>, [SPIRV_DenseElementAttrsWithTensorArmType, IntElementsAttrAllValuesAtLeast<0>]>;
 
 class Is1DTensorArmAttrOfLength<list<int> allowedLengths> :
   AttrConstraint<And<[CPred<[{::llvm::cast<::mlir::spirv::TensorArmType>(::llvm::cast<::mlir::DenseElementsAttr>($_self).getType()).getShape().size() == 1 }]>,
@@ -204,15 +206,46 @@ class MatchBroadcastableShapes<string input1, string input2, string output>:
     "})">]>
   >;
 
+class DimOf<string input, int dim> :
+  StrFunc<"::llvm::cast<::mlir::ShapedType>($" # input #
+          ".getType()).getDimSize(" # dim # ")">;
+
+class DimIsDynamic<string input, int dim> :
+  CPred<"::mlir::ShapedType::isDynamic(" # DimOf<input, dim>.result # ")">;
+
+class DimIsOne<string input, int dim> :
+  CPred<DimOf<input, dim>.result # " == 1">;
+
+class DimsMatch<string lhs, int lhsDim, string rhs, int rhsDim> :
+  CPred<DimOf<lhs, lhsDim>.result # " == " # DimOf<rhs, rhsDim>.result>;
+
+class ProductDimsMatch<string lhs, int lhsDim, string rhs, int rhsDim,
+                       string output, int outputDim> :
+  CPred<DimOf<lhs, lhsDim>.result # " * " # DimOf<rhs, rhsDim>.result #
+        " == " # DimOf<output, outputDim>.result>;
+
 class SameDimsOrDynamicPred<string lhs, int lhsDim, string rhs, int rhsDim> :
-  CPred<"[](::mlir::ShapedType lhsType, ::mlir::ShapedType rhsType) { "
-        "  int64_t lhsSize = lhsType.getDimSize(" # lhsDim # "); "
-        "  int64_t rhsSize = rhsType.getDimSize(" # rhsDim # "); "
-        "  return ::mlir::ShapedType::isDynamic(lhsSize) || "
-        "         ::mlir::ShapedType::isDynamic(rhsSize) || lhsSize == rhsSize; "
-        "}("
-        "::llvm::cast<::mlir::ShapedType>($" # lhs # ".getType()), "
-        "::llvm::cast<::mlir::ShapedType>($" # rhs # ".getType()))">;
+  Or<[
+    DimIsDynamic<lhs, lhsDim>,
+    DimIsDynamic<rhs, rhsDim>,
+    DimsMatch<lhs, lhsDim, rhs, rhsDim>
+  ]>;
+
+class SameDimsOrOneOrDynamicPred<string lhs, int lhsDim, string rhs, int rhsDim> :
+  Or<[
+    SameDimsOrDynamicPred<lhs, lhsDim, rhs, rhsDim>,
+    DimIsOne<lhs, lhsDim>
+  ]>;
+
+class ProductDimOrOneOrDynamicPred<string lhs, int lhsDim, string rhs,
+                                   int rhsDim, string output, int outputDim> :
+  Or<[
+    DimIsDynamic<lhs, lhsDim>,
+    DimIsDynamic<rhs, rhsDim>,
+    DimIsDynamic<output, outputDim>,
+    ProductDimsMatch<lhs, lhsDim, rhs, rhsDim, output, outputDim>,
+    DimIsOne<output, outputDim>
+  ]>;
 
 class ValuesIndicesShapesMatch<string values, string indices, string tensor>:
   PredOpTrait<"shapes of " # values # ", " # indices # ", and " # tensor #
@@ -234,6 +267,41 @@ class NHWCInputOutputShapeMatch<string input, string output>:
       SameDimsOrDynamicPred<input, 3, output, 3>
     ]>>;
 
+
+class Conv2DShapeMatch<string input, string weight, string bias, string output>:
+  PredOpTrait<"shapes of " # input # ", " # weight # ", " # bias # ", and " #
+                  output # " must satisfy [N,IH,IW,IC], [OC,KH,KW,IC], "
+                  "[OC/1], [N,OH,OW,OC]",
+    And<[
+      SameDimsOrDynamicPred<input, 0, output, 0>,
+      SameDimsOrDynamicPred<input, 3, weight, 3>,
+      SameDimsOrDynamicPred<weight, 0, output, 3>,
+      SameDimsOrOneOrDynamicPred<bias, 0, output, 3>
+    ]>>;
+
+class Conv3DShapeMatch<string input, string weight, string bias, string output>:
+  PredOpTrait<"shapes of " # input # ", " # weight # ", " # bias # ", and " #
+                  output # " must satisfy [N,ID,IH,IW,IC], [OC,KD,KH,KW,IC], "
+                  "[OC/1], [N,OD,OH,OW,OC]",
+    And<[
+      SameDimsOrDynamicPred<input, 0, output, 0>,
+      SameDimsOrDynamicPred<input, 4, weight, 4>,
+      SameDimsOrDynamicPred<weight, 0, output, 4>,
+      SameDimsOrOneOrDynamicPred<bias, 0, output, 4>
+    ]>>;
+
+class DepthwiseConv2DShapeMatch<string input, string weight, string bias,
+                                string output>:
+  PredOpTrait<"shapes of " # input # ", " # weight # ", " # bias # ", and " #
+                  output # " must satisfy [N,IH,IW,IC], [KH,KW,IC,M], "
+                  "[IC*M/1], [N,OH,OW,IC*M]",
+    And<[
+      SameDimsOrDynamicPred<input, 0, output, 0>,
+      SameDimsOrDynamicPred<input, 3, weight, 2>,
+      ProductDimOrOneOrDynamicPred<input, 3, weight, 3, output, 3>,
+      SameDimsOrOneOrDynamicPred<bias, 0, output, 3>
+    ]>>;
+
 class FetchNthIntElementsAttr<string attrName, int idx> :
   StrFunc<"get" # snakeCaseToCamelCase<attrName>.ret # "().getValues<APInt>()[" # idx # "].getSExtValue()">;
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
index a2cc8be54e4f3..610bc843c8c0c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
@@ -111,6 +111,176 @@ LogicalResult verifyPool2DOp(Operation *op, DenseIntElementsAttr kernel,
   return success();
 }
 
+LogicalResult verifyConvolutionOutputDim(
+    Operation *op, int64_t inputSize, int64_t kernelSize, int64_t outputSize,
+    int64_t padBefore, int64_t padAfter, int64_t strideSize,
+    int64_t dilationSize, StringRef dimName, StringRef dimAxis,
+    StringRef padBeforeName, StringRef padAfterName, StringRef errorMessage) {
+  if (ShapedType::isDynamic(inputSize) || ShapedType::isDynamic(kernelSize))
+    return success();
+
+  const int64_t numerator =
+      inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilationSize;
+  if (numerator % strideSize != 0)
+    return op->emitOpError("expected input_")
+           << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
+           << padAfterName << " - (kernel_" << dimName << " - 1) * dilation_"
+           << dimAxis << " to be wholly divisible by stride_" << dimAxis
+           << ", got (" << inputSize << " - 1 + " << padBefore << " + "
+           << padAfter << " - (" << kernelSize << " - 1) * " << dilationSize
+           << ") / " << strideSize;
+
+  const int64_t calculatedOutput = numerator / strideSize + 1;
+  if (!ShapedType::isDynamic(outputSize) && outputSize != calculatedOutput)
+    return op->emitOpError(errorMessage);
+
+  return success();
+}
+
+LogicalResult verifyTransposeConvolutionOutputDim(
+    Operation *op, int64_t inputSize, int64_t kernelSize, int64_t outputSize,
+    int64_t padBefore, int64_t padAfter, int64_t strideSize,
+    StringRef errorMessage) {
+  if (ShapedType::isDynamic(inputSize) || ShapedType::isDynamic(kernelSize))
+    return success();
+
+  const int64_t calculatedOutput =
+      (inputSize - 1) * strideSize + padBefore + padAfter + kernelSize;
+  if (!ShapedType::isDynamic(outputSize) && outputSize != calculatedOutput)
+    return op->emitOpError(errorMessage);
+
+  return success();
+}
+
+LogicalResult verifyConv2DOutputShape(Operation *op, DenseIntElementsAttr pad,
+                                      DenseIntElementsAttr stride,
+                                      DenseIntElementsAttr dilation,
+                                      TensorArmType inputType,
+                                      TensorArmType weightType,
+                                      TensorArmType outputType) {
+  constexpr StringLiteral errorMessage =
+      "failed to verify that shapes of input, weight, and output must satisfy "
+      "[N,IH,IW,*], [*,KH,KW,*], [N,OH,OW,*], with OH = ((IH - 1 + pad_top + "
+      "pad_bottom - (KH - 1) * dilation_y) / stride_y) + 1 and OW = ((IW - 1 "
+      "+ pad_left + pad_right - (KW - 1) * dilation_x) / stride_x) + 1";
+  if (!inputType.hasRank() || !weightType.hasRank() || !outputType.hasRank())
+    return success();
+
+  if (failed(verifyConvolutionOutputDim(
+          op, inputType.getDimSize(1), weightType.getDimSize(1),
+          outputType.getDimSize(1), getIntValue(pad, 0), getIntValue(pad, 1),
+          getIntValue(stride, 0), getIntValue(dilation, 0), "height", "y",
+          "top", "bottom", errorMessage)))
+    return failure();
+
+  return verifyConvolutionOutputDim(
+      op, inputType.getDimSize(2), weightType.getDimSize(2),
+      outputType.getDimSize(2), getIntValue(pad, 2), getIntValue(pad, 3),
+      getIntValue(stride, 1), getIntValue(dilation, 1), "width", "x", "left",
+      "right", errorMessage);
+}
+
+LogicalResult verifyConv3DOutputShape(Operation *op, DenseIntElementsAttr pad,
+                                      DenseIntElementsAttr stride,
+                                      DenseIntElementsAttr dilation,
+                                      TensorArmType inputType,
+                                      TensorArmType weightType,
+                                      TensorArmType outputType) {
+  constexpr StringLiteral errorMessage =
+      "failed to verify that shapes of input, weight, and output must satisfy "
+      "[N,ID,IH,IW,*], [*,KD,KH,KW,*], [N,OD,OH,OW,*], with OD = ((ID - 1 + "
+      "pad_front + pad_back - (KD - 1) * dilation_d) / stride_d) + 1, OH = "
+      "((IH - 1 + pad_top + pad_bottom - (KH - 1) * dilation_y) / stride_y) "
+      "+ 1 and OW = ((IW - 1 + pad_left + pad_right - (KW - 1) * dilation_x) "
+      "/ stride_x) + 1";
+  if (!inputType.hasRank() || !weightType.hasRank() || !outputType.hasRank())
+    return success();
+
+  if (failed(verifyConvolutionOutputDim(
+          op, inputType.getDimSize(1), weightType.getDimSize(1),
+          outputType.getDimSize(1), getIntValue(pad, 0), getIntValue(pad, 1),
+          getIntValue(stride, 0), getIntValue(dilation, 0), "depth", "d",
+          "front", "back", errorMessage)) ||
+      failed(verifyConvolutionOutputDim(
+          op, inputType.getDimSize(2), weightType.getDimSize(2),
+          outputType.getDimSize(2), getIntValue(pad, 2), getIntValue(pad, 3),
+          getIntValue(stride, 1), getIntValue(dilation, 1), "height", "y",
+          "top", "bottom", errorMessage)))
+    return failure();
+
+  return verifyConvolutionOutputDim(
+      op, inputType.getDimSize(3), weightType.getDimSize(3),
+      outputType.getDimSize(3), getIntValue(pad, 4), getIntValue(pad, 5),
+      getIntValue(stride, 2), getIntValue(dilation, 2), "width", "x", "left",
+      "right", errorMessage);
+}
+
+LogicalResult verifyDepthwiseConv2DOutputShape(
+    Operation *op, DenseIntElementsAttr pad, DenseIntElementsAttr stride,
+    DenseIntElementsAttr dilation, TensorArmType inputType,
+    TensorArmType weightType, TensorArmType outputType) {
+  constexpr StringLiteral errorMessage =
+      "failed to verify that shapes of input, weight, and output must satisfy "
+      "[N,IH,IW,*], [KH,KW,*,*], [N,OH,OW,*], with OH = ((IH - 1 + pad_top + "
+      "pad_bottom - (KH - 1) * dilation_y) / stride_y) + 1 and OW = ((IW - 1 "
+      "+ pad_left + pad_right - (KW - 1) * dilation_x) / stride_x) + 1";
+  if (!inputType.hasRank() || !weightType.hasRank() || !outputType.hasRank())
+    return success();
+
+  if (failed(verifyConvolutionOutputDim(
+          op, inputType.getDimSize(1), weightType.getDimSize(0),
+          outputType.getDimSize(1), getIntValue(pad, 0), getIntValue(pad, 1),
+          getIntValue(stride, 0), getIntValue(dilation, 0), "height", "y",
+          "top", "bottom", errorMessage)))
+    return failure();
+
+  return verifyConvolutionOutputDim(
+      op, inputType.getDimSize(2), weightType.getDimSize(1),
+      outputType.getDimSize(2), getIntValue(pad, 2), getIntValue(pad, 3),
+      getIntValue(stride, 1), getIntValue(dilation, 1), "width", "x", "left",
+      "right", errorMessage);
+}
+
+LogicalResult verifyTransposeConv2DOutputShape(Operation *op,
+                                               DenseIntElementsAttr outPad,
+                                               DenseIntElementsAttr stride,
+                                               TensorArmType inputType,
+                                               TensorArmType weightType,
+                                               TensorArmType outputType) {
+  constexpr StringLiteral errorMessage =
+      "failed to verify that shapes of input, weight, and output must satisfy "
+      "[N,IH,IW,*], [*,KH,KW,*], [N,OH,OW,*], with OH = (IH - 1) * stride_y + "
+      "out_pad_top + out_pad_bottom + KH and OW = (IW - 1) * stride_x + "
+      "out_pad_left + out_pad_right + KW";
+  if (!inputType.hasRank() || !weightType.hasRank() || !outputType.hasRank())
+    return success();
+
+  const int64_t kernelHeight = weightType.getDimSize(1);
+  if (ShapedType::isStatic(kernelHeight) &&
+      (getIntValue(outPad, 0) <= -kernelHeight ||
+       getIntValue(outPad, 1) <= -kernelHeight))
+    return op->...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/194592


More information about the Mlir-commits mailing list