[Mlir-commits] [mlir] [mlir][spirv] Add comparison and elementwise ternary ops in TOSA Ext Inst Set (PR #186356)
Davide Grohmann
llvmlistbot at llvm.org
Tue Mar 17 04:55:14 PDT 2026
https://github.com/davidegrohmann updated https://github.com/llvm/llvm-project/pull/186356
>From 46fe967084f3dcf54e98d30c3da1f7209737ada3 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Wed, 28 Jan 2026 13:39:57 +0100
Subject: [PATCH] [mlir][spirv] Add comparison and elementwise ternary ops in
TOSA Ext Inst Set
This patch introduces the following elementwise unary operators:
spirv.Tosa.Select
spirv.Tosa.Equal
spirv.Tosa.Greater
spirv.Tosa.GreaterEqual
Also dialect and serialization round-trip tests have been added.
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: Ib7dd5c061ba49aca5d8532190598e4fdb75ae8d9
---
.../mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td | 160 +++++++++++++++++
.../mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td | 3 +-
mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp | 65 +++++++
.../SPIRV/IR/tosa-ops-verification.mlir | 142 +++++++++++++++
mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir | 88 ++++++++++
mlir/test/Target/SPIRV/tosa-ops.mlir | 162 ++++++++++++++++++
6 files changed, 619 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index 9fb3a53286bdf..bc26e0bb9cb0a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -183,6 +183,35 @@ class SPIRV_TosaConvolutionOp<string mnemonic, int opcode, list<Trait> traits =
}];
}
+class SPIRV_TosaComparisonOp<string mnemonic, int opcode, list<Trait> traits = []> :
+ SPIRV_TosaOpWithResult<mnemonic, opcode, !listconcat(traits, [Pure,
+ AllElementTypesMatch<["input1", "input2"]>,
+ MatchBroadcastableShapes<"input1", "input2", "output">])> {
+
+ let arguments = (ins
+ SPIRV_TosaNumerical_TensorArm: $input1,
+ SPIRV_TosaNumerical_TensorArm: $input2
+ );
+
+ let results = (outs
+ SPIRV_Bool_TensorArm: $output
+ );
+
+ let assemblyFormat = [{
+ $input1 `,`
+ $input2
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ ::mlir::spirv::TensorArmType getInput1Type() {
+ return cast<::mlir::spirv::TensorArmType>(getInput1().getType());
+ }
+ ::mlir::spirv::TensorArmType getInput2Type() {
+ return cast<::mlir::spirv::TensorArmType>(getInput2().getType());
+ }
+ }];
+}
def SPIRV_TosaArgMaxOp : SPIRV_TosaOpWithResult<"ArgMax", 0, [Pure,
OutputRankIsInputRankMinusOne<"input", "output">,
@@ -1742,4 +1771,135 @@ def SPIRV_TosaSinOp : SPIRV_TosaFloatElementwiseUnaryOp<"Sin", 43> {
}
+def SPIRV_TosaSelectOp : SPIRV_TosaOpWithResult<"Select", 44, [Pure,
+ AllElementTypesMatch<["true_value", "false_value", "output"]>,
+ AllRanksMatch<["condition", "true_value", "false_value", "output"]>,
+ DeclareOpInterfaceMethods<SelectLikeOpInterface>]> {
+ let summary = "Select operator.";
+
+ let description = [{
+ Elementwise Select of the output based on a condition.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_select
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_select
+
+ #### Example:
+ ```mlir
+ %0 = spirv.Tosa.Select %cond, %trueVal, %falseVal : !spirv.arm.tensor<4x1x4x5xi1>, !spirv.arm.tensor<4x6x4x5xi8>, !spirv.arm.tensor<4x6x4x5xi8> -> !spirv.arm.tensor<4x6x4x5xi8>
+ %0 = spirv.Tosa.Select %cond, %trueVal, %falseVal : !spirv.arm.tensor<9x2x15x8xi1>, !spirv.arm.tensor<9x2x15x8xf16>, !spirv.arm.tensor<9x1x15x8xf16> -> !spirv.arm.tensor<9x2x15x8xf16>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_Bool_TensorArm: $condition,
+ SPIRV_TosaAny_TensorArm: $true_value,
+ SPIRV_TosaAny_TensorArm: $false_value
+ );
+
+ let results = (outs
+ SPIRV_TosaAny_TensorArm: $output
+ );
+
+ let hasVerifier = 1;
+
+ let assemblyFormat = [{
+ $condition `,`
+ $true_value `,`
+ $false_value
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ ::mlir::spirv::TensorArmType getConditionType() {
+ return cast<::mlir::spirv::TensorArmType>(getCondition().getType());
+ }
+ ::mlir::spirv::TensorArmType getTrueValueType() {
+ return cast<::mlir::spirv::TensorArmType>(getTrueValue().getType());
+ }
+ ::mlir::spirv::TensorArmType getFalseValueType() {
+ return cast<::mlir::spirv::TensorArmType>(getFalseValue().getType());
+ }
+ ::mlir::Value getInput1() {
+ return getCondition();
+ }
+ ::mlir::Value getInput2() {
+ return getTrueValue();
+ }
+ ::mlir::Value getInput3() {
+ return getFalseValue();
+ }
+ ::mlir::spirv::TensorArmType getInput1Type() {
+ return getConditionType();
+ }
+ ::mlir::spirv::TensorArmType getInput2Type() {
+ return getTrueValueType();
+ }
+ ::mlir::spirv::TensorArmType getInput3Type() {
+ return getFalseValueType();
+ }
+ }];
+}
+
+
+def SPIRV_TosaEqualOp : SPIRV_TosaComparisonOp<"Equal", 45> {
+ let summary = "Equal comparison operation";
+
+ let description = [{
+ Elementwise Equal comparison operation: returns the truth value of
+ (input1 == input2) element-wise.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_equal
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_equal
+
+ #### Example:
+ ```mlir
+ %0 = spirv.Tosa.Equal %arg0, %arg1 : !spirv.arm.tensor<51x28x59xi32>, !spirv.arm.tensor<51x1x59xi32> -> !spirv.arm.tensor<51x28x59xi1>
+ %0 = spirv.Tosa.Equal %arg0, %arg1 : !spirv.arm.tensor<16x11x5x3xf32>, !spirv.arm.tensor<16x1x5x3xf32> -> !spirv.arm.tensor<16x11x5x3xi1>
+ ```
+ }];
+}
+
+
+def SPIRV_TosaGreaterOp : SPIRV_TosaComparisonOp<"Greater", 46> {
+ let summary = "Greater comparison operation";
+
+ let description = [{
+ Elementwise Greater than comparison operation: returns the truth value of
+ (input1 > input2) element-wise.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_greater
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_greater
+
+ #### Example:
+ ```mlir
+ %0 = spirv.Tosa.Greater %arg0, %arg1 : !spirv.arm.tensor<11x10x10x2xi32>, !spirv.arm.tensor<11x10x10x1xi32> -> !spirv.arm.tensor<11x10x10x2xi1>
+ %0 = spirv.Tosa.Greater %arg0, %arg1 : !spirv.arm.tensor<6x3x12x4xf16>, !spirv.arm.tensor<6x3x1x4xf16> -> !spirv.arm.tensor<6x3x12x4xi1>
+ ```
+ }];
+}
+
+
+def SPIRV_TosaGreaterEqualOp : SPIRV_TosaComparisonOp<"GreaterEqual", 47> {
+ let summary = "Greater or Equal comparison operation";
+
+ let description = [{
+ Elementwise Greater or Equal than comparison operation: returns the truth value of
+ (input1 >= input2) element-wise.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_greater_equal
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_greater_equal
+
+ #### Example:
+ ```mlir
+ %0 = spirv.Tosa.GreaterEqual %arg0, %arg1 : !spirv.arm.tensor<10x17x7x1xi32>, !spirv.arm.tensor<10x17x7x16xi32> -> !spirv.arm.tensor<10x17x7x16xi1>
+ %0 = spirv.Tosa.GreaterEqual %arg0, %arg1 : !spirv.arm.tensor<3x17x6x3xf32>, !spirv.arm.tensor<1x17x6x3xf32> -> !spirv.arm.tensor<3x17x6x3xi1>
+ ```
+ }];
+}
+
+
#endif // MLIR_DIALECT_SPIRV_IR_TOSA_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
index 89d242781f5f7..8d862dd87c12a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -45,6 +45,7 @@ def SPIRV_TosaNumerical_TensorArm3D : TensorArmRankOf<[SPIRV_TosaNumerical], [3]
def SPIRV_TosaNumerical_TensorArm4D : TensorArmRankOf<[SPIRV_TosaNumerical], [4]>;
def SPIRV_TosaNumerical_TensorArm5D : TensorArmRankOf<[SPIRV_TosaNumerical], [5]>;
+def SPIRV_TosaAny_TensorArm : TensorArmRankOf<[SPIRV_TosaAny], [1, 2, 3, 4, 5, 6]>;
def SPIRV_TosaNumerical_TensorArm : TensorArmRankOf<[SPIRV_TosaNumerical], [1, 2, 3, 4, 5, 6]>;
def SPIRV_TosaInteger_TensorArm : TensorArmRankOf<[SPIRV_TosaInteger], [1, 2, 3, 4, 5, 6]>;
def SPIRV_TosaFloat_TensorArm : TensorArmRankOf<[SPIRV_TosaFloat], [1, 2, 3, 4, 5, 6]>;
@@ -121,7 +122,7 @@ class TypeImpliesAccType<string input, Type type, list<string> allowedAccTypes>:
Implies<ElementTypeIsPred<input, type>, [AccTypeIn<allowedAccTypes>]>>;
class MatchBroadcastableShapes<string input1, string input2, string output>:
- PredOpTrait<"the shape of " # input1 # " and " # input2 # " are compatible for broadcasting and the broadcast shape is equal to the output shape",
+ PredOpTrait<"the shape of " # input1 # " and " # input2 # " are compatible for broadcasting and the broadcast shape is equal to the " # output # " shape",
Implies<And<[CPred<HasRank<input1>.result>, CPred<HasRank<input2>.result>, CPred<HasRank<output>.result>,
CPred<Rank<input1>.result # " == " # Rank<input2>.result # " && " # Rank<input1>.result # " == " # Rank<output>.result>]>,
[CPred<"llvm::all_of_zip(" # Shape<input1>.result # ", " # Shape<input2>.result # ", " # Shape<output>.result # ", " #
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
index 5116fef6201df..a0591ee31acf8 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
@@ -11,7 +11,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/InterleavedRange.h"
+#include <algorithm>
namespace mlir::spirv {
@@ -48,4 +50,67 @@ void printSPIRV_I32_1DArmTensor(OpAsmPrinter &printer, Operation *,
[](const APInt &a) { return a.getSExtValue(); }));
}
+//===----------------------------------------------------------------------===//
+// SPIRV Tosa Custom verifiers
+//===----------------------------------------------------------------------===//
+
+LogicalResult TosaSelectOp::verify() {
+ TensorArmType condType = getConditionType();
+ TensorArmType trueValType = getTrueValueType();
+ TensorArmType falseValType = getFalseValueType();
+ TensorArmType resultType = getResultType();
+
+ if (llvm::any_of(ArrayRef<TensorArmType>{condType, trueValType, falseValType,
+ resultType},
+ [](TensorArmType type) { return !type.hasRank(); }))
+ return success();
+
+ ArrayRef<int64_t> condShape = condType.getShape();
+ ArrayRef<int64_t> trueValShape = trueValType.getShape();
+ ArrayRef<int64_t> falseValShape = falseValType.getShape();
+ ArrayRef<int64_t> resultShape = resultType.getShape();
+
+ if (!llvm::all_equal({condShape.size(), trueValShape.size(),
+ falseValShape.size(), resultShape.size()})) {
+ // The AllRanksMatch predicate enforces that all ranks are equal.
+ // This is just an extra safe guard for the code coming after that
+ // assumes that all ranks are equal.
+ return failure();
+ }
+
+ for (auto dims :
+ llvm::zip_equal(condShape, trueValShape, falseValShape, resultShape)) {
+ auto [condDim, trueValDim, falseValDim, resultDim] = dims;
+
+ if (llvm::any_of(
+ ArrayRef<int64_t>{condDim, trueValDim, falseValDim, resultDim},
+ [](int64_t dim) { return ShapedType::isDynamic(dim); })) {
+ continue;
+ }
+
+ auto isPairBroadcastable = [](int64_t lhs, int64_t rhs) {
+ return lhs == rhs || lhs == 1 || rhs == 1;
+ };
+
+ if (!isPairBroadcastable(condDim, trueValDim) ||
+ !isPairBroadcastable(condDim, falseValDim) ||
+ !isPairBroadcastable(trueValDim, falseValDim)) {
+ return emitOpError(
+ "failed to verify that the shape of inputs: condition, "
+ "true_value, and false_value are compatible for "
+ "broadcasting");
+ }
+
+ int64_t bradcastedInputDim =
+ std::max(condDim, std::max(trueValDim, falseValDim));
+ if (bradcastedInputDim != resultDim) {
+ return emitOpError(
+ "failed to verify that the broadcast shape of inputs: condition, "
+ "true_value, and false_value is equal to "
+ "the output shape");
+ }
+ }
+ return success();
+}
+
} // namespace mlir::spirv
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
index 8242a61020ce0..7b9c88fe60d80 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
@@ -1422,3 +1422,145 @@ spirv.ARM.Graph @sin_input_output_shapes_not_matching(%arg0: !spirv.arm.tensor<4
%0 = spirv.Tosa.Sin %arg0 : !spirv.arm.tensor<49x38x58xf16> -> !spirv.arm.tensor<49x38x59xf16>
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<49x38x59xf16>
}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Select
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @select_true_value_output_element_types_not_matching(%arg0: !spirv.arm.tensor<4x6x4x5xi1>, %arg1: !spirv.arm.tensor<4x6x4x5xi8>, %arg2: !spirv.arm.tensor<4x6x4x5xi8>) -> (!spirv.arm.tensor<4x6x4x5xi16>) {
+ // expected-error @+1 {{op failed to verify that all of {true_value, false_value, output} have same element type}}
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x6x4x5xi1>, !spirv.arm.tensor<4x6x4x5xi8>, !spirv.arm.tensor<4x6x4x5xi8> -> !spirv.arm.tensor<4x6x4x5xi16>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x6x4x5xi16>
+}
+
+spirv.ARM.Graph @select_true_value_false_value_element_types_not_matching(%arg0: !spirv.arm.tensor<4x6x4x5xi1>, %arg1: !spirv.arm.tensor<4x6x4x5xi8>, %arg2: !spirv.arm.tensor<4x6x4x5xi16>) -> (!spirv.arm.tensor<4x6x4x5xi8>) {
+ // expected-error @+1 {{op failed to verify that all of {true_value, false_value, output} have same element type}}
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x6x4x5xi1>, !spirv.arm.tensor<4x6x4x5xi8>, !spirv.arm.tensor<4x6x4x5xi16> -> !spirv.arm.tensor<4x6x4x5xi8>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x6x4x5xi8>
+}
+
+spirv.ARM.Graph @select_false_value_output_element_types_not_matching(%arg0: !spirv.arm.tensor<4x6x4x5xi1>, %arg1: !spirv.arm.tensor<4x6x4x5xi16>, %arg2: !spirv.arm.tensor<4x6x4x5xi8>) -> (!spirv.arm.tensor<4x6x4x5xi16>) {
+ // expected-error @+1 {{op failed to verify that all of {true_value, false_value, output} have same element type}}
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x6x4x5xi1>, !spirv.arm.tensor<4x6x4x5xi16>, !spirv.arm.tensor<4x6x4x5xi8> -> !spirv.arm.tensor<4x6x4x5xi16>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x6x4x5xi16>
+}
+
+spirv.ARM.Graph @select_condition_must_be_bool_tensor(%arg0: !spirv.arm.tensor<4x6x4x5xi8>, %arg1: !spirv.arm.tensor<4x6x4x5xi8>, %arg2: !spirv.arm.tensor<4x6x4x5xi8>) -> (!spirv.arm.tensor<4x6x4x5xi8>) {
+ // expected-error @+1 {{op operand #0 must be 1D/2D/3D/4D/5D/6D tensorArm of bool values, but got '!spirv.arm.tensor<4x6x4x5xi8>'}}
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x6x4x5xi8>, !spirv.arm.tensor<4x6x4x5xi8>, !spirv.arm.tensor<4x6x4x5xi8> -> !spirv.arm.tensor<4x6x4x5xi8>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x6x4x5xi8>
+}
+
+spirv.ARM.Graph @select_condition_rank_not_matching(%arg0: !spirv.arm.tensor<4x6x4xi1>, %arg1: !spirv.arm.tensor<4x6x4x5xi8>, %arg2: !spirv.arm.tensor<4x6x4x5xi8>) -> (!spirv.arm.tensor<4x6x4x5xi8>) {
+ // expected-error @+1 {{op failed to verify that all of {condition, true_value, false_value, output} have same rank}}
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x6x4xi1>, !spirv.arm.tensor<4x6x4x5xi8>, !spirv.arm.tensor<4x6x4x5xi8> -> !spirv.arm.tensor<4x6x4x5xi8>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x6x4x5xi8>
+}
+
+spirv.ARM.Graph @select_true_value_rank_not_matching(%arg0: !spirv.arm.tensor<4x6x4x5xi1>, %arg1: !spirv.arm.tensor<4x6x4xi8>, %arg2: !spirv.arm.tensor<4x6x4x5xi8>) -> (!spirv.arm.tensor<4x6x4x5xi8>) {
+ // expected-error @+1 {{op failed to verify that all of {condition, true_value, false_value, output} have same rank}}
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x6x4x5xi1>, !spirv.arm.tensor<4x6x4xi8>, !spirv.arm.tensor<4x6x4x5xi8> -> !spirv.arm.tensor<4x6x4x5xi8>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x6x4x5xi8>
+}
+
+spirv.ARM.Graph @select_false_value_rank_not_matching(%arg0: !spirv.arm.tensor<4x6x4x5xi1>, %arg1: !spirv.arm.tensor<4x6x4x5xi8>, %arg2: !spirv.arm.tensor<4x6x4xi8>) -> (!spirv.arm.tensor<4x6x4x5xi8>) {
+ // expected-error @+1 {{op failed to verify that all of {condition, true_value, false_value, output} have same rank}}
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x6x4x5xi1>, !spirv.arm.tensor<4x6x4x5xi8>, !spirv.arm.tensor<4x6x4xi8> -> !spirv.arm.tensor<4x6x4x5xi8>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x6x4x5xi8>
+}
+
+spirv.ARM.Graph @select_output_rank_not_matching_inputs(%arg0: !spirv.arm.tensor<4x1x4x5xi1>, %arg1: !spirv.arm.tensor<4x6x4x5xi8>, %arg2: !spirv.arm.tensor<4x1x4x5xi8>) -> (!spirv.arm.tensor<4x6x5xi8>) {
+ // expected-error @+1 {{op failed to verify that all of {condition, true_value, false_value, output} have same rank}}
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x1x4x5xi1>, !spirv.arm.tensor<4x6x4x5xi8>, !spirv.arm.tensor<4x1x4x5xi8> -> !spirv.arm.tensor<4x6x5xi8>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x6x5xi8>
+}
+
+spirv.ARM.Graph @select_condition_true_value_not_broadcastable_false_value_compatible(%arg0: !spirv.arm.tensor<4x2x4x5xi1>, %arg1: !spirv.arm.tensor<4x3x4x5xi8>, %arg2: !spirv.arm.tensor<4x1x4x5xi8>) -> (!spirv.arm.tensor<4x3x4x5xi8>) {
+ // expected-error @+1 {{op failed to verify that the shape of inputs: condition, true_value, and false_value are compatible for broadcasting}}
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x2x4x5xi1>, !spirv.arm.tensor<4x3x4x5xi8>, !spirv.arm.tensor<4x1x4x5xi8> -> !spirv.arm.tensor<4x3x4x5xi8>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x3x4x5xi8>
+}
+
+spirv.ARM.Graph @select_condition_false_value_not_broadcastable_true_value_compatible(%arg0: !spirv.arm.tensor<4x2x4x5xi1>, %arg1: !spirv.arm.tensor<4x1x4x5xi8>, %arg2: !spirv.arm.tensor<4x3x4x5xi8>) -> (!spirv.arm.tensor<4x3x4x5xi8>) {
+ // expected-error @+1 {{op failed to verify that the shape of inputs: condition, true_value, and false_value are compatible for broadcasting}}
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x2x4x5xi1>, !spirv.arm.tensor<4x1x4x5xi8>, !spirv.arm.tensor<4x3x4x5xi8> -> !spirv.arm.tensor<4x3x4x5xi8>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x3x4x5xi8>
+}
+
+spirv.ARM.Graph @select_true_value_false_value_not_broadcastable_condition_compatible(%arg0: !spirv.arm.tensor<4x1x4x5xi1>, %arg1: !spirv.arm.tensor<4x2x4x5xi8>, %arg2: !spirv.arm.tensor<4x3x4x5xi8>) -> (!spirv.arm.tensor<4x3x4x5xi8>) {
+ // expected-error @+1 {{op failed to verify that the shape of inputs: condition, true_value, and false_value are compatible for broadcasting}}
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x1x4x5xi1>, !spirv.arm.tensor<4x2x4x5xi8>, !spirv.arm.tensor<4x3x4x5xi8> -> !spirv.arm.tensor<4x3x4x5xi8>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x3x4x5xi8>
+}
+
+spirv.ARM.Graph @select_inputs_broadcastable_output_shape_not_broadcast_shape(%arg0: !spirv.arm.tensor<4x1x4x5xi1>, %arg1: !spirv.arm.tensor<4x6x4x5xi8>, %arg2: !spirv.arm.tensor<4x1x4x5xi8>) -> (!spirv.arm.tensor<4x1x4x5xi8>) {
+ // expected-error @+1 {{op failed to verify that the broadcast shape of inputs: condition, true_value, and false_value is equal to the output shape}}
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x1x4x5xi1>, !spirv.arm.tensor<4x6x4x5xi8>, !spirv.arm.tensor<4x1x4x5xi8> -> !spirv.arm.tensor<4x1x4x5xi8>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x1x4x5xi8>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Equal
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @equal_input_element_types_not_matching(%arg0: !spirv.arm.tensor<16x11x5x3xf32>, %arg1: !spirv.arm.tensor<16x11x5x3xf16>) -> (!spirv.arm.tensor<16x11x5x3xi1>) {
+ // expected-error @+1 {{op failed to verify that all of {input1, input2} have same element type}}
+ %0 = spirv.Tosa.Equal %arg0, %arg1 : !spirv.arm.tensor<16x11x5x3xf32>, !spirv.arm.tensor<16x11x5x3xf16> -> !spirv.arm.tensor<16x11x5x3xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<16x11x5x3xi1>
+}
+
+spirv.ARM.Graph @equal_input_shapes_not_broadcastable(%arg0: !spirv.arm.tensor<16x11x5x3xf32>, %arg1: !spirv.arm.tensor<16x7x5x3xf32>) -> (!spirv.arm.tensor<16x11x5x3xi1>) {
+ // expected-error @+1 {{op failed to verify that the shape of input1 and input2 are compatible for broadcasting and the broadcast shape is equal to the output shape}}
+ %0 = spirv.Tosa.Equal %arg0, %arg1 : !spirv.arm.tensor<16x11x5x3xf32>, !spirv.arm.tensor<16x7x5x3xf32> -> !spirv.arm.tensor<16x11x5x3xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<16x11x5x3xi1>
+}
+
+spirv.ARM.Graph @equal_output_shape_not_broadcast_shape(%arg0: !spirv.arm.tensor<16x11x5x3xf32>, %arg1: !spirv.arm.tensor<16x1x5x3xf32>) -> (!spirv.arm.tensor<16x1x5x3xi1>) {
+ // expected-error @+1 {{op failed to verify that the shape of input1 and input2 are compatible for broadcasting and the broadcast shape is equal to the output shape}}
+ %0 = spirv.Tosa.Equal %arg0, %arg1 : !spirv.arm.tensor<16x11x5x3xf32>, !spirv.arm.tensor<16x1x5x3xf32> -> !spirv.arm.tensor<16x1x5x3xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<16x1x5x3xi1>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Greater
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @greater_input_element_types_not_matching(%arg0: !spirv.arm.tensor<11x10x10x2xi32>, %arg1: !spirv.arm.tensor<11x10x10x2xf16>) -> (!spirv.arm.tensor<11x10x10x2xi1>) {
+ // expected-error @+1 {{op failed to verify that all of {input1, input2} have same element type}}
+ %0 = spirv.Tosa.Greater %arg0, %arg1 : !spirv.arm.tensor<11x10x10x2xi32>, !spirv.arm.tensor<11x10x10x2xf16> -> !spirv.arm.tensor<11x10x10x2xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<11x10x10x2xi1>
+}
+
+spirv.ARM.Graph @greater_output_shape_not_broadcast_shape(%arg0: !spirv.arm.tensor<11x10x10x2xi32>, %arg1: !spirv.arm.tensor<11x10x10x1xi32>) -> (!spirv.arm.tensor<11x10x10x1xi1>) {
+ // expected-error @+1 {{op failed to verify that the shape of input1 and input2 are compatible for broadcasting and the broadcast shape is equal to the output shape}}
+ %0 = spirv.Tosa.Greater %arg0, %arg1 : !spirv.arm.tensor<11x10x10x2xi32>, !spirv.arm.tensor<11x10x10x1xi32> -> !spirv.arm.tensor<11x10x10x1xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<11x10x10x1xi1>
+}
+
+spirv.ARM.Graph @greater_output_shape_not_broadcast_shape_batch_dim(%arg0: !spirv.arm.tensor<11x10x10x2xf16>, %arg1: !spirv.arm.tensor<1x10x10x2xf16>) -> (!spirv.arm.tensor<1x10x10x2xi1>) {
+ // expected-error @+1 {{op failed to verify that the shape of input1 and input2 are compatible for broadcasting and the broadcast shape is equal to the output shape}}
+ %0 = spirv.Tosa.Greater %arg0, %arg1 : !spirv.arm.tensor<11x10x10x2xf16>, !spirv.arm.tensor<1x10x10x2xf16> -> !spirv.arm.tensor<1x10x10x2xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1x10x10x2xi1>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.GreaterEqual
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @greaterequal_input_element_types_not_matching(%arg0: !spirv.arm.tensor<10x17x7x16xi32>, %arg1: !spirv.arm.tensor<10x17x7x16xf32>) -> (!spirv.arm.tensor<10x17x7x16xi1>) {
+ // expected-error @+1 {{op failed to verify that all of {input1, input2} have same element type}}
+ %0 = spirv.Tosa.GreaterEqual %arg0, %arg1 : !spirv.arm.tensor<10x17x7x16xi32>, !spirv.arm.tensor<10x17x7x16xf32> -> !spirv.arm.tensor<10x17x7x16xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<10x17x7x16xi1>
+}
+
+spirv.ARM.Graph @greaterequal_input_shapes_not_broadcastable(%arg0: !spirv.arm.tensor<10x17x7x16xf32>, %arg1: !spirv.arm.tensor<10x17x5x16xf32>) -> (!spirv.arm.tensor<10x17x7x16xi1>) {
+ // expected-error @+1 {{op failed to verify that the shape of input1 and input2 are compatible for broadcasting and the broadcast shape is equal to the output shape}}
+ %0 = spirv.Tosa.GreaterEqual %arg0, %arg1 : !spirv.arm.tensor<10x17x7x16xf32>, !spirv.arm.tensor<10x17x5x16xf32> -> !spirv.arm.tensor<10x17x7x16xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<10x17x7x16xi1>
+}
+
+spirv.ARM.Graph @greaterequal_output_shape_not_broadcast_shape(%arg0: !spirv.arm.tensor<10x17x7x16xf32>, %arg1: !spirv.arm.tensor<1x17x7x16xf32>) -> (!spirv.arm.tensor<1x17x7x16xi1>) {
+ // expected-error @+1 {{op failed to verify that the shape of input1 and input2 are compatible for broadcasting and the broadcast shape is equal to the output shape}}
+ %0 = spirv.Tosa.GreaterEqual %arg0, %arg1 : !spirv.arm.tensor<10x17x7x16xf32>, !spirv.arm.tensor<1x17x7x16xf32> -> !spirv.arm.tensor<1x17x7x16xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1x17x7x16xi1>
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
index 8d51dd49e7ae0..a4a26cb394603 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
@@ -698,3 +698,91 @@ spirv.ARM.Graph @sin_fp(%arg0: !spirv.arm.tensor<49x38x58xf16>) -> (!spirv.arm.t
// CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<49x38x58xf16>
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<49x38x58xf16>
}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Select - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @select_int(%arg0: !spirv.arm.tensor<4x1x4x5xi1>, %arg1: !spirv.arm.tensor<4x6x4x5xi8>, %arg2: !spirv.arm.tensor<4x6x4x5xi8>) -> (!spirv.arm.tensor<4x6x4x5xi8>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x1x4x5xi1>, !spirv.arm.tensor<4x6x4x5xi8>, !spirv.arm.tensor<4x6x4x5xi8> -> !spirv.arm.tensor<4x6x4x5xi8>
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x1x4x5xi1>, !spirv.arm.tensor<4x6x4x5xi8>, !spirv.arm.tensor<4x6x4x5xi8> -> !spirv.arm.tensor<4x6x4x5xi8>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<4x6x4x5xi8>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x6x4x5xi8>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Select - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @select_fp(%arg0: !spirv.arm.tensor<9x2x15x8xi1>, %arg1: !spirv.arm.tensor<9x2x15x8xf16>, %arg2: !spirv.arm.tensor<9x1x15x8xf16>) -> (!spirv.arm.tensor<9x2x15x8xf16>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<9x2x15x8xi1>, !spirv.arm.tensor<9x2x15x8xf16>, !spirv.arm.tensor<9x1x15x8xf16> -> !spirv.arm.tensor<9x2x15x8xf16>
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<9x2x15x8xi1>, !spirv.arm.tensor<9x2x15x8xf16>, !spirv.arm.tensor<9x1x15x8xf16> -> !spirv.arm.tensor<9x2x15x8xf16>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<9x2x15x8xf16>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<9x2x15x8xf16>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Equal - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @equal_int(%arg0: !spirv.arm.tensor<51x28x59xi32>, %arg1: !spirv.arm.tensor<51x1x59xi32>) -> (!spirv.arm.tensor<51x28x59xi1>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Equal %arg0, %arg1 : !spirv.arm.tensor<51x28x59xi32>, !spirv.arm.tensor<51x1x59xi32> -> !spirv.arm.tensor<51x28x59xi1>
+ %0 = spirv.Tosa.Equal %arg0, %arg1 : !spirv.arm.tensor<51x28x59xi32>, !spirv.arm.tensor<51x1x59xi32> -> !spirv.arm.tensor<51x28x59xi1>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<51x28x59xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<51x28x59xi1>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Equal - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @equal_fp(%arg0: !spirv.arm.tensor<16x11x5x3xf32>, %arg1: !spirv.arm.tensor<16x1x5x3xf32>) -> (!spirv.arm.tensor<16x11x5x3xi1>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Equal %arg0, %arg1 : !spirv.arm.tensor<16x11x5x3xf32>, !spirv.arm.tensor<16x1x5x3xf32> -> !spirv.arm.tensor<16x11x5x3xi1>
+ %0 = spirv.Tosa.Equal %arg0, %arg1 : !spirv.arm.tensor<16x11x5x3xf32>, !spirv.arm.tensor<16x1x5x3xf32> -> !spirv.arm.tensor<16x11x5x3xi1>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<16x11x5x3xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<16x11x5x3xi1>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Greater - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @greater_int(%arg0: !spirv.arm.tensor<11x10x10x2xi32>, %arg1: !spirv.arm.tensor<11x10x10x1xi32>) -> (!spirv.arm.tensor<11x10x10x2xi1>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Greater %arg0, %arg1 : !spirv.arm.tensor<11x10x10x2xi32>, !spirv.arm.tensor<11x10x10x1xi32> -> !spirv.arm.tensor<11x10x10x2xi1>
+ %0 = spirv.Tosa.Greater %arg0, %arg1 : !spirv.arm.tensor<11x10x10x2xi32>, !spirv.arm.tensor<11x10x10x1xi32> -> !spirv.arm.tensor<11x10x10x2xi1>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<11x10x10x2xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<11x10x10x2xi1>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Greater - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @greater_fp(%arg0: !spirv.arm.tensor<6x3x12x4xf16>, %arg1: !spirv.arm.tensor<6x3x1x4xf16>) -> (!spirv.arm.tensor<6x3x12x4xi1>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Greater %arg0, %arg1 : !spirv.arm.tensor<6x3x12x4xf16>, !spirv.arm.tensor<6x3x1x4xf16> -> !spirv.arm.tensor<6x3x12x4xi1>
+ %0 = spirv.Tosa.Greater %arg0, %arg1 : !spirv.arm.tensor<6x3x12x4xf16>, !spirv.arm.tensor<6x3x1x4xf16> -> !spirv.arm.tensor<6x3x12x4xi1>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<6x3x12x4xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<6x3x12x4xi1>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.GreaterEqual - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @greaterequal_int(%arg0: !spirv.arm.tensor<10x17x7x1xi32>, %arg1: !spirv.arm.tensor<10x17x7x16xi32>) -> (!spirv.arm.tensor<10x17x7x16xi1>) {
+ // CHECK: {{%.*}} = spirv.Tosa.GreaterEqual %arg0, %arg1 : !spirv.arm.tensor<10x17x7x1xi32>, !spirv.arm.tensor<10x17x7x16xi32> -> !spirv.arm.tensor<10x17x7x16xi1>
+ %0 = spirv.Tosa.GreaterEqual %arg0, %arg1 : !spirv.arm.tensor<10x17x7x1xi32>, !spirv.arm.tensor<10x17x7x16xi32> -> !spirv.arm.tensor<10x17x7x16xi1>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<10x17x7x16xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<10x17x7x16xi1>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.GreaterEqual - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @greaterequal_fp(%arg0: !spirv.arm.tensor<3x17x6x3xf32>, %arg1: !spirv.arm.tensor<1x17x6x3xf32>) -> (!spirv.arm.tensor<3x17x6x3xi1>) {
+ // CHECK: {{%.*}} = spirv.Tosa.GreaterEqual %arg0, %arg1 : !spirv.arm.tensor<3x17x6x3xf32>, !spirv.arm.tensor<1x17x6x3xf32> -> !spirv.arm.tensor<3x17x6x3xi1>
+ %0 = spirv.Tosa.GreaterEqual %arg0, %arg1 : !spirv.arm.tensor<3x17x6x3xf32>, !spirv.arm.tensor<1x17x6x3xf32> -> !spirv.arm.tensor<3x17x6x3xi1>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<3x17x6x3xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<3x17x6x3xi1>
+}
diff --git a/mlir/test/Target/SPIRV/tosa-ops.mlir b/mlir/test/Target/SPIRV/tosa-ops.mlir
index 9c58bad93258f..856fde46a4866 100644
--- a/mlir/test/Target/SPIRV/tosa-ops.mlir
+++ b/mlir/test/Target/SPIRV/tosa-ops.mlir
@@ -1222,3 +1222,165 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<49x38x58xf16>
}
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Select - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @select_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<4x1x4x5xi1>, UniformConstant>
+ spirv.GlobalVariable @select_int_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<4x6x4x5xi8>, UniformConstant>
+ spirv.GlobalVariable @select_int_arg_2 bind(0, 2) : !spirv.ptr<!spirv.arm.tensor<4x6x4x5xi8>, UniformConstant>
+ spirv.GlobalVariable @select_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<4x6x4x5xi8>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @select_int, @select_int_arg_0, @select_int_arg_1, @select_int_arg_2, @select_int_res_0
+ spirv.ARM.Graph @select_int(%arg0: !spirv.arm.tensor<4x1x4x5xi1>, %arg1: !spirv.arm.tensor<4x6x4x5xi8>, %arg2: !spirv.arm.tensor<4x6x4x5xi8>) -> (!spirv.arm.tensor<4x6x4x5xi8>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x1x4x5xi1>, !spirv.arm.tensor<4x6x4x5xi8>, !spirv.arm.tensor<4x6x4x5xi8> -> !spirv.arm.tensor<4x6x4x5xi8>
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x1x4x5xi1>, !spirv.arm.tensor<4x6x4x5xi8>, !spirv.arm.tensor<4x6x4x5xi8> -> !spirv.arm.tensor<4x6x4x5xi8>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<4x6x4x5xi8>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x6x4x5xi8>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Select - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @select_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<9x2x15x8xi1>, UniformConstant>
+ spirv.GlobalVariable @select_fp_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<9x2x15x8xf16>, UniformConstant>
+ spirv.GlobalVariable @select_fp_arg_2 bind(0, 2) : !spirv.ptr<!spirv.arm.tensor<9x1x15x8xf16>, UniformConstant>
+ spirv.GlobalVariable @select_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<9x2x15x8xf16>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @select_fp, @select_fp_arg_0, @select_fp_arg_1, @select_fp_arg_2, @select_fp_res_0
+ spirv.ARM.Graph @select_fp(%arg0: !spirv.arm.tensor<9x2x15x8xi1>, %arg1: !spirv.arm.tensor<9x2x15x8xf16>, %arg2: !spirv.arm.tensor<9x1x15x8xf16>) -> (!spirv.arm.tensor<9x2x15x8xf16>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<9x2x15x8xi1>, !spirv.arm.tensor<9x2x15x8xf16>, !spirv.arm.tensor<9x1x15x8xf16> -> !spirv.arm.tensor<9x2x15x8xf16>
+ %0 = spirv.Tosa.Select %arg0, %arg1, %arg2 : !spirv.arm.tensor<9x2x15x8xi1>, !spirv.arm.tensor<9x2x15x8xf16>, !spirv.arm.tensor<9x1x15x8xf16> -> !spirv.arm.tensor<9x2x15x8xf16>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<9x2x15x8xf16>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<9x2x15x8xf16>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Equal - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @equal_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<51x28x59xi32>, UniformConstant>
+ spirv.GlobalVariable @equal_int_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<51x1x59xi32>, UniformConstant>
+ spirv.GlobalVariable @equal_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<51x28x59xi1>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @equal_int, @equal_int_arg_0, @equal_int_arg_1, @equal_int_res_0
+ spirv.ARM.Graph @equal_int(%arg0: !spirv.arm.tensor<51x28x59xi32>, %arg1: !spirv.arm.tensor<51x1x59xi32>) -> (!spirv.arm.tensor<51x28x59xi1>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Equal %arg0, %arg1 : !spirv.arm.tensor<51x28x59xi32>, !spirv.arm.tensor<51x1x59xi32> -> !spirv.arm.tensor<51x28x59xi1>
+ %0 = spirv.Tosa.Equal %arg0, %arg1 : !spirv.arm.tensor<51x28x59xi32>, !spirv.arm.tensor<51x1x59xi32> -> !spirv.arm.tensor<51x28x59xi1>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<51x28x59xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<51x28x59xi1>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Equal - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @equal_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<16x11x5x3xf32>, UniformConstant>
+ spirv.GlobalVariable @equal_fp_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<16x1x5x3xf32>, UniformConstant>
+ spirv.GlobalVariable @equal_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<16x11x5x3xi1>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @equal_fp, @equal_fp_arg_0, @equal_fp_arg_1, @equal_fp_res_0
+ spirv.ARM.Graph @equal_fp(%arg0: !spirv.arm.tensor<16x11x5x3xf32>, %arg1: !spirv.arm.tensor<16x1x5x3xf32>) -> (!spirv.arm.tensor<16x11x5x3xi1>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Equal %arg0, %arg1 : !spirv.arm.tensor<16x11x5x3xf32>, !spirv.arm.tensor<16x1x5x3xf32> -> !spirv.arm.tensor<16x11x5x3xi1>
+ %0 = spirv.Tosa.Equal %arg0, %arg1 : !spirv.arm.tensor<16x11x5x3xf32>, !spirv.arm.tensor<16x1x5x3xf32> -> !spirv.arm.tensor<16x11x5x3xi1>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<16x11x5x3xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<16x11x5x3xi1>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Greater - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @greater_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<11x10x10x2xi32>, UniformConstant>
+ spirv.GlobalVariable @greater_int_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<11x10x10x1xi32>, UniformConstant>
+ spirv.GlobalVariable @greater_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<11x10x10x2xi1>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @greater_int, @greater_int_arg_0, @greater_int_arg_1, @greater_int_res_0
+ spirv.ARM.Graph @greater_int(%arg0: !spirv.arm.tensor<11x10x10x2xi32>, %arg1: !spirv.arm.tensor<11x10x10x1xi32>) -> (!spirv.arm.tensor<11x10x10x2xi1>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Greater %arg0, %arg1 : !spirv.arm.tensor<11x10x10x2xi32>, !spirv.arm.tensor<11x10x10x1xi32> -> !spirv.arm.tensor<11x10x10x2xi1>
+ %0 = spirv.Tosa.Greater %arg0, %arg1 : !spirv.arm.tensor<11x10x10x2xi32>, !spirv.arm.tensor<11x10x10x1xi32> -> !spirv.arm.tensor<11x10x10x2xi1>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<11x10x10x2xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<11x10x10x2xi1>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Greater - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @greater_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<6x3x12x4xf16>, UniformConstant>
+ spirv.GlobalVariable @greater_fp_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<6x3x1x4xf16>, UniformConstant>
+ spirv.GlobalVariable @greater_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<6x3x12x4xi1>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @greater_fp, @greater_fp_arg_0, @greater_fp_arg_1, @greater_fp_res_0
+ spirv.ARM.Graph @greater_fp(%arg0: !spirv.arm.tensor<6x3x12x4xf16>, %arg1: !spirv.arm.tensor<6x3x1x4xf16>) -> (!spirv.arm.tensor<6x3x12x4xi1>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Greater %arg0, %arg1 : !spirv.arm.tensor<6x3x12x4xf16>, !spirv.arm.tensor<6x3x1x4xf16> -> !spirv.arm.tensor<6x3x12x4xi1>
+ %0 = spirv.Tosa.Greater %arg0, %arg1 : !spirv.arm.tensor<6x3x12x4xf16>, !spirv.arm.tensor<6x3x1x4xf16> -> !spirv.arm.tensor<6x3x12x4xi1>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<6x3x12x4xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<6x3x12x4xi1>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.GreaterEqual - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @greaterequal_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<10x17x7x1xi32>, UniformConstant>
+ spirv.GlobalVariable @greaterequal_int_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<10x17x7x16xi32>, UniformConstant>
+ spirv.GlobalVariable @greaterequal_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<10x17x7x16xi1>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @greaterequal_int, @greaterequal_int_arg_0, @greaterequal_int_arg_1, @greaterequal_int_res_0
+ spirv.ARM.Graph @greaterequal_int(%arg0: !spirv.arm.tensor<10x17x7x1xi32>, %arg1: !spirv.arm.tensor<10x17x7x16xi32>) -> (!spirv.arm.tensor<10x17x7x16xi1>) {
+ // CHECK: {{%.*}} = spirv.Tosa.GreaterEqual %arg0, %arg1 : !spirv.arm.tensor<10x17x7x1xi32>, !spirv.arm.tensor<10x17x7x16xi32> -> !spirv.arm.tensor<10x17x7x16xi1>
+ %0 = spirv.Tosa.GreaterEqual %arg0, %arg1 : !spirv.arm.tensor<10x17x7x1xi32>, !spirv.arm.tensor<10x17x7x16xi32> -> !spirv.arm.tensor<10x17x7x16xi1>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<10x17x7x16xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<10x17x7x16xi1>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.GreaterEqual - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @greaterequal_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<3x17x6x3xf32>, UniformConstant>
+ spirv.GlobalVariable @greaterequal_fp_arg_1 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<1x17x6x3xf32>, UniformConstant>
+ spirv.GlobalVariable @greaterequal_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<3x17x6x3xi1>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @greaterequal_fp, @greaterequal_fp_arg_0, @greaterequal_fp_arg_1, @greaterequal_fp_res_0
+ spirv.ARM.Graph @greaterequal_fp(%arg0: !spirv.arm.tensor<3x17x6x3xf32>, %arg1: !spirv.arm.tensor<1x17x6x3xf32>) -> (!spirv.arm.tensor<3x17x6x3xi1>) {
+ // CHECK: {{%.*}} = spirv.Tosa.GreaterEqual %arg0, %arg1 : !spirv.arm.tensor<3x17x6x3xf32>, !spirv.arm.tensor<1x17x6x3xf32> -> !spirv.arm.tensor<3x17x6x3xi1>
+ %0 = spirv.Tosa.GreaterEqual %arg0, %arg1 : !spirv.arm.tensor<3x17x6x3xf32>, !spirv.arm.tensor<1x17x6x3xf32> -> !spirv.arm.tensor<3x17x6x3xi1>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<3x17x6x3xi1>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<3x17x6x3xi1>
+ }
+}
More information about the Mlir-commits
mailing list