[Mlir-commits] [mlir] 546787e - [mlir][spirv] Fix SPIRV TOSA per-channel rescale length verification (#190748)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 7 03:52:34 PDT 2026


Author: Davide Grohmann
Date: 2026-04-07T12:52:30+02:00
New Revision: 546787ec97e292391c27fe55ef4549e7ba022afe

URL: https://github.com/llvm/llvm-project/commit/546787ec97e292391c27fe55ef4549e7ba022afe
DIFF: https://github.com/llvm/llvm-project/commit/546787ec97e292391c27fe55ef4549e7ba022afe.diff

LOG: [mlir][spirv] Fix SPIRV TOSA per-channel rescale length verification (#190748)

`TensorLengthMatchesPerChannel` was checking `rank(input) - 1` instead
of `input_shape[rank(input) - 1]`. Fix the predicate and update the
rescale verifier tests accordingly.

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
    mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
index 1fc3bdad48e75..5a610aaa45cef 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -217,10 +217,10 @@ class ElementTypeMatchesScale32<string tensor> :
           "isInteger(getScale32() ? 32 : 16)">>;
 
 class TensorLengthMatchesPerChannel<string tensor> :
-  PredOpTrait<tensor # " must have length rank(input) - 1 when per_channel is true, otherwise length 1",
+  PredOpTrait<tensor # " must have length input_shape[rank(input) - 1] when per_channel is true, otherwise length 1",
     CPred<"::llvm::cast<::mlir::ShapedType>($" # tensor # ".getType()).getShape()[0] == "
           "(getPerChannel() ? "
-          "::llvm::cast<::mlir::ShapedType>($input.getType()).getRank() - 1 : 1)">>;
+          "::llvm::cast<::mlir::ShapedType>($input.getType()).getShape().back() : 1)">>;
 
 
 #endif // MLIR_DIALECT_SPIRV_IR_TOSA_TYPES

diff  --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
index f981fe7853636..057aa353b4ee1 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
@@ -2005,23 +2005,23 @@ spirv.ARM.Graph @rescale_scale32_false_requires_i16_multiplier(%arg0: !spirv.arm
   spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16>
 }
 
-spirv.ARM.Graph @rescale_per_channel_true_requires_multiplier_length_rank_minus_one(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi16>) {
+spirv.ARM.Graph @rescale_per_channel_true_requires_multiplier_length_last_dimension(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi16>) {
   %1 = spirv.Constant dense<[1]> : !spirv.arm.tensor<1xi16>
   %2 = spirv.Constant dense<[0, 0]> : !spirv.arm.tensor<2xi8>
   %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
   %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
-  // expected-error @+1 {{op failed to verify that multiplier must have length rank(input) - 1 when per_channel is true, otherwise length 1}}
+  // expected-error @+1 {{op failed to verify that multiplier must have length input_shape[rank(input) - 1] when per_channel is true, otherwise length 1}}
   %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = true, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<2xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16>
   spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16>
 }
 
-spirv.ARM.Graph @rescale_per_channel_true_requires_shift_length_rank_minus_one(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi16>) {
-  %1 = spirv.Constant dense<[1, 1]> : !spirv.arm.tensor<2xi16>
+spirv.ARM.Graph @rescale_per_channel_true_requires_shift_length_last_dimension(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi16>) {
+  %1 = spirv.Constant dense<[1, 1, 1, 1]> : !spirv.arm.tensor<4xi16>
   %2 = spirv.Constant dense<[0]> : !spirv.arm.tensor<1xi8>
   %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
   %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
-  // expected-error @+1 {{op failed to verify that shift must have length rank(input) - 1 when per_channel is true, otherwise length 1}}
-  %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = true, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<2xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16>
+  // expected-error @+1 {{op failed to verify that shift must have length input_shape[rank(input) - 1] when per_channel is true, otherwise length 1}}
+  %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = true, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<4xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16>
   spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16>
 }
 
@@ -2030,7 +2030,7 @@ spirv.ARM.Graph @rescale_per_channel_false_requires_multiplier_length_one(%arg0:
   %2 = spirv.Constant dense<[0]> : !spirv.arm.tensor<1xi8>
   %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
   %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
-  // expected-error @+1 {{op failed to verify that multiplier must have length rank(input) - 1 when per_channel is true, otherwise length 1}}
+  // expected-error @+1 {{op failed to verify that multiplier must have length input_shape[rank(input) - 1] when per_channel is true, otherwise length 1}}
   %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<2xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16>
   spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16>
 }
@@ -2040,7 +2040,7 @@ spirv.ARM.Graph @rescale_per_channel_false_requires_shift_length_one(%arg0: !spi
   %2 = spirv.Constant dense<[0, 0]> : !spirv.arm.tensor<2xi8>
   %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
   %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
-  // expected-error @+1 {{op failed to verify that shift must have length rank(input) - 1 when per_channel is true, otherwise length 1}}
+  // expected-error @+1 {{op failed to verify that shift must have length input_shape[rank(input) - 1] when per_channel is true, otherwise length 1}}
   %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<2xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16>
   spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16>
 }


        


More information about the Mlir-commits mailing list