[Mlir-commits] [mlir] [mlir][spirv] Initial support for TOSA Extended Instruction Set (0010… (PR #174402)

Davide Grohmann llvmlistbot at llvm.org
Fri Jan 9 05:54:52 PST 2026


https://github.com/davidegrohmann updated https://github.com/llvm/llvm-project/pull/174402

>From 26b0e67b76ae9e947452141c183105b5fe754d56 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Fri, 7 Nov 2025 15:29:46 +0100
Subject: [PATCH] [mlir][spirv] Initial support for TOSA Extended Instruction
 Set (001000.1)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This patch adds initial support for the TOSA Extended Instruction Set
(001000.1) to the SPIR-V dialect in MLIR. The TOSA extended
instruction set provides a standardized set of machine learning
operations designed to be used within `spirv.ARM.Graph` operations
(corresponding to OpGraphARM in SPV_ARM_graph) and typed with
`!spirv.arm.tensor<...>` (corresponding to OpTypeTensorARM in
SPV_ARM_tensor).

The change introduces:
* Dialect plumbing for import, serialization, and deserialization of
  the TOSA extended instruction set.
* The `spirv.Tosa.ArgMax` operation from TOSA extended instruction, each
  lowering to the corresponding `OpExtInst`.
* Verification enforcing that `spirv.Tosa.ArgMax` appears only within
  `spirv.ARM.Graph` regions, operates on `!spirv.arm.tensor<...>`
  types, and is well-formed according to the TOSA 001000.1
  specification.

Only the ArgMax operation from TOSA 001000.1 extended instructions is
introduced in order to show case the work needed: [arser, printer,
verifier, and round-trip tests using MLIR’s SPIR-V
serialization/deserialization infrastructure are included.

This work aligns with Khronos SPIR-V TOSA specifications.

Specification:
https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html

Change-Id: Ibf2aad7c9e86a9dc28c6133f5d6cb0cd67163ddf
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        | 15 ++++
 .../include/mlir/Dialect/SPIRV/IR/SPIRVOps.td |  1 +
 .../mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td     | 74 +++++++++++++++++++
 .../mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td   | 39 ++++++++++
 mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt      |  1 +
 mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp    | 55 ++++++++++++++
 mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir      | 27 +++++++
 mlir/test/Target/SPIRV/tosa-ops.mlir          | 43 +++++++++++
 mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp      | 10 ++-
 9 files changed, 262 insertions(+), 3 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
 create mode 100644 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
 create mode 100644 mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
 create mode 100644 mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
 create mode 100644 mlir/test/Target/SPIRV/tosa-ops.mlir

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 97ee9e15a68ef..ca6872c56ad06 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4233,8 +4233,10 @@ def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_s
 def SPIRV_Void : TypeAlias<NoneType, "void">;
 def SPIRV_Bool : TypeAlias<I1, "bool">;
 def SPIRV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>;
+def SPIRV_Int8 : TypeAlias<I8, "Int8">;
 def SPIRV_Int16 : TypeAlias<I16, "Int16">;
 def SPIRV_Int32 : TypeAlias<I32, "Int32">;
+def SPIRV_Float16 : TypeAlias<F16, "Float16">;
 def SPIRV_Float32 : TypeAlias<F32, "Float32">;
 def SPIRV_BFloat16KHR : TypeAlias<BF16, "BFloat16">;
 def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
@@ -4909,4 +4911,17 @@ def SPIRV_FPFastMathModeAttr :
       SPIRV_FPFMM_AllowReassocINTEL
     ]>;
 
+//===----------------------------------------------------------------------===//
+// SPIR-V TOSA enum definitions.
+//===----------------------------------------------------------------------===//
+
+// NOTE: This is an attribute in the SPIR-V *dialect* but a constant (<id>) in
+// SPIR-V proper.
+def SPIRV_TosaExtNaNPropagationModeAttr : SPIRV_I32EnumAttr<
+  "TosaExtNaNPropagationModeType", "Tosa Ext NaN Propoagation Mode Type", "tosa_ext_nan_propagation_mode_type",
+  [
+      I32EnumAttrCase<"Propagate", 1>,
+      I32EnumAttrCase<"Ignore", 2>,
+  ]>;
+
 #endif // MLIR_DIALECT_SPIRV_IR_BASE
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
index 96ef035eda37a..3ef9699154cd1 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
@@ -45,6 +45,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVPrimitiveOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVCLOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
 #endif // MLIR_DIALECT_SPIRV_IR_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
new file mode 100644
index 0000000000000..5afa0ca9440bb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -0,0 +1,74 @@
+//===- SPIRVTosaOps.td - TOSA extended insts spec file -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the op definition spec of TOSA extension ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_TOSA_OPS
+#define MLIR_DIALECT_SPIRV_IR_TOSA_OPS
+
+include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// SPIR-V TOSA opcode specification.
+//===----------------------------------------------------------------------===//
+
+// Base class for all TOSA ops.
+class SPIRV_TosaOp<string mnemonic, int opcode, list<Trait> traits = []> :
+  SPIRV_ExtInstOp<mnemonic, "Tosa", "TOSA.001000.1", opcode, !listconcat(traits, [InGraphScope])> {
+
+  let availability = [
+    MinVersion<SPIRV_V_1_5>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_ARM_graph, SPV_ARM_tensors]>,
+    Capability<[SPIRV_C_GraphARM]>
+  ];
+}
+
+
+def SPIRV_TosaArgMaxOp : SPIRV_TosaOp<"ArgMax", 0, [Pure]> {
+ let summary = "Perform argmax on the input.";
+
+  let description = [{
+    Returns the index with the largest value across the given axis of the
+    input tensor. If multiple locations have equal values, returns the first
+    match along the search axis.
+
+    References:
+    https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_argmax
+    https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_argmax
+
+    #### Example:
+    ```mlir
+    %2 = spirv.Tosa.ArgMax %0, %1 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<Propagate>} : i32, !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi32>
+    %2 = spirv.Tosa.ArgMax %0, %1 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<Propagate>} : i32, !spirv.arm.tensor<2x2x7x14xf32> -> !spirv.arm.tensor<2x2x14xi32>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_Int32: $axis,
+    SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
+    SPIRV_TosaNumerical_TensorArm: $input
+  );
+
+  let results = (outs
+    SPIRV_TosaInteger_TensorArmUpTo5D: $output
+  );
+
+  let hasVerifier = 1;
+
+  let assemblyFormat = [{
+    operands attr-dict `:` type(operands) `->` type(results)
+  }];
+}
+
+#endif // MLIR_DIALECT_SPIRV_IR_TOSA_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
new file mode 100644
index 0000000000000..40ad96813c1bd
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -0,0 +1,39 @@
+//===- SPIRVTosaTypes.td - Tosa Types insts spec file --------*- tablegen -*-=//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This specifies Tosa types used by the Graph Extension and Tosa Ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
+#define MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
+
+include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+
+def SPIRV_TosaInteger : AnyIntOfWidths<[8, 16, 32, 64]>;
+def SPIRV_TosaFloat : AnyTypeOf<[SPIRV_Float16, SPIRV_Float32, SPIRV_BFloat16KHR]>;
+def SPIRV_TosaNumerical : AnyTypeOf<[SPIRV_TosaInteger, SPIRV_TosaFloat]>;
+def SPIRV_TosaAny : AnyTypeOf<[SPIRV_TosaNumerical, SPIRV_Bool]>;
+
+// TensorARM Types
+
+class RankedTensorArmOf<list<Type> allowedTypes, list<Pred> preds = [],
+                     string summary = "ranked tensorArm">
+  : ShapedContainerType<
+      allowedTypes, And<!listconcat([SPIRV_IsTensorArmType], preds)>,
+      summary, "::mlir::spirv::TensorArmType">;
+
+class TensorArmRankOf<list<Type> allowedTypes, list<int> ranks>
+  : RankedTensorArmOf<allowedTypes,
+      [HasAnyRankOfPred<ranks>],
+      !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensorArm">;
+
+def SPIRV_TosaNumerical_TensorArm : TensorArmRankOf<[SPIRV_TosaNumerical], [1, 2, 3, 4, 5, 6]>;
+def SPIRV_TosaInteger_TensorArmUpTo5D : TensorArmRankOf<[SPIRV_TosaInteger], [1, 2, 3, 4, 5]>;
+
+#endif // MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index 60d705d940cfc..f05f596aa9f23 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRSPIRVDialect
   SPIRVOpDefinition.cpp
   SPIRVOps.cpp
   SPIRVParsingUtils.cpp
+  SPIRVTosaOps.cpp
   SPIRVTypes.cpp
   TargetAndABI.cpp
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
new file mode 100644
index 0000000000000..838c1782dc1b2
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
@@ -0,0 +1,55 @@
+//===- SPIRVTosaOps.cpp - MLIR SPIR-V operations --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Tosa operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/TypeUtilities.h"
+
+namespace mlir::spirv {
+
+//===----------------------------------------------------------------------===//
+// TOSA Operator Verifiers.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// spirv.TosaArgmaxOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TosaArgMaxOp::verify() {
+  auto inputTy = cast<ShapedType>(getInput().getType());
+  auto resultTy = cast<ShapedType>(getType());
+
+  if (inputTy.hasRank() && resultTy.hasRank() &&
+      resultTy.getRank() !=
+          (inputTy.getRank() > 1 ? inputTy.getRank() - 1 : 1)) {
+    return emitOpError("result rank must be max of 1 and (input rank - 1)");
+  }
+
+  Type resultETy = resultTy.getElementType();
+  if (!resultETy.isIntOrIndex()) {
+    return emitOpError("result is not of integer type");
+  }
+
+  IntegerAttr axisAttr;
+  if (!matchPattern(getAxis(), m_Constant(&axisAttr))) {
+    return emitOpError("axis type must be a constant integer");
+  }
+
+  const int axis = axisAttr.getInt();
+  if (inputTy.hasRank() && ((axis < 0) || axis >= inputTy.getRank())) {
+    return emitOpError("specified axis is outside the rank of input");
+  }
+
+  return success();
+}
+
+} // namespace mlir::spirv
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
new file mode 100644
index 0000000000000..09c4e9a287d7d
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt --split-input-file %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @argmax_int(%arg0: !spirv.arm.tensor<3x28x17x17xi8>) -> (!spirv.arm.tensor<3x28x17xi32>) {
+  %0 = spirv.Constant 3 : i32
+  // CHECK: {{%.*}} = spirv.Tosa.ArgMax {{%.*}}, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<Propagate>} : i32, !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi32>
+  %2 = spirv.Tosa.ArgMax %0, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<Propagate>} : i32, !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<3x28x17xi32>
+  spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<3x28x17xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @argmax_fp(%arg0: !spirv.arm.tensor<2x2x7x14xf32>) -> (!spirv.arm.tensor<2x2x14xi32>) {
+  %0 = spirv.Constant 2 : i32
+  // CHECK: {{%.*}} = spirv.Tosa.ArgMax {{%.*}}, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<Propagate>} : i32, !spirv.arm.tensor<2x2x7x14xf32> -> !spirv.arm.tensor<2x2x14xi32>
+  %2 = spirv.Tosa.ArgMax %0, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<Propagate>} : i32, !spirv.arm.tensor<2x2x7x14xf32> -> !spirv.arm.tensor<2x2x14xi32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<2x2x14xi32>
+  spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<2x2x14xi32>
+}
diff --git a/mlir/test/Target/SPIRV/tosa-ops.mlir b/mlir/test/Target/SPIRV/tosa-ops.mlir
new file mode 100644
index 0000000000000..5306209f90dac
--- /dev/null
+++ b/mlir/test/Target/SPIRV/tosa-ops.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-translate --no-implicit-module --split-input-file --test-spirv-roundtrip %s | FileCheck %s
+// RUN: %if spirv-tools %{ rm -rf %t %}
+// RUN: %if spirv-tools %{ mkdir %t %}
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module  --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %}
+// RUN: %if spirv-tools %{ spirv-val %t %}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @argmax_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<3x28x17x17xi8>, UniformConstant>
+  spirv.GlobalVariable @argmax_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<3x28x17xi32>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @argmax_int, @argmax_int_arg_0, @argmax_int_res_0
+  spirv.ARM.Graph @argmax_int(%arg0: !spirv.arm.tensor<3x28x17x17xi8>) -> (!spirv.arm.tensor<3x28x17xi32>) {
+    %0 = spirv.Constant 3 : i32
+    // CHECK: {{%.*}} = spirv.Tosa.ArgMax {{%.*}}, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<Propagate>} : i32, !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi32>
+    %2 = spirv.Tosa.ArgMax %0, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<Propagate>} : i32, !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi32>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<3x28x17xi32>
+    spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<3x28x17xi32>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @argmax_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<2x2x7x14xf32>, UniformConstant>
+  spirv.GlobalVariable @argmax_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<2x2x14xi32>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @argmax_fp, @argmax_fp_arg_0, @argmax_fp_res_0
+  spirv.ARM.Graph @argmax_fp(%arg0: !spirv.arm.tensor<2x2x7x14xf32>) -> (!spirv.arm.tensor<2x2x14xi32>) {
+    %0 = spirv.Constant 2 : i32
+    // CHECK: {{%.*}} = spirv.Tosa.ArgMax {{%.*}}, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<Propagate>} : i32, !spirv.arm.tensor<2x2x7x14xf32> -> !spirv.arm.tensor<2x2x14xi32>
+    %2 = spirv.Tosa.ArgMax %0, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<Propagate>} : i32, !spirv.arm.tensor<2x2x7x14xf32> -> !spirv.arm.tensor<2x2x14xi32>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<2x2x14xi32>
+    spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<2x2x14xi32>
+  }
+}
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index ca291b57f4344..d24eb74278e89 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -496,9 +496,13 @@ static mlir::GenRegistration
 // directly use the constant value as attribute in SPIR-V dialect. So need
 // to handle them separately from normal enum attributes.
 constexpr llvm::StringLiteral constantIdEnumAttrs[] = {
-    "SPIRV_ScopeAttr", "SPIRV_KHR_CooperativeMatrixUseAttr",
-    "SPIRV_KHR_CooperativeMatrixLayoutAttr", "SPIRV_MemorySemanticsAttr",
-    "SPIRV_MatrixLayoutAttr"};
+    "SPIRV_ScopeAttr",
+    "SPIRV_KHR_CooperativeMatrixUseAttr",
+    "SPIRV_KHR_CooperativeMatrixLayoutAttr",
+    "SPIRV_MemorySemanticsAttr",
+    "SPIRV_MatrixLayoutAttr",
+    "SPIRV_TosaExtNaNPropagationModeAttr",
+};
 
 /// Generates code to serialize attributes of a SPIRV_Op `op` into `os`. The
 /// generates code extracts the attribute with name `attrName` from



More information about the Mlir-commits mailing list