[Mlir-commits] [mlir] 8168577 - [mlir][spirv] Initial support for TOSA Extended Instruction Set (#174402)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 19 08:29:53 PST 2026
Author: Davide Grohmann
Date: 2026-01-19T11:29:47-05:00
New Revision: 8168577795e1e36f73ccba17f566343fb613a3da
URL: https://github.com/llvm/llvm-project/commit/8168577795e1e36f73ccba17f566343fb613a3da
DIFF: https://github.com/llvm/llvm-project/commit/8168577795e1e36f73ccba17f566343fb613a3da.diff
LOG: [mlir][spirv] Initial support for TOSA Extended Instruction Set (#174402)
This patch adds initial support for the TOSA Extended Instruction Set
(001000.1) to the SPIR-V dialect in MLIR. The TOSA extended instruction
set provides a standardized set of machine learning operations designed
to be used within `spirv.ARM.Graph` operations (corresponding to
OpGraphARM in SPV_ARM_graph) and typed with `!spirv.arm.tensor<...>`
(corresponding to OpTypeTensorARM in SPV_ARM_tensor).
The change introduces:
* Dialect plumbing for import, serialization, and deserialization of the
TOSA extended instruction set.
* The `spirv.Tosa.ArgMax` operation from TOSA extended instruction, each
lowering to the corresponding `OpExtInst`.
* Verification enforcing that `spirv.Tosa.ArgMax` appears only within
`spirv.ARM.Graph` regions, operates on `!spirv.arm.tensor<...>` types,
and is well-formed according to the TOSA 001000.1 specification.
Only the ArgMax operation from TOSA 001000.1 extended instructions is
introduced in order to show case the work needed: [arser, printer,
verifier, and round-trip tests using MLIR’s SPIR-V
serialization/deserialization infrastructure are included.
This work aligns with Khronos SPIR-V TOSA specifications.
Specification:
https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Added:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
mlir/test/Target/SPIRV/tosa-ops.mlir
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 97ee9e15a68ef..21010d91dc47c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4233,8 +4233,10 @@ def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_s
def SPIRV_Void : TypeAlias<NoneType, "void">;
def SPIRV_Bool : TypeAlias<I1, "bool">;
def SPIRV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>;
+def SPIRV_Int8 : TypeAlias<I8, "Int8">;
def SPIRV_Int16 : TypeAlias<I16, "Int16">;
def SPIRV_Int32 : TypeAlias<I32, "Int32">;
+def SPIRV_Float16 : TypeAlias<F16, "Float16">;
def SPIRV_Float32 : TypeAlias<F32, "Float32">;
def SPIRV_BFloat16KHR : TypeAlias<BF16, "BFloat16">;
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
@@ -4909,4 +4911,18 @@ def SPIRV_FPFastMathModeAttr :
SPIRV_FPFMM_AllowReassocINTEL
]>;
+//===----------------------------------------------------------------------===//
+// SPIR-V TOSA enum definitions.
+//===----------------------------------------------------------------------===//
+
+// NOTE: This is an attribute in the SPIR-V *dialect* but a constant (<id>) in
+// SPIR-V proper.
+def SPIRV_TosaExtNaNPropagationModeAttr : SPIRV_I32EnumAttr<
+ "TosaExtNaNPropagationModeType", "Tosa Ext NaN Propoagation Mode Type",
+ "tosa_ext_nan_propagation_mode_type",
+ [
+ I32EnumAttrCase<"Propagate", 1>,
+ I32EnumAttrCase<"Ignore", 2>,
+ ]>;
+
#endif // MLIR_DIALECT_SPIRV_IR_BASE
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
index 96ef035eda37a..3ef9699154cd1 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
@@ -45,6 +45,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVPrimitiveOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVCLOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
#endif // MLIR_DIALECT_SPIRV_IR_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
new file mode 100644
index 0000000000000..6c6a318db4827
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -0,0 +1,86 @@
+//===- SPIRVTosaOps.td - TOSA extended insts spec file -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the op definition spec of TOSA extension ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_TOSA_OPS
+#define MLIR_DIALECT_SPIRV_IR_TOSA_OPS
+
+include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// SPIR-V TOSA opcode specification.
+//===----------------------------------------------------------------------===//
+
+// Base class for all TOSA ops.
+class SPIRV_TosaOp<string mnemonic, int opcode, list<Trait> traits = []> :
+ SPIRV_ExtInstOp<mnemonic, "Tosa", "TOSA.001000.1", opcode,
+ !listconcat(traits, [InGraphScope])> {
+
+ let availability = [
+ MinVersion<SPIRV_V_1_5>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[SPV_ARM_graph]>,
+ Capability<[SPIRV_C_GraphARM]>
+ ];
+}
+
+
+def SPIRV_TosaArgMaxOp : SPIRV_TosaOp<"ArgMax", 0, [Pure]> {
+ let summary = "Perform argmax on the input.";
+
+ let description = [{
+ Returns the index with the largest value across the given axis of the
+ input tensor. If multiple locations have equal values, returns the first
+ match along the search axis.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_argmax
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_argmax
+
+ #### Example:
+ ```mlir
+ %2 = spirv.Tosa.ArgMax axis = 3, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi32>
+ %2 = spirv.Tosa.ArgMax axis = 3, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<2x2x7x14xf32> -> !spirv.arm.tensor<2x2x14xi32>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_TensorArmAxisAttr: $axis,
+ SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
+ SPIRV_TosaNumerical_TensorArm: $input
+ );
+
+ let results = (outs
+ SPIRV_Int32_TensorArmUpTo5D: $output
+ );
+
+ let hasVerifier = 1;
+
+ let assemblyFormat = [{
+ `axis` `=` $axis `,` `nan_mode` `=` $nan_mode `,`
+ $input
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::spirv::TensorArmType getInputType() {
+ return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+ }
+ ::mlir::spirv::TensorArmType getResultType() {
+ return cast<::mlir::spirv::TensorArmType>(getType());
+ }
+ }];
+}
+
+#endif // MLIR_DIALECT_SPIRV_IR_TOSA_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
new file mode 100644
index 0000000000000..e731388182eb4
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -0,0 +1,41 @@
+//===- SPIRVTosaTypes.td - Tosa Types insts spec file --------*- tablegen -*-=//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This specifies Tosa types used by the Graph Extension and Tosa Ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
+#define MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
+
+include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+
+def SPIRV_TosaInteger : AnyIntOfWidths<[8, 16, 32, 64]>;
+def SPIRV_TosaFloat : AnyTypeOf<[SPIRV_Float16, SPIRV_Float32, SPIRV_BFloat16KHR]>;
+def SPIRV_TosaNumerical : AnyTypeOf<[SPIRV_TosaInteger, SPIRV_TosaFloat]>;
+def SPIRV_TosaAny : AnyTypeOf<[SPIRV_TosaNumerical, SPIRV_Bool]>;
+
+def SPIRV_TensorArmAxisAttr : ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<5>]>;
+
+// TensorARM Types
+
+class RankedTensorArmOf<list<Type> allowedTypes, list<Pred> preds = [],
+ string summary = "ranked tensorArm">
+ : ShapedContainerType<
+ allowedTypes, And<!listconcat([SPIRV_IsTensorArmType], preds)>,
+ summary, "::mlir::spirv::TensorArmType">;
+
+class TensorArmRankOf<list<Type> allowedTypes, list<int> ranks>
+ : RankedTensorArmOf<allowedTypes,
+ [HasAnyRankOfPred<ranks>],
+ !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensorArm">;
+
+def SPIRV_TosaNumerical_TensorArm : TensorArmRankOf<[SPIRV_TosaNumerical], [1, 2, 3, 4, 5, 6]>;
+def SPIRV_Int32_TensorArmUpTo5D : TensorArmRankOf<[SPIRV_Int32], [1, 2, 3, 4, 5]>;
+
+#endif // MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index 60d705d940cfc..f05f596aa9f23 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRSPIRVDialect
SPIRVOpDefinition.cpp
SPIRVOps.cpp
SPIRVParsingUtils.cpp
+ SPIRVTosaOps.cpp
SPIRVTypes.cpp
TargetAndABI.cpp
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
new file mode 100644
index 0000000000000..4f3c91d4a1c12
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp
@@ -0,0 +1,49 @@
+//===- SPIRVTosaOps.cpp - MLIR SPIR-V Tosa operations ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Tosa operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/TypeUtilities.h"
+
+namespace mlir::spirv {
+
+//===----------------------------------------------------------------------===//
+// TOSA Operator Verifiers.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// spirv.TosaArgmaxOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TosaArgMaxOp::verify() {
+ ShapedType inputTy = getInputType();
+ ShapedType resultTy = getResultType();
+
+ if (inputTy.hasRank() && resultTy.hasRank() &&
+ resultTy.getRank() !=
+ (inputTy.getRank() > 1 ? inputTy.getRank() - 1 : 1)) {
+ return emitOpError(
+ "result rank must be max of 1 and (input rank - 1), got ")
+ << resultTy.getRank();
+ }
+
+ const uint32_t axis = getAxis();
+ if (inputTy.hasRank() && axis >= inputTy.getRank()) {
+ return emitOpError(
+ "specified axis is greater than the rank of input, got axis = ")
+ << axis << " and input rank = " << inputTy.getRank();
+ }
+
+ return success();
+}
+
+} // namespace mlir::spirv
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
new file mode 100644
index 0000000000000..a6496316f9881
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt --verify-diagnostics %s
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @argmax_non_i32_result_element_type(%arg0: !spirv.arm.tensor<3x28x17x17xi8>) -> (!spirv.arm.tensor<3x28x17xi16>) {
+ // expected-error @+1 {{op result #0 must be 1D/2D/3D/4D/5D tensorArm of Int32 values}}
+ %2 = spirv.Tosa.ArgMax axis = 3, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi16>
+ spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<3x28x17xi16>
+}
+
+spirv.ARM.Graph @argmax_incorrect_output_rank(%arg0: !spirv.arm.tensor<3x28x17x17xi8>) -> (!spirv.arm.tensor<3x28xi32>) {
+ // expected-error @+1 {{op result rank must be max of 1 and (input rank - 1), got 2}}
+ %2 = spirv.Tosa.ArgMax axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28xi32>
+ spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<3x28xi32>
+}
+
+spirv.ARM.Graph @argmax_axis_value_not_in_input_rank_range(%arg0: !spirv.arm.tensor<3x28x17x17xi8>) -> (!spirv.arm.tensor<3x28x17xi32>) {
+ // expected-error @+1 {{op specified axis is greater than the rank of input, got axis = 4 and input rank = 4}}
+ %2 = spirv.Tosa.ArgMax axis = 4, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi32>
+ spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<3x28x17xi32>
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
new file mode 100644
index 0000000000000..c9832b903b79e
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @argmax_int(%arg0: !spirv.arm.tensor<3x28x17x17xi8>) -> (!spirv.arm.tensor<3x28x17xi32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.ArgMax axis = 3, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi32>
+ %2 = spirv.Tosa.ArgMax axis = 3, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<3x28x17xi32>
+ spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<3x28x17xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @argmax_fp(%arg0: !spirv.arm.tensor<2x2x7x14xf32>) -> (!spirv.arm.tensor<2x2x14xi32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.ArgMax axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<2x2x7x14xf32> -> !spirv.arm.tensor<2x2x14xi32>
+ %2 = spirv.Tosa.ArgMax axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<2x2x7x14xf32> -> !spirv.arm.tensor<2x2x14xi32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<2x2x14xi32>
+ spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<2x2x14xi32>
+}
diff --git a/mlir/test/Target/SPIRV/tosa-ops.mlir b/mlir/test/Target/SPIRV/tosa-ops.mlir
new file mode 100644
index 0000000000000..8c0429bca68e4
--- /dev/null
+++ b/mlir/test/Target/SPIRV/tosa-ops.mlir
@@ -0,0 +1,41 @@
+// RUN: mlir-translate --no-implicit-module --split-input-file --test-spirv-roundtrip %s | FileCheck %s
+// RUN: %if spirv-tools %{ rm -rf %t %}
+// RUN: %if spirv-tools %{ mkdir %t %}
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %}
+// RUN: %if spirv-tools %{ spirv-val %t %}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @argmax_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<3x28x17x17xi8>, UniformConstant>
+ spirv.GlobalVariable @argmax_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<3x28x17xi32>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @argmax_int, @argmax_int_arg_0, @argmax_int_res_0
+ spirv.ARM.Graph @argmax_int(%arg0: !spirv.arm.tensor<3x28x17x17xi8>) -> (!spirv.arm.tensor<3x28x17xi32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.ArgMax axis = 3, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi32>
+ %2 = spirv.Tosa.ArgMax axis = 3, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<3x28x17xi32>
+ spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<3x28x17xi32>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @argmax_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<2x2x7x14xf32>, UniformConstant>
+ spirv.GlobalVariable @argmax_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<2x2x14xi32>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @argmax_fp, @argmax_fp_arg_0, @argmax_fp_res_0
+ spirv.ARM.Graph @argmax_fp(%arg0: !spirv.arm.tensor<2x2x7x14xf32>) -> (!spirv.arm.tensor<2x2x14xi32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.ArgMax axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<2x2x7x14xf32> -> !spirv.arm.tensor<2x2x14xi32>
+ %2 = spirv.Tosa.ArgMax axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<2x2x7x14xf32> -> !spirv.arm.tensor<2x2x14xi32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<2x2x14xi32>
+ spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<2x2x14xi32>
+ }
+}
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index ca291b57f4344..f3327e31aae04 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -496,9 +496,13 @@ static mlir::GenRegistration
// directly use the constant value as attribute in SPIR-V dialect. So need
// to handle them separately from normal enum attributes.
constexpr llvm::StringLiteral constantIdEnumAttrs[] = {
- "SPIRV_ScopeAttr", "SPIRV_KHR_CooperativeMatrixUseAttr",
- "SPIRV_KHR_CooperativeMatrixLayoutAttr", "SPIRV_MemorySemanticsAttr",
- "SPIRV_MatrixLayoutAttr"};
+ "SPIRV_ScopeAttr",
+ "SPIRV_KHR_CooperativeMatrixUseAttr",
+ "SPIRV_KHR_CooperativeMatrixLayoutAttr",
+ "SPIRV_MemorySemanticsAttr",
+ "SPIRV_MatrixLayoutAttr",
+ "SPIRV_TosaExtNaNPropagationModeAttr",
+};
/// Generates code to serialize attributes of a SPIRV_Op `op` into `os`. The
/// generates code extracts the attribute with name `attrName` from
@@ -552,6 +556,11 @@ static void emitAttributeSerialization(const Attribute &attr,
os << tabs << " return failure();\n";
os << tabs << " }\n";
os << tabs << formatv(" {0}.push_back(attrTypeID);\n", operandList);
+ } else if (attr.getAttrDefName() == "SPIRV_TensorArmAxisAttr") {
+ os << tabs
+ << formatv(
+ " {0}.push_back(prepareConstantScalar({1}.getLoc(), attr));\n",
+ operandList, opVar);
} else {
PrintFatalError(
loc,
@@ -846,6 +855,23 @@ static void emitAttributeDeserialization(const Attribute &attr,
<< formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
"TypeAttr::get(getType({2}[{3}++]))));\n",
attrList, attrName, words, wordIndex);
+ } else if (attr.getAttrDefName() == "SPIRV_TensorArmAxisAttr") {
+ os << tabs
+ << formatv("std::optional<std::pair<Attribute, Type>> c = "
+ "getConstant({0}[{1}++]);\n",
+ words, wordIndex);
+ os << tabs << "if (!c.has_value()) {\n";
+ os << tabs
+ << formatv(" "
+ "return emitError(unknownLoc, \"could not fetch "
+ "constant attribute for {0}\") << "
+ "{1} << \" of \" << {2}.size() << \" processed\";\n",
+ attrName, wordIndex, words);
+ os << tabs << "}\n";
+ os << tabs
+ << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
+ "c.value().first));\n",
+ attrList, attrName);
} else {
PrintFatalError(
loc, llvm::Twine(
More information about the Mlir-commits
mailing list