[Mlir-commits] [mlir] [mlir][spirv] Initial support for TOSA Extended Instruction Set (0010… (PR #174402)
Davide Grohmann
llvmlistbot at llvm.org
Fri Jan 9 05:51:12 PST 2026
https://github.com/davidegrohmann updated https://github.com/llvm/llvm-project/pull/174402
>From 047593b43c953a42a29eea9ddbb7d8b5dea4fac9 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Fri, 7 Nov 2025 15:29:46 +0100
Subject: [PATCH] [mlir][spirv] Initial support for TOSA Extended Instruction
Set (001000.1)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
This patch adds initial support for the TOSA Extended Instruction Set
(001000.1) to the SPIR-V dialect in MLIR. The TOSA extended
instruction set provides a standardized set of machine learning
operations designed to be used within `spirv.ARM.Graph` operations
(corresponding to OpGraphARM in SPV_ARM_graph) and typed with
`!spirv.arm.tensor<...>` (corresponding to OpTypeTensorARM in
SPV_ARM_tensor).
The change introduces:
* Dialect plumbing for import, serialization, and deserialization of
the TOSA extended instruction set.
* The `spirv.Tosa.ArgMax` operation from TOSA extended instruction, each
lowering to the corresponding `OpExtInst`.
* Verification enforcing that `spirv.Tosa.ArgMax` appears only within
`spirv.ARM.Graph` regions, operates on `!spirv.arm.tensor<...>`
types, and is well-formed according to the TOSA 001000.1
specification.
Only the ArgMax operation from TOSA 001000.1 extended instructions is
introduced in order to show case the work needed: [arser, printer,
verifier, and round-trip tests using MLIR’s SPIR-V
serialization/deserialization infrastructure are included.
This work aligns with Khronos SPIR-V TOSA specifications.
Specification:
https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html
Change-Id: Ibf2aad7c9e86a9dc28c6133f5d6cb0cd67163ddf
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
---
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 15 ++++
.../include/mlir/Dialect/SPIRV/IR/SPIRVOps.td | 1 +
.../mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td | 74 +++++++++++++++++++
.../mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td | 39 ++++++++++
mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt | 1 +
mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp | 55 ++++++++++++++
mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir | 27 +++++++
mlir/test/Target/SPIRV/tosa-ops.mlir | 43 +++++++++++
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | 10 ++-
9 files changed, 262 insertions(+), 3 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
create mode 100644 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
create mode 100644 mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
create mode 100644 mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
create mode 100644 mlir/test/Target/SPIRV/tosa-ops.mlir
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 97ee9e15a68ef..ca6872c56ad06 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4233,8 +4233,10 @@ def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_s
def SPIRV_Void : TypeAlias<NoneType, "void">;
def SPIRV_Bool : TypeAlias<I1, "bool">;
def SPIRV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>;
+def SPIRV_Int8 : TypeAlias<I8, "Int8">;
def SPIRV_Int16 : TypeAlias<I16, "Int16">;
def SPIRV_Int32 : TypeAlias<I32, "Int32">;
+def SPIRV_Float16 : TypeAlias<F16, "Float16">;
def SPIRV_Float32 : TypeAlias<F32, "Float32">;
def SPIRV_BFloat16KHR : TypeAlias<BF16, "BFloat16">;
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
@@ -4909,4 +4911,17 @@ def SPIRV_FPFastMathModeAttr :
SPIRV_FPFMM_AllowReassocINTEL
]>;
+//===----------------------------------------------------------------------===//
+// SPIR-V TOSA enum definitions.
+//===----------------------------------------------------------------------===//
+
+// NOTE: This is an attribute in the SPIR-V *dialect* but a constant (<id>) in
+// SPIR-V proper.
+def SPIRV_TosaExtNaNPropagationModeAttr : SPIRV_I32EnumAttr<
+ "TosaExtNaNPropagationModeType", "Tosa Ext NaN Propoagation Mode Type", "tosa_ext_nan_propagation_mode_type",
+ [
+ I32EnumAttrCase<"Propagate", 1>,
+ I32EnumAttrCase<"Ignore", 2>,
+ ]>;
+
#endif // MLIR_DIALECT_SPIRV_IR_BASE
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
index 96ef035eda37a..3ef9699154cd1 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
@@ -45,6 +45,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVPrimitiveOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVCLOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
#endif // MLIR_DIALECT_SPIRV_IR_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
new file mode 100644
index 0000000000000..5afa0ca9440bb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -0,0 +1,74 @@
+//===- SPIRVTosaOps.td - TOSA extended insts spec file -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the op definition spec of TOSA extension ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_TOSA_OPS
+#define MLIR_DIALECT_SPIRV_IR_TOSA_OPS
+
+include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// SPIR-V TOSA opcode specification.
+//===----------------------------------------------------------------------===//
+
+// Base class for all TOSA ops.
+class SPIRV_TosaOp<string mnemonic, int opcode, list<Trait> traits = []> :
+ SPIRV_ExtInstOp<mnemonic, "Tosa", "TOSA.001000.1", opcode, !listconcat(traits, [InGraphScope])> {
+
+ let availability = [
+ MinVersion<SPIRV_V_1_5>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[SPV_ARM_graph, SPV_ARM_tensors]>,
+ Capability<[SPIRV_C_GraphARM]>
+ ];
+}
+
+
+def SPIRV_TosaArgMaxOp : SPIRV_TosaOp<"ArgMax", 0, [Pure]> {
+ let summary = "Perform argmax on the input.";
+
+ let description = [{
+ Returns the index with the largest value across the given axis of the
+ input tensor. If multiple locations have equal values, returns the first
+ match along the search axis.
+
+ References:
+ https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_argmax
+ https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_argmax
+
+ #### Example:
+ ```mlir
+ %2 = spirv.Tosa.ArgMax %0, %1 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<Propagate>} : i32, !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi32>
+ %2 = spirv.Tosa.ArgMax %0, %1 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<Propagate>} : i32, !spirv.arm.tensor<2x2x7x14xf32> -> !spirv.arm.tensor<2x2x14xi32>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_Int32: $axis,
+ SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
+ SPIRV_TosaNumerical_TensorArm: $input
+ );
+
+ let results = (outs
+ SPIRV_TosaInteger_TensorArmUpTo5D: $output
+ );
+
+ let hasVerifier = 1;
+
+ let assemblyFormat = [{
+ operands attr-dict `:` type(operands) `->` type(results)
+ }];
+}
+
+#endif // MLIR_DIALECT_SPIRV_IR_TOSA_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
new file mode 100644
index 0000000000000..002ffc886698e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -0,0 +1,39 @@
+//===- SPIRVTosaTypes.td - Tosa Types insts spec file ----*- tablegen -*-=//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This specifies Tosa types used by the Graph Extension and Tosa Ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
+#define MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
+
+include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+
+def SPIRV_TosaInteger : AnyIntOfWidths<[8, 16, 32, 64]>;
+def SPIRV_TosaFloat : AnyTypeOf<[SPIRV_Float16, SPIRV_Float32, SPIRV_BFloat16KHR]>;
+def SPIRV_TosaNumerical : AnyTypeOf<[SPIRV_TosaInteger, SPIRV_TosaFloat]>;
+def SPIRV_TosaAny : AnyTypeOf<[SPIRV_TosaNumerical, SPIRV_Bool]>;
+
+// TensorARM Types
+
+class RankedTensorArmOf<list<Type> allowedTypes, list<Pred> preds = [],
+ string summary = "ranked tensorArm">
+ : ShapedContainerType<
+ allowedTypes, And<!listconcat([SPIRV_IsTensorArmType], preds)>,
+ summary, "::mlir::spirv::TensorArmType">;
+
+class TensorArmRankOf<list<Type> allowedTypes, list<int> ranks>
+ : RankedTensorArmOf<allowedTypes,
+ [HasAnyRankOfPred<ranks>],
+ !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensorArm">;
+
+def SPIRV_TosaNumerical_TensorArm : TensorArmRankOf<[SPIRV_TosaNumerical], [1, 2, 3, 4, 5, 6]>;
+def SPIRV_TosaInteger_TensorArmUpTo5D : TensorArmRankOf<[SPIRV_TosaInteger], [1, 2, 3, 4, 5]>;
+
+#endif // MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index 60d705d940cfc..f05f596aa9f23 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRSPIRVDialect
SPIRVOpDefinition.cpp
SPIRVOps.cpp
SPIRVParsingUtils.cpp
+ SPIRVTosaOps.cpp
SPIRVTypes.cpp
TargetAndABI.cpp
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
new file mode 100644
index 0000000000000..838c1782dc1b2
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
@@ -0,0 +1,55 @@
+//===- SPIRVTosaOps.cpp - MLIR SPIR-V operations --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Tosa operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/TypeUtilities.h"
+
+namespace mlir::spirv {
+
+//===----------------------------------------------------------------------===//
+// TOSA Operator Verifiers.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// spirv.TosaArgmaxOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TosaArgMaxOp::verify() {
+ auto inputTy = cast<ShapedType>(getInput().getType());
+ auto resultTy = cast<ShapedType>(getType());
+
+ if (inputTy.hasRank() && resultTy.hasRank() &&
+ resultTy.getRank() !=
+ (inputTy.getRank() > 1 ? inputTy.getRank() - 1 : 1)) {
+ return emitOpError("result rank must be max of 1 and (input rank - 1)");
+ }
+
+ Type resultETy = resultTy.getElementType();
+ if (!resultETy.isIntOrIndex()) {
+ return emitOpError("result is not of integer type");
+ }
+
+ IntegerAttr axisAttr;
+ if (!matchPattern(getAxis(), m_Constant(&axisAttr))) {
+ return emitOpError("axis type must be a constant integer");
+ }
+
+ const int axis = axisAttr.getInt();
+ if (inputTy.hasRank() && ((axis < 0) || axis >= inputTy.getRank())) {
+ return emitOpError("specified axis is outside the rank of input");
+ }
+
+ return success();
+}
+
+} // namespace mlir::spirv
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
new file mode 100644
index 0000000000000..09c4e9a287d7d
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt --split-input-file %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @argmax_int(%arg0: !spirv.arm.tensor<3x28x17x17xi8>) -> (!spirv.arm.tensor<3x28x17xi32>) {
+ %0 = spirv.Constant 3 : i32
+ // CHECK: {{%.*}} = spirv.Tosa.ArgMax {{%.*}}, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<Propagate>} : i32, !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi32>
+ %2 = spirv.Tosa.ArgMax %0, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<Propagate>} : i32, !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<3x28x17xi32>
+ spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<3x28x17xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @argmax_fp(%arg0: !spirv.arm.tensor<2x2x7x14xf32>) -> (!spirv.arm.tensor<2x2x14xi32>) {
+ %0 = spirv.Constant 2 : i32
+ // CHECK: {{%.*}} = spirv.Tosa.ArgMax {{%.*}}, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<Propagate>} : i32, !spirv.arm.tensor<2x2x7x14xf32> -> !spirv.arm.tensor<2x2x14xi32>
+ %2 = spirv.Tosa.ArgMax %0, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<Propagate>} : i32, !spirv.arm.tensor<2x2x7x14xf32> -> !spirv.arm.tensor<2x2x14xi32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<2x2x14xi32>
+ spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<2x2x14xi32>
+}
diff --git a/mlir/test/Target/SPIRV/tosa-ops.mlir b/mlir/test/Target/SPIRV/tosa-ops.mlir
new file mode 100644
index 0000000000000..5306209f90dac
--- /dev/null
+++ b/mlir/test/Target/SPIRV/tosa-ops.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-translate --no-implicit-module --split-input-file --test-spirv-roundtrip %s | FileCheck %s
+// RUN: %if spirv-tools %{ rm -rf %t %}
+// RUN: %if spirv-tools %{ mkdir %t %}
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %}
+// RUN: %if spirv-tools %{ spirv-val %t %}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @argmax_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<3x28x17x17xi8>, UniformConstant>
+ spirv.GlobalVariable @argmax_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<3x28x17xi32>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @argmax_int, @argmax_int_arg_0, @argmax_int_res_0
+ spirv.ARM.Graph @argmax_int(%arg0: !spirv.arm.tensor<3x28x17x17xi8>) -> (!spirv.arm.tensor<3x28x17xi32>) {
+ %0 = spirv.Constant 3 : i32
+ // CHECK: {{%.*}} = spirv.Tosa.ArgMax {{%.*}}, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<Propagate>} : i32, !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi32>
+ %2 = spirv.Tosa.ArgMax %0, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<Propagate>} : i32, !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<3x28x17xi32>
+ spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<3x28x17xi32>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @argmax_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<2x2x7x14xf32>, UniformConstant>
+ spirv.GlobalVariable @argmax_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<2x2x14xi32>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @argmax_fp, @argmax_fp_arg_0, @argmax_fp_res_0
+ spirv.ARM.Graph @argmax_fp(%arg0: !spirv.arm.tensor<2x2x7x14xf32>) -> (!spirv.arm.tensor<2x2x14xi32>) {
+ %0 = spirv.Constant 2 : i32
+ // CHECK: {{%.*}} = spirv.Tosa.ArgMax {{%.*}}, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<Propagate>} : i32, !spirv.arm.tensor<2x2x7x14xf32> -> !spirv.arm.tensor<2x2x14xi32>
+ %2 = spirv.Tosa.ArgMax %0, %arg0 {nan_mode = #spirv.tosa_ext_nan_propagation_mode_type<Propagate>} : i32, !spirv.arm.tensor<2x2x7x14xf32> -> !spirv.arm.tensor<2x2x14xi32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<2x2x14xi32>
+ spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<2x2x14xi32>
+ }
+}
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index ca291b57f4344..d24eb74278e89 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -496,9 +496,13 @@ static mlir::GenRegistration
// directly use the constant value as attribute in SPIR-V dialect. So need
// to handle them separately from normal enum attributes.
constexpr llvm::StringLiteral constantIdEnumAttrs[] = {
- "SPIRV_ScopeAttr", "SPIRV_KHR_CooperativeMatrixUseAttr",
- "SPIRV_KHR_CooperativeMatrixLayoutAttr", "SPIRV_MemorySemanticsAttr",
- "SPIRV_MatrixLayoutAttr"};
+ "SPIRV_ScopeAttr",
+ "SPIRV_KHR_CooperativeMatrixUseAttr",
+ "SPIRV_KHR_CooperativeMatrixLayoutAttr",
+ "SPIRV_MemorySemanticsAttr",
+ "SPIRV_MatrixLayoutAttr",
+ "SPIRV_TosaExtNaNPropagationModeAttr",
+};
/// Generates code to serialize attributes of a SPIRV_Op `op` into `os`. The
/// generates code extracts the attribute with name `attrName` from
More information about the Mlir-commits
mailing list