[Mlir-commits] [mlir] [mlir][spirv] Tighten SPIR-V TOSA pool constraints (PR #193515)

Davide Grohmann llvmlistbot at llvm.org
Wed Apr 22 07:51:24 PDT 2026


https://github.com/davidegrohmann created https://github.com/llvm/llvm-project/pull/193515

Tighten AvgPool2D and MaxPool2D verification by constraining kernel, stride, and pad attributes and by checking the input/output NHWC relationship.

Add verification tests for batch/channel mismatches, non-divisible pooled shapes, pad-vs-kernel failures, and incorrect output shapes.

>From 5ae67ac06439057a7b61182a60b32de0f7dc24c0 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Wed, 22 Apr 2026 10:52:06 +0200
Subject: [PATCH] [mlir][spirv] Tighten SPIR-V TOSA pool constraints

Tighten AvgPool2D and MaxPool2D verification by constraining kernel,
stride, and pad attributes and by checking the input/output NHWC
relationship.

Add verification tests for batch/channel mismatches, non-divisible
pooled shapes, pad-vs-kernel failures, and incorrect output shapes.

Change-Id: Iaf7eb4ba34febd3211835858ee0af0430277e939
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
---
 .../mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td     | 22 +++--
 .../mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td   | 16 ++++
 mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp    | 90 +++++++++++++++++++
 .../SPIRV/IR/tosa-ops-verification.mlir       | 56 ++++++++++++
 4 files changed, 176 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index db91e529dc623..12652f982de4b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -285,7 +285,8 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffe
   TypeImpliesAccType<"input", F32, ["FP32"]>,
   TypeImpliesAccType<"input", F8E4M3FN, ["FP16"]>,
   TypeImpliesAccType<"input", F8E5M2, ["FP16"]>,
-  AllElementTypesMatch<["input", "input_zp", "output", "output_zp"]>]> {
+  AllElementTypesMatch<["input", "input_zp", "output", "output_zp"]>,
+  NHWCInputOutputShapeMatch<"input", "output">]> {
   let summary = "Performs average pooling on the input.";
 
   let description = [{
@@ -308,9 +309,9 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffe
   }];
 
   let arguments = (ins
-    SPIRV_I32_1DTensorArmOfLength2Attr: $kernel,
-    SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
-    SPIRV_I32_1DTensorArmOfLength4Attr: $pad,
+    SPIRV_PositiveInt32_1DTensorArmOfLength2Attr: $kernel,
+    SPIRV_PositiveInt32_1DTensorArmOfLength2Attr: $stride,
+    SPIRV_NonNegativeInt32_1DTensorArmOfLength4Attr: $pad,
     SPIRV_TosaExtAccTypeAttr: $acc_type,
     SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input,
     SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $input_zp,
@@ -337,6 +338,8 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffe
       return cast<::mlir::spirv::TensorArmType>(getInput().getType());
     }
   }];
+
+  let hasVerifier = 1;
 }
 
 
@@ -619,7 +622,8 @@ def SPIRV_TosaMatMulOp : SPIRV_TosaOpWithResult<"MatMul", 6, [NoMemoryEffect,
 
 
 def SPIRV_TosaMaxPool2DOp : SPIRV_TosaOpWithResult<"MaxPool2D", 7, [Pure,
-  AllElementTypesMatch<["input", "output"]>]> {
+  AllElementTypesMatch<["input", "output"]>,
+  NHWCInputOutputShapeMatch<"input", "output">]> {
   let summary = "Performs max pooling on the input.";
 
   let description = [{
@@ -640,9 +644,9 @@ def SPIRV_TosaMaxPool2DOp : SPIRV_TosaOpWithResult<"MaxPool2D", 7, [Pure,
   }];
 
   let arguments = (ins
-    SPIRV_I32_1DTensorArmOfLength2Attr: $kernel,
-    SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
-    SPIRV_I32_1DTensorArmOfLength4Attr: $pad,
+    SPIRV_PositiveInt32_1DTensorArmOfLength2Attr: $kernel,
+    SPIRV_PositiveInt32_1DTensorArmOfLength2Attr: $stride,
+    SPIRV_NonNegativeInt32_1DTensorArmOfLength4Attr: $pad,
     SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
     SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input
   );
@@ -665,6 +669,8 @@ def SPIRV_TosaMaxPool2DOp : SPIRV_TosaOpWithResult<"MaxPool2D", 7, [Pure,
       return cast<::mlir::spirv::TensorArmType>(getInput().getType());
     }
   }];
+
+  let hasVerifier = 1;
 }
 
 
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
index 5704911d7f53d..0376ed89d71e8 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -115,6 +115,14 @@ def SPIRV_I32_1DTensorArmOfLength3Attr : ConfinedAttr<RankedI32ElementsAttr<[3]>
 def SPIRV_I32_1DTensorArmOfLength4Attr : ConfinedAttr<RankedI32ElementsAttr<[4]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
 def SPIRV_I32_1DTensorArmOfLength5Attr : ConfinedAttr<RankedI32ElementsAttr<[5]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
 def SPIRV_I32_1DTensorArmOfLength6Attr : ConfinedAttr<RankedI32ElementsAttr<[6]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
