[Mlir-commits] [mlir] [mlir][spirv] Initial support for TOSA Extended Instruction Set (0010… (PR #174402)

Davide Grohmann llvmlistbot at llvm.org
Mon Jan 5 05:14:34 PST 2026


https://github.com/davidegrohmann created https://github.com/llvm/llvm-project/pull/174402

…00.1)

This patch adds initial support for the TOSA Extended Instruction Set (001000.1) to the SPIR-V dialect in MLIR. The TOSA extended instruction set provides a standardized set of machine learning operations designed to be used within `spirv.ARM.Graph` operations (corresponding to OpGraphARM in SPV_ARM_graph) and typed with `!spirv.arm.tensor<...>` (corresponding to OpTypeTensorARM in SPV_ARM_tensor).

The change introduces:
* Dialect plumbing for import, serialization, and deserialization of the TOSA extended instruction set.
* The `spirv.Tosa.ArgMax` operation from TOSA extended instruction, each lowering to the corresponding `OpExtInst`.
* Verification enforcing that `spirv.Tosa.ArgMax` appears only within `spirv.ARM.Graph` regions, operates on `!spirv.arm.tensor<...>` types, and is well-formed according to the TOSA 001000.1 specification.

Only the ArgMax operation from TOSA 001000.1 extended instructions is introduced in order to show case the work needed: [arser, printer, verifier, and round-trip tests using MLIR’s SPIR-V serialization/deserialization infrastructure are included.

This work aligns with Khronos SPIR-V TOSA specifications.

Specification:
https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html


Change-Id: Ibf2aad7c9e86a9dc28c6133f5d6cb0cd67163ddf

>From bfcf19fcaf5338d6c87889519cebf2986cd34cb9 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Fri, 7 Nov 2025 15:29:46 +0100
Subject: [PATCH] [mlir][spirv] Initial support for TOSA Extended Instruction
 Set (001000.1)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This patch adds initial support for the TOSA Extended Instruction Set
(001000.1) to the SPIR-V dialect in MLIR. The TOSA extended
instruction set provides a standardized set of machine learning
operations designed to be used within `spirv.ARM.Graph` operations
(corresponding to OpGraphARM in SPV_ARM_graph) and typed with
`!spirv.arm.tensor<...>` (corresponding to OpTypeTensorARM in
SPV_ARM_tensor).

The change introduces:
* Dialect plumbing for import, serialization, and deserialization of
  the TOSA extended instruction set.
* The `spirv.Tosa.ArgMax` operation from TOSA extended instruction, each
  lowering to the corresponding `OpExtInst`.
* Verification enforcing that `spirv.Tosa.ArgMax` appears only within
  `spirv.ARM.Graph` regions, operates on `!spirv.arm.tensor<...>`
  types, and is well-formed according to the TOSA 001000.1
  specification.

Only the ArgMax operation from TOSA 001000.1 extended instructions is
introduced in order to show case the work needed: [arser, printer,
verifier, and round-trip tests using MLIR’s SPIR-V
serialization/deserialization infrastructure are included.

This work aligns with Khronos SPIR-V TOSA specifications.

Specification:
https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: Ibf2aad7c9e86a9dc28c6133f5d6cb0cd67163ddf
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        | 14 ++++
 .../include/mlir/Dialect/SPIRV/IR/SPIRVOps.td |  1 +
 .../mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td     | 64 +++++++++++++++++
 .../mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td   | 39 ++++++++++
 mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt      |  1 +
 mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp    | 72 +++++++++++++++++++
 mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir      | 27 +++++++
 mlir/test/Target/SPIRV/tosa-ops.mlir          | 43 +++++++++++
 mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp      | 10 ++-
 9 files changed, 268 insertions(+), 3 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
 create mode 100644 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
 create mode 100644 mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
 create mode 100644 mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
 create mode 100644 mlir/test/Target/SPIRV/tosa-ops.mlir

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index ecbbf39a534e1..5677f825a8aec 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4233,6 +4233,7 @@ def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_s
 def SPIRV_Void : TypeAlias<NoneType, "void">;
 def SPIRV_Bool : TypeAlias<I1, "bool">;
 def SPIRV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>;
+def SPIRV_Int8 : TypeAlias<I8, "Int8">;
 def SPIRV_Int16 : TypeAlias<I16, "Int16">;
 def SPIRV_Int32 : TypeAlias<I32, "Int32">;
 def SPIRV_Float32 : TypeAlias<F32, "Float32">;
@@ -4908,4 +4909,17 @@ def SPIRV_FPFastMathModeAttr :
       SPIRV_FPFMM_AllowReassocINTEL
     ]>;
 
+//===----------------------------------------------------------------------===//
+// SPIR-V TOSA enum definitions.
+//===----------------------------------------------------------------------===//
+
+// NOTE: This is an attribute in the SPIR-V *dialect* but a constant (<id>) in
+// SPIR-V proper.
+def SPIRV_TosaExtNanPropagationModeAttr : SPIRV_I32EnumAttr<
+  "TosaExtNanPropagationModeType", "Tosa Ext NAN Propoagation Mode Type", "tosa_ext_nan_propagation_mode_type",
+  [
+      I32EnumAttrCase<"PROPAGATE", 1>,
+      I32EnumAttrCase<"IGNORE", 2>,
+  ]>;
+
 #endif // MLIR_DIALECT_SPIRV_IR_BASE
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
index 96ef035eda37a..3ef9699154cd1 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
@@ -45,6 +45,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVPrimitiveOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVCLOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
 #endif // MLIR_DIALECT_SPIRV_IR_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
new file mode 100644
index 0000000000000..f545c4c2fa1ae
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -0,0 +1,64 @@
+//===- SPIRVTosaOps.td - TOSA extended insts spec file -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the op definition spec of TOSA extension ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_TOSA_OPS
+#define MLIR_DIALECT_SPIRV_IR_TOSA_OPS
+
+include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// SPIR-V TOSA opcode specification.
+//===----------------------------------------------------------------------===//
+
+// Base class for all TOSA ops.
+class SPIRV_TosaOp<string mnemonic, int opcode, list<Trait> traits = []> :
+  SPIRV_ExtInstOp<mnemonic, "Tosa", "TOSA.001000.1", opcode, !listconcat(traits, [InGraphScope])> {
+
+  let availability = [
+    MinVersion<SPIRV_V_1_5>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_ARM_graph, SPV_ARM_tensors]>,
+    Capability<[SPIRV_C_GraphARM]>
+  ];
+}
+
+
+def SPIRV_TosaArgMaxOp : SPIRV_TosaOp<"ArgMax", 0, [Pure]> {
+ let summary = "ArgMax - TOSA extended instruction set 001000.1";
+
+  let description = [{
+    https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_argmax
+  }];
+
+
+  let arguments = (ins
+    SPIRV_Int32: $axis,
+    SPIRV_TosaExtNanPropagationModeAttr: $nan_mode,
+    SPIRV_TosaNumerical_TensorArm: $input
+  );
+
+
+  let results = (outs
+    SPIRV_TosaInteger_TensorArmUpTo5D: $output
+  );
+
+  let hasVerifier = 1;
+
+  let assemblyFormat = [{
+    operands attr-dict `:` `(` type(operands) `)` `->` type(results)
+  }];
+}
+
+#endif // MLIR_DIALECT_SPIRV_IR_TOSA_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
new file mode 100644
index 0000000000000..9081c5fa6c156
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -0,0 +1,39 @@
+//===- SPIRVTosaTypes.td - Tosa Types insts spec file ----*- tablegen -*-=//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This specifies Tosa types used by the Graph Extension and Tosa Ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
+#define MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
+
+include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+
+def SPIRV_TosaInteger : AnyIntOfWidths<[8, 16, 32, 64]>;
+def SPIRV_TosaFloat : AnyTypeOf<[SPIRV_BFloat16KHR, SPIRV_Float16or32]>;
+def SPIRV_TosaNumerical : AnyTypeOf<[SPIRV_TosaInteger, SPIRV_TosaFloat]>;
+def SPIRV_TosaAny : AnyTypeOf<[SPIRV_TosaNumerical, SPIRV_Bool]>;
+
+// TensorARM Types
+
+class RankedTensorArmOf<list<Type> allowedTypes, list<Pred> preds = [],
+                     string summary = "ranked tensorArm">
+  : ShapedContainerType<
+      allowedTypes, And<!listconcat([SPIRV_IsTensorArmType], preds)>,
+      summary, "::mlir::spirv::TensorArmType">;
+
+class TensorArmRankOf<list<Type> allowedTypes, list<int> ranks>
+  : RankedTensorArmOf<allowedTypes,
+      [HasAnyRankOfPred<ranks>],
+      !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensorArm">;
+
+def SPIRV_TosaNumerical_TensorArm : TensorArmRankOf<[SPIRV_TosaNumerical], [1, 2, 3, 4, 5, 6]>;
+def SPIRV_TosaInteger_TensorArmUpTo5D : TensorArmRankOf<[SPIRV_TosaInteger], [1, 2, 3, 4, 5]>;
+
+#endif // MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index 60d705d940cfc..f05f596aa9f23 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRSPIRVDialect
   SPIRVOpDefinition.cpp
   SPIRVOps.cpp
   SPIRVParsingUtils.cpp
+  SPIRVTosaOps.cpp
   SPIRVTypes.cpp
   TargetAndABI.cpp
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
new file mode 100644
index 0000000000000..31f7734dd8495
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
@@ -0,0 +1,72 @@
+//===- SPIRVTosaOps.cpp - MLIR SPIR-V operations --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Tosa operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// TOSA Operator Verifiers.
+//===----------------------------------------------------------------------===//
+
+// Get value attr from spirv::ConstantOp or
+// spirv::EXTConstantCompositeReplicateOp
+template <typename TAttr>
+static LogicalResult getConstAttr(Value value, TAttr &valAttr) {
+  if (auto constOp = value.template getDefiningOp<spirv::ConstantOp>()) {
+    valAttr = dyn_cast<TAttr>(constOp.getValue());
+  } else if (auto constCompositeReplicateOp =
+                 value.template getDefiningOp<
+                     spirv::EXTConstantCompositeReplicateOp>()) {
+    auto splatAttr = constCompositeReplicateOp.getValue();
+    auto denseValAttr = SplatElementsAttr::get(
+        cast<ShapedType>(constCompositeReplicateOp.getType()), splatAttr);
+    valAttr = dyn_cast<TAttr>(denseValAttr);
+  }
+
+  return valAttr ? success() : failure();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TosaArgmaxOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::TosaArgMaxOp::verify() {
+  auto inputTy = cast<ShapedType>(getInput().getType());
+  auto resultTy = cast<ShapedType>(getType());
+
+  if (inputTy.hasRank() && resultTy.hasRank() &&
+      resultTy.getRank() !=
+          (inputTy.getRank() > 1 ? inputTy.getRank() - 1 : 1)) {
+    return emitOpError("result rank must be max of 1 and (input rank - 1)");
+  }
+
+  auto resultETy = resultTy.getElementType();
+  if (!resultETy.isIntOrIndex()) {
+    return emitOpError("result is not of integer type");
+  }
+
+  IntegerAttr axisAttr;
+  if (getConstAttr(getAxis(), axisAttr).failed()) {
+    return emitOpError("axis type must be a constant integer");
+  }
+
+  const int axis = axisAttr.getInt();
+  if (inputTy.hasRank() && ((axis < 0) || axis >= inputTy.getRank())) {
+    return emitOpError("specified axis is outside the rank of input");
+  }
+
+  return success();
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
new file mode 100644
index 0000000000000..48a3735c2c596
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @argmax_int(%arg0: !spirv.arm.tensor<3x28x17x17xi8>) -> (!spirv.arm.tensor<3x28x17xi32>) {
+  %0 = spirv.Constant 3 : i32
+  // CHECK: {{%.*}} = spirv.Tosa.ArgMax {{%.*}}, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<3x28x17x17xi8>) -> !spirv.arm.tensor<3x28x17xi32>
+  %2 = spirv.Tosa.ArgMax %0, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<3x28x17x17xi8>) -> !spirv.arm.tensor<3x28x17xi32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<3x28x17xi32>
+  spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<3x28x17xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @argmax_fp(%arg0: !spirv.arm.tensor<2x2x7x14xf32>) -> (!spirv.arm.tensor<2x2x14xi32>) {
+  %0 = spirv.Constant 2 : i32
+  // CHECK: {{%.*}} = spirv.Tosa.ArgMax {{%.*}}, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<2x2x7x14xf32>) -> !spirv.arm.tensor<2x2x14xi32>
+  %2 = spirv.Tosa.ArgMax %0, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<2x2x7x14xf32>) -> !spirv.arm.tensor<2x2x14xi32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<2x2x14xi32>
+  spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<2x2x14xi32>
+}
diff --git a/mlir/test/Target/SPIRV/tosa-ops.mlir b/mlir/test/Target/SPIRV/tosa-ops.mlir
new file mode 100644
index 0000000000000..53e92050f5d32
--- /dev/null
+++ b/mlir/test/Target/SPIRV/tosa-ops.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-translate --no-implicit-module --split-input-file --verify-diagnostics --test-spirv-roundtrip %s | FileCheck %s
+// RUN: %if spirv-tools %{ rm -rf %t %}
+// RUN: %if spirv-tools %{ mkdir %t %}
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module  --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %}
+// RUN: %if spirv-tools %{ spirv-val %t %}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @argmax_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<3x28x17x17xi8>, UniformConstant>
+  spirv.GlobalVariable @argmax_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<3x28x17xi32>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @argmax_int, @argmax_int_arg_0, @argmax_int_res_0
+  spirv.ARM.Graph @argmax_int(%arg0: !spirv.arm.tensor<3x28x17x17xi8>) -> (!spirv.arm.tensor<3x28x17xi32>) {
+    %0 = spirv.Constant 3 : i32
+    // CHECK: {{%.*}} = spirv.Tosa.ArgMax {{%.*}}, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<3x28x17x17xi8>) -> !spirv.arm.tensor<3x28x17xi32>
+    %2 = spirv.Tosa.ArgMax %0, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<3x28x17x17xi8>) -> !spirv.arm.tensor<3x28x17xi32>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<3x28x17xi32>
+    spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<3x28x17xi32>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @argmax_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<2x2x7x14xf32>, UniformConstant>
+  spirv.GlobalVariable @argmax_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<2x2x14xi32>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @argmax_fp, @argmax_fp_arg_0, @argmax_fp_res_0
+  spirv.ARM.Graph @argmax_fp(%arg0: !spirv.arm.tensor<2x2x7x14xf32>) -> (!spirv.arm.tensor<2x2x14xi32>) {
+    %0 = spirv.Constant 2 : i32
+    // CHECK: {{%.*}} = spirv.Tosa.ArgMax {{%.*}}, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<2x2x7x14xf32>) -> !spirv.arm.tensor<2x2x14xi32>
+    %2 = spirv.Tosa.ArgMax %0, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<PROPAGATE>} : (i32, !spirv.arm.tensor<2x2x7x14xf32>) -> !spirv.arm.tensor<2x2x14xi32>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<2x2x14xi32>
+    spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<2x2x14xi32>
+  }
+}
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index ca291b57f4344..83d3322ebfe13 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -496,9 +496,13 @@ static mlir::GenRegistration
 // directly use the constant value as attribute in SPIR-V dialect. So need
 // to handle them separately from normal enum attributes.
 constexpr llvm::StringLiteral constantIdEnumAttrs[] = {
-    "SPIRV_ScopeAttr", "SPIRV_KHR_CooperativeMatrixUseAttr",
-    "SPIRV_KHR_CooperativeMatrixLayoutAttr", "SPIRV_MemorySemanticsAttr",
-    "SPIRV_MatrixLayoutAttr"};
+    "SPIRV_ScopeAttr",
+    "SPIRV_KHR_CooperativeMatrixUseAttr",
+    "SPIRV_KHR_CooperativeMatrixLayoutAttr",
+    "SPIRV_MemorySemanticsAttr",
+    "SPIRV_MatrixLayoutAttr",
+    "SPIRV_TosaExtNanPropagationModeAttr",
+};
 
 /// Generates code to serialize attributes of a SPIRV_Op `op` into `os`. The
 /// generates code extracts the attribute with name `attrName` from



More information about the Mlir-commits mailing list