[Mlir-commits] [mlir] [mlir][spirv] Improve type constraints for SPIR-V Tosa CastOp (PR #192227)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 15 03:16:18 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-spirv

Author: Davide Grohmann (davidegrohmann)

<details>
<summary>Changes</summary>

And disallow I64 for I/O in several other operators in SPIR-V Tosa

---

Patch is 26.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/192227.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td (+69-2) 
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td (+4-1) 
- (modified) mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir (+125) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index 2c086c9e48ebb..f245c55f36fa2 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -64,7 +64,8 @@ class SPIRV_TosaOpWithComplexResult<string mnemonic, int opcode, list<Trait> tra
 
 class SPIRV_TosaElementwiseUnaryOp<string mnemonic, int opcode, list<Trait> traits = []> :
   SPIRV_TosaOpWithResult<mnemonic, opcode, !listconcat(traits,
-    [AllTypesMatch<["input1", "output"]>])> {
+    [AllTypesMatch<["input1", "output"]>,
+     ElementTypeIsNot<"input1", I64>])> {
 
   let extraClassDeclaration = extraBaseClassDeclaration#[{
     ::mlir::spirv::TensorArmType getInput1Type() {
@@ -93,6 +94,7 @@ class SPIRV_TosaFloatElementwiseUnaryOp<string mnemonic, int opcode, list<Trait>
 class SPIRV_TosaBinaryOp<string mnemonic, int opcode, list<Trait> traits = []> :
   SPIRV_TosaOpWithResult<mnemonic, opcode, !listconcat(traits, [
     AllElementTypesMatch<["input1", "input2"]>,
+    ElementTypeIsNot<"input1", I64>,
     AllRanksMatch<["input1", "input2", "output"]>,
     MatchBroadcastableShapes<"input1", "input2", "output">])> {
 
@@ -186,6 +188,7 @@ class SPIRV_TosaConvolutionOp<string mnemonic, int opcode, list<Trait> traits =
 class SPIRV_TosaComparisonOp<string mnemonic, int opcode, list<Trait> traits = []> :
   SPIRV_TosaOpWithResult<mnemonic, opcode, !listconcat(traits, [Pure,
   AllElementTypesMatch<["input1", "input2"]>,
+  ElementTypeIsNot<"input1", I64>,
   MatchBroadcastableShapes<"input1", "input2", "output">])> {
 
   let arguments = (ins
@@ -216,6 +219,7 @@ class SPIRV_TosaComparisonOp<string mnemonic, int opcode, list<Trait> traits = [
 class SPIRV_TosaReductionOp<string mnemonic, int opcode, list<Trait> traits = []> :
   SPIRV_TosaOpWithResult<mnemonic, opcode, !listconcat(traits, [
   AllElementTypesMatch<["input", "output"]>,
+  ElementTypeIsNot<"input", I64>,
   AllRanksMatch<["input", "output"]>,
   AxisValueLessThanRankOf<"input">])> {
 
@@ -227,6 +231,7 @@ class SPIRV_TosaReductionOp<string mnemonic, int opcode, list<Trait> traits = []
 }
 
 def SPIRV_TosaArgMaxOp : SPIRV_TosaOpWithResult<"ArgMax", 0, [Pure,
+  ElementTypeIsNot<"input", I64>,
   OutputRankIsInputRankMinusOne<"input", "output">,
   AxisValueLessThanRankOf<"input">]> {
   let summary = "Perform argmax on the input.";
@@ -274,6 +279,7 @@ def SPIRV_TosaArgMaxOp : SPIRV_TosaOpWithResult<"ArgMax", 0, [Pure,
 
 
 def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffect,
+  ElementTypeIsNot<"input", I64>,
   TypeImpliesAccType<"input", I8, ["INT32"]>,
   TypeImpliesAccType<"input", I16, ["INT32"]>,
   TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
@@ -554,6 +560,7 @@ def SPIRV_TosaFFT2DOp : SPIRV_TosaOpWithComplexResult<"FFT2D", 5, [Pure]> {
 
 
 def SPIRV_TosaMatMulOp : SPIRV_TosaOpWithResult<"MatMul", 6, [NoMemoryEffect,
+  ElementTypeIsNot<"A", I64>,
   TypeConstraintImplicationOn<"A", I8, "output", [I32]>,
   TypeConstraintImplicationOn<"A", I16, "output", [I64]>,
   TypeConstraintImplicationOn<"A", BF16, "output", [F32]>,
@@ -611,6 +618,7 @@ def SPIRV_TosaMatMulOp : SPIRV_TosaOpWithResult<"MatMul", 6, [NoMemoryEffect,
 
 
 def SPIRV_TosaMaxPool2DOp : SPIRV_TosaOpWithResult<"MaxPool2D", 7, [Pure,
+  ElementTypeIsNot<"input", I64>,
   AllElementTypesMatch<["input", "output"]>]> {
   let summary = "Performs max pooling on the input.";
 
@@ -763,6 +771,7 @@ def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaConvolutionOp<"TransposeConv2D", 9>
 
 
 def SPIRV_TosaClampOp : SPIRV_TosaOpWithResult<"Clamp", 10, [Pure,
+  ElementTypeIsNot<"input", I64>,
   AllTypesMatch<["input", "output"]>,
   AllElementTypesMatch<["input", "output", "min_val", "max_val"]>]> {
   let summary = "Computes Clamp(min, max).";
@@ -1423,6 +1432,7 @@ def SPIRV_TosaSubOp : SPIRV_TosaElementwiseBinaryOp<"Sub", 29, [NoMemoryEffect,
 
 
 def SPIRV_TosaTableOp : SPIRV_TosaOpWithResult<"Table", 30, [NoMemoryEffect,
+  ElementTypeIsNot<"input1", I64>,
   AllElementTypesMatch<["input1", "table"]>,
   AllShapesMatch<["input1", "output"]>,
   TypeConstraintImplicationOn<"input1", I8, "output", [I8]>,
@@ -1797,6 +1807,7 @@ def SPIRV_TosaSinOp : SPIRV_TosaFloatElementwiseUnaryOp<"Sin", 43> {
 
 
 def SPIRV_TosaSelectOp : SPIRV_TosaOpWithResult<"Select", 44, [Pure,
+  ElementTypeIsNot<"true_value", I64>,
   AllElementTypesMatch<["true_value", "false_value", "output"]>,
   AllRanksMatch<["condition", "true_value", "false_value", "output"]>,
   DeclareOpInterfaceMethods<SelectLikeOpInterface>]> {
@@ -2138,6 +2149,7 @@ def SPIRV_TosaReduceSumOp : SPIRV_TosaReductionOp<"ReduceSum", 53, [NoMemoryEffe
 
 
 def SPIRV_TosaConcatOp : SPIRV_TosaOpWithResult<"Concat", 54, [Pure,
+  ElementTypeIsNot<"output", I64>,
   VariadicInputWithMinSize<"input1", 1>,
   VariadicInputAllSameElementType<"output", "input1">,
   VariadicInputAllSameRank<"output", "input1">,
@@ -2183,6 +2195,7 @@ def SPIRV_TosaConcatOp : SPIRV_TosaOpWithResult<"Concat", 54, [Pure,
 
 
 def SPIRV_TosaPadOp : SPIRV_TosaOpWithResult<"Pad", 55, [Pure,
+  ElementTypeIsNot<"input1", I64>,
   AllElementTypesMatch<["input1", "pad_const", "output"]>,
   AllRanksMatch<["input1", "output"]>,
   ShapeConstraintFromInputRank<"input1", "padding", 2>]> {
@@ -2236,6 +2249,7 @@ def SPIRV_TosaPadOp : SPIRV_TosaOpWithResult<"Pad", 55, [Pure,
 
 
 def SPIRV_TosaReshapeOp : SPIRV_TosaOpWithResult<"Reshape", 56, [Pure,
+  ElementTypeIsNot<"input1", I64>,
   AllElementTypesMatch<["input1", "output"]>,
   AllElementCountsMatch<["input1", "output"]>,
   ShapeConstraintFromInputRank<"output", "shape">]> {
@@ -2284,6 +2298,7 @@ def SPIRV_TosaReshapeOp : SPIRV_TosaOpWithResult<"Reshape", 56, [Pure,
 
 
 def SPIRV_TosaReverseOp : SPIRV_TosaOpWithResult<"Reverse", 57, [Pure,
+  ElementTypeIsNot<"input1", I64>,
   AllTypesMatch<["input1", "output"]>,
   AxisValueLessThanRankOf<"input1">]> {
   let summary = "Reverse operator.";
@@ -2328,6 +2343,7 @@ def SPIRV_TosaReverseOp : SPIRV_TosaOpWithResult<"Reverse", 57, [Pure,
 
 
 def SPIRV_TosaSliceOp : SPIRV_TosaOpWithResult<"Slice", 58, [Pure,
+  ElementTypeIsNot<"input1", I64>,
   AllElementTypesMatch<["input1", "output"]>,
   ShapeConstraintFromInputRank<"input1", "start">,
   ShapeConstraintFromInputRank<"input1", "size">]> {
@@ -2381,6 +2397,7 @@ def SPIRV_TosaSliceOp : SPIRV_TosaOpWithResult<"Slice", 58, [Pure,
 
 
 def SPIRV_TosaTileOp : SPIRV_TosaOpWithResult<"Tile", 59, [Pure,
+  ElementTypeIsNot<"input1", I64>,
   AllElementTypesMatch<["input1", "output"]>,
   AllRanksMatch<["input1", "output"]>,
   ShapeConstraintFromInputRank<"input1", "multiples">]> {
@@ -2429,6 +2446,7 @@ def SPIRV_TosaTileOp : SPIRV_TosaOpWithResult<"Tile", 59, [Pure,
 
 
 def SPIRV_TosaTransposeOp : SPIRV_TosaOpWithResult<"Transpose", 60, [Pure,
+  ElementTypeIsNot<"input1", I64>,
   AllElementTypesMatch<["input1", "output"]>,
   AllRanksMatch<["input1", "output"]>,
   AllElementCountsMatch<["input1", "output"]>,
@@ -2475,6 +2493,7 @@ def SPIRV_TosaTransposeOp : SPIRV_TosaOpWithResult<"Transpose", 60, [Pure,
 
 
 def SPIRV_TosaGatherOp : SPIRV_TosaOpWithResult<"Gather", 61, [NoMemoryEffect,
+  ElementTypeIsNot<"values", I64>,
   AllElementTypesMatch<["values", "output"]>,
   ValuesIndicesShapesMatch<"values", "indices", "output">]> {
   let summary = "Gather operation.";
@@ -2522,6 +2541,7 @@ def SPIRV_TosaGatherOp : SPIRV_TosaOpWithResult<"Gather", 61, [NoMemoryEffect,
 
 
 def SPIRV_TosaScatterOp : SPIRV_TosaOpWithResult<"Scatter", 62, [NoMemoryEffect,
+  ElementTypeIsNot<"values_in", I64>,
   AllElementTypesMatch<["values_in", "input", "values_out"]>,
   AllTypesMatch<["values_in", "values_out"]>,
   ValuesIndicesShapesMatch<"values_in", "indices", "input">]> {
@@ -2578,6 +2598,7 @@ def SPIRV_TosaScatterOp : SPIRV_TosaOpWithResult<"Scatter", 62, [NoMemoryEffect,
 
 
 def SPIRV_TosaResizeOp : SPIRV_TosaOpWithResult<"Resize", 63, [Pure,
+  ElementTypeIsNot<"input", I64>,
   TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
   TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
   TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
@@ -2661,11 +2682,57 @@ def SPIRV_TosaResizeOp : SPIRV_TosaOpWithResult<"Resize", 63, [Pure,
 
 
 def SPIRV_TosaCastOp : SPIRV_TosaOpWithResult<"Cast", 64, [Pure,
-  AllShapesMatch<["input", "output"]>]> {
+  AllShapesMatch<["input", "output"]>,
+  ElementTypeIsNot<"input", I64>,
+  TypeConstraintImplicationOn<"input", F16, "output", [F32, I16, I32, I8]>,
+  TypeConstraintImplicationOn<"input", F32, "output", [F16, I16, I32, I8, BF16]>,
+  TypeConstraintImplicationOn<"input", I16, "output", [F16, F32, I32, I8, SPIRV_Bool, BF16]>,
+  TypeConstraintImplicationOn<"input", I32, "output", [F16, F32, I16, I8, SPIRV_Bool, BF16]>,
+  TypeConstraintImplicationOn<"input", I8, "output", [F16, F32, I16, I32, SPIRV_Bool, BF16]>,
+  TypeConstraintImplicationOn<"input", SPIRV_Bool, "output", [I16, I32, I8]>,
+  TypeConstraintImplicationOn<"input", BF16, "output", [F32, I16, I32, I8]>]> {
   let summary = "Cast operation.";
 
   let description = [{
     Casts a tensor from one data type to another.
+    Valid casting combinations are defined in the following table:
+
+    | Mode              | Input   | Output  |
+    |-------------------|---------|---------|
+    | fp16 to fp32      | float16 | float32 |
+    | fp16 to int16     | float16 | int16   |
+    | fp16 to int32     | float16 | int32   |
+    | fp16 to int8      | float16 | int8    |
+    | fp32 to fp16      | float32 | float16 |
+    | fp32 to int16     | float32 | int16   |
+    | fp32 to int32     | float32 | int32   |
+    | fp32 to int8      | float32 | int8    |
+    | int16 to fp16     | int16   | float16 |
+    | int16 to fp32     | int16   | float32 |
+    | int32 to fp16     | int32   | float16 |
+    | int32 to fp32     | int32   | float32 |
+    | int8 to fp16      | int8    | float16 |
+    | int8 to fp32      | int8    | float32 |
+    | bool to int16     | Boolean | int16   |
+    | bool to int32     | Boolean | int32   |
+    | bool to int8      | Boolean | int8    |
+    | int16 to bool     | int16   | Boolean |
+    | int16 to int32    | int16   | int32   |
+    | int16 to int8     | int16   | int8    |
+    | int32 to bool     | int32   | Boolean |
+    | int32 to int16    | int32   | int16   |
+    | int32 to int8     | int32   | int8    |
+    | int8 to bool      | int8    | Boolean |
+    | int8 to int16     | int8    | int16   |
+    | int8 to int32     | int8    | int32   |
+    | bf16 to fp32      | bf16    | float32 |
+    | bf16 to int16     | bf16    | int16   |
+    | bf16 to int32     | bf16    | int32   |
+    | bf16 to int8      | bf16    | int8    |
+    | fp32 to bf16      | float32 | bf16    |
+    | int16 to bf16     | int16   | bf16    |
+    | int32 to bf16     | int32   | bf16    |
+    | int8 to bf16      | int8    | bf16    |
 
     References:
       * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_cast
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
index 5a610aaa45cef..70545ce0884fa 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -124,6 +124,10 @@ class TypeConstraintImplicationOn<string name, Type type, string other, list<Typ
     Implies<ElementTypeIsPred<name, type>,
     !foreach(allowedType, allowedTypes, ElementTypeIsPred<other, allowedType>)>>;
 
+class ElementTypeIsNot<string name, Type type> :
+  PredOpTrait<name # " must not have type " # type.summary,
+              Neg<ElementTypeIsPred<name, type>>>;
+
 class BoolAttrTypeConstraintImplicationOn<string boolAttr, string other, list<Type> allowedTypes>:
   PredOpTrait<"if " # boolAttr # " is true then " #
               other # " must have a type in [" #
@@ -222,5 +226,4 @@ class TensorLengthMatchesPerChannel<string tensor> :
           "(getPerChannel() ? "
           "::llvm::cast<::mlir::ShapedType>($input.getType()).getShape().back() : 1)">>;
 
-
 #endif // MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
index 057aa353b4ee1..9dbae90a7f1eb 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
@@ -22,6 +22,12 @@ spirv.ARM.Graph @argmax_axis_value_not_in_input_rank_range(%arg0: !spirv.arm.ten
   spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<3x28x17xi32>
 }
 
+spirv.ARM.Graph @argmax_input_must_not_be_i64(%arg0: !spirv.arm.tensor<3x28x17x17xi64>) -> (!spirv.arm.tensor<3x28x17xi32>) {
+  // expected-error @+1 {{op failed to verify that input must not have type 64-bit signless integer}}
+  %2 = spirv.Tosa.ArgMax axis = 3, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<3x28x17x17xi64> -> !spirv.arm.tensor<3x28x17xi32>
+  spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<3x28x17xi32>
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.TOSA.AvgPool2D
 //===----------------------------------------------------------------------===//
@@ -34,6 +40,14 @@ spirv.ARM.Graph @avgpool2d_input_output_different_elemnt_type(%arg0: !spirv.arm.
   spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x32768x1xi16>
 }
 
+spirv.ARM.Graph @avgpool2d_input_must_not_be_i64(%arg0: !spirv.arm.tensor<1x3x65537x1xi64>) -> (!spirv.arm.tensor<1x2x32768x1xi64>) {
+  %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi64>
+  %5 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi64>
+  // expected-error @+1 {{op failed to verify that input must not have type 64-bit signless integer}}
+  %6 = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [1, 2], pad = [0, 1, 0, 0], acc_type = <INT32>, %arg0, %4, %5 : !spirv.arm.tensor<1x3x65537x1xi64>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi64> -> !spirv.arm.tensor<1x2x32768x1xi64>
+  spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x32768x1xi64>
+}
+
 spirv.ARM.Graph @avgpool2d_input_input_zero_point_different_elemnt_type(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32768x1xi8>) {
   %4 = spirv.Constant dense<125> : !spirv.arm.tensor<1xi16>
   %5 = spirv.Constant dense<-90> : !spirv.arm.tensor<1xi8>
@@ -389,6 +403,12 @@ spirv.ARM.Graph @maxpool2d_input_output_different_element_types(%arg0: !spirv.ar
   spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x2x32769x1xi16>
 }
 
+spirv.ARM.Graph @maxpool2d_input_must_not_be_i64(%arg0: !spirv.arm.tensor<1x3x65537x1xi64>) -> (!spirv.arm.tensor<1x2x32769x1xi64>) {
+  // expected-error @+1 {{op failed to verify that input must not have type 64-bit signless integer}}
+  %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [1, 2], pad = [1, 0, 0, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x3x65537x1xi64> -> !spirv.arm.tensor<1x2x32769x1xi64>
+  spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x2x32769x1xi64>
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.TOSA.TransposeConv2D
 //===----------------------------------------------------------------------===//
@@ -489,6 +509,12 @@ spirv.ARM.Graph @clamp_max_val_different_element_type_wrt_input_output(%arg0: !s
   spirv.ARM.GraphOutputs %3 : !spirv.arm.tensor<27x44x55xi8>
 }
 
+spirv.ARM.Graph @clamp_input_must_not_be_i64(%arg0: !spirv.arm.tensor<27x44x55xi64>) -> (!spirv.arm.tensor<27x44x55xi64>) {
+  // expected-error @+1 {{op failed to verify that input must not have type 64-bit signless integer}}
+  %3 = spirv.Tosa.Clamp min_val = -102 : i64, max_val = -100 : i64, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x44x55xi64> -> !spirv.arm.tensor<27x44x55xi64>
+  spirv.ARM.GraphOutputs %3 : !spirv.arm.tensor<27x44x55xi64>
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.TOSA.Erf
 //===----------------------------------------------------------------------===//
@@ -571,6 +597,12 @@ spirv.ARM.Graph @add_output_shape_does_not_match_broadcast_shape(%arg0: !spirv.a
   spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1x10x6x6xi32>
 }
 
+spirv.ARM.Graph @add_input_must_not_be_i64(%arg0: !spirv.arm.tensor<6x10x6x6xi64>, %arg1: !spirv.arm.tensor<1x10x6x6xi64>) -> (!spirv.arm.tensor<6x10x6x6xi64>) {
+  // expected-error @+1 {{op failed to verify that input1 must not have type 64-bit signless integer}}
+  %0 = spirv.Tosa.Add %arg0, %arg1 : !spirv.arm.tensor<6x10x6x6xi64>, !spirv.arm.tensor<1x10x6x6xi64> -> !spirv.arm.tensor<6x10x6x6xi64>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<6x10x6x6xi64>
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.TOSA.ArithmeticRightShift
 //===----------------------------------------------------------------------===//
@@ -1475,6 +1507,12 @@ spirv.ARM.Graph @select_output_rank_not_matching_inputs(%arg0: !spirv.arm.tensor
   spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x6x5xi8>
 }
 
+spirv.ARM.Graph @select_true_value_must_not_be_i64(%arg0: !spirv.arm.tensor<4x6x4x5xi1>, %arg1: !spirv.arm.tensor<4x6x4x5xi64>, %arg2: !spirv.arm.tensor<4x6x4x5xi64>) -> (!spirv.arm.tensor<4x6x4x5xi64>) {
+  // expected-error @+1 {{op failed to verify that true_value must not have type 64-bit signless integer}}
+  %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x6x4x5xi1>, !spirv.arm.tensor<4x6x4x5xi64>, !spirv.arm.tensor<4x6x4x5xi64> -> !spirv.arm.tensor<4x6x4x5xi64>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x6x4x5xi64>
+}
+
 spirv.ARM.Graph @select_condition_true_value_not_broadcastable_false_value_compatible(%arg0: !spirv.arm.tensor<4x2x4x5xi1>, %arg1: !spirv.arm.tensor<4x3x4x5xi8>, %arg2: !spirv.arm.tensor<4x1x4x5xi8>) -> (!spirv.arm.tensor<4x3x4x5xi8>) {
   // expected-error @+1 {{op failed to verify that the shape of inputs: condition, true_value, and false_value are compatible for broadcasting}}
   %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x2x4x5xi1>, !spirv.arm.tensor<4x3x4x5xi8>, !spirv.arm.tensor<4x1x4x5xi8> -> !spirv.arm.tensor<4x3x4x5xi8>
@@ -1619,6 +1657,12 @@ spirv.ARM.Graph @reducemax_axis_value_not_in_input_rank_range(%arg0: !spirv.arm.
   spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<8x30x12x3xi8>
 }
 
+spirv.ARM.Graph @reducemax_input_must_not_be_i64(%arg0: !spirv.arm.tensor<8x30x12x3xi64>) -> (!spirv.arm.tensor<8x30x12x3xi64>) {
+  // expected-error @+1 {{op failed to verify that input must not have type 64-bit signless integer}}
+  %0 = spirv.Tosa.ReduceMax axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<8x30x12x3xi64> -> !spirv.arm.tensor<8x30x12x3xi64>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<8x30x12x3xi64>
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.TOSA.ReduceMin
 //===----------------------------------------------------------------------===//
@@ -1641,6 +1685,12 @@ spirv.ARM.Graph @reducemin_axis_value_not_in_input_rank_range(%arg0: !spirv.arm.
   spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<27x10x25x9xf16>
 }
 
+spirv.ARM.Graph @reducemin_input_must_not_be_i64(%arg0: !spirv.arm.tensor<27x10x25x9xi64>) -> (!spirv.arm.tensor<27x10x25x9xi64>) {
+  // expected-error @+1 {{op failed to verify that input must not have type 64-bit signless integer}}
+  %0 = spirv.Tosa.ReduceMin axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x10x25x9xi64> -> !spirv.arm.tensor<27x10x25x9xi64>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<27x10x25x9xi64>
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.TOSA.ReduceProduct
 //===----------------------------------------------------------------------===//
@@ -1685,6 +1735,12 @@ spirv.ARM.Graph @reducesum_axis_value_not_in_input_rank_range(%arg0: !spirv.arm.
   spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<20x24x22xi32>
 }
 
+spirv.ARM.Graph @reducesum_input_must_not_be_i64(%arg0: !spirv.arm.tensor<20x24x22xi64>) -> (!spirv.arm.tensor<20x24x22xi64>) {
+  // expected-error @+1 {{op failed to verify that input must not have type 64-bit signless integer}}
+  %0 = spirv.Tosa.ReduceSum axis = 1, %arg0 : !spirv.arm.tensor<20x24x22xi64> -> !spirv.arm.tensor<20x24x22xi64>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<20x24x22xi64>
+}
+
 //===---------------------------------------------------------------------...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/192227


More information about the Mlir-commits mailing list