[Mlir-commits] [mlir] 546787e - [mlir][spirv] Fix SPIRV TOSA per-channel rescale length verification (#190748)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 7 03:52:34 PDT 2026
Author: Davide Grohmann
Date: 2026-04-07T12:52:30+02:00
New Revision: 546787ec97e292391c27fe55ef4549e7ba022afe
URL: https://github.com/llvm/llvm-project/commit/546787ec97e292391c27fe55ef4549e7ba022afe
DIFF: https://github.com/llvm/llvm-project/commit/546787ec97e292391c27fe55ef4549e7ba022afe.diff
LOG: [mlir][spirv] Fix SPIRV TOSA per-channel rescale length verification (#190748)
`TensorLengthMatchesPerChannel` was checking `rank(input) - 1` instead
of `input_shape[rank(input) - 1]`. Fix the predicate and update the
rescale verifier tests accordingly.
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
index 1fc3bdad48e75..5a610aaa45cef 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -217,10 +217,10 @@ class ElementTypeMatchesScale32<string tensor> :
"isInteger(getScale32() ? 32 : 16)">>;
class TensorLengthMatchesPerChannel<string tensor> :
- PredOpTrait<tensor # " must have length rank(input) - 1 when per_channel is true, otherwise length 1",
+ PredOpTrait<tensor # " must have length input_shape[rank(input) - 1] when per_channel is true, otherwise length 1",
CPred<"::llvm::cast<::mlir::ShapedType>($" # tensor # ".getType()).getShape()[0] == "
"(getPerChannel() ? "
- "::llvm::cast<::mlir::ShapedType>($input.getType()).getRank() - 1 : 1)">>;
+ "::llvm::cast<::mlir::ShapedType>($input.getType()).getShape().back() : 1)">>;
#endif // MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
index f981fe7853636..057aa353b4ee1 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
@@ -2005,23 +2005,23 @@ spirv.ARM.Graph @rescale_scale32_false_requires_i16_multiplier(%arg0: !spirv.arm
spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16>
}
-spirv.ARM.Graph @rescale_per_channel_true_requires_multiplier_length_rank_minus_one(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi16>) {
+spirv.ARM.Graph @rescale_per_channel_true_requires_multiplier_length_last_dimension(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi16>) {
%1 = spirv.Constant dense<[1]> : !spirv.arm.tensor<1xi16>
%2 = spirv.Constant dense<[0, 0]> : !spirv.arm.tensor<2xi8>
%3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
%4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
- // expected-error @+1 {{op failed to verify that multiplier must have length rank(input) - 1 when per_channel is true, otherwise length 1}}
+ // expected-error @+1 {{op failed to verify that multiplier must have length input_shape[rank(input) - 1] when per_channel is true, otherwise length 1}}
%5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = true, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<2xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16>
spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16>
}
-spirv.ARM.Graph @rescale_per_channel_true_requires_shift_length_rank_minus_one(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi16>) {
- %1 = spirv.Constant dense<[1, 1]> : !spirv.arm.tensor<2xi16>
+spirv.ARM.Graph @rescale_per_channel_true_requires_shift_length_last_dimension(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi16>) {
+ %1 = spirv.Constant dense<[1, 1, 1, 1]> : !spirv.arm.tensor<4xi16>
%2 = spirv.Constant dense<[0]> : !spirv.arm.tensor<1xi8>
%3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
%4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
- // expected-error @+1 {{op failed to verify that shift must have length rank(input) - 1 when per_channel is true, otherwise length 1}}
- %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = true, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<2xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16>
+ // expected-error @+1 {{op failed to verify that shift must have length input_shape[rank(input) - 1] when per_channel is true, otherwise length 1}}
+ %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = true, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<4xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16>
spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16>
}
@@ -2030,7 +2030,7 @@ spirv.ARM.Graph @rescale_per_channel_false_requires_multiplier_length_one(%arg0:
%2 = spirv.Constant dense<[0]> : !spirv.arm.tensor<1xi8>
%3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
%4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
- // expected-error @+1 {{op failed to verify that multiplier must have length rank(input) - 1 when per_channel is true, otherwise length 1}}
+ // expected-error @+1 {{op failed to verify that multiplier must have length input_shape[rank(input) - 1] when per_channel is true, otherwise length 1}}
%5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<2xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16>
spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16>
}
@@ -2040,7 +2040,7 @@ spirv.ARM.Graph @rescale_per_channel_false_requires_shift_length_one(%arg0: !spi
%2 = spirv.Constant dense<[0, 0]> : !spirv.arm.tensor<2xi8>
%3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
%4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
- // expected-error @+1 {{op failed to verify that shift must have length rank(input) - 1 when per_channel is true, otherwise length 1}}
+ // expected-error @+1 {{op failed to verify that shift must have length input_shape[rank(input) - 1] when per_channel is true, otherwise length 1}}
%5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<2xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16>
spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16>
}
More information about the Mlir-commits
mailing list