[Mlir-commits] [mlir] [mlir][spirv] Improve type constraints for SPIR-V Tosa CastOp (PR #192227)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 15 03:16:18 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-spirv
Author: Davide Grohmann (davidegrohmann)
<details>
<summary>Changes</summary>
And disallow I64 for I/O in several other operators in SPIR-V Tosa
---
Patch is 26.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/192227.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td (+69-2)
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td (+4-1)
- (modified) mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir (+125)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index 2c086c9e48ebb..f245c55f36fa2 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -64,7 +64,8 @@ class SPIRV_TosaOpWithComplexResult<string mnemonic, int opcode, list<Trait> tra
class SPIRV_TosaElementwiseUnaryOp<string mnemonic, int opcode, list<Trait> traits = []> :
SPIRV_TosaOpWithResult<mnemonic, opcode, !listconcat(traits,
- [AllTypesMatch<["input1", "output"]>])> {
+ [AllTypesMatch<["input1", "output"]>,
+ ElementTypeIsNot<"input1", I64>])> {
let extraClassDeclaration = extraBaseClassDeclaration#[{
::mlir::spirv::TensorArmType getInput1Type() {
@@ -93,6 +94,7 @@ class SPIRV_TosaFloatElementwiseUnaryOp<string mnemonic, int opcode, list<Trait>
class SPIRV_TosaBinaryOp<string mnemonic, int opcode, list<Trait> traits = []> :
SPIRV_TosaOpWithResult<mnemonic, opcode, !listconcat(traits, [
AllElementTypesMatch<["input1", "input2"]>,
+ ElementTypeIsNot<"input1", I64>,
AllRanksMatch<["input1", "input2", "output"]>,
MatchBroadcastableShapes<"input1", "input2", "output">])> {
@@ -186,6 +188,7 @@ class SPIRV_TosaConvolutionOp<string mnemonic, int opcode, list<Trait> traits =
class SPIRV_TosaComparisonOp<string mnemonic, int opcode, list<Trait> traits = []> :
SPIRV_TosaOpWithResult<mnemonic, opcode, !listconcat(traits, [Pure,
AllElementTypesMatch<["input1", "input2"]>,
+ ElementTypeIsNot<"input1", I64>,
MatchBroadcastableShapes<"input1", "input2", "output">])> {
let arguments = (ins
@@ -216,6 +219,7 @@ class SPIRV_TosaComparisonOp<string mnemonic, int opcode, list<Trait> traits = [
class SPIRV_TosaReductionOp<string mnemonic, int opcode, list<Trait> traits = []> :
SPIRV_TosaOpWithResult<mnemonic, opcode, !listconcat(traits, [
AllElementTypesMatch<["input", "output"]>,
+ ElementTypeIsNot<"input", I64>,
AllRanksMatch<["input", "output"]>,
AxisValueLessThanRankOf<"input">])> {
@@ -227,6 +231,7 @@ class SPIRV_TosaReductionOp<string mnemonic, int opcode, list<Trait> traits = []
}
def SPIRV_TosaArgMaxOp : SPIRV_TosaOpWithResult<"ArgMax", 0, [Pure,
+ ElementTypeIsNot<"input", I64>,
OutputRankIsInputRankMinusOne<"input", "output">,
AxisValueLessThanRankOf<"input">]> {
let summary = "Perform argmax on the input.";
@@ -274,6 +279,7 @@ def SPIRV_TosaArgMaxOp : SPIRV_TosaOpWithResult<"ArgMax", 0, [Pure,
def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffect,
+ ElementTypeIsNot<"input", I64>,
TypeImpliesAccType<"input", I8, ["INT32"]>,
TypeImpliesAccType<"input", I16, ["INT32"]>,
TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
@@ -554,6 +560,7 @@ def SPIRV_TosaFFT2DOp : SPIRV_TosaOpWithComplexResult<"FFT2D", 5, [Pure]> {
def SPIRV_TosaMatMulOp : SPIRV_TosaOpWithResult<"MatMul", 6, [NoMemoryEffect,
+ ElementTypeIsNot<"A", I64>,
TypeConstraintImplicationOn<"A", I8, "output", [I32]>,
TypeConstraintImplicationOn<"A", I16, "output", [I64]>,
TypeConstraintImplicationOn<"A", BF16, "output", [F32]>,
@@ -611,6 +618,7 @@ def SPIRV_TosaMatMulOp : SPIRV_TosaOpWithResult<"MatMul", 6, [NoMemoryEffect,
def SPIRV_TosaMaxPool2DOp : SPIRV_TosaOpWithResult<"MaxPool2D", 7, [Pure,
+ ElementTypeIsNot<"input", I64>,
AllElementTypesMatch<["input", "output"]>]> {
let summary = "Performs max pooling on the input.";
@@ -763,6 +771,7 @@ def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaConvolutionOp<"TransposeConv2D", 9>
def SPIRV_TosaClampOp : SPIRV_TosaOpWithResult<"Clamp", 10, [Pure,
+ ElementTypeIsNot<"input", I64>,
AllTypesMatch<["input", "output"]>,
AllElementTypesMatch<["input", "output", "min_val", "max_val"]>]> {
let summary = "Computes Clamp(min, max).";
@@ -1423,6 +1432,7 @@ def SPIRV_TosaSubOp : SPIRV_TosaElementwiseBinaryOp<"Sub", 29, [NoMemoryEffect,
def SPIRV_TosaTableOp : SPIRV_TosaOpWithResult<"Table", 30, [NoMemoryEffect,
+ ElementTypeIsNot<"input1", I64>,
AllElementTypesMatch<["input1", "table"]>,
AllShapesMatch<["input1", "output"]>,
TypeConstraintImplicationOn<"input1", I8, "output", [I8]>,
@@ -1797,6 +1807,7 @@ def SPIRV_TosaSinOp : SPIRV_TosaFloatElementwiseUnaryOp<"Sin", 43> {
def SPIRV_TosaSelectOp : SPIRV_TosaOpWithResult<"Select", 44, [Pure,
+ ElementTypeIsNot<"true_value", I64>,
AllElementTypesMatch<["true_value", "false_value", "output"]>,
AllRanksMatch<["condition", "true_value", "false_value", "output"]>,
DeclareOpInterfaceMethods<SelectLikeOpInterface>]> {
@@ -2138,6 +2149,7 @@ def SPIRV_TosaReduceSumOp : SPIRV_TosaReductionOp<"ReduceSum", 53, [NoMemoryEffe
def SPIRV_TosaConcatOp : SPIRV_TosaOpWithResult<"Concat", 54, [Pure,
+ ElementTypeIsNot<"output", I64>,
VariadicInputWithMinSize<"input1", 1>,
VariadicInputAllSameElementType<"output", "input1">,
VariadicInputAllSameRank<"output", "input1">,
@@ -2183,6 +2195,7 @@ def SPIRV_TosaConcatOp : SPIRV_TosaOpWithResult<"Concat", 54, [Pure,
def SPIRV_TosaPadOp : SPIRV_TosaOpWithResult<"Pad", 55, [Pure,
+ ElementTypeIsNot<"input1", I64>,
AllElementTypesMatch<["input1", "pad_const", "output"]>,
AllRanksMatch<["input1", "output"]>,
ShapeConstraintFromInputRank<"input1", "padding", 2>]> {
@@ -2236,6 +2249,7 @@ def SPIRV_TosaPadOp : SPIRV_TosaOpWithResult<"Pad", 55, [Pure,
def SPIRV_TosaReshapeOp : SPIRV_TosaOpWithResult<"Reshape", 56, [Pure,
+ ElementTypeIsNot<"input1", I64>,
AllElementTypesMatch<["input1", "output"]>,
AllElementCountsMatch<["input1", "output"]>,
ShapeConstraintFromInputRank<"output", "shape">]> {
@@ -2284,6 +2298,7 @@ def SPIRV_TosaReshapeOp : SPIRV_TosaOpWithResult<"Reshape", 56, [Pure,
def SPIRV_TosaReverseOp : SPIRV_TosaOpWithResult<"Reverse", 57, [Pure,
+ ElementTypeIsNot<"input1", I64>,
AllTypesMatch<["input1", "output"]>,
AxisValueLessThanRankOf<"input1">]> {
let summary = "Reverse operator.";
@@ -2328,6 +2343,7 @@ def SPIRV_TosaReverseOp : SPIRV_TosaOpWithResult<"Reverse", 57, [Pure,
def SPIRV_TosaSliceOp : SPIRV_TosaOpWithResult<"Slice", 58, [Pure,
+ ElementTypeIsNot<"input1", I64>,
AllElementTypesMatch<["input1", "output"]>,
ShapeConstraintFromInputRank<"input1", "start">,
ShapeConstraintFromInputRank<"input1", "size">]> {
@@ -2381,6 +2397,7 @@ def SPIRV_TosaSliceOp : SPIRV_TosaOpWithResult<"Slice", 58, [Pure,
def SPIRV_TosaTileOp : SPIRV_TosaOpWithResult<"Tile", 59, [Pure,
+ ElementTypeIsNot<"input1", I64>,
AllElementTypesMatch<["input1", "output"]>,
AllRanksMatch<["input1", "output"]>,
ShapeConstraintFromInputRank<"input1", "multiples">]> {
@@ -2429,6 +2446,7 @@ def SPIRV_TosaTileOp : SPIRV_TosaOpWithResult<"Tile", 59, [Pure,
def SPIRV_TosaTransposeOp : SPIRV_TosaOpWithResult<"Transpose", 60, [Pure,
+ ElementTypeIsNot<"input1", I64>,
AllElementTypesMatch<["input1", "output"]>,
AllRanksMatch<["input1", "output"]>,
AllElementCountsMatch<["input1", "output"]>,
@@ -2475,6 +2493,7 @@ def SPIRV_TosaTransposeOp : SPIRV_TosaOpWithResult<"Transpose", 60, [Pure,
def SPIRV_TosaGatherOp : SPIRV_TosaOpWithResult<"Gather", 61, [NoMemoryEffect,
+ ElementTypeIsNot<"values", I64>,
AllElementTypesMatch<["values", "output"]>,
ValuesIndicesShapesMatch<"values", "indices", "output">]> {
let summary = "Gather operation.";
@@ -2522,6 +2541,7 @@ def SPIRV_TosaGatherOp : SPIRV_TosaOpWithResult<"Gather", 61, [NoMemoryEffect,
def SPIRV_TosaScatterOp : SPIRV_TosaOpWithResult<"Scatter", 62, [NoMemoryEffect,
+ ElementTypeIsNot<"values_in", I64>,
AllElementTypesMatch<["values_in", "input", "values_out"]>,
AllTypesMatch<["values_in", "values_out"]>,
ValuesIndicesShapesMatch<"values_in", "indices", "input">]> {
@@ -2578,6 +2598,7 @@ def SPIRV_TosaScatterOp : SPIRV_TosaOpWithResult<"Scatter", 62, [NoMemoryEffect,
def SPIRV_TosaResizeOp : SPIRV_TosaOpWithResult<"Resize", 63, [Pure,
+ ElementTypeIsNot<"input", I64>,
TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
@@ -2661,11 +2682,57 @@ def SPIRV_TosaResizeOp : SPIRV_TosaOpWithResult<"Resize", 63, [Pure,
def SPIRV_TosaCastOp : SPIRV_TosaOpWithResult<"Cast", 64, [Pure,
- AllShapesMatch<["input", "output"]>]> {
+ AllShapesMatch<["input", "output"]>,
+ ElementTypeIsNot<"input", I64>,
+ TypeConstraintImplicationOn<"input", F16, "output", [F32, I16, I32, I8]>,
+ TypeConstraintImplicationOn<"input", F32, "output", [F16, I16, I32, I8, BF16]>,
+ TypeConstraintImplicationOn<"input", I16, "output", [F16, F32, I32, I8, SPIRV_Bool, BF16]>,
+ TypeConstraintImplicationOn<"input", I32, "output", [F16, F32, I16, I8, SPIRV_Bool, BF16]>,
+ TypeConstraintImplicationOn<"input", I8, "output", [F16, F32, I16, I32, SPIRV_Bool, BF16]>,
+ TypeConstraintImplicationOn<"input", SPIRV_Bool, "output", [I16, I32, I8]>,
+ TypeConstraintImplicationOn<"input", BF16, "output", [F32, I16, I32, I8]>]> {
let summary = "Cast operation.";
let description = [{
Casts a tensor from one data type to another.
+ Valid casting combinations are defined in the following table:
+
+ | Mode | Input | Output |
+ |-------------------|---------|---------|
+ | fp16 to fp32 | float16 | float32 |
+ | fp16 to int16 | float16 | int16 |
+ | fp16 to int32 | float16 | int32 |
+ | fp16 to int8 | float16 | int8 |
+ | fp32 to fp16 | float32 | float16 |
+ | fp32 to int16 | float32 | int16 |
+ | fp32 to int32 | float32 | int32 |
+ | fp32 to int8 | float32 | int8 |
+ | int16 to fp16 | int16 | float16 |
+ | int16 to fp32 | int16 | float32 |
+ | int32 to fp16 | int32 | float16 |
+ | int32 to fp32 | int32 | float32 |
+ | int8 to fp16 | int8 | float16 |
+ | int8 to fp32 | int8 | float32 |
+ | bool to int16 | Boolean | int16 |
+ | bool to int32 | Boolean | int32 |
+ | bool to int8 | Boolean | int8 |
+ | int16 to bool | int16 | Boolean |
+ | int16 to int32 | int16 | int32 |
+ | int16 to int8 | int16 | int8 |
+ | int32 to bool | int32 | Boolean |
+ | int32 to int16 | int32 | int16 |
+ | int32 to int8 | int32 | int8 |
+ | int8 to bool | int8 | Boolean |
+ | int8 to int16 | int8 | int16 |
+ | int8 to int32 | int8 | int32 |
+ | bf16 to fp32 | bf16 | float32 |
+ | bf16 to int16 | bf16 | int16 |
+ | bf16 to int32 | bf16 | int32 |
+ | bf16 to int8 | bf16 | int8 |
+ | fp32 to bf16 | float32 | bf16 |
+ | int16 to bf16 | int16 | bf16 |
+ | int32 to bf16 | int32 | bf16 |
+ | int8 to bf16 | int8 | bf16 |
References:
* https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_cast
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
index 5a610aaa45cef..70545ce0884fa 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -124,6 +124,10 @@ class TypeConstraintImplicationOn<string name, Type type, string other, list<Typ
Implies<ElementTypeIsPred<name, type>,
!foreach(allowedType, allowedTypes, ElementTypeIsPred<other, allowedType>)>>;
+class ElementTypeIsNot<string name, Type type> :
+ PredOpTrait<name # " must not have type " # type.summary,
+ Neg<ElementTypeIsPred<name, type>>>;
+
class BoolAttrTypeConstraintImplicationOn<string boolAttr, string other, list<Type> allowedTypes>:
PredOpTrait<"if " # boolAttr # " is true then " #
other # " must have a type in [" #
@@ -222,5 +226,4 @@ class TensorLengthMatchesPerChannel<string tensor> :
"(getPerChannel() ? "
"::llvm::cast<::mlir::ShapedType>($input.getType()).getShape().back() : 1)">>;
-
#endif // MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
index 057aa353b4ee1..9dbae90a7f1eb 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
@@ -22,6 +22,12 @@ spirv.ARM.Graph @argmax_axis_value_not_in_input_rank_range(%arg0: !spirv.arm.ten
spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<3x28x17xi32>
}
+spirv.ARM.Graph @argmax_input_must_not_be_i64(%arg0: !spirv.arm.tensor<3x28x17x17xi64>) -> (!spirv.arm.tensor<3x28x17xi32>) {
+ // expected-error @+1 {{op failed to verify that input must not have type 64-bit signless integer}}
+ %2 = spirv.Tosa.ArgMax axis = 3, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<3x28x17x17xi64> -> !spirv.arm.tensor<3x28x17xi32>
+ spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<3x28x17xi32>
+}
+
//===----------------------------------------------------------------------===//
// spirv.TOSA.AvgPool2D
//===----------------------------------------------------------------------===//
@@ -34,6 +40,14 @@ spirv.ARM.Graph @avgpool2d_input_output_different_elemnt_type(%arg0: !spirv.arm.
spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x32768x1xi16>
}
+spirv.ARM.Graph @avgpool2d_input_must_not_be_i64(%arg0: !spirv.arm.tensor<1x3x65537x1xi64>) -> (!spirv.arm.tensor<1x2x32768x1xi64>) {
+ %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi64>
+ %5 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi64>
+ // expected-error @+1 {{op failed to verify that input must not have type 64-bit signless integer}}
+ %6 = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [1, 2], pad = [0, 1, 0, 0], acc_type = <INT32>, %arg0, %4, %5 : !spirv.arm.tensor<1x3x65537x1xi64>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi64> -> !spirv.arm.tensor<1x2x32768x1xi64>
+ spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x32768x1xi64>
+}
+
spirv.ARM.Graph @avgpool2d_input_input_zero_point_different_elemnt_type(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32768x1xi8>) {
%4 = spirv.Constant dense<125> : !spirv.arm.tensor<1xi16>
%5 = spirv.Constant dense<-90> : !spirv.arm.tensor<1xi8>
@@ -389,6 +403,12 @@ spirv.ARM.Graph @maxpool2d_input_output_different_element_types(%arg0: !spirv.ar
spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x2x32769x1xi16>
}
+spirv.ARM.Graph @maxpool2d_input_must_not_be_i64(%arg0: !spirv.arm.tensor<1x3x65537x1xi64>) -> (!spirv.arm.tensor<1x2x32769x1xi64>) {
+ // expected-error @+1 {{op failed to verify that input must not have type 64-bit signless integer}}
+ %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [1, 2], pad = [1, 0, 0, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x3x65537x1xi64> -> !spirv.arm.tensor<1x2x32769x1xi64>
+ spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x2x32769x1xi64>
+}
+
//===----------------------------------------------------------------------===//
// spirv.TOSA.TransposeConv2D
//===----------------------------------------------------------------------===//
@@ -489,6 +509,12 @@ spirv.ARM.Graph @clamp_max_val_different_element_type_wrt_input_output(%arg0: !s
spirv.ARM.GraphOutputs %3 : !spirv.arm.tensor<27x44x55xi8>
}
+spirv.ARM.Graph @clamp_input_must_not_be_i64(%arg0: !spirv.arm.tensor<27x44x55xi64>) -> (!spirv.arm.tensor<27x44x55xi64>) {
+ // expected-error @+1 {{op failed to verify that input must not have type 64-bit signless integer}}
+ %3 = spirv.Tosa.Clamp min_val = -102 : i64, max_val = -100 : i64, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x44x55xi64> -> !spirv.arm.tensor<27x44x55xi64>
+ spirv.ARM.GraphOutputs %3 : !spirv.arm.tensor<27x44x55xi64>
+}
+
//===----------------------------------------------------------------------===//
// spirv.TOSA.Erf
//===----------------------------------------------------------------------===//
@@ -571,6 +597,12 @@ spirv.ARM.Graph @add_output_shape_does_not_match_broadcast_shape(%arg0: !spirv.a
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1x10x6x6xi32>
}
+spirv.ARM.Graph @add_input_must_not_be_i64(%arg0: !spirv.arm.tensor<6x10x6x6xi64>, %arg1: !spirv.arm.tensor<1x10x6x6xi64>) -> (!spirv.arm.tensor<6x10x6x6xi64>) {
+ // expected-error @+1 {{op failed to verify that input1 must not have type 64-bit signless integer}}
+ %0 = spirv.Tosa.Add %arg0, %arg1 : !spirv.arm.tensor<6x10x6x6xi64>, !spirv.arm.tensor<1x10x6x6xi64> -> !spirv.arm.tensor<6x10x6x6xi64>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<6x10x6x6xi64>
+}
+
//===----------------------------------------------------------------------===//
// spirv.TOSA.ArithmeticRightShift
//===----------------------------------------------------------------------===//
@@ -1475,6 +1507,12 @@ spirv.ARM.Graph @select_output_rank_not_matching_inputs(%arg0: !spirv.arm.tensor
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x6x5xi8>
}
+spirv.ARM.Graph @select_true_value_must_not_be_i64(%arg0: !spirv.arm.tensor<4x6x4x5xi1>, %arg1: !spirv.arm.tensor<4x6x4x5xi64>, %arg2: !spirv.arm.tensor<4x6x4x5xi64>) -> (!spirv.arm.tensor<4x6x4x5xi64>) {
+ // expected-error @+1 {{op failed to verify that true_value must not have type 64-bit signless integer}}
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x6x4x5xi1>, !spirv.arm.tensor<4x6x4x5xi64>, !spirv.arm.tensor<4x6x4x5xi64> -> !spirv.arm.tensor<4x6x4x5xi64>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x6x4x5xi64>
+}
+
spirv.ARM.Graph @select_condition_true_value_not_broadcastable_false_value_compatible(%arg0: !spirv.arm.tensor<4x2x4x5xi1>, %arg1: !spirv.arm.tensor<4x3x4x5xi8>, %arg2: !spirv.arm.tensor<4x1x4x5xi8>) -> (!spirv.arm.tensor<4x3x4x5xi8>) {
// expected-error @+1 {{op failed to verify that the shape of inputs: condition, true_value, and false_value are compatible for broadcasting}}
%0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x2x4x5xi1>, !spirv.arm.tensor<4x3x4x5xi8>, !spirv.arm.tensor<4x1x4x5xi8> -> !spirv.arm.tensor<4x3x4x5xi8>
@@ -1619,6 +1657,12 @@ spirv.ARM.Graph @reducemax_axis_value_not_in_input_rank_range(%arg0: !spirv.arm.
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<8x30x12x3xi8>
}
+spirv.ARM.Graph @reducemax_input_must_not_be_i64(%arg0: !spirv.arm.tensor<8x30x12x3xi64>) -> (!spirv.arm.tensor<8x30x12x3xi64>) {
+ // expected-error @+1 {{op failed to verify that input must not have type 64-bit signless integer}}
+ %0 = spirv.Tosa.ReduceMax axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<8x30x12x3xi64> -> !spirv.arm.tensor<8x30x12x3xi64>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<8x30x12x3xi64>
+}
+
//===----------------------------------------------------------------------===//
// spirv.TOSA.ReduceMin
//===----------------------------------------------------------------------===//
@@ -1641,6 +1685,12 @@ spirv.ARM.Graph @reducemin_axis_value_not_in_input_rank_range(%arg0: !spirv.arm.
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<27x10x25x9xf16>
}
+spirv.ARM.Graph @reducemin_input_must_not_be_i64(%arg0: !spirv.arm.tensor<27x10x25x9xi64>) -> (!spirv.arm.tensor<27x10x25x9xi64>) {
+ // expected-error @+1 {{op failed to verify that input must not have type 64-bit signless integer}}
+ %0 = spirv.Tosa.ReduceMin axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x10x25x9xi64> -> !spirv.arm.tensor<27x10x25x9xi64>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<27x10x25x9xi64>
+}
+
//===----------------------------------------------------------------------===//
// spirv.TOSA.ReduceProduct
//===----------------------------------------------------------------------===//
@@ -1685,6 +1735,12 @@ spirv.ARM.Graph @reducesum_axis_value_not_in_input_rank_range(%arg0: !spirv.arm.
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<20x24x22xi32>
}
+spirv.ARM.Graph @reducesum_input_must_not_be_i64(%arg0: !spirv.arm.tensor<20x24x22xi64>) -> (!spirv.arm.tensor<20x24x22xi64>) {
+ // expected-error @+1 {{op failed to verify that input must not have type 64-bit signless integer}}
+ %0 = spirv.Tosa.ReduceSum axis = 1, %arg0 : !spirv.arm.tensor<20x24x22xi64> -> !spirv.arm.tensor<20x24x22xi64>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<20x24x22xi64>
+}
+
//===---------------------------------------------------------------------...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/192227
More information about the Mlir-commits
mailing list