[Mlir-commits] [mlir] [mlir][tosa] Finalize profile-based validation for TOSA v1.0 (PR #131208)
TatWai Chong
llvmlistbot at llvm.org
Tue Mar 18 23:30:36 PDT 2025
https://github.com/tatwaichong updated https://github.com/llvm/llvm-project/pull/131208
>From db4907f2933de156300860ec57a0b4a699d6adcd Mon Sep 17 00:00:00 2001
From: TatWai Chong <tatwai.chong at arm.com>
Date: Wed, 3 Jul 2024 14:12:07 -0700
Subject: [PATCH] [Mlir][tosa] Finalize profile-based validation for TOSA v1.0
- When the operand type of an operation changes to a profile-dependent
type, the compliance metadata must be updated. Update compliance
check for the following:
- CONV2D, CONV3D, DEPTHWISE_CONV2D, and TRANSPOSE_CONV2D, as zero
points have changed to variable inputs.
- PAD, because pad_const has been changed to a variable input.
- GATHER and SCATTER, as indices has changed to index_t.
- Add an int16 extension check for CONCAT.
- Add a compliance check for COND_IF, WHILE_LOOP, VARIABLE, VARIABLE_READ,
and VARIABLE_WRITE.
- Correct the profile requirements for IDENTITY, TABLE, MATMUL and
LOGICAL-like operations.
- Remove unnecessary checks for non-v1.0 operations.
- Add condition requirements (anyOf and allOf) to the type mode of
metadata for modes that have multiple profile/extension considerations.
---
.../Dialect/Tosa/IR/TosaComplianceData.h.inc | 232 ++++++++++--------
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 12 +-
.../Dialect/Tosa/IR/TosaProfileCompliance.h | 14 +-
.../Tosa/Transforms/TosaProfileCompliance.cpp | 155 ++++++++----
mlir/test/Dialect/Tosa/availability.mlir | 12 +-
mlir/test/Dialect/Tosa/invalid.mlir | 2 +-
mlir/test/Dialect/Tosa/invalid_extension.mlir | 6 +-
.../Dialect/Tosa/profile_all_unsupported.mlir | 122 +++++++++
.../Tosa/profile_pro_int_unsupported.mlir | 4 -
9 files changed, 379 insertions(+), 180 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index f06b156c1e41a..014d98e16e18d 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -11,29 +11,23 @@ profileComplianceMap = {
{fp16T, fp16T, fp16T, fp32T, fp16T},
{fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
{"tosa.conv2d",
- {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}},
+ {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+ {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
+ {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
+ {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
{"tosa.conv3d",
- {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}},
+ {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+ {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
+ {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
+ {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
{"tosa.depthwise_conv2d",
- {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}},
+ {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
- {"tosa.fully_connected",
- {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T}}},
- {{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp32T, fp32T},
- {fp32T, fp32T, fp32T, fp32T}}}}},
+ {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
+ {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
+ {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
{"tosa.matmul",
{{{Profile::pro_int}, {{i8T, i8T, i8T, i8T, i32T}}},
{{Profile::pro_fp},
@@ -44,11 +38,11 @@ profileComplianceMap = {
{{{Profile::pro_int}, {{i8T, i8T}}},
{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
{"tosa.transpose_conv2d",
- {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}},
+ {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+ {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
+ {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
+ {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
{"tosa.clamp",
{{{Profile::pro_int}, {{i8T, i8T}}},
{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
@@ -56,7 +50,7 @@ profileComplianceMap = {
{"tosa.sigmoid", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
{"tosa.tanh", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
{"tosa.add",
- {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}},
+ {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf},
{{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
{"tosa.arithmetic_right_shift",
{{{Profile::pro_int},
@@ -70,20 +64,22 @@ profileComplianceMap = {
{"tosa.bitwise_xor",
{{{Profile::pro_int},
{{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
- {"tosa.intdiv",
- {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}}}},
+ {"tosa.int_div",
+ {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf}}},
{"tosa.logical_and",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}}}},
+ {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}},
{"tosa.logical_left_shift",
{{{Profile::pro_int, Profile::pro_fp},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+ {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}},
+ anyOf}}},
{"tosa.logical_right_shift",
{{{Profile::pro_int, Profile::pro_fp},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+ {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}},
+ anyOf}}},
{"tosa.logical_or",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}}}},
+ {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}},
{"tosa.logical_xor",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}}}},
+ {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}},
{"tosa.maximum",
{{{Profile::pro_int}, {{i32T, i32T, i32T}}},
{{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
@@ -92,12 +88,12 @@ profileComplianceMap = {
{{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
{"tosa.mul",
{{{Profile::pro_int}, {{i8T, i8T, i32T}, {i16T, i16T, i32T}}},
- {{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}},
+ {{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf},
{{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
{"tosa.pow",
{{{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
{"tosa.sub",
- {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}},
+ {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf},
{{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
{"tosa.table", {{{Profile::pro_int}, {{i8T, i8T, i8T}}}}},
{"tosa.abs",
@@ -112,7 +108,7 @@ profileComplianceMap = {
{"tosa.floor", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
{"tosa.log", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
{"tosa.logical_not",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}}},
+ {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}},
{"tosa.negate",
{{{Profile::pro_int},
{{i8T, i8T, i8T, i8T},
@@ -123,12 +119,12 @@ profileComplianceMap = {
{"tosa.reciprocal",
{{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
{"tosa.rsqrt", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {"tosa.sin", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
{"tosa.select",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}},
+ {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf},
{{Profile::pro_int},
{{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}},
{{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
- {"tosa.sin", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
{"tosa.equal",
{{{Profile::pro_int}, {{i32T, i32T, boolT}}},
{{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}},
@@ -139,9 +135,9 @@ profileComplianceMap = {
{{{Profile::pro_int}, {{i32T, i32T, boolT}}},
{{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}},
{"tosa.reduce_all",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}}},
+ {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}},
{"tosa.reduce_any",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}}},
+ {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}},
{"tosa.reduce_max",
{{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
@@ -154,40 +150,45 @@ profileComplianceMap = {
{{{Profile::pro_int}, {{i32T, i32T}}},
{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
{"tosa.concat",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
- {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+ {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
+ {{Profile::pro_int}, {{i8T, i8T}, {i32T, i32T}}},
{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
{"tosa.pad",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
- {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf},
+ {{Profile::pro_int},
+ {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}},
+ {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
{"tosa.reshape",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
+ {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
{"tosa.reverse",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
+ {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
{"tosa.slice",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
+ {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
{"tosa.tile",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
+ {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
{"tosa.transpose",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
+ {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
{"tosa.gather",
- {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int},
+ {{i8T, i32T, i8T}, {i16T, i32T, i16T}, {i32T, i32T, i32T}}},
+ {{Profile::pro_fp}, {{fp16T, i32T, fp16T}, {fp32T, i32T, fp32T}}}}},
{"tosa.scatter",
{{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{i8T, i32T, i8T, i8T},
+ {i16T, i32T, i16T, i16T},
+ {i32T, i32T, i32T, i32T}}},
+ {{Profile::pro_fp},
+ {{fp16T, i32T, fp16T, fp16T}, {fp32T, i32T, fp32T, fp32T}}}}},
{"tosa.resize",
{{{Profile::pro_int}, {{i8T, i32T}, {i8T, i8T}}},
{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
@@ -232,16 +233,21 @@ profileComplianceMap = {
{i32T, i32T, i16T, i16T},
{i32T, i32T, i32T, i32T}}}}},
{"tosa.const",
- {{{Profile::pro_int}, {{boolT}, {i8T}, {i16T}, {i32T}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{boolT}, {i8T}, {i16T}, {i32T}},
+ anyOf},
{{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
{"tosa.identity",
- {{{Profile::pro_int},
- {{boolT, boolT}, {i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{boolT, boolT}, {i8T, i8T}, {i16T, i16T}, {i32T, i32T}},
+ anyOf},
{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.dim",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT}}},
- {{Profile::pro_int}, {{i8T}, {i16T}, {i32T}}},
- {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
+ {"tosa.variable",
+ {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
+ {"tosa.variable_write",
+ {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
+ {"tosa.variable_read",
+ {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
};
extensionComplianceMap = {
@@ -256,32 +262,47 @@ extensionComplianceMap = {
{{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T}}},
{{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
{"tosa.conv2d",
- {{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}},
- {{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T, fp16T, fp16T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T, fp16T, fp16T}}},
- {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+ {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}},
+ {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}},
+ {{Extension::fp8e4m3},
+ {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}},
+ {{Extension::fp8e5m2},
+ {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}},
+ {{Extension::bf16},
+ {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
{"tosa.conv3d",
- {{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}},
- {{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T, fp16T, fp16T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T, fp16T, fp16T}}},
- {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+ {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}},
+ {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}},
+ {{Extension::fp8e4m3},
+ {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}},
+ {{Extension::fp8e5m2},
+ {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}},
+ {{Extension::bf16},
+ {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
{"tosa.depthwise_conv2d",
- {{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}},
- {{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T, fp16T, fp16T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T, fp16T, fp16T}}},
- {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+ {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}},
+ {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}},
+ {{Extension::fp8e4m3},
+ {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}},
+ {{Extension::fp8e5m2},
+ {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}},
+ {{Extension::bf16},
+ {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
{"tosa.fft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T, fp32T}}}}},
- {"tosa.fully_connected",
- {{{Extension::int4}, {{i8T, i4T, i32T, i32T}}},
- {{Extension::int16}, {{i16T, i8T, i48T, i48T}}},
- {{Extension::bf16}, {{bf16T, bf16T, fp32T, fp32T}}}}},
{"tosa.matmul",
{{{Extension::int16}, {{i16T, i16T, i16T, i16T, i48T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T}}},
+ {{Extension::fp8e4m3},
+ {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T},
+ {fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp32T}}},
+ {{Extension::fp8e5m2},
+ {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T},
+ {fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp32T}}},
+ {{Extension::fp8e4m3, Extension::fp8e5m2},
+ {{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp16T},
+ {fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp32T},
+ {fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp16T},
+ {fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp32T}},
+ allOf},
{{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T, fp32T}}}}},
{"tosa.max_pool2d",
{{{Extension::int16}, {{i16T, i16T}}},
@@ -290,11 +311,14 @@ extensionComplianceMap = {
{{Extension::bf16}, {{bf16T, bf16T}}}}},
{"tosa.rfft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T}}}}},
{"tosa.transpose_conv2d",
- {{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}},
- {{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T, fp16T, fp16T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T, fp16T, fp16T}}},
- {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+ {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}},
+ {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}},
+ {{Extension::fp8e4m3},
+ {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}},
+ {{Extension::fp8e5m2},
+ {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}},
+ {{Extension::bf16},
+ {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
{"tosa.clamp",
{{{Extension::int16}, {{i16T, i16T}}},
{{Extension::bf16}, {{bf16T, bf16T}}}}},
@@ -317,8 +341,8 @@ extensionComplianceMap = {
{"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T}}}}},
{"tosa.reciprocal", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
{"tosa.rsqrt", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.select", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
{"tosa.sin", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {"tosa.select", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
{"tosa.equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}},
{"tosa.greater", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}},
{"tosa.greater_equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}},
@@ -327,13 +351,14 @@ extensionComplianceMap = {
{"tosa.reduce_product", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
{"tosa.reduce_sum", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
{"tosa.concat",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
+ {{{Extension::int16}, {{i16T, i16T}}},
+ {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
{{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
{{Extension::bf16}, {{bf16T, bf16T}}}}},
{"tosa.pad",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T}}},
+ {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T}}},
+ {{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
{"tosa.reshape",
{{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
{{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
@@ -355,13 +380,13 @@ extensionComplianceMap = {
{{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
{{Extension::bf16}, {{bf16T, bf16T}}}}},
{"tosa.gather",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3}, {{fp8e4m3T, i32T, fp8e4m3T}}},
+ {{Extension::fp8e5m2}, {{fp8e5m2T, i32T, fp8e5m2T}}},
+ {{Extension::bf16}, {{bf16T, i32T, bf16T}}}}},
{"tosa.scatter",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3}, {{fp8e4m3T, i32T, fp8e4m3T, fp8e4m3T}}},
+ {{Extension::fp8e5m2}, {{fp8e5m2T, i32T, fp8e5m2T, fp8e5m2T}}},
+ {{Extension::bf16}, {{bf16T, i32T, bf16T, bf16T}}}}},
{"tosa.resize",
{{{Extension::int16}, {{i16T, i48T}, {i16T, i16T}}},
{{Extension::bf16}, {{bf16T, bf16T}}}}},
@@ -376,9 +401,11 @@ extensionComplianceMap = {
{bf16T, fp32T},
{fp32T, bf16T}}},
{{Extension::bf16, Extension::fp8e4m3},
- {{bf16T, fp8e4m3T}, {fp8e4m3T, bf16T}}},
+ {{bf16T, fp8e4m3T}, {fp8e4m3T, bf16T}},
+ allOf},
{{Extension::bf16, Extension::fp8e5m2},
- {{bf16T, fp8e5m2T}, {fp8e5m2T, bf16T}}},
+ {{bf16T, fp8e5m2T}, {fp8e5m2T, bf16T}},
+ allOf},
{{Extension::fp8e4m3},
{{fp8e4m3T, fp16T},
{fp8e4m3T, fp32T},
@@ -406,9 +433,12 @@ extensionComplianceMap = {
{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
{{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.dim",
- {{{Extension::fp8e4m3}, {{fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T}}}}},
+ {"tosa.cond_if", {{{Extension::controlflow}, {{boolT}}}}},
+ {"tosa.while_loop", {{{Extension::controlflow}, {{boolT}}}}},
+ {"tosa.variable", {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}},
+ {"tosa.variable_write",
+ {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}},
+ {"tosa.variable_read",
+ {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}},
};
// End of auto-generated metadata
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 4301ee5a583b7..14e15173de7bc 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -801,7 +801,7 @@ def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
);
list<Availability> availability = [
- Profile<[Tosa_PRO_INT]>,
+ Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
Extension<[]>,
];
}
@@ -828,7 +828,7 @@ def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift",
);
list<Availability> availability = [
- Profile<[Tosa_PRO_INT]>,
+ Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
Extension<[]>,
];
}
@@ -856,7 +856,7 @@ def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift",
);
list<Availability> availability = [
- Profile<[Tosa_PRO_INT]>,
+ Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
Extension<[]>,
];
}
@@ -884,7 +884,7 @@ def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
);
list<Availability> availability = [
- Profile<[Tosa_PRO_INT]>,
+ Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
Extension<[]>,
];
}
@@ -912,7 +912,7 @@ def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
);
list<Availability> availability = [
- Profile<[Tosa_PRO_INT]>,
+ Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
Extension<[]>,
];
}
@@ -1108,7 +1108,7 @@ def Tosa_TableOp : Tosa_InferShapedTypeOp<"table"> {
list<Availability> availability = [
Profile<[Tosa_PRO_INT]>,
- Extension<[Tosa_EXT_BF16]>,
+ Extension<[Tosa_EXT_INT16]>,
];
let assemblyFormat = [{
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 69b827fe14dee..1df1761d38455 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -29,11 +29,11 @@ typedef struct {
} TypeInfo;
enum CheckCondition {
+ invalid,
// Valid when any of the profile (extension) requirement is meet.
anyOf,
// Valid when all of the profile (extension) requirement are meet.
- allOf,
- invalid
+ allOf
};
template <typename T>
@@ -76,20 +76,20 @@ class ProfileInfoDepot {
LogicalResult populatationDispatch(Operation *op);
- void populateProfileInfo(ValueRange operands, Value output);
+ LogicalResult populateProfileInfo(ValueRange operands, Value output);
// Base
template <typename T>
- void populateProfileInfo(T op) {
+ LogicalResult populateProfileInfo(T op) {
op->emitOpError() << "profile requirement for this op has not been defined";
}
// For conv2d, conv3d, transpose_conv2d, and depthwise_conv2d.
template <typename T>
- void populateProfileInfoConv(T op);
+ LogicalResult populateProfileInfoConv(T op);
- // For pad, reshape, slice, tile, and transpose.
+ // For reshape, slice, tile, and transpose.
template <typename T>
- void populateProfileInfoDataLayout(T op);
+ LogicalResult populateProfileInfoDataLayout(T op);
private:
SmallVector<TypeInfo> tyInfo;
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index ed2c40598458c..4aeb095ffff07 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -43,158 +43,206 @@ TosaProfileCompliance::getProfileComplianceMap() {
}
// Base populating function
-void ProfileInfoDepot::populateProfileInfo(ValueRange operands, Value output) {
+LogicalResult ProfileInfoDepot::populateProfileInfo(ValueRange operands,
+ Value output) {
for (auto operand : operands)
addValue(operand);
addValue(output);
+ return success();
}
template <>
-void ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
+LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
addValue(op.getInput1().front());
addValue(op.getOutput());
+ return success();
}
template <>
-void ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
+LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
addValue(op.getInput());
addValue(op.getInputZp());
addValue(op.getOutputZp());
addType(op.getAccType());
addValue(op.getOutput());
+ return success();
}
template <typename T>
-void ProfileInfoDepot::populateProfileInfoConv(T op) {
+LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
addValue(op.getInput());
addValue(op.getWeight());
addValue(op.getBias());
+ addValue(op.getInputZp());
+ addValue(op.getWeightZp());
addType(op.getAccType());
addValue(op.getOutput());
+ return success();
}
template <>
-void ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) {
- populateProfileInfoConv(op);
+LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) {
+ return populateProfileInfoConv(op);
}
template <>
-void ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) {
- populateProfileInfoConv(op);
+LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) {
+ return populateProfileInfoConv(op);
}
template <>
-void ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) {
- populateProfileInfoConv(op);
+LogicalResult
+ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) {
+ return populateProfileInfoConv(op);
}
template <>
-void ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) {
- populateProfileInfoConv(op);
+LogicalResult
+ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) {
+ return populateProfileInfoConv(op);
}
-template <typename T>
-void ProfileInfoDepot::populateProfileInfoDataLayout(T op) {
+template <>
+LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) {
addValue(op.getInput1());
+ addValue(op.getPadConst());
addValue(op.getOutput());
+ return success();
}
-template <>
-void ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) {
- populateProfileInfoDataLayout(op);
+template <typename T>
+LogicalResult ProfileInfoDepot::populateProfileInfoDataLayout(T op) {
+ addValue(op.getInput1());
+ addValue(op.getOutput());
+ return success();
}
template <>
-void ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) {
- populateProfileInfoDataLayout(op);
+LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) {
+ return populateProfileInfoDataLayout(op);
}
template <>
-void ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) {
- populateProfileInfoDataLayout(op);
+LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) {
+ return populateProfileInfoDataLayout(op);
}
template <>
-void ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) {
- populateProfileInfoDataLayout(op);
+LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) {
+ return populateProfileInfoDataLayout(op);
}
template <>
-void ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
- populateProfileInfoDataLayout(op);
+LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
+ return populateProfileInfoDataLayout(op);
}
template <>
-void ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
+LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
addValue(op.getValues());
addValue(op.getOutput());
+ return success();
}
template <>
-void ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
+LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
addValue(op.getValuesIn());
addValue(op.getInput());
addValue(op.getValuesOut());
+ return success();
}
template <>
-void ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) {
+LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) {
addValue(op.getInput1());
addValue(op.getInput2());
addValue(op.getOutput());
+ return success();
}
template <>
-void ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
+LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
addValue(op.getInput());
addValue(op.getOutput());
+ return success();
}
template <>
-void ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) {
+LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) {
addValue(op.getInputReal());
addValue(op.getInputImag());
addValue(op.getOutputReal());
addValue(op.getOutputImag());
+ return success();
}
template <>
-void ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {
+LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {
addValue(op.getInputReal());
addValue(op.getOutputReal());
addValue(op.getOutputImag());
+ return success();
}
template <>
-void ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
+LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
addValue(op.getInput2());
addValue(op.getInput3());
addValue(op.getOutput());
+ return success();
}
template <>
-void ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
+LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
addValue(op.getInput());
addValue(op.getInputZp());
addValue(op.getOutputZp());
addValue(op.getOutput());
+ return success();
}
template <>
-void ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
+LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
addValue(op.getA());
addValue(op.getB());
addValue(op.getAZp());
addValue(op.getBZp());
addValue(op.getOutput());
+ return success();
+}
+
+template <>
+LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
+ ::mlir::Attribute attr = op.getInitialValueAttr();
+ if (attr == nullptr)
+ return failure();
+
+ if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
+ addType(getElementTypeOrSelf(typedAttr));
+ return success();
+ }
+ return failure();
+}
+
+template <>
+LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::IfOp op) {
+ addValue(op.getCondition());
+ return success();
+}
+
+template <>
+LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::WhileOp op) {
+ Block *block = &op.getCondGraph().front();
+ Operation *terminator = block->getTerminator();
+ addValue(terminator->getOperands().front());
+ return success();
}
LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// This helper function only populates the info for the customised operands.
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
if (isa<tosa::tosaOp##Op>(op)) { \
- populateProfileInfo(cast<tosa::tosaOp##Op>(op)); \
- return success(); \
+ return populateProfileInfo(cast<tosa::tosaOp##Op>(op)); \
}
#define POPULATE_PROFILE_INFO_SKIP(tosaOp) \
@@ -204,8 +252,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// This helper function populates the info for all operands.
#define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
if (isa<tosa::tosaOp##Op>(op)) { \
- populateProfileInfo(op->getOperands(), op->getResult(0)); \
- return success(); \
+ return populateProfileInfo(op->getOperands(), op->getResult(0)); \
}
// Skip irrelevant operands when they are independent and not tied to any
@@ -230,17 +277,9 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
POPULATE_PROFILE_INFO_CUSTOM(Select)
POPULATE_PROFILE_INFO_CUSTOM(Rescale)
POPULATE_PROFILE_INFO_CUSTOM(MatMul)
-
- // Type Invariant Extension, a capability extension that is independent
- // of the data type, meaning any compatible type can be used. No type
- // constraint for those operations.
- POPULATE_PROFILE_INFO_SKIP(ConstShape)
- POPULATE_PROFILE_INFO_SKIP(Variable)
- POPULATE_PROFILE_INFO_SKIP(VariableRead)
- POPULATE_PROFILE_INFO_SKIP(VariableWrite)
- POPULATE_PROFILE_INFO_SKIP(If)
- POPULATE_PROFILE_INFO_SKIP(While)
- POPULATE_PROFILE_INFO_SKIP(Yield)
+ POPULATE_PROFILE_INFO_CUSTOM(Variable)
+ POPULATE_PROFILE_INFO_CUSTOM(If)
+ POPULATE_PROFILE_INFO_CUSTOM(While)
// For the most of tosa operators, all operands are profile/extension related
// and hence are all considered in this profile-based compilance check.
@@ -292,6 +331,14 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
POPULATE_PROFILE_INFO_COMMON(Greater)
POPULATE_PROFILE_INFO_COMMON(Reverse)
POPULATE_PROFILE_INFO_COMMON(Identity)
+ POPULATE_PROFILE_INFO_COMMON(VariableRead)
+ POPULATE_PROFILE_INFO_COMMON(VariableWrite)
+
+ // Type Invariant Extension, a capability extension that is independent
+ // of the data type, meaning any compatible type can be used. No type
+ // constraint for those operations.
+ POPULATE_PROFILE_INFO_SKIP(ConstShape)
+ POPULATE_PROFILE_INFO_SKIP(Yield)
return failure();
}
@@ -314,7 +361,7 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
auto it = compMap.find(opName);
if (it == compMap.end()) {
- // Operators such as variable and shape ops do not have an operand type
+ // Operators such as control-flow and shape ops do not have an operand type
// restriction. When the profile compliance information of operation is not
// found, confirm if the target have enabled the profile required from the
// specification.
@@ -425,7 +472,8 @@ template <typename T>
SmallVector<T> TosaProfileCompliance::findMatchedProfile(
Operation *op, SmallVector<OpComplianceInfo<T>> compInfo,
CheckCondition &condition) {
- assert(compInfo.size() != 0);
+ assert(compInfo.size() != 0 &&
+ "profile-based compliance information is empty");
// Populate the type of profile/extension relevant operands.
ProfileInfoDepot depot(op);
@@ -437,7 +485,10 @@ SmallVector<T> TosaProfileCompliance::findMatchedProfile(
SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet;
for (SmallVector<TypeInfo> expected : sets) {
- assert(present.size() == expected.size());
+ assert(present.size() == expected.size() &&
+ "the entries for profile-based compliance do not match between "
+ "the generated metadata and the type definition retrieved from "
+ " the operation");
bool is_found = true;
// Compare the type signature between the given operation and the
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 7867df2e3917f..ff910a40cf219 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -197,7 +197,7 @@ func.func @test_int_div(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>)
// -----
// CHECK-LABEL: logical_and
func.func @test_logical_and(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x21x1xi1>) -> tensor<13x21x3xi1> {
- // CHECK: profiles: [ [pro_int] ]
+ // CHECK: profiles: [ [pro_int, pro_fp] ]
// CHECK: extensions: [ ]
%0 = tosa.logical_and %arg0, %arg1 : (tensor<13x21x3xi1>, tensor<13x21x1xi1>) -> tensor<13x21x3xi1>
return %0 : tensor<13x21x3xi1>
@@ -206,7 +206,7 @@ func.func @test_logical_and(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x21x1xi1>
// -----
// CHECK-LABEL: logical_left_shift
func.func @test_logical_left_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> {
- // CHECK: profiles: [ [pro_int] ]
+ // CHECK: profiles: [ [pro_int, pro_fp] ]
// CHECK: extensions: [ ]
%0 = tosa.logical_left_shift %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
@@ -215,7 +215,7 @@ func.func @test_logical_left_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x
// -----
// CHECK-LABEL: logical_right_shift
func.func @test_logical_right_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> {
- // CHECK: profiles: [ [pro_int] ]
+ // CHECK: profiles: [ [pro_int, pro_fp] ]
// CHECK: extensions: [ ]
%0 = tosa.logical_right_shift %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
@@ -224,7 +224,7 @@ func.func @test_logical_right_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13
// -----
// CHECK-LABEL: logical_or
func.func @test_logical_or(%arg0: tensor<13x1x3xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
- // CHECK: profiles: [ [pro_int] ]
+ // CHECK: profiles: [ [pro_int, pro_fp] ]
// CHECK: extensions: [ ]
%0 = tosa.logical_or %arg0, %arg1 : (tensor<13x1x3xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1>
return %0 : tensor<13x21x3xi1>
@@ -233,7 +233,7 @@ func.func @test_logical_or(%arg0: tensor<13x1x3xi1>, %arg1: tensor<13x21x3xi1>)
// -----
// CHECK-LABEL: logical_xor
func.func @test_logical_xor(%arg0: tensor<13x1x3xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
- // CHECK: profiles: [ [pro_int] ]
+ // CHECK: profiles: [ [pro_int, pro_fp] ]
// CHECK: extensions: [ ]
%0 = tosa.logical_xor %arg0, %arg1 : (tensor<13x1x3xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1>
return %0 : tensor<13x21x3xi1>
@@ -289,7 +289,7 @@ func.func @test_sub(%arg0: tensor<1x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> te
// CHECK-LABEL: table
func.func @test_table(%arg0: tensor<64xi32>, %arg1: tensor<513x!quant.uniform<i16:f32, 1.0:0>>) -> tensor<64x!quant.uniform<i16:f32, 1.0:0>> {
// CHECK: profiles: [ [pro_int] ]
- // CHECK: extensions: [ [bf16] ]
+ // CHECK: extensions: [ [int16] ]
%0 = tosa.table %arg0, %arg1 : (tensor<64xi32>, tensor<513x!quant.uniform<i16:f32, 1.000000e+00>>) -> tensor<64x!quant.uniform<i16:f32, 1.000000e+00>>
return %0 : tensor<64x!quant.uniform<i16:f32, 1.0:0>>
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index ca7c71cd3b137..3203c64b439da 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -619,7 +619,7 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: ten
func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi32>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
// expected-error at +1 {{'tosa.variable' op name has already been declared}}
- tosa.variable @stored_var : tensor<1x4x8xi32>
+ tosa.variable @stored_var = dense<3> : tensor<1x4x8xi32>
return
}
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 13952716a9611..bde5b5ec7cffe 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -12,9 +12,9 @@ func.func @test_fft2d(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (te
}
// -----
-func.func @test_variable_read_type(%arg0: tensor<2x4x8xi32>) -> () {
+func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
// expected-error at +1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}}
- tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error at +1 {{'tosa.variable.read' op illegal: requires [variable]}}
%0 = tosa.variable.read @stored_var : tensor<2x4x8xi16>
return
@@ -23,7 +23,7 @@ func.func @test_variable_read_type(%arg0: tensor<2x4x8xi32>) -> () {
// -----
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
// expected-error at +1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}}
- tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error at +1 {{'tosa.variable.write' op illegal: requires [variable]}}
tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16>
return
diff --git a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
index 8b221eb38b4f4..e9cff00cbde37 100644
--- a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
@@ -53,6 +53,61 @@ func.func @test_add(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> t
return %0 : tensor<13x21x3xf32>
}
+// -----
+func.func @test_add_i32(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
+ // expected-error at +1 {{'tosa.add' op illegal: requires any of [pro_int, pro_fp] but not enabled in target}}
+ %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
+ return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+func.func @test_int_div(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
+ // expected-error at +1 {{'tosa.int_div' op illegal: requires any of [pro_int, pro_fp] but not enabled in target}}
+ %0 = tosa.int_div %arg0, %arg1 : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
+ return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+func.func @test_logical_and(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x21x1xi1>) -> tensor<13x21x3xi1> {
+ // expected-error at +1 {{'tosa.logical_and' op illegal: requires any of [pro_int, pro_fp] but not enabled in target}}
+ %0 = tosa.logical_and %arg0, %arg1 : (tensor<13x21x3xi1>, tensor<13x21x1xi1>) -> tensor<13x21x3xi1>
+ return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+func.func @test_logical_left_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> {
+ // expected-error at +1 {{'tosa.logical_left_shift' op illegal: requires any of [pro_int, pro_fp] but not enabled in target}}
+ %0 = tosa.logical_left_shift %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32>
+ return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+func.func @test_mul(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>, %shift: tensor<1xi8>) -> tensor<13x21x3xi32> {
+ // expected-error at +1 {{'tosa.mul' op illegal: requires any of [pro_int, pro_fp] but not enabled in target}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi32>, tensor<13x1x3xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
+ return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+func.func @test_sub(%arg0: tensor<1x21x3xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
+ // expected-error at +1 {{'tosa.sub' op illegal: requires any of [pro_int, pro_fp] but not enabled in target}}
+ %0 = tosa.sub %arg0, %arg1 : (tensor<1x21x3xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
+ return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+func.func @test_logical_not(%arg0: tensor<1x21x3xi1>) -> tensor<1x21x3xi1> {
+ // expected-error at +1 {{'tosa.logical_not' op illegal: requires any of [pro_int, pro_fp] but not enabled in target}}
+ %0 = tosa.logical_not %arg0 : (tensor<1x21x3xi1>) -> tensor<1x21x3xi1>
+ return %0 : tensor<1x21x3xi1>
+}
+
+// -----
+func.func @test_select(%arg0: tensor<1x1x1xi1>, %arg1: tensor<13x21x3xi1>, %arg2: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
+ // expected-error at +1 {{'tosa.select' op illegal: requires any of [pro_int, pro_fp] but not enabled in target}}
+ %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<1x1x1xi1>, tensor<13x21x3xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1>
+ return %0 : tensor<13x21x3xi1>
+}
// -----
func.func @test_reduce_all(%arg0: tensor<13x21x3xi1>) -> tensor<1x21x3xi1> {
// expected-error at +1 {{'tosa.reduce_all' op illegal: requires any of [pro_int, pro_fp] but not enabled in target}}
@@ -60,6 +115,13 @@ func.func @test_reduce_all(%arg0: tensor<13x21x3xi1>) -> tensor<1x21x3xi1> {
return %0 : tensor<1x21x3xi1>
}
+// -----
+func.func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<1x21x3xi1> {
+ // expected-error at +1 {{'tosa.reduce_any' op illegal: requires any of [pro_int, pro_fp] but not enabled in target}}
+ %0 = tosa.reduce_any %arg0 {axis = 0 : i32} : (tensor<13x21x3xi1>) -> tensor<1x21x3xi1>
+ return %0 : tensor<1x21x3xi1>
+}
+
// -----
func.func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<26x21x3xf32> {
// expected-error at +1 {{'tosa.concat' op illegal: requires [pro_fp] but not enabled in target}}
@@ -67,6 +129,66 @@ func.func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -
return %0 : tensor<26x21x3xf32>
}
+// -----
+func.func @test_concat(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<26x21x3xi1> {
+ // expected-error at +1 {{'tosa.concat' op illegal: requires any of [pro_int, pro_fp] but not enabled in target}}
+ %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xi1>, tensor<13x21x3xi1>) -> tensor<26x21x3xi1>
+ return %0 : tensor<26x21x3xi1>
+}
+
+// -----
+func.func @test_pad(%arg0: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
+ // expected-error at +1 {{'tosa.const_shape' op illegal: requires any of [pro_int, pro_fp] but not enabled in target}}
+ %padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
+ // expected-error at +1 {{'tosa.const' op illegal: requires any of [pro_int, pro_fp] but not enabled in target}}
+ %pad_const = "tosa.const"() {values = dense<1> : tensor<1xi1>} : () -> tensor<1xi1>
+ // expected-error at +1 {{'tosa.pad' op illegal: requires any of [pro_int, pro_fp] but not enabled in target}}
+ %0 = tosa.pad %arg0, %padding, %pad_const : (tensor<13x21x3xi1>, !tosa.shape<6>, tensor<1xi1>) -> tensor<13x21x3xi1>
+ return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+func.func @test_reshape(%arg0: tensor<13x21x3xi1>) -> tensor<1x819xi1> {
+ // expected-error at +1 {{'tosa.const_shape' op illegal: requires any of [pro_int, pro_fp] but not enabled in target}}
+ %1 = tosa.const_shape {values = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.reshape' op illegal: requires any of [pro_int, pro_fp] but not enabled in target}}
+ %0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xi1>, !tosa.shape<2>) -> tensor<1x819xi1>
+ return %0 : tensor<1x819xi1>
+}
+
+// -----
+func.func @test_reverse(%arg0: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
+ // expected-error at +1 {{'tosa.reverse' op illegal: requires any of [pro_int, pro_fp] but not enabled in target}}
+ %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xi1>) -> tensor<13x21x3xi1>
+ return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+func.func @test_slice(%arg0: tensor<13x21x3xi1>) -> tensor<4x11x1xi1> {
+ // expected-error at +1 {{'tosa.const_shape' op illegal: requires any of [pro_int, pro_fp] but not enabled in target}}
+ %0 = tosa.const_shape {values = dense<[4, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ // expected-error at +1 {{'tosa.const_shape' op illegal: requires any of [pro_int, pro_fp] but not enabled in target}}
+ %1 = tosa.const_shape {values = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ // expected-error at +1 {{'tosa.slice' op illegal: requires any of [pro_int, pro_fp] but not enabled in target}}
+ %2 = tosa.slice %arg0, %0, %1 : (tensor<13x21x3xi1>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xi1>
+ return %2 : tensor<4x11x1xi1>
+}
+
+// -----
+func.func @test_tile(%arg0: tensor<13x21x3xi1>) -> tensor<39x21x6xi1> {
+ // expected-error at +1 {{'tosa.const_shape' op illegal: requires any of [pro_int, pro_fp] but not enabled in target}}
+ %cst = tosa.const_shape { values = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3>
+ // expected-error at +1 {{'tosa.tile' op illegal: requires any of [pro_int, pro_fp] but not enabled in target}}
+ %0 = tosa.tile %arg0, %cst: (tensor<13x21x3xi1>, !tosa.shape<3>) -> tensor<39x21x6xi1>
+ return %0 : tensor<39x21x6xi1>
+}
+
+// -----
+func.func @test_transpose(%arg0: tensor<13x21x3xi1>) -> tensor<3x13x21xi1> {
+ // expected-error at +1 {{'tosa.transpose' op illegal: requires any of [pro_int, pro_fp] but not enabled in target}}
+ %1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<13x21x3xi1>) -> tensor<3x13x21xi1>
+ return %1 : tensor<3x13x21xi1>
+}
// -----
func.func @test_cast_i32_f32(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
// expected-error at +1 {{'tosa.cast' op illegal: requires [pro_fp] but not enabled in target}}
diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
index c7221a88bda00..c69f78fcb9d1a 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
@@ -27,13 +27,9 @@ func.func @test_cast_i8_i32(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi8> {
// -----
func.func @test_rescale(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi32> {
- // expected-error at +1 {{'tosa.const' op illegal: requires [pro_int] but not enabled in target}}
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32>} : () -> tensor<1xi32>
- // expected-error at +1 {{'tosa.const' op illegal: requires [pro_int] but not enabled in target}}
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8>} : () -> tensor<1xi8>
- // expected-error at +1 {{'tosa.const' op illegal: requires [pro_int] but not enabled in target}}
%input_zp = "tosa.const"() {values = dense<127> : tensor<1xi8>} : () -> tensor<1xi8>
- // expected-error at +1 {{'tosa.const' op illegal: requires [pro_int] but not enabled in target}}
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
// expected-error at +1 {{'tosa.rescale' op illegal: requires [pro_int] but not enabled in target}}
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", scale32 = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<13x21x3xi32>
More information about the Mlir-commits
mailing list