[Mlir-commits] [mlir] [mlir][spirv] Allow dynamic rescale parameter lengths (PR #200155)

Davide Grohmann llvmlistbot at llvm.org
Thu May 28 03:45:12 PDT 2026


https://github.com/davidegrohmann created https://github.com/llvm/llvm-project/pull/200155

The SPIR-V TOSA rescale verifier checked multiplier and shift lengths with a direct equality against the input channel dimension. That rejects otherwise valid operations when either side of the shape comparison is dynamic.

Express the check with reusable dimension predicates so unranked or dynamic dimensions pass, while static dimensions still enforce the per-channel and scalar length requirements. Add dedicated dynamic-shape op coverage for dynamic input channel dimensions and dynamic multiplier/shift lengths.

>From 8478783e1e950aeecde38e08ca5f0b9ba79181b4 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Thu, 28 May 2026 11:50:33 +0200
Subject: [PATCH] [mlir][spirv] Allow dynamic rescale parameter lengths

The SPIR-V TOSA rescale verifier checked multiplier and shift lengths
with a direct equality against the input channel dimension. That
rejects otherwise valid operations when either side of the shape
comparison is dynamic.

Express the check with reusable dimension predicates so unranked or
dynamic dimensions pass, while static dimensions still enforce the
per-channel and scalar length requirements. Add dedicated
dynamic-shape op coverage for dynamic input channel dimensions and
dynamic multiplier/shift lengths.

Change-Id: Ifbb8a62a7c7480eaca7b3b7e349c315ed5e24bb7
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
---
 .../mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td   | 21 +++++++++++++---
 .../Dialect/SPIRV/IR/tosa-ops-dynamic.mlir    | 25 +++++++++++++++++++
 2 files changed, 43 insertions(+), 3 deletions(-)
 create mode 100644 mlir/test/Dialect/SPIRV/IR/tosa-ops-dynamic.mlir

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
index 8497f4f0c4b46..777db957b4d19 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -195,6 +195,15 @@ class ShapedTypeOf<string input> :
 class DimOf<string input, int dim> :
   DimOfType<ShapedTypeOf<input>.result, dim>;
 
+class LastDimOf<string input> :
+  StrFunc<Shape<input>.result # ".back()">;
+
+class LastDimIsDynamic<string input> :
+  CPred<"::mlir::ShapedType::isDynamic(" # LastDimOf<input>.result # ")">;
+
+class DimMatchesLastDim<string lhs, int lhsDim, string rhs> :
+  CPred<DimOf<lhs, lhsDim>.result # " == " # LastDimOf<rhs>.result>;
+
 class DimIsDynamic<string input, int dim> :
   CPred<"::mlir::ShapedType::isDynamic(" # DimOf<input, dim>.result # ")">;
 
@@ -484,9 +493,15 @@ class ElementTypeMatchesScale32<string tensor> :
 
 class TensorLengthMatchesPerChannel<string tensor> :
   PredOpTrait<tensor # " must have length input_shape[rank(input) - 1] when per_channel is true, otherwise length 1",
-    CPred<"::llvm::cast<::mlir::ShapedType>($" # tensor # ".getType()).getShape()[0] == "
-          "(getPerChannel() ? "
-          "::llvm::cast<::mlir::ShapedType>($input.getType()).getShape().back() : 1)">>;
+    Or<[
+      Neg<CPred<HasRank<"input">.result>>,
+      Neg<CPred<HasRank<tensor>.result>>,
+      DimIsDynamic<tensor, 0>,
+      And<[CPred<"getPerChannel()">,
+           Or<[LastDimIsDynamic<"input">,
+               DimMatchesLastDim<tensor, 0, "input">]>]>,
+      And<[Neg<CPred<"getPerChannel()">>, DimIsOne<tensor, 0>]>
+    ]>>;
 
 
 #endif // MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-dynamic.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-dynamic.mlir
new file mode 100644
index 0000000000000..a2973c73ab362
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-dynamic.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Rescale
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @rescale_per_channel_dynamic_input_last_dimension(%arg0: !spirv.arm.tensor<?x?x?xi16>) -> (!spirv.arm.tensor<?x?x?xi16>) {
+  %1 = spirv.Constant dense<[1]> : !spirv.arm.tensor<1xi16>
+  %2 = spirv.Constant dense<[0, 0]> : !spirv.arm.tensor<2xi8>
+  %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+  %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+  // CHECK: {{%.*}} = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = true, input_unsigned = false, output_unsigned = false, %arg0, {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : !spirv.arm.tensor<?x?x?xi16>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<2xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<?x?x?xi16>
+  %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = true, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<?x?x?xi16>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<2xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<?x?x?xi16>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<?x?x?xi16>
+  spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<?x?x?xi16>
+}
+
+spirv.ARM.Graph @rescale_per_channel_dynamic_multiplier_and_shift_length(%arg0: !spirv.arm.tensor<2x3x4xi16>, %multiplier: !spirv.arm.tensor<?xi16>, %shift: !spirv.arm.tensor<?xi8>) -> (!spirv.arm.tensor<2x3x4xi16>) {
+  %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+  %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+  // CHECK: {{%.*}} = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = true, input_unsigned = false, output_unsigned = false, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<?xi16>, !spirv.arm.tensor<?xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16>
+  %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = true, input_unsigned = false, output_unsigned = false, %arg0, %multiplier, %shift, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<?xi16>, !spirv.arm.tensor<?xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<2x3x4xi16>
+  spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16>
+}



More information about the Mlir-commits mailing list