[Mlir-commits] [mlir] [mlir][spirv] Tighten SPIR-V TOSA pool constraints (PR #193515)
Davide Grohmann
llvmlistbot at llvm.org
Wed Apr 22 07:51:24 PDT 2026
https://github.com/davidegrohmann created https://github.com/llvm/llvm-project/pull/193515
Tighten AvgPool2D and MaxPool2D verification by constraining kernel, stride, and pad attributes and by checking the input/output NHWC relationship.
Add verification tests for batch/channel mismatches, non-divisible pooled shapes, pad-vs-kernel failures, and incorrect output shapes.
>From 5ae67ac06439057a7b61182a60b32de0f7dc24c0 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Wed, 22 Apr 2026 10:52:06 +0200
Subject: [PATCH] [mlir][spirv] Tighten SPIR-V TOSA pool constraints
Tighten AvgPool2D and MaxPool2D verification by constraining kernel,
stride, and pad attributes and by checking the input/output NHWC
relationship.
Add verification tests for batch/channel mismatches, non-divisible
pooled shapes, pad-vs-kernel failures, and incorrect output shapes.
Change-Id: Iaf7eb4ba34febd3211835858ee0af0430277e939
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
---
.../mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td | 22 +++--
.../mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td | 16 ++++
mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp | 90 +++++++++++++++++++
.../SPIRV/IR/tosa-ops-verification.mlir | 56 ++++++++++++
4 files changed, 176 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index db91e529dc623..12652f982de4b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -285,7 +285,8 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffe
TypeImpliesAccType<"input", F32, ["FP32"]>,
TypeImpliesAccType<"input", F8E4M3FN, ["FP16"]>,
TypeImpliesAccType<"input", F8E5M2, ["FP16"]>,
- AllElementTypesMatch<["input", "input_zp", "output", "output_zp"]>]> {
+ AllElementTypesMatch<["input", "input_zp", "output", "output_zp"]>,
+ NHWCInputOutputShapeMatch<"input", "output">]> {
let summary = "Performs average pooling on the input.";
let description = [{
@@ -308,9 +309,9 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffe
}];
let arguments = (ins
- SPIRV_I32_1DTensorArmOfLength2Attr: $kernel,
- SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
- SPIRV_I32_1DTensorArmOfLength4Attr: $pad,
+ SPIRV_PositiveInt32_1DTensorArmOfLength2Attr: $kernel,
+ SPIRV_PositiveInt32_1DTensorArmOfLength2Attr: $stride,
+ SPIRV_NonNegativeInt32_1DTensorArmOfLength4Attr: $pad,
SPIRV_TosaExtAccTypeAttr: $acc_type,
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input,
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $input_zp,
@@ -337,6 +338,8 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffe
return cast<::mlir::spirv::TensorArmType>(getInput().getType());
}
}];
+
+ let hasVerifier = 1;
}
@@ -619,7 +622,8 @@ def SPIRV_TosaMatMulOp : SPIRV_TosaOpWithResult<"MatMul", 6, [NoMemoryEffect,
def SPIRV_TosaMaxPool2DOp : SPIRV_TosaOpWithResult<"MaxPool2D", 7, [Pure,
- AllElementTypesMatch<["input", "output"]>]> {
+ AllElementTypesMatch<["input", "output"]>,
+ NHWCInputOutputShapeMatch<"input", "output">]> {
let summary = "Performs max pooling on the input.";
let description = [{
@@ -640,9 +644,9 @@ def SPIRV_TosaMaxPool2DOp : SPIRV_TosaOpWithResult<"MaxPool2D", 7, [Pure,
}];
let arguments = (ins
- SPIRV_I32_1DTensorArmOfLength2Attr: $kernel,
- SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
- SPIRV_I32_1DTensorArmOfLength4Attr: $pad,
+ SPIRV_PositiveInt32_1DTensorArmOfLength2Attr: $kernel,
+ SPIRV_PositiveInt32_1DTensorArmOfLength2Attr: $stride,
+ SPIRV_NonNegativeInt32_1DTensorArmOfLength4Attr: $pad,
SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input
);
@@ -665,6 +669,8 @@ def SPIRV_TosaMaxPool2DOp : SPIRV_TosaOpWithResult<"MaxPool2D", 7, [Pure,
return cast<::mlir::spirv::TensorArmType>(getInput().getType());
}
}];
+
+ let hasVerifier = 1;
}
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
index 5704911d7f53d..0376ed89d71e8 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -115,6 +115,14 @@ def SPIRV_I32_1DTensorArmOfLength3Attr : ConfinedAttr<RankedI32ElementsAttr<[3]>
def SPIRV_I32_1DTensorArmOfLength4Attr : ConfinedAttr<RankedI32ElementsAttr<[4]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
def SPIRV_I32_1DTensorArmOfLength5Attr : ConfinedAttr<RankedI32ElementsAttr<[5]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
def SPIRV_I32_1DTensorArmOfLength6Attr : ConfinedAttr<RankedI32ElementsAttr<[6]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
+class IntElementsAttrAllValuesAtLeast<int minValue> : AttrConstraint<
+ CPred<"::llvm::all_of(::llvm::cast<::mlir::DenseElementsAttr>($_self).getValues<::llvm::APInt>(), "
+ "[](const ::llvm::APInt &value) { return value.getSExtValue() >= " #
+ minValue # "; })">,
+ "all values must be >= " # minValue>;
+
+def SPIRV_PositiveInt32_1DTensorArmOfLength2Attr : ConfinedAttr<RankedI32ElementsAttr<[2]>, [SPIRV_DenseElementAttrsWithTensorArmType, IntElementsAttrAllValuesAtLeast<1>]>;
+def SPIRV_NonNegativeInt32_1DTensorArmOfLength4Attr : ConfinedAttr<RankedI32ElementsAttr<[4]>, [SPIRV_DenseElementAttrsWithTensorArmType, IntElementsAttrAllValuesAtLeast<0>]>;
class Is1DTensorArmAttrOfLength<list<int> allowedLengths> :
AttrConstraint<And<[CPred<[{::llvm::cast<::mlir::spirv::TensorArmType>(::llvm::cast<::mlir::DenseElementsAttr>($_self).getType()).getShape().size() == 1 }]>,
@@ -217,6 +225,14 @@ class ValuesIndicesShapesMatch<string values, string indices, string tensor>:
SameDimsOrDynamicPred<values, 2, tensor, 2>
]>>;
+class NHWCInputOutputShapeMatch<string input, string output>:
+ PredOpTrait<"shapes of " # input # " and " # output #
+ " must satisfy [N,*,*,C] and [N,*,*,C]",
+ And<[
+ SameDimsOrDynamicPred<input, 0, output, 0>,
+ SameDimsOrDynamicPred<input, 3, output, 3>
+ ]>>;
+
class TableSizeConstraint<string input, Type type, int size>:
PredOpTrait<"table must have size " # size # " if " # input # " has element type " # type.summary,
Implies<ElementTypeIsPred<input, type>, [CPred<"::llvm::cast<::mlir::ShapedType>(getTable().getType()).getShape()[0] == " # size>]>
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
index a0591ee31acf8..c078c88f7dd3d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
@@ -54,6 +54,96 @@ void printSPIRV_I32_1DArmTensor(OpAsmPrinter &printer, Operation *,
// SPIRV Tosa Custom verifiers
//===----------------------------------------------------------------------===//
+namespace {
+
+int64_t getIntValue(DenseIntElementsAttr attr, size_t idx) {
+ return (*std::next(attr.getValues<APInt>().begin(), idx)).getSExtValue();
+}
+
+LogicalResult verifyPool2DPadValuesLessThanKernel(Operation *op,
+ DenseIntElementsAttr kernel,
+ DenseIntElementsAttr pad) {
+ const int64_t kernelY = getIntValue(kernel, 0);
+ const int64_t kernelX = getIntValue(kernel, 1);
+ const int64_t padTop = getIntValue(pad, 0);
+ const int64_t padBottom = getIntValue(pad, 1);
+ const int64_t padLeft = getIntValue(pad, 2);
+ const int64_t padRight = getIntValue(pad, 3);
+
+ if (padTop < kernelY && padBottom < kernelY && padLeft < kernelX &&
+ padRight < kernelX)
+ return success();
+
+ return op->emitOpError("pad values must satisfy pad_top/pad_bottom < "
+ "kernel_y and pad_left/pad_right < kernel_x");
+}
+
+LogicalResult verifyPool2DOutputDim(Operation *op, int64_t inputSize,
+ int64_t outputSize, int64_t kernelSize,
+ int64_t strideSize, int64_t padBefore,
+ int64_t padAfter, StringRef dimName,
+ StringRef dimAxis, StringRef padBeforeName,
+ StringRef padAfterName) {
+ if (ShapedType::isDynamic(inputSize))
+ return success();
+
+ const int64_t numerator = inputSize + padBefore + padAfter - kernelSize;
+ if (numerator % strideSize != 0)
+ return op->emitOpError("expected input_")
+ << dimName << " + pad_" << padBeforeName << " + pad_" << padAfterName
+ << " - kernel_" << dimAxis << " to be wholly divisible by stride_"
+ << dimAxis << ", got (" << inputSize << " + " << padBefore << " + "
+ << padAfter << " - " << kernelSize << ") / " << strideSize;
+
+ const int64_t calculatedOutput = numerator / strideSize + 1;
+ if (!ShapedType::isDynamic(outputSize) && outputSize != calculatedOutput)
+ return op->emitOpError("failed to verify that shapes of input and output "
+ "must satisfy [N,IH,IW,C] and [N,OH,OW,C], with "
+ "OH = ((IH + pad_top + pad_bottom - kernel_y) / "
+ "stride_y) + 1 and OW = ((IW + pad_left + "
+ "pad_right - kernel_x) / stride_x) + 1");
+
+ return success();
+}
+
+LogicalResult verifyPool2DOp(Operation *op, DenseIntElementsAttr kernel,
+ DenseIntElementsAttr stride,
+ DenseIntElementsAttr pad, TensorArmType inputType,
+ TensorArmType outputType) {
+
+ if (failed(verifyPool2DPadValuesLessThanKernel(op, kernel, pad)))
+ return failure();
+
+ if (!inputType.hasRank() || !outputType.hasRank())
+ return success();
+
+ if (failed(verifyPool2DOutputDim(
+ op, inputType.getDimSize(1), outputType.getDimSize(1),
+ getIntValue(kernel, 0), getIntValue(stride, 0), getIntValue(pad, 0),
+ getIntValue(pad, 1), "height", "y", "top", "bottom")))
+ return failure();
+
+ if (failed(verifyPool2DOutputDim(
+ op, inputType.getDimSize(2), outputType.getDimSize(2),
+ getIntValue(kernel, 1), getIntValue(stride, 1), getIntValue(pad, 2),
+ getIntValue(pad, 3), "width", "x", "left", "right")))
+ return failure();
+
+ return success();
+}
+
+} // namespace
+
+LogicalResult TosaAvgPool2DOp::verify() {
+ return verifyPool2DOp(getOperation(), getKernel(), getStride(), getPad(),
+ getInputType(), getResultType());
+}
+
+LogicalResult TosaMaxPool2DOp::verify() {
+ return verifyPool2DOp(getOperation(), getKernel(), getStride(), getPad(),
+ getInputType(), getResultType());
+}
+
LogicalResult TosaSelectOp::verify() {
TensorArmType condType = getConditionType();
TensorArmType trueValType = getTrueValueType();
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
index 78ae3b4586004..0489629b98f2f 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
@@ -91,6 +91,38 @@ spirv.ARM.Graph @avgpool2d_accumulator_must_be_FP16_for_f8e5m2_element_type(%arg
spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x2x2xf8E5M2>
}
+spirv.ARM.Graph @avgpool2d_input_output_batch_or_channel_mismatch(%arg0: !spirv.arm.tensor<1x3x65537x2xi8>) -> (!spirv.arm.tensor<2x2x32768x1xi8>) {
+ %4 = spirv.Constant dense<125> : !spirv.arm.tensor<1xi8>
+ %5 = spirv.Constant dense<-90> : !spirv.arm.tensor<1xi8>
+ // expected-error @+1 {{op failed to verify that shapes of input and output must satisfy [N,*,*,C] and [N,*,*,C]}}
+ %6 = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [1, 2], pad = [0, 1, 0, 0], acc_type = <INT32>, %arg0, %4, %5 : !spirv.arm.tensor<1x3x65537x2xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<2x2x32768x1xi8>
+ spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<2x2x32768x1xi8>
+}
+
+spirv.ARM.Graph @avgpool2d_input_shape_not_wholly_divisible_by_stride(%arg0: !spirv.arm.tensor<1x4x4x1xi8>) -> (!spirv.arm.tensor<1x1x1x1xi8>) {
+ %4 = spirv.Constant dense<125> : !spirv.arm.tensor<1xi8>
+ %5 = spirv.Constant dense<-90> : !spirv.arm.tensor<1xi8>
+ // expected-error @+1 {{op expected input_height + pad_top + pad_bottom - kernel_y to be wholly divisible by stride_y}}
+ %6 = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [2, 2], pad = [0, 0, 0, 0], acc_type = <INT32>, %arg0, %4, %5 : !spirv.arm.tensor<1x4x4x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x1x1x1xi8>
+ spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x1x1x1xi8>
+}
+
+spirv.ARM.Graph @avgpool2d_pad_values_must_be_less_than_kernel(%arg0: !spirv.arm.tensor<1x4x4x1xi8>) -> (!spirv.arm.tensor<1x2x1x1xi8>) {
+ %4 = spirv.Constant dense<125> : !spirv.arm.tensor<1xi8>
+ %5 = spirv.Constant dense<-90> : !spirv.arm.tensor<1xi8>
+ // expected-error @+1 {{op pad values must satisfy pad_top/pad_bottom < kernel_y and pad_left/pad_right < kernel_x}}
+ %6 = spirv.Tosa.AvgPool2D kernel = [2, 3], stride = [1, 2], pad = [2, 0, 0, 0], acc_type = <INT32>, %arg0, %4, %5 : !spirv.arm.tensor<1x4x4x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x2x1x1xi8>
+ spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x1x1xi8>
+}
+
+spirv.ARM.Graph @avgpool2d_input_output_height_width_mismatch(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32769x1xi8>) {
+ %4 = spirv.Constant dense<125> : !spirv.arm.tensor<1xi8>
+ %5 = spirv.Constant dense<-90> : !spirv.arm.tensor<1xi8>
+ // expected-error @+1 {{op failed to verify that shapes of input and output must satisfy [N,IH,IW,C] and [N,OH,OW,C], with OH = ((IH + pad_top + pad_bottom - kernel_y) / stride_y) + 1 and OW = ((IW + pad_left + pad_right - kernel_x) / stride_x) + 1}}
+ %6 = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [1, 2], pad = [0, 1, 0, 0], acc_type = <INT32>, %arg0, %4, %5 : !spirv.arm.tensor<1x3x65537x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x2x32769x1xi8>
+ spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x32769x1xi8>
+}
+
//===----------------------------------------------------------------------===//
// spirv.TOSA.Conv2D
//===----------------------------------------------------------------------===//
@@ -537,6 +569,30 @@ spirv.ARM.Graph @maxpool2d_input_output_different_element_types(%arg0: !spirv.ar
spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x2x32769x1xi16>
}
+spirv.ARM.Graph @maxpool2d_input_output_batch_or_channel_mismatch(%arg0: !spirv.arm.tensor<1x3x65537x2xi8>) -> (!spirv.arm.tensor<2x2x32769x1xi8>) {
+ // expected-error @+1 {{op failed to verify that shapes of input and output must satisfy [N,*,*,C] and [N,*,*,C]}}
+ %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [1, 2], pad = [1, 0, 0, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x3x65537x2xi8> -> !spirv.arm.tensor<2x2x32769x1xi8>
+ spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<2x2x32769x1xi8>
+}
+
+spirv.ARM.Graph @maxpool2d_input_shape_not_wholly_divisible_by_stride(%arg0: !spirv.arm.tensor<1x4x4x1xi8>) -> (!spirv.arm.tensor<1x1x1x1xi8>) {
+ // expected-error @+1 {{op expected input_height + pad_top + pad_bottom - kernel_y to be wholly divisible by stride_y}}
+ %4 = spirv.Tosa.MaxPool2D kernel = [3, 3], stride = [2, 2], pad = [0, 0, 0, 0], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x4x4x1xi8> -> !spirv.arm.tensor<1x1x1x1xi8>
+ spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x1x1x1xi8>
+}
+
+spirv.ARM.Graph @maxpool2d_pad_values_must_be_less_than_kernel(%arg0: !spirv.arm.tensor<1x4x4x1xi8>) -> (!spirv.arm.tensor<1x2x1x1xi8>) {
+ // expected-error @+1 {{op pad values must satisfy pad_top/pad_bottom < kernel_y and pad_left/pad_right < kernel_x}}
+ %4 = spirv.Tosa.MaxPool2D kernel = [2, 3], stride = [1, 2], pad = [2, 0, 0, 0], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x4x4x1xi8> -> !spirv.arm.tensor<1x2x1x1xi8>
+ spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x2x1x1xi8>
+}
+
+spirv.ARM.Graph @maxpool2d_input_output_height_width_mismatch(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32768x1xi8>) {
+ // expected-error @+1 {{op failed to verify that shapes of input and output must satisfy [N,IH,IW,C] and [N,OH,OW,C], with OH = ((IH + pad_top + pad_bottom - kernel_y) / stride_y) + 1 and OW = ((IW + pad_left + pad_right - kernel_x) / stride_x) + 1}}
+ %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [1, 2], pad = [1, 0, 0, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x3x65537x1xi8> -> !spirv.arm.tensor<1x2x32768x1xi8>
+ spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x2x32768x1xi8>
+}
+
//===----------------------------------------------------------------------===//
// spirv.TOSA.TransposeConv2D
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list