[Mlir-commits] [mlir] [mlir][spirv] Tighten SPIR-V TOSA convolution verification (PR #194592)
Igor Wodiany
llvmlistbot at llvm.org
Tue Apr 28 08:53:08 PDT 2026
================
@@ -111,18 +111,212 @@ LogicalResult verifyPool2DOp(Operation *op, DenseIntElementsAttr kernel,
return success();
}
+LogicalResult verifyConvolutionOutputDim(
+ Operation *op, int64_t inputSize, int64_t kernelSize, int64_t outputSize,
+ int64_t padBefore, int64_t padAfter, int64_t strideSize,
+ int64_t dilationSize, StringRef dimName, StringRef dimAxis,
+ StringRef padBeforeName, StringRef padAfterName, StringRef errorMessage) {
+ if (ShapedType::isDynamic(inputSize) || ShapedType::isDynamic(kernelSize))
+ return success();
+
+ const int64_t numerator =
+ inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilationSize;
+ if (numerator % strideSize != 0)
+ return op->emitOpError("expected input_")
+ << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
+ << padAfterName << " - (kernel_" << dimName << " - 1) * dilation_"
+ << dimAxis << " to be wholly divisible by stride_" << dimAxis
+ << ", got (" << inputSize << " - 1 + " << padBefore << " + "
+ << padAfter << " - (" << kernelSize << " - 1) * " << dilationSize
+ << ") / " << strideSize;
+
+ const int64_t calculatedOutput = numerator / strideSize + 1;
+ if (!ShapedType::isDynamic(outputSize) && outputSize != calculatedOutput)
+ return op->emitOpError(errorMessage);
+
+ return success();
+}
+
+LogicalResult verifyTransposeConvolutionOutputDim(
+ Operation *op, int64_t inputSize, int64_t kernelSize, int64_t outputSize,
+ int64_t padBefore, int64_t padAfter, int64_t strideSize,
+ StringRef errorMessage) {
+ if (ShapedType::isDynamic(inputSize) || ShapedType::isDynamic(kernelSize))
+ return success();
+
+ const int64_t calculatedOutput =
+ (inputSize - 1) * strideSize + padBefore + padAfter + kernelSize;
+ if (!ShapedType::isDynamic(outputSize) && outputSize != calculatedOutput)
+ return op->emitOpError(errorMessage);
+
+ return success();
+}
+
+LogicalResult verifyConv2DOutputShape(Operation *op, DenseIntElementsAttr pad,
+ DenseIntElementsAttr stride,
+ DenseIntElementsAttr dilation,
+ TensorArmType inputType,
+ TensorArmType weightType,
+ TensorArmType outputType) {
+ constexpr StringLiteral errorMessage =
+ "failed to verify that shapes of input, weight, and output must satisfy "
+ "[N,IH,IW,*], [*,KH,KW,*], [N,OH,OW,*], with OH = ((IH - 1 + pad_top + "
+ "pad_bottom - (KH - 1) * dilation_y) / stride_y) + 1 and OW = ((IW - 1 "
+ "+ pad_left + pad_right - (KW - 1) * dilation_x) / stride_x) + 1";
+ if (!inputType.hasRank() || !weightType.hasRank() || !outputType.hasRank())
+ return success();
+
+ if (failed(verifyConvolutionOutputDim(
+ op, inputType.getDimSize(1), weightType.getDimSize(1),
+ outputType.getDimSize(1), getIntValue(pad, 0), getIntValue(pad, 1),
+ getIntValue(stride, 0), getIntValue(dilation, 0), "height", "y",
+ "top", "bottom", errorMessage)))
+ return failure();
+
+ return verifyConvolutionOutputDim(
+ op, inputType.getDimSize(2), weightType.getDimSize(2),
+ outputType.getDimSize(2), getIntValue(pad, 2), getIntValue(pad, 3),
+ getIntValue(stride, 1), getIntValue(dilation, 1), "width", "x", "left",
+ "right", errorMessage);
+}
+
+LogicalResult verifyConv3DOutputShape(Operation *op, DenseIntElementsAttr pad,
+ DenseIntElementsAttr stride,
+ DenseIntElementsAttr dilation,
+ TensorArmType inputType,
+ TensorArmType weightType,
+ TensorArmType outputType) {
+ constexpr StringLiteral errorMessage =
+ "failed to verify that shapes of input, weight, and output must satisfy "
+ "[N,ID,IH,IW,*], [*,KD,KH,KW,*], [N,OD,OH,OW,*], with OD = ((ID - 1 + "
+ "pad_front + pad_back - (KD - 1) * dilation_d) / stride_d) + 1, OH = "
+ "((IH - 1 + pad_top + pad_bottom - (KH - 1) * dilation_y) / stride_y) "
+ "+ 1 and OW = ((IW - 1 + pad_left + pad_right - (KW - 1) * dilation_x) "
+ "/ stride_x) + 1";
+ if (!inputType.hasRank() || !weightType.hasRank() || !outputType.hasRank())
+ return success();
+
+ if (failed(verifyConvolutionOutputDim(
+ op, inputType.getDimSize(1), weightType.getDimSize(1),
+ outputType.getDimSize(1), getIntValue(pad, 0), getIntValue(pad, 1),
+ getIntValue(stride, 0), getIntValue(dilation, 0), "depth", "d",
+ "front", "back", errorMessage)) ||
+ failed(verifyConvolutionOutputDim(
----------------
IgWod wrote:
I think splitting it into 2 `if`s would be easier to follow.
https://github.com/llvm/llvm-project/pull/194592
More information about the Mlir-commits
mailing list