[Mlir-commits] [mlir] [mlir][spirv] Add comparison and elementwise ternary ops in TOSA Ext Inst Set (PR #186356)
Davide Grohmann
llvmlistbot at llvm.org
Fri Mar 13 08:43:26 PDT 2026
https://github.com/davidegrohmann updated https://github.com/llvm/llvm-project/pull/186356
>From dc4531d13e7fed6b7e638a8873a1bc8ea0d0d65e Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Wed, 28 Jan 2026 13:39:57 +0100
Subject: [PATCH] [mlir][spirv] Add comparison and elementwise ternary ops in
TOSA Ext Inst Set
This patch introduces the following elementwise unary operators:
spirv.Tosa.Select
spirv.Tosa.Equal
spirv.Tosa.Greater
spirv.Tosa.GreaterEqual
Also dialect and serialization round-trip tests have been added.
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: Ib7dd5c061ba49aca5d8532190598e4fdb75ae8d9
---
.../mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td | 159 +++++++++++++++++
.../mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td | 18 +-
.../SPIRV/IR/tosa-ops-verification.mlir | 142 +++++++++++++++
mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir | 88 ++++++++++
mlir/test/Target/SPIRV/tosa-ops.mlir | 162 ++++++++++++++++++
5 files changed, 568 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index 9fb3a53286bdf..5339769cadbf3 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -183,6 +183,35 @@ class SPIRV_TosaConvolutionOp<string mnemonic, int opcode, list<Trait> traits =
}];
}
+class SPIRV_TosaComparisonOp<string mnemonic, int opcode, list<Trait> traits = []> :
+ SPIRV_TosaOpWithResult<mnemonic, opcode, !listconcat(traits, [Pure,
+ AllElementTypesMatch<["input1", "input2"]>,
+ MatchBroadcastableShapes<"input1", "input2", "output">])> {
+
+ let arguments = (ins
+ SPIRV_TosaNumerical_TensorArm: $input1,
+ SPIRV_TosaNumerical_TensorArm: $input2
+ );
+
+ let results = (outs
+ SPIRV_Bool_TensorArm: $output
+ );
+
+ let assemblyFormat = [{
+ $input1 `,`
+ $input2
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ ::mlir::spirv::TensorArmType getInput1Type() {
+ return cast<::mlir::spirv::TensorArmType>(getInput1().getType());
+ }
+ ::mlir::spirv::TensorArmType getInput2Type() {
+ return cast<::mlir::spirv::TensorArmType>(getInput2().getType());
+ }
+ }];
+}
def SPIRV_TosaArgMaxOp : SPIRV_TosaOpWithResult<"ArgMax", 0, [Pure,
OutputRankIsInputRankMinusOne<"input", "output">,
@@ -1742,4 +1771,134 @@ def SPIRV_TosaSinOp : SPIRV_TosaFloatElementwiseUnaryOp<"Sin", 43> {
}
+def SPIRV_TosaSelectOp : SPIRV_TosaOpWithResult<"Select", 44, [Pure,
+ AllElementTypesMatch<["true_value", "false_value", "output"]>,
+ AllRanksMatch<["condition", "true_value", "false_value", "output"]>,
+ TernaryMatchBroadcastableShapes<"condition", "true_value", "false_value", "output">,
+ DeclareOpInterfaceMethods<SelectLikeOpInterface>]> {
+ let summary = "Select operator.";
+
+ let description = [{
+ Elementwise Select of the output based on a condition.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_select
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_select
+
+ #### Example:
+ ```mlir
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x1x4x5xi1>, !spirv.arm.tensor<4x6x4x5xi8>, !spirv.arm.tensor<4x6x4x5xi8> -> !spirv.arm.tensor<4x6x4x5xi8>
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<9x2x15x8xi1>, !spirv.arm.tensor<9x2x15x8xf16>, !spirv.arm.tensor<9x1x15x8xf16> -> !spirv.arm.tensor<9x2x15x8xf16>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_Bool_TensorArm: $condition,
+ SPIRV_TosaAny_TensorArm: $true_value,
+ SPIRV_TosaAny_TensorArm: $false_value
+ );
+
+ let results = (outs
+ SPIRV_TosaAny_TensorArm: $output
+ );
+
+ let assemblyFormat = [{
+ $condition `,`
+ $true_value `,`
+ $false_value
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ ::mlir::spirv::TensorArmType getConditionType() {
+ return cast<::mlir::spirv::TensorArmType>(getCondition().getType());
+ }
+ ::mlir::spirv::TensorArmType getTrueValueType() {
+ return cast<::mlir::spirv::TensorArmType>(getTrueValue().getType());
+ }
+ ::mlir::spirv::TensorArmType getFalseValueType() {
+ return cast<::mlir::spirv::TensorArmType>(getFalseValue().getType());
+ }
+ ::mlir::Value getInput1() {
+ return getCondition();
+ }
+ ::mlir::Value getInput2() {
+ return getTrueValue();
+ }
+ ::mlir::Value getInput3() {
+ return getFalseValue();
+ }
+ ::mlir::spirv::TensorArmType getInput1Type() {
+ return getConditionType();
+ }
+ ::mlir::spirv::TensorArmType getInput2Type() {
+ return getTrueValueType();
+ }
+ ::mlir::spirv::TensorArmType getInput3Type() {
+ return getFalseValueType();
+ }
+ }];
+}
+
+
+def SPIRV_TosaEqualOp : SPIRV_TosaComparisonOp<"Equal", 45> {
+ let summary = "Equal comparison operation";
+
+ let description = [{
+ Elementwise Equal comparison operation: returns the truth value of
+ (input1 == input2) element-wise.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_equal
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_equal
+
+ #### Example:
+ ```mlir
+ %0 = spirv.Tosa.Equal %arg0, %arg1 : !spirv.arm.tensor<51x28x59xi32>, !spirv.arm.tensor<51x1x59xi32> -> !spirv.arm.tensor<51x28x59xi1>
+ %0 = spirv.Tosa.Equal %arg0, %arg1 : !spirv.arm.tensor<16x11x5x3xf32>, !spirv.arm.tensor<16x1x5x3xf32> -> !spirv.arm.tensor<16x11x5x3xi1>
+ ```
+ }];
+}
+
+
+def SPIRV_TosaGreaterOp : SPIRV_TosaComparisonOp<"Greater", 46> {
+ let summary = "Greater comparison operation";
+
+ let description = [{
+ Elementwise Greater than comparison operation: returns the truth value of
+ (input1 > input2) element-wise.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_greater
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_greater
+
+ #### Example:
+ ```mlir
+ %0 = spirv.Tosa.Greater %arg0, %arg1 : !spirv.arm.tensor<11x10x10x2xi32>, !spirv.arm.tensor<11x10x10x1xi32> -> !spirv.arm.tensor<11x10x10x2xi1>
+ %0 = spirv.Tosa.Greater %arg0, %arg1 : !spirv.arm.tensor<6x3x12x4xf16>, !spirv.arm.tensor<6x3x1x4xf16> -> !spirv.arm.tensor<6x3x12x4xi1>
+ ```
+ }];
+}
+
+
+def SPIRV_TosaGreaterEqualOp : SPIRV_TosaComparisonOp<"GreaterEqual", 47> {
+ let summary = "Greater or Equal comparison operation";
+
+ let description = [{
+ Elementwise Greater or Equal than comparison operation: returns the truth value of
+ (input1 >= input2) element-wise.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_greater_equal
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_greater_equal
+
+ #### Example:
+ ```mlir
+ %0 = spirv.Tosa.GreaterEqual %arg0, %arg1 : !spirv.arm.tensor<10x17x7x1xi32>, !spirv.arm.tensor<10x17x7x16xi32> -> !spirv.arm.tensor<10x17x7x16xi1>
+ %0 = spirv.Tosa.GreaterEqual %arg0, %arg1 : !spirv.arm.tensor<3x17x6x3xf32>, !spirv.arm.tensor<1x17x6x3xf32> -> !spirv.arm.tensor<3x17x6x3xi1>
+ ```
+ }];
+}
+
+
#endif // MLIR_DIALECT_SPIRV_IR_TOSA_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
index 89d242781f5f7..32d6af9a6742c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -45,6 +45,7 @@ def SPIRV_TosaNumerical_TensorArm3D : TensorArmRankOf<[SPIRV_TosaNumerical], [3]
def SPIRV_TosaNumerical_TensorArm4D : TensorArmRankOf<[SPIRV_TosaNumerical], [4]>;
def SPIRV_TosaNumerical_TensorArm5D : TensorArmRankOf<[SPIRV_TosaNumerical], [5]>;
+def SPIRV_TosaAny_TensorArm : TensorArmRankOf<[SPIRV_TosaAny], [1, 2, 3, 4, 5, 6]>;
def SPIRV_TosaNumerical_TensorArm : TensorArmRankOf<[SPIRV_TosaNumerical], [1, 2, 3, 4, 5, 6]>;
def SPIRV_TosaInteger_TensorArm : TensorArmRankOf<[SPIRV_TosaInteger], [1, 2, 3, 4, 5, 6]>;
def SPIRV_TosaFloat_TensorArm : TensorArmRankOf<[SPIRV_TosaFloat], [1, 2, 3, 4, 5, 6]>;
@@ -121,7 +122,7 @@ class TypeImpliesAccType<string input, Type type, list<string> allowedAccTypes>:
Implies<ElementTypeIsPred<input, type>, [AccTypeIn<allowedAccTypes>]>>;
class MatchBroadcastableShapes<string input1, string input2, string output>:
- PredOpTrait<"the shape of " # input1 # " and " # input2 # " are compatible for broadcasting and the broadcast shape is equal to the output shape",
+ PredOpTrait<"the shape of " # input1 # " and " # input2 # " are compatible for broadcasting and the broadcast shape is equal to the " # output # " shape",
Implies<And<[CPred<HasRank<input1>.result>, CPred<HasRank<input2>.result>, CPred<HasRank<output>.result>,
CPred<Rank<input1>.result # " == " # Rank<input2>.result # " && " # Rank<input1>.result # " == " # Rank<output>.result>]>,
[CPred<"llvm::all_of_zip(" # Shape<input1>.result # ", " # Shape<input2>.result # ", " # Shape<output>.result # ", " #
@@ -133,6 +134,21 @@ class MatchBroadcastableShapes<string input1, string input2, string output>:
"})">]>
>;
+class TernaryMatchBroadcastableShapes<string input1, string input2, string input3, string output>:
+ PredOpTrait<"the shape of " # input1 # ", " # input2 # ", and " # input3 # " are compatible for broadcasting and the broadcast shape is equal to the " # output # " shape",
+ Implies<And<[CPred<HasRank<input1>.result>, CPred<HasRank<input2>.result>, CPred<HasRank<input3>.result>, CPred<HasRank<output>.result>,
+ CPred<Rank<input1>.result # " == " # Rank<input2>.result # " && " # Rank<input1>.result # " == " # Rank<input3>.result # " && " # Rank<input1>.result # " == " # Rank<output>.result>]>,
+ [CPred<"llvm::all_of_zip(" # Shape<input1>.result # ", " # Shape<input2>.result # ", " # Shape<input3>.result # ", " # Shape<output>.result # ", " #
+ "[](int64_t input1Dim, int64_t input2Dim, int64_t input3Dim, int64_t outputDim) { " #
+ " bool dynamic = ShapedType::isDynamic(input1Dim) || ShapedType::isDynamic(input2Dim) || ShapedType::isDynamic(input3Dim) || ShapedType::isDynamic(outputDim);"
+ " bool broadcastableInputs = (input1Dim == input2Dim || input1Dim == 1 || input2Dim == 1) && "
+ " (input1Dim == input3Dim || input1Dim == 1 || input3Dim == 1) && "
+ " (input2Dim == input3Dim || input2Dim == 1 || input3Dim == 1);"
+ " bool broacastDimMatchesOutputDim = std::max(input1Dim, std::max(input2Dim, input3Dim)) == outputDim;"
+ " return dynamic || (broadcastableInputs && broacastDimMatchesOutputDim);" #
+ "})">]>
+ >;
+
class TableSizeConstraint<string input, Type type, int size>:
PredOpTrait<"table must have size " # size # " if " # input # " has element type " # type.summary,
Implies<ElementTypeIsPred<input, type>, [CPred<"::llvm::cast<::mlir::ShapedType>(getTable().getType()).getShape()[0] == " # size>]>
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
index 8242a61020ce0..cb62f3d7c7be3 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
@@ -1422,3 +1422,145 @@ spirv.ARM.Graph @sin_input_output_shapes_not_matching(%arg0: !spirv.arm.tensor<4
%0 = spirv.Tosa.Sin %arg0 : !spirv.arm.tensor<49x38x58xf16> -> !spirv.arm.tensor<49x38x59xf16>
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<49x38x59xf16>
}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Select
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @select_true_value_output_element_types_not_matching(%arg0: !spirv.arm.tensor<4x6x4x5xi1>, %arg1: !spirv.arm.tensor<4x6x4x5xi8>, %arg2: !spirv.arm.tensor<4x6x4x5xi8>) -> (!spirv.arm.tensor<4x6x4x5xi16>) {
+ // expected-error @+1 {{op failed to verify that all of {true_value, false_value, output} have same element type}}
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x6x4x5xi1>, !spirv.arm.tensor<4x6x4x5xi8>, !spirv.arm.tensor<4x6x4x5xi8> -> !spirv.arm.tensor<4x6x4x5xi16>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x6x4x5xi16>
+}
+
+spirv.ARM.Graph @select_true_value_false_value_element_types_not_matching(%arg0: !spirv.arm.tensor<4x6x4x5xi1>, %arg1: !spirv.arm.tensor<4x6x4x5xi8>, %arg2: !spirv.arm.tensor<4x6x4x5xi16>) -> (!spirv.arm.tensor<4x6x4x5xi8>) {
+ // expected-error @+1 {{op failed to verify that all of {true_value, false_value, output} have same element type}}
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x6x4x5xi1>, !spirv.arm.tensor<4x6x4x5xi8>, !spirv.arm.tensor<4x6x4x5xi16> -> !spirv.arm.tensor<4x6x4x5xi8>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x6x4x5xi8>
+}
+
+spirv.ARM.Graph @select_false_value_output_element_types_not_matching(%arg0: !spirv.arm.tensor<4x6x4x5xi1>, %arg1: !spirv.arm.tensor<4x6x4x5xi16>, %arg2: !spirv.arm.tensor<4x6x4x5xi8>) -> (!spirv.arm.tensor<4x6x4x5xi16>) {
+ // expected-error @+1 {{op failed to verify that all of {true_value, false_value, output} have same element type}}
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x6x4x5xi1>, !spirv.arm.tensor<4x6x4x5xi16>, !spirv.arm.tensor<4x6x4x5xi8> -> !spirv.arm.tensor<4x6x4x5xi16>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x6x4x5xi16>
+}
+
+spirv.ARM.Graph @select_condition_must_be_bool_tensor(%arg0: !spirv.arm.tensor<4x6x4x5xi8>, %arg1: !spirv.arm.tensor<4x6x4x5xi8>, %arg2: !spirv.arm.tensor<4x6x4x5xi8>) -> (!spirv.arm.tensor<4x6x4x5xi8>) {
+ // expected-error @+1 {{op operand #0 must be 1D/2D/3D/4D/5D/6D tensorArm of bool values, but got '!spirv.arm.tensor<4x6x4x5xi8>'}}
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x6x4x5xi8>, !spirv.arm.tensor<4x6x4x5xi8>, !spirv.arm.tensor<4x6x4x5xi8> -> !spirv.arm.tensor<4x6x4x5xi8>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x6x4x5xi8>
+}
+
+spirv.ARM.Graph @select_condition_rank_not_matching(%arg0: !spirv.arm.tensor<4x6x4xi1>, %arg1: !spirv.arm.tensor<4x6x4x5xi8>, %arg2: !spirv.arm.tensor<4x6x4x5xi8>) -> (!spirv.arm.tensor<4x6x4x5xi8>) {
+ // expected-error @+1 {{op failed to verify that all of {condition, true_value, false_value, output} have same rank}}
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x6x4xi1>, !spirv.arm.tensor<4x6x4x5xi8>, !spirv.arm.tensor<4x6x4x5xi8> -> !spirv.arm.tensor<4x6x4x5xi8>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x6x4x5xi8>
+}
+
+spirv.ARM.Graph @select_true_value_rank_not_matching(%arg0: !spirv.arm.tensor<4x6x4x5xi1>, %arg1: !spirv.arm.tensor<4x6x4xi8>, %arg2: !spirv.arm.tensor<4x6x4x5xi8>) -> (!spirv.arm.tensor<4x6x4x5xi8>) {
+ // expected-error @+1 {{op failed to verify that all of {condition, true_value, false_value, output} have same rank}}
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x6x4x5xi1>, !spirv.arm.tensor<4x6x4xi8>, !spirv.arm.tensor<4x6x4x5xi8> -> !spirv.arm.tensor<4x6x4x5xi8>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x6x4x5xi8>
+}
+
+spirv.ARM.Graph @select_false_value_rank_not_matching(%arg0: !spirv.arm.tensor<4x6x4x5xi1>, %arg1: !spirv.arm.tensor<4x6x4x5xi8>, %arg2: !spirv.arm.tensor<4x6x4xi8>) -> (!spirv.arm.tensor<4x6x4x5xi8>) {
+ // expected-error @+1 {{op failed to verify that all of {condition, true_value, false_value, output} have same rank}}
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x6x4x5xi1>, !spirv.arm.tensor<4x6x4x5xi8>, !spirv.arm.tensor<4x6x4xi8> -> !spirv.arm.tensor<4x6x4x5xi8>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x6x4x5xi8>
+}
+
+spirv.ARM.Graph @select_output_rank_not_matching_inputs(%arg0: !spirv.arm.tensor<4x1x4x5xi1>, %arg1: !spirv.arm.tensor<4x6x4x5xi8>, %arg2: !spirv.arm.tensor<4x1x4x5xi8>) -> (!spirv.arm.tensor<4x6x5xi8>) {
+ // expected-error @+1 {{op failed to verify that all of {condition, true_value, false_value, output} have same rank}}
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x1x4x5xi1>, !spirv.arm.tensor<4x6x4x5xi8>, !spirv.arm.tensor<4x1x4x5xi8> -> !spirv.arm.tensor<4x6x5xi8>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x6x5xi8>
+}
+
+spirv.ARM.Graph @select_condition_true_value_not_broadcastable_false_value_compatible(%arg0: !spirv.arm.tensor<4x2x4x5xi1>, %arg1: !spirv.arm.tensor<4x3x4x5xi8>, %arg2: !spirv.arm.tensor<4x1x4x5xi8>) -> (!spirv.arm.tensor<4x3x4x5xi8>) {
+ // expected-error @+1 {{op failed to verify that the shape of condition, true_value, and false_value are compatible for broadcasting and the broadcast shape is equal to the output shape}}
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x2x4x5xi1>, !spirv.arm.tensor<4x3x4x5xi8>, !spirv.arm.tensor<4x1x4x5xi8> -> !spirv.arm.tensor<4x3x4x5xi8>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x3x4x5xi8>
+}
+
+spirv.ARM.Graph @select_condition_false_value_not_broadcastable_true_value_compatible(%arg0: !spirv.arm.tensor<4x2x4x5xi1>, %arg1: !spirv.arm.tensor<4x1x4x5xi8>, %arg2: !spirv.arm.tensor<4x3x4x5xi8>) -> (!spirv.arm.tensor<4x3x4x5xi8>) {
+ // expected-error @+1 {{op failed to verify that the shape of condition, true_value, and false_value are compatible for broadcasting and the broadcast shape is equal to the output shape}}
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x2x4x5xi1>, !spirv.arm.tensor<4x1x4x5xi8>, !spirv.arm.tensor<4x3x4x5xi8> -> !spirv.arm.tensor<4x3x4x5xi8>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x3x4x5xi8>
+}
+
+spirv.ARM.Graph @select_true_value_false_value_not_broadcastable_condition_compatible(%arg0: !spirv.arm.tensor<4x1x4x5xi1>, %arg1: !spirv.arm.tensor<4x2x4x5xi8>, %arg2: !spirv.arm.tensor<4x3x4x5xi8>) -> (!spirv.arm.tensor<4x3x4x5xi8>) {
+ // expected-error @+1 {{op failed to verify that the shape of condition, true_value, and false_value are compatible for broadcasting and the broadcast shape is equal to the output shape}}
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x1x4x5xi1>, !spirv.arm.tensor<4x2x4x5xi8>, !spirv.arm.tensor<4x3x4x5xi8> -> !spirv.arm.tensor<4x3x4x5xi8>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x3x4x5xi8>
+}
+
+spirv.ARM.Graph @select_inputs_broadcastable_output_shape_not_broadcast_shape(%arg0: !spirv.arm.tensor<4x1x4x5xi1>, %arg1: !spirv.arm.tensor<4x6x4x5xi8>, %arg2: !spirv.arm.tensor<4x1x4x5xi8>) -> (!spirv.arm.tensor<4x1x4x5xi8>) {
+ // expected-error @+1 {{op failed to verify that the shape of condition, true_value, and false_value are compatible for broadcasting and the broadcast shape is equal to the output shape}}
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x1x4x5xi1>, !spirv.arm.tensor<4x6x4x5xi8>, !spirv.arm.tensor<4x1x4x5xi8> -> !spirv.arm.tensor<4x1x4x5xi8>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x1x4x5xi8>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Equal
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @equal_input_element_types_not_matching(%arg0: !spirv.arm.tensor<16x11x5x3xf32>, %arg1: !spirv.arm.tensor<16x11x5x3xf16>) -> (!spirv.arm.tensor<16x11x5x3xi1>) {
+ // expected-error @+1 {{op failed to verify that all of {input1, input2} have same element type}}
+ %0 = spirv.Tosa.Equal %arg0, %arg1 : !spirv.arm.tensor<16x11x5x3xf32>, !spirv.arm.tensor<16x11x5x3xf16> -> !spirv.arm.tensor<16x11x5x3xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<16x11x5x3xi1>
+}
+
+spirv.ARM.Graph @equal_input_shapes_not_broadcastable(%arg0: !spirv.arm.tensor<16x11x5x3xf32>, %arg1: !spirv.arm.tensor<16x7x5x3xf32>) -> (!spirv.arm.tensor<16x11x5x3xi1>) {
+ // expected-error @+1 {{op failed to verify that the shape of input1 and input2 are compatible for broadcasting and the broadcast shape is equal to the output shape}}
+ %0 = spirv.Tosa.Equal %arg0, %arg1 : !spirv.arm.tensor<16x11x5x3xf32>, !spirv.arm.tensor<16x7x5x3xf32> -> !spirv.arm.tensor<16x11x5x3xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<16x11x5x3xi1>
+}
+
+spirv.ARM.Graph @equal_output_shape_not_broadcast_shape(%arg0: !spirv.arm.tensor<16x11x5x3xf32>, %arg1: !spirv.arm.tensor<16x1x5x3xf32>) -> (!spirv.arm.tensor<16x1x5x3xi1>) {
+ // expected-error @+1 {{op failed to verify that the shape of input1 and input2 are compatible for broadcasting and the broadcast shape is equal to the output shape}}
+ %0 = spirv.Tosa.Equal %arg0, %arg1 : !spirv.arm.tensor<16x11x5x3xf32>, !spirv.arm.tensor<16x1x5x3xf32> -> !spirv.arm.tensor<16x1x5x3xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<16x1x5x3xi1>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Greater
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @greater_input_element_types_not_matching(%arg0: !spirv.arm.tensor<11x10x10x2xi32>, %arg1: !spirv.arm.tensor<11x10x10x2xf16>) -> (!spirv.arm.tensor<11x10x10x2xi1>) {
+ // expected-error @+1 {{op failed to verify that all of {input1, input2} have same element type}}
+ %0 = spirv.Tosa.Greater %arg0, %arg1 : !spirv.arm.tensor<11x10x10x2xi32>, !spirv.arm.tensor<11x10x10x2xf16> -> !spirv.arm.tensor<11x10x10x2xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<11x10x10x2xi1>
+}
+
+spirv.ARM.Graph @greater_output_shape_not_broadcast_shape(%arg0: !spirv.arm.tensor<11x10x10x2xi32>, %arg1: !spirv.arm.tensor<11x10x10x1xi32>) -> (!spirv.arm.tensor<11x10x10x1xi1>) {
+ // expected-error @+1 {{op failed to verify that the shape of input1 and input2 are compatible for broadcasting and the broadcast shape is equal to the output shape}}
+ %0 = spirv.Tosa.Greater %arg0, %arg1 : !spirv.arm.tensor<11x10x10x2xi32>, !spirv.arm.tensor<11x10x10x1xi32> -> !spirv.arm.tensor<11x10x10x1xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<11x10x10x1xi1>
+}
+
+spirv.ARM.Graph @greater_output_shape_not_broadcast_shape_batch_dim(%arg0: !spirv.arm.tensor<11x10x10x2xf16>, %arg1: !spirv.arm.tensor<1x10x10x2xf16>) -> (!spirv.arm.tensor<1x10x10x2xi1>) {
+ // expected-error @+1 {{op failed to verify that the shape of input1 and input2 are compatible for broadcasting and the broadcast shape is equal to the output shape}}
+ %0 = spirv.Tosa.Greater %arg0, %arg1 : !spirv.arm.tensor<11x10x10x2xf16>, !spirv.arm.tensor<1x10x10x2xf16> -> !spirv.arm.tensor<1x10x10x2xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1x10x10x2xi1>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.GreaterEqual
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @greaterequal_input_element_types_not_matching(%arg0: !spirv.arm.tensor<10x17x7x16xi32>, %arg1: !spirv.arm.tensor<10x17x7x16xf32>) -> (!spirv.arm.tensor<10x17x7x16xi1>) {
+ // expected-error @+1 {{op failed to verify that all of {input1, input2} have same element type}}
+ %0 = spirv.Tosa.GreaterEqual %arg0, %arg1 : !spirv.arm.tensor<10x17x7x16xi32>, !spirv.arm.tensor<10x17x7x16xf32> -> !spirv.arm.tensor<10x17x7x16xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<10x17x7x16xi1>
+}
+
+spirv.ARM.Graph @greaterequal_input_shapes_not_broadcastable(%arg0: !spirv.arm.tensor<10x17x7x16xf32>, %arg1: !spirv.arm.tensor<10x17x5x16xf32>) -> (!spirv.arm.tensor<10x17x7x16xi1>) {
+ // expected-error @+1 {{op failed to verify that the shape of input1 and input2 are compatible for broadcasting and the broadcast shape is equal to the output shape}}
+ %0 = spirv.Tosa.GreaterEqual %arg0, %arg1 : !spirv.arm.tensor<10x17x7x16xf32>, !spirv.arm.tensor<10x17x5x16xf32> -> !spirv.arm.tensor<10x17x7x16xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<10x17x7x16xi1>
+}
+
+spirv.ARM.Graph @greaterequal_output_shape_not_broadcast_shape(%arg0: !spirv.arm.tensor<10x17x7x16xf32>, %arg1: !spirv.arm.tensor<1x17x7x16xf32>) -> (!spirv.arm.tensor<1x17x7x16xi1>) {
+ // expected-error @+1 {{op failed to verify that the shape of input1 and input2 are compatible for broadcasting and the broadcast shape is equal to the output shape}}
+ %0 = spirv.Tosa.GreaterEqual %arg0, %arg1 : !spirv.arm.tensor<10x17x7x16xf32>, !spirv.arm.tensor<1x17x7x16xf32> -> !spirv.arm.tensor<1x17x7x16xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1x17x7x16xi1>
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
index 8d51dd49e7ae0..a4a26cb394603 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
@@ -698,3 +698,91 @@ spirv.ARM.Graph @sin_fp(%arg0: !spirv.arm.tensor<49x38x58xf16>) -> (!spirv.arm.t
// CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<49x38x58xf16>
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<49x38x58xf16>
}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Select - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @select_int(%arg0: !spirv.arm.tensor<4x1x4x5xi1>, %arg1: !spirv.arm.tensor<4x6x4x5xi8>, %arg2: !spirv.arm.tensor<4x6x4x5xi8>) -> (!spirv.arm.tensor<4x6x4x5xi8>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x1x4x5xi1>, !spirv.arm.tensor<4x6x4x5xi8>, !spirv.arm.tensor<4x6x4x5xi8> -> !spirv.arm.tensor<4x6x4x5xi8>
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x1x4x5xi1>, !spirv.arm.tensor<4x6x4x5xi8>, !spirv.arm.tensor<4x6x4x5xi8> -> !spirv.arm.tensor<4x6x4x5xi8>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<4x6x4x5xi8>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x6x4x5xi8>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Select - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @select_fp(%arg0: !spirv.arm.tensor<9x2x15x8xi1>, %arg1: !spirv.arm.tensor<9x2x15x8xf16>, %arg2: !spirv.arm.tensor<9x1x15x8xf16>) -> (!spirv.arm.tensor<9x2x15x8xf16>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<9x2x15x8xi1>, !spirv.arm.tensor<9x2x15x8xf16>, !spirv.arm.tensor<9x1x15x8xf16> -> !spirv.arm.tensor<9x2x15x8xf16>
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<9x2x15x8xi1>, !spirv.arm.tensor<9x2x15x8xf16>, !spirv.arm.tensor<9x1x15x8xf16> -> !spirv.arm.tensor<9x2x15x8xf16>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<9x2x15x8xf16>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<9x2x15x8xf16>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Equal - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @equal_int(%arg0: !spirv.arm.tensor<51x28x59xi32>, %arg1: !spirv.arm.tensor<51x1x59xi32>) -> (!spirv.arm.tensor<51x28x59xi1>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Equal %arg0, %arg1 : !spirv.arm.tensor<51x28x59xi32>, !spirv.arm.tensor<51x1x59xi32> -> !spirv.arm.tensor<51x28x59xi1>
+ %0 = spirv.Tosa.Equal %arg0, %arg1 : !spirv.arm.tensor<51x28x59xi32>, !spirv.arm.tensor<51x1x59xi32> -> !spirv.arm.tensor<51x28x59xi1>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<51x28x59xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<51x28x59xi1>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Equal - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @equal_fp(%arg0: !spirv.arm.tensor<16x11x5x3xf32>, %arg1: !spirv.arm.tensor<16x1x5x3xf32>) -> (!spirv.arm.tensor<16x11x5x3xi1>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Equal %arg0, %arg1 : !spirv.arm.tensor<16x11x5x3xf32>, !spirv.arm.tensor<16x1x5x3xf32> -> !spirv.arm.tensor<16x11x5x3xi1>
+ %0 = spirv.Tosa.Equal %arg0, %arg1 : !spirv.arm.tensor<16x11x5x3xf32>, !spirv.arm.tensor<16x1x5x3xf32> -> !spirv.arm.tensor<16x11x5x3xi1>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<16x11x5x3xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<16x11x5x3xi1>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Greater - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @greater_int(%arg0: !spirv.arm.tensor<11x10x10x2xi32>, %arg1: !spirv.arm.tensor<11x10x10x1xi32>) -> (!spirv.arm.tensor<11x10x10x2xi1>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Greater %arg0, %arg1 : !spirv.arm.tensor<11x10x10x2xi32>, !spirv.arm.tensor<11x10x10x1xi32> -> !spirv.arm.tensor<11x10x10x2xi1>
+ %0 = spirv.Tosa.Greater %arg0, %arg1 : !spirv.arm.tensor<11x10x10x2xi32>, !spirv.arm.tensor<11x10x10x1xi32> -> !spirv.arm.tensor<11x10x10x2xi1>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<11x10x10x2xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<11x10x10x2xi1>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Greater - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @greater_fp(%arg0: !spirv.arm.tensor<6x3x12x4xf16>, %arg1: !spirv.arm.tensor<6x3x1x4xf16>) -> (!spirv.arm.tensor<6x3x12x4xi1>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Greater %arg0, %arg1 : !spirv.arm.tensor<6x3x12x4xf16>, !spirv.arm.tensor<6x3x1x4xf16> -> !spirv.arm.tensor<6x3x12x4xi1>
+ %0 = spirv.Tosa.Greater %arg0, %arg1 : !spirv.arm.tensor<6x3x12x4xf16>, !spirv.arm.tensor<6x3x1x4xf16> -> !spirv.arm.tensor<6x3x12x4xi1>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<6x3x12x4xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<6x3x12x4xi1>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.GreaterEqual - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @greaterequal_int(%arg0: !spirv.arm.tensor<10x17x7x1xi32>, %arg1: !spirv.arm.tensor<10x17x7x16xi32>) -> (!spirv.arm.tensor<10x17x7x16xi1>) {
+ // CHECK: {{%.*}} = spirv.Tosa.GreaterEqual %arg0, %arg1 : !spirv.arm.tensor<10x17x7x1xi32>, !spirv.arm.tensor<10x17x7x16xi32> -> !spirv.arm.tensor<10x17x7x16xi1>
+ %0 = spirv.Tosa.GreaterEqual %arg0, %arg1 : !spirv.arm.tensor<10x17x7x1xi32>, !spirv.arm.tensor<10x17x7x16xi32> -> !spirv.arm.tensor<10x17x7x16xi1>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<10x17x7x16xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<10x17x7x16xi1>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.GreaterEqual - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @greaterequal_fp(%arg0: !spirv.arm.tensor<3x17x6x3xf32>, %arg1: !spirv.arm.tensor<1x17x6x3xf32>) -> (!spirv.arm.tensor<3x17x6x3xi1>) {
+ // CHECK: {{%.*}} = spirv.Tosa.GreaterEqual %arg0, %arg1 : !spirv.arm.tensor<3x17x6x3xf32>, !spirv.arm.tensor<1x17x6x3xf32> -> !spirv.arm.tensor<3x17x6x3xi1>
+ %0 = spirv.Tosa.GreaterEqual %arg0, %arg1 : !spirv.arm.tensor<3x17x6x3xf32>, !spirv.arm.tensor<1x17x6x3xf32> -> !spirv.arm.tensor<3x17x6x3xi1>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<3x17x6x3xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<3x17x6x3xi1>
+}
diff --git a/mlir/test/Target/SPIRV/tosa-ops.mlir b/mlir/test/Target/SPIRV/tosa-ops.mlir
index 9c58bad93258f..856fde46a4866 100644
--- a/mlir/test/Target/SPIRV/tosa-ops.mlir
+++ b/mlir/test/Target/SPIRV/tosa-ops.mlir
@@ -1222,3 +1222,165 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<49x38x58xf16>
}
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Select - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @select_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<4x1x4x5xi1>, UniformConstant>
+ spirv.GlobalVariable @select_int_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<4x6x4x5xi8>, UniformConstant>
+ spirv.GlobalVariable @select_int_arg_2 bind(0, 2) : !spirv.ptr<!spirv.arm.tensor<4x6x4x5xi8>, UniformConstant>
+ spirv.GlobalVariable @select_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<4x6x4x5xi8>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @select_int, @select_int_arg_0, @select_int_arg_1, @select_int_arg_2, @select_int_res_0
+ spirv.ARM.Graph @select_int(%arg0: !spirv.arm.tensor<4x1x4x5xi1>, %arg1: !spirv.arm.tensor<4x6x4x5xi8>, %arg2: !spirv.arm.tensor<4x6x4x5xi8>) -> (!spirv.arm.tensor<4x6x4x5xi8>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x1x4x5xi1>, !spirv.arm.tensor<4x6x4x5xi8>, !spirv.arm.tensor<4x6x4x5xi8> -> !spirv.arm.tensor<4x6x4x5xi8>
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x1x4x5xi1>, !spirv.arm.tensor<4x6x4x5xi8>, !spirv.arm.tensor<4x6x4x5xi8> -> !spirv.arm.tensor<4x6x4x5xi8>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<4x6x4x5xi8>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x6x4x5xi8>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Select - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @select_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<9x2x15x8xi1>, UniformConstant>
+ spirv.GlobalVariable @select_fp_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<9x2x15x8xf16>, UniformConstant>
+ spirv.GlobalVariable @select_fp_arg_2 bind(0, 2) : !spirv.ptr<!spirv.arm.tensor<9x1x15x8xf16>, UniformConstant>
+ spirv.GlobalVariable @select_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<9x2x15x8xf16>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @select_fp, @select_fp_arg_0, @select_fp_arg_1, @select_fp_arg_2, @select_fp_res_0
+ spirv.ARM.Graph @select_fp(%arg0: !spirv.arm.tensor<9x2x15x8xi1>, %arg1: !spirv.arm.tensor<9x2x15x8xf16>, %arg2: !spirv.arm.tensor<9x1x15x8xf16>) -> (!spirv.arm.tensor<9x2x15x8xf16>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<9x2x15x8xi1>, !spirv.arm.tensor<9x2x15x8xf16>, !spirv.arm.tensor<9x1x15x8xf16> -> !spirv.arm.tensor<9x2x15x8xf16>
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<9x2x15x8xi1>, !spirv.arm.tensor<9x2x15x8xf16>, !spirv.arm.tensor<9x1x15x8xf16> -> !spirv.arm.tensor<9x2x15x8xf16>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<9x2x15x8xf16>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<9x2x15x8xf16>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Equal - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @equal_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<51x28x59xi32>, UniformConstant>
+ spirv.GlobalVariable @equal_int_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<51x1x59xi32>, UniformConstant>
+ spirv.GlobalVariable @equal_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<51x28x59xi1>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @equal_int, @equal_int_arg_0, @equal_int_arg_1, @equal_int_res_0
+ spirv.ARM.Graph @equal_int(%arg0: !spirv.arm.tensor<51x28x59xi32>, %arg1: !spirv.arm.tensor<51x1x59xi32>) -> (!spirv.arm.tensor<51x28x59xi1>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Equal %arg0, %arg1 : !spirv.arm.tensor<51x28x59xi32>, !spirv.arm.tensor<51x1x59xi32> -> !spirv.arm.tensor<51x28x59xi1>
+ %0 = spirv.Tosa.Equal %arg0, %arg1 : !spirv.arm.tensor<51x28x59xi32>, !spirv.arm.tensor<51x1x59xi32> -> !spirv.arm.tensor<51x28x59xi1>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<51x28x59xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<51x28x59xi1>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Equal - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @equal_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<16x11x5x3xf32>, UniformConstant>
+ spirv.GlobalVariable @equal_fp_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<16x1x5x3xf32>, UniformConstant>
+ spirv.GlobalVariable @equal_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<16x11x5x3xi1>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @equal_fp, @equal_fp_arg_0, @equal_fp_arg_1, @equal_fp_res_0
+ spirv.ARM.Graph @equal_fp(%arg0: !spirv.arm.tensor<16x11x5x3xf32>, %arg1: !spirv.arm.tensor<16x1x5x3xf32>) -> (!spirv.arm.tensor<16x11x5x3xi1>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Equal %arg0, %arg1 : !spirv.arm.tensor<16x11x5x3xf32>, !spirv.arm.tensor<16x1x5x3xf32> -> !spirv.arm.tensor<16x11x5x3xi1>
+ %0 = spirv.Tosa.Equal %arg0, %arg1 : !spirv.arm.tensor<16x11x5x3xf32>, !spirv.arm.tensor<16x1x5x3xf32> -> !spirv.arm.tensor<16x11x5x3xi1>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<16x11x5x3xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<16x11x5x3xi1>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Greater - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @greater_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<11x10x10x2xi32>, UniformConstant>
+ spirv.GlobalVariable @greater_int_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<11x10x10x1xi32>, UniformConstant>
+ spirv.GlobalVariable @greater_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<11x10x10x2xi1>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @greater_int, @greater_int_arg_0, @greater_int_arg_1, @greater_int_res_0
+ spirv.ARM.Graph @greater_int(%arg0: !spirv.arm.tensor<11x10x10x2xi32>, %arg1: !spirv.arm.tensor<11x10x10x1xi32>) -> (!spirv.arm.tensor<11x10x10x2xi1>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Greater %arg0, %arg1 : !spirv.arm.tensor<11x10x10x2xi32>, !spirv.arm.tensor<11x10x10x1xi32> -> !spirv.arm.tensor<11x10x10x2xi1>
+ %0 = spirv.Tosa.Greater %arg0, %arg1 : !spirv.arm.tensor<11x10x10x2xi32>, !spirv.arm.tensor<11x10x10x1xi32> -> !spirv.arm.tensor<11x10x10x2xi1>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<11x10x10x2xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<11x10x10x2xi1>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Greater - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @greater_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<6x3x12x4xf16>, UniformConstant>
+ spirv.GlobalVariable @greater_fp_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<6x3x1x4xf16>, UniformConstant>
+ spirv.GlobalVariable @greater_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<6x3x12x4xi1>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @greater_fp, @greater_fp_arg_0, @greater_fp_arg_1, @greater_fp_res_0
+ spirv.ARM.Graph @greater_fp(%arg0: !spirv.arm.tensor<6x3x12x4xf16>, %arg1: !spirv.arm.tensor<6x3x1x4xf16>) -> (!spirv.arm.tensor<6x3x12x4xi1>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Greater %arg0, %arg1 : !spirv.arm.tensor<6x3x12x4xf16>, !spirv.arm.tensor<6x3x1x4xf16> -> !spirv.arm.tensor<6x3x12x4xi1>
+ %0 = spirv.Tosa.Greater %arg0, %arg1 : !spirv.arm.tensor<6x3x12x4xf16>, !spirv.arm.tensor<6x3x1x4xf16> -> !spirv.arm.tensor<6x3x12x4xi1>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<6x3x12x4xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<6x3x12x4xi1>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.GreaterEqual - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @greaterequal_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<10x17x7x1xi32>, UniformConstant>
+ spirv.GlobalVariable @greaterequal_int_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<10x17x7x16xi32>, UniformConstant>
+ spirv.GlobalVariable @greaterequal_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<10x17x7x16xi1>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @greaterequal_int, @greaterequal_int_arg_0, @greaterequal_int_arg_1, @greaterequal_int_res_0
+ spirv.ARM.Graph @greaterequal_int(%arg0: !spirv.arm.tensor<10x17x7x1xi32>, %arg1: !spirv.arm.tensor<10x17x7x16xi32>) -> (!spirv.arm.tensor<10x17x7x16xi1>) {
+ // CHECK: {{%.*}} = spirv.Tosa.GreaterEqual %arg0, %arg1 : !spirv.arm.tensor<10x17x7x1xi32>, !spirv.arm.tensor<10x17x7x16xi32> -> !spirv.arm.tensor<10x17x7x16xi1>
+ %0 = spirv.Tosa.GreaterEqual %arg0, %arg1 : !spirv.arm.tensor<10x17x7x1xi32>, !spirv.arm.tensor<10x17x7x16xi32> -> !spirv.arm.tensor<10x17x7x16xi1>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<10x17x7x16xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<10x17x7x16xi1>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.GreaterEqual - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @greaterequal_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<3x17x6x3xf32>, UniformConstant>
+ spirv.GlobalVariable @greaterequal_fp_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<1x17x6x3xf32>, UniformConstant>
+ spirv.GlobalVariable @greaterequal_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<3x17x6x3xi1>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @greaterequal_fp, @greaterequal_fp_arg_0, @greaterequal_fp_arg_1, @greaterequal_fp_res_0
+ spirv.ARM.Graph @greaterequal_fp(%arg0: !spirv.arm.tensor<3x17x6x3xf32>, %arg1: !spirv.arm.tensor<1x17x6x3xf32>) -> (!spirv.arm.tensor<3x17x6x3xi1>) {
+ // CHECK: {{%.*}} = spirv.Tosa.GreaterEqual %arg0, %arg1 : !spirv.arm.tensor<3x17x6x3xf32>, !spirv.arm.tensor<1x17x6x3xf32> -> !spirv.arm.tensor<3x17x6x3xi1>
+ %0 = spirv.Tosa.GreaterEqual %arg0, %arg1 : !spirv.arm.tensor<3x17x6x3xf32>, !spirv.arm.tensor<1x17x6x3xf32> -> !spirv.arm.tensor<3x17x6x3xi1>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<3x17x6x3xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<3x17x6x3xi1>
+ }
+}
More information about the Mlir-commits
mailing list