[Mlir-commits] [mlir] b36d14d - [mlir][spirv] Add Activation operators to TOSA Extended Instruction S… (#178620)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 3 05:56:37 PST 2026
Author: Davide Grohmann
Date: 2026-02-03T08:56:31-05:00
New Revision: b36d14d619822c1e8d049e94e102768acb69c513
URL: https://github.com/llvm/llvm-project/commit/b36d14d619822c1e8d049e94e102768acb69c513
DIFF: https://github.com/llvm/llvm-project/commit/b36d14d619822c1e8d049e94e102768acb69c513.diff
LOG: [mlir][spirv] Add Activation operators to TOSA Extended Instruction S… (#178620)
…et (001000.1)
In details the Activation operators introduced are:
spirv.Tosa.{Clamp,Erf,Sigmoid,Tanh}, along with dialect and
serialization round-trip tests.
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
mlir/include/mlir/IR/CommonAttrConstraints.td
mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
mlir/test/Target/SPIRV/tosa-ops.mlir
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index d69e215e05205..61e8ea2c9ebc8 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -524,8 +524,7 @@ def SPIRV_TosaMaxPool2DOp : SPIRV_TosaOpWithResult<"MaxPool2D", 7, [Pure,
let description = [{
Performs a max pooling over the given input tensor. A sliding window of
size given by <kernel size> is passed over the input tensor, with the
- maximum value being placed in the
- output tensor.
+ maximum value being placed in the output tensor.
References:
* https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_max_pool2d
@@ -574,7 +573,7 @@ def SPIRV_TosaRFFT2DOp : SPIRV_TosaOpWithComplexResult<"RFFT2D", 8, [Pure]> {
Performs a batched 2D real-valued Fast Fourier Transform over the input where
the input tensor consists of real values producing complex valued output. The
complex output values will be split into the output_real and output_imag
- tensor arguments. RFFT2D takes advantage of Hermitian symmetry to only
+ tensor arguments. This operator takes advantage of Hermitian symmetry to only
calculate the first half of the final output axis. Implementations may choose
to skip calculation of the imaginary values at (0,0), (0,W/2), (H/2,0), and
(H/2, W/2). If the calculation is skipped, the result at that location must be
@@ -694,4 +693,174 @@ def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaOpWithResult<"TransposeConv2D", 9, [
}
+def SPIRV_TosaClampOp : SPIRV_TosaOpWithResult<"Clamp", 10, [Pure,
+ AllTypesMatch<["input", "output"]>,
+ AllElementTypesMatch<["input", "output", "min_val", "max_val"]>]> {
+ let summary = "Computes Clamp(min, max).";
+
+ let description = [{
+ Clamp to an arbitrary minimum and maximum value.
+ Maximum and minimum values are specified as values in the range of the
+ input type.
+ No zero point subtraction is done to the values, thus to clamp to the zero
+ point value, the zero point itself should be supplied as the minimum value.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_clamp
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_clamp
+
+ #### Example:
+ ```mlir
+ %3 = spirv.Tosa.Clamp min_val = -102 : i8, max_val = -100 : i8, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x44x55xi8> -> !spirv.arm.tensor<27x44x55xi8>
+ %3 = spirv.Tosa.Clamp min_val = -1.19339396E+38 : f32, max_val = 2.38255944E+38 : f32, nan_mode = <Ignore>, %arg0 : !spirv.arm.tensor<18x5x17x6xf32> -> !spirv.arm.tensor<18x5x17x6xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_TosaNumericalAttr: $min_val,
+ SPIRV_TosaNumericalAttr: $max_val,
+ SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
+ SPIRV_TosaNumerical_TensorArm: $input
+ );
+
+ let results = (outs
+ SPIRV_TosaNumerical_TensorArm: $output
+ );
+
+ let assemblyFormat = [{
+ `min_val` `=` $min_val `,`
+ `max_val` `=` $max_val `,`
+ `nan_mode` `=` $nan_mode `,`
+ $input
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ ::mlir::spirv::TensorArmType getInputType() {
+ return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+ }
+ }];
+}
+
+
+def SPIRV_TosaErfOp : SPIRV_TosaOpWithResult<"Erf", 11, [Pure,
+ AllTypesMatch<["input", "output"]>]> {
+ let summary = "Gauss Error Function.";
+
+ let description = [{
+ Gauss Error Function: $ erf(x) = \frac{2}{\sqrt{\pi}} \int_{0}^{x} e^{-t^2} dt $
+ For quantized integer data types, the `spirv.Tosa.Table` operator should be used instead.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_erf
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_erf
+
+ #### Example:
+ ```mlir
+ %0 = spirv.Tosa.Erf %arg0 : !spirv.arm.tensor<47x38x51xf32> -> !spirv.arm.tensor<47x38x51xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_TosaFloat_TensorArm: $input
+ );
+
+ let results = (outs
+ SPIRV_TosaFloat_TensorArm: $output
+ );
+
+ let assemblyFormat = [{
+ $input
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ ::mlir::spirv::TensorArmType getInputType() {
+ return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+ }
+ }];
+}
+
+
+def SPIRV_TosaSigmoidOp : SPIRV_TosaOpWithResult<"Sigmoid", 12, [Pure,
+ AllTypesMatch<["input", "output"]>]> {
+ let summary = "Sigmoid operator.";
+
+ let description = [{
+ Applies the sigmoid logistic function to each element of the input tensor:
+ $ sigmoid(x) = \frac{1}{1 + e^{-x}} $.
+
+ For quantized integer data types, the `spirv.Tosa.Table` operator should be used instead.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_sigmoid
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_sigmoid
+
+ #### Example:
+ ```mlir
+ %0 = spirv.Tosa.Sigmoid %arg0 : !spirv.arm.tensor<28x43x45xf32> -> !spirv.arm.tensor<28x43x45xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_TosaFloat_TensorArm: $input
+ );
+
+ let results = (outs
+ SPIRV_TosaFloat_TensorArm: $output
+ );
+
+ let assemblyFormat = [{
+ $input
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ ::mlir::spirv::TensorArmType getInputType() {
+ return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+ }
+ }];
+}
+
+
+def SPIRV_TosaTanhOp : SPIRV_TosaOpWithResult<"Tanh", 13, [Pure,
+ AllTypesMatch<["input", "output"]>]> {
+ let summary = "Hyperbolic Tangent operator.";
+
+ let description = [{
+ Elementwise Parameterized Hyperbolic Tangent: $ tanh(x) = \frac{1 - e^{-2x}}{1 + e^{-2x}} $.
+
+ For quantized integer data types, the `spirv.Tosa.Table` operator should be used instead.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_tanh
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_tanh
+
+ #### Example:
+ ```mlir
+ %0 = spirv.Tosa.Tanh %arg0 : !spirv.arm.tensor<46x50x36xf16> -> !spirv.arm.tensor<46x50x36xf16>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_TosaFloat_TensorArm: $input
+ );
+
+ let results = (outs
+ SPIRV_TosaFloat_TensorArm: $output
+ );
+
+ let assemblyFormat = [{
+ $input
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ ::mlir::spirv::TensorArmType getInputType() {
+ return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+ }
+ }];
+}
+
+
#endif // MLIR_DIALECT_SPIRV_IR_TOSA_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
index db4ad8064fc11..5fe3bc53618f4 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -23,6 +23,7 @@ def SPIRV_TosaAny : AnyTypeOf<[SPIRV_TosaNumerical, SPIRV_Bool]>;
def SPIRV_TensorArmAxisAttr : ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<5>]>;
def SPIRV_BoolConstAttr : ConfinedAttr<BoolAttr, []>;
+def SPIRV_TosaNumericalAttr: AnyAttrOf<[I8Attr, I16Attr, I32Attr, I64Attr, F16Attr, F32Attr, BF16Attr]>;
// TensorARM Types
@@ -44,6 +45,7 @@ def SPIRV_TosaNumerical_TensorArm4D : TensorArmRankOf<[SPIRV_TosaNumerical], [4]
def SPIRV_TosaNumerical_TensorArm5D : TensorArmRankOf<[SPIRV_TosaNumerical], [5]>;
def SPIRV_TosaNumerical_TensorArm : TensorArmRankOf<[SPIRV_TosaNumerical], [1, 2, 3, 4, 5, 6]>;
+def SPIRV_TosaFloat_TensorArm : TensorArmRankOf<[SPIRV_TosaFloat], [1, 2, 3, 4, 5, 6]>;
def SPIRV_Int32_TensorArmUpTo5D : TensorArmRankOf<[SPIRV_Int32], [1, 2, 3, 4, 5]>;
class Is1DTensorArmOfLength<list<int> allowedLengths> :
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index 8ac1a2ea21422..ba6cf55a8fb9e 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -334,9 +334,17 @@ class FloatAttrBase<F attrValType, string descr> :
let returnType = [{ ::llvm::APFloat }];
}
+def F16Attr : FloatAttrBase<F16, "16-bit float attribute">;
def F32Attr : FloatAttrBase<F32, "32-bit float attribute">;
def F64Attr : FloatAttrBase<F64, "64-bit float attribute">;
+def BF16Attr : TypedAttrBase<BF16, "FloatAttr",
+ And<[CPred<"::llvm::isa<::mlir::FloatAttr>($_self)">,
+ CPred<"::llvm::cast<::mlir::FloatAttr>($_self).getType().isBF16()">]>,
+ "16-bit bfloat attribute"> {
+ let returnType = [{ ::llvm::APFloat }];
+}
+
// An attribute backed by a string type.
class StringBasedAttr<Pred condition, string descr> : Attr<condition, descr> {
let constBuilderCall = "$_builder.getStringAttr($0)";
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
index 56cd6d6900fdb..dd18a3a2ae788 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
@@ -410,8 +410,24 @@ spirv.ARM.Graph @matmul_invalid_input_output_element_type_combination(%arg0: !sp
// spirv.TOSA.MaxPool2D
//===----------------------------------------------------------------------===//
-spirv.ARM.Graph @maxpool2d_int(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32769x1xi16>) {
+spirv.ARM.Graph @maxpool2d_input_output_
diff erent_element_types(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32769x1xi16>) {
// expected-error @+1 {{op failed to verify that all of {input, output} have same element type}}
%4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [1, 2], pad = [1, 0, 0, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x3x65537x1xi8> -> !spirv.arm.tensor<1x2x32769x1xi16>
spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x2x32769x1xi16>
}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Clamp
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @clamp_min_val_
diff erent_element_type_wrt_input_output(%arg0: !spirv.arm.tensor<27x44x55xi8>) -> (!spirv.arm.tensor<27x44x55xi8>) {
+ // expected-error @+1 {{op failed to verify that all of {input, output, min_val, max_val} have same element type}}
+ %3 = spirv.Tosa.Clamp min_val = -102 : i16, max_val = -100 : i8, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x44x55xi8> -> !spirv.arm.tensor<27x44x55xi8>
+ spirv.ARM.GraphOutputs %3 : !spirv.arm.tensor<27x44x55xi8>
+}
+
+spirv.ARM.Graph @clamp_max_val_
diff erent_element_type_wrt_input_output(%arg0: !spirv.arm.tensor<27x44x55xi8>) -> (!spirv.arm.tensor<27x44x55xi8>) {
+ // expected-error @+1 {{op failed to verify that all of {input, output, min_val, max_val} have same element type}}
+ %3 = spirv.Tosa.Clamp min_val = -102 : i8, max_val = -100 : i16, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x44x55xi8> -> !spirv.arm.tensor<27x44x55xi8>
+ spirv.ARM.GraphOutputs %3 : !spirv.arm.tensor<27x44x55xi8>
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
index 1a43e2c95c530..a9f7bc2b8ef7d 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
@@ -229,3 +229,58 @@ spirv.ARM.Graph @transposeconv2d_fp(%arg0: !spirv.arm.tensor<10x24x9x13xf16>, %a
// CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<10x25x65x14xf16>
spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<10x25x65x14xf16>
}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Clamp - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @clamp_int(%arg0: !spirv.arm.tensor<27x44x55xi8>) -> (!spirv.arm.tensor<27x44x55xi8>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Clamp min_val = -102 : i8, max_val = -100 : i8, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x44x55xi8> -> !spirv.arm.tensor<27x44x55xi8>
+ %3 = spirv.Tosa.Clamp min_val = -102 : i8, max_val = -100 : i8, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x44x55xi8> -> !spirv.arm.tensor<27x44x55xi8>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<27x44x55xi8>
+ spirv.ARM.GraphOutputs %3 : !spirv.arm.tensor<27x44x55xi8>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Clamp - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @clamp_fp(%arg0: !spirv.arm.tensor<18x5x17x6xf32>) -> (!spirv.arm.tensor<18x5x17x6xf32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Clamp min_val = -1.19339396E+38 : f32, max_val = 2.38255944E+38 : f32, nan_mode = <Ignore>, %arg0 : !spirv.arm.tensor<18x5x17x6xf32> -> !spirv.arm.tensor<18x5x17x6xf32>
+ %3 = spirv.Tosa.Clamp min_val = -1.19339396E+38 : f32, max_val = 2.38255944E+38 : f32, nan_mode = <Ignore>, %arg0 : !spirv.arm.tensor<18x5x17x6xf32> -> !spirv.arm.tensor<18x5x17x6xf32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<18x5x17x6xf32>
+ spirv.ARM.GraphOutputs %3 : !spirv.arm.tensor<18x5x17x6xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Erf - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @erf_fp(%arg0: !spirv.arm.tensor<47x38x51xf32>) -> (!spirv.arm.tensor<47x38x51xf32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Erf %arg0 : !spirv.arm.tensor<47x38x51xf32> -> !spirv.arm.tensor<47x38x51xf32>
+ %0 = spirv.Tosa.Erf %arg0 : !spirv.arm.tensor<47x38x51xf32> -> !spirv.arm.tensor<47x38x51xf32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<47x38x51xf32>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<47x38x51xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Sigmoid - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @sigmoid_fp(%arg0: !spirv.arm.tensor<28x43x45xf32>) -> (!spirv.arm.tensor<28x43x45xf32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Sigmoid %arg0 : !spirv.arm.tensor<28x43x45xf32> -> !spirv.arm.tensor<28x43x45xf32>
+ %0 = spirv.Tosa.Sigmoid %arg0 : !spirv.arm.tensor<28x43x45xf32> -> !spirv.arm.tensor<28x43x45xf32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<28x43x45xf32>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<28x43x45xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Tanh - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @tanh_fp(%arg0: !spirv.arm.tensor<46x50x36xf16>) -> (!spirv.arm.tensor<46x50x36xf16>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Tanh %arg0 : !spirv.arm.tensor<46x50x36xf16> -> !spirv.arm.tensor<46x50x36xf16>
+ %0 = spirv.Tosa.Tanh %arg0 : !spirv.arm.tensor<46x50x36xf16> -> !spirv.arm.tensor<46x50x36xf16>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<46x50x36xf16>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<46x50x36xf16>
+}
diff --git a/mlir/test/Target/SPIRV/tosa-ops.mlir b/mlir/test/Target/SPIRV/tosa-ops.mlir
index 1d219b855bec1..9f2ff1c31cbc5 100644
--- a/mlir/test/Target/SPIRV/tosa-ops.mlir
+++ b/mlir/test/Target/SPIRV/tosa-ops.mlir
@@ -396,3 +396,98 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader
spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<10x25x65x14xf16>
}
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Clamp - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @clamp_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<27x44x55xi8>, UniformConstant>
+ spirv.GlobalVariable @clamp_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<27x44x55xi8>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @clamp_int, @clamp_int_arg_0, @clamp_int_res_0
+ spirv.ARM.Graph @clamp_int(%arg0: !spirv.arm.tensor<27x44x55xi8>) -> (!spirv.arm.tensor<27x44x55xi8>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Clamp min_val = -102 : i8, max_val = -100 : i8, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x44x55xi8> -> !spirv.arm.tensor<27x44x55xi8>
+ %3 = spirv.Tosa.Clamp min_val = -102 : i8, max_val = -100 : i8, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x44x55xi8> -> !spirv.arm.tensor<27x44x55xi8>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<27x44x55xi8>
+ spirv.ARM.GraphOutputs %3 : !spirv.arm.tensor<27x44x55xi8>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Clamp - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @clamp_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<18x5x17x6xf32>, UniformConstant>
+ spirv.GlobalVariable @clamp_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<18x5x17x6xf32>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @clamp_fp, @clamp_fp_arg_0, @clamp_fp_res_0
+ spirv.ARM.Graph @clamp_fp(%arg0: !spirv.arm.tensor<18x5x17x6xf32>) -> (!spirv.arm.tensor<18x5x17x6xf32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Clamp min_val = -1.19339396E+38 : f32, max_val = 2.38255944E+38 : f32, nan_mode = <Ignore>, %arg0 : !spirv.arm.tensor<18x5x17x6xf32> -> !spirv.arm.tensor<18x5x17x6xf32>
+ %3 = spirv.Tosa.Clamp min_val = -1.19339396E+38 : f32, max_val = 2.38255944E+38 : f32, nan_mode = <Ignore>, %arg0 : !spirv.arm.tensor<18x5x17x6xf32> -> !spirv.arm.tensor<18x5x17x6xf32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<18x5x17x6xf32>
+ spirv.ARM.GraphOutputs %3 : !spirv.arm.tensor<18x5x17x6xf32>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Erf - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @erf_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<47x38x51xf32>, UniformConstant>
+ spirv.GlobalVariable @erf_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<47x38x51xf32>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @erf_fp, @erf_fp_arg_0, @erf_fp_res_0
+ spirv.ARM.Graph @erf_fp(%arg0: !spirv.arm.tensor<47x38x51xf32>) -> (!spirv.arm.tensor<47x38x51xf32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Erf %arg0 : !spirv.arm.tensor<47x38x51xf32> -> !spirv.arm.tensor<47x38x51xf32>
+ %0 = spirv.Tosa.Erf %arg0 : !spirv.arm.tensor<47x38x51xf32> -> !spirv.arm.tensor<47x38x51xf32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<47x38x51xf32>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<47x38x51xf32>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Sigmoid - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @sigmoid_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<28x43x45xf32>, UniformConstant>
+ spirv.GlobalVariable @sigmoid_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<28x43x45xf32>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @sigmoid_fp, @sigmoid_fp_arg_0, @sigmoid_fp_res_0
+ spirv.ARM.Graph @sigmoid_fp(%arg0: !spirv.arm.tensor<28x43x45xf32>) -> (!spirv.arm.tensor<28x43x45xf32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Sigmoid %arg0 : !spirv.arm.tensor<28x43x45xf32> -> !spirv.arm.tensor<28x43x45xf32>
+ %0 = spirv.Tosa.Sigmoid %arg0 : !spirv.arm.tensor<28x43x45xf32> -> !spirv.arm.tensor<28x43x45xf32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<28x43x45xf32>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<28x43x45xf32>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Tanh - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @tanh_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<46x50x36xf16>, UniformConstant>
+ spirv.GlobalVariable @tanh_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<46x50x36xf16>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @tanh_fp, @tanh_fp_arg_0, @tanh_fp_res_0
+ spirv.ARM.Graph @tanh_fp(%arg0: !spirv.arm.tensor<46x50x36xf16>) -> (!spirv.arm.tensor<46x50x36xf16>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Tanh %arg0 : !spirv.arm.tensor<46x50x36xf16> -> !spirv.arm.tensor<46x50x36xf16>
+ %0 = spirv.Tosa.Tanh %arg0 : !spirv.arm.tensor<46x50x36xf16> -> !spirv.arm.tensor<46x50x36xf16>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<46x50x36xf16>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<46x50x36xf16>
+ }
+}
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 9cb48934b2c10..34edd4df49d8e 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -558,9 +558,10 @@ static void emitAttributeSerialization(const Attribute &attr,
os << tabs << " return failure();\n";
os << tabs << " }\n";
os << tabs << formatv(" {0}.push_back(attrTypeID);\n", operandList);
- } else if (llvm::is_contained(
- {"SPIRV_BoolConstAttr", "SPIRV_TensorArmAxisAttr"},
- attr.getAttrDefName())) {
+ } else if (llvm::is_contained({"SPIRV_BoolConstAttr",
+ "SPIRV_TensorArmAxisAttr",
+ "SPIRV_TosaNumericalAttr"},
+ attr.getAttrDefName())) {
os << tabs
<< formatv(
" {0}.push_back(prepareConstantScalar({1}.getLoc(), attr));\n",
@@ -864,7 +865,9 @@ static void emitAttributeDeserialization(const Attribute &attr,
<< formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
"TypeAttr::get(getType({2}[{3}++]))));\n",
attrList, attrName, words, wordIndex);
- } else if (attr.getAttrDefName() == "SPIRV_BoolConstAttr" ||
+ } else if (llvm::is_contained(
+ {"SPIRV_BoolConstAttr", "SPIRV_TosaNumericalAttr"},
+ attr.getAttrDefName()) ||
attr.getAttrDefName().contains("TensorArm")) {
os << tabs
<< formatv("std::optional<std::pair<Attribute, Type>> c = "
More information about the Mlir-commits
mailing list