[Mlir-commits] [mlir] [mlir][spirv] Add Activation operators to TOSA Extended Instruction S… (PR #178620)

Davide Grohmann llvmlistbot at llvm.org
Thu Jan 29 02:36:34 PST 2026


https://github.com/davidegrohmann updated https://github.com/llvm/llvm-project/pull/178620

>From 47b06d600786b321efb00c9041ff2e020832c25d Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Wed, 28 Jan 2026 13:39:57 +0100
Subject: [PATCH] [mlir][spirv] Add Activation operators to TOSA Extended
 Instruction Set (001000.1)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This patch adds the Activation operators to the TOSA Extended
Instruction Set (001000.1) to the SPIR-V dialect in MLIR. The TOSA
extended instruction set provides a standardized set of machine
learning operations designed to be used within `spirv.ARM.Graph`
operations (corresponding to OpGraphARM in SPV_ARM_graph) and typed
with `!spirv.arm.tensor<...>` (corresponding to OpTypeTensorARM in
SPV_ARM_tensor).

The change introduces:
* Dialect plumbing for import, serialization, and deserialization of
  the TOSA extended instruction set.
* spirv.Tosa.{Clamp,Erf,Sigmoid,Tanh} operations per TOSA extended
  instruction, each lowering to the corresponding `OpExtInst`.
* Verification enforcing that all the activation ops appear only
  within `spirv.ARM.Graph` regions, operate on
  `!spirv.arm.tensor<...>` types, and are well-formed according to the
  TOSA 001000.1 specification.

All these operations from TOSA 001000.1 extended instructions are
introduced: Parser, printer, verifier, and round-trip tests using
MLIR’s SPIR-V serialization/deserialization infrastructure are
included.

This work completes support for expressing TOSA extended instructions
inside SPIR-V graphs in MLIR, aligning with Khronos SPIR-V TOSA
specifications.

Specification:
https: //github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html
Change-Id: I0e7bed3f5a2b0098a4d532ba1d577f1d82507aa0
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
---
 .../mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td     | 178 ++++++++++++++++++
 .../mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td   |   2 +
 mlir/include/mlir/IR/CommonAttrConstraints.td |   8 +
 .../SPIRV/IR/tosa-ops-verification.mlir       |  18 +-
 mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir      |  55 ++++++
 mlir/test/Target/SPIRV/tosa-ops.mlir          |  95 ++++++++++
 mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp      |  11 +-
 7 files changed, 362 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index d69e215e05205..e8947cecab848 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -694,4 +694,182 @@ def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaOpWithResult<"TransposeConv2D", 9, [
 }
 
 
+def SPIRV_TosaClampOp : SPIRV_TosaOpWithResult<"Clamp", 10, [Pure,
+  AllTypesMatch<["input", "output"]>,
+  AllElementTypesMatch<["input", "output", "min_val", "max_val"]>]> {
+  let summary = "Computes Clamp(min, max).";
+
+  let description = [{
+    Clamp to an arbitrary minimum and maximum value.
+    Maximum and minimum values are specified as values in the range of the
+    input type.
+    No zero point subtraction is done to the values, thus to clamp to the zero
+    point value, the zero point itself should be supplied as the minimum value.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_clamp
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_clamp
+
+    #### Example:
+    ```mlir
+    %3 = spirv.Tosa.Clamp min_val = -102 : i8, max_val = -100 : i8, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x44x55xi8> -> !spirv.arm.tensor<27x44x55xi8>
+    %3 = spirv.Tosa.Clamp min_val = -1.19339396E+38 : f32, max_val = 2.38255944E+38 : f32, nan_mode = <Ignore>, %arg0 : !spirv.arm.tensor<18x5x17x6xf32> -> !spirv.arm.tensor<18x5x17x6xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_TosaNumericalAttr: $min_val,
+    SPIRV_TosaNumericalAttr: $max_val,
+    SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
+    SPIRV_TosaNumerical_TensorArm: $input
+  );
+
+  let results = (outs
+    SPIRV_TosaNumerical_TensorArm: $output
+  );
+
+  let assemblyFormat = [{
+    `min_val` `=` $min_val `,`
+    `max_val` `=` $max_val `,`
+    `nan_mode` `=` $nan_mode `,`
+    $input
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+
+  let extraClassDeclaration = extraBaseClassDeclaration#[{
+    ::mlir::spirv::TensorArmType getInputType() {
+      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+    }
+  }];
+}
+
+
+def SPIRV_TosaErfOp : SPIRV_TosaOpWithResult<"Erf", 11, [Pure,
+  AllTypesMatch<["input", "output"]>]> {
+  let summary = "Computes Gauss Error Function of input.";
+
+  let description = [{
+    Gauss Error Function: $ erf(x) = \frac{2}{\sqrt{\pi}} \int_{0}^{x} e^{-t^2} dt $
+    For quantized integer data types, the table operator should be used instead
+    with the following definition. The ERF table has 513 entries each of
+    16-bit precision and covering the input range -4.0 to +4.0 in steps of 1/64.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_erf
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_erf
+
+    #### Example:
+    ```mlir
+    %0 = spirv.Tosa.Erf %arg0 : !spirv.arm.tensor<47x38x51xf32> -> !spirv.arm.tensor<47x38x51xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_TosaFloat_TensorArm: $input
+  );
+
+  let results = (outs
+    SPIRV_TosaFloat_TensorArm: $output
+  );
+
+  let assemblyFormat = [{
+    $input
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+
+  let extraClassDeclaration = extraBaseClassDeclaration#[{
+    ::mlir::spirv::TensorArmType getInputType() {
+      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+    }
+  }];
+}
+
+
+def SPIRV_TosaSigmoidOp : SPIRV_TosaOpWithResult<"Sigmoid", 12, [Pure,
+  AllTypesMatch<["input", "output"]>]> {
+  let summary = "Computes elementwise sigmoid of input.";
+
+  let description = [{
+    Applies the sigmoid logistic function to each element of the input tensor:
+    $ sigmoid(x) = \frac{1}{1 + e^{-x}} $.
+
+    For quantized integer data types, the table operator should be used instead.
+    Each implementation may choose an appropriate table given the scale and zero
+    point of the input data. Eight or sixteen bit precision tables may be used
+    based on the input tensor to the sigmoid function.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_sigmoid
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_sigmoid
+
+    #### Example:
+    ```mlir
+    %0 = spirv.Tosa.Sigmoid %arg0 : !spirv.arm.tensor<28x43x45xf32> -> !spirv.arm.tensor<28x43x45xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_TosaFloat_TensorArm: $input
+  );
+
+  let results = (outs
+    SPIRV_TosaFloat_TensorArm: $output
+  );
+
+  let assemblyFormat = [{
+    $input
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+
+  let extraClassDeclaration = extraBaseClassDeclaration#[{
+    ::mlir::spirv::TensorArmType getInputType() {
+      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+    }
+  }];
+}
+
+
+def SPIRV_TosaTanhOp : SPIRV_TosaOpWithResult<"Tanh", 13, [Pure,
+  AllTypesMatch<["input", "output"]>]> {
+  let summary = "Computes elementwise Hyperbolic Tangent of input.";
+
+  let description = [{
+    Parameterized Hyperbolic Tangent: $ tanh(x) = \frac{1 - e^{-2x}}{1 + e^{-2x}} $.
+
+    For quantized integer data types, the table operator should be used instead.
+    Each implementation may choose an appropriate table given the scale and zero
+    point of the input data. Eight or sixteen bit precision tables may be used
+    based on the input tensor to the tanh function.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_tanh
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_tanh
+
+    #### Example:
+    ```mlir
+    %0 = spirv.Tosa.Tanh %arg0 : !spirv.arm.tensor<46x50x36xf16> -> !spirv.arm.tensor<46x50x36xf16>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_TosaFloat_TensorArm: $input
+  );
+
+  let results = (outs
+    SPIRV_TosaFloat_TensorArm: $output
+  );
+
+  let assemblyFormat = [{
+    $input
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+
+  let extraClassDeclaration = extraBaseClassDeclaration#[{
+    ::mlir::spirv::TensorArmType getInputType() {
+      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+    }
+  }];
+}
+
+
 #endif // MLIR_DIALECT_SPIRV_IR_TOSA_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
index db4ad8064fc11..5fe3bc53618f4 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -23,6 +23,7 @@ def SPIRV_TosaAny : AnyTypeOf<[SPIRV_TosaNumerical, SPIRV_Bool]>;
 
 def SPIRV_TensorArmAxisAttr : ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<5>]>;
 def SPIRV_BoolConstAttr : ConfinedAttr<BoolAttr, []>;
+def SPIRV_TosaNumericalAttr: AnyAttrOf<[I8Attr, I16Attr, I32Attr, I64Attr, F16Attr, F32Attr, BF16Attr]>;
 
 // TensorARM Types
 
@@ -44,6 +45,7 @@ def SPIRV_TosaNumerical_TensorArm4D : TensorArmRankOf<[SPIRV_TosaNumerical], [4]
 def SPIRV_TosaNumerical_TensorArm5D : TensorArmRankOf<[SPIRV_TosaNumerical], [5]>;
 
 def SPIRV_TosaNumerical_TensorArm : TensorArmRankOf<[SPIRV_TosaNumerical], [1, 2, 3, 4, 5, 6]>;
+def SPIRV_TosaFloat_TensorArm : TensorArmRankOf<[SPIRV_TosaFloat], [1, 2, 3, 4, 5, 6]>;
 def SPIRV_Int32_TensorArmUpTo5D : TensorArmRankOf<[SPIRV_Int32], [1, 2, 3, 4, 5]>;
 
 class Is1DTensorArmOfLength<list<int> allowedLengths> :
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index 8ac1a2ea21422..ba6cf55a8fb9e 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -334,9 +334,17 @@ class FloatAttrBase<F attrValType, string descr> :
   let returnType = [{ ::llvm::APFloat }];
 }
 
+def F16Attr : FloatAttrBase<F16, "16-bit float attribute">;
 def F32Attr : FloatAttrBase<F32, "32-bit float attribute">;
 def F64Attr : FloatAttrBase<F64, "64-bit float attribute">;
 
+def BF16Attr : TypedAttrBase<BF16, "FloatAttr",
+              And<[CPred<"::llvm::isa<::mlir::FloatAttr>($_self)">,
+                     CPred<"::llvm::cast<::mlir::FloatAttr>($_self).getType().isBF16()">]>,
+              "16-bit bfloat attribute"> {
+  let returnType = [{ ::llvm::APFloat }];
+}
+
 // An attribute backed by a string type.
 class StringBasedAttr<Pred condition, string descr> : Attr<condition, descr> {
   let constBuilderCall = "$_builder.getStringAttr($0)";
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
index 56cd6d6900fdb..dd18a3a2ae788 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
@@ -410,8 +410,24 @@ spirv.ARM.Graph @matmul_invalid_input_output_element_type_combination(%arg0: !sp
 // spirv.TOSA.MaxPool2D
 //===----------------------------------------------------------------------===//
 
-spirv.ARM.Graph @maxpool2d_int(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32769x1xi16>) {
+spirv.ARM.Graph @maxpool2d_input_output_different_element_types(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32769x1xi16>) {
   // expected-error @+1 {{op failed to verify that all of {input, output} have same element type}}
   %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [1, 2], pad = [1, 0, 0, 1], nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<1x3x65537x1xi8> -> !spirv.arm.tensor<1x2x32769x1xi16>
   spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x2x32769x1xi16>
 }
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Clamp
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @clamp_min_val_different_element_type_wrt_input_output(%arg0: !spirv.arm.tensor<27x44x55xi8>) -> (!spirv.arm.tensor<27x44x55xi8>) {
+  // expected-error @+1 {{op failed to verify that all of {input, output, min_val, max_val} have same element type}}
+  %3 = spirv.Tosa.Clamp min_val = -102 : i16, max_val = -100 : i8, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x44x55xi8> -> !spirv.arm.tensor<27x44x55xi8>
+  spirv.ARM.GraphOutputs %3 : !spirv.arm.tensor<27x44x55xi8>
+}
+
+spirv.ARM.Graph @clamp_max_val_different_element_type_wrt_input_output(%arg0: !spirv.arm.tensor<27x44x55xi8>) -> (!spirv.arm.tensor<27x44x55xi8>) {
+  // expected-error @+1 {{op failed to verify that all of {input, output, min_val, max_val} have same element type}}
+  %3 = spirv.Tosa.Clamp min_val = -102 : i8, max_val = -100 : i16, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x44x55xi8> -> !spirv.arm.tensor<27x44x55xi8>
+  spirv.ARM.GraphOutputs %3 : !spirv.arm.tensor<27x44x55xi8>
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
index 1a43e2c95c530..a9f7bc2b8ef7d 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
@@ -229,3 +229,58 @@ spirv.ARM.Graph @transposeconv2d_fp(%arg0: !spirv.arm.tensor<10x24x9x13xf16>, %a
   // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<10x25x65x14xf16>
   spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<10x25x65x14xf16>
 }
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Clamp - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @clamp_int(%arg0: !spirv.arm.tensor<27x44x55xi8>) -> (!spirv.arm.tensor<27x44x55xi8>) {
+  // CHECK: {{%.*}} = spirv.Tosa.Clamp min_val = -102 : i8, max_val = -100 : i8, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x44x55xi8> -> !spirv.arm.tensor<27x44x55xi8>
+  %3 = spirv.Tosa.Clamp min_val = -102 : i8, max_val = -100 : i8, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x44x55xi8> -> !spirv.arm.tensor<27x44x55xi8>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<27x44x55xi8>
+  spirv.ARM.GraphOutputs %3 : !spirv.arm.tensor<27x44x55xi8>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Clamp - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @clamp_fp(%arg0: !spirv.arm.tensor<18x5x17x6xf32>) -> (!spirv.arm.tensor<18x5x17x6xf32>) {
+  // CHECK: {{%.*}} = spirv.Tosa.Clamp min_val = -1.19339396E+38 : f32, max_val = 2.38255944E+38 : f32, nan_mode = <Ignore>, %arg0 : !spirv.arm.tensor<18x5x17x6xf32> -> !spirv.arm.tensor<18x5x17x6xf32>
+  %3 = spirv.Tosa.Clamp min_val = -1.19339396E+38 : f32, max_val = 2.38255944E+38 : f32, nan_mode = <Ignore>, %arg0 : !spirv.arm.tensor<18x5x17x6xf32> -> !spirv.arm.tensor<18x5x17x6xf32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<18x5x17x6xf32>
+  spirv.ARM.GraphOutputs %3 : !spirv.arm.tensor<18x5x17x6xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Erf - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @erf_fp(%arg0: !spirv.arm.tensor<47x38x51xf32>) -> (!spirv.arm.tensor<47x38x51xf32>) {
+  // CHECK: {{%.*}} = spirv.Tosa.Erf %arg0 : !spirv.arm.tensor<47x38x51xf32> -> !spirv.arm.tensor<47x38x51xf32>
+  %0 = spirv.Tosa.Erf %arg0 : !spirv.arm.tensor<47x38x51xf32> -> !spirv.arm.tensor<47x38x51xf32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<47x38x51xf32>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<47x38x51xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Sigmoid - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @sigmoid_fp(%arg0: !spirv.arm.tensor<28x43x45xf32>) -> (!spirv.arm.tensor<28x43x45xf32>) {
+  // CHECK: {{%.*}} = spirv.Tosa.Sigmoid %arg0 : !spirv.arm.tensor<28x43x45xf32> -> !spirv.arm.tensor<28x43x45xf32>
+  %0 = spirv.Tosa.Sigmoid %arg0 : !spirv.arm.tensor<28x43x45xf32> -> !spirv.arm.tensor<28x43x45xf32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<28x43x45xf32>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<28x43x45xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Tanh - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @tanh_fp(%arg0: !spirv.arm.tensor<46x50x36xf16>) -> (!spirv.arm.tensor<46x50x36xf16>) {
+  // CHECK: {{%.*}} = spirv.Tosa.Tanh %arg0 : !spirv.arm.tensor<46x50x36xf16> -> !spirv.arm.tensor<46x50x36xf16>
+  %0 = spirv.Tosa.Tanh %arg0 : !spirv.arm.tensor<46x50x36xf16> -> !spirv.arm.tensor<46x50x36xf16>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<46x50x36xf16>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<46x50x36xf16>
+}
diff --git a/mlir/test/Target/SPIRV/tosa-ops.mlir b/mlir/test/Target/SPIRV/tosa-ops.mlir
index 1d219b855bec1..9f2ff1c31cbc5 100644
--- a/mlir/test/Target/SPIRV/tosa-ops.mlir
+++ b/mlir/test/Target/SPIRV/tosa-ops.mlir
@@ -396,3 +396,98 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader
     spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<10x25x65x14xf16>
   }
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Clamp - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @clamp_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<27x44x55xi8>, UniformConstant>
+  spirv.GlobalVariable @clamp_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<27x44x55xi8>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @clamp_int, @clamp_int_arg_0, @clamp_int_res_0
+  spirv.ARM.Graph @clamp_int(%arg0: !spirv.arm.tensor<27x44x55xi8>) -> (!spirv.arm.tensor<27x44x55xi8>) {
+    // CHECK: {{%.*}} = spirv.Tosa.Clamp min_val = -102 : i8, max_val = -100 : i8, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x44x55xi8> -> !spirv.arm.tensor<27x44x55xi8>
+    %3 = spirv.Tosa.Clamp min_val = -102 : i8, max_val = -100 : i8, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x44x55xi8> -> !spirv.arm.tensor<27x44x55xi8>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<27x44x55xi8>
+    spirv.ARM.GraphOutputs %3 : !spirv.arm.tensor<27x44x55xi8>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Clamp - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @clamp_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<18x5x17x6xf32>, UniformConstant>
+  spirv.GlobalVariable @clamp_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<18x5x17x6xf32>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @clamp_fp, @clamp_fp_arg_0, @clamp_fp_res_0
+  spirv.ARM.Graph @clamp_fp(%arg0: !spirv.arm.tensor<18x5x17x6xf32>) -> (!spirv.arm.tensor<18x5x17x6xf32>) {
+    // CHECK: {{%.*}} = spirv.Tosa.Clamp min_val = -1.19339396E+38 : f32, max_val = 2.38255944E+38 : f32, nan_mode = <Ignore>, %arg0 : !spirv.arm.tensor<18x5x17x6xf32> -> !spirv.arm.tensor<18x5x17x6xf32>
+    %3 = spirv.Tosa.Clamp min_val = -1.19339396E+38 : f32, max_val = 2.38255944E+38 : f32, nan_mode = <Ignore>, %arg0 : !spirv.arm.tensor<18x5x17x6xf32> -> !spirv.arm.tensor<18x5x17x6xf32>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<18x5x17x6xf32>
+    spirv.ARM.GraphOutputs %3 : !spirv.arm.tensor<18x5x17x6xf32>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Erf - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @erf_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<47x38x51xf32>, UniformConstant>
+  spirv.GlobalVariable @erf_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<47x38x51xf32>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @erf_fp, @erf_fp_arg_0, @erf_fp_res_0
+  spirv.ARM.Graph @erf_fp(%arg0: !spirv.arm.tensor<47x38x51xf32>) -> (!spirv.arm.tensor<47x38x51xf32>) {
+    // CHECK: {{%.*}} = spirv.Tosa.Erf %arg0 : !spirv.arm.tensor<47x38x51xf32> -> !spirv.arm.tensor<47x38x51xf32>
+    %0 = spirv.Tosa.Erf %arg0 : !spirv.arm.tensor<47x38x51xf32> -> !spirv.arm.tensor<47x38x51xf32>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<47x38x51xf32>
+    spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<47x38x51xf32>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Sigmoid - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @sigmoid_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<28x43x45xf32>, UniformConstant>
+  spirv.GlobalVariable @sigmoid_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<28x43x45xf32>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @sigmoid_fp, @sigmoid_fp_arg_0, @sigmoid_fp_res_0
+  spirv.ARM.Graph @sigmoid_fp(%arg0: !spirv.arm.tensor<28x43x45xf32>) -> (!spirv.arm.tensor<28x43x45xf32>) {
+    // CHECK: {{%.*}} = spirv.Tosa.Sigmoid %arg0 : !spirv.arm.tensor<28x43x45xf32> -> !spirv.arm.tensor<28x43x45xf32>
+    %0 = spirv.Tosa.Sigmoid %arg0 : !spirv.arm.tensor<28x43x45xf32> -> !spirv.arm.tensor<28x43x45xf32>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<28x43x45xf32>
+    spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<28x43x45xf32>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Tanh - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @tanh_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<46x50x36xf16>, UniformConstant>
+  spirv.GlobalVariable @tanh_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<46x50x36xf16>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @tanh_fp, @tanh_fp_arg_0, @tanh_fp_res_0
+  spirv.ARM.Graph @tanh_fp(%arg0: !spirv.arm.tensor<46x50x36xf16>) -> (!spirv.arm.tensor<46x50x36xf16>) {
+    // CHECK: {{%.*}} = spirv.Tosa.Tanh %arg0 : !spirv.arm.tensor<46x50x36xf16> -> !spirv.arm.tensor<46x50x36xf16>
+    %0 = spirv.Tosa.Tanh %arg0 : !spirv.arm.tensor<46x50x36xf16> -> !spirv.arm.tensor<46x50x36xf16>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<46x50x36xf16>
+    spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<46x50x36xf16>
+  }
+}
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 0b1771ffcee71..d5859944f5f59 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -557,9 +557,10 @@ static void emitAttributeSerialization(const Attribute &attr,
     os << tabs << "    return failure();\n";
     os << tabs << "  }\n";
     os << tabs << formatv("  {0}.push_back(attrTypeID);\n", operandList);
-  } else if (llvm::is_contained(
-                 {"SPIRV_BoolConstAttr", "SPIRV_TensorArmAxisAttr"},
-                 attr.getAttrDefName())) {
+  } else if (llvm::is_contained({"SPIRV_BoolConstAttr",
+                                 "SPIRV_TensorArmAxisAttr",
+                                 "SPIRV_TosaNumericalAttr"},
+                                attr.getAttrDefName())) {
     os << tabs
        << formatv(
               "  {0}.push_back(prepareConstantScalar({1}.getLoc(), attr));\n",
@@ -863,7 +864,9 @@ static void emitAttributeDeserialization(const Attribute &attr,
        << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
                   "TypeAttr::get(getType({2}[{3}++]))));\n",
                   attrList, attrName, words, wordIndex);
-  } else if (attr.getAttrDefName() == "SPIRV_BoolConstAttr" ||
+  } else if (llvm::is_contained(
+                 {"SPIRV_BoolConstAttr", "SPIRV_TosaNumericalAttr"},
+                 attr.getAttrDefName()) ||
              attr.getAttrDefName().contains("TensorArm")) {
     os << tabs
        << formatv("std::optional<std::pair<Attribute, Type>> c = "



More information about the Mlir-commits mailing list