[Mlir-commits] [mlir] [mlir][spirv] Allow dynamic rescale parameter lengths (PR #200155)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 28 03:45:52 PDT 2026


llvmorg-github-actions[bot] wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Davide Grohmann (davidegrohmann)

<details>
<summary>Changes</summary>

The SPIR-V TOSA rescale verifier checked multiplier and shift lengths with a direct equality against the input channel dimension. That rejects otherwise valid operations when either side of the shape comparison is dynamic.

Express the check with reusable dimension predicates so unranked or dynamic dimensions pass, while static dimensions still enforce the per-channel and scalar length requirements. Add dedicated dynamic-shape op coverage for dynamic input channel dimensions and dynamic multiplier/shift lengths.

---
Full diff: https://github.com/llvm/llvm-project/pull/200155.diff


2 Files Affected:

- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td (+18-3) 
- (added) mlir/test/Dialect/SPIRV/IR/tosa-ops-dynamic.mlir (+25) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
index 8497f4f0c4b46..777db957b4d19 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -195,6 +195,15 @@ class ShapedTypeOf<string input> :
 class DimOf<string input, int dim> :
   DimOfType<ShapedTypeOf<input>.result, dim>;
 
+class LastDimOf<string input> :
+  StrFunc<Shape<input>.result # ".back()">;
+
+class LastDimIsDynamic<string input> :
+  CPred<"::mlir::ShapedType::isDynamic(" # LastDimOf<input>.result # ")">;
+
+class DimMatchesLastDim<string lhs, int lhsDim, string rhs> :
+  CPred<DimOf<lhs, lhsDim>.result # " == " # LastDimOf<rhs>.result>;
+
 class DimIsDynamic<string input, int dim> :
   CPred<"::mlir::ShapedType::isDynamic(" # DimOf<input, dim>.result # ")">;
 
@@ -484,9 +493,15 @@ class ElementTypeMatchesScale32<string tensor> :
 
 class TensorLengthMatchesPerChannel<string tensor> :
   PredOpTrait<tensor # " must have length input_shape[rank(input) - 1] when per_channel is true, otherwise length 1",
-    CPred<"::llvm::cast<::mlir::ShapedType>($" # tensor # ".getType()).getShape()[0] == "
-          "(getPerChannel() ? "
-          "::llvm::cast<::mlir::ShapedType>($input.getType()).getShape().back() : 1)">>;
+    Or<[
+      Neg<CPred<HasRank<"input">.result>>,
+      Neg<CPred<HasRank<tensor>.result>>,
+      DimIsDynamic<tensor, 0>,
+      And<[CPred<"getPerChannel()">,
+           Or<[LastDimIsDynamic<"input">,
+               DimMatchesLastDim<tensor, 0, "input">]>]>,
+      And<[Neg<CPred<"getPerChannel()">>, DimIsOne<tensor, 0>]>
+    ]>>;
 
 
 #endif // MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-dynamic.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-dynamic.mlir
new file mode 100644
index 0000000000000..a2973c73ab362
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-dynamic.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Rescale
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @rescale_per_channel_dynamic_input_last_dimension(%arg0: !spirv.arm.tensor<?x?x?xi16>) -> (!spirv.arm.tensor<?x?x?xi16>) {
+  %1 = spirv.Constant dense<[1]> : !spirv.arm.tensor<1xi16>
+  %2 = spirv.Constant dense<[0, 0]> : !spirv.arm.tensor<2xi8>
+  %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+  %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+  // CHECK: {{%.*}} = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = true, input_unsigned = false, output_unsigned = false, %arg0, {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : !spirv.arm.tensor<?x?x?xi16>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<2xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<?x?x?xi16>
+  %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = true, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<?x?x?xi16>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<2xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<?x?x?xi16>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<?x?x?xi16>
+  spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<?x?x?xi16>
+}
+
+spirv.ARM.Graph @rescale_per_channel_dynamic_multiplier_and_shift_length(%arg0: !spirv.arm.tensor<2x3x4xi16>, %multiplier: !spirv.arm.tensor<?xi16>, %shift: !spirv.arm.tensor<?xi8>) -> (!spirv.arm.tensor<2x3x4xi16>) {
+  %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+  %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+  // CHECK: {{%.*}} = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = true, input_unsigned = false, output_unsigned = false, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<?xi16>, !spirv.arm.tensor<?xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16>
+  %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = true, input_unsigned = false, output_unsigned = false, %arg0, %multiplier, %shift, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<?xi16>, !spirv.arm.tensor<?xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<2x3x4xi16>
+  spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/200155


More information about the Mlir-commits mailing list