[Mlir-commits] [mlir] [mlir][tosa] Add support for cast_from/to_block_scaled (PR #163436)
Luke Hutton
llvmlistbot at llvm.org
Tue Oct 14 12:16:57 PDT 2025
https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/163436
This commit adds support for the cast_from/to_block_scaled operations from the ext-mxfp extension. This includes:
- Operation definition in TosaOps.td
- Micro-scaling supported types definition
- Shape inference and verifiers
- Validation pass checks to ensure usage is only valid when the target environment includes ext-mxfp and at least v1.1.draft of the specification.
Note: currently it excludes support for mxint8. This will be added in a later commit.
Note: this commit adds support as defined in the spec in https://review.mlplatform.org/c/tosa/specification/+/15362. EXT_MXFP extension is considered experimental and subject to breaking change.
Note: This PR relies on #156425 and #163433 so also contains their contents.
Co-authored-by: Tat Wai Chong <tatwai.chong at arm.com>
>From 0f7a02f264ecd2b48adf4d886cdf00b0ccf2de32 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Fri, 29 Aug 2025 11:23:58 +0000
Subject: [PATCH 1/3] [mlir][tosa] Add specification versioning to target
environment
This commit adds a new "specification_version" field to the TOSA
target environment attribute. This allows a user to specify which
version of the TOSA specification they would like to target during
lowering.
A leading example in the validation pass has also been added. This
addition adds a version to each profile compliance entry to track
which version of the specification the entry was added. This allows
a backwards compatibility check to be implemented between the target
version and the profile compliance entry version.
For now a default version of "1.0" is assumed. "1.1.draft" is added
to denote an in-development version of the specification targeting
the next release.
Change-Id: I6549e05bd4fe975d12ea31e8acc783233db66171
---
mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h | 48 +-
.../Dialect/Tosa/IR/TosaComplianceData.h.inc | 940 ++++++++++++------
.../mlir/Dialect/Tosa/IR/TosaOpBase.td | 33 +-
.../Dialect/Tosa/IR/TosaProfileCompliance.h | 13 +-
.../mlir/Dialect/Tosa/Transforms/Passes.td | 7 +
mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp | 7 +-
.../Tosa/Transforms/TosaAttachTarget.cpp | 4 +-
.../Tosa/Transforms/TosaProfileCompliance.cpp | 74 +-
.../test/Dialect/Tosa/tosa-attach-target.mlir | 8 +-
.../tosa-validation-version-1p0-invalid.mlir | 21 +
.../tosa-validation-version-1p1-valid.mlir | 20 +
11 files changed, 821 insertions(+), 354 deletions(-)
create mode 100644 mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
create mode 100644 mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
index 10491f65d37af..4ecf03c34c1a5 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
@@ -50,28 +50,63 @@ TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);
/// returned by getDefaultTargetEnv() if not provided.
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op);
+/// A thin wrapper around the SpecificationVersion enum to represent
+/// and provide utilities around the TOSA specification version.
+class TosaSpecificationVersion {
+public:
+ TosaSpecificationVersion(uint32_t major, uint32_t minor)
+ : majorVersion(major), minorVersion(minor) {}
+ TosaSpecificationVersion(SpecificationVersion version)
+ : TosaSpecificationVersion(fromVersionEnum(version)) {}
+
+ bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const {
+ return this->majorVersion == baseVersion.majorVersion &&
+ this->minorVersion >= baseVersion.minorVersion;
+ }
+
+ uint32_t getMajor() const { return majorVersion; }
+ uint32_t getMinor() const { return minorVersion; }
+
+private:
+ uint32_t majorVersion = 0;
+ uint32_t minorVersion = 0;
+
+ static TosaSpecificationVersion
+ fromVersionEnum(SpecificationVersion version) {
+ switch (version) {
+ case SpecificationVersion::V_1_0:
+ return TosaSpecificationVersion(1, 0);
+ case SpecificationVersion::V_1_1_DRAFT:
+ return TosaSpecificationVersion(1, 1);
+ }
+ llvm_unreachable("Unknown TOSA version");
+ }
+};
+
+llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version);
+
/// This class represents the capability enabled in the target implementation
/// such as profile, extension, and level. It's a wrapper class around
/// tosa::TargetEnvAttr.
class TargetEnv {
public:
TargetEnv() {}
- explicit TargetEnv(Level level, const ArrayRef<Profile> &profiles,
+ explicit TargetEnv(SpecificationVersion specificationVersion, Level level,
+ const ArrayRef<Profile> &profiles,
const ArrayRef<Extension> &extensions)
- : level(level) {
+ : specificationVersion(specificationVersion), level(level) {
enabledProfiles.insert_range(profiles);
enabledExtensions.insert_range(extensions);
}
explicit TargetEnv(TargetEnvAttr targetAttr)
- : TargetEnv(targetAttr.getLevel(), targetAttr.getProfiles(),
- targetAttr.getExtensions()) {}
+ : TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(),
+ targetAttr.getProfiles(), targetAttr.getExtensions()) {}
void addProfile(Profile p) { enabledProfiles.insert(p); }
void addExtension(Extension e) { enabledExtensions.insert(e); }
- // TODO implement the following utilities.
- // Version getSpecVersion() const;
+ SpecificationVersion getSpecVersion() const { return specificationVersion; }
TosaLevel getLevel() const {
if (level == Level::eightK)
@@ -105,6 +140,7 @@ class TargetEnv {
}
private:
+ SpecificationVersion specificationVersion;
Level level;
llvm::SmallSet<Profile, 3> enabledProfiles;
llvm::SmallSet<Extension, 13> enabledExtensions;
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index 1f718accabd15..c1b5e785bd739 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -2,441 +2,779 @@
// `tools/genspec.py` in https://git.mlplatform.org/tosa/specification.git
profileComplianceMap = {
{"tosa.argmax",
- {{{Profile::pro_int}, {{i8T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, i32T}, {fp32T, i32T}}}}},
+ {{{Profile::pro_int}, {{{i8T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, i32T}, SpecificationVersion::V_1_0},
+ {{fp32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.avg_pool2d",
- {{{Profile::pro_int}, {{i8T, i8T, i8T, i32T, i8T}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i8T, i32T, i8T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp32T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.conv2d",
- {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.conv3d",
- {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.depthwise_conv2d",
- {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.matmul",
- {{{Profile::pro_int}, {{i8T, i8T, i8T, i8T, i32T}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i8T, i8T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp16T, fp32T},
- {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp16T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.max_pool2d",
- {{{Profile::pro_int}, {{i8T, i8T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int}, {{{i8T, i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.transpose_conv2d",
- {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.clamp",
- {{{Profile::pro_int}, {{i8T, i8T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.erf", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.sigmoid", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.tanh", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int}, {{{i8T, i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.erf",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.sigmoid",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.tanh",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.add",
- {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.arithmetic_right_shift",
{{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.bitwise_and",
{{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.bitwise_or",
{{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.bitwise_xor",
{{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.intdiv",
- {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.logical_and",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.logical_left_shift",
{{{Profile::pro_int, Profile::pro_fp},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
anyOf}}},
{"tosa.logical_right_shift",
{{{Profile::pro_int, Profile::pro_fp},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
anyOf}}},
{"tosa.logical_or",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.logical_xor",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.maximum",
- {{{Profile::pro_int}, {{i32T, i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{Profile::pro_int}, {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.minimum",
- {{{Profile::pro_int}, {{i32T, i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{Profile::pro_int}, {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.mul",
- {{{Profile::pro_int}, {{i8T, i8T, i32T}, {i16T, i16T, i32T}}},
- {{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i32T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_int, Profile::pro_fp},
+ {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.pow",
- {{{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.sub",
- {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
- {"tosa.table", {{{Profile::pro_int}, {{i8T, i8T, i8T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.table",
+ {{{Profile::pro_int}, {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}}}}},
{"tosa.abs",
- {{{Profile::pro_int}, {{i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int}, {{{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.bitwise_not",
- {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}}},
- {"tosa.ceil", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.clz", {{{Profile::pro_int}, {{i32T, i32T}}}}},
- {"tosa.cos", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.exp", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.floor", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.log", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.ceil",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.clz",
+ {{{Profile::pro_int}, {{{i32T, i32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.cos",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.exp",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.floor",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.log",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.logical_not",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.negate",
{{{Profile::pro_int},
- {{i8T, i8T, i8T, i8T},
- {i16T, i16T, i16T, i16T},
- {i32T, i32T, i32T, i32T}}},
+ {{{i8T, i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reciprocal",
- {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.rsqrt", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.sin", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.rsqrt",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.sin",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.select",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.equal",
- {{{Profile::pro_int}, {{i32T, i32T, boolT}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}},
+ {{{Profile::pro_int},
+ {{{i32T, i32T, boolT}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, boolT}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, boolT}, SpecificationVersion::V_1_0}}}}},
{"tosa.greater",
- {{{Profile::pro_int}, {{i32T, i32T, boolT}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}},
+ {{{Profile::pro_int},
+ {{{i32T, i32T, boolT}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, boolT}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, boolT}, SpecificationVersion::V_1_0}}}}},
{"tosa.greater_equal",
- {{{Profile::pro_int}, {{i32T, i32T, boolT}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}},
+ {{{Profile::pro_int},
+ {{{i32T, i32T, boolT}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, boolT}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, boolT}, SpecificationVersion::V_1_0}}}}},
{"tosa.reduce_all",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.reduce_any",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.reduce_max",
- {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reduce_min",
- {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reduce_product",
- {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reduce_sum",
- {{{Profile::pro_int}, {{i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int}, {{{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.concat",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.pad",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reshape",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reverse",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.slice",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.tile",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.transpose",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.gather",
{{{Profile::pro_int},
- {{i8T, i32T, i8T}, {i16T, i32T, i16T}, {i32T, i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, i32T, fp16T}, {fp32T, i32T, fp32T}}}}},
+ {{{i8T, i32T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i32T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, i32T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, i32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.scatter",
{{{Profile::pro_int},
- {{i8T, i32T, i8T, i8T},
- {i16T, i32T, i16T, i16T},
- {i32T, i32T, i32T, i32T}}},
+ {{{i8T, i32T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i32T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, i32T, fp16T, fp16T}, {fp32T, i32T, fp32T, fp32T}}}}},
+ {{{fp16T, i32T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, i32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.resize",
- {{{Profile::pro_int}, {{i8T, i32T}, {i8T, i8T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int},
+ {{{i8T, i32T}, SpecificationVersion::V_1_0},
+ {{i8T, i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.cast",
{{{Profile::pro_int},
- {{boolT, i8T},
- {boolT, i16T},
- {boolT, i32T},
- {i8T, boolT},
- {i8T, i16T},
- {i8T, i32T},
- {i16T, boolT},
- {i16T, i8T},
- {i16T, i32T},
- {i32T, boolT},
- {i32T, i8T},
- {i32T, i16T}}},
- {{Profile::pro_fp},
- {{i8T, fp16T},
- {i8T, fp32T},
- {i16T, fp16T},
- {i16T, fp32T},
- {i32T, fp16T},
- {i32T, fp32T},
- {fp16T, i8T},
- {fp16T, i16T},
- {fp16T, i32T},
- {fp16T, fp32T},
- {fp32T, i8T},
- {fp32T, i16T},
- {fp32T, i32T},
- {fp32T, fp16T}}}}},
+ {{{boolT, i8T}, SpecificationVersion::V_1_0},
+ {{boolT, i16T}, SpecificationVersion::V_1_0},
+ {{boolT, i32T}, SpecificationVersion::V_1_0},
+ {{i8T, boolT}, SpecificationVersion::V_1_0},
+ {{i8T, i16T}, SpecificationVersion::V_1_0},
+ {{i8T, i32T}, SpecificationVersion::V_1_0},
+ {{i16T, boolT}, SpecificationVersion::V_1_0},
+ {{i16T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i32T}, SpecificationVersion::V_1_0},
+ {{i32T, boolT}, SpecificationVersion::V_1_0},
+ {{i32T, i8T}, SpecificationVersion::V_1_0},
+ {{i32T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{i8T, fp16T}, SpecificationVersion::V_1_0},
+ {{i8T, fp32T}, SpecificationVersion::V_1_0},
+ {{i16T, fp16T}, SpecificationVersion::V_1_0},
+ {{i16T, fp32T}, SpecificationVersion::V_1_0},
+ {{i32T, fp16T}, SpecificationVersion::V_1_0},
+ {{i32T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp16T, i8T}, SpecificationVersion::V_1_0},
+ {{fp16T, i16T}, SpecificationVersion::V_1_0},
+ {{fp16T, i32T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp32T, i8T}, SpecificationVersion::V_1_0},
+ {{fp32T, i16T}, SpecificationVersion::V_1_0},
+ {{fp32T, i32T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.rescale",
{{{Profile::pro_int},
- {{i8T, i8T, i8T, i8T},
- {i8T, i8T, i16T, i16T},
- {i8T, i8T, i32T, i32T},
- {i16T, i16T, i8T, i8T},
- {i16T, i16T, i16T, i16T},
- {i16T, i16T, i32T, i32T},
- {i32T, i32T, i8T, i8T},
- {i32T, i32T, i16T, i16T},
- {i32T, i32T, i32T, i32T}}}}},
+ {{{i8T, i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i8T, i8T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i32T, i32T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.const",
{{{Profile::pro_int, Profile::pro_fp},
- {{boolT}, {i8T}, {i16T}, {i32T}},
+ {{{boolT}, SpecificationVersion::V_1_0},
+ {{i8T}, SpecificationVersion::V_1_0},
+ {{i16T}, SpecificationVersion::V_1_0},
+ {{i32T}, SpecificationVersion::V_1_0}},
anyOf},
- {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
+ {{Profile::pro_fp},
+ {{{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.identity",
{{{Profile::pro_int, Profile::pro_fp},
- {{boolT, boolT}, {i8T, i8T}, {i16T, i16T}, {i32T, i32T}},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0},
+ {{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}},
anyOf},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.variable",
- {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
+ {{{Profile::pro_int}, {{{i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.variable_write",
- {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
+ {{{Profile::pro_int}, {{{i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.variable_read",
- {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
+ {{{Profile::pro_int}, {{{i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
};
extensionComplianceMap = {
{"tosa.argmax",
- {{{Extension::int16}, {{i16T, i32T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, i32T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, i32T}}},
- {{Extension::bf16}, {{bf16T, i32T}}}}},
+ {{{Extension::int16}, {{{i16T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3}, {{{fp8e4m3T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2}, {{{fp8e5m2T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.avg_pool2d",
- {{{Extension::int16}, {{i16T, i16T, i16T, i32T, i16T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+ {{{Extension::int16},
+ {{{i16T, i16T, i16T, i32T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T},
+ SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T},
+ SpecificationVersion::V_1_0}}},
+ {{Extension::bf16},
+ {{{bf16T, bf16T, bf16T, fp32T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.conv2d",
- {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}},
- {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}},
+ {{{Extension::int4},
+ {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16},
+ {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
- {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}},
+ {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e5m2},
- {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}},
+ {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::bf16},
- {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+ {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.conv3d",
- {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}},
- {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}},
+ {{{Extension::int4},
+ {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16},
+ {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
- {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}},
+ {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e5m2},
- {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}},
+ {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::bf16},
- {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+ {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.depthwise_conv2d",
- {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}},
- {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}},
+ {{{Extension::int4},
+ {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16},
+ {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
- {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}},
+ {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e5m2},
- {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}},
+ {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::bf16},
- {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
- {"tosa.fft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
+ SpecificationVersion::V_1_0}}}}},
+ {"tosa.fft2d",
+ {{{Extension::fft},
+ {{{fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.matmul",
- {{{Extension::int16}, {{i16T, i16T, i16T, i16T, i48T}}},
+ {{{Extension::int16},
+ {{{i16T, i16T, i16T, i16T, i48T}, SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
- {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T},
- {fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp32T}}},
+ {{{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp32T},
+ SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::fp8e5m2},
- {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T},
- {fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp32T}}},
+ {{{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp32T},
+ SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::fp8e4m3, Extension::fp8e5m2},
- {{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp16T},
- {fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp32T},
- {fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp16T},
- {fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp32T}},
+ {{{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp16T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp32T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp16T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp32T},
+ SpecificationVersion::V_1_1_DRAFT}},
allOf},
- {{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T, fp32T}}}}},
+ {{Extension::bf16},
+ {{{bf16T, bf16T, bf16T, bf16T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.max_pool2d",
- {{{Extension::int16}, {{i16T, i16T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.rfft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T}}}}},
+ {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.rfft2d",
+ {{{Extension::fft},
+ {{{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.transpose_conv2d",
- {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}},
- {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}},
+ {{{Extension::int4},
+ {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16},
+ {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
- {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}},
+ {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e5m2},
- {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}},
+ {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::bf16},
- {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+ {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.clamp",
- {{{Extension::int16}, {{i16T, i16T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.erf", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.sigmoid", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.tanh", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.add", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.maximum", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.minimum", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.mul", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.pow", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.sub", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.table", {{{Extension::int16}, {{i16T, i16T, i32T}}}}},
- {"tosa.abs", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.ceil", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.cos", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.exp", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.floor", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.log", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T}}}}},
- {"tosa.reciprocal", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.rsqrt", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.sin", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.select", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}},
- {"tosa.greater", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}},
- {"tosa.greater_equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}},
- {"tosa.reduce_max", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.reduce_min", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.reduce_product", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.reduce_sum", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.erf",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.sigmoid",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.tanh",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.add",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.maximum",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.minimum",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.mul",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.pow",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.sub",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.table",
+ {{{Extension::int16},
+ {{{i16T, i16T, i32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.abs",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.ceil",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.cos",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.exp",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.floor",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.log",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.negate",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reciprocal",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.rsqrt",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.sin",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.select",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.equal",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, boolT}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.greater",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, boolT}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.greater_equal",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, boolT}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reduce_max",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reduce_min",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reduce_product",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reduce_sum",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.concat",
- {{{Extension::int16}, {{i16T, i16T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.pad",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reshape",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reverse",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.slice",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.tile",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.transpose",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.gather",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, i32T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, i32T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, i32T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, i32T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, i32T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16},
+ {{{bf16T, i32T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.scatter",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, i32T, fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, i32T, fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, i32T, bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, i32T, fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, i32T, fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16},
+ {{{bf16T, i32T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.resize",
- {{{Extension::int16}, {{i16T, i48T}, {i16T, i16T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::int16},
+ {{{i16T, i48T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.cast",
{{{Extension::bf16},
- {{i8T, bf16T},
- {i16T, bf16T},
- {i32T, bf16T},
- {bf16T, i8T},
- {bf16T, i16T},
- {bf16T, i32T},
- {bf16T, fp32T},
- {fp32T, bf16T}}},
+ {{{i8T, bf16T}, SpecificationVersion::V_1_0},
+ {{i16T, bf16T}, SpecificationVersion::V_1_0},
+ {{i32T, bf16T}, SpecificationVersion::V_1_0},
+ {{bf16T, i8T}, SpecificationVersion::V_1_0},
+ {{bf16T, i16T}, SpecificationVersion::V_1_0},
+ {{bf16T, i32T}, SpecificationVersion::V_1_0},
+ {{bf16T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp32T, bf16T}, SpecificationVersion::V_1_0}}},
{{Extension::bf16, Extension::fp8e4m3},
- {{bf16T, fp8e4m3T}, {fp8e4m3T, bf16T}},
+ {{{bf16T, fp8e4m3T}, SpecificationVersion::V_1_0},
+ {{fp8e4m3T, bf16T}, SpecificationVersion::V_1_0}},
allOf},
{{Extension::bf16, Extension::fp8e5m2},
- {{bf16T, fp8e5m2T}, {fp8e5m2T, bf16T}},
+ {{{bf16T, fp8e5m2T}, SpecificationVersion::V_1_0},
+ {{fp8e5m2T, bf16T}, SpecificationVersion::V_1_0}},
allOf},
{{Extension::fp8e4m3},
- {{fp8e4m3T, fp16T},
- {fp8e4m3T, fp32T},
- {fp16T, fp8e4m3T},
- {fp32T, fp8e4m3T}}},
+ {{{fp8e4m3T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp8e4m3T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp8e4m3T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
{{Extension::fp8e5m2},
- {{fp8e5m2T, fp16T},
- {fp8e5m2T, fp32T},
- {fp16T, fp8e5m2T},
- {fp32T, fp8e5m2T}}}}},
+ {{{fp8e5m2T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp8e5m2T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp8e5m2T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp8e5m2T}, SpecificationVersion::V_1_0}}}}},
{"tosa.rescale",
{{{Extension::int16},
- {{i48T, i48T, i8T, i8T},
- {i48T, i48T, i16T, i16T},
- {i48T, i48T, i32T, i32T}}}}},
+ {{{i48T, i48T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i48T, i48T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i48T, i48T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.const",
- {{{Extension::int4}, {{i4T}}},
- {{Extension::int16}, {{i48T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T}}}}},
+ {{{Extension::int4}, {{{i4T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16}, {{{i48T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3}, {{{fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2}, {{{fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.identity",
- {{{Extension::int4}, {{i4T, i4T}}},
- {{Extension::int16}, {{i48T, i48T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.variable", {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}},
+ {{{Extension::int4}, {{{i4T, i4T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16}, {{{i48T, i48T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.variable",
+ {{{Extension::variable},
+ {{{i8T}, SpecificationVersion::V_1_0},
+ {{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.variable_write",
- {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}},
+ {{{Extension::variable},
+ {{{i8T}, SpecificationVersion::V_1_0},
+ {{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.variable_read",
- {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}},
+ {{{Extension::variable},
+ {{{i8T}, SpecificationVersion::V_1_0},
+ {{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
};
+
// End of auto-generated metadata
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 38cb2936ad8d9..8376a4c87dbf2 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -221,7 +221,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
}
//===----------------------------------------------------------------------===//
-// TOSA Spec Section 1.5.
+// TOSA Profiles and extensions
//
// Profile:
// INT : Integer Inference. Integer operations, primarily 8 and 32-bit values.
@@ -293,12 +293,6 @@ def Tosa_ExtensionAttr
def Tosa_ExtensionArrayAttr
: TypedArrayAttrBase<Tosa_ExtensionAttr, "TOSA extension array attribute">;
-def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>;
-def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">;
-
-def Tosa_LevelAttr
- : Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>;
-
// The base class for defining op availability dimensions.
class Availability {
// The following are fields for controlling the generated C++ OpInterface.
@@ -404,18 +398,41 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability {
let instance = "ref";
}
+//===----------------------------------------------------------------------===//
+// TOSA Levels
+//===----------------------------------------------------------------------===//
+
+def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>;
+def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">;
+
+def Tosa_LevelAttr
+ : Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>;
+
+//===----------------------------------------------------------------------===//
+// TOSA Specification versions
+//===----------------------------------------------------------------------===//
+
+def Tosa_V_1_0 : I32EnumAttrCase<"V_1_0", 0, "1.0">;
+def Tosa_V_1_1_DRAFT : I32EnumAttrCase<"V_1_1_DRAFT", 1, "1.1.draft">;
+
+def Tosa_SpecificationVersion : Tosa_I32EnumAttr<
+ "SpecificationVersion", "TOSA specification version", "specification_version",
+ [Tosa_V_1_0, Tosa_V_1_1_DRAFT]>;
+
//===----------------------------------------------------------------------===//
// TOSA target environment.
//===----------------------------------------------------------------------===//
def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> {
let summary = "Target environment information.";
let parameters = ( ins
+ "SpecificationVersion": $specification_version,
"Level": $level,
ArrayRefParameter<"Profile">: $profiles,
ArrayRefParameter<"Extension">: $extensions
);
- let assemblyFormat = "`<` `level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` "
+ let assemblyFormat = "`<` `specification_version` `=` $specification_version `,` "
+ "`level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` "
"`extensions` `=` `[` $extensions `]` `>`";
}
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 8f5c72bc5f7a9..7b946ad6c6a89 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -36,12 +36,15 @@ enum CheckCondition {
allOf
};
+using VersionedTypeInfo =
+ std::pair<SmallVector<TypeInfo>, SpecificationVersion>;
+
template <typename T>
struct OpComplianceInfo {
// Certain operations require multiple modes enabled.
// e.g. cast bf16 to fp8e4m3 requires EXT-BF16 and EXT-FP8E4M3.
SmallVector<T> mode;
- SmallVector<SmallVector<TypeInfo>> operandTypeInfoSet;
+ SmallVector<VersionedTypeInfo> operandTypeInfoSet;
CheckCondition condition = CheckCondition::anyOf;
};
@@ -130,9 +133,8 @@ class TosaProfileCompliance {
// Find the required profiles or extensions from the compliance info according
// to the operand type combination.
template <typename T>
- SmallVector<T> findMatchedProfile(Operation *op,
- SmallVector<OpComplianceInfo<T>> compInfo,
- CheckCondition &condition);
+ OpComplianceInfo<T>
+ findMatchedEntry(Operation *op, SmallVector<OpComplianceInfo<T>> compInfo);
SmallVector<Profile> getCooperativeProfiles(Extension ext) {
switch (ext) {
@@ -168,8 +170,7 @@ class TosaProfileCompliance {
private:
template <typename T>
- FailureOr<SmallVector<T>> getOperatorDefinition(Operation *op,
- CheckCondition &condition);
+ FailureOr<OpComplianceInfo<T>> getOperatorDefinition(Operation *op);
OperationProfileComplianceMap profileComplianceMap;
OperationExtensionComplianceMap extensionComplianceMap;
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 6ae19d81e0820..14b00b04ccc18 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -137,6 +137,13 @@ def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> {
];
let options = [
+ Option<"specificationVersion", "specification_version", "mlir::tosa::SpecificationVersion",
+ /*default=*/"mlir::tosa::SpecificationVersion::V_1_0",
+ "The specification version that TOSA operators should conform to.",
+ [{::llvm::cl::values(
+ clEnumValN(mlir::tosa::SpecificationVersion::V_1_0, "1.0", "TOSA Specification version 1.0"),
+ clEnumValN(mlir::tosa::SpecificationVersion::V_1_1_DRAFT, "1.1.draft", "TOSA Specification version 1.1.draft")
+ )}]>,
Option<"level", "level", "mlir::tosa::Level",
/*default=*/"mlir::tosa::Level::eightK",
"The TOSA level that operators should conform to. A TOSA level defines "
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
index 5aad67173cc61..1cba1bb540c02 100644
--- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
+#include "llvm/Support/FormatVariadic.h"
namespace mlir {
namespace tosa {
@@ -27,7 +28,7 @@ TargetEnvAttr lookupTargetEnv(Operation *op) {
}
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) {
- return TargetEnvAttr::get(context, Level::eightK,
+ return TargetEnvAttr::get(context, SpecificationVersion::V_1_0, Level::eightK,
{Profile::pro_int, Profile::pro_fp}, {});
}
@@ -38,5 +39,9 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) {
return getDefaultTargetEnv(op->getContext());
}
+llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) {
+ return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor());
+}
+
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
index bcb880a808b36..a0661e4ee0bd2 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
@@ -61,8 +61,8 @@ class TosaAttachTarget
ModuleOp mod = getOperation();
MLIRContext *ctx = &getContext();
- const auto targetEnvAttr =
- TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions);
+ const auto targetEnvAttr = TargetEnvAttr::get(
+ ctx, specificationVersion, level, selectedProfiles, selectedExtensions);
mod->setAttr(TargetEnvAttr::name, targetEnvAttr);
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 20f9333e7c951..f072e3eff1975 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -335,16 +335,15 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
//===----------------------------------------------------------------------===//
template <typename T>
-FailureOr<SmallVector<T>>
-TosaProfileCompliance::getOperatorDefinition(Operation *op,
- CheckCondition &condition) {
+FailureOr<OpComplianceInfo<T>>
+TosaProfileCompliance::getOperatorDefinition(Operation *op) {
const std::string opName = op->getName().getStringRef().str();
const auto complianceMap = getProfileComplianceMap<T>();
const auto it = complianceMap.find(opName);
if (it == complianceMap.end())
return {};
- return findMatchedProfile<T>(op, it->second, condition);
+ return findMatchedEntry<T>(op, it->second);
}
template <typename T>
@@ -356,22 +355,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
if (specRequiredModeSet.size() == 0)
return success();
- CheckCondition condition = CheckCondition::invalid;
- const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition);
- if (failed(maybeOpRequiredMode)) {
+ const auto maybeOpDefinition = getOperatorDefinition<T>(op);
+ if (failed(maybeOpDefinition)) {
// Operators such as control-flow and shape ops do not have an operand type
// restriction. When the profile compliance information of operation is not
// found, confirm if the target have enabled the profile required from the
// specification.
- int mode_count = 0;
+ int modeCount = 0;
for (const auto &cands : specRequiredModeSet) {
if (targetEnv.allowsAnyOf(cands))
return success();
- mode_count += cands.size();
+ modeCount += cands.size();
}
op->emitOpError() << "illegal: requires"
- << (mode_count > 1 ? " any of " : " ") << "["
+ << (modeCount > 1 ? " any of " : " ") << "["
<< llvm::join(stringifyProfile<T>(specRequiredModeSet),
", ")
<< "] but not enabled in target\n";
@@ -381,7 +379,10 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
// Find the required profiles or extensions according to the operand type
// combination.
- const auto opRequiredMode = maybeOpRequiredMode.value();
+ const auto opDefinition = maybeOpDefinition.value();
+ const SmallVector<T> opRequiredMode = opDefinition.mode;
+ const CheckCondition condition = opDefinition.condition;
+
if (opRequiredMode.size() == 0) {
// No matched restriction found.
return success();
@@ -437,6 +438,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
}
}
+ // Ensure the matched op compliance version does not exceed the target
+ // specification version.
+ const VersionedTypeInfo versionedTypeInfo =
+ opDefinition.operandTypeInfoSet[0];
+ const TosaSpecificationVersion complianceVersion{versionedTypeInfo.second};
+ const TosaSpecificationVersion targetVersion{targetEnv.getSpecVersion()};
+ if (!targetVersion.isBackwardsCompatibleWith(complianceVersion)) {
+ op->emitOpError() << "illegal: the target specification version ("
+ << stringifyVersion(targetVersion)
+ << ") is not backwards compatible with the op compliance "
+ "specification version ("
+ << stringifyVersion(complianceVersion) << ")\n";
+ return failure();
+ }
+
return success();
}
@@ -461,14 +477,14 @@ TosaProfileCompliance::checkExtension(Operation *op,
}
LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
- CheckCondition condition = CheckCondition::invalid;
- const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
- const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
+ const auto maybeProfDef = getOperatorDefinition<Profile>(op);
+ const auto maybeExtDef = getOperatorDefinition<Extension>(op);
if (failed(maybeProfDef) && failed(maybeExtDef))
return success();
- const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) ||
- (succeeded(maybeExtDef) && !maybeExtDef->empty());
+ const bool hasEntry =
+ (succeeded(maybeProfDef) && !maybeProfDef->mode.empty()) ||
+ (succeeded(maybeExtDef) && !maybeExtDef->mode.empty());
if (!hasEntry) {
std::string message;
llvm::raw_string_ostream os(message);
@@ -488,7 +504,9 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
SmallVector<TypeInfo> bestTypeInfo;
const auto searchBestMatch = [&](auto map) {
for (const auto &complianceInfos : map[opName]) {
- for (const auto &typeInfos : complianceInfos.operandTypeInfoSet) {
+ for (const auto &versionedTypeInfos :
+ complianceInfos.operandTypeInfoSet) {
+ const SmallVector<TypeInfo> typeInfos = versionedTypeInfos.first;
const int matches = llvm::count_if(
llvm::zip_equal(current, typeInfos), [&](const auto zipType) {
return isSameTypeInfo(std::get<0>(zipType),
@@ -520,9 +538,8 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
// Find the profiles or extensions requirement according to the signature of
// type of the operand list.
template <typename T>
-SmallVector<T> TosaProfileCompliance::findMatchedProfile(
- Operation *op, SmallVector<OpComplianceInfo<T>> compInfo,
- CheckCondition &condition) {
+OpComplianceInfo<T> TosaProfileCompliance::findMatchedEntry(
+ Operation *op, SmallVector<OpComplianceInfo<T>> compInfo) {
assert(compInfo.size() != 0 &&
"profile-based compliance information is empty");
@@ -533,27 +550,30 @@ SmallVector<T> TosaProfileCompliance::findMatchedProfile(
return {};
for (size_t i = 0; i < compInfo.size(); i++) {
- SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet;
- for (SmallVector<TypeInfo> expected : sets) {
+ SmallVector<VersionedTypeInfo> sets = compInfo[i].operandTypeInfoSet;
+ for (const auto &set : sets) {
+ SmallVector<TypeInfo> expected = set.first;
assert(present.size() == expected.size() &&
"the entries for profile-based compliance do not match between "
"the generated metadata and the type definition retrieved from "
" the operation");
- bool is_found = true;
+ bool isFound = true;
// Compare the type signature between the given operation and the
// compliance metadata.
for (size_t j = 0; j < expected.size(); j++) {
if (!isSameTypeInfo(present[j], expected[j])) {
// Verify the next mode set from the list.
- is_found = false;
+ isFound = false;
break;
}
}
- if (is_found == true) {
- condition = compInfo[i].condition;
- return compInfo[i].mode;
+ if (isFound == true) {
+ SmallVector<VersionedTypeInfo> typeInfoSet{set};
+ OpComplianceInfo<T> info{compInfo[i].mode, typeInfoSet,
+ compInfo[i].condition};
+ return info;
}
}
}
diff --git a/mlir/test/Dialect/Tosa/tosa-attach-target.mlir b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir
index d6c886c44b013..a0c59c0c4bb3b 100644
--- a/mlir/test/Dialect/Tosa/tosa-attach-target.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir
@@ -1,12 +1,14 @@
// RUN: mlir-opt %s -split-input-file -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,dynamic level=none" | FileCheck %s --check-prefix=CHECK-ALL
// RUN: mlir-opt %s -split-input-file -tosa-attach-target="level=8k" | FileCheck %s --check-prefix=CHECK-LVL-8K
// RUN: mlir-opt %s -split-input-file -tosa-attach-target | FileCheck %s --check-prefix=CHECK-DEFAULT
+// RUN: mlir-opt %s -split-input-file -tosa-attach-target="specification_version=1.1.draft" | FileCheck %s --check-prefix=CHECK-VERSION-1P1
// -----
-// CHECK-ALL: module attributes {tosa.target_env = #tosa.target_env<level = none, profiles = [pro_int, pro_fp], extensions = [int16, int4, bf16, fp8e4m3, fp8e5m2, fft, variable, controlflow, doubleround, inexactround, dynamic]>}
-// CHECK-LVL-8K: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>}
-// CHECK-DEFAULT: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>}
+// CHECK-ALL: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.0", level = none, profiles = [pro_int, pro_fp], extensions = [int16, int4, bf16, fp8e4m3, fp8e5m2, fft, variable, controlflow, doubleround, inexactround, dynamic]>}
+// CHECK-LVL-8K: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.0", level = "8k", profiles = [], extensions = []>}
+// CHECK-DEFAULT: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.0", level = "8k", profiles = [], extensions = []>}
+// CHECK-VERSION-1P1: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.1.draft", level = "8k", profiles = [], extensions = []>}
// CHECK-LABEL: test_simple
func.func @test_simple(%arg0 : tensor<1x1x1x1xf32>, %arg1 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> {
%1 = tosa.add %arg0, %arg1 : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
new file mode 100644
index 0000000000000..51089df238b84
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.0 profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
+
+// -----
+
+func.func @test_matmul_fp8_mixed_precision_operands(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> {
+ %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+ %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ // expected-error at +1 {{'tosa.matmul' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}}
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E5M2>, tensor<1xf8E4M3FN>, tensor<1xf8E5M2>) -> tensor<1x14x28xf16>
+ return %0 : tensor<1x14x28xf16>
+}
+
+// -----
+
+func.func @test_matmul_fp8_input_fp32_acc_type(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf32> {
+ %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+ %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+ // expected-error at +1 {{'tosa.matmul' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}}
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x14x28xf32>
+ return %0 : tensor<1x14x28xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
new file mode 100644
index 0000000000000..81645092bf195
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" | FileCheck %s
+
+// -----
+
+func.func @test_matmul_fp8_mixed_precision_operands(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> {
+ %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+ %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E5M2>, tensor<1xf8E4M3FN>, tensor<1xf8E5M2>) -> tensor<1x14x28xf16>
+ return %0 : tensor<1x14x28xf16>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_fp8_input_fp32_acc_type
+func.func @test_matmul_fp8_input_fp32_acc_type(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf32> {
+ %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+ %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x14x28xf32>
+ return %0 : tensor<1x14x28xf32>
+}
>From 47f8fd14af94cd5fb58d09a14e215fd4b71cc4c4 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Thu, 4 Sep 2025 15:21:52 +0000
Subject: [PATCH 2/3] [mlir][tosa] Add support for matmul_t_block_scaled
This commit adds support for the MATMUL_T_BLOCK_SCALED
operation from the EXT_MXFP extension. This includes:
- Operation definition in TosaOps.td
- Micro-scaling supported types definition
- Shape inference and verifiers
- Validation pass checks to ensure usage is only valid when
the target environment includes ext-mxfp and at least
v1.1.draft of the specification.
As part of this commit, a notion of EXT_MXFP is also added.
The extension can be specified as part of the target environment
and can only be used if the specification version is at least 1.1.
Note: currently it excludes support for mxint8. This will be
added in a later commit.
Note: this commit adds support as defined in the spec in
https://review.mlplatform.org/c/tosa/specification/+/15362. EXT_MXFP
extension is considered experimental and subject to breaking change.
Co-authored-by: Tat Wai Chong <tatwai.chong at arm.com>
Change-Id: I92afdea87eef1eea444dfebf9f74796f3a236809
---
mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h | 37 ++-
.../Dialect/Tosa/IR/TosaComplianceData.h.inc | 7 +
.../mlir/Dialect/Tosa/IR/TosaOpBase.td | 22 +-
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 34 +++
.../Dialect/Tosa/IR/TosaProfileCompliance.h | 1 +
.../mlir/Dialect/Tosa/IR/TosaTypesBase.td | 13 ++
mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp | 94 +++++++-
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 219 ++++++++++++++++--
.../Tosa/Transforms/TosaProfileCompliance.cpp | 15 ++
.../Tosa/Transforms/TosaValidation.cpp | 19 +-
mlir/test/Dialect/Tosa/invalid_extension.mlir | 8 +
mlir/test/Dialect/Tosa/level_check.mlir | 9 +
mlir/test/Dialect/Tosa/ops.mlir | 42 ++++
.../Tosa/profile_pro_fp_unsupported.mlir | 9 +-
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 54 +++++
.../tosa-validation-version-1p1-valid.mlir | 10 +-
mlir/test/Dialect/Tosa/verifier.mlir | 48 ++++
17 files changed, 596 insertions(+), 45 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
index 4ecf03c34c1a5..e088eb31338dc 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
@@ -54,6 +54,8 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op);
/// and provide utilities around the TOSA specification version.
class TosaSpecificationVersion {
public:
+ TosaSpecificationVersion() = default;
+
TosaSpecificationVersion(uint32_t major, uint32_t minor)
: majorVersion(major), minorVersion(minor) {}
TosaSpecificationVersion(SpecificationVersion version)
@@ -83,6 +85,10 @@ class TosaSpecificationVersion {
}
};
+TosaSpecificationVersion getMinVersion(const Profile &profile);
+TosaSpecificationVersion getMinVersion(const Extension &extension);
+TosaSpecificationVersion getMinVersion(const Level &level);
+
llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version);
/// This class represents the capability enabled in the target implementation
@@ -91,22 +97,19 @@ llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version);
class TargetEnv {
public:
TargetEnv() {}
- explicit TargetEnv(SpecificationVersion specificationVersion, Level level,
- const ArrayRef<Profile> &profiles,
- const ArrayRef<Extension> &extensions)
- : specificationVersion(specificationVersion), level(level) {
- enabledProfiles.insert_range(profiles);
- enabledExtensions.insert_range(extensions);
- }
- explicit TargetEnv(TargetEnvAttr targetAttr)
- : TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(),
- targetAttr.getProfiles(), targetAttr.getExtensions()) {}
+ static FailureOr<TargetEnv>
+ createTargetEnvFromAttr(TargetEnvAttr targetAttr, Location targetEnvAttrLoc);
+
+ static LogicalResult verifyTargetInformation(TargetEnvAttr targetAttr,
+ Location targetAttrLoc);
void addProfile(Profile p) { enabledProfiles.insert(p); }
void addExtension(Extension e) { enabledExtensions.insert(e); }
- SpecificationVersion getSpecVersion() const { return specificationVersion; }
+ TosaSpecificationVersion getSpecVersion() const {
+ return specificationVersion;
+ }
TosaLevel getLevel() const {
if (level == Level::eightK)
@@ -140,7 +143,17 @@ class TargetEnv {
}
private:
- SpecificationVersion specificationVersion;
+ // Require target information is verified before constructing, via the use of
+ // `createTargetEnvFromAttr`.
+ explicit TargetEnv(SpecificationVersion specificationVersion, Level level,
+ const ArrayRef<Profile> &profiles,
+ const ArrayRef<Extension> &extensions)
+ : specificationVersion(specificationVersion), level(level) {
+ enabledProfiles.insert_range(profiles);
+ enabledExtensions.insert_range(extensions);
+ }
+
+ TosaSpecificationVersion specificationVersion;
Level level;
llvm::SmallSet<Profile, 3> enabledProfiles;
llvm::SmallSet<Extension, 13> enabledExtensions;
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index c1b5e785bd739..ed25d97e8c271 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -554,6 +554,13 @@ extensionComplianceMap = {
allOf},
{{Extension::bf16},
{{{bf16T, bf16T, bf16T, bf16T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.matmul_t_block_scaled",
+ {{{Extension::mxfp},
+ {{{fp4e2m1T, fp8ue8m0T, fp4e2m1T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e2m3T, fp8ue8m0T, fp6e2m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e3m2T, fp8ue8m0T, fp6e3m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e4m3T, fp8ue8m0T, fp8e4m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e5m2T, fp8ue8m0T, fp8e5m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}},
{"tosa.max_pool2d",
{{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 8376a4c87dbf2..8b6edef08db20 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -270,13 +270,14 @@ def Tosa_EXT_CONTROLFLOW : I32EnumAttrCase<"controlflow", 8>;
def Tosa_EXT_DOUBLEROUND : I32EnumAttrCase<"doubleround", 9>;
def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>;
def Tosa_EXT_DYNAMIC : I32EnumAttrCase<"dynamic", 11>;
+def Tosa_EXT_MXFP : I32EnumAttrCase<"mxfp", 12>;
def Tosa_ExtensionAttr
: Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [
Tosa_EXT_NONE, Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16,
Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE,
Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND,
- Tosa_EXT_DYNAMIC
+ Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP
]> {
let extraClassDeclaration = [{
static llvm::SmallVector<Extension, 11> getAllValues() {
@@ -284,7 +285,7 @@ def Tosa_ExtensionAttr
Extension::int16, Extension::int4, Extension::bf16,
Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft,
Extension::variable, Extension::controlflow, Extension::doubleround,
- Extension::inexactround, Extension::dynamic
+ Extension::inexactround, Extension::dynamic, Extension::mxfp
};
}
}];
@@ -437,7 +438,7 @@ def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> {
}
//===----------------------------------------------------------------------===//
-// Iterable attributes.
+// Enum attributes.
//===----------------------------------------------------------------------===//
// Defined in `section 3. Enumerations` of the TOSA specification.
@@ -463,6 +464,21 @@ def Tosa_RoundingModeAttr
: Tosa_I32EnumAttr<"RoundingMode", "Supported rounding modes", "rounding_mode",
[Tosa_ROUNDING_SINGLE_ROUND, Tosa_ROUNDING_INEXACT_ROUND, Tosa_ROUNDING_DOUBLE_ROUND]>;
+def Tosa_BLOCK_SIZE_32 : I32EnumAttrCase<"BLOCK_SIZE_32", 1>;
+
+def Tosa_BlockSizeAttr
+ : Tosa_I32EnumAttr<"BlockSize", "Block size for the block_scaled formats", "block_size",
+ [Tosa_BLOCK_SIZE_32]> {
+ let extraClassDeclaration = [{
+ static unsigned int getBlockSizeValue(BlockSize blockSize) {
+ switch (blockSize) {
+ case BlockSize::BLOCK_SIZE_32:
+ return 32;
+ }
+ }
+ }];
+}
+
//===----------------------------------------------------------------------===//
// TOSA Interfaces.
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 48759f2a3c9e8..a5251fcada4c9 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -348,6 +348,40 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
"operands attr-dict `:` functional-type(operands, results)";
}
+//===----------------------------------------------------------------------===//
+// Operator: matmul_t_block_scaled
+//===----------------------------------------------------------------------===//
+def Tosa_MatmulTBlockScaledOp : Tosa_InferShapedTypeOp<"matmul_t_block_scaled"> {
+ let summary = "Performs two dimensional matrix multiplications using block scaled tensors.";
+
+ let description = [{
+ Performs two dimensional matrix multiplications using block scaled tensors. The block
+ dimension is always the the last dimension of the tensor, so the result is effectively
+ a matrix multiply of A by the transposed B matrix. If the N dimension of input B is of
+ size 1, the B matrix will be broadcast.
+ }];
+
+ let arguments = (ins
+ Tosa_MXFPDataTensor3D:$a_data,
+ Tosa_MXFPScaleTensor3D:$a_scale,
+ Tosa_MXFPDataTensor3D:$b_data,
+ Tosa_MXFPScaleTensor3D:$b_scale,
+ Tosa_BlockSizeAttr:$block_size
+ );
+
+ let results = (outs
+ Tosa_Tensor3D:$output_data
+ );
+
+ let hasVerifier = 1;
+ let hasCustomAssemblyFormat = 1;
+
+ list<Availability> availability = [
+ Profile<[Tosa_PRO_FP]>,
+ Extension<[Tosa_EXT_MXFP]>
+ ];
+}
+
//===----------------------------------------------------------------------===//
// Operator: max_pool2d
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 7b946ad6c6a89..79df1b888b40e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -147,6 +147,7 @@ class TosaProfileCompliance {
case Extension::fp8e4m3:
case Extension::fp8e5m2:
case Extension::fft:
+ case Extension::mxfp:
return {Profile::pro_fp};
case Extension::variable:
case Extension::controlflow:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 93ab120339d55..20bb961482ad8 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -84,6 +84,10 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
"number">;
+def Tosa_MXFPNumber : AnyTypeOf<[F8E4M3FN, F8E5M2, F4E2M1FN, F6E2M3FN, F6E3M2FN],
+ "micro-scaling format number">;
+def Tosa_MXFPScaleNumber : AnyTypeOf<[F8E8M0FNU], "micro-scaling format scale number">;
+
//===----------------------------------------------------------------------===//
// TOSA Tensor Conformance
//===----------------------------------------------------------------------===//
@@ -187,6 +191,15 @@ def Tosa_Int32Tensor2D : AnyTypeOf<[
def Tosa_TensorAtLeast1D : AnyTypeOf<[
Tosa_UnrankedTensor, TosaRankedTensorOf<[Tosa_AnyNumber], [AtLeastRankOne]>], "tosa-conformant tensor of at least rank 1", "::mlir::TensorType">;
+def Tosa_MXFPDataTensor3D : AnyTypeOf<[
+ TosaUnrankedTensorOf<[Tosa_MXFPNumber]>,
+ TosaTensorRankOf<[Tosa_MXFPNumber], [3]>
+]>;
+def Tosa_MXFPScaleTensor3D : AnyTypeOf<[
+ TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>,
+ TosaTensorRankOf<[Tosa_MXFPScaleNumber], [3]>
+]>;
+
//===----------------------------------------------------------------------===//
// Generic scalar, vector, or tensor of a particular type.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
index 1cba1bb540c02..32eb286531d28 100644
--- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -12,6 +12,96 @@
namespace mlir {
namespace tosa {
+llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) {
+ return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor());
+}
+
+TosaSpecificationVersion getMinVersion(const Profile &profile) {
+ switch (profile) {
+ case Profile::pro_int:
+ case Profile::pro_fp:
+ return TosaSpecificationVersion(1, 0);
+ case Profile::none:
+ return TosaSpecificationVersion(0, 0);
+ }
+ llvm_unreachable("Unknown TOSA profile");
+}
+
+TosaSpecificationVersion getMinVersion(const Extension &extension) {
+ switch (extension) {
+ case Extension::int16:
+ case Extension::int4:
+ case Extension::bf16:
+ case Extension::fp8e4m3:
+ case Extension::fp8e5m2:
+ case Extension::fft:
+ case Extension::variable:
+ case Extension::controlflow:
+ case Extension::doubleround:
+ case Extension::inexactround:
+ case Extension::dynamic:
+ return TosaSpecificationVersion(1, 0);
+ case Extension::mxfp:
+ return TosaSpecificationVersion(1, 1);
+ case Extension::none:
+ return TosaSpecificationVersion(0, 0);
+ }
+ llvm_unreachable("Unknown TOSA extension");
+}
+
+TosaSpecificationVersion getMinVersion(const Level &level) {
+ switch (level) {
+ case Level::eightK:
+ case Level::none:
+ return TosaSpecificationVersion(1, 0);
+ }
+ llvm_unreachable("Unknown TOSA level");
+}
+
+FailureOr<TargetEnv>
+TargetEnv::createTargetEnvFromAttr(TargetEnvAttr targetAttr,
+ Location targetEnvAttrLoc) {
+ if (failed(verifyTargetInformation(targetAttr, targetEnvAttrLoc)))
+ return failure();
+
+ return TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(),
+ targetAttr.getProfiles(), targetAttr.getExtensions());
+}
+
+LogicalResult TargetEnv::verifyTargetInformation(TargetEnvAttr targetAttr,
+ Location targetAttrLoc) {
+ TosaSpecificationVersion targetVersion(targetAttr.getSpecificationVersion());
+
+ const auto isCompatibleWithTargetVersion =
+ [&](const auto &targetEnum, Location targetAttrLoc,
+ StringRef enumName) -> LogicalResult {
+ const TosaSpecificationVersion minRequiredVersion =
+ getMinVersion(targetEnum);
+ if (!targetVersion.isBackwardsCompatibleWith(minRequiredVersion))
+ return emitError(targetAttrLoc, enumName)
+ << " '" << stringifyEnum(targetEnum)
+ << "' is not compatible with the target version "
+ << stringifyVersion(targetVersion)
+ << ", minimum required version is "
+ << stringifyVersion(minRequiredVersion);
+ return success();
+ };
+
+ for (const auto &profile : targetAttr.getProfiles())
+ if (failed(
+ isCompatibleWithTargetVersion(profile, targetAttrLoc, "profile")))
+ return failure();
+ for (const auto &extension : targetAttr.getExtensions())
+ if (failed(isCompatibleWithTargetVersion(extension, targetAttrLoc,
+ "extension")))
+ return failure();
+ if (failed(isCompatibleWithTargetVersion(targetAttr.getLevel(), targetAttrLoc,
+ "level")))
+ return failure();
+
+ return success();
+}
+
TargetEnvAttr lookupTargetEnv(Operation *op) {
while (op) {
op = SymbolTable::getNearestSymbolTable(op);
@@ -39,9 +129,5 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) {
return getDefaultTargetEnv(op->getContext());
}
-llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) {
- return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor());
-}
-
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c51b5e9cbfc78..6479fe10cf68d 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -321,6 +321,19 @@ ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser,
}
}
+ // special handling: block_size accepts a *bare* BlockSizeMode enum
+ if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) {
+ if (name == "block_size" && succeeded(parser.parseOptionalKeyword(&kw))) {
+ auto sym = symbolizeBlockSize(kw);
+ if (!sym)
+ return parser.emitError(parser.getCurrentLocation())
+ << "invalid block_size value: " << kw;
+ auto attr = BlockSizeAttr::get(parser.getContext(), sym.value());
+ outAttrs.push_back(NamedAttribute(name, attr));
+ return success();
+ }
+ }
+
// Default path: parse any normal attribute literal, including fully qualified
// enum keyword
Attribute attr;
@@ -373,6 +386,8 @@ void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) {
} else if (auto nanPropagationModeAttr =
dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
parser << nanPropagationModeAttr.getValue();
+ } else if (auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) {
+ parser << blockSizeAttr.getValue();
} else {
parser.printAttribute(attr);
}
@@ -508,6 +523,15 @@ void ReduceMinOp::print(OpAsmPrinter &parser) {
printWithNanPropagationHandling(parser, *this);
}
+ParseResult MatmulTBlockScaledOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseWithEnumHandling<tosa::BlockSize>(parser, result);
+}
+
+void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
//===----------------------------------------------------------------------===//
// Tosa utilities.
//===----------------------------------------------------------------------===//
@@ -933,32 +957,35 @@ static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
// verify that inType and outType have same element types
template <typename T>
-static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
- auto inputType = llvm::dyn_cast<TensorType>(inType);
- auto outputType = llvm::dyn_cast<TensorType>(outType);
- if (!inputType) {
- op.emitOpError("expect shaped tensor for input, got ") << inType;
+static LogicalResult verifySameElementTypes(T op, Type aType, Type bType,
+ StringRef aName = "input",
+ StringRef bName = "output") {
+ auto aTType = llvm::dyn_cast<TensorType>(aType);
+ auto bTType = llvm::dyn_cast<TensorType>(bType);
+ if (!aTType) {
+ op.emitOpError("expect shaped tensor for") << aName << ", got " << aType;
return failure();
}
- if (!outputType) {
- op.emitOpError("expect shaped tensor for output, got ") << outType;
+ if (!bTType) {
+ op.emitOpError("expect shaped tensor for") << bName << ", got" << bType;
return failure();
}
- auto inputElementType = inputType.getElementType();
- auto outputElementType = outputType.getElementType();
- auto inputQuantType =
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
- auto outputQuantType =
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
- if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
- (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
- inputElementType != outputElementType) {
+ auto aElementType = aTType.getElementType();
+ auto bElementType = bTType.getElementType();
+ auto aQuantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType);
+ auto bQuantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType);
+ if ((aElementType.isIntOrIndexOrFloat() || aQuantType) &&
+ (bElementType.isIntOrIndexOrFloat() || bQuantType) &&
+ aElementType != bElementType) {
// only check if both element types are int/index/float/UniformQuantized
// eg, not sure how to check quant::QuantizedType
// this happens in test_conv2d_q_grouped_convolution in
// tfl-to-tosa-pipeline.mlir
- op.emitOpError("expect input and output to have same element type, got ")
- << inputElementType << " and " << outputElementType;
+ op.emitOpError("expect ")
+ << aName << " and " << bName << " to have same element type, got "
+ << aElementType << " and " << bElementType;
return failure();
}
return success();
@@ -1846,6 +1873,162 @@ LogicalResult MatMulOp::verify() {
return success();
}
+LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ MatmulTBlockScaledOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ SmallVector<int64_t, 3> outShape(3, ShapedType::kDynamic);
+
+ const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType());
+ if (aDataShape.hasRank()) {
+ outShape[0] = aDataShape.getDimSize(0);
+ outShape[1] = aDataShape.getDimSize(1);
+ }
+
+ const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType());
+ if (aScaleShape.hasRank()) {
+ outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0)
+ : outShape[0];
+ outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1)
+ : outShape[1];
+ }
+
+ // If B batch size is 1, it is broadcast across A's batch size
+ const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType());
+ if (bDataShape.hasRank()) {
+ const int64_t bDataBatchSize = bDataShape.getDimSize(0);
+ if (bDataBatchSize != 1)
+ outShape[0] =
+ ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0];
+ outShape[2] = bDataShape.getDimSize(1);
+ }
+
+ const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType());
+ if (bScaleShape.hasRank()) {
+ const int64_t bScaleBatchSize = bScaleShape.getDimSize(0);
+ if (bScaleBatchSize != 1)
+ outShape[0] =
+ ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0];
+ outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1)
+ : outShape[2];
+ }
+
+ inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
+ return success();
+}
+
+LogicalResult MatmulTBlockScaledOp::verify() {
+ // Verify same input data types
+ const Type aDataType = getAData().getType();
+ const Type bDataType = getBData().getType();
+ if (failed(verifySameElementTypes(*this, aDataType, bDataType, "A_data",
+ "B_data")))
+ return failure();
+
+ auto tryUpdateDimOrFailure = [&](int64_t &currDim, const int64_t newDim,
+ const StringRef operandName,
+ const StringRef dimName) -> LogicalResult {
+ if (ShapedType::isDynamic(currDim)) {
+ currDim = newDim;
+ return success();
+ } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
+ return emitOpError("expected ")
+ << dimName << " of " << operandName << " to match size " << currDim
+ << ", got " << newDim;
+ }
+ return success();
+ };
+
+ // Verify input shape compatibility
+ int64_t N = ShapedType::kDynamic;
+ int64_t D = ShapedType::kDynamic;
+ int64_t H = ShapedType::kDynamic;
+ int64_t W = ShapedType::kDynamic;
+ int64_t C = ShapedType::kDynamic;
+ int64_t multiplesOfC = ShapedType::kDynamic;
+
+ const ShapeAdaptor aDataShape = ShapeAdaptor(aDataType);
+ if (aDataShape.hasRank()) {
+ N = aDataShape.getDimSize(0);
+ H = aDataShape.getDimSize(1);
+ C = aDataShape.getDimSize(2);
+ }
+
+ const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType());
+ if (aScaleShape.hasRank()) {
+ if (failed(tryUpdateDimOrFailure(N, aScaleShape.getDimSize(0), "a_scale",
+ "batch")) ||
+ failed(tryUpdateDimOrFailure(H, aScaleShape.getDimSize(1), "a_scale",
+ "height")))
+ return failure();
+ multiplesOfC = aScaleShape.getDimSize(2);
+ }
+
+ const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType);
+ if (bDataShape.hasRank()) {
+ if (failed(tryUpdateDimOrFailure(D, bDataShape.getDimSize(0), "b_data",
+ "batch")) ||
+ failed(tryUpdateDimOrFailure(C, bDataShape.getDimSize(2), "b_data",
+ "channels")))
+ return failure();
+ W = bDataShape.getDimSize(1);
+ }
+
+ const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType());
+ if (bScaleShape.hasRank()) {
+ if (failed(tryUpdateDimOrFailure(D, bScaleShape.getDimSize(0), "b_scale",
+ "batch")) ||
+ failed(tryUpdateDimOrFailure(W, bScaleShape.getDimSize(1), "b_scale",
+ "width")) ||
+ failed(tryUpdateDimOrFailure(multiplesOfC, bScaleShape.getDimSize(2),
+ "b_scale", "C/block_size")))
+ return failure();
+ }
+
+ // Verify batch size is broadcast compatible
+ if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1)
+ return emitOpError("expect B matrix batch size to be broadcast compatible "
+ "with A, got D=")
+ << D << " vs N=" << N;
+
+ // Verify C is a multiple of block size
+ const unsigned int blockSize =
+ BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ if (ShapedType::isStatic(C) && C % blockSize != 0)
+ return emitOpError("expect C to be a multiple of block size, got C=")
+ << C << ", block_size=" << blockSize;
+
+ // Verify multiplesOfC is C / block size
+ if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) &&
+ multiplesOfC != C / blockSize)
+ return emitOpError(
+ "expect scale operands dimension 2 to equal C/block_size (")
+ << C << "/" << blockSize << ")"
+ << ", got " << multiplesOfC;
+
+ // Verify output shape
+ N = ShapedType::isDynamic(N) ? D : N;
+ const SmallVector<int64_t, 3> expectedOutputShape = {N, H, W};
+ const auto outputType = cast<ShapedType>(getResult().getType());
+ if (outputType.hasRank() &&
+ failed(
+ verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) {
+ InFlightDiagnostic opError = emitOpError("expected output shape ");
+ auto stringifyDim = [&](int64_t d) {
+ if (ShapedType::isDynamic(d))
+ opError << "?";
+ else
+ opError << d;
+ };
+ llvm::interleaveComma(outputType.getShape(), opError, stringifyDim);
+ opError << " to be compatible with expected output shape ";
+ llvm::interleaveComma(expectedOutputShape, opError, stringifyDim);
+ return opError;
+ }
+
+ return success();
+}
+
LogicalResult tosa::PadOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
PadOp::Adaptor adaptor,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index f072e3eff1975..e965ae0cf9888 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -25,6 +25,12 @@ TosaProfileCompliance::TosaProfileCompliance() {
const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};
+ // micro-scaling formats
+ const TypeInfo fp6e2m3T = {mlir::Float6E2M3FNType::getTypeID(), 6};
+ const TypeInfo fp6e3m2T = {mlir::Float6E3M2FNType::getTypeID(), 6};
+ const TypeInfo fp4e2m1T = {mlir::Float4E2M1FNType::getTypeID(), 4};
+ const TypeInfo fp8ue8m0T = {mlir::Float8E8M0FNUType::getTypeID(), 8};
+
// The profile-based compliance content below is auto-generated by a script
// in https://git.mlplatform.org/tosa/specification.git
#include "mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc"
@@ -269,6 +275,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// For the most of tosa operators, all operands are profile/extension related
// and hence are all considered in this profile-based compilance check.
+ POPULATE_PROFILE_INFO_COMMON(MatmulTBlockScaled)
POPULATE_PROFILE_INFO_COMMON(Cast)
POPULATE_PROFILE_INFO_COMMON(Const)
POPULATE_PROFILE_INFO_COMMON(ArgMax)
@@ -623,6 +630,14 @@ TosaProfileCompliance::stringifyTypeInfo(const TypeInfo &typeInfo) {
return {"fp8e4m3"};
} else if (typeInfo.typeID == mlir::Float8E5M2Type::getTypeID()) {
return {"fp8e5m2"};
+ } else if (typeInfo.typeID == mlir::Float6E2M3FNType::getTypeID()) {
+ return {"fp6e2m3"};
+ } else if (typeInfo.typeID == mlir::Float6E3M2FNType::getTypeID()) {
+ return {"fp6e3m2"};
+ } else if (typeInfo.typeID == mlir::Float4E2M1FNType::getTypeID()) {
+ return {"fp4e2m1"};
+ } else if (typeInfo.typeID == mlir::Float8E8M0FNUType::getTypeID()) {
+ return {"fp8e8m0"};
}
llvm_unreachable("unknown type");
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 82f2f7eb17af4..3f874d94ab9be 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -657,6 +657,7 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
CHECK_SIZES(TransposeConv2D);
CHECK_SIZES(FFT2d);
CHECK_SIZES(MatMul);
+ CHECK_SIZES(MatmulTBlockScaled);
CHECK_SIZES(MaxPool2d);
CHECK_SIZES(RFFT2d);
// Scatter/Gather Operators
@@ -1192,9 +1193,9 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
if (isa<FloatType>(type)) {
return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
- Float8E5M2Type>(type);
- }
- if (auto intTy = dyn_cast<IntegerType>(type)) {
+ Float8E5M2Type, Float4E2M1FNType, Float6E2M3FNType,
+ Float6E3M2FNType, Float8E8M0FNUType>(type);
+ } else if (auto intTy = dyn_cast<IntegerType>(type)) {
if (intTy.isSignless()) {
switch (intTy.getWidth()) {
case 1:
@@ -1220,13 +1221,19 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
}
void TosaValidation::runOnOperation() {
+ ModuleOp modOp = getOperation();
+ const TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(modOp);
+ const auto maybeTargetEnv =
+ tosa::TargetEnv::createTargetEnvFromAttr(targetEnvAttr, modOp.getLoc());
+ if (failed(maybeTargetEnv))
+ return signalPassFailure();
+ targetEnv = *maybeTargetEnv;
+
TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>();
if (!tosaDialect)
return;
- targetEnv = tosa::TargetEnv(lookupTargetEnvOrDefault(getOperation()));
-
- getOperation().walk([&](Operation *op) {
+ modOp.walk([&](Operation *op) {
if (op->getDialect() != tosaDialect)
return;
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index e5c9402caaddc..005601d4017b8 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -538,3 +538,11 @@ func.func @test_avg_pool2d_non_const_output_zp(%arg0: tensor<1x32x32x8xf32>, %ou
(tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32>
return %0 : tensor<1x32x32x8xf32>
}
+
+// -----
+
+func.func @test_matmul_t_block_scaled(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E4M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+ // expected-error at +1 {{'tosa.matmul_t_block_scaled' op illegal: requires [mxfp] but not enabled in target}}
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E4M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+ return %0 : tensor<4x8x16xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 8cc357efa0c77..8771e6e2476e4 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1622,3 +1622,12 @@ func.func @test_unranked_weight_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor
%0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf32>, tensor<*xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
+
+// -----
+
+// CHECK-LABEL: test_matmul_t_block_scaled_invalid_size
+func.func @test_matmul_t_block_scaled_invalid_size(%arg0: tensor<4x8x536870912xf4E2M1FN>, %arg1: tensor<4x8x16777216xf8E8M0FNU>, %arg2: tensor<4x16x536870912xf4E2M1FN>, %arg3: tensor<4x16x16777216xf8E8M0FNU>) -> tensor<*xf32> {
+ // expected-error at +1 {{'tosa.matmul_t_block_scaled' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x536870912xf4E2M1FN>, tensor<4x8x16777216xf8E8M0FNU>, tensor<4x16x536870912xf4E2M1FN>, tensor<4x16x16777216xf8E8M0FNU>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 868b7b7a93335..9bf36b5fd4c7d 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -1226,3 +1226,45 @@ func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x29x3xf8E4M3FN>, %arg1: tensor<
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x29x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x29x3xf8E4M3FN>
return %0 : tensor<13x29x3xf8E4M3FN>
}
+
+// -----
+// CHECK-LABEL: test_matmul_t_block_scaled_static
+func.func @test_matmul_t_block_scaled_static(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E4M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E4M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+ return %0 : tensor<4x8x16xf32>
+}
+
+// -----
+// CHECK-LABEL: test_matmul_t_block_scaled_unranked
+func.func @test_matmul_t_block_scaled_unranked(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<*xf8E8M0FNU>) -> tensor<*xf32> {
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+// CHECK-LABEL: test_matmul_t_block_scaled_fp6e3m2
+func.func @test_matmul_t_block_scaled_fp6e3m2(%arg0: tensor<4x8x32xf6E3M2FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf6E3M2FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf6E3M2FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E3M2FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+ return %0 : tensor<4x8x16xf32>
+}
+
+// -----
+// CHECK-LABEL: test_matmul_t_block_scaled_fp6e2m3
+func.func @test_matmul_t_block_scaled_fp6e2m3(%arg0: tensor<4x8x32xf6E2M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf6E2M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf6E2M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E2M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+ return %0 : tensor<4x8x16xf32>
+}
+
+// -----
+// CHECK-LABEL: test_matmul_t_block_scaled_fp4e2m1
+func.func @test_matmul_t_block_scaled_fp4e2m1(%arg0: tensor<4x8x32xf4E2M1FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf4E2M1FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf4E2M1FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf4E2M1FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+ return %0 : tensor<4x8x16xf32>
+}
+
+// -----
+// CHECK-LABEL: test_matmul_t_block_scaled_broadcast
+func.func @test_matmul_t_block_scaled_broadcast(%arg0: tensor<?x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<?x16x32xf8E4M3FN>, %arg3: tensor<1x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<?x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<?x16x32xf8E4M3FN>, tensor<1x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+ return %0 : tensor<4x8x16xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
index 7ff8065ee41fd..0271d71561a52 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
@@ -2,7 +2,7 @@
// Enable all supported extensions to focus the verification of expected profile requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround,mxfp" -tosa-validate="strict-op-spec-alignment"
// -----
func.func @test_const_f16() -> tensor<3x11x11x3xf16> {
@@ -325,3 +325,10 @@ func.func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> {
%1 = tosa.resize %arg0, %scale, %offset, %border { mode = BILINEAR } : (tensor<1x32x32x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x64x64x8xf32>
return %1 : tensor<1x64x64x8xf32>
}
+
+// -----
+func.func @test_matmul_t_block_scaled(%arg0: tensor<4x8x32xf6E3M2FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf6E3M2FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+ // expected-error at +1 {{'tosa.matmul_t_block_scaled' op illegal: requires [pro_fp] but not enabled in target}}
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf6E3M2FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E3M2FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+ return %0 : tensor<4x8x16xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 80f06f11fe4ad..72479fe21ade8 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1574,3 +1574,57 @@ func.func @test_mul_scalar(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<*xf
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
+
+// -----
+
+// CHECK-LABEL: test_matmul_t_block_scaled_static
+func.func @test_matmul_t_block_scaled_static(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<1x16x32xf8E4M3FN>, %arg3: tensor<1x16x1xf8E8M0FNU>) -> tensor<?x?x?xf32> {
+ // CHECK: -> tensor<4x8x16xf32>
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<1x16x32xf8E4M3FN>, tensor<1x16x1xf8E8M0FNU>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_t_block_scaled_unranked_a_data
+func.func @test_matmul_t_block_scaled_unranked_a_data(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E4M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<?x?x?xf32> {
+ // CHECK: -> tensor<4x8x16xf32>
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E4M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_t_block_scaled_unranked_b_data_and_scale
+func.func @test_matmul_t_block_scaled_unranked_b_data_and_scale(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<*xf8E8M0FNU>) -> tensor<?x?x?xf32> {
+ // CHECK: -> tensor<4x8x?xf32>
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_t_block_scaled_unranked_all
+func.func @test_matmul_t_block_scaled_unranked_all(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<*xf8E8M0FNU>) -> tensor<?x?x?xf32> {
+ // CHECK: -> tensor<?x?x?xf32>
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_t_block_scaled_broadcast_b_data
+func.func @test_matmul_t_block_scaled_broadcast_b_data(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<1x4x32xf8E4M3FN>, %arg3: tensor<1x4x1xf8E8M0FNU>) -> tensor<?x?x?xf32> {
+ // CHECK: -> tensor<?x?x4xf32>
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>, tensor<1x4x32xf8E4M3FN>, tensor<1x4x1xf8E8M0FNU>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_t_block_scaled_broadcast_b_scale
+func.func @test_matmul_t_block_scaled_broadcast_b_scale(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<1x4x1xf8E8M0FNU>) -> tensor<?x?x?xf32> {
+ // CHECK: -> tensor<?x?x4xf32>
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<1x4x1xf8E8M0FNU>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index 81645092bf195..2040a4bc7e6af 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,mxfp" -tosa-validate="strict-op-spec-alignment" | FileCheck %s
// -----
@@ -18,3 +18,11 @@ func.func @test_matmul_fp8_input_fp32_acc_type(%arg0: tensor<1x14x19xf8E4M3FN>,
%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x14x28xf32>
return %0 : tensor<1x14x28xf32>
}
+
+// -----
+
+// CHECK-LABEL: test_matmul_t_block_scaled_fp6e2m3
+func.func @test_matmul_t_block_scaled_fp6e2m3(%arg0: tensor<4x8x32xf6E2M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf6E2M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = BLOCK_SIZE_32} : (tensor<4x8x32xf6E2M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E2M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+ return %0 : tensor<4x8x16xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 430b06ad16c39..4be5d725ad612 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1102,3 +1102,51 @@ func.func @scatter_invalid_K_W(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<2x6xi32
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<2x6xi32>, tensor<2x6x5xi32>) -> tensor<2x4x5xi32>
return
}
+
+// -----
+
+func.func @test_matmul_t_block_scaled_data_mismatch(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E5M2>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+ // expected-error at +1 {{'tosa.matmul_t_block_scaled' op expect A_data and B_data to have same element type, got 'f8E4M3FN' and 'f8E5M2'}}
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E5M2>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+ return %0 : tensor<4x8x16xf32>
+}
+
+// -----
+
+func.func @test_matmul_t_block_scaled_output_batch_mismatch(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<?x8x1xf8E8M0FNU>, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<4x?x?xf8E8M0FNU>) -> tensor<5x?x?xf32> {
+ // expected-error at +1 {{'tosa.matmul_t_block_scaled' op expected output shape 5, ?, ? to be compatible with expected output shape 4, 8, ?}}
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<*xf8E4M3FN>, tensor<?x8x1xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<4x?x?xf8E8M0FNU>) -> tensor<5x?x?xf32>
+ return %0 : tensor<5x?x?xf32>
+}
+
+// -----
+
+func.func @test_matmul_t_block_scaled_output_height_mismatch(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<?x9x1xf8E8M0FNU>, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<4x?x?xf8E8M0FNU>) -> tensor<4x8x?xf32> {
+ // expected-error at +1 {{'tosa.matmul_t_block_scaled' op expected output shape 4, 8, ? to be compatible with expected output shape 4, 9, ?}}
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<*xf8E4M3FN>, tensor<?x9x1xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<4x?x?xf8E8M0FNU>) -> tensor<4x8x?xf32>
+ return %0 : tensor<4x8x?xf32>
+}
+
+// -----
+
+func.func @test_matmul_t_block_scaled_output_width_mismatch(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<?x?x1xf8E8M0FNU>, %arg2: tensor<?x1x?xf8E4M3FN>, %arg3: tensor<*xf8E8M0FNU>) -> tensor<?x?x10xf32> {
+ // expected-error at +1 {{'tosa.matmul_t_block_scaled' op expected output shape ?, ?, 10 to be compatible with expected output shape ?, ?, 1}}
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<*xf8E4M3FN>, tensor<?x?x1xf8E8M0FNU>, tensor<?x1x?xf8E4M3FN>, tensor<*xf8E8M0FNU>) -> tensor<?x?x10xf32>
+ return %0 : tensor<?x?x10xf32>
+}
+
+// -----
+
+func.func @test_matmul_t_block_scaled_channel_not_multiple_of_block_size(%arg0: tensor<4x8x55xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E4M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+ // expected-error at +1 {{'tosa.matmul_t_block_scaled' op expected channels of b_data to match size 55, got 32}}
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x8x55xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E4M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+ return %0 : tensor<4x8x16xf32>
+}
+
+// -----
+
+func.func @test_matmul_t_block_scaled_batch_mismatch(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<2x16x32xf8E4M3FN>, %arg3: tensor<2x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+ // expected-error at +1 {{'tosa.matmul_t_block_scaled' op expect B matrix batch size to be broadcast compatible with A, got D=2 vs N=4}}
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<2x16x32xf8E4M3FN>, tensor<2x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+ return %0 : tensor<4x8x16xf32>
+}
>From b7bc15a3fe6cf4bf839a873c644c841b218f880e Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Mon, 15 Sep 2025 21:52:24 +0000
Subject: [PATCH 3/3] [mlir][tosa] Add support for cast_from/to_block_scaled
This commit adds support for the cast_from/to_block_scaled
operations from the ext-mxfp extension. This includes:
- Operation definition in TosaOps.td
- Micro-scaling supported types definition
- Shape inference and verifiers
- Validation pass checks to ensure usage is only valid when
the target environment includes ext-mxfp and at least
v1.1.draft of the specification.
Note: currently it excludes support for mxint8. This will be
added in a later commit.
Note: this commit adds support as defined in the spec in
https://review.mlplatform.org/c/tosa/specification/+/15362. EXT_MXFP
extension is considered experimental and subject to breaking change.
Co-authored-by: Tat Wai Chong <tatwai.chong at arm.com>
Change-Id: I490645ce99b7ccd7021ed06acaf1530b4fbf6dfd
---
.../Dialect/Tosa/IR/TosaComplianceData.h.inc | 26 +++
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 63 +++++++
.../Dialect/Tosa/IR/TosaProfileCompliance.h | 2 +-
.../mlir/Dialect/Tosa/IR/TosaTypesBase.td | 10 ++
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 159 +++++++++++++++++-
.../Tosa/Transforms/TosaProfileCompliance.cpp | 32 +---
.../Tosa/Transforms/TosaValidation.cpp | 2 +
mlir/test/Dialect/Tosa/availability.mlir | 18 ++
mlir/test/Dialect/Tosa/invalid_extension.mlir | 16 ++
mlir/test/Dialect/Tosa/level_check.mlir | 33 +++-
mlir/test/Dialect/Tosa/ops.mlir | 28 +++
.../Tosa/profile_pro_fp_unsupported.mlir | 14 ++
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 45 +++++
.../tosa-validation-version-1p1-valid.mlir | 24 +++
mlir/test/Dialect/Tosa/verifier.mlir | 88 +++++++++-
15 files changed, 526 insertions(+), 34 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index ed25d97e8c271..9eaf0847802cb 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -748,6 +748,32 @@ extensionComplianceMap = {
{{fp8e5m2T, fp32T}, SpecificationVersion::V_1_0},
{{fp16T, fp8e5m2T}, SpecificationVersion::V_1_0},
{{fp32T, fp8e5m2T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.cast_from_block_scaled",
+ {{{Extension::bf16, Extension::mxfp},
+ {{{fp4e2m1T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e2m3T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e3m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e4m3T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e5m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}}, allOf},
+ {{Extension::mxfp},
+ {{{fp4e2m1T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e2m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e3m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e4m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e5m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}},
+ {"tosa.cast_to_block_scaled",
+ {{{Extension::mxfp},
+ {{{bf16T, fp4e2m1T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp32T, fp4e2m1T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp32T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp32T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp32T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp32T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}},
+ {{Extension::bf16, Extension::mxfp},
+ {{{bf16T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+ {{bf16T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+ {{bf16T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+ {{bf16T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}, allOf}}},
{"tosa.rescale",
{{{Extension::int16},
{{{i48T, i48T, i8T, i8T}, SpecificationVersion::V_1_0},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index a5251fcada4c9..2f5fe6b347e33 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2472,6 +2472,69 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure, SameOperandsAndResultShape,
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// Operator: cast_from_block_scaled
+//===----------------------------------------------------------------------===//
+def Tosa_CastFromBlockScaledOp: Tosa_InferShapedTypeOp<"cast_from_block_scaled"> {
+ let summary = "Apply scales from a scale tensor to the values in a value tensor";
+
+ let description = [{
+ Apply the scales from a scale tensor to the values in a value tensor, casting
+ the result to the output type. The block dimension must be the last dimension
+ of the tensor.
+ }];
+
+ let arguments = (ins
+ Tosa_MXFPDataTensorAtLeast1D:$input_data,
+ Tosa_MXFPScaleTensorAtLeast1D:$input_scale,
+ Tosa_BlockSizeAttr:$block_size
+ );
+
+ let results = (outs
+ Tosa_TensorAtLeast1D: $output_data
+ );
+
+ list<Availability> availability = [
+ Profile<[Tosa_PRO_FP]>,
+ Extension<[Tosa_EXT_BF16, Tosa_EXT_MXFP]>,
+ ];
+
+ let hasVerifier = 1;
+ let hasCustomAssemblyFormat = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: cast_to_block_scaled
+//===----------------------------------------------------------------------===//
+def Tosa_CastToBlockScaledOp : Tosa_InferShapedTypeOp<"cast_to_block_scaled"> {
+ let summary = "Calculate scale tensor values per block, output to separate scale and data tensors.";
+
+ let description = [{
+ Calculate a scale value per block of input values and use that to calculate
+ scaled data values from an input tensor. The output tensors are cast to the
+ specified scale and value types. The block dimension will be the last dimension
+ of the tensor.
+ }];
+
+ let arguments = (ins
+ Tosa_TensorAtLeast1D:$input_data,
+ Tosa_BlockSizeAttr:$block_size
+ );
+
+ let results = (outs
+ Tosa_MXFPDataTensorAtLeast1D:$output_data,
+ Tosa_MXFPScaleTensorAtLeast1D:$output_scale
+ );
+
+ list<Availability> availability = [
+ Profile<[Tosa_PRO_FP]>,
+ Extension<[Tosa_EXT_BF16, Tosa_EXT_MXFP]>
+ ];
+
+ let hasVerifier = 1;
+ let hasCustomAssemblyFormat = 1;
+}
+
//===----------------------------------------------------------------------===//
// Operator: rescale
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 79df1b888b40e..4a899e3c787e6 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -79,7 +79,7 @@ class ProfileInfoDepot {
LogicalResult populatationDispatch(Operation *op);
- LogicalResult populateProfileInfo(ValueRange operands, Value output);
+ LogicalResult populateProfileInfo(ValueRange operands, ValueRange output);
// Base
template <typename T>
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 20bb961482ad8..93843e86fd378 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -199,6 +199,16 @@ def Tosa_MXFPScaleTensor3D : AnyTypeOf<[
TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>,
TosaTensorRankOf<[Tosa_MXFPScaleNumber], [3]>
]>;
+def Tosa_MXFPDataTensorAtLeast1D : AnyTypeOf<[
+ TosaUnrankedTensorOf<[Tosa_MXFPNumber]>,
+ TosaRankedTensorOf<[Tosa_MXFPNumber], [AtLeastRankOne]>],
+ "tosa-conformant tensor of at least rank 1", "::mlir::TensorType"
+>;
+def Tosa_MXFPScaleTensorAtLeast1D : AnyTypeOf<[
+ TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>,
+ TosaRankedTensorOf<[Tosa_MXFPScaleNumber], [AtLeastRankOne]>],
+ "tosa-conformant tensor of at least rank 1", "::mlir::TensorType"
+>;
//===----------------------------------------------------------------------===//
// Generic scalar, vector, or tensor of a particular type.
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 6479fe10cf68d..2e97f6f2989e3 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -370,7 +370,7 @@ ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) {
result.operands)))
return failure();
- result.addTypes(fnTy.getResult(0));
+ result.addTypes(fnTy.getResults());
result.addAttributes(attrs);
return success();
@@ -532,6 +532,24 @@ void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) {
printWithEnumHandling(parser, *this);
}
+ParseResult CastFromBlockScaledOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseWithEnumHandling<tosa::BlockSize>(parser, result);
+}
+
+void CastFromBlockScaledOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
+ParseResult CastToBlockScaledOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseWithEnumHandling<tosa::BlockSize>(parser, result);
+}
+
+void CastToBlockScaledOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
//===----------------------------------------------------------------------===//
// Tosa utilities.
//===----------------------------------------------------------------------===//
@@ -3944,6 +3962,145 @@ LogicalResult RescaleOp::inferReturnTypeComponents(
return success();
}
+LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ CastFromBlockScaledOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ const ShapeAdaptor inputShape(adaptor.getInputData().getType());
+ inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
+ return success();
+}
+
+LogicalResult CastFromBlockScaledOp::verify() {
+ const Type inputDataType = getInputData().getType();
+ const Type outputDataType = getResult().getType();
+ if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
+ return emitOpError() << "require compatible shapes for input_data ("
+ << inputDataType << ") and "
+ << "output_data (" << outputDataType << ")";
+
+ const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
+
+ if (inputDataShape.hasRank()) {
+ const unsigned int blockSize =
+ BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ const int64_t inputDataLastDim =
+ inputDataShape.getDimSize(inputDataShape.getRank() - 1);
+ if (inputDataLastDim % blockSize != 0)
+ return emitOpError() << "expect last dimension of input_data ("
+ << inputDataLastDim
+ << ") to be divisible by block_size (" << blockSize
+ << ")";
+
+ const Type inputScaleType = getInputScale().getType();
+ const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType);
+
+ if (inputScaleShape.hasRank()) {
+ SmallVector<int64_t> inputDataDims, inputScaleDims;
+ inputDataShape.getDims(inputDataDims);
+ inputScaleShape.getDims(inputScaleDims);
+
+ if (inputDataDims.size() != inputScaleDims.size() ||
+ failed(verifyCompatibleShape(
+ ArrayRef<int64_t>(inputDataDims).drop_back(1),
+ ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
+ return emitOpError() << "require compatible shapes for input_data ("
+ << inputDataType << ") and "
+ << "input_scale (" << inputScaleType
+ << ") except for the last dimension";
+
+ const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
+ inputScaleDims.back()};
+ if (ShapedType::isStatic(inputDataLastDim) &&
+ failed(verifyCompatibleDims(dimsToCheck)))
+ return emitOpError()
+ << "expect last dimension of input_scale ("
+ << inputScaleDims.back()
+ << ") to be equal to last dimension of input_data / block_size ("
+ << inputDataDims.back() / blockSize << ")";
+ }
+ }
+
+ return success();
+}
+
+LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ CastToBlockScaledOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ const ShapeAdaptor inputShape(adaptor.getInputData().getType());
+ inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
+ if (!inputShape.hasRank())
+ return success();
+
+ // Calculate output_scale shape if ranked input provided
+ SmallVector<int64_t> outputScaleShape;
+ inputShape.getDims(outputScaleShape);
+ const int64_t lastDimLoc = inputShape.getRank() - 1;
+ const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
+ if (ShapedType::isStatic(lastDimSize)) {
+ const unsigned int blockSize =
+ BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
+ outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
+ }
+ inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape));
+ return success();
+}
+
+LogicalResult CastToBlockScaledOp::verify() {
+ const Type inputDataType = getInputData().getType();
+ const Type outputDataType = getResult(0).getType();
+ if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
+ return emitOpError() << "require compatible shapes for input_data ("
+ << inputDataType << ") and "
+ << "output_data (" << outputDataType << ")";
+
+ const unsigned int blockSize =
+ BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
+ if (inputDataShape.hasRank()) {
+ const int64_t inputDataLastDim =
+ inputDataShape.getDimSize(inputDataShape.getRank() - 1);
+ if (ShapedType::isStatic(inputDataLastDim) &&
+ inputDataLastDim % blockSize != 0)
+ return emitOpError() << "expect last dimension of input_data ("
+ << inputDataLastDim
+ << ") to be divisible by block_size (" << blockSize
+ << ")";
+ }
+
+ const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType);
+ const Type outputScaleType = getResult(1).getType();
+ const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType);
+ if (outputDataShape.hasRank() && outputScaleShape.hasRank()) {
+ SmallVector<int64_t> outputDataDims, outputScaleDims;
+ outputDataShape.getDims(outputDataDims);
+ outputScaleShape.getDims(outputScaleDims);
+
+ if (outputDataDims.size() != outputScaleDims.size() ||
+ failed(verifyCompatibleShape(
+ ArrayRef<int64_t>(outputDataDims).drop_back(1),
+ ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
+ return emitOpError() << "require compatible shapes for output_data ("
+ << outputDataType << ") and "
+ << "output_scale (" << outputScaleType
+ << ") except for the last dimension";
+
+ const int64_t outputDataLastDim = outputDataDims.back();
+ const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize,
+ outputScaleDims.back()};
+ if (ShapedType::isStatic(outputDataLastDim) &&
+ failed(verifyCompatibleDims(dimsToCheck)))
+ return emitOpError()
+ << "expect last dimension of output_scale ("
+ << outputScaleDims.back()
+ << ") to be equal to last dimension of output_data / block_size ("
+ << outputDataDims.back() / blockSize << ")";
+ }
+
+ return success();
+}
+
LogicalResult IfOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
IfOp::Adaptor adaptor,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index e965ae0cf9888..92d5bac9c2653 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -50,10 +50,11 @@ TosaProfileCompliance::getProfileComplianceMap() {
// Base populating function
LogicalResult ProfileInfoDepot::populateProfileInfo(ValueRange operands,
- Value output) {
- for (auto operand : operands)
+ ValueRange outputs) {
+ for (const auto &operand : operands)
addValue(operand);
- addValue(output);
+ for (const auto &output : outputs)
+ addValue(output);
return success();
}
@@ -175,23 +176,6 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
return success();
}
-template <>
-LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) {
- addValue(op.getInputReal());
- addValue(op.getInputImag());
- addValue(op.getOutputReal());
- addValue(op.getOutputImag());
- return success();
-}
-
-template <>
-LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {
- addValue(op.getInputReal());
- addValue(op.getOutputReal());
- addValue(op.getOutputImag());
- return success();
-}
-
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
addValue(op.getOnTrue());
@@ -245,7 +229,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// This helper function populates the info for all operands.
#define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
if (isa<tosa::tosaOp##Op>(op)) { \
- return populateProfileInfo(op->getOperands(), op->getResult(0)); \
+ return populateProfileInfo(op->getOperands(), op->getResults()); \
}
// Skip irrelevant operands when they are independent and not tied to any
@@ -256,8 +240,6 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
POPULATE_PROFILE_INFO_CUSTOM(Conv3D)
POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D)
POPULATE_PROFILE_INFO_CUSTOM(Mul)
- POPULATE_PROFILE_INFO_CUSTOM(FFT2d)
- POPULATE_PROFILE_INFO_CUSTOM(RFFT2d)
POPULATE_PROFILE_INFO_CUSTOM(Concat)
POPULATE_PROFILE_INFO_CUSTOM(Pad)
POPULATE_PROFILE_INFO_CUSTOM(Reshape)
@@ -276,7 +258,11 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// For the most of tosa operators, all operands are profile/extension related
// and hence are all considered in this profile-based compilance check.
POPULATE_PROFILE_INFO_COMMON(MatmulTBlockScaled)
+ POPULATE_PROFILE_INFO_COMMON(FFT2d)
+ POPULATE_PROFILE_INFO_COMMON(RFFT2d)
POPULATE_PROFILE_INFO_COMMON(Cast)
+ POPULATE_PROFILE_INFO_COMMON(CastFromBlockScaled)
+ POPULATE_PROFILE_INFO_COMMON(CastToBlockScaled)
POPULATE_PROFILE_INFO_COMMON(Const)
POPULATE_PROFILE_INFO_COMMON(ArgMax)
POPULATE_PROFILE_INFO_COMMON(Sub)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 3f874d94ab9be..a142926bf87e2 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -635,6 +635,8 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
CHECK_RANKS_AND_SIZES(Transpose);
// Type Conversion
CHECK_RANKS_AND_SIZES(Cast);
+ CHECK_RANKS_AND_SIZES(CastFromBlockScaled);
+ CHECK_RANKS_AND_SIZES(CastToBlockScaled);
CHECK_RANKS_AND_SIZES(Rescale);
// Control Flow Operators
CHECK_RANKS_AND_SIZES(If);
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 600c4c717922a..d92d433a7d185 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -696,3 +696,21 @@ func.func @test_const_shape() -> !tosa.shape<4> {
return %cst : !tosa.shape<4>
}
+// -----
+// CHECK-LABEL: test_cast_from_block_scaled
+func.func @test_cast_from_block_scaled(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+ // CHECK: profiles: [ [pro_fp] ]
+ // CHECK: extensions: [ [bf16, mxfp] ]
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32>
+ return %0 : tensor<4x32xf32>
+}
+
+// -----
+// CHECK-LABEL: test_cast_to_block_scaled
+func.func @test_cast_to_block_scaled(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) {
+ // CHECK: profiles: [ [pro_fp] ]
+ // CHECK: extensions: [ [bf16, mxfp] ]
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = BLOCK_SIZE_32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>
+}
+
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 005601d4017b8..fff31c294a3f7 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -546,3 +546,19 @@ func.func @test_matmul_t_block_scaled(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: ten
%0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E4M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
return %0 : tensor<4x8x16xf32>
}
+
+// -----
+
+func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+ // expected-error at +1 {{'tosa.cast_from_block_scaled' op illegal: requires [mxfp] but not enabled in target}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf8E5M2>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32>
+ return %0 : tensor<4x32xf32>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>) {
+ // expected-error at +1 {{'tosa.cast_to_block_scaled' op illegal: requires [mxfp] but not enabled in target}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 8771e6e2476e4..cd392fcc20ea1 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1625,9 +1625,40 @@ func.func @test_unranked_weight_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor
// -----
-// CHECK-LABEL: test_matmul_t_block_scaled_invalid_size
func.func @test_matmul_t_block_scaled_invalid_size(%arg0: tensor<4x8x536870912xf4E2M1FN>, %arg1: tensor<4x8x16777216xf8E8M0FNU>, %arg2: tensor<4x16x536870912xf4E2M1FN>, %arg3: tensor<4x16x16777216xf8E8M0FNU>) -> tensor<*xf32> {
// expected-error at +1 {{'tosa.matmul_t_block_scaled' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
%0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x536870912xf4E2M1FN>, tensor<4x8x16777216xf8E8M0FNU>, tensor<4x16x536870912xf4E2M1FN>, tensor<4x16x16777216xf8E8M0FNU>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
+
+// -----
+
+func.func @test_cast_from_block_scaled_invalid_size(%arg0: tensor<536870912x32xf6E2M3FN>, %arg1: tensor<536870912x1xf8E8M0FNU>) -> tensor<536870912x32xf32> {
+ // expected-error at +1 {{'tosa.cast_from_block_scaled' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<536870912x32xf6E2M3FN>, tensor<536870912x1xf8E8M0FNU>) -> tensor<536870912x32xf32>
+ return %0 : tensor<536870912x32xf32>
+}
+
+// -----
+
+func.func @test_cast_from_block_scaled_invalid_rank(%arg0: tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, %arg1: tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>) -> tensor<1x2x3x4x5x6x7x32xf32> {
+ // expected-error at +1 {{'tosa.cast_from_block_scaled' op failed level check: operand rank(shape) <= MAX_RANK}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>) -> tensor<1x2x3x4x5x6x7x32xf32>
+ return %0 : tensor<1x2x3x4x5x6x7x32xf32>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_invalid_size(%arg0: tensor<536870912x32xf32>) -> (tensor<536870912x32xf6E2M3FN>, tensor<536870912x1xf8E8M0FNU>) {
+ // expected-error at +1 {{'tosa.cast_to_block_scaled' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<536870912x32xf32>) -> (tensor<536870912x32xf6E2M3FN>, tensor<536870912x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<536870912x32xf6E2M3FN>, tensor<536870912x1xf8E8M0FNU>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_invalid_rank(%arg0: tensor<1x2x3x4x5x6x7x32xf32>) -> (tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>) {
+ // expected-error at +1 {{'tosa.cast_to_block_scaled' op failed level check: operand rank(shape) <= MAX_RANK}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x2x3x4x5x6x7x32xf32>) -> (tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 9bf36b5fd4c7d..865f712ce1a5a 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -1268,3 +1268,31 @@ func.func @test_matmul_t_block_scaled_broadcast(%arg0: tensor<?x8x32xf8E4M3FN>,
%0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<?x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<?x16x32xf8E4M3FN>, tensor<1x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
return %0 : tensor<4x8x16xf32>
}
+
+// -----
+// CHECK-LABEL: test_cast_from_block_scaled_static
+func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32>
+ return %0 : tensor<4x32xf32>
+}
+
+// -----
+// CHECK-LABEL: test_cast_from_block_scaled_unranked
+func.func @test_cast_from_block_scaled_unranked(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>) -> tensor<*xf32> {
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+// CHECK-LABEL: test_cast_to_block_scaled_static
+func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) {
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>
+}
+
+// -----
+// CHECK-LABEL: test_cast_to_block_scaled_unranked
+func.func @test_cast_to_block_scaled_unranked(%arg0: tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) {
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>
+}
diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
index 0271d71561a52..7de7b85bcaedf 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
@@ -332,3 +332,17 @@ func.func @test_matmul_t_block_scaled(%arg0: tensor<4x8x32xf6E3M2FN>, %arg1: ten
%0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf6E3M2FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E3M2FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
return %0 : tensor<4x8x16xf32>
}
+
+// -----
+func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+ // expected-error at +1 {{'tosa.cast_from_block_scaled' op illegal: requires [pro_fp] but not enabled in target}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32>
+ return %0 : tensor<4x32xf32>
+}
+
+// -----
+func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) {
+ // expected-error at +1 {{'tosa.cast_to_block_scaled' op illegal: requires [pro_fp] but not enabled in target}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 72479fe21ade8..54556a0eb08e0 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1628,3 +1628,48 @@ func.func @test_matmul_t_block_scaled_broadcast_b_scale(%arg0: tensor<*xf8E4M3FN
%0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<1x4x1xf8E8M0FNU>) -> tensor<?x?x?xf32>
return %0 : tensor<?x?x?xf32>
}
+
+// -----
+
+// CHECK-LABEL: test_cast_from_block_scaled_static
+func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<*xf32> {
+ // CHECK: -> tensor<4x32xf32>
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_from_block_scaled_unranked_input_scale
+func.func @test_cast_from_block_scaled_unranked_input_scale(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>) -> tensor<*xf32> {
+ // CHECK: -> tensor<4x32xf32>
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<*xf8E8M0FNU>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_to_block_scaled_static
+func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) {
+ // CHECK: -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_to_block_scaled_unranked
+func.func @test_cast_to_block_scaled_unranked(%arg0: tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) {
+ // CHECK: -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_to_block_scaled_dynamic_scales
+func.func @test_cast_to_block_scaled_dynamic_scales(%arg0: tensor<4x?xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) {
+ // CHECK: -> (tensor<4x?xf4E2M1FN>, tensor<4x?xf8E8M0FNU>)
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x?xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index 2040a4bc7e6af..8b6cdc07925f0 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -26,3 +26,27 @@ func.func @test_matmul_t_block_scaled_fp6e2m3(%arg0: tensor<4x8x32xf6E2M3FN>, %a
%0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = BLOCK_SIZE_32} : (tensor<4x8x32xf6E2M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E2M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
return %0 : tensor<4x8x16xf32>
}
+
+// -----
+
+// CHECK-LABEL: test_cast_from_block_scaled_fp8e5m2_fp32
+func.func @test_cast_from_block_scaled_fp8e5m2_fp32(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf8E5M2>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32>
+ return %0 : tensor<4x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_from_block_scaled_fp8e5m2_bf16
+func.func @test_cast_from_block_scaled_fp8e5m2_bf16(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xbf16> {
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf8E5M2>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xbf16>
+ return %0 : tensor<4x32xbf16>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_to_block_scaled_static
+func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>) {
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>
+}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 4be5d725ad612..6cf76cdc7ad8e 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1033,7 +1033,6 @@ module {
// -----
-// CHECK-LABEL: @scatter_invalid_indices_N
func.func @scatter_invalid_indices_N(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<3x2xi32>, %arg2 : tensor<2x2x5xi32>) {
// expected-error at +1 {{'tosa.scatter' op requires indices dimension 0 to have size 2, got 3}}
%1 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<3x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x5xi32>
@@ -1042,7 +1041,6 @@ func.func @scatter_invalid_indices_N(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<3
// -----
-// CHECK-LABEL: @scatter_invalid_input_N
func.func @scatter_invalid_input_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<2x2xi32>, %arg2 : tensor<3x2x5xi32>) {
// expected-error at +1 {{'tosa.scatter' op requires input dimension 0 to have size 2, got 3}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<2x2xi32>, tensor<3x2x5xi32>) -> tensor<2x4x5xi32>
@@ -1051,7 +1049,6 @@ func.func @scatter_invalid_input_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<2x2
// -----
-// CHECK-LABEL: @scatter_invalid_out_N
func.func @scatter_invalid_out_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
// expected-error at +1 {{'tosa.scatter' op requires values_out dimension 0 to have size 2, got 3}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<3x4x5xi32>
@@ -1060,7 +1057,6 @@ func.func @scatter_invalid_out_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi
// -----
-// CHECK-LABEL: @scatter_invalid_out_K
func.func @scatter_invalid_out_K(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
// expected-error at +1 {{'tosa.scatter' op requires values_out dimension 1 to have size 4, got 3}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x3x5xi32>
@@ -1069,7 +1065,6 @@ func.func @scatter_invalid_out_K(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi
// -----
-// CHECK-LABEL: @scatter_invalid_input_W
func.func @scatter_invalid_input_W(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x3x5xi32>) {
// expected-error at +1 {{'tosa.scatter' op requires input dimension 1 to have size 2, got 3}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x3x5xi32>) -> tensor<2x4x5xi32>
@@ -1078,7 +1073,6 @@ func.func @scatter_invalid_input_W(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2
// -----
-// CHECK-LABEL: @scatter_invalid_input_C
func.func @scatter_invalid_input_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x6xi32>) {
// expected-error at +1 {{'tosa.scatter' op requires input dimension 2 to have size 5, got 6}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x6xi32>) -> tensor<2x4x5xi32>
@@ -1087,7 +1081,6 @@ func.func @scatter_invalid_input_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2
// -----
-// CHECK-LABEL: @scatter_invalid_out_C
func.func @scatter_invalid_out_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
// expected-error at +1 {{'tosa.scatter' op requires values_out dimension 2 to have size 5, got 6}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x6xi32>
@@ -1096,7 +1089,6 @@ func.func @scatter_invalid_out_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi
// -----
-// CHECK-LABEL: @scatter_invalid_K_W
func.func @scatter_invalid_K_W(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<2x6xi32>, %arg2 : tensor<2x6x5xi32>) {
// expected-error at +1 {{'tosa.scatter' op requires dimensions K >= W, got K=4 and W=6}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<2x6xi32>, tensor<2x6x5xi32>) -> tensor<2x4x5xi32>
@@ -1150,3 +1142,83 @@ func.func @test_matmul_t_block_scaled_batch_mismatch(%arg0: tensor<4x8x32xf8E4M3
%0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<2x16x32xf8E4M3FN>, tensor<2x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
return %0 : tensor<4x8x16xf32>
}
+
+// -----
+
+func.func @cast_from_block_scaled_incompatible_input_output_shape(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<5x32xf32> {
+ // expected-error at +1 {{'tosa.cast_from_block_scaled' op require compatible shapes for input_data ('tensor<4x32xf4E2M1FN>') and output_data ('tensor<5x32xf32>')}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<5x32xf32>
+ return %0 : tensor<5x32xf32>
+}
+
+// -----
+
+func.func @cast_from_block_scaled_not_scalar(%arg0: tensor<f4E2M1FN>, %arg1: tensor<f8E8M0FNU>) -> tensor<f32> {
+ // expected-error at +1 {{'tosa.cast_from_block_scaled' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f4E2M1FN>'}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<f4E2M1FN>, tensor<f8E8M0FNU>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @cast_from_block_scaled_not_divisible_by_block_size(%arg0: tensor<4x33xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x33xf32> {
+ // expected-error at +1 {{'tosa.cast_from_block_scaled' op expect last dimension of input_data (33) to be divisible by block_size (32)}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x33xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x33xf32>
+ return %0 : tensor<4x33xf32>
+}
+
+// -----
+
+func.func @cast_from_block_scaled_data_scale_mismatch(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<5x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+ // expected-error at +1 {{'tosa.cast_from_block_scaled' op require compatible shapes for input_data ('tensor<4x32xf4E2M1FN>') and input_scale ('tensor<5x1xf8E8M0FNU>') except for the last dimension}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<5x1xf8E8M0FNU>) -> tensor<4x32xf32>
+ return %0 : tensor<4x32xf32>
+}
+
+// -----
+
+func.func @cast_from_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x2xf8E8M0FNU>) -> tensor<4x32xf32> {
+ // expected-error at +1 {{'tosa.cast_from_block_scaled' op expect last dimension of input_scale (2) to be equal to last dimension of input_data / block_size (1)}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>) -> tensor<4x32xf32>
+ return %0 : tensor<4x32xf32>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_incompatible_input_output_shape(%arg0: tensor<4x32xf32>) -> (tensor<5x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) {
+ // expected-error at +1 {{'tosa.cast_to_block_scaled' op require compatible shapes for input_data ('tensor<4x32xf32>') and output_data ('tensor<5x32xf4E2M1FN>')}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<5x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<5x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_not_scalar(%arg0: tensor<f32>) -> (tensor<f4E2M1FN>, tensor<f8E8M0FNU>) {
+ // expected-error at +1 {{'tosa.cast_to_block_scaled' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<f32>) -> (tensor<f4E2M1FN>, tensor<f8E8M0FNU>)
+ return %0#0, %0#1 : tensor<f4E2M1FN>, tensor<f8E8M0FNU>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_not_divisible_by_block_size(%arg0: tensor<4x33xf32>) -> (tensor<4x33xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) {
+ // expected-error at +1 {{'tosa.cast_to_block_scaled' op expect last dimension of input_data (33) to be divisible by block_size (32)}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x33xf32>) -> (tensor<4x33xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<4x33xf4E2M1FN>, tensor<4x1xf8E8M0FNU>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_data_scale_mismatch(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<5x1xf8E8M0FNU>) {
+ // expected-error at +1 {{'tosa.cast_to_block_scaled' op require compatible shapes for output_data ('tensor<4x32xf4E2M1FN>') and output_scale ('tensor<5x1xf8E8M0FNU>') except for the last dimension}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<5x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<5x1xf8E8M0FNU>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>) {
+ // expected-error at +1 {{'tosa.cast_to_block_scaled' op expect last dimension of output_scale (2) to be equal to last dimension of output_data / block_size (1)}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>
+}
More information about the Mlir-commits
mailing list