[Mlir-commits] [mlir] [mlir][tosa] Add specification versioning to target environment (PR #156425)
Luke Hutton
llvmlistbot at llvm.org
Tue Sep 2 01:43:28 PDT 2025
https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/156425
This commit adds a new "specification_version" field to the TOSA target environment attribute. This allows a user to specify which version of the TOSA specification they would like to target during lowering.
A leading example in the validation pass has also been added. This addition adds a version to each profile compliance entry to track which version of the specification the entry was added. This allows a backwards compatibility check to be implemented between the target version and the profile compliance entry version.
For now a default version of "1.0" is assumed. "1.1.draft" is added to denote an in-development version of the specification targeting the next release.
Note: this PR is marked as draft as it is dependent on the following PR's being merged prior to it: #153771, #155799
>From f262af4ec13edc5461b66fbbed742673bf7eb1c4 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Thu, 14 Aug 2025 16:18:10 +0000
Subject: [PATCH 1/3] [mlir][tosa] Add the concept of a TOSA target envrionment
This commit introduces a new module-level attribute `tosa.target_env`.
IT encapsulates target information for use during compilation such as:
level, profiles and extensions. For example:
```mlir
module attributes {tosa.target_env =
#tosa.target_env<level = none, profiles = [pro_int], extensions = [int16, int4]>} {
<my-tosa-program>
}
```
Previously the validation pass accepted target infomation as a series of
command line pass options. This commit changes the behaviour to query
the attached target environment from the module attribute. This
refactoring allows other passes to query the same target information.
A new target environment can be atached using the `--tosa-attach-target`
pass, which takes the same command line options as the previous validation
pass arguments. For example:
```bash
mlir-opt --tosa-attach-target="profiles=pro_int extensions=int4,int16 level=none" test.mlir
```
Change-Id: I74a254855f6320dc70b29ae3509997764e3e5d95
---
.../Conversion/TosaToLinalg/TosaToLinalg.h | 3 +-
mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h | 56 +++++++++--
.../mlir/Dialect/Tosa/IR/TosaOpBase.td | 50 ++++++++--
.../Dialect/Tosa/Transforms/CMakeLists.txt | 2 -
.../mlir/Dialect/Tosa/Transforms/Passes.h | 1 -
.../mlir/Dialect/Tosa/Transforms/Passes.td | 64 ++++++++-----
.../TosaToLinalg/TosaToLinalgPass.cpp | 3 -
mlir/lib/Dialect/Tosa/CMakeLists.txt | 1 +
mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp | 42 ++++++++
.../Dialect/Tosa/Transforms/CMakeLists.txt | 1 +
.../Tosa/Transforms/TosaAttachTarget.cpp | 87 +++++++++++++++++
.../Tosa/Transforms/TosaValidation.cpp | 96 ++++---------------
mlir/test/Dialect/Tosa/dynamic_extension.mlir | 2 +-
mlir/test/Dialect/Tosa/error_if_check.mlir | 2 +-
mlir/test/Dialect/Tosa/invalid.mlir | 2 +-
mlir/test/Dialect/Tosa/invalid_extension.mlir | 2 +-
mlir/test/Dialect/Tosa/level_check.mlir | 2 +-
.../Dialect/Tosa/profile_all_unsupported.mlir | 2 +-
.../Tosa/profile_pro_fp_unsupported.mlir | 2 +-
.../Tosa/profile_pro_int_unsupported.mlir | 2 +-
.../test/Dialect/Tosa/tosa-attach-target.mlir | 14 +++
.../Dialect/Tosa/tosa-validation-valid.mlir | 2 +-
22 files changed, 309 insertions(+), 129 deletions(-)
create mode 100644 mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
create mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
create mode 100644 mlir/test/Dialect/Tosa/tosa-attach-target.mlir
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index f4823858e3893..ab9b9f24ef3dd 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -39,8 +39,7 @@ void addTosaToLinalgPasses(
TosaToLinalgNamedOptions(),
// Note: Default to 'none' level unless otherwise specified.
std::optional<tosa::TosaValidationOptions> validationOptions =
- tosa::TosaValidationOptions{
- {"none"}, {"none"}, false, false, tosa::TosaLevelEnum::None});
+ tosa::TosaValidationOptions{false, false});
/// Populates TOSA to linalg pipelines
/// Currently, this includes only the "tosa-to-linalg-pipeline".
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
index 9ee5079559d2b..10491f65d37af 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
@@ -20,24 +20,67 @@
namespace mlir {
namespace tosa {
+struct TosaLevel {
+ int32_t MAX_RANK = 0;
+ int32_t MAX_KERNEL = 0;
+ int32_t MAX_STRIDE = 0;
+ int32_t MAX_SCALE = 0;
+ int32_t MAX_LOG2_SIZE = 0;
+ int32_t MAX_NESTING = 0;
+ int32_t MAX_TENSOR_LIST_SIZE = 0;
+
+ bool operator==(const TosaLevel &rhs) {
+ return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
+ MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
+ MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
+ MAX_NESTING == rhs.MAX_NESTING &&
+ MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
+ }
+};
+
+static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
+static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
+ 63, 256, 256};
+
+TargetEnvAttr lookupTargetEnv(Operation *op);
+TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);
+
+/// Queries the target environment recursively from enclosing symbol table ops
+/// containing the given `op` or returns the default target environment as
+/// returned by getDefaultTargetEnv() if not provided.
+TargetEnvAttr lookupTargetEnvOrDefault(Operation *op);
+
/// This class represents the capability enabled in the target implementation
-/// such as profile, extension, and level.
+/// such as profile, extension, and level. It's a wrapper class around
+/// tosa::TargetEnvAttr.
class TargetEnv {
public:
TargetEnv() {}
- explicit TargetEnv(const SmallVectorImpl<Profile> &profiles,
- const SmallVectorImpl<Extension> &extensions) {
+ explicit TargetEnv(Level level, const ArrayRef<Profile> &profiles,
+ const ArrayRef<Extension> &extensions)
+ : level(level) {
enabledProfiles.insert_range(profiles);
-
enabledExtensions.insert_range(extensions);
}
+ explicit TargetEnv(TargetEnvAttr targetAttr)
+ : TargetEnv(targetAttr.getLevel(), targetAttr.getProfiles(),
+ targetAttr.getExtensions()) {}
+
void addProfile(Profile p) { enabledProfiles.insert(p); }
void addExtension(Extension e) { enabledExtensions.insert(e); }
// TODO implement the following utilities.
// Version getSpecVersion() const;
- // TosaLevel getLevel() const;
+
+ TosaLevel getLevel() const {
+ if (level == Level::eightK)
+ return TOSA_LEVEL_EIGHTK;
+ else if (level == Level::none)
+ return TOSA_LEVEL_NONE;
+ else
+ llvm_unreachable("Unknown TOSA level");
+ };
// Returns true if the given profile is allowed.
bool allows(Profile prof) const { return enabledProfiles.count(prof) != 0; }
@@ -62,8 +105,9 @@ class TargetEnv {
}
private:
+ Level level;
llvm::SmallSet<Profile, 3> enabledProfiles;
- llvm::SmallSet<Extension, 8> enabledExtensions;
+ llvm::SmallSet<Extension, 13> enabledExtensions;
};
} // namespace tosa
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index e048f8af7cc33..9e9ff21db1fa1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -245,6 +245,19 @@ def Tosa_NONE : I32EnumAttrCase<"none", 0>;
def Tosa_PRO_INT : I32EnumAttrCase<"pro_int", 1>;
def Tosa_PRO_FP : I32EnumAttrCase<"pro_fp", 2>;
+def Tosa_ProfileAttr
+ : Tosa_I32EnumAttr<"Profile", "supported TOSA profiles", "prof",
+ [Tosa_PRO_INT, Tosa_PRO_FP, Tosa_NONE]> {
+ let extraClassDeclaration = [{
+ static llvm::SmallVector<Profile, 2> getAllValues() {
+ return {Profile::pro_int, Profile::pro_fp};
+ }
+ }];
+}
+
+def Tosa_ProfileArrayAttr
+ : TypedArrayAttrBase<Tosa_ProfileAttr, "TOSA profile array attribute">;
+
def Tosa_EXT_NONE : I32EnumAttrCase<"none", 0>;
def Tosa_EXT_INT16 : I32EnumAttrCase<"int16", 1>;
def Tosa_EXT_INT4 : I32EnumAttrCase<"int4", 2>;
@@ -264,17 +277,27 @@ def Tosa_ExtensionAttr
Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE,
Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND,
Tosa_EXT_DYNAMIC
- ]>;
+ ]> {
+ let extraClassDeclaration = [{
+ static llvm::SmallVector<Extension, 11> getAllValues() {
+ return {
+ Extension::int16, Extension::int4, Extension::bf16,
+ Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft,
+ Extension::variable, Extension::controlflow, Extension::doubleround,
+ Extension::inexactround, Extension::dynamic
+ };
+ }
+ }];
+}
def Tosa_ExtensionArrayAttr
: TypedArrayAttrBase<Tosa_ExtensionAttr, "TOSA extension array attribute">;
-def Tosa_ProfileAttr
- : Tosa_I32EnumAttr<"Profile", "supported TOSA profiles", "prof",
- [Tosa_PRO_INT, Tosa_PRO_FP, Tosa_NONE]>;
+def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>;
+def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">;
-def Tosa_ProfileArrayAttr
- : TypedArrayAttrBase<Tosa_ProfileAttr, "TOSA profile array attribute">;
+def Tosa_LevelAttr
+ : Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>;
// The base class for defining op availability dimensions.
class Availability {
@@ -381,6 +404,21 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability {
let instance = "ref";
}
+//===----------------------------------------------------------------------===//
+// TOSA target environment.
+//===----------------------------------------------------------------------===//
+def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> {
+ let summary = "Target environment information.";
+ let parameters = ( ins
+ "Level": $level,
+ ArrayRefParameter<"Profile">: $profiles,
+ ArrayRefParameter<"Extension">: $extensions
+ );
+
+ let assemblyFormat = "`<` `level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` "
+ "`extensions` `=` `[` $extensions `]` `>`";
+}
+
//===----------------------------------------------------------------------===//
// TOSA Interfaces.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
index 7484473c0db23..f52b82a964da9 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,7 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name TosaOpt)
-mlir_tablegen(PassesEnums.h.inc -gen-enum-decls)
-mlir_tablegen(PassesEnums.cpp.inc -gen-enum-defs)
add_mlir_dialect_tablegen_target(MLIRTosaPassIncGen)
add_mlir_doc(Passes TosaPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 306e4b1f218e7..ba99d2f1d2727 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -15,7 +15,6 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/Dialect/Tosa/Transforms/PassesEnums.h.inc"
#include "mlir/Pass/Pass.h"
namespace mlir {
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index b96682843538c..6ae19d81e0820 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -65,14 +65,6 @@ def TosaOptionalDecompositionsPass
}];
}
-def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level",
- [
- I32EnumAttrCase<"None", 0, "none">,
- I32EnumAttrCase<"EightK", 1, "8k">,
- ]>{
- let cppNamespace = "mlir::tosa";
-}
-
def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
let summary = "Validates TOSA dialect";
let description = [{
@@ -81,10 +73,6 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
}];
let options = [
- ListOption<"profile", "profile", "std::string",
- "Validate if operations match for the given profile set">,
- ListOption<"extension", "extension", "std::string",
- "Validate if operations match for the given extension set">,
Option<"strictOpSpecAlignment", "strict-op-spec-alignment", "bool",
/*default=*/"false",
"Verify if the properties of certain operations align the spec requirement">,
@@ -92,17 +80,7 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
/*default=*/"false",
"Disable checks for operations that are determined to be invalid due to their "
"operand/result datatypes not aligning with the 'Supported Data Types' "
- "sections of the specifciation">,
- Option<"level", "level", "mlir::tosa::TosaLevelEnum",
- /*default=*/"mlir::tosa::TosaLevelEnum::EightK",
- "Validate if operator parameters are within specfication for the given level",
- [{::llvm::cl::values(
- clEnumValN(mlir::tosa::TosaLevelEnum::EightK, "8k",
- "Ranges are expected to be sufficient for applications with frame sizes up to 8K."),
- clEnumValN(mlir::tosa::TosaLevelEnum::None, "none",
- "Allows the full range of arguments specified by the operations according "
- "to the operation data types.")
- )}]>
+ "sections of the specifciation">
];
}
@@ -141,4 +119,44 @@ def TosaConvertIntegerTypeToSignless : Pass<"tosa-convert-integer-type-to-signle
}];
}
+def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> {
+ let summary = "Attach tosa.target_env information to the given module.";
+
+ let description = [{
+ This pass allows the user to specify a TOSA target environment consisting of
+ the following components: level, profiles and extensions.
+
+ The target environment is attached to the module as an attribute, allowing other
+ transformations to query the selected target and adapt their behaviour based on
+ this information.
+ }];
+
+ let dependentDialects = [
+ "func::FuncDialect",
+ "tosa::TosaDialect",
+ ];
+
+ let options = [
+ Option<"level", "level", "mlir::tosa::Level",
+ /*default=*/"mlir::tosa::Level::eightK",
+ "The TOSA level that operators should conform to. A TOSA level defines "
+ "operator argument ranges that an implementation shall support.",
+ [{::llvm::cl::values(
+ clEnumValN(mlir::tosa::Level::eightK, "8k",
+ "Ranges are expected to be sufficient for applications with frame "
+ "sizes up to 8K."),
+ clEnumValN(mlir::tosa::Level::none, "none",
+ "Allows the full range of arguments specified by the operations according "
+ "to the operation data types.")
+ )}]>,
+ ListOption<"profiles", "profiles", "std::string",
+ "The TOSA profile(s) that operators should conform to. TOSA profiles "
+ "enable efficient implementation on different classes of device. Each "
+ "profile is an independent set of operations and data type combinations.">,
+ ListOption<"extensions", "extensions", "std::string",
+ "The TOSA extension(s) that operators should conform to. TOSA profile "
+ "extensions define optional operation and data type combinations.">
+ ];
+}
+
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index c6a3ba9f1439f..e7602b4508cf1 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -115,11 +115,8 @@ void mlir::tosa::registerTosaToLinalgPipelines() {
TosaToLinalgOptions tosaToLinalgOptions;
TosaToLinalgNamedOptions tosaToLinalgNamedOptions;
TosaValidationOptions validationOptions;
- validationOptions.profile = {"none"};
- validationOptions.extension = {"none"};
validationOptions.strictOpSpecAlignment = false;
validationOptions.allowInvalidOpDatatypeCombinations = false;
- validationOptions.level = tosa::TosaLevelEnum::EightK;
tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
tosaToLinalgNamedOptions,
validationOptions);
diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt
index c6a438d348946..a95906aa8352e 100644
--- a/mlir/lib/Dialect/Tosa/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRTosaDialect
IR/TosaOps.cpp
IR/TosaCanonicalizations.cpp
+ IR/TargetEnv.cpp
Utils/ConversionUtils.cpp
Utils/QuantUtils.cpp
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
new file mode 100644
index 0000000000000..5aad67173cc61
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -0,0 +1,42 @@
+//===-------------- TosaTarget.cpp - TOSA Target utilities ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
+
+namespace mlir {
+namespace tosa {
+
+TargetEnvAttr lookupTargetEnv(Operation *op) {
+ while (op) {
+ op = SymbolTable::getNearestSymbolTable(op);
+ if (!op)
+ break;
+
+ if (auto attr = op->getAttrOfType<TargetEnvAttr>(TargetEnvAttr::name))
+ return attr;
+
+ op = op->getParentOp();
+ }
+
+ return {};
+}
+
+TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) {
+ return TargetEnvAttr::get(context, Level::eightK,
+ {Profile::pro_int, Profile::pro_fp}, {});
+}
+
+TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) {
+ if (auto attr = lookupTargetEnv(op))
+ return attr;
+
+ return getDefaultTargetEnv(op->getContext());
+}
+
+} // namespace tosa
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 803993bb1008d..41b338d6e7189 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRTosaTransforms
+ TosaAttachTarget.cpp
TosaConvertIntegerTypeToSignless.cpp
TosaDecomposeTransposeConv.cpp
TosaDecomposeDepthwise.cpp
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
new file mode 100644
index 0000000000000..bcb880a808b36
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
@@ -0,0 +1,87 @@
+//===- TosaAttachTarget.cpp
+//------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Attach target information to a TOSA module.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace tosa {
+
+#define GEN_PASS_DEF_TOSAATTACHTARGET
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+
+namespace {
+
+class TosaAttachTarget
+ : public tosa::impl::TosaAttachTargetBase<TosaAttachTarget> {
+ using Base::Base;
+
+public:
+ void runOnOperation() override {
+ llvm::SmallVector<Profile, 2> selectedProfiles;
+ if (!profiles.empty()) {
+ for (const std::string &prof : profiles) {
+ std::optional<Profile> profSymbol = symbolizeProfile(prof);
+ if (!profSymbol) {
+ llvm::SmallVector<Profile> allProfiles = ProfileAttr::getAllValues();
+ llvm::errs() << buildUnkownParameterErrorMessage(allProfiles,
+ "profile", prof);
+ return signalPassFailure();
+ }
+ selectedProfiles.push_back(profSymbol.value());
+ }
+ }
+
+ llvm::SmallVector<Extension, 10> selectedExtensions;
+ if (!extensions.empty()) {
+ for (const std::string &ext : extensions) {
+ std::optional<Extension> extSymbol = symbolizeExtension(ext);
+ if (!extSymbol) {
+ llvm::SmallVector<Extension> allExtensions =
+ ExtensionAttr::getAllValues();
+ llvm::errs() << buildUnkownParameterErrorMessage(allExtensions,
+ "extension", ext);
+ return signalPassFailure();
+ }
+ selectedExtensions.push_back(extSymbol.value());
+ }
+ }
+
+ ModuleOp mod = getOperation();
+ MLIRContext *ctx = &getContext();
+ const auto targetEnvAttr =
+ TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions);
+ mod->setAttr(TargetEnvAttr::name, targetEnvAttr);
+ }
+
+private:
+ template <typename T>
+ std::string buildUnkownParameterErrorMessage(llvm::SmallVector<T> &enumValues,
+ std::string enumName,
+ std::string unknownArgument) {
+ std::string message;
+ llvm::raw_string_ostream os(message);
+ os << "Unknown TOSA " << enumName << " name passed in '" << unknownArgument
+ << "', supported " << enumName << "s are: ";
+ llvm::interleaveComma(enumValues, os);
+ os << "\n";
+ return message;
+ }
+};
+
+} // namespace
+
+} // namespace tosa
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 881736f020090..0124a168cefe2 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -14,7 +14,6 @@
#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
#include "mlir/Dialect/Tosa/IR/TosaProfileCompliance.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
-#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
#include <string>
@@ -130,28 +129,6 @@ static LogicalResult checkConstantOperandNegate(Operation *op,
return success();
}
-struct TosaLevel {
- int32_t MAX_RANK = 0;
- int32_t MAX_KERNEL = 0;
- int32_t MAX_STRIDE = 0;
- int32_t MAX_SCALE = 0;
- int32_t MAX_LOG2_SIZE = 0;
- int32_t MAX_NESTING = 0;
- int32_t MAX_TENSOR_LIST_SIZE = 0;
-
- bool operator==(const TosaLevel &rhs) {
- return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
- MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
- MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
- MAX_NESTING == rhs.MAX_NESTING &&
- MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
- }
-};
-
-static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
-static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
- 63, 256, 256};
-
//===----------------------------------------------------------------------===//
// TOSA Validation Pass.
//===----------------------------------------------------------------------===//
@@ -162,12 +139,9 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
explicit TosaValidation(const TosaValidationOptions &options)
: TosaValidation() {
- this->profile = options.profile;
- this->extension = options.extension;
this->strictOpSpecAlignment = options.strictOpSpecAlignment;
this->allowInvalidOpDatatypeCombinations =
options.allowInvalidOpDatatypeCombinations;
- this->level = options.level;
}
void runOnOperation() final;
@@ -206,7 +180,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
}
bool levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) {
- if (v > tosaLevel.MAX_KERNEL) {
+ if (v > targetEnv.getLevel().MAX_KERNEL) {
op->emitOpError() << "failed level check: " << checkDesc;
return false;
}
@@ -214,7 +188,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
}
bool levelCheckStride(Operation *op, int32_t v, const StringRef checkDesc) {
- if (v > tosaLevel.MAX_STRIDE) {
+ if (v > targetEnv.getLevel().MAX_STRIDE) {
op->emitOpError() << "failed level check: " << checkDesc;
return false;
}
@@ -222,7 +196,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
}
bool levelCheckScale(Operation *op, int32_t v, const StringRef checkDesc) {
- if (v > tosaLevel.MAX_SCALE) {
+ if (v > targetEnv.getLevel().MAX_SCALE) {
op->emitOpError() << "failed level check: " << checkDesc;
return false;
}
@@ -230,7 +204,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
}
bool levelCheckListSize(Operation *op, int32_t v, const StringRef checkDesc) {
- if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) {
+ if (v > targetEnv.getLevel().MAX_TENSOR_LIST_SIZE) {
op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: "
<< checkDesc;
return false;
@@ -291,6 +265,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
template <typename T>
bool levelCheckRanks(T tosaOp) {
auto op = tosaOp.getOperation();
+ const TosaLevel tosaLevel = targetEnv.getLevel();
for (auto v : op->getOperands()) {
if (!levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK))
return false;
@@ -472,7 +447,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
int32_t maxNestedDepth = 0;
getMaxNestedDepth(op, maxNestedDepth);
- if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
+ if (maxNestedDepth >= targetEnv.getLevel().MAX_NESTING) {
op->emitOpError() << "failed level check: " << maxNestedDepth
<< " >= MAX_NESTING";
return false;
@@ -526,43 +501,6 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
return true;
}
- // configure profile and level values from pass options profileName and
- // levelName
- void configLevelAndProfile() {
- tosaLevel = TOSA_LEVEL_NONE;
- if (level == TosaLevelEnum::EightK) {
- tosaLevel = TOSA_LEVEL_EIGHTK;
- }
-
- if (!profile.empty()) {
- for (std::string &prof : profile) {
- auto profSymbol = symbolizeProfile(prof);
- if (profSymbol) {
- targetEnv.addProfile(profSymbol.value());
- } else {
- llvm::errs() << "unknown TOSA profile name passed in: " << prof
- << ", supported profiles are `pro_int` and `pro_fp`\n";
- return signalPassFailure();
- }
- }
- }
-
- if (!extension.empty()) {
- for (std::string &ext : extension) {
- auto extSymbol = symbolizeExtension(ext);
- if (extSymbol) {
- targetEnv.addExtension(extSymbol.value());
- } else {
- llvm::errs() << "unknown TOSA extension name passed in: " << ext
- << ", supported extension are int16, int4, bf16, "
- << "fp8e4m3, fp8e5m2, fft, variable, controlflow, "
- << "doubleround, inexactround and dynamic\n";
- return signalPassFailure();
- }
- }
- }
- }
-
bool CheckVariable(Operation *op);
bool CheckVariableReadOrWrite(Operation *op);
bool isValidElementType(Type type, const bool allowUnsigned = false);
@@ -570,7 +508,6 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
SmallVector<
std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
constCheckers;
- TosaLevel tosaLevel;
DenseMap<StringAttr, mlir::Type> variablesMap;
TosaProfileCompliance profileComp;
tosa::TargetEnv targetEnv;
@@ -579,11 +516,13 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
template <>
bool TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
auto op = tosaOp.getOperation();
- if (!levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK))
+ if (!levelCheckRank(op, tosaOp.getInput(), "operand",
+ targetEnv.getLevel().MAX_RANK))
return false;
// rank(output) = rank(input) - 1
- if (!levelCheckRank(op, tosaOp.getOutput(), "result", tosaLevel.MAX_RANK - 1))
+ if (!levelCheckRank(op, tosaOp.getOutput(), "result",
+ targetEnv.getLevel().MAX_RANK - 1))
return false;
return true;
@@ -594,7 +533,8 @@ bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
auto op = tosaOp.getOperation();
// Only the condition input has rank limitation.
- if (!levelCheckRank(op, tosaOp.getCondition(), "operand", tosaLevel.MAX_RANK))
+ if (!levelCheckRank(op, tosaOp.getCondition(), "operand",
+ targetEnv.getLevel().MAX_RANK))
return false;
return true;
@@ -604,7 +544,8 @@ template <>
bool TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
auto op = tosaOp.getOperation();
auto variableType = getVariableType(tosaOp);
- if (!levelCheckRank(op, variableType, "variable type", tosaLevel.MAX_RANK))
+ if (!levelCheckRank(op, variableType, "variable type",
+ targetEnv.getLevel().MAX_RANK))
return false;
return true;
@@ -764,7 +705,8 @@ bool TosaValidation::levelCheckSize(Operation *op, const Type &typeToCheck,
// defined in 1.7. Levels.
// For each tensor, the number of tensor elements multiplied by the
// element size in bytes must be representable as a tensor_size_t.
- const int64_t max_size = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1;
+ const int64_t max_size =
+ (INT64_C(1) << targetEnv.getLevel().MAX_LOG2_SIZE) - 1;
if (size > max_size) {
op->emitOpError()
<< "failed level check: " << operandOrResult
@@ -776,7 +718,7 @@ bool TosaValidation::levelCheckSize(Operation *op, const Type &typeToCheck,
}
LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
- if (tosaLevel == TOSA_LEVEL_NONE) {
+ if (targetEnv.getLevel() == TOSA_LEVEL_NONE) {
// no need to do level checks
return success();
}
@@ -1335,12 +1277,12 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
}
void TosaValidation::runOnOperation() {
- configLevelAndProfile();
-
TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>();
if (!tosaDialect)
return;
+ targetEnv = tosa::TargetEnv(lookupTargetEnvOrDefault(getOperation()));
+
getOperation().walk([&](Operation *op) {
if (op->getDialect() != tosaDialect)
return;
diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
index e23ce43031a24..f73b41cf434fc 100644
--- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir
+++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
@@ -2,7 +2,7 @@
// Check operations when the dynamic extension is enabled.
//--------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic allow-invalid-op-datatype-combinations"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=dynamic" -tosa-validate="strict-op-spec-alignment allow-invalid-op-datatype-combinations"
// -----
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index fad1bec0e3ecc..c57287fa8e368 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="level=none profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="level=none profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic" -tosa-validate="strict-op-spec-alignment"
// -----
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 7c76c178436c7..0f9252ff7a838 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -4,7 +4,7 @@
// validation flow.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
func.func @test_cast(%arg0: tensor<i1>) -> tensor<5xi32> {
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 3154f541e0519..40622e6d233c7 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -2,7 +2,7 @@
// Enable all supported profiles to focus the verification of expected extension requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp" -tosa-validate="strict-op-spec-alignment"
// -----
func.func @test_argmax(%arg0: tensor<14x19xbf16>) -> tensor<14xi32> {
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 0184d2b05f0ee..c824ca0e33753 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="extension=dynamic"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="extensions=dynamic" -tosa-validate
func.func @test_argmax_rank_invalid(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {
// expected-error at +1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}}
diff --git a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
index 225b962589df9..09e96eca776e2 100644
--- a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
@@ -2,7 +2,7 @@
// Enable all supported extensions to focus the verification of expected profile requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
// -----
func.func @test_add_i32(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
index fad4859351251..0bdba0a1d3fbd 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
@@ -2,7 +2,7 @@
// Enable all supported extensions to focus the verification of expected profile requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
// -----
func.func @test_const_f16() -> tensor<3x11x11x3xf16> {
diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
index 9438179622aad..a4ba71a72d798 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
@@ -2,7 +2,7 @@
// Enable all supported extensions to focus the verification of expected profile requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
// -----
func.func @test_const_i1() -> tensor<3x11x11x3xi1> {
diff --git a/mlir/test/Dialect/Tosa/tosa-attach-target.mlir b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir
new file mode 100644
index 0000000000000..d6c886c44b013
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s -split-input-file -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,dynamic level=none" | FileCheck %s --check-prefix=CHECK-ALL
+// RUN: mlir-opt %s -split-input-file -tosa-attach-target="level=8k" | FileCheck %s --check-prefix=CHECK-LVL-8K
+// RUN: mlir-opt %s -split-input-file -tosa-attach-target | FileCheck %s --check-prefix=CHECK-DEFAULT
+
+// -----
+
+// CHECK-ALL: module attributes {tosa.target_env = #tosa.target_env<level = none, profiles = [pro_int, pro_fp], extensions = [int16, int4, bf16, fp8e4m3, fp8e5m2, fft, variable, controlflow, doubleround, inexactround, dynamic]>}
+// CHECK-LVL-8K: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>}
+// CHECK-DEFAULT: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>}
+// CHECK-LABEL: test_simple
+func.func @test_simple(%arg0 : tensor<1x1x1x1xf32>, %arg1 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> {
+ %1 = tosa.add %arg0, %arg1 : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
+ return %1 : tensor<1x1x1x1xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
index cab14201dc0ce..79032f39af99d 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
@@ -4,7 +4,7 @@
// validation flow.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate | FileCheck %s
// -----
>From 9d0903bf2d288ea57e28882dd9e0625ac011de42 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 30 Jul 2025 09:12:27 +0000
Subject: [PATCH 2/3] [mlir][tosa] Relax constraint on matmul verifier
requiring equal operand types
Removes the verifier constraint allowing support for matmul with
different operand types such as fp8e5m2xfp8e4m3. Support for specific
operand types strictly adhering to the TOSA specification will still be
caught in the validation pass.
Change-Id: I1453ded48326ea0460fa6caf52651c02b7d8c055
---
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 6 ------
mlir/test/Dialect/Tosa/ops.mlir | 9 +++++++++
2 files changed, 9 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 3cafb199d2db3..d58acf6a53d5e 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1605,12 +1605,6 @@ LogicalResult MatMulOp::verify() {
return emitOpError("expect quantized operands to have same widths, got ")
<< aQuantWidth << " and " << bQuantWidth;
}
- } else {
- // non-quantized element types
- if (aElementType != bElementType) {
- return emitOpError("expect same element type for inputs a and b, got ")
- << aElementType << " and " << bElementType;
- }
}
// check a_zp and b_zp
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 30361a882afe5..6061c04662de8 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -934,6 +934,15 @@ func.func @test_matmul_f8E5M2(%arg0: tensor<1x14x19xf8E5M2>, %arg1: tensor<1x19x
return %0 : tensor<1x14x28xf16>
}
+// -----
+// CHECK-LABEL: test_matmul_f8E5M2_f8E4M3
+func.func @test_matmul_f8E5M2_f8E4M3(%arg0: tensor<1x14x19xf8E5M2>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf16> {
+ %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E5M2>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E5M2>, tensor<1xf8E4M3FN>) -> tensor<1x14x28xf16>
+ return %0 : tensor<1x14x28xf16>
+}
+
// -----
// CHECK-LABEL: max_pool2d_f8E5M2
func.func @test_max_pool2d_f8E5M2(%arg0: tensor<1x32x32x8xf8E5M2>) -> tensor<1x32x32x8xf8E5M2> {
>From cd92efaeae7f6490b1c88c8c954986ef4674b224 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Fri, 29 Aug 2025 11:23:58 +0000
Subject: [PATCH 3/3] [mlir][tosa] Add specification versioning to target
environment
This commit adds a new "specification_version" field to the TOSA
target environment attribute. This allows a user to specify which
version of the TOSA specification they would like to target during
lowering.
A leading example in the validation pass has also been added. This
addition adds a version to each profile compliance entry to track
which version of the specification the entry was added. This allows
a backwards compatibility check to be implemented between the target
version and the profile compliance entry version.
For now a default version of "1.0" is assumed. "1.1.draft" is added
to denote an in-development version of the specification targeting
the next release.
Change-Id: I6549e05bd4fe975d12ea31e8acc783233db66171
---
mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h | 48 +-
.../Dialect/Tosa/IR/TosaComplianceData.h.inc | 1087 ++++++++++-------
.../mlir/Dialect/Tosa/IR/TosaOpBase.td | 33 +-
.../Dialect/Tosa/IR/TosaProfileCompliance.h | 13 +-
.../mlir/Dialect/Tosa/Transforms/Passes.td | 7 +
mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp | 7 +-
.../Tosa/Transforms/TosaAttachTarget.cpp | 4 +-
.../Tosa/Transforms/TosaProfileCompliance.cpp | 74 +-
.../test/Dialect/Tosa/tosa-attach-target.mlir | 8 +-
.../tosa-validation-version-1p0-invalid.mlir | 21 +
.../tosa-validation-version-1p1-valid.mlir | 20 +
11 files changed, 835 insertions(+), 487 deletions(-)
create mode 100644 mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
create mode 100644 mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
index 10491f65d37af..4ecf03c34c1a5 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
@@ -50,28 +50,63 @@ TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);
/// returned by getDefaultTargetEnv() if not provided.
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op);
+/// A thin wrapper around the SpecificationVersion enum to represent
+/// and provide utilities around the TOSA specification version.
+class TosaSpecificationVersion {
+public:
+ TosaSpecificationVersion(uint32_t major, uint32_t minor)
+ : majorVersion(major), minorVersion(minor) {}
+ TosaSpecificationVersion(SpecificationVersion version)
+ : TosaSpecificationVersion(fromVersionEnum(version)) {}
+
+ bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const {
+ return this->majorVersion == baseVersion.majorVersion &&
+ this->minorVersion >= baseVersion.minorVersion;
+ }
+
+ uint32_t getMajor() const { return majorVersion; }
+ uint32_t getMinor() const { return minorVersion; }
+
+private:
+ uint32_t majorVersion = 0;
+ uint32_t minorVersion = 0;
+
+ static TosaSpecificationVersion
+ fromVersionEnum(SpecificationVersion version) {
+ switch (version) {
+ case SpecificationVersion::V_1_0:
+ return TosaSpecificationVersion(1, 0);
+ case SpecificationVersion::V_1_1_DRAFT:
+ return TosaSpecificationVersion(1, 1);
+ }
+ llvm_unreachable("Unknown TOSA version");
+ }
+};
+
+llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version);
+
/// This class represents the capability enabled in the target implementation
/// such as profile, extension, and level. It's a wrapper class around
/// tosa::TargetEnvAttr.
class TargetEnv {
public:
TargetEnv() {}
- explicit TargetEnv(Level level, const ArrayRef<Profile> &profiles,
+ explicit TargetEnv(SpecificationVersion specificationVersion, Level level,
+ const ArrayRef<Profile> &profiles,
const ArrayRef<Extension> &extensions)
- : level(level) {
+ : specificationVersion(specificationVersion), level(level) {
enabledProfiles.insert_range(profiles);
enabledExtensions.insert_range(extensions);
}
explicit TargetEnv(TargetEnvAttr targetAttr)
- : TargetEnv(targetAttr.getLevel(), targetAttr.getProfiles(),
- targetAttr.getExtensions()) {}
+ : TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(),
+ targetAttr.getProfiles(), targetAttr.getExtensions()) {}
void addProfile(Profile p) { enabledProfiles.insert(p); }
void addExtension(Extension e) { enabledExtensions.insert(e); }
- // TODO implement the following utilities.
- // Version getSpecVersion() const;
+ SpecificationVersion getSpecVersion() const { return specificationVersion; }
TosaLevel getLevel() const {
if (level == Level::eightK)
@@ -105,6 +140,7 @@ class TargetEnv {
}
private:
+ SpecificationVersion specificationVersion;
Level level;
llvm::SmallSet<Profile, 3> enabledProfiles;
llvm::SmallSet<Extension, 13> enabledExtensions;
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index 1f718accabd15..dc46ba73cd624 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -1,442 +1,661 @@
// The profile-based compliance content below is auto-generated by the script
// `tools/genspec.py` in https://git.mlplatform.org/tosa/specification.git
profileComplianceMap = {
- {"tosa.argmax",
- {{{Profile::pro_int}, {{i8T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, i32T}, {fp32T, i32T}}}}},
- {"tosa.avg_pool2d",
- {{{Profile::pro_int}, {{i8T, i8T, i8T, i32T, i8T}}},
- {{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
- {"tosa.conv2d",
- {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}},
- {{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
- {"tosa.conv3d",
- {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}},
- {{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
- {"tosa.depthwise_conv2d",
- {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}},
- {{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
- {"tosa.matmul",
- {{{Profile::pro_int}, {{i8T, i8T, i8T, i8T, i32T}}},
- {{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp16T, fp32T},
- {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
- {"tosa.max_pool2d",
- {{{Profile::pro_int}, {{i8T, i8T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.transpose_conv2d",
- {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}},
- {{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
- {"tosa.clamp",
- {{{Profile::pro_int}, {{i8T, i8T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.erf", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.sigmoid", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.tanh", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.add",
- {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
- {"tosa.arithmetic_right_shift",
- {{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
- {"tosa.bitwise_and",
- {{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
- {"tosa.bitwise_or",
- {{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
- {"tosa.bitwise_xor",
- {{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
- {"tosa.intdiv",
- {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf}}},
- {"tosa.logical_and",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}},
- {"tosa.logical_left_shift",
- {{{Profile::pro_int, Profile::pro_fp},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}},
- anyOf}}},
- {"tosa.logical_right_shift",
- {{{Profile::pro_int, Profile::pro_fp},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}},
- anyOf}}},
- {"tosa.logical_or",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}},
- {"tosa.logical_xor",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}},
- {"tosa.maximum",
- {{{Profile::pro_int}, {{i32T, i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
- {"tosa.minimum",
- {{{Profile::pro_int}, {{i32T, i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
- {"tosa.mul",
- {{{Profile::pro_int}, {{i8T, i8T, i32T}, {i16T, i16T, i32T}}},
- {{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
- {"tosa.pow",
- {{{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
- {"tosa.sub",
- {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
- {"tosa.table", {{{Profile::pro_int}, {{i8T, i8T, i8T}}}}},
- {"tosa.abs",
- {{{Profile::pro_int}, {{i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.bitwise_not",
- {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}}},
- {"tosa.ceil", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.clz", {{{Profile::pro_int}, {{i32T, i32T}}}}},
- {"tosa.cos", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.exp", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.floor", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.log", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.logical_not",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}},
- {"tosa.negate",
- {{{Profile::pro_int},
- {{i8T, i8T, i8T, i8T},
- {i16T, i16T, i16T, i16T},
- {i32T, i32T, i32T, i32T}}},
- {{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T, fp32T}}}}},
- {"tosa.reciprocal",
- {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.rsqrt", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.sin", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.select",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf},
- {{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
- {"tosa.equal",
- {{{Profile::pro_int}, {{i32T, i32T, boolT}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}},
- {"tosa.greater",
- {{{Profile::pro_int}, {{i32T, i32T, boolT}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}},
- {"tosa.greater_equal",
- {{{Profile::pro_int}, {{i32T, i32T, boolT}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}},
- {"tosa.reduce_all",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}},
- {"tosa.reduce_any",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}},
- {"tosa.reduce_max",
- {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.reduce_min",
- {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.reduce_product",
- {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.reduce_sum",
- {{{Profile::pro_int}, {{i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.concat",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.pad",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf},
- {{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
- {"tosa.reshape",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.reverse",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.slice",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.tile",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.transpose",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.gather",
- {{{Profile::pro_int},
- {{i8T, i32T, i8T}, {i16T, i32T, i16T}, {i32T, i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, i32T, fp16T}, {fp32T, i32T, fp32T}}}}},
- {"tosa.scatter",
- {{{Profile::pro_int},
- {{i8T, i32T, i8T, i8T},
- {i16T, i32T, i16T, i16T},
- {i32T, i32T, i32T, i32T}}},
- {{Profile::pro_fp},
- {{fp16T, i32T, fp16T, fp16T}, {fp32T, i32T, fp32T, fp32T}}}}},
- {"tosa.resize",
- {{{Profile::pro_int}, {{i8T, i32T}, {i8T, i8T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.cast",
- {{{Profile::pro_int},
- {{boolT, i8T},
- {boolT, i16T},
- {boolT, i32T},
- {i8T, boolT},
- {i8T, i16T},
- {i8T, i32T},
- {i16T, boolT},
- {i16T, i8T},
- {i16T, i32T},
- {i32T, boolT},
- {i32T, i8T},
- {i32T, i16T}}},
- {{Profile::pro_fp},
- {{i8T, fp16T},
- {i8T, fp32T},
- {i16T, fp16T},
- {i16T, fp32T},
- {i32T, fp16T},
- {i32T, fp32T},
- {fp16T, i8T},
- {fp16T, i16T},
- {fp16T, i32T},
- {fp16T, fp32T},
- {fp32T, i8T},
- {fp32T, i16T},
- {fp32T, i32T},
- {fp32T, fp16T}}}}},
- {"tosa.rescale",
- {{{Profile::pro_int},
- {{i8T, i8T, i8T, i8T},
- {i8T, i8T, i16T, i16T},
- {i8T, i8T, i32T, i32T},
- {i16T, i16T, i8T, i8T},
- {i16T, i16T, i16T, i16T},
- {i16T, i16T, i32T, i32T},
- {i32T, i32T, i8T, i8T},
- {i32T, i32T, i16T, i16T},
- {i32T, i32T, i32T, i32T}}}}},
- {"tosa.const",
- {{{Profile::pro_int, Profile::pro_fp},
- {{boolT}, {i8T}, {i16T}, {i32T}},
- anyOf},
- {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
- {"tosa.identity",
- {{{Profile::pro_int, Profile::pro_fp},
- {{boolT, boolT}, {i8T, i8T}, {i16T, i16T}, {i32T, i32T}},
- anyOf},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.variable",
- {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
- {"tosa.variable_write",
- {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
- {"tosa.variable_read",
- {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
+ {"tosa.argmax",
+ {{{Profile::pro_int}, {{{i8T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, i32T}, SpecificationVersion::V_1_0}, {{fp32T, i32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.avg_pool2d",
+ {{{Profile::pro_int}, {{{i8T, i8T, i8T, i32T, i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp32T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.conv2d",
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.conv3d",
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.depthwise_conv2d",
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.matmul",
+ {{{Profile::pro_int}, {{{i8T, i8T, i8T, i8T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp16T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.max_pool2d",
+ {{{Profile::pro_int}, {{{i8T, i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.transpose_conv2d",
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.clamp",
+ {{{Profile::pro_int}, {{{i8T, i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.erf",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.sigmoid",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.tanh",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.add",
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.arithmetic_right_shift",
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.bitwise_and",
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.bitwise_or",
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.bitwise_xor",
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.intdiv",
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
+ anyOf}}},
+ {"tosa.logical_and",
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
+ {"tosa.logical_left_shift",
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
+ anyOf}}},
+ {"tosa.logical_right_shift",
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
+ anyOf}}},
+ {"tosa.logical_or",
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
+ {"tosa.logical_xor",
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
+ {"tosa.maximum",
+ {{{Profile::pro_int}, {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.minimum",
+ {{{Profile::pro_int}, {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.mul",
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i32T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_int, Profile::pro_fp},
+ {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.pow",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.sub",
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.table", {{{Profile::pro_int}, {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.abs",
+ {{{Profile::pro_int}, {{{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.bitwise_not",
+ {{{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.ceil",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.clz", {{{Profile::pro_int}, {{{i32T, i32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.cos",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.exp",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.floor",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.log",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.logical_not",
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
+ {"tosa.negate",
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reciprocal",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.rsqrt",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.sin",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.select",
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.equal",
+ {{{Profile::pro_int}, {{{i32T, i32T, boolT}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, boolT}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, boolT}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.greater",
+ {{{Profile::pro_int}, {{{i32T, i32T, boolT}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, boolT}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, boolT}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.greater_equal",
+ {{{Profile::pro_int}, {{{i32T, i32T, boolT}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, boolT}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, boolT}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reduce_all",
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
+ {"tosa.reduce_any",
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
+ {"tosa.reduce_max",
+ {{{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reduce_min",
+ {{{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reduce_product",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reduce_sum",
+ {{{Profile::pro_int}, {{{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.concat",
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0}, {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.pad",
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reshape",
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reverse",
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.slice",
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.tile",
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.transpose",
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.gather",
+ {{{Profile::pro_int},
+ {{{i8T, i32T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i32T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, i32T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, i32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.scatter",
+ {{{Profile::pro_int},
+ {{{i8T, i32T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i32T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, i32T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, i32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.resize",
+ {{{Profile::pro_int},
+ {{{i8T, i32T}, SpecificationVersion::V_1_0}, {{i8T, i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.cast",
+ {{{Profile::pro_int},
+ {{{boolT, i8T}, SpecificationVersion::V_1_0},
+ {{boolT, i16T}, SpecificationVersion::V_1_0},
+ {{boolT, i32T}, SpecificationVersion::V_1_0},
+ {{i8T, boolT}, SpecificationVersion::V_1_0},
+ {{i8T, i16T}, SpecificationVersion::V_1_0},
+ {{i8T, i32T}, SpecificationVersion::V_1_0},
+ {{i16T, boolT}, SpecificationVersion::V_1_0},
+ {{i16T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i32T}, SpecificationVersion::V_1_0},
+ {{i32T, boolT}, SpecificationVersion::V_1_0},
+ {{i32T, i8T}, SpecificationVersion::V_1_0},
+ {{i32T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{i8T, fp16T}, SpecificationVersion::V_1_0},
+ {{i8T, fp32T}, SpecificationVersion::V_1_0},
+ {{i16T, fp16T}, SpecificationVersion::V_1_0},
+ {{i16T, fp32T}, SpecificationVersion::V_1_0},
+ {{i32T, fp16T}, SpecificationVersion::V_1_0},
+ {{i32T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp16T, i8T}, SpecificationVersion::V_1_0},
+ {{fp16T, i16T}, SpecificationVersion::V_1_0},
+ {{fp16T, i32T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp32T, i8T}, SpecificationVersion::V_1_0},
+ {{fp32T, i16T}, SpecificationVersion::V_1_0},
+ {{fp32T, i32T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.rescale",
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i8T, i8T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i32T, i32T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.const",
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT}, SpecificationVersion::V_1_0},
+ {{i8T}, SpecificationVersion::V_1_0},
+ {{i16T}, SpecificationVersion::V_1_0},
+ {{i32T}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_fp},
+ {{{fp16T}, SpecificationVersion::V_1_0}, {{fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.identity",
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0},
+ {{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.variable",
+ {{{Profile::pro_int}, {{{i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T}, SpecificationVersion::V_1_0}, {{fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.variable_write",
+ {{{Profile::pro_int}, {{{i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T}, SpecificationVersion::V_1_0}, {{fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.variable_read",
+ {{{Profile::pro_int}, {{{i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T}, SpecificationVersion::V_1_0}, {{fp32T}, SpecificationVersion::V_1_0}}}}},
};
extensionComplianceMap = {
- {"tosa.argmax",
- {{{Extension::int16}, {{i16T, i32T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, i32T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, i32T}}},
- {{Extension::bf16}, {{bf16T, i32T}}}}},
- {"tosa.avg_pool2d",
- {{{Extension::int16}, {{i16T, i16T, i16T, i32T, i16T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
- {"tosa.conv2d",
- {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}},
- {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}},
- {{Extension::fp8e4m3},
- {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}},
- {{Extension::fp8e5m2},
- {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}},
- {{Extension::bf16},
- {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
- {"tosa.conv3d",
- {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}},
- {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}},
- {{Extension::fp8e4m3},
- {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}},
- {{Extension::fp8e5m2},
- {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}},
- {{Extension::bf16},
- {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
- {"tosa.depthwise_conv2d",
- {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}},
- {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}},
- {{Extension::fp8e4m3},
- {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}},
- {{Extension::fp8e5m2},
- {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}},
- {{Extension::bf16},
- {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
- {"tosa.fft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T, fp32T}}}}},
- {"tosa.matmul",
- {{{Extension::int16}, {{i16T, i16T, i16T, i16T, i48T}}},
- {{Extension::fp8e4m3},
- {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T},
- {fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp32T}}},
- {{Extension::fp8e5m2},
- {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T},
- {fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp32T}}},
- {{Extension::fp8e4m3, Extension::fp8e5m2},
- {{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp16T},
- {fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp32T},
- {fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp16T},
- {fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp32T}},
- allOf},
- {{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T, fp32T}}}}},
- {"tosa.max_pool2d",
- {{{Extension::int16}, {{i16T, i16T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.rfft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T}}}}},
- {"tosa.transpose_conv2d",
- {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}},
- {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}},
- {{Extension::fp8e4m3},
- {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}},
- {{Extension::fp8e5m2},
- {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}},
- {{Extension::bf16},
- {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
- {"tosa.clamp",
- {{{Extension::int16}, {{i16T, i16T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.erf", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.sigmoid", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.tanh", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.add", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.maximum", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.minimum", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.mul", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.pow", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.sub", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.table", {{{Extension::int16}, {{i16T, i16T, i32T}}}}},
- {"tosa.abs", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.ceil", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.cos", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.exp", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.floor", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.log", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T}}}}},
- {"tosa.reciprocal", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.rsqrt", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.sin", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.select", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}},
- {"tosa.greater", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}},
- {"tosa.greater_equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}},
- {"tosa.reduce_max", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.reduce_min", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.reduce_product", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.reduce_sum", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.concat",
- {{{Extension::int16}, {{i16T, i16T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.pad",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.reshape",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.reverse",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.slice",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.tile",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.transpose",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.gather",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, i32T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, i32T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, i32T, bf16T}}}}},
- {"tosa.scatter",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, i32T, fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, i32T, fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, i32T, bf16T, bf16T}}}}},
- {"tosa.resize",
- {{{Extension::int16}, {{i16T, i48T}, {i16T, i16T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.cast",
- {{{Extension::bf16},
- {{i8T, bf16T},
- {i16T, bf16T},
- {i32T, bf16T},
- {bf16T, i8T},
- {bf16T, i16T},
- {bf16T, i32T},
- {bf16T, fp32T},
- {fp32T, bf16T}}},
- {{Extension::bf16, Extension::fp8e4m3},
- {{bf16T, fp8e4m3T}, {fp8e4m3T, bf16T}},
- allOf},
- {{Extension::bf16, Extension::fp8e5m2},
- {{bf16T, fp8e5m2T}, {fp8e5m2T, bf16T}},
- allOf},
- {{Extension::fp8e4m3},
- {{fp8e4m3T, fp16T},
- {fp8e4m3T, fp32T},
- {fp16T, fp8e4m3T},
- {fp32T, fp8e4m3T}}},
- {{Extension::fp8e5m2},
- {{fp8e5m2T, fp16T},
- {fp8e5m2T, fp32T},
- {fp16T, fp8e5m2T},
- {fp32T, fp8e5m2T}}}}},
- {"tosa.rescale",
- {{{Extension::int16},
- {{i48T, i48T, i8T, i8T},
- {i48T, i48T, i16T, i16T},
- {i48T, i48T, i32T, i32T}}}}},
- {"tosa.const",
- {{{Extension::int4}, {{i4T}}},
- {{Extension::int16}, {{i48T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T}}}}},
- {"tosa.identity",
- {{{Extension::int4}, {{i4T, i4T}}},
- {{Extension::int16}, {{i48T, i48T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.variable", {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}},
- {"tosa.variable_write",
- {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}},
- {"tosa.variable_read",
- {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}},
+ {"tosa.argmax",
+ {{{Extension::int16}, {{{i16T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3}, {{{fp8e4m3T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2}, {{{fp8e5m2T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, i32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.avg_pool2d",
+ {{{Extension::int16}, {{{i16T, i16T, i16T, i32T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16},
+ {{{bf16T, bf16T, bf16T, fp32T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.conv2d",
+ {{{Extension::int4},
+ {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16},
+ {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
+ {{Extension::bf16},
+ {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.conv3d",
+ {{{Extension::int4},
+ {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16},
+ {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
+ {{Extension::bf16},
+ {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.depthwise_conv2d",
+ {{{Extension::int4},
+ {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16},
+ {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
+ {{Extension::bf16},
+ {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.fft2d",
+ {{{Extension::fft}, {{{fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.matmul",
+ {{{Extension::int16}, {{{i16T, i16T, i16T, i16T, i48T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}},
+ {{Extension::fp8e4m3, Extension::fp8e5m2},
+ {{{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp32T}, SpecificationVersion::V_1_1_DRAFT}},
+ allOf},
+ {{Extension::bf16},
+ {{{bf16T, bf16T, bf16T, bf16T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.max_pool2d",
+ {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3}, {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2}, {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.rfft2d",
+ {{{Extension::fft}, {{{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.transpose_conv2d",
+ {{{Extension::int4},
+ {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16},
+ {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
+ {{Extension::bf16},
+ {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.clamp",
+ {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.erf", {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.sigmoid", {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.tanh", {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.add",
+ {{{Extension::bf16}, {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.maximum",
+ {{{Extension::bf16}, {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.minimum",
+ {{{Extension::bf16}, {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.mul",
+ {{{Extension::bf16}, {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.pow",
+ {{{Extension::bf16}, {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.sub",
+ {{{Extension::bf16}, {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.table",
+ {{{Extension::int16}, {{{i16T, i16T, i32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.abs", {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.ceil", {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.cos", {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.exp", {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.floor", {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.log", {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.negate",
+ {{{Extension::bf16}, {{{bf16T, bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reciprocal",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.rsqrt", {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.sin", {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.select",
+ {{{Extension::bf16}, {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.equal",
+ {{{Extension::bf16}, {{{bf16T, bf16T, boolT}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.greater",
+ {{{Extension::bf16}, {{{bf16T, bf16T, boolT}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.greater_equal",
+ {{{Extension::bf16}, {{{bf16T, bf16T, boolT}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reduce_max",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reduce_min",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reduce_product",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reduce_sum",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.concat",
+ {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3}, {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2}, {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.pad",
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reshape",
+ {{{Extension::fp8e4m3}, {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2}, {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reverse",
+ {{{Extension::fp8e4m3}, {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2}, {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.slice",
+ {{{Extension::fp8e4m3}, {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2}, {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.tile",
+ {{{Extension::fp8e4m3}, {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2}, {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.transpose",
+ {{{Extension::fp8e4m3}, {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2}, {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.gather",
+ {{{Extension::fp8e4m3}, {{{fp8e4m3T, i32T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2}, {{{fp8e5m2T, i32T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, i32T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.scatter",
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, i32T, fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, i32T, fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, i32T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.resize",
+ {{{Extension::int16},
+ {{{i16T, i48T}, SpecificationVersion::V_1_0}, {{i16T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.cast",
+ {{{Extension::bf16},
+ {{{i8T, bf16T}, SpecificationVersion::V_1_0},
+ {{i16T, bf16T}, SpecificationVersion::V_1_0},
+ {{i32T, bf16T}, SpecificationVersion::V_1_0},
+ {{bf16T, i8T}, SpecificationVersion::V_1_0},
+ {{bf16T, i16T}, SpecificationVersion::V_1_0},
+ {{bf16T, i32T}, SpecificationVersion::V_1_0},
+ {{bf16T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp32T, bf16T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16, Extension::fp8e4m3},
+ {{{bf16T, fp8e4m3T}, SpecificationVersion::V_1_0},
+ {{fp8e4m3T, bf16T}, SpecificationVersion::V_1_0}},
+ allOf},
+ {{Extension::bf16, Extension::fp8e5m2},
+ {{{bf16T, fp8e5m2T}, SpecificationVersion::V_1_0},
+ {{fp8e5m2T, bf16T}, SpecificationVersion::V_1_0}},
+ allOf},
+ {{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp8e4m3T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp8e4m3T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp8e5m2T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp8e5m2T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp8e5m2T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.rescale",
+ {{{Extension::int16},
+ {{{i48T, i48T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i48T, i48T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i48T, i48T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.const",
+ {{{Extension::int4}, {{{i4T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16}, {{{i48T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3}, {{{fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2}, {{{fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.identity",
+ {{{Extension::int4}, {{{i4T, i4T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16}, {{{i48T, i48T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3}, {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2}, {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.variable",
+ {{{Extension::variable},
+ {{{i8T}, SpecificationVersion::V_1_0},
+ {{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.variable_write",
+ {{{Extension::variable},
+ {{{i8T}, SpecificationVersion::V_1_0},
+ {{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.variable_read",
+ {{{Extension::variable},
+ {{{i8T}, SpecificationVersion::V_1_0},
+ {{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
};
+
// End of auto-generated metadata
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 9e9ff21db1fa1..6379aedf80089 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -221,7 +221,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
}
//===----------------------------------------------------------------------===//
-// TOSA Spec Section 1.5.
+// TOSA Profiles and extensions
//
// Profile:
// INT : Integer Inference. Integer operations, primarily 8 and 32-bit values.
@@ -293,12 +293,6 @@ def Tosa_ExtensionAttr
def Tosa_ExtensionArrayAttr
: TypedArrayAttrBase<Tosa_ExtensionAttr, "TOSA extension array attribute">;
-def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>;
-def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">;
-
-def Tosa_LevelAttr
- : Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>;
-
// The base class for defining op availability dimensions.
class Availability {
// The following are fields for controlling the generated C++ OpInterface.
@@ -404,18 +398,41 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability {
let instance = "ref";
}
+//===----------------------------------------------------------------------===//
+// TOSA Levels
+//===----------------------------------------------------------------------===//
+
+def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>;
+def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">;
+
+def Tosa_LevelAttr
+ : Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>;
+
+//===----------------------------------------------------------------------===//
+// TOSA Specification versions
+//===----------------------------------------------------------------------===//
+
+def Tosa_V_1_0 : I32EnumAttrCase<"V_1_0", 0, "1.0">;
+def Tosa_V_1_1_DRAFT : I32EnumAttrCase<"V_1_1_DRAFT", 1, "1.1.draft">;
+
+def Tosa_SpecificationVersion : Tosa_I32EnumAttr<
+ "SpecificationVersion", "TOSA specification version", "specification_version",
+ [Tosa_V_1_0, Tosa_V_1_1_DRAFT]>;
+
//===----------------------------------------------------------------------===//
// TOSA target environment.
//===----------------------------------------------------------------------===//
def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> {
let summary = "Target environment information.";
let parameters = ( ins
+ "SpecificationVersion": $specification_version,
"Level": $level,
ArrayRefParameter<"Profile">: $profiles,
ArrayRefParameter<"Extension">: $extensions
);
- let assemblyFormat = "`<` `level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` "
+ let assemblyFormat = "`<` `specification_version` `=` $specification_version `,` "
+ "`level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` "
"`extensions` `=` `[` $extensions `]` `>`";
}
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 8f5c72bc5f7a9..7b946ad6c6a89 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -36,12 +36,15 @@ enum CheckCondition {
allOf
};
+using VersionedTypeInfo =
+ std::pair<SmallVector<TypeInfo>, SpecificationVersion>;
+
template <typename T>
struct OpComplianceInfo {
// Certain operations require multiple modes enabled.
// e.g. cast bf16 to fp8e4m3 requires EXT-BF16 and EXT-FP8E4M3.
SmallVector<T> mode;
- SmallVector<SmallVector<TypeInfo>> operandTypeInfoSet;
+ SmallVector<VersionedTypeInfo> operandTypeInfoSet;
CheckCondition condition = CheckCondition::anyOf;
};
@@ -130,9 +133,8 @@ class TosaProfileCompliance {
// Find the required profiles or extensions from the compliance info according
// to the operand type combination.
template <typename T>
- SmallVector<T> findMatchedProfile(Operation *op,
- SmallVector<OpComplianceInfo<T>> compInfo,
- CheckCondition &condition);
+ OpComplianceInfo<T>
+ findMatchedEntry(Operation *op, SmallVector<OpComplianceInfo<T>> compInfo);
SmallVector<Profile> getCooperativeProfiles(Extension ext) {
switch (ext) {
@@ -168,8 +170,7 @@ class TosaProfileCompliance {
private:
template <typename T>
- FailureOr<SmallVector<T>> getOperatorDefinition(Operation *op,
- CheckCondition &condition);
+ FailureOr<OpComplianceInfo<T>> getOperatorDefinition(Operation *op);
OperationProfileComplianceMap profileComplianceMap;
OperationExtensionComplianceMap extensionComplianceMap;
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 6ae19d81e0820..14b00b04ccc18 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -137,6 +137,13 @@ def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> {
];
let options = [
+ Option<"specificationVersion", "specification_version", "mlir::tosa::SpecificationVersion",
+ /*default=*/"mlir::tosa::SpecificationVersion::V_1_0",
+ "The specification version that TOSA operators should conform to.",
+ [{::llvm::cl::values(
+ clEnumValN(mlir::tosa::SpecificationVersion::V_1_0, "1.0", "TOSA Specification version 1.0"),
+ clEnumValN(mlir::tosa::SpecificationVersion::V_1_1_DRAFT, "1.1.draft", "TOSA Specification version 1.1.draft")
+ )}]>,
Option<"level", "level", "mlir::tosa::Level",
/*default=*/"mlir::tosa::Level::eightK",
"The TOSA level that operators should conform to. A TOSA level defines "
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
index 5aad67173cc61..1cba1bb540c02 100644
--- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
+#include "llvm/Support/FormatVariadic.h"
namespace mlir {
namespace tosa {
@@ -27,7 +28,7 @@ TargetEnvAttr lookupTargetEnv(Operation *op) {
}
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) {
- return TargetEnvAttr::get(context, Level::eightK,
+ return TargetEnvAttr::get(context, SpecificationVersion::V_1_0, Level::eightK,
{Profile::pro_int, Profile::pro_fp}, {});
}
@@ -38,5 +39,9 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) {
return getDefaultTargetEnv(op->getContext());
}
+llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) {
+ return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor());
+}
+
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
index bcb880a808b36..a0661e4ee0bd2 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
@@ -61,8 +61,8 @@ class TosaAttachTarget
ModuleOp mod = getOperation();
MLIRContext *ctx = &getContext();
- const auto targetEnvAttr =
- TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions);
+ const auto targetEnvAttr = TargetEnvAttr::get(
+ ctx, specificationVersion, level, selectedProfiles, selectedExtensions);
mod->setAttr(TargetEnvAttr::name, targetEnvAttr);
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 9543fa1fe39d8..058ec4fa5c8e8 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -335,16 +335,15 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
//===----------------------------------------------------------------------===//
template <typename T>
-FailureOr<SmallVector<T>>
-TosaProfileCompliance::getOperatorDefinition(Operation *op,
- CheckCondition &condition) {
+FailureOr<OpComplianceInfo<T>>
+TosaProfileCompliance::getOperatorDefinition(Operation *op) {
const std::string opName = op->getName().getStringRef().str();
const auto complianceMap = getProfileComplianceMap<T>();
const auto it = complianceMap.find(opName);
if (it == complianceMap.end())
return {};
- return findMatchedProfile<T>(op, it->second, condition);
+ return findMatchedEntry<T>(op, it->second);
}
template <typename T>
@@ -356,22 +355,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
if (specRequiredModeSet.size() == 0)
return success();
- CheckCondition condition = CheckCondition::invalid;
- const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition);
- if (failed(maybeOpRequiredMode)) {
+ const auto maybeOpDefinition = getOperatorDefinition<T>(op);
+ if (failed(maybeOpDefinition)) {
// Operators such as control-flow and shape ops do not have an operand type
// restriction. When the profile compliance information of operation is not
// found, confirm if the target have enabled the profile required from the
// specification.
- int mode_count = 0;
+ int modeCount = 0;
for (const auto &cands : specRequiredModeSet) {
if (targetEnv.allowsAnyOf(cands))
return success();
- mode_count += cands.size();
+ modeCount += cands.size();
}
op->emitOpError() << "illegal: requires"
- << (mode_count > 1 ? " any of " : " ") << "["
+ << (modeCount > 1 ? " any of " : " ") << "["
<< llvm::join(stringifyProfile<T>(specRequiredModeSet),
", ")
<< "] but not enabled in target\n";
@@ -381,7 +379,10 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
// Find the required profiles or extensions according to the operand type
// combination.
- const auto opRequiredMode = maybeOpRequiredMode.value();
+ const auto opDefinition = maybeOpDefinition.value();
+ const SmallVector<T> opRequiredMode = opDefinition.mode;
+ const CheckCondition condition = opDefinition.condition;
+
if (opRequiredMode.size() == 0) {
// No matched restriction found.
return success();
@@ -437,6 +438,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
}
}
+ // Ensure the matched op compliance version does not exceed the target
+ // specification version.
+ const VersionedTypeInfo versionedTypeInfo =
+ opDefinition.operandTypeInfoSet[0];
+ const TosaSpecificationVersion complianceVersion{versionedTypeInfo.second};
+ const TosaSpecificationVersion targetVersion{targetEnv.getSpecVersion()};
+ if (!targetVersion.isBackwardsCompatibleWith(complianceVersion)) {
+ op->emitOpError() << "illegal: requires the target specification version ("
+ << stringifyVersion(targetVersion)
+ << ") be backwards compatible with the op compliance "
+ "specification version ("
+ << stringifyVersion(complianceVersion) << ")\n";
+ return failure();
+ }
+
return success();
}
@@ -461,14 +477,14 @@ TosaProfileCompliance::checkExtension(Operation *op,
}
LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
- CheckCondition condition = CheckCondition::invalid;
- const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
- const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
+ const auto maybeProfDef = getOperatorDefinition<Profile>(op);
+ const auto maybeExtDef = getOperatorDefinition<Extension>(op);
if (failed(maybeProfDef) && failed(maybeExtDef))
return success();
- const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) ||
- (succeeded(maybeExtDef) && !maybeExtDef->empty());
+ const bool hasEntry =
+ (succeeded(maybeProfDef) && !maybeProfDef->mode.empty()) ||
+ (succeeded(maybeExtDef) && !maybeExtDef->mode.empty());
if (!hasEntry) {
std::string message;
llvm::raw_string_ostream os(message);
@@ -488,7 +504,9 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
SmallVector<TypeInfo> bestTypeInfo;
const auto searchBestMatch = [&](auto map) {
for (const auto &complianceInfos : map[opName]) {
- for (const auto &typeInfos : complianceInfos.operandTypeInfoSet) {
+ for (const auto &versionedTypeInfos :
+ complianceInfos.operandTypeInfoSet) {
+ const SmallVector<TypeInfo> typeInfos = versionedTypeInfos.first;
const int matches = llvm::count_if(
llvm::zip_equal(current, typeInfos), [&](const auto zipType) {
return isSameTypeInfo(std::get<0>(zipType),
@@ -520,9 +538,8 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
// Find the profiles or extensions requirement according to the signature of
// type of the operand list.
template <typename T>
-SmallVector<T> TosaProfileCompliance::findMatchedProfile(
- Operation *op, SmallVector<OpComplianceInfo<T>> compInfo,
- CheckCondition &condition) {
+OpComplianceInfo<T> TosaProfileCompliance::findMatchedEntry(
+ Operation *op, SmallVector<OpComplianceInfo<T>> compInfo) {
assert(compInfo.size() != 0 &&
"profile-based compliance information is empty");
@@ -533,27 +550,30 @@ SmallVector<T> TosaProfileCompliance::findMatchedProfile(
return {};
for (size_t i = 0; i < compInfo.size(); i++) {
- SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet;
- for (SmallVector<TypeInfo> expected : sets) {
+ SmallVector<VersionedTypeInfo> sets = compInfo[i].operandTypeInfoSet;
+ for (const auto &set : sets) {
+ SmallVector<TypeInfo> expected = set.first;
assert(present.size() == expected.size() &&
"the entries for profile-based compliance do not match between "
"the generated metadata and the type definition retrieved from "
" the operation");
- bool is_found = true;
+ bool isFound = true;
// Compare the type signature between the given operation and the
// compliance metadata.
for (size_t j = 0; j < expected.size(); j++) {
if (!isSameTypeInfo(present[j], expected[j])) {
// Verify the next mode set from the list.
- is_found = false;
+ isFound = false;
break;
}
}
- if (is_found == true) {
- condition = compInfo[i].condition;
- return compInfo[i].mode;
+ if (isFound == true) {
+ SmallVector<VersionedTypeInfo> typeInfoSet{set};
+ OpComplianceInfo<T> info{compInfo[i].mode, typeInfoSet,
+ compInfo[i].condition};
+ return info;
}
}
}
diff --git a/mlir/test/Dialect/Tosa/tosa-attach-target.mlir b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir
index d6c886c44b013..a0c59c0c4bb3b 100644
--- a/mlir/test/Dialect/Tosa/tosa-attach-target.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir
@@ -1,12 +1,14 @@
// RUN: mlir-opt %s -split-input-file -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,dynamic level=none" | FileCheck %s --check-prefix=CHECK-ALL
// RUN: mlir-opt %s -split-input-file -tosa-attach-target="level=8k" | FileCheck %s --check-prefix=CHECK-LVL-8K
// RUN: mlir-opt %s -split-input-file -tosa-attach-target | FileCheck %s --check-prefix=CHECK-DEFAULT
+// RUN: mlir-opt %s -split-input-file -tosa-attach-target="specification_version=1.1.draft" | FileCheck %s --check-prefix=CHECK-VERSION-1P1
// -----
-// CHECK-ALL: module attributes {tosa.target_env = #tosa.target_env<level = none, profiles = [pro_int, pro_fp], extensions = [int16, int4, bf16, fp8e4m3, fp8e5m2, fft, variable, controlflow, doubleround, inexactround, dynamic]>}
-// CHECK-LVL-8K: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>}
-// CHECK-DEFAULT: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>}
+// CHECK-ALL: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.0", level = none, profiles = [pro_int, pro_fp], extensions = [int16, int4, bf16, fp8e4m3, fp8e5m2, fft, variable, controlflow, doubleround, inexactround, dynamic]>}
+// CHECK-LVL-8K: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.0", level = "8k", profiles = [], extensions = []>}
+// CHECK-DEFAULT: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.0", level = "8k", profiles = [], extensions = []>}
+// CHECK-VERSION-1P1: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.1.draft", level = "8k", profiles = [], extensions = []>}
// CHECK-LABEL: test_simple
func.func @test_simple(%arg0 : tensor<1x1x1x1xf32>, %arg1 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> {
%1 = tosa.add %arg0, %arg1 : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
new file mode 100644
index 0000000000000..d4b5b75fbf9fd
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.0 profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
+
+// -----
+
+func.func @test_matmul_fp8_mixed_precision_operands(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> {
+ %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+ %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ // expected-error at +1 {{'tosa.matmul' op illegal: requires the target specification version (1.0) be backwards compatible with the op compliance specification version (1.1)}}
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E5M2>, tensor<1xf8E4M3FN>, tensor<1xf8E5M2>) -> tensor<1x14x28xf16>
+ return %0 : tensor<1x14x28xf16>
+}
+
+// -----
+
+func.func @test_matmul_fp8_input_fp32_acc_type(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf32> {
+ %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+ %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+ // expected-error at +1 {{'tosa.matmul' op illegal: requires the target specification version (1.0) be backwards compatible with the op compliance specification version (1.1)}}
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x14x28xf32>
+ return %0 : tensor<1x14x28xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
new file mode 100644
index 0000000000000..81645092bf195
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" | FileCheck %s
+
+// -----
+
+func.func @test_matmul_fp8_mixed_precision_operands(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> {
+ %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+ %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E5M2>, tensor<1xf8E4M3FN>, tensor<1xf8E5M2>) -> tensor<1x14x28xf16>
+ return %0 : tensor<1x14x28xf16>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_fp8_input_fp32_acc_type
+func.func @test_matmul_fp8_input_fp32_acc_type(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf32> {
+ %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+ %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x14x28xf32>
+ return %0 : tensor<1x14x28xf32>
+}
More information about the Mlir-commits
mailing list