+class IntElementsAttrAllValuesAtLeast<int minValue> : AttrConstraint<
+  CPred<"::llvm::all_of(::llvm::cast<::mlir::DenseElementsAttr>($_self).getValues<::llvm::APInt>(), "
+        "[](const ::llvm::APInt &value) { return value.getSExtValue() >= " #
+        minValue # "; })">,
+  "all values must be >= " # minValue>;
+
+def SPIRV_PositiveInt32_1DTensorArmOfLength2Attr : ConfinedAttr<RankedI32ElementsAttr<[2]>, [SPIRV_DenseElementAttrsWithTensorArmType, IntElementsAttrAllValuesAtLeast<1>]>;
+def SPIRV_NonNegativeInt32_1DTensorArmOfLength4Attr : ConfinedAttr<RankedI32ElementsAttr<[4]>, [SPIRV_DenseElementAttrsWithTensorArmType, IntElementsAttrAllValuesAtLeast<0>]>;
 
 class Is1DTensorArmAttrOfLength<list<int> allowedLengths> :
   AttrConstraint<And<[CPred<[{::llvm::cast<::mlir::spirv::TensorArmType>(::llvm::cast<::mlir::DenseElementsAttr>($_self).getType()).getShape().size() == 1 }]>,
@@ -217,6 +225,14 @@ class ValuesIndicesShapesMatch<string values, string indices, string tensor>:
       SameDimsOrDynamicPred<values, 2, tensor, 2>
     ]>>;
 
+class NHWCInputOutputShapeMatch<string input, string output>:
+  PredOpTrait<"shapes of " # input # " and " # output #
+                  " must satisfy [N,*,*,C] and [N,*,*,C]",
+    And<[
+      SameDimsOrDynamicPred<input, 0, output, 0>,
+      SameDimsOrDynamicPred<input, 3, output, 3>
+    ]>>;
+
 class TableSizeConstraint<string input, Type type, int size>:
   PredOpTrait<"table must have size " # size # " if " # input # " has element type " # type.summary,
       Implies<ElementTypeIsPred<input, type>, [CPred<"::llvm::cast<::mlir::ShapedType>(getTable().getType()).getShape()[0] == " # size>]>
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
index a0591ee31acf8..c078c88f7dd3d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
@@ -54,6 +54,96 @@ void printSPIRV_I32_1DArmTensor(OpAsmPrinter &printer, Operation *,
 // SPIRV Tosa Custom verifiers
 //===----------------------------------------------------------------------===//
 
+namespace {
+
+int64_t getIntValue(DenseIntElementsAttr attr, size_t idx) {
+  return (*std::next(attr.getValues<APInt>().begin(), idx)).getSExtValue();
+}
+
+LogicalResult verifyPool2DPadValuesLessThanKernel(Operation *op,
+                                                  DenseIntElementsAttr kernel,
+                                                  DenseIntElementsAttr pad) {
+  const int64_t kernelY = getIntValue(kernel, 0);
+  const int64_t kernelX = getIntValue(kernel, 1);
+  const int64_t padTop = getIntValue(pad, 0);
+  const int64_t padBottom = getIntValue(pad, 1);
+  const int64_t padLeft = getIntValue(pad, 2);
+  const int64_t padRight = getIntValue(pad, 3);
+
+  if (padTop < kernelY && padBottom < kernelY && padLeft < kernelX &&
+      padRight < kernelX)
+    return success();
+
+  return op->emitOpError("pad values must satisfy pad_top/pad_bottom < "
+                         "kernel_y and pad_left/pad_right < kernel_x");
+}
+
+LogicalResult verifyPool2DOutputDim(Operation *op, int64_t inputSize,
+                                    int64_t outputSize, int64_t kernelSize,
+                                    int64_t strideSize, int64_t padBefore,
+                                    int64_t padAfter, StringRef dimName,
+                                    StringRef dimAxis, StringRef padBeforeName,
+                                    StringRef padAfterName) {
+  if (ShapedType::isDynamic(inputSize))
+    return success();
+
+  const int64_t numerator = inputSize + padBefore + padAfter - kernelSize;
+  if (numerator % strideSize != 0)
+    return op->emitOpError("expected input_")
+           << dimName << " + pad_" << padBeforeName << " + pad_" << padAfterName
+           << " - kernel_" << dimAxis << " to be wholly divisible by stride_"
+           << dimAxis << ", got (" << inputSize << " + " << padBefore << " + "
+           << padAfter << " - " << kernelSize << ") / " << strideSize;
+
+  const int64_t calculatedOutput = numerator / strideSize + 1;
+  if (!ShapedType::isDynamic(outputSize) && outputSize != calculatedOutput)
+    return op->emitOpError("failed to verify that shapes of input and output "
+                           "must satisfy [N,IH,IW,C] and [N,OH,OW,C], with "
+                           "OH = ((IH + pad_top + pad_bottom - kernel_y) / "
+                           "stride_y) + 1 and OW = ((IW + pad_left + "
+                           "pad_right - kernel_x) / stride_x) + 1");
+
+  return success();
+}
+
+LogicalResult verifyPool2DOp(Operation *op, DenseIntElementsAttr kernel,
+                             DenseIntElementsAttr stride,
+                             DenseIntElementsAttr pad, TensorArmType inputType,
+                             TensorArmType outputType) {
+
+  if (failed(verifyPool2DPadValuesLessThanKernel(op, kernel, pad)))
+    return failure();
+
+  if (!inputType.hasRank() || !outputType.hasRank())
+    return success();
+
+  if (failed(verifyPool2DOutputDim(
+          op, inputType.getDimSize(1), outputType.getDimSize(1),
+          getIntValue(kernel, 0), getIntValue(stride, 0), getIntValue(pad, 0),
+          getIntValue(pad, 1), "height", "y", "top", "bottom")))
+    return failure();
+
+  if (failed(verifyPool2DOutputDim(
+          op, inputType.getDimSize(2), outputType.getDimSize(2),
+          getIntValue(kernel, 1), getIntValue(stride, 1), getIntValue(pad, 2),
+          getIntValue(pad, 3), "width", "x", "left", "right")))
+    return failure();
+
+  return success();
+}
+
+} // namespace
+
+LogicalResult TosaAvgPool2DOp::verify() {
+  return verifyPool2DOp(getOperation(), getKernel(), getStride(), getPad(),
+                        getInputType(), getResultType());
+}
+
+LogicalResult TosaMaxPool2DOp::verify() {
+  return verifyPool2DOp(getOperation(), getKernel(), getStride(), getPad(),
+                        getInputType(), getResultType());
+}
+
 LogicalResult TosaSelectOp::verify() {
   TensorArmType condType = getConditionType();
   TensorArmType trueValType = getTrueValueType();
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
index 78ae3b4586004..0489629b98f2f 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
@@ -91,6 +91,38 @@ spirv.ARM.Graph @avgpool2d_accumulator_must_be_FP16_for_f8e5m2_element_type(%arg
   spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x2x2xf8E5M2>
 }
 
+spirv.ARM.Graph @avgpool2d_input_output_batch_or_channel_mismatch(%arg0: !spirv.arm.tensor<1x3x65537x2xi8>) -> (!spirv.arm.tensor<2x2x32768x1xi8>) {
+  %4 = spirv.Constant dense<125> : !spirv.arm.tensor<1xi8>
+  %5 = spirv.Constant dense<-90> : !spirv.arm.tensor<1xi8>
+  // expected-error @+1 {{op failed to verify that shapes of input and output must satisfy [N,*,*,C] and [N,*,*,C]}}
+  %6 = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [1, 2], pad = [0, 1, 0, 0], acc_type = <INT32>, %arg0, %4, %5 : !spirv.arm.tensor<1x3x65537x2xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<2x2x32768x1xi8>
+  spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<2x2x32768x1xi8>
+}
+
+spirv.ARM.Graph @avgpool2d_input_shape_not_wholly_divisible_by_stride(%arg0: !spirv.arm.tensor<1x4x4x1xi8>) -> (!spirv.arm.tensor<1x1x1x1xi8>) {
+  %4 = spirv.Constant dense<125> : !spirv.arm.tensor<1xi8>
+  %5 = spirv.Constant dense<-90> : !spirv.arm.tensor<1xi8>
+  // expected-error @+1 {{op expected input_height + pad_top + pad_bottom - kernel_y to be wholly divisible by stride_y}}
+  %6 = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [2, 2], pad = [0, 0, 0, 0], acc_type = <INT32>, %arg0, %4, %5 : !spirv.arm.tensor<1x4x4x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x1x1x1xi8>
+  spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x1x1x1xi8>
+}
+
+spirv.ARM.Graph @avgpool2d_pad_values_must_be_less_than_kernel(%arg0: !spirv.arm.tensor<1x4x4x1xi8>) -> (!spirv.arm.tensor<1x2x1x1xi8>) {
+  %4 = spirv.Constant dense<125> : !spirv.arm.tensor<1xi8>
+  %5 = spirv.Constant dense<-90> : !spirv.arm.tensor<1xi8>
+  // expected-error @+1 {{op pad values must satisfy pad_top/pad_bottom < kernel_y and pad_left/pad_right < kernel_x}}
+  %6 = spirv.Tosa.AvgPool2D kernel = [2, 3], stride = [1, 2], pad = [2, 0, 0, 0], acc_type = <INT32>, %arg0, %4, %5 : !spirv.arm.tensor<1x4x4x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x2x1x1xi8>
+  spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x1x1xi8>
+}
+
+spirv.ARM.Graph @avgpool2d_input_output_height_width_mismatch(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32769x1xi8>) {
+  %4 = spirv.Constant dense<125> : !spirv.arm.tensor<1xi8>
+  %5 = spirv.Constant dense<-90> : !spirv.arm.tensor<1xi8>
+  // expected-error @+1 {{op failed to verify that shapes of input and output must satisfy [N,IH,IW,C] and [N,OH,OW,C], with OH = ((IH + pad_top + pad_bottom - kernel_y) / stride_y) + 1 and OW = ((IW + pad_left + pad_right - kernel_x) / stride_x) + 1}}
+  %6 = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [1, 2], pad = [0, 1, 0, 0], acc_type = <INT32>, %arg0, %4, %5 : !spirv.arm.tensor<1x3x65537x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x2x32769x1xi8>
+  spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x32769x1xi8>
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.TOSA.Conv2D
 //===----------------------------------------------------------------------===//
@@ -537,6 +569,30 @@ spirv.ARM.Graph @maxpool2d_input_output_different_element_types(%arg0: !spirv.ar
   spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x2x32769x1xi16>
 }
 
+spirv.ARM.Graph @maxpool2d_input_output_batch_or_channel_mismatch(%arg0: !spirv.arm.tensor<1x3x65537x2xi8>) -> (!spirv.arm.tensor<2x2x32769x1xi8>) {
+  // expected-error @+1 {{op failed to verify that shapes of input and output must satisfy [N,*,*,C] and [N,*,*,C]}}
+  %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [1, 2], pad = [1, 0, 0, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x3x65537x2xi8> -> !spirv.arm.tensor<2x2x32769x1xi8>
+  spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<2x2x32769x1xi8>
+}
+
+spirv.ARM.Graph @maxpool2d_input_shape_not_wholly_divisible_by_stride(%arg0: !spirv.arm.tensor<1x4x4x1xi8>) -> (!spirv.arm.tensor<1x1x1x1xi8>) {
+  // expected-error @+1 {{op expected input_height + pad_top + pad_bottom - kernel_y to be wholly divisible by stride_y}}
+  %4 = spirv.Tosa.MaxPool2D kernel = [3, 3], stride = [2, 2], pad = [0, 0, 0, 0], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x4x4x1xi8> -> !spirv.arm.tensor<1x1x1x1xi8>
+  spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x1x1x1xi8>
+}
+
+spirv.ARM.Graph @maxpool2d_pad_values_must_be_less_than_kernel(%arg0: !spirv.arm.tensor<1x4x4x1xi8>) -> (!spirv.arm.tensor<1x2x1x1xi8>) {
+  // expected-error @+1 {{op pad values must satisfy pad_top/pad_bottom < kernel_y and pad_left/pad_right < kernel_x}}
+  %4 = spirv.Tosa.MaxPool2D kernel = [2, 3], stride = [1, 2], pad = [2, 0, 0, 0], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x4x4x1xi8> -> !spirv.arm.tensor<1x2x1x1xi8>
+  spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x2x1x1xi8>
+}
+
+spirv.ARM.Graph @maxpool2d_input_output_height_width_mismatch(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32768x1xi8>) {
+  // expected-error @+1 {{op failed to verify that shapes of input and output must satisfy [N,IH,IW,C] and [N,OH,OW,C], with OH = ((IH + pad_top + pad_bottom - kernel_y) / stride_y) + 1 and OW = ((IW + pad_left + pad_right - kernel_x) / stride_x) + 1}}
+  %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [1, 2], pad = [1, 0, 0, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x3x65537x1xi8> -> !spirv.arm.tensor<1x2x32768x1xi8>
+  spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x2x32768x1xi8>
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.TOSA.TransposeConv2D
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list