[Mlir-commits] [mlir] [mlir][spirv] Initial support for TOSA Extended Instruction Set (0010… (PR #174402)

Davide Grohmann llvmlistbot at llvm.org
Mon Jan 19 06:46:42 PST 2026


https://github.com/davidegrohmann updated https://github.com/llvm/llvm-project/pull/174402

>From 2af5501bcc464878fec4b25bd6d9e6c8ceb6c053 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Fri, 7 Nov 2025 15:29:46 +0100
Subject: [PATCH] [mlir][spirv] Initial support for TOSA Extended Instruction
 Set (001000.1)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This patch adds initial support for the TOSA Extended Instruction Set
(001000.1) to the SPIR-V dialect in MLIR. The TOSA extended
instruction set provides a standardized set of machine learning
operations designed to be used within `spirv.ARM.Graph` operations
(corresponding to OpGraphARM in SPV_ARM_graph) and typed with
`!spirv.arm.tensor<...>` (corresponding to OpTypeTensorARM in
SPV_ARM_tensor).

The change introduces:
* Dialect plumbing for import, serialization, and deserialization of
  the TOSA extended instruction set.
* The `spirv.Tosa.ArgMax` operation from TOSA extended instruction, each
  lowering to the corresponding `OpExtInst`.
* Verification enforcing that `spirv.Tosa.ArgMax` appears only within
  `spirv.ARM.Graph` regions, operates on `!spirv.arm.tensor<...>`
  types, and is well-formed according to the TOSA 001000.1
  specification.

Only the ArgMax operation from TOSA 001000.1 extended instructions is
introduced in order to show case the work needed: [arser, printer,
verifier, and round-trip tests using MLIR’s SPIR-V
serialization/deserialization infrastructure are included.

This work aligns with Khronos SPIR-V TOSA specifications.

Specification:
https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html

Change-Id: Ibf2aad7c9e86a9dc28c6133f5d6cb0cd67163ddf
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        | 16 ++++
 .../include/mlir/Dialect/SPIRV/IR/SPIRVOps.td |  1 +
 .../mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td     | 86 +++++++++++++++++++
 .../mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td   | 41 +++++++++
 mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt      |  1 +
 mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp    | 49 +++++++++++
 .../SPIRV/IR/tosa-ops-verification.mlir       | 23 +++++
 mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir      | 23 +++++
 mlir/test/Target/SPIRV/tosa-ops.mlir          | 41 +++++++++
 mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp      | 32 ++++++-
 10 files changed, 310 insertions(+), 3 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
 create mode 100644 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
 create mode 100644 mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
 create mode 100644 mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
 create mode 100644 mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
 create mode 100644 mlir/test/Target/SPIRV/tosa-ops.mlir

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 97ee9e15a68ef..21010d91dc47c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4233,8 +4233,10 @@ def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_s
 def SPIRV_Void : TypeAlias<NoneType, "void">;
 def SPIRV_Bool : TypeAlias<I1, "bool">;
 def SPIRV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>;
+def SPIRV_Int8 : TypeAlias<I8, "Int8">;
 def SPIRV_Int16 : TypeAlias<I16, "Int16">;
 def SPIRV_Int32 : TypeAlias<I32, "Int32">;
+def SPIRV_Float16 : TypeAlias<F16, "Float16">;
 def SPIRV_Float32 : TypeAlias<F32, "Float32">;
 def SPIRV_BFloat16KHR : TypeAlias<BF16, "BFloat16">;
 def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
@@ -4909,4 +4911,18 @@ def SPIRV_FPFastMathModeAttr :
       SPIRV_FPFMM_AllowReassocINTEL
     ]>;
 
+//===----------------------------------------------------------------------===//
+// SPIR-V TOSA enum definitions.
+//===----------------------------------------------------------------------===//
+
+// NOTE: This is an attribute in the SPIR-V *dialect* but a constant (<id>) in
+// SPIR-V proper.
+def SPIRV_TosaExtNaNPropagationModeAttr : SPIRV_I32EnumAttr<
+  "TosaExtNaNPropagationModeType", "Tosa Ext NaN Propoagation Mode Type",
+  "tosa_ext_nan_propagation_mode_type",
+  [
+      I32EnumAttrCase<"Propagate", 1>,
+      I32EnumAttrCase<"Ignore", 2>,
+  ]>;
+
 #endif // MLIR_DIALECT_SPIRV_IR_BASE
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
index 96ef035eda37a..3ef9699154cd1 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
@@ -45,6 +45,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVPrimitiveOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVCLOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
 #endif // MLIR_DIALECT_SPIRV_IR_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
new file mode 100644
index 0000000000000..6c6a318db4827
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -0,0 +1,86 @@
+//===- SPIRVTosaOps.td - TOSA extended insts spec file -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the op definition spec of TOSA extension ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_TOSA_OPS
+#define MLIR_DIALECT_SPIRV_IR_TOSA_OPS
+
+include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// SPIR-V TOSA opcode specification.
+//===----------------------------------------------------------------------===//
+
+// Base class for all TOSA ops.
+class SPIRV_TosaOp<string mnemonic, int opcode, list<Trait> traits = []> :
+  SPIRV_ExtInstOp<mnemonic, "Tosa", "TOSA.001000.1", opcode,
+  !listconcat(traits, [InGraphScope])> {
+
+  let availability = [
+    MinVersion<SPIRV_V_1_5>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_ARM_graph]>,
+    Capability<[SPIRV_C_GraphARM]>
+  ];
+}
+
+
+def SPIRV_TosaArgMaxOp : SPIRV_TosaOp<"ArgMax", 0, [Pure]> {
+  let summary = "Perform argmax on the input.";
+
+  let description = [{
+    Returns the index with the largest value across the given axis of the
+    input tensor. If multiple locations have equal values, returns the first
+    match along the search axis.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_argmax
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_argmax
+
+    #### Example:
+    ```mlir
+    %2 = spirv.Tosa.ArgMax axis = 3, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi32>
+    %2 = spirv.Tosa.ArgMax axis = 3, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<2x2x7x14xf32> -> !spirv.arm.tensor<2x2x14xi32>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_TensorArmAxisAttr: $axis,
+    SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
+    SPIRV_TosaNumerical_TensorArm: $input
+  );
+
+  let results = (outs
+    SPIRV_Int32_TensorArmUpTo5D: $output
+  );
+
+  let hasVerifier = 1;
+
+  let assemblyFormat = [{
+    `axis` `=` $axis `,` `nan_mode` `=` $nan_mode `,`
+    $input
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::spirv::TensorArmType getInputType() {
+      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+    }
+    ::mlir::spirv::TensorArmType getResultType() {
+      return cast<::mlir::spirv::TensorArmType>(getType());
+    }
+  }];
+}
+
+#endif // MLIR_DIALECT_SPIRV_IR_TOSA_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
new file mode 100644
index 0000000000000..e731388182eb4
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -0,0 +1,41 @@
+//===- SPIRVTosaTypes.td - Tosa Types insts spec file --------*- tablegen -*-=//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This specifies Tosa types used by the Graph Extension and Tosa Ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
+#define MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
+
+include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+
+def SPIRV_TosaInteger : AnyIntOfWidths<[8, 16, 32, 64]>;
+def SPIRV_TosaFloat : AnyTypeOf<[SPIRV_Float16, SPIRV_Float32, SPIRV_BFloat16KHR]>;
+def SPIRV_TosaNumerical : AnyTypeOf<[SPIRV_TosaInteger, SPIRV_TosaFloat]>;
+def SPIRV_TosaAny : AnyTypeOf<[SPIRV_TosaNumerical, SPIRV_Bool]>;
+
+def SPIRV_TensorArmAxisAttr : ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<5>]>;
+
+// TensorARM Types
+
+class RankedTensorArmOf<list<Type> allowedTypes, list<Pred> preds = [],
+                     string summary = "ranked tensorArm">
+  : ShapedContainerType<
+      allowedTypes, And<!listconcat([SPIRV_IsTensorArmType], preds)>,
+      summary, "::mlir::spirv::TensorArmType">;
+
+class TensorArmRankOf<list<Type> allowedTypes, list<int> ranks>
+  : RankedTensorArmOf<allowedTypes,
+      [HasAnyRankOfPred<ranks>],
+      !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensorArm">;
+
+def SPIRV_TosaNumerical_TensorArm : TensorArmRankOf<[SPIRV_TosaNumerical], [1, 2, 3, 4, 5, 6]>;
+def SPIRV_Int32_TensorArmUpTo5D : TensorArmRankOf<[SPIRV_Int32], [1, 2, 3, 4, 5]>;
+
+#endif // MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index 60d705d940cfc..f05f596aa9f23 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRSPIRVDialect
   SPIRVOpDefinition.cpp
   SPIRVOps.cpp
   SPIRVParsingUtils.cpp
+  SPIRVTosaOps.cpp
   SPIRVTypes.cpp
   TargetAndABI.cpp
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
new file mode 100644
index 0000000000000..4f3c91d4a1c12
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
@@ -0,0 +1,49 @@
+//===- SPIRVTosaOps.cpp - MLIR SPIR-V Tosa operations ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Tosa operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/TypeUtilities.h"
+
+namespace mlir::spirv {
+
+//===----------------------------------------------------------------------===//
+// TOSA Operator Verifiers.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// spirv.TosaArgmaxOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TosaArgMaxOp::verify() {
+  ShapedType inputTy = getInputType();
+  ShapedType resultTy = getResultType();
+
+  if (inputTy.hasRank() && resultTy.hasRank() &&
+      resultTy.getRank() !=
+          (inputTy.getRank() > 1 ? inputTy.getRank() - 1 : 1)) {
+    return emitOpError(
+               "result rank must be max of 1 and (input rank - 1), got ")
+           << resultTy.getRank();
+  }
+
+  const uint32_t axis = getAxis();
+  if (inputTy.hasRank() && axis >= inputTy.getRank()) {
+    return emitOpError(
+               "specified axis is greater than the rank of input, got axis = ")
+           << axis << " and input rank = " << inputTy.getRank();
+  }
+
+  return success();
+}
+
+} // namespace mlir::spirv
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
new file mode 100644
index 0000000000000..a6496316f9881
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt --verify-diagnostics %s
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @argmax_non_i32_result_element_type(%arg0: !spirv.arm.tensor<3x28x17x17xi8>) -> (!spirv.arm.tensor<3x28x17xi16>) {
+  // expected-error @+1 {{op result #0 must be 1D/2D/3D/4D/5D tensorArm of Int32 values}}
+  %2 = spirv.Tosa.ArgMax axis = 3, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi16>
+  spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<3x28x17xi16>
+}
+
+spirv.ARM.Graph @argmax_incorrect_output_rank(%arg0: !spirv.arm.tensor<3x28x17x17xi8>) -> (!spirv.arm.tensor<3x28xi32>) {
+  // expected-error @+1 {{op result rank must be max of 1 and (input rank - 1), got 2}}
+  %2 = spirv.Tosa.ArgMax axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28xi32>
+  spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<3x28xi32>
+}
+
+spirv.ARM.Graph @argmax_axis_value_not_in_input_rank_range(%arg0: !spirv.arm.tensor<3x28x17x17xi8>) -> (!spirv.arm.tensor<3x28x17xi32>) {
+  // expected-error @+1 {{op specified axis is greater than the rank of input, got axis = 4 and input rank = 4}}
+  %2 = spirv.Tosa.ArgMax axis = 4, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi32>
+  spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<3x28x17xi32>
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
new file mode 100644
index 0000000000000..c9832b903b79e
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @argmax_int(%arg0: !spirv.arm.tensor<3x28x17x17xi8>) -> (!spirv.arm.tensor<3x28x17xi32>) {
+  // CHECK: {{%.*}} = spirv.Tosa.ArgMax axis = 3, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi32>
+  %2 = spirv.Tosa.ArgMax axis = 3, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<3x28x17xi32>
+  spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<3x28x17xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @argmax_fp(%arg0: !spirv.arm.tensor<2x2x7x14xf32>) -> (!spirv.arm.tensor<2x2x14xi32>) {
+  // CHECK: {{%.*}} = spirv.Tosa.ArgMax axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<2x2x7x14xf32> -> !spirv.arm.tensor<2x2x14xi32>
+  %2 = spirv.Tosa.ArgMax axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<2x2x7x14xf32> -> !spirv.arm.tensor<2x2x14xi32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<2x2x14xi32>
+  spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<2x2x14xi32>
+}
diff --git a/mlir/test/Target/SPIRV/tosa-ops.mlir b/mlir/test/Target/SPIRV/tosa-ops.mlir
new file mode 100644
index 0000000000000..8c0429bca68e4
--- /dev/null
+++ b/mlir/test/Target/SPIRV/tosa-ops.mlir
@@ -0,0 +1,41 @@
+// RUN: mlir-translate --no-implicit-module --split-input-file --test-spirv-roundtrip %s | FileCheck %s
+// RUN: %if spirv-tools %{ rm -rf %t %}
+// RUN: %if spirv-tools %{ mkdir %t %}
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module  --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %}
+// RUN: %if spirv-tools %{ spirv-val %t %}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @argmax_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<3x28x17x17xi8>, UniformConstant>
+  spirv.GlobalVariable @argmax_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<3x28x17xi32>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @argmax_int, @argmax_int_arg_0, @argmax_int_res_0
+  spirv.ARM.Graph @argmax_int(%arg0: !spirv.arm.tensor<3x28x17x17xi8>) -> (!spirv.arm.tensor<3x28x17xi32>) {
+    // CHECK: {{%.*}} = spirv.Tosa.ArgMax axis = 3, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi32>
+    %2 = spirv.Tosa.ArgMax axis = 3, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi32>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<3x28x17xi32>
+    spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<3x28x17xi32>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @argmax_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<2x2x7x14xf32>, UniformConstant>
+  spirv.GlobalVariable @argmax_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<2x2x14xi32>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @argmax_fp, @argmax_fp_arg_0, @argmax_fp_res_0
+  spirv.ARM.Graph @argmax_fp(%arg0: !spirv.arm.tensor<2x2x7x14xf32>) -> (!spirv.arm.tensor<2x2x14xi32>) {
+    // CHECK: {{%.*}} = spirv.Tosa.ArgMax axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<2x2x7x14xf32> -> !spirv.arm.tensor<2x2x14xi32>
+    %2 = spirv.Tosa.ArgMax axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<2x2x7x14xf32> -> !spirv.arm.tensor<2x2x14xi32>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<2x2x14xi32>
+    spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<2x2x14xi32>
+  }
+}
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index ca291b57f4344..f3327e31aae04 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -496,9 +496,13 @@ static mlir::GenRegistration
 // directly use the constant value as attribute in SPIR-V dialect. So need
 // to handle them separately from normal enum attributes.
 constexpr llvm::StringLiteral constantIdEnumAttrs[] = {
-    "SPIRV_ScopeAttr", "SPIRV_KHR_CooperativeMatrixUseAttr",
-    "SPIRV_KHR_CooperativeMatrixLayoutAttr", "SPIRV_MemorySemanticsAttr",
-    "SPIRV_MatrixLayoutAttr"};
+    "SPIRV_ScopeAttr",
+    "SPIRV_KHR_CooperativeMatrixUseAttr",
+    "SPIRV_KHR_CooperativeMatrixLayoutAttr",
+    "SPIRV_MemorySemanticsAttr",
+    "SPIRV_MatrixLayoutAttr",
+    "SPIRV_TosaExtNaNPropagationModeAttr",
+};
 
 /// Generates code to serialize attributes of a SPIRV_Op `op` into `os`. The
 /// generates code extracts the attribute with name `attrName` from
@@ -552,6 +556,11 @@ static void emitAttributeSerialization(const Attribute &attr,
     os << tabs << "    return failure();\n";
     os << tabs << "  }\n";
     os << tabs << formatv("  {0}.push_back(attrTypeID);\n", operandList);
+  } else if (attr.getAttrDefName() == "SPIRV_TensorArmAxisAttr") {
+    os << tabs
+       << formatv(
+              "  {0}.push_back(prepareConstantScalar({1}.getLoc(), attr));\n",
+              operandList, opVar);
   } else {
     PrintFatalError(
         loc,
@@ -846,6 +855,23 @@ static void emitAttributeDeserialization(const Attribute &attr,
        << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
                   "TypeAttr::get(getType({2}[{3}++]))));\n",
                   attrList, attrName, words, wordIndex);
+  } else if (attr.getAttrDefName() == "SPIRV_TensorArmAxisAttr") {
+    os << tabs
+       << formatv("std::optional<std::pair<Attribute, Type>> c = "
+                  "getConstant({0}[{1}++]);\n",
+                  words, wordIndex);
+    os << tabs << "if (!c.has_value()) {\n";
+    os << tabs
+       << formatv("  "
+                  "return emitError(unknownLoc, \"could not fetch "
+                  "constant attribute for {0}\") << "
+                  "{1} << \" of \" << {2}.size() << \" processed\";\n",
+                  attrName, wordIndex, words);
+    os << tabs << "}\n";
+    os << tabs
+       << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
+                  "c.value().first));\n",
+                  attrList, attrName);
   } else {
     PrintFatalError(
         loc, llvm::Twine(



More information about the Mlir-commits mailing list