[Mlir-commits] [llvm] [mlir] [mlir][tosa] Add profile-based operation validation (PR #126992)

TatWai Chong llvmlistbot at llvm.org
Thu Feb 13 18:08:06 PST 2025


https://github.com/tatwaichong updated https://github.com/llvm/llvm-project/pull/126992

>From 1e99d1ee9bd3ca656c45a45252df6218f89a0cfc Mon Sep 17 00:00:00 2001
From: TatWai Chong <tatwai.chong at arm.com>
Date: Wed, 12 Feb 2025 16:21:29 -0800
Subject: [PATCH] [mlir][tosa] Add profile-based operation validation

TOSA MLIR profile-based validation is designed to identify the
profile/extension requirements for each operation in TOSA MLIR
graph, ensuring that TOSA operators conform to the profiles and
extensions enabled by the target implementation.

The available profiles/extensions are reflected in the availability
property attached to each TOSA operator in the dialect. The design
of availability, the profile/extension classes, and their interface,
is inspired by the SPIRV implementation.

This patch includes the following changes:
 - Introduces profile and extension knowledge within the dialect
   and establishes an interface to query this information.
 - Implements profile-based validation logic in the pass.
 - Adds a TargetEnv class that represents the capabilities enabled
   in the target implementation, such as profiles, extensions, and
   levels.
 - Adds a set of tests to ensure that profile and extension
   requirements are properly attached to the operations and that
   validation correctly verifies the requirements of a given
   operation against the target implementation.
---
 .../Conversion/TosaToLinalg/TosaToLinalg.h    |   2 +-
 .../mlir/Dialect/Tosa/IR/CMakeLists.txt       |  11 +
 mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h |  84 +++
 .../mlir/Dialect/Tosa/IR/TosaComplianceData.h | 403 +++++++++++
 .../mlir/Dialect/Tosa/IR/TosaOpBase.td        | 200 +++++
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h   |   7 +
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  | 356 +++++++++
 .../Dialect/Tosa/IR/TosaProfileCompliance.h   | 163 +++++
 .../mlir/Dialect/Tosa/IR/TosaShapeOps.td      |   4 +
 .../mlir/Dialect/Tosa/IR/TosaUtilOps.td       |  15 +
 .../mlir/Dialect/Tosa/Transforms/Passes.h     |  23 +-
 .../mlir/Dialect/Tosa/Transforms/Passes.td    |  14 +-
 .../TosaToLinalg/TosaToLinalgPass.cpp         |   3 +-
 mlir/lib/Dialect/Tosa/CMakeLists.txt          |   2 +
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          |   3 +
 .../Dialect/Tosa/Transforms/CMakeLists.txt    |   1 +
 .../Tosa/Transforms/TosaProfileCompliance.cpp | 476 ++++++++++++
 .../Tosa/Transforms/TosaValidation.cpp        |  48 +-
 mlir/test/Dialect/Tosa/availability.mlir      | 681 ++++++++++++++++++
 mlir/test/Dialect/Tosa/invalid.mlir           |   3 +-
 mlir/test/Dialect/Tosa/invalid_extension.mlir |  38 +
 mlir/test/Dialect/Tosa/level_check.mlir       |   5 +-
 .../Dialect/Tosa/profile_all_unsupported.mlir |  83 +++
 .../Dialect/Tosa/profile_bi_unsupported.mlir  |  26 +
 .../Dialect/Tosa/profile_mi_unsupported.mlir  |  62 ++
 mlir/test/lib/Dialect/Tosa/CMakeLists.txt     |   1 +
 .../lib/Dialect/Tosa/TestAvailability.cpp     |  78 ++
 mlir/tools/mlir-opt/mlir-opt.cpp              |   2 +
 mlir/tools/mlir-tblgen/CMakeLists.txt         |   1 +
 mlir/tools/mlir-tblgen/TosaUtilsGen.cpp       | 226 ++++++
 .../llvm-project-overlay/mlir/BUILD.bazel     |  36 +
 31 files changed, 3005 insertions(+), 52 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
 create mode 100644 mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h
 create mode 100644 mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
 create mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
 create mode 100644 mlir/test/Dialect/Tosa/availability.mlir
 create mode 100644 mlir/test/Dialect/Tosa/invalid_extension.mlir
 create mode 100644 mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
 create mode 100644 mlir/test/Dialect/Tosa/profile_bi_unsupported.mlir
 create mode 100644 mlir/test/Dialect/Tosa/profile_mi_unsupported.mlir
 create mode 100644 mlir/test/lib/Dialect/Tosa/TestAvailability.cpp
 create mode 100644 mlir/tools/mlir-tblgen/TosaUtilsGen.cpp

diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index a1eb22eba6987..195a58432737b 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -40,7 +40,7 @@ void addTosaToLinalgPasses(
     // Note: Default to 'none' level unless otherwise specified.
     std::optional<tosa::TosaValidationOptions> validationOptions =
         tosa::TosaValidationOptions{
-            {"none"}, false, tosa::TosaLevelEnum::None});
+            {"none"}, {"none"}, false, tosa::TosaLevelEnum::None});
 
 /// Populates TOSA to linalg pipelines
 /// Currently, this includes only the "tosa-to-linalg-pipeline".
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
index cc8d5ed9b0044..0a855d701d7b8 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
@@ -12,3 +12,14 @@ add_public_tablegen_target(MLIRTosaAttributesIncGen)
 set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td)
 mlir_tablegen(TosaDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Tosa")
 add_public_tablegen_target(MLIRTosaDialectBytecodeIncGen)
+
+set(LLVM_TARGET_DEFINITIONS TosaOpBase.td)
+mlir_tablegen(TosaEnums.h.inc -gen-enum-decls)
+mlir_tablegen(TosaEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRTosaEnumsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS TosaOps.td)
+mlir_tablegen(TosaAvailability.h.inc -gen-avail-interface-decls)
+mlir_tablegen(TosaAvailability.cpp.inc -gen-avail-interface-defs)
+mlir_tablegen(TosaOpAvailabilityImpl.inc -gen-tosa-avail-impls)
+add_public_tablegen_target(MLIRTosaAvailabilityIncGen)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
new file mode 100644
index 0000000000000..86fb4077b9207
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
@@ -0,0 +1,84 @@
+//===- TargetEnv.h - Tosa target environment utilities ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares utilities for Tosa target environment (implementation).
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TOSA_IR_TARGETENV_H
+#define MLIR_DIALECT_TOSA_IR_TARGETENV_H
+
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallSet.h"
+
+namespace mlir {
+namespace tosa {
+
+/// This class represents the capability enabled in the target implementation
+/// such as profile, extension, and level.
+class TargetEnv {
+public:
+  TargetEnv() {}
+  explicit TargetEnv(const SmallVectorImpl<Profile> &profiles,
+                     const SmallVectorImpl<Extension> &extensions) {
+    for (Profile prof : profiles)
+      enabledProfiles.insert(prof);
+
+    for (Extension ext : extensions)
+      enabledExtensions.insert(ext);
+  }
+
+  void addProfile(Profile p) { enabledProfiles.insert(p); }
+  void addExtension(Extension e) { enabledExtensions.insert(e); }
+
+  // TODO implement the following utilities.
+  // Version getSpecVersion() const;
+  // TosaLevel getLevel() const;
+
+  // Returns true if the given profile is allowed.
+  bool allows(Profile prof) const { return enabledProfiles.count(prof) != 0; }
+
+  bool allowsAnyOf(ArrayRef<Profile> profs) const {
+    const auto *chosen = llvm::find_if(
+        profs, [this](tosa::Profile prof) { return allows(prof); });
+    return chosen != profs.end() ? true : false;
+  }
+
+  bool allowsAllOf(ArrayRef<Profile> profs) const {
+    bool is_allowed = true;
+    llvm::for_each(profs,
+                   [&](tosa::Profile prof) { is_allowed &= allows(prof); });
+    return is_allowed;
+  }
+
+  // Returns true if the given extension is allowed.
+  bool allows(Extension ext) const { return enabledExtensions.count(ext) != 0; }
+
+  bool allowsAnyOf(ArrayRef<Extension> exts) const {
+    const auto *chosen = llvm::find_if(
+        exts, [this](tosa::Extension ext) { return allows(ext); });
+    return chosen != exts.end() ? true : false;
+  }
+
+  bool allowsAllOf(ArrayRef<Extension> exts) const {
+    bool is_allowed = true;
+    llvm::for_each(exts,
+                   [&](tosa::Extension ext) { is_allowed &= allows(ext); });
+    return is_allowed;
+  }
+
+private:
+  llvm::SmallSet<Profile, 3> enabledProfiles;
+  llvm::SmallSet<Extension, 8> enabledExtensions;
+};
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TOSA_IR_TARGETENV_H
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h
new file mode 100644
index 0000000000000..1a10d8579962d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h
@@ -0,0 +1,403 @@
+// The profile-based compliance content below is auto-generated by a script
+// in https://git.mlplatform.org/tosa/specification.git
+profileComplianceMap = {
+    {"tosa.argmax",
+     {{{Profile::pro_int}, {{i8T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, i32T}, {fp32T, i32T}}}}},
+    {"tosa.avg_pool2d",
+     {{{Profile::pro_int}, {{i8T, i32T, i8T}}},
+      {{Profile::pro_fp},
+       {{fp16T, fp16T, fp16T}, {fp16T, fp32T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.conv2d",
+     {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}},
+      {{Profile::pro_fp},
+       {{fp16T, fp16T, fp16T, fp16T, fp16T},
+        {fp16T, fp16T, fp16T, fp32T, fp16T},
+        {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+    {"tosa.conv3d",
+     {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}},
+      {{Profile::pro_fp},
+       {{fp16T, fp16T, fp16T, fp16T, fp16T},
+        {fp16T, fp16T, fp16T, fp32T, fp16T},
+        {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+    {"tosa.depthwise_conv2d",
+     {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}},
+      {{Profile::pro_fp},
+       {{fp16T, fp16T, fp16T, fp16T, fp16T},
+        {fp16T, fp16T, fp16T, fp32T, fp16T},
+        {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+    {"tosa.fully_connected",
+     {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T}}},
+      {{Profile::pro_fp},
+       {{fp16T, fp16T, fp16T, fp16T},
+        {fp16T, fp16T, fp32T, fp32T},
+        {fp32T, fp32T, fp32T, fp32T}}}}},
+    {"tosa.matmul",
+     {{{Profile::pro_int}, {{i8T, i8T, i32T}}},
+      {{Profile::pro_fp},
+       {{fp16T, fp16T, fp16T}, {fp16T, fp16T, fp32T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.max_pool2d",
+     {{{Profile::pro_int}, {{i8T, i8T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.transpose_conv2d",
+     {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}},
+      {{Profile::pro_fp},
+       {{fp16T, fp16T, fp16T, fp16T, fp16T},
+        {fp16T, fp16T, fp16T, fp32T, fp16T},
+        {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+    {"tosa.clamp",
+     {{{Profile::pro_int}, {{i8T, i8T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.erf", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.sigmoid", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.tanh", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.add",
+     {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.arithmetic_right_shift",
+     {{{Profile::pro_int},
+       {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+    {"tosa.bitwise_and",
+     {{{Profile::pro_int},
+       {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+    {"tosa.bitwise_or",
+     {{{Profile::pro_int},
+       {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+    {"tosa.bitwise_xor",
+     {{{Profile::pro_int},
+       {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+    {"tosa.intdiv",
+     {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}}}},
+    {"tosa.logical_and",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}}}},
+    {"tosa.logical_left_shift",
+     {{{Profile::pro_int, Profile::pro_fp},
+       {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+    {"tosa.logical_right_shift",
+     {{{Profile::pro_int, Profile::pro_fp},
+       {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+    {"tosa.logical_or",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}}}},
+    {"tosa.logical_xor",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}}}},
+    {"tosa.maximum",
+     {{{Profile::pro_int}, {{i32T, i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.minimum",
+     {{{Profile::pro_int}, {{i32T, i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.mul",
+     {{{Profile::pro_int}, {{i8T, i8T, i32T}, {i16T, i16T, i32T}}},
+      {{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.pow",
+     {{{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.sub",
+     {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.table", {{{Profile::pro_int}, {{i8T, i8T, i8T}}}}},
+    {"tosa.abs",
+     {{{Profile::pro_int}, {{i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.bitwise_not",
+     {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}}},
+    {"tosa.ceil", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.clz", {{{Profile::pro_int}, {{i32T, i32T}}}}},
+    {"tosa.cos", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.exp", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.floor", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.log", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.logical_not",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}}},
+    {"tosa.negate",
+     {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.reciprocal",
+     {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.rsqrt", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.select",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}},
+      {{Profile::pro_int},
+       {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.sin", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.equal",
+     {{{Profile::pro_int}, {{i32T, i32T, boolT}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}},
+    {"tosa.greater",
+     {{{Profile::pro_int}, {{i32T, i32T, boolT}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}},
+    {"tosa.greater_equal",
+     {{{Profile::pro_int}, {{i32T, i32T, boolT}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}},
+    {"tosa.reduce_all",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}}},
+    {"tosa.reduce_any",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}}},
+    {"tosa.reduce_max",
+     {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.reduce_min",
+     {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.reduce_product",
+     {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.reduce_sum",
+     {{{Profile::pro_int}, {{i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.concat",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
+      {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.pad",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
+      {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.reshape",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
+      {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.reverse",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
+      {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.slice",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
+      {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.tile",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
+      {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.transpose",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
+      {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.gather",
+     {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.scatter",
+     {{{Profile::pro_int},
+       {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.resize",
+     {{{Profile::pro_int}, {{i8T, i32T}, {i8T, i8T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.cast",
+     {{{Profile::pro_int},
+       {{boolT, i8T},
+        {boolT, i16T},
+        {boolT, i32T},
+        {i8T, boolT},
+        {i8T, i16T},
+        {i8T, i32T},
+        {i16T, boolT},
+        {i16T, i8T},
+        {i16T, i32T},
+        {i32T, boolT},
+        {i32T, i8T},
+        {i32T, i16T}}},
+      {{Profile::pro_fp},
+       {{i8T, fp16T},
+        {i8T, fp32T},
+        {i16T, fp16T},
+        {i16T, fp32T},
+        {i32T, fp16T},
+        {i32T, fp32T},
+        {fp16T, i8T},
+        {fp16T, i16T},
+        {fp16T, i32T},
+        {fp16T, fp32T},
+        {fp32T, i8T},
+        {fp32T, i16T},
+        {fp32T, i32T},
+        {fp32T, fp16T}}}}},
+    {"tosa.rescale",
+     {{{Profile::pro_int},
+       {{i8T, i8T},
+        {i8T, i16T},
+        {i8T, i32T},
+        {i16T, i8T},
+        {i16T, i16T},
+        {i16T, i32T},
+        {i32T, i8T},
+        {i32T, i16T},
+        {i32T, i32T}}}}},
+    {"tosa.const",
+     {{{Profile::pro_int}, {{boolT}, {i8T}, {i16T}, {i32T}}},
+      {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
+    {"tosa.identity",
+     {{{Profile::pro_int},
+       {{boolT, boolT}, {i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.dim",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT}}},
+      {{Profile::pro_int}, {{i8T}, {i16T}, {i32T}}},
+      {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
+};
+
+extensionComplianceMap = {
+    {"tosa.argmax",
+     {{{Extension::int16}, {{i16T, i32T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, i32T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, i32T}}},
+      {{Extension::bf16}, {{bf16T, i32T}}}}},
+    {"tosa.avg_pool2d",
+     {{{Extension::int16}, {{i16T, i32T, i16T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, fp16T, fp8e4m3T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp16T, fp8e5m2T}}},
+      {{Extension::bf16}, {{bf16T, fp32T, bf16T}}}}},
+    {"tosa.conv2d",
+     {{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}},
+      {{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T, fp16T, fp16T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T, fp16T, fp16T}}},
+      {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+    {"tosa.conv3d",
+     {{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}},
+      {{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T, fp16T, fp16T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T, fp16T, fp16T}}},
+      {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+    {"tosa.depthwise_conv2d",
+     {{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}},
+      {{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T, fp16T, fp16T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T, fp16T, fp16T}}},
+      {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+    {"tosa.fft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T, fp32T}}}}},
+    {"tosa.fully_connected",
+     {{{Extension::int4}, {{i8T, i4T, i32T, i32T}}},
+      {{Extension::int16}, {{i16T, i8T, i48T, i48T}}},
+      {{Extension::bf16}, {{bf16T, bf16T, fp32T, fp32T}}}}},
+    {"tosa.matmul",
+     {{{Extension::int16}, {{i16T, i16T, i48T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T}}},
+      {{Extension::bf16}, {{bf16T, bf16T, fp32T}}}}},
+    {"tosa.max_pool2d",
+     {{{Extension::int16}, {{i16T, i16T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
+      {{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.rfft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T}}}}},
+    {"tosa.transpose_conv2d",
+     {{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}},
+      {{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T, fp16T, fp16T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T, fp16T, fp16T}}},
+      {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+    {"tosa.clamp",
+     {{{Extension::int16}, {{i16T, i16T}}},
+      {{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.erf", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.sigmoid", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.tanh", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.add", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
+    {"tosa.maximum", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
+    {"tosa.minimum", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
+    {"tosa.mul", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
+    {"tosa.pow", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
+    {"tosa.sub", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
+    {"tosa.table", {{{Extension::int16}, {{i16T, i16T, i32T}}}}},
+    {"tosa.abs", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.ceil", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.cos", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.exp", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.floor", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.log", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.reciprocal", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.rsqrt", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.select", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
+    {"tosa.sin", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}},
+    {"tosa.greater", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}},
+    {"tosa.greater_equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}},
+    {"tosa.reduce_max", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.reduce_min", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.reduce_product", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.reduce_sum", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.concat",
+     {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
+      {{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.pad",
+     {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
+      {{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.reshape",
+     {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
+      {{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.reverse",
+     {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
+      {{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.slice",
+     {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
+      {{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.tile",
+     {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
+      {{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.transpose",
+     {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
+      {{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.gather",
+     {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
+      {{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.scatter",
+     {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T}}},
+      {{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
+    {"tosa.resize",
+     {{{Extension::int16}, {{i16T, i48T}, {i16T, i16T}}},
+      {{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.cast",
+     {{{Extension::bf16},
+       {{i8T, bf16T},
+        {i16T, bf16T},
+        {i32T, bf16T},
+        {bf16T, i8T},
+        {bf16T, i16T},
+        {bf16T, i32T},
+        {bf16T, fp32T},
+        {fp32T, bf16T}}},
+      {{Extension::bf16, Extension::fp8e4m3},
+       {{bf16T, fp8e4m3T}, {fp8e4m3T, bf16T}}},
+      {{Extension::bf16, Extension::fp8e5m2},
+       {{bf16T, fp8e5m2T}, {fp8e5m2T, bf16T}}},
+      {{Extension::fp8e4m3},
+       {{fp8e4m3T, fp16T},
+        {fp8e4m3T, fp32T},
+        {fp16T, fp8e4m3T},
+        {fp32T, fp8e4m3T}}},
+      {{Extension::fp8e5m2},
+       {{fp8e5m2T, fp16T},
+        {fp8e5m2T, fp32T},
+        {fp16T, fp8e5m2T},
+        {fp32T, fp8e5m2T}}}}},
+    {"tosa.rescale",
+     {{{Extension::int16}, {{i48T, i8T}, {i48T, i16T}, {i48T, i32T}}}}},
+    {"tosa.const",
+     {{{Extension::int4}, {{i4T}}},
+      {{Extension::int16}, {{i48T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T}}},
+      {{Extension::bf16}, {{bf16T}}}}},
+    {"tosa.identity",
+     {{{Extension::int4}, {{i4T, i4T}}},
+      {{Extension::int16}, {{i48T, i48T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
+      {{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.dim",
+     {{{Extension::fp8e4m3}, {{fp8e4m3T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T}}},
+      {{Extension::bf16}, {{bf16T}}}}},
+};
+// End of auto-generated metadata
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 862d98ad436a6..ffcb2c91d3619 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -14,8 +14,15 @@
 #define TOSA_OP_BASE
 
 include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
 
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/LoopLikeInterface.td"
+
+include "mlir/Dialect/Tosa/IR/TosaInterfaces.td"
+
 //===----------------------------------------------------------------------===//
 // The TOSA Dialect.
 //===----------------------------------------------------------------------===//
@@ -200,6 +207,189 @@ def Tosa_ExplicitValuePadOpQuantInfoBuilder : OpBuilder<
                                          input, paddings, pad_value);
   }]>;
 
+// Wrapper over base I32EnumAttr to set common fields.
+class Tosa_I32Enum<string name, string description, list<I32EnumAttrCase> cases>
+     : I32EnumAttr<name, description, cases> {
+   let genSpecializedAttr = 0;
+   let cppNamespace = "::mlir::tosa";
+}
+
+class Tosa_I32EnumAttr<string name, string description, string mnemonic,
+                         list<I32EnumAttrCase> cases>
+    : EnumAttr<Tosa_Dialect, Tosa_I32Enum<name, description, cases>, mnemonic> {
+   let assemblyFormat = "`<` $value `>`";
+}
+
+//===----------------------------------------------------------------------===//
+// TOSA Spec Section 1.5.
+//
+// Profile:
+// INT : Integer Inference. Integer operations, primarily 8 and 32-bit values.
+// FP  : Floating-Point Inference. Primarily FP16 and FP32 operations.
+//
+// Extension:
+// INT16    : 16-bit integer operations.
+// INT4     : 4-bit integer weights.
+// BF16     : BFloat16 operations.
+// FP8      : 8-bit floating-point operations E4M3.
+// FP8      : 8-bit floating-point operations E5M2.
+// FFT      : Fast Fourier Transform operations.
+// VARIABLE : Stateful variable operations.
+//===----------------------------------------------------------------------===//
+
+def Tosa_PRO_INT   : I32EnumAttrCase<"pro_int", 1>;
+def Tosa_PRO_FP   : I32EnumAttrCase<"pro_fp", 2>;
+def Tosa_NONE : I32EnumAttrCase<"none", 3>;
+
+def Tosa_EXT_INT16    : I32EnumAttrCase<"int16", 1>;
+def Tosa_EXT_INT4     : I32EnumAttrCase<"int4", 2>;
+def Tosa_EXT_BF16     : I32EnumAttrCase<"bf16", 3>;
+def Tosa_EXT_FP8E4M3  : I32EnumAttrCase<"fp8e4m3", 4>;
+def Tosa_EXT_FP8E5M2  : I32EnumAttrCase<"fp8e5m2", 5>;
+def Tosa_EXT_FFT      : I32EnumAttrCase<"fft", 6>;
+def Tosa_EXT_VARIABLE : I32EnumAttrCase<"variable", 7>;
+def Tosa_EXT_NONE     : I32EnumAttrCase<"none", 8>;
+
+def Tosa_ExtensionAttr
+    : Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [
+      Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16, Tosa_EXT_FP8E4M3,
+      Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_NONE
+    ]>;
+
+def Tosa_ExtensionArrayAttr
+    : TypedArrayAttrBase<Tosa_ExtensionAttr, "TOSA extension array attribute">;
+
+def Tosa_ProfileAttr
+    : Tosa_I32EnumAttr<"Profile", "supported TOSA profiles", "prof",
+                       [Tosa_PRO_INT, Tosa_PRO_FP, Tosa_NONE]>;
+
+def Tosa_ProfileArrayAttr
+    : TypedArrayAttrBase<Tosa_ProfileAttr, "TOSA profile array attribute">;
+
+// The base class for defining op availability dimensions.
+class Availability {
+  // The following are fields for controlling the generated C++ OpInterface.
+
+  // The namespace for the generated C++ OpInterface subclass.
+  string cppNamespace = "::mlir::tosa";
+
+  // The name for the generated C++ OpInterface subclass.
+  string interfaceName = ?;
+
+  // The description for the generated C++ OpInterface subclass.
+  string interfaceDescription = "";
+
+  // The query function's return type in the generated C++ OpInterface subclass.
+  string queryFnRetType = ?;
+
+  // The query function's name in the generated C++ OpInterface subclass.
+  string queryFnName = ?;
+
+  code mergeAction = ?;
+
+  // The initializer for the final availability requirement.
+  string initializer = ?;
+
+  // An availability instance's type.
+  string instanceType = ?;
+
+  // The following are fields for a concrete availability instance.
+
+  // The code for preparing a concrete instance. This should be C++ statements
+  // and will be generated before the `mergeAction` logic.
+  code instancePreparation = "";
+
+  // The availability requirement carried by a concrete instance.
+  string instance = ?;
+}
+
+
+class Profile<list<I32EnumAttrCase> profiles> : Availability {
+  let interfaceName = "QueryProfileInterface";
+  let interfaceDescription = [{
+    Querying interface for the supported set of Tosa profile.
+
+    This interface provides a `getProfiles()` method to query
+    the supported set of Tosa profile. The returned value is a
+    list of `mlir::Tosa::Profile` enum number.
+  }];
+
+  let queryFnRetType = "::llvm::SmallVector<::llvm::ArrayRef<"
+                          "::mlir::tosa::Profile>, 1>";
+  let queryFnName = "getProfiles";
+
+  let mergeAction = !if(
+      !empty(profiles), "", "$overall.emplace_back($instance)");
+
+  let initializer = "{}";
+
+  let instanceType = "::llvm::ArrayRef<::mlir::tosa::Profile>";
+
+  // Pack all profiles as a static array and get its reference.
+  let instancePreparation = !if(!empty(profiles), "",
+    "static const ::mlir::tosa::Profile profs[] = {" #
+    !interleave(!foreach(prof, profiles,
+                         "::mlir::tosa::Profile::" # prof.symbol), ", ") #
+    "}; " #
+    "ArrayRef<::mlir::tosa::Profile> " #
+      "ref(profs, std::size(profs));");
+
+  let instance = "ref";
+}
+
+class Extension<list<I32EnumAttrCase> extensions> : Availability {
+  let interfaceName = "QueryExtensionInterface";
+  let interfaceDescription = [{
+    Querying interface for the supported set of TOSA extension.
+
+    This interface provides a `getExtensions()` method to query
+    the supported set of Tosa extension. The returned value is a
+    list of `mlir::Tosa::Extension` enum number.
+  }];
+
+  let queryFnRetType = "::llvm::SmallVector<::llvm::ArrayRef<"
+                          "::mlir::tosa::Extension>, 1>";
+  let queryFnName = "getExtensions";
+
+  let mergeAction = !if(
+      !empty(extensions), "", "$overall.emplace_back($instance)");
+
+  let initializer = "{}";
+
+  let instanceType = "::llvm::ArrayRef<::mlir::tosa::Extension>";
+
+  // Pack all extensions as a static array and get its reference.
+  let instancePreparation = !if(!empty(extensions), "",
+    "static const ::mlir::tosa::Extension exts[] = {" #
+    !interleave(!foreach(ext, extensions,
+                         "::mlir::tosa::Extension::" # ext.symbol), ", ") #
+    "}; " #
+    "ArrayRef<::mlir::tosa::Extension> " #
+      "ref(exts, std::size(exts));");
+
+  let instance = "ref";
+}
+
+//===----------------------------------------------------------------------===//
+// TOSA Interfaces.
+//===----------------------------------------------------------------------===//
+
+def QueryProfileInterface : OpInterface<"QueryProfileInterface"> {
+  let cppNamespace = "::mlir::tosa";
+  let methods = [InterfaceMethod<
+    "get supported profiles",
+    "::llvm::SmallVector<::llvm::ArrayRef<::mlir::tosa::Profile>, 1>",
+    "getProfiles">];
+}
+
+def QueryExtensionInterface : OpInterface<"QueryExtensionInterface"> {
+  let cppNamespace = "::mlir::tosa";
+  let methods = [InterfaceMethod<
+    "get supported extensions",
+    "::llvm::SmallVector<::llvm::ArrayRef<::mlir::tosa::Extension>, 1>",
+    "getExtensions">];
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Operator Trait.
 //===----------------------------------------------------------------------===//
@@ -223,7 +413,17 @@ def TosaResolvableShapeOperands : NativeOpTrait<"TosaResolvableShapeOperands"> {
 
 class Tosa_Op<string mnemonic, list<Trait> traits = []> :
     Op<Tosa_Dialect, mnemonic, !listconcat(traits, [TosaOpInterface,
+    DeclareOpInterfaceMethods<QueryProfileInterface>,
+    DeclareOpInterfaceMethods<QueryExtensionInterface>,
     TosaResolvableShapeOperands])> {
+
+  // Default availability specification.
+  list<Availability> availability = [
+    Profile<[]>,
+    Extension<[]>];
+
+  // When not set, manual implementation of these methods is required.
+  bit autogenAvailability = 1;
 }
 
 class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 069073bc2d164..358e5dabfeb62 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -29,9 +29,16 @@
 // TOSA dialect and structs includes.
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Tosa/IR/TosaEnums.h.inc"
 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.h.inc"
 #include "mlir/Transforms/DialectConversion.h"
 
+//===----------------------------------------------------------------------===//
+// TOSA operation validation includes.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/IR/TosaAvailability.h.inc"
+
 namespace mlir {
 class PatternRewriter;
 
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index d11ba65a13736..66956c9948f65 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -50,6 +50,11 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
     Tosa_Tensor: $output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
+  ];
+
   let hasFolder = 1;
   let hasVerifier = 1;
 }
@@ -86,6 +91,11 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
     Tosa_Tensor4D:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
+  ];
+
   let builders = [Tosa_AvgPool2dOpQuantInfoBuilder];
   let hasVerifier = 1;
 }
@@ -118,6 +128,11 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
     Tosa_Tensor4D:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_INT4, Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
+  ];
+
   let builders = [Tosa_ConvOpQuantInfoBuilder];
   let hasVerifier = 1;
 }
@@ -149,6 +164,11 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
     Tosa_Tensor5D:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_INT4, Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
+  ];
+
   let builders = [Tosa_ConvOpQuantInfoBuilder];
   let hasVerifier = 1;
 }
@@ -181,6 +201,11 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
     Tosa_Tensor4D:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_INT4, Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
+  ];
+
   let builders = [Tosa_ConvOpQuantInfoBuilder];
   let hasVerifier = 1;
 }
@@ -218,6 +243,11 @@ def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d"> {
     Tosa_Tensor3D:$output_imag
   );
 
+  list<Availability> availability = [
+    Profile<[]>,
+    Extension<[Tosa_EXT_FFT]>,
+  ];
+
   let assemblyFormat = [{
     $input_real `,` $input_imag attr-dict `:` `(` type($input_real) `,`
     type($input_imag) `)` `->` `(` type($output_real) `,` type($output_imag) `)`
@@ -247,6 +277,11 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
     Tosa_Tensor3D:$c
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
+  ];
+
   let builders = [Tosa_MatMulOpQuantInfoBuilder];
 }
 
@@ -276,6 +311,11 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
     Tosa_Tensor4D:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
+  ];
+
   let hasCanonicalizer = 1;
 }
 
@@ -310,6 +350,11 @@ def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d"> {
     Tosa_Tensor3D:$output_imag
   );
 
+  list<Availability> availability = [
+    Profile<[]>,
+    Extension<[Tosa_EXT_FFT]>,
+  ];
+
   let assemblyFormat = [{
     $input attr-dict `:` `(` type($input) `)` `->` `(` type($output_real) `,` type($output_imag) `)`
   }];
@@ -343,6 +388,11 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
     Tosa_Tensor4D:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_INT4, Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
+  ];
+
   let builders = [Tosa_TransConvOpQuantInfoBuilder];
   let hasVerifier = 1;
 }
@@ -377,6 +427,11 @@ def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> {
     Tosa_Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_INT16, Tosa_EXT_BF16]>,
+  ];
+
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
 }
@@ -402,6 +457,11 @@ def Tosa_SigmoidOp : Tosa_ElementwiseUnaryOp<"sigmoid"> {
   let results = (outs
     Tosa_Tensor:$output
   );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -424,6 +484,11 @@ def Tosa_TanhOp : Tosa_ElementwiseUnaryOp<"tanh"> {
   let results = (outs
     Tosa_Tensor:$output
   );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -447,6 +512,11 @@ def Tosa_ErfOp : Tosa_ElementwiseUnaryOp<"erf"> {
     Tosa_Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
+
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 }
 
@@ -488,6 +558,11 @@ def Tosa_AddOp : Tosa_ElementwiseOp<"add", [
     Tosa_Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
+
   let hasFolder = 1;
 }
 
@@ -512,6 +587,11 @@ def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift",
   let results = (outs
     Tosa_Tensor:$output
   );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT]>,
+    Extension<[]>,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -535,6 +615,11 @@ def Tosa_BitwiseAndOp : Tosa_ElementwiseOp<"bitwise_and", [
   let results = (outs
     Tosa_Tensor:$output
   );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT]>,
+    Extension<[]>,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -558,6 +643,11 @@ def Tosa_BitwiseOrOp : Tosa_ElementwiseOp<"bitwise_or", [
   let results = (outs
     Tosa_Tensor:$output
   );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT]>,
+    Extension<[]>,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -581,6 +671,11 @@ def Tosa_BitwiseXorOp : Tosa_ElementwiseOp<"bitwise_xor", [
   let results = (outs
     Tosa_Tensor:$output
   );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT]>,
+    Extension<[]>,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -603,6 +698,11 @@ def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [SameOperandsAndResultElementT
     Tosa_Int32Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[]>,
+  ];
+
   let hasFolder = 1;
 }
 
@@ -627,6 +727,11 @@ def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
   let results = (outs
     Tosa_I1Tensor:$z
   );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT]>,
+    Extension<[]>,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -649,6 +754,11 @@ def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift",
   let results = (outs
     Tosa_Tensor:$output
   );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT]>,
+    Extension<[]>,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -671,6 +781,11 @@ def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift",
   let results = (outs
     Tosa_Tensor:$output
   );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT]>,
+    Extension<[]>,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -694,6 +809,11 @@ def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
   let results = (outs
     Tosa_I1Tensor:$z
   );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT]>,
+    Extension<[]>,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -717,6 +837,11 @@ def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
   let results = (outs
     Tosa_I1Tensor:$z
   );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT]>,
+    Extension<[]>,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -741,6 +866,11 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
   let results = (outs
     Tosa_Tensor:$output
   );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -765,6 +895,11 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
   let results = (outs
     Tosa_Tensor:$output
   );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
 }
 
 def MulOperandsAndResultElementType :
@@ -799,6 +934,11 @@ def Tosa_MulOp : Tosa_Op<"mul", [
     Tosa_Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
+
   let hasFolder = 1;
   let hasVerifier = 1;
 
@@ -825,6 +965,11 @@ def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [SameOperandsAndResultElementType]> {
   let results = (outs
     Tosa_Tensor:$output
   );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -847,6 +992,11 @@ def Tosa_SubOp : Tosa_ElementwiseOp<"sub", [SameOperandsAndResultElementType]> {
     Tosa_Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
+
   let hasFolder = 1;
 }
 
@@ -882,6 +1032,11 @@ def Tosa_TableOp : Tosa_InferShapedTypeOp<"table"> {
     Tosa_Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
+
   let assemblyFormat = [{
     $input1 `,` $table attr-dict `:` `(` type($input1) `,` type($table) `)` `->` type($output)
   }];
@@ -919,6 +1074,11 @@ def Tosa_AbsOp : Tosa_ElementwiseUnaryOp<"abs"> {
     Tosa_Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
+
   let hasFolder = 1;
 }
 
@@ -939,6 +1099,11 @@ def Tosa_BitwiseNotOp : Tosa_ElementwiseUnaryOp<"bitwise_not"> {
   let results = (outs
     Tosa_Tensor:$output
   );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT]>,
+    Extension<[]>,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -958,6 +1123,11 @@ def Tosa_CeilOp : Tosa_ElementwiseUnaryOp<"ceil"> {
   let results = (outs
     Tosa_Tensor:$output
   );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -977,6 +1147,11 @@ def Tosa_ClzOp : Tosa_ElementwiseUnaryOp<"clz"> {
   let results = (outs
     Tosa_Tensor:$output
   );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT]>,
+    Extension<[]>,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -996,6 +1171,11 @@ def Tosa_CosOp : Tosa_ElementwiseUnaryOp<"cos"> {
   let results = (outs
     Tosa_FloatTensor:$output
   );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -1016,6 +1196,11 @@ def Tosa_ExpOp : Tosa_ElementwiseUnaryOp<"exp"> {
     Tosa_Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
+
   let hasFolder = 1;
 }
 
@@ -1036,6 +1221,11 @@ def Tosa_FloorOp : Tosa_ElementwiseUnaryOp<"floor"> {
   let results = (outs
     Tosa_Tensor:$output
   );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -1056,6 +1246,11 @@ def Tosa_LogOp : Tosa_ElementwiseUnaryOp<"log"> {
     Tosa_Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
+
   let hasFolder = 1;
 }
 
@@ -1076,6 +1271,11 @@ def Tosa_LogicalNotOp : Tosa_ElementwiseUnaryOp<"logical_not"> {
   let results = (outs
     Tosa_I1Tensor:$output
   );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[]>,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -1098,6 +1298,11 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
     Tosa_Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
+
   let builders = [Tosa_UnaryOpQuantInfoBuilder];
 
   let hasFolder = 1;
@@ -1122,6 +1327,11 @@ def Tosa_ReciprocalOp : Tosa_ElementwiseUnaryOp<"reciprocal"> {
     Tosa_Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
+
   let extraClassDeclaration = [{
     /// Return the reciprocal result on the operand.
     static inline APFloat calcOneElement(const APFloat &operand) {
@@ -1152,6 +1362,11 @@ def Tosa_RsqrtOp : Tosa_ElementwiseUnaryOp<"rsqrt"> {
   let results = (outs
     Tosa_Tensor:$output
   );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -1171,6 +1386,11 @@ def Tosa_SinOp : Tosa_ElementwiseUnaryOp<"sin"> {
   let results = (outs
     Tosa_FloatTensor:$output
   );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -1198,6 +1418,12 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
   let results = (outs
     Tosa_Tensor:$output
   );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
+
   let hasCanonicalizeMethod = 1;
   let hasFolder = 1;
 
@@ -1234,6 +1460,11 @@ def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
     Tosa_I1Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
+
   let extraClassDeclaration = [{
     /// Returns when two result types are compatible for this op; method used by
     /// InferTypeOpInterface.
@@ -1262,6 +1493,11 @@ def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [SameOperandsElementType]> {
     Tosa_I1Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
+
   let hasFolder = 1;
 }
 
@@ -1285,6 +1521,11 @@ def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal",
     Tosa_I1Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
+
   let hasFolder = 1;
 }
 
@@ -1312,6 +1553,11 @@ def Tosa_ReduceAllOp : Tosa_InferTensorTypeOp<"reduce_all"> {
     Tosa_Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[]>,
+  ];
+
   let hasFolder = 1;
   let hasVerifier = 1;
 
@@ -1346,6 +1592,11 @@ def Tosa_ReduceAnyOp : Tosa_InferTensorTypeOp<"reduce_any"> {
     Tosa_Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[]>,
+  ];
+
   let hasFolder = 1;
   let hasVerifier = 1;
 
@@ -1381,6 +1632,11 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
     Tosa_Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
+
   let hasFolder = 1;
   let hasVerifier = 1;
 
@@ -1417,6 +1673,11 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
     Tosa_Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
+
   let hasFolder = 1;
   let hasVerifier = 1;
 
@@ -1452,6 +1713,11 @@ def Tosa_ReduceProdOp : Tosa_InferTensorTypeOp<"reduce_prod"> {
     Tosa_Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
+
   let hasFolder = 1;
   let hasVerifier = 1;
 
@@ -1486,6 +1752,11 @@ def Tosa_ReduceSumOp : Tosa_InferTensorTypeOp<"reduce_sum"> {
     Tosa_Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16]>,
+  ];
+
   let hasFolder = 1;
   let hasVerifier = 1;
 
@@ -1526,6 +1797,11 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
     Tosa_Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
+  ];
+
   let hasCanonicalizer = 1;
   let hasFolder = 1;
 
@@ -1573,6 +1849,11 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
     Tosa_RankedTensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
+  ];
+
   let builders = [Tosa_PadOpQuantInfoBuilder,
                   Tosa_ExplicitValuePadOpQuantInfoBuilder];
 
@@ -1605,6 +1886,11 @@ def Tosa_ReshapeOp : Tosa_InferTensorTypeOp<"reshape"> {
     Tosa_RankedTensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
+  ];
+
   let extraClassDeclaration = [{
     /// Returns true when two result types are compatible for this op;
     /// Method used by InferTypeOpInterface.
@@ -1637,6 +1923,11 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
     Tosa_Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
+  ];
+
   let hasFolder = 1;
   let hasVerifier = 1;
 
@@ -1665,6 +1956,11 @@ def Tosa_SliceOp : Tosa_InferShapedTypeOp<"slice"> {
     Tosa_Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
+  ];
+
   let hasCanonicalizer = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
@@ -1688,6 +1984,11 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
     Tosa_Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
+  ];
+
   let extraClassDeclaration = [{
     LogicalResult getConstantMultiples(llvm::SmallVector<int64_t> &multiples);
   }];
@@ -1717,6 +2018,11 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
     outs Tosa_Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
+  ];
+
   let extraClassDeclaration = [{
     LogicalResult getConstantPerms(llvm::SmallVector<int32_t> &perms);
   }];
@@ -1750,6 +2056,11 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
   let results = (outs
     Tosa_Tensor3D:$output
   );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -1772,6 +2083,11 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
   let results = (outs
     Tosa_Tensor3D:$values_out
   );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -1806,6 +2122,11 @@ def Tosa_ResizeOp : Tosa_InferShapedTypeOp<"resize"> {
     Tosa_Tensor4D:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_INT16, Tosa_EXT_BF16]>,
+  ];
+
   let hasFolder = 1;
 }
 
@@ -1856,6 +2177,11 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
     Tosa_Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
+  ];
+
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 
   let hasFolder = 1;
@@ -1907,6 +2233,11 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
     Tosa_Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT]>,
+    Extension<[Tosa_EXT_INT16]>,
+  ];
+
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 }
 
@@ -1943,6 +2274,11 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
     TosaTensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_INT4, Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
+  ];
+
   let hasFolder = 1;
   let hasVerifier = 1;
 }
@@ -1967,6 +2303,11 @@ def Tosa_IdentityOp: Tosa_Op<"identity", [Pure,
     Tosa_Tensor:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_INT4, Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
+  ];
+
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 }
 
@@ -2022,6 +2363,11 @@ def Tosa_CustomOp : Tosa_Op<"custom"> {
     Variadic<Tosa_Tensor>:$output_list
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[]>,
+  ];
+
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 }
 
@@ -2056,6 +2402,11 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
     Variadic<Tosa_Tensor>:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[]>,
+  ];
+
   let regions = (region
     SizedRegion<1>:$then_branch,
     SizedRegion<1>:$else_branch
@@ -2092,6 +2443,11 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
     Variadic<Tosa_Tensor>:$output
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[]>,
+  ];
+
   let regions = (region
     SizedRegion<1>:$cond,
     SizedRegion<1>:$body
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
new file mode 100644
index 0000000000000..35798452065c5
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -0,0 +1,163 @@
+//===- TosaProfileCompliance.h - Tosa Profile-based Compliance Validation -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_TOSAPROFILECOMPILANCE_H
+#define MLIR_DIALECT_TOSA_TRANSFORMS_TOSAPROFILECOMPILANCE_H
+
+#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+
+#include "mlir/Support/TypeID.h"
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+//===----------------------------------------------------------------------===//
+// Type Compilance Definition
+//===----------------------------------------------------------------------===//
+
+typedef struct {
+  mlir::TypeID typeID;
+  uint32_t bitWidth;
+} TypeInfo;
+
+enum CheckCondition {
+  // Valid when any of the profile (extension) requirement is meet.
+  anyOf,
+  // Valid when all of the profile (extension) requirement are meet.
+  allOf,
+  invalid
+};
+
+template <typename T>
+struct OpComplianceInfo {
+  // Certain operations require multiple modes enabled.
+  // e.g. cast bf16 to fp8e4m3 requires EXT-BF16 and EXT-FP8E4M3.
+  SmallVector<T> mode;
+  SmallVector<SmallVector<TypeInfo>> operandTypeInfoSet;
+  CheckCondition condition = CheckCondition::anyOf;
+};
+
+using OperationProfileComplianceMap =
+    std::unordered_map<std::string, SmallVector<OpComplianceInfo<Profile>>>;
+using OperationExtensionComplianceMap =
+    std::unordered_map<std::string, SmallVector<OpComplianceInfo<Extension>>>;
+
+//===----------------------------------------------------------------------===//
+// Tosa Profile And Extension Information Depot
+//===----------------------------------------------------------------------===//
+
+class ProfileInfoDepot {
+public:
+  ProfileInfoDepot(Operation *op) {
+    if (failed(populatationDispatch(op)))
+      op->emitOpError() << "fail to populate the profile info\n";
+  }
+
+  void addType(Type t) { tyInfo.push_back(convertTypeToInfo(t)); }
+  void addValue(Value v) { tyInfo.push_back(convertValueToInfo(v)); }
+  SmallVector<TypeInfo> getInfo() { return tyInfo; }
+
+private:
+  TypeInfo convertTypeToInfo(Type type) {
+    return {type.getTypeID(), type.getIntOrFloatBitWidth()};
+  }
+
+  TypeInfo convertValueToInfo(Value value) {
+    return convertTypeToInfo(getElementTypeOrSelf(value.getType()));
+  }
+
+  LogicalResult populatationDispatch(Operation *op);
+
+  void populateProfileInfo(ValueRange operands, Value output);
+
+  // Base
+  template <typename T>
+  void populateProfileInfo(T op) {
+    op->emitOpError() << "profile requirement for this op has not been defined";
+  }
+  // For conv2d, conv3d, transpose_conv2d, and depthwise_conv2d.
+  template <typename T>
+  void populateProfileInfoConv(T op);
+
+  // For pad, reshap, slice, tile, and transpose.
+  template <typename T>
+  void populateProfileInfoDataLayout(T op);
+
+private:
+  SmallVector<TypeInfo> tyInfo;
+};
+
+//===----------------------------------------------------------------------===//
+// Tosa Profile And Extension Compliance Checker
+//===----------------------------------------------------------------------===//
+
+class TosaProfileCompliance {
+public:
+  explicit TosaProfileCompliance();
+
+  // Accessor of the compliance info map.
+  template <typename T>
+  std::unordered_map<std::string, SmallVector<OpComplianceInfo<T>>>
+  getProfileComplianceMap() {
+    // Only profile and extension compliance info are provided.
+    return {};
+  }
+
+  // Verify if the operation is allowed to be executed in the given target
+  // environment.
+  LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv);
+  LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv);
+
+  template <typename T>
+  LogicalResult checkProfileOrExtension(
+      Operation *op, const tosa::TargetEnv &targetEnv,
+      const SmallVector<ArrayRef<T>> &specDefinedProfileSet);
+
+  bool isSameTypeInfo(TypeInfo a, TypeInfo b) {
+    return a.typeID == b.typeID && a.bitWidth == b.bitWidth;
+  }
+
+  // Find the required profiles or extensions from the compliance info according
+  // to the operand type combination.
+  template <typename T>
+  SmallVector<T> findMatchedProfile(Operation *op,
+                                    SmallVector<OpComplianceInfo<T>> compInfo,
+                                    CheckCondition &condition);
+
+  SmallVector<Profile> getCooperativeProfiles(Extension ext) {
+    switch (ext) {
+    case Extension::int16:
+    case Extension::int4:
+      return {Profile::pro_int};
+    case Extension::bf16:
+    case Extension::fp8e4m3:
+    case Extension::fp8e5m2:
+    case Extension::fft:
+      return {Profile::pro_fp};
+    case Extension::variable:
+      return {Profile::pro_fp, Profile::pro_int};
+    case Extension::none:
+      return {};
+    };
+  }
+
+  // Debug utilites.
+  template <typename T>
+  SmallVector<StringRef> stringifyProfile(ArrayRef<T> profiles);
+
+  template <typename T>
+  SmallVector<StringRef>
+  stringifyProfile(const SmallVector<ArrayRef<T>> &profileSet);
+
+private:
+  OperationProfileComplianceMap profileComplianceMap;
+  OperationExtensionComplianceMap extensionComplianceMap;
+};
+
+#endif // MLIR_DIALECT_TOSA_TRANSFORMS_TOSAPROFILECOMPILANCE_H
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index 597dc32e84402..82cfe01865853 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -30,6 +30,10 @@ def TosaShapeOperator : NativeOpTrait<"TosaShapeOperator"> {
 
 class Tosa_ShapeOp<string mnemonic, list<Trait> traits = []>
     : Tosa_Op<mnemonic, !listconcat(traits, [TosaShapeOperator, Pure])> {
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[]>,
+  ];
 
   let assemblyFormat =
       "operands attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index f9f25da1b649d..8756cb9e5de3a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -96,6 +96,11 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
     OptionalAttr<AnyAttr>:$initial_value
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_VARIABLE]>,
+  ];
+
   let assemblyFormat = [{
     $name
     attr-dict
@@ -118,6 +123,11 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
     AnyType:$value
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_VARIABLE]>,
+  ];
+
   let assemblyFormat = [{
     $name attr-dict `,` $value `:` type($value)
   }];
@@ -141,6 +151,11 @@ def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
     AnyType:$value
   );
 
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_VARIABLE]>,
+  ];
+
   let assemblyFormat = [{
     $name attr-dict `:` type($value)
   }];
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 565970367e5dc..33bbc069c521d 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -14,6 +14,7 @@
 #define MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H
 
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Transforms/PassesEnums.h.inc"
 #include "mlir/Pass/Pass.h"
 
@@ -48,28 +49,6 @@ std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
 std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
 std::unique_ptr<Pass> createTosaOptionalDecompositions();
 
-struct ValidationOptions {
-  /// Validate if operations match for the given profile.
-  TosaProfileEnum profile = TosaProfileEnum::Undefined;
-  ValidationOptions &setProfile(TosaProfileEnum profile) {
-    this->profile = profile;
-    return *this;
-  }
-  /// Verify if the properties of certain operations align the spec requirement.
-  bool strictOperationSpecAlignment = false;
-  ValidationOptions &enableStrictOperationSpecAlignment(bool enable = true) {
-    strictOperationSpecAlignment = enable;
-    return *this;
-  }
-  /// Validate if operator parameters are within specfication for the given
-  /// level.
-  TosaLevelEnum level = TosaLevelEnum::EightK;
-  ValidationOptions &setLevel(TosaLevelEnum level) {
-    this->level = level;
-    return *this;
-  }
-};
-
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
 
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index dac67633769c7..f6ead2b6ba3dd 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -71,16 +71,6 @@ def TosaOptionalDecompositions
   let constructor = "tosa::createTosaOptionalDecompositions()";
 }
 
-def TosaProfileType : I32EnumAttr<"TosaProfileEnum", "Tosa profile",
-    [
-      I32EnumAttrCase<"BaseInference", 0, "bi">,
-      I32EnumAttrCase<"MainInference", 1, "mi">,
-      I32EnumAttrCase<"MainTraining", 2, "mt">,
-      I32EnumAttrCase<"Undefined", 3, "none">
-    ]>{
-  let cppNamespace = "mlir::tosa";
-}
-
 def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level",
     [
       I32EnumAttrCase<"None", 0, "none">,
@@ -99,7 +89,9 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
   let options = [
       ListOption<"profile", "profile", "std::string",
              "Validate if operations match for the given profile set">,
-      Option<"StrictOperationSpecAlignment", "strict-op-spec-alignment", "bool",
+      ListOption<"extension", "extension", "std::string",
+             "Validate if operations match for the given extension set">,
+      Option<"strictOpSpecAlignment", "strict-op-spec-alignment", "bool",
              /*default=*/"false",
              "Verify if the properties of certain operations align the spec requirement">,
       Option<"level", "level", "mlir::tosa::TosaLevelEnum",
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 8dfa55bef74fc..bfadebba12708 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -117,7 +117,8 @@ void mlir::tosa::registerTosaToLinalgPipelines() {
         TosaToLinalgNamedOptions tosaToLinalgNamedOptions;
         TosaValidationOptions validationOptions;
         validationOptions.profile = {"none"};
-        validationOptions.StrictOperationSpecAlignment = true;
+        validationOptions.extension = {"none"};
+        validationOptions.strictOpSpecAlignment = false;
         validationOptions.level = tosa::TosaLevelEnum::EightK;
         tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
                                     tosaToLinalgNamedOptions,
diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt
index e6999f6fa0d85..b1fac8c85a204 100644
--- a/mlir/lib/Dialect/Tosa/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt
@@ -12,6 +12,8 @@ add_mlir_dialect_library(MLIRTosaDialect
   MLIRTosaDialectBytecodeIncGen
   MLIRTosaOpsIncGen
   MLIRTosaInterfacesIncGen
+  MLIRTosaEnumsIncGen
+  MLIRTosaAvailabilityIncGen
   MLIRShardingInterfaceIncGen
 
   LINK_LIBS PUBLIC
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 67021d6c07401..b8de753812f86 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -42,7 +42,10 @@ using namespace mlir::tosa;
 // Tosa dialect interface includes.
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
+#include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
+#include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
 
 namespace {
 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 9c3345b617cc5..bbf079faea3d0 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
   TosaOptionalDecompositions.cpp
   TosaReduceTransposes.cpp
   TosaTypeConverters.cpp
+  TosaProfileCompliance.cpp
   TosaValidation.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
new file mode 100644
index 0000000000000..3e8c3d8698d43
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -0,0 +1,476 @@
+//===--- TosaProfileCompliance.cpp - Tosa Profile Compliance Validation ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/IR/TosaProfileCompliance.h"
+#include "llvm/ADT/StringExtras.h"
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+TosaProfileCompliance::TosaProfileCompliance() {
+  const TypeInfo boolT = {mlir::IntegerType::getTypeID(), 1};
+  const TypeInfo i4T = {mlir::IntegerType::getTypeID(), 4};
+  const TypeInfo i8T = {mlir::IntegerType::getTypeID(), 8};
+  const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16};
+  const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32};
+  const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48};
+  const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16};
+  const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16};
+  const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32};
+  const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
+  const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};
+
+// The profile-based compliance content below is auto-generated by a script
+// in https://git.mlplatform.org/tosa/specification.git
+#include "mlir/Dialect/Tosa/IR/TosaComplianceData.h"
+  // End of auto-generated metadata
+}
+
+template <>
+OperationProfileComplianceMap TosaProfileCompliance::getProfileComplianceMap() {
+  return profileComplianceMap;
+}
+
+template <>
+OperationExtensionComplianceMap
+TosaProfileCompliance::getProfileComplianceMap() {
+  return extensionComplianceMap;
+}
+
+// Base populating function
+void ProfileInfoDepot::populateProfileInfo(ValueRange operands, Value output) {
+  for (auto operand : operands)
+    addValue(operand);
+  addValue(output);
+}
+
+template <>
+void ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
+  addValue(op.getInput1().front());
+  addValue(op.getOutput());
+}
+
+template <>
+void ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
+  addValue(op.getInput());
+  addType(op.getAccType());
+  addValue(op.getOutput());
+}
+
+template <typename T>
+void ProfileInfoDepot::populateProfileInfoConv(T op) {
+  addValue(op.getInput());
+  addValue(op.getWeight());
+  addValue(op.getBias());
+  addType(op.getAccType());
+  addValue(op.getOutput());
+}
+
+template <>
+void ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) {
+  populateProfileInfoConv(op);
+}
+
+template <>
+void ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) {
+  populateProfileInfoConv(op);
+}
+
+template <>
+void ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) {
+  populateProfileInfoConv(op);
+}
+
+template <>
+void ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) {
+  populateProfileInfoConv(op);
+}
+
+template <typename T>
+void ProfileInfoDepot::populateProfileInfoDataLayout(T op) {
+  addValue(op.getInput1());
+  addValue(op.getOutput());
+}
+
+template <>
+void ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) {
+  populateProfileInfoDataLayout(op);
+}
+
+template <>
+void ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) {
+  populateProfileInfoDataLayout(op);
+}
+
+template <>
+void ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) {
+  populateProfileInfoDataLayout(op);
+}
+
+template <>
+void ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) {
+  populateProfileInfoDataLayout(op);
+}
+
+template <>
+void ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
+  populateProfileInfoDataLayout(op);
+}
+
+template <>
+void ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
+  addValue(op.getValues());
+  addValue(op.getOutput());
+}
+
+template <>
+void ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
+  addValue(op.getValuesIn());
+  addValue(op.getInput());
+  addValue(op.getValuesOut());
+}
+
+template <>
+void ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) {
+  addValue(op.getInput1());
+  addValue(op.getInput2());
+  addValue(op.getOutput());
+}
+
+template <>
+void ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
+  addValue(op.getInput());
+  addValue(op.getOutput());
+}
+
+template <>
+void ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) {
+  addValue(op.getInputReal());
+  addValue(op.getInputImag());
+  addValue(op.getOutputReal());
+  addValue(op.getOutputImag());
+}
+
+template <>
+void ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {
+  addValue(op.getInput());
+  addValue(op.getOutputReal());
+  addValue(op.getOutputImag());
+}
+
+template <>
+void ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
+  addValue(op.getOnTrue());
+  addValue(op.getOnFalse());
+  addValue(op.getOutput());
+}
+
+template <>
+void ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
+  addValue(op.getInput());
+  addValue(op.getOutput());
+}
+
+LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
+// This helper function only populates the info for the customised operands.
+#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp)                                   \
+  if (isa<tosa::tosaOp##Op>(op)) {                                             \
+    populateProfileInfo(cast<tosa::tosaOp##Op>(op));                           \
+    return success();                                                          \
+  }
+
+#define POPULATE_PROFILE_INFO_SKIP(tosaOp)                                     \
+  if (isa<tosa::tosaOp##Op>(op))                                               \
+    return success();
+
+// This helper function populates the info for all operands.
+#define POPULATE_PROFILE_INFO_COMMON(tosaOp)                                   \
+  if (isa<tosa::tosaOp##Op>(op)) {                                             \
+    populateProfileInfo(op->getOperands(), op->getResult(0));                  \
+    return success();                                                          \
+  }
+
+  // Skip irrelevant operands when they are independent and not tied to any
+  // specific profile/extension.
+  POPULATE_PROFILE_INFO_CUSTOM(AvgPool2d)
+  POPULATE_PROFILE_INFO_CUSTOM(TransposeConv2D)
+  POPULATE_PROFILE_INFO_CUSTOM(Conv2D)
+  POPULATE_PROFILE_INFO_CUSTOM(Conv3D)
+  POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D)
+  POPULATE_PROFILE_INFO_CUSTOM(Mul)
+  POPULATE_PROFILE_INFO_CUSTOM(FFT2d)
+  POPULATE_PROFILE_INFO_CUSTOM(RFFT2d)
+  POPULATE_PROFILE_INFO_CUSTOM(Concat)
+  POPULATE_PROFILE_INFO_CUSTOM(Pad)
+  POPULATE_PROFILE_INFO_CUSTOM(Reshape)
+  POPULATE_PROFILE_INFO_CUSTOM(Slice)
+  POPULATE_PROFILE_INFO_CUSTOM(Tile)
+  POPULATE_PROFILE_INFO_CUSTOM(Transpose)
+  POPULATE_PROFILE_INFO_CUSTOM(Gather)
+  POPULATE_PROFILE_INFO_CUSTOM(Scatter)
+  POPULATE_PROFILE_INFO_CUSTOM(Resize)
+  POPULATE_PROFILE_INFO_CUSTOM(Select)
+  POPULATE_PROFILE_INFO_CUSTOM(Rescale)
+
+  // Type Invariant Extension, a capability extension that is independent
+  // of the data type, meaning any compatible type can be used. No type
+  // constraint for those operations.
+  POPULATE_PROFILE_INFO_SKIP(ConstShape)
+  POPULATE_PROFILE_INFO_SKIP(Variable)
+  POPULATE_PROFILE_INFO_SKIP(VariableRead)
+  POPULATE_PROFILE_INFO_SKIP(VariableWrite)
+  POPULATE_PROFILE_INFO_SKIP(If)
+  POPULATE_PROFILE_INFO_SKIP(While)
+  POPULATE_PROFILE_INFO_SKIP(Yield)
+
+  // For the most of tosa operators, all operands are profile/extension related
+  // and hence are all considered in this profile-based compilance check.
+  POPULATE_PROFILE_INFO_COMMON(Cast)
+  POPULATE_PROFILE_INFO_COMMON(Const)
+  POPULATE_PROFILE_INFO_COMMON(ArgMax)
+  POPULATE_PROFILE_INFO_COMMON(MatMul)
+  POPULATE_PROFILE_INFO_COMMON(Sub)
+  POPULATE_PROFILE_INFO_COMMON(Maximum)
+  POPULATE_PROFILE_INFO_COMMON(Minimum)
+  POPULATE_PROFILE_INFO_COMMON(MaxPool2d)
+  POPULATE_PROFILE_INFO_COMMON(Clamp)
+  POPULATE_PROFILE_INFO_COMMON(Erf)
+  POPULATE_PROFILE_INFO_COMMON(Sigmoid)
+  POPULATE_PROFILE_INFO_COMMON(Tanh)
+  POPULATE_PROFILE_INFO_COMMON(Add)
+  POPULATE_PROFILE_INFO_COMMON(ArithmeticRightShift)
+  POPULATE_PROFILE_INFO_COMMON(BitwiseAnd)
+  POPULATE_PROFILE_INFO_COMMON(BitwiseNot)
+  POPULATE_PROFILE_INFO_COMMON(BitwiseOr)
+  POPULATE_PROFILE_INFO_COMMON(BitwiseXor)
+  POPULATE_PROFILE_INFO_COMMON(LogicalLeftShift)
+  POPULATE_PROFILE_INFO_COMMON(LogicalRightShift)
+  POPULATE_PROFILE_INFO_COMMON(LogicalAnd)
+  POPULATE_PROFILE_INFO_COMMON(LogicalNot)
+  POPULATE_PROFILE_INFO_COMMON(LogicalOr)
+  POPULATE_PROFILE_INFO_COMMON(LogicalXor)
+  POPULATE_PROFILE_INFO_COMMON(IntDiv)
+  POPULATE_PROFILE_INFO_COMMON(Pow)
+  POPULATE_PROFILE_INFO_COMMON(Table)
+  POPULATE_PROFILE_INFO_COMMON(Abs)
+  POPULATE_PROFILE_INFO_COMMON(Ceil)
+  POPULATE_PROFILE_INFO_COMMON(Clz)
+  POPULATE_PROFILE_INFO_COMMON(Sin)
+  POPULATE_PROFILE_INFO_COMMON(Cos)
+  POPULATE_PROFILE_INFO_COMMON(Exp)
+  POPULATE_PROFILE_INFO_COMMON(Floor)
+  POPULATE_PROFILE_INFO_COMMON(Log)
+  POPULATE_PROFILE_INFO_COMMON(Negate)
+  POPULATE_PROFILE_INFO_COMMON(Reciprocal)
+  POPULATE_PROFILE_INFO_COMMON(Rsqrt)
+  POPULATE_PROFILE_INFO_COMMON(ReduceAll)
+  POPULATE_PROFILE_INFO_COMMON(ReduceAny)
+  POPULATE_PROFILE_INFO_COMMON(ReduceMax)
+  POPULATE_PROFILE_INFO_COMMON(ReduceMin)
+  POPULATE_PROFILE_INFO_COMMON(ReduceProd)
+  POPULATE_PROFILE_INFO_COMMON(ReduceSum)
+  POPULATE_PROFILE_INFO_COMMON(Equal)
+  POPULATE_PROFILE_INFO_COMMON(GreaterEqual)
+  POPULATE_PROFILE_INFO_COMMON(Greater)
+  POPULATE_PROFILE_INFO_COMMON(Reverse)
+  POPULATE_PROFILE_INFO_COMMON(Identity)
+
+  return failure();
+}
+
+//===----------------------------------------------------------------------===//
+// Tosa Profile And Extension Compliance Checker
+//===----------------------------------------------------------------------===//
+
+template <typename T>
+LogicalResult TosaProfileCompliance::checkProfileOrExtension(
+    Operation *op, const tosa::TargetEnv &targetEnv,
+    const SmallVector<ArrayRef<T>> &specRequiredModeSet) {
+
+  // None of profile requirement is set in the specification.
+  if (specRequiredModeSet.size() == 0)
+    return success();
+
+  auto opName = op->getName().getStringRef().str();
+  auto compMap = getProfileComplianceMap<T>();
+  auto it = compMap.find(opName);
+
+  if (it == compMap.end()) {
+    // Operators such as variable and shape ops do not have an operand type
+    // restriction. When the profile compliance information of operation is not
+    // found, confirm if the target have enabled the profile required from the
+    // specification.
+    int mode_count = 0;
+    for (const auto &cands : specRequiredModeSet) {
+      if (targetEnv.allowsAnyOf(cands))
+        return success();
+      mode_count += cands.size();
+    }
+
+    op->emitOpError() << "illegal: requires"
+                      << (mode_count > 1 ? " any of " : " ") << "["
+                      << llvm::join(stringifyProfile<T>(specRequiredModeSet),
+                                    ", ")
+                      << "] but not enabled in target\n";
+
+    return failure();
+  }
+
+  CheckCondition condition = CheckCondition::invalid;
+  // Find the profiles or extensions requirement according to the signature of
+  // type of the operand list.
+  SmallVector<T> opRequiredMode =
+      findMatchedProfile<T>(op, it->second, condition);
+
+  if (opRequiredMode.size() == 0) {
+    // No matched restriction found.
+    return success();
+  }
+
+  if (condition == CheckCondition::allOf &&
+      !targetEnv.allowsAllOf(opRequiredMode)) {
+    op->emitOpError() << "illegal: requires"
+                      << (opRequiredMode.size() > 1 ? " all of " : " ") << "["
+                      << llvm::join(stringifyProfile<T>(opRequiredMode), ", ")
+                      << "] but not enabled in target\n";
+    return failure();
+  }
+
+  if (condition == CheckCondition::anyOf &&
+      !targetEnv.allowsAnyOf(opRequiredMode)) {
+    op->emitOpError() << "illegal: requires"
+                      << (opRequiredMode.size() > 1 ? " any of " : " ") << "["
+                      << llvm::join(stringifyProfile<T>(opRequiredMode), ", ")
+                      << "] but not enabled in target\n";
+    return failure();
+  }
+
+  // Each extension can contain a list of profiles that it works with, usually
+  // have the same data type.
+  if constexpr (std::is_same_v<T, Extension>) {
+    for (const auto &mode : opRequiredMode) {
+      SmallVector<Profile> coProfs = getCooperativeProfiles(mode);
+      if (!targetEnv.allowsAnyOf(coProfs)) {
+        op->emitOpError() << "illegal: requires ["
+                          << llvm::join(stringifyProfile<Profile>(coProfs),
+                                        ", ")
+                          << "] to work with but not enabled in target\n";
+        return failure();
+      }
+    }
+  }
+
+  // Ensure the profile inference match the profile knowledge of the
+  // specification.
+  for (const auto &cands : specRequiredModeSet) {
+    for (size_t i = 0; i < opRequiredMode.size(); i++) {
+      if (std::find(cands.begin(), cands.end(), opRequiredMode[i]) ==
+          cands.end()) {
+        op->emitOpError() << "illegal: requires ["
+                          << llvm::join(stringifyProfile<T>(opRequiredMode),
+                                        ", ")
+                          << "] but not included in the profile compliance ["
+                          << llvm::join(
+                                 stringifyProfile<T>(specRequiredModeSet), ", ")
+                          << "]\n";
+        return failure();
+      }
+    }
+  }
+
+  return success();
+}
+
+LogicalResult
+TosaProfileCompliance::checkProfile(Operation *op,
+                                    const tosa::TargetEnv &targetEnv) {
+  if (auto interface = dyn_cast<tosa::QueryProfileInterface>(op))
+    return checkProfileOrExtension<Profile>(op, targetEnv,
+                                            interface.getProfiles());
+
+  return success();
+}
+
+LogicalResult
+TosaProfileCompliance::checkExtension(Operation *op,
+                                      const tosa::TargetEnv &targetEnv) {
+  if (auto interface = dyn_cast<tosa::QueryExtensionInterface>(op))
+    return checkProfileOrExtension<Extension>(op, targetEnv,
+                                              interface.getExtensions());
+
+  return success();
+}
+
+// Find the profiles or extensions requirement according to the signature of
+// type of the operand list.
+template <typename T>
+SmallVector<T> TosaProfileCompliance::findMatchedProfile(
+    Operation *op, SmallVector<OpComplianceInfo<T>> compInfo,
+    CheckCondition &condition) {
+  assert(compInfo.size() != 0);
+
+  // Populate the type of profile/extension relevant operands.
+  ProfileInfoDepot depot(op);
+  SmallVector<TypeInfo> present = depot.getInfo();
+  if (present.size() == 0)
+    return {};
+
+  for (size_t i = 0; i < compInfo.size(); i++) {
+    SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet;
+
+    for (SmallVector<TypeInfo> expected : sets) {
+      assert(present.size() == expected.size());
+
+      bool is_found = true;
+      // Compare the type signature between the given operation and the
+      // compliance metadata.
+      for (size_t j = 0; j < expected.size(); j++) {
+        if (!isSameTypeInfo(present[j], expected[j])) {
+          // Verify the next mode set from the list.
+          is_found = false;
+          break;
+        }
+      }
+
+      if (is_found == true) {
+        condition = compInfo[i].condition;
+        return compInfo[i].mode;
+      }
+    }
+  }
+
+  return {};
+}
+
+// Debug utilites.
+template <typename T>
+SmallVector<StringRef>
+TosaProfileCompliance::stringifyProfile(ArrayRef<T> profiles) {
+  SmallVector<StringRef> debugStrings;
+  for (const auto &profile : profiles) {
+    if constexpr (std::is_same_v<T, Profile>)
+      debugStrings.push_back(tosa::stringifyProfile(profile));
+    else
+      debugStrings.push_back(tosa::stringifyExtension(profile));
+  }
+  return debugStrings;
+}
+
+template <typename T>
+SmallVector<StringRef> TosaProfileCompliance::stringifyProfile(
+    const SmallVector<ArrayRef<T>> &profileSet) {
+  SmallVector<StringRef> debugStrings;
+
+  for (const auto &profiles : profileSet) {
+    auto tempStrings = stringifyProfile<T>(profiles);
+    debugStrings.insert(debugStrings.end(), tempStrings.begin(),
+                        tempStrings.end());
+  }
+
+  return debugStrings;
+}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 7f59ff70d3374..bb2373aac7ff3 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -11,6 +11,8 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
+#include "mlir/Dialect/Tosa/IR/TosaProfileCompliance.h"
 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
 #include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
 
@@ -24,6 +26,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/StringExtras.h"
 
 namespace mlir {
 namespace tosa {
@@ -85,10 +88,12 @@ static constexpr TosaLevel TOSA_LEVEL_NONE = {0, 0, 0, 0};
 struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
 public:
   explicit TosaValidation() { populateConstantOperandChecks(); }
+
   explicit TosaValidation(const TosaValidationOptions &options)
       : TosaValidation() {
     this->profile = options.profile;
-    this->StrictOperationSpecAlignment = options.StrictOperationSpecAlignment;
+    this->extension = options.extension;
+    this->strictOpSpecAlignment = options.strictOpSpecAlignment;
     this->level = options.level;
   }
   void runOnOperation() final;
@@ -394,9 +399,25 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
 
     if (!profile.empty()) {
       for (std::string &prof : profile) {
-        auto profSymbol = symbolizeTosaProfileEnum(prof);
+        auto profSymbol = symbolizeProfile(prof);
         if (profSymbol) {
-          enabled_profiles.push_back(profSymbol.value());
+          targetEnv.addProfile(profSymbol.value());
+        } else {
+          llvm::errs() << "warning: unknown profile name passed in, supported "
+                          "profile are bi and mi\n";
+        }
+      }
+    }
+
+    if (!extension.empty()) {
+      for (std::string &ext : extension) {
+        auto extSymbol = symbolizeExtension(ext);
+        if (extSymbol) {
+          targetEnv.addExtension(extSymbol.value());
+        } else {
+          llvm::errs() << "warning: unknown extension name passed in, "
+                          "supported extension are int16, int4, bf16, fp8e4m3, "
+                          "fp8e5m2, fft, and variable\n";
         }
       }
     }
@@ -404,17 +425,13 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
 
   bool CheckVariable(Operation *op);
   bool CheckVariableReadOrWrite(Operation *op);
-
   bool isValidElementType(Type type);
-  bool isEnabledProfile(TosaProfileEnum prof) {
-    return std::find(enabled_profiles.begin(), enabled_profiles.end(), prof) !=
-           std::end(enabled_profiles);
-  }
 
   SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
-  SmallVector<TosaProfileEnum, 3> enabled_profiles;
   TosaLevel tosaLevel;
   DenseMap<StringAttr, mlir::Type> variablesMap;
+  TosaProfileCompliance profileComp;
+  tosa::TargetEnv targetEnv;
 };
 
 LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
@@ -507,8 +524,6 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
 
 bool TosaValidation::isValidElementType(Type type) {
   if (isa<FloatType>(type)) {
-    if (!isEnabledProfile(TosaProfileEnum::MainInference))
-      return false;
     return type.isF32() || type.isF16() || type.isBF16();
   } else if (auto intTy = dyn_cast<IntegerType>(type)) {
     if (intTy.isSignless()) {
@@ -539,6 +554,15 @@ void TosaValidation::runOnOperation() {
     if (op->getDialect() != tosaDialect)
       return;
 
+    // Profile-Extension based validation should be performed at the beginning.
+    if (strictOpSpecAlignment &&
+        failed(profileComp.checkProfile(op, targetEnv)))
+      return signalPassFailure();
+
+    if (strictOpSpecAlignment &&
+        failed(profileComp.checkExtension(op, targetEnv)))
+      return signalPassFailure();
+
     for (Value operand : op->getOperands()) {
       auto elementTy = getElementTypeOrSelf(operand);
       if (!isValidElementType(elementTy)) {
@@ -558,7 +582,7 @@ void TosaValidation::runOnOperation() {
 
     // Some uses of TOSA rely on the constant operands of particular
     // operations.
-    if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op)))
+    if (strictOpSpecAlignment && failed(applyConstantOperandCheck(op)))
       signalPassFailure();
 
     // do level checks
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
new file mode 100644
index 0000000000000..7a78616e0b9d7
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -0,0 +1,681 @@
+//--------------------------------------------------------------------------------------------------
+// Test whether the supported profile and extension are attached to the operation properly.
+// The data type of arguments of operation are irrelevant in this test.
+//--------------------------------------------------------------------------------------------------
+
+// RUN: mlir-opt -mlir-disable-threading -test-tosa-op-availability %s | FileCheck %s
+
+// -----
+// CHECK-LABEL: argmax
+func.func @test_argmax(%arg0: tensor<14x19xf32>) -> tensor<14xi32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [int16, fp8e4m3, fp8e5m2, bf16] ]
+  %0 = tosa.argmax %arg0 {axis = 1 : i32} : (tensor<14x19xf32>) -> tensor<14xi32>
+  return %0 : tensor<14xi32>
+}
+
+// -----
+// CHECK-LABEL: avg_pool2d
+func.func @test_avg_pool2d(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [int16, fp8e4m3, fp8e5m2, bf16] ]
+  %0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32>
+  return %0 : tensor<1x7x7x9xf32>
+}
+
+// -----
+// CHECK-LABEL: conv2d
+func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [int4, int16, fp8e4m3, fp8e5m2, bf16] ]
+  %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
+  return %0 : tensor<1x4x4x8xf32>
+}
+
+// -----
+// CHECK-LABEL: conv3d
+func.func @test_conv3d(%arg0: tensor<1x4x8x21x17xf32>, %arg1: tensor<34x1x1x1x17xf32>, %arg2: tensor<34xf32>) -> tensor<1x4x8x21x34xf32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [int4, int16, fp8e4m3, fp8e5m2, bf16] ]
+  %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<34xf32>) -> tensor<1x4x8x21x34xf32>
+  return %0 : tensor<1x4x8x21x34xf32>
+}
+
+// -----
+// CHECK-LABEL: depthwise_conv2d
+func.func @test_depthwise_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<1x1x4x2xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [int4, int16, fp8e4m3, fp8e5m2, bf16] ]
+  %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf32>, tensor<1x1x4x2xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
+  return %0 : tensor<1x4x4x8xf32>
+}
+
+// -----
+// CHECK-LABEL: fft2d
+func.func @test_fft2d(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) {
+  // CHECK: profiles: [ ]
+  // CHECK: extensions: [ [fft] ]
+  %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>)
+  return %0, %1 : tensor<1x4x8xf32>, tensor<1x4x8xf32>
+}
+
+// -----
+// CHECK-LABEL: matmul
+func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [int16, fp8e4m3, fp8e5m2, bf16] ]
+  %0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>) -> tensor<1x14x28xf32>
+  return %0 : tensor<1x14x28xf32>
+}
+
+// -----
+// CHECK-LABEL: max_pool2d_f32
+func.func @test_max_pool2d_f32(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [int16, fp8e4m3, fp8e5m2, bf16] ]
+  %0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+// CHECK-LABEL: rfft2d
+func.func @test_rfft2d(%arg0: tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {
+  // CHECK: profiles: [ ]
+  // CHECK: extensions: [ [fft] ]
+  %0, %1 = tosa.rfft2d %arg0 : (tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>)
+  return %0, %1 : tensor<13x8x9xf32>, tensor<13x8x9xf32>
+}
+
+// -----
+// CHECK-LABEL: transpose_conv2d
+func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [int4, int16, fp8e4m3, fp8e5m2, bf16] ]
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+  return %0 : tensor<1x32x32x16xf32>
+}
+
+// -----
+// CHECK-LABEL: clamp
+func.func @test_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [int16, bf16] ]
+  %0 = tosa.clamp %arg0 {min_val = -3.40282347E+38 : f32, max_val = 3.40282347E+38 : f32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: sigmoid
+func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  // CHECK: profiles: [ [pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.sigmoid %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: tanh
+func.func @test_tanh(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  // CHECK: profiles: [ [pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.tanh %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+// -----
+// CHECK-LABEL: erf
+func.func @test_erf(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  // CHECK: profiles: [ [pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.erf %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: add
+func.func @test_add(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: arithmetic_right_shift
+func.func @test_arithmetic_right_shift(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  // CHECK: profiles: [ [pro_int] ]
+  // CHECK: extensions: [ ]
+  %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: bitwise_and
+func.func @test_bitwise_and(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> {
+  // CHECK: profiles: [ [pro_int] ]
+  // CHECK: extensions: [ ]
+  %0 = tosa.bitwise_and %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+// CHECK-LABEL: bitwise_or
+func.func @test_bitwise_or(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> {
+  // CHECK: profiles: [ [pro_int] ]
+  // CHECK: extensions: [ ]
+  %0 = tosa.bitwise_or %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+// CHECK-LABEL: bitwise_xor
+func.func @test_bitwise_xor(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
+  // CHECK: profiles: [ [pro_int] ]
+  // CHECK: extensions: [ ]
+  %0 = tosa.bitwise_xor %arg0, %arg1 : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+// CHECK-LABEL: int_div
+func.func @test_int_div(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ ]
+  %0 = tosa.int_div %arg0, %arg1 : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+// CHECK-LABEL: logical_and
+func.func @test_logical_and(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x21x1xi1>) -> tensor<13x21x3xi1> {
+  // CHECK: profiles: [ [pro_int] ]
+  // CHECK: extensions: [ ]
+  %0 = tosa.logical_and %arg0, %arg1 : (tensor<13x21x3xi1>, tensor<13x21x1xi1>) -> tensor<13x21x3xi1>
+  return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+// CHECK-LABEL: logical_left_shift
+func.func @test_logical_left_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> {
+  // CHECK: profiles: [ [pro_int] ]
+  // CHECK: extensions: [ ]
+  %0 = tosa.logical_left_shift %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+// CHECK-LABEL: logical_right_shift
+func.func @test_logical_right_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> {
+  // CHECK: profiles: [ [pro_int] ]
+  // CHECK: extensions: [ ]
+  %0 = tosa.logical_right_shift %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+// CHECK-LABEL: logical_or
+func.func @test_logical_or(%arg0: tensor<13x1x3xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
+  // CHECK: profiles: [ [pro_int] ]
+  // CHECK: extensions: [ ]
+  %0 = tosa.logical_or %arg0, %arg1 : (tensor<13x1x3xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1>
+  return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+// CHECK-LABEL: logical_xor
+func.func @test_logical_xor(%arg0: tensor<13x1x3xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
+  // CHECK: profiles: [ [pro_int] ]
+  // CHECK: extensions: [ ]
+  %0 = tosa.logical_xor %arg0, %arg1 : (tensor<13x1x3xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1>
+  return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+// CHECK-LABEL: maximum
+func.func @test_max(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.maximum %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: minimum
+func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> tensor<13x21x3xf32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.minimum %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<1x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: mul
+func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: pow
+func.func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {
+  // CHECK: profiles: [ [pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.pow %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: sub
+func.func @test_sub(%arg0: tensor<1x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.sub %arg0, %arg1 : (tensor<1x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: table
+func.func @test_table(%arg0: tensor<64xi32>, %arg1: tensor<513x!quant.uniform<i16:f32, 1.0:0>>) -> tensor<64x!quant.uniform<i16:f32, 1.0:0>> {
+  // CHECK: profiles: [ [pro_int] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.table %arg0, %arg1 : (tensor<64xi32>, tensor<513x!quant.uniform<i16:f32, 1.000000e+00>>) -> tensor<64x!quant.uniform<i16:f32, 1.000000e+00>>
+  return %0 : tensor<64x!quant.uniform<i16:f32, 1.0:0>>
+}
+
+// -----
+// CHECK-LABEL: abs
+func.func @test_abs(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.abs %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: bitwise_not
+func.func @test_bitwise_not(%arg0: tensor<13x21x1xi32>) -> tensor<13x21x1xi32> {
+  // CHECK: profiles: [ [pro_int] ]
+  // CHECK: extensions: [ ]
+  %0 = tosa.bitwise_not %arg0 : (tensor<13x21x1xi32>) -> tensor<13x21x1xi32>
+  return %0 : tensor<13x21x1xi32>
+}
+
+// -----
+// CHECK-LABEL: ceil
+func.func @test_ceil(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  // CHECK: profiles: [ [pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.ceil %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: clz
+func.func @test_clz(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
+  // CHECK: profiles: [ [pro_int] ]
+  // CHECK: extensions: [ ]
+  %0 = tosa.clz %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+// CHECK-LABEL: cos
+func.func @test_cos(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  // CHECK: profiles: [ [pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.cos %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: exp
+func.func @test_exp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  // CHECK: profiles: [ [pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.exp %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: floor
+func.func @test_floor(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  // CHECK: profiles: [ [pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.floor %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: log
+func.func @test_log(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  // CHECK: profiles: [ [pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.log %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: logical_not
+func.func @test_logical_not(%arg0: tensor<1x21x3xi1>) -> tensor<1x21x3xi1> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ ]
+  %0 = tosa.logical_not %arg0 : (tensor<1x21x3xi1>) -> tensor<1x21x3xi1>
+  return %0 : tensor<1x21x3xi1>
+}
+
+// -----
+// CHECK-LABEL: negate
+func.func @test_negate(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.negate %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: reciprocal
+func.func @test_reciprocal(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  // CHECK: profiles: [ [pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.reciprocal %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: rsqrt
+func.func @test_rsqrt(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  // CHECK: profiles: [ [pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.rsqrt %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: sin
+func.func @test_sin(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  // CHECK: profiles: [ [pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.sin %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: select
+func.func @test_select(%arg0: tensor<1x1x1xi1>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<1x1x1xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: equal
+func.func @test_equal(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xi1> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.equal %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xi1>
+  return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+// CHECK-LABEL: greater
+func.func @test_greater(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.greater %arg0, %arg1 : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1>
+  return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+// CHECK-LABEL: greater_equal
+func.func @test_greater_equal(%arg0: tensor<13x1x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.greater_equal %arg0, %arg1 : (tensor<13x1x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1>
+  return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+// CHECK-LABEL: reduce_all
+func.func @test_reduce_all(%arg0: tensor<13x21x3xi1>) -> tensor<1x21x3xi1> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ ]
+  %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<13x21x3xi1>) -> tensor<1x21x3xi1>
+  return %0 : tensor<1x21x3xi1>
+}
+
+// -----
+// CHECK-LABEL: reduce_any
+func.func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<1x21x3xi1> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ ]
+  %0 = tosa.reduce_any %arg0 {axis = 0 : i32} : (tensor<13x21x3xi1>) -> tensor<1x21x3xi1>
+  return %0 : tensor<1x21x3xi1>
+}
+
+// -----
+// CHECK-LABEL: reduce_max
+func.func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<1x21x3xf32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.reduce_max %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
+  return %0 : tensor<1x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: reduce_min
+func.func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<1x21x3xf32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
+  return %0 : tensor<1x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: reduce_product
+func.func @test_reduce_product(%arg0: tensor<13x21x3xf32>) -> tensor<1x21x3xf32> {
+  // CHECK: profiles: [ [pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.reduce_prod %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
+  return %0 : tensor<1x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: reduce_sum
+func.func @test_reduce_sum(%arg0: tensor<13x21x3xf32>) -> tensor<1x21x3xf32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
+  return %0 : tensor<1x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: concat
+func.func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<26x21x3xf32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
+  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<26x21x3xf32>
+  return %0 : tensor<26x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: pad
+func.func @test_pad(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  %padding = tosa.const_shape {value = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
+  %0 = tosa.pad %arg0, %padding : (tensor<13x21x3xf32>, !tosa.shape<6>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: reshape
+func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x819xf32> {
+  %1 = tosa.const_shape {value = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
+  %0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf32>, !tosa.shape<2>) -> tensor<1x819xf32>
+  return %0 : tensor<1x819xf32>
+}
+
+// -----
+// CHECK-LABEL: reverse
+func.func @test_reverse(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
+  %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: slice
+func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> {
+  %0 = tosa.const_shape {value = dense<[4, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %1 = tosa.const_shape {value = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
+  %2 = tosa.slice %arg0, %0, %1 : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xf32>
+  return %2 : tensor<4x11x1xf32>
+}
+
+// -----
+// CHECK-LABEL: tile
+func.func @test_tile(%arg0: tensor<13x21x3xf32>) -> tensor<39x21x6xf32> {
+  %cst = tosa.const_shape { value = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3>
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
+  %0 = tosa.tile %arg0, %cst: (tensor<13x21x3xf32>, !tosa.shape<3>) -> tensor<39x21x6xf32>
+  return %0 : tensor<39x21x6xf32>
+}
+
+// -----
+// CHECK-LABEL: transpose
+func.func @test_transpose(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> {
+  %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
+  %1 = tosa.transpose %arg0, %0 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32>
+  return %1 : tensor<3x13x21xf32>
+}
+
+// -----
+// CHECK-LABEL: gather
+func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x26x3xf32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x26x3xf32>
+  return %0 : tensor<13x26x3xf32>
+}
+
+// -----
+// CHECK-LABEL: scatter
+func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: resize
+func.func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [int16, bf16] ]
+  %1 = tosa.resize %arg0 {mode = "BILINEAR", scale = array<i64: 4, 2, 4, 2>, offset = array<i64: -1, -1>, border = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32>
+  return %1 : tensor<1x64x64x8xf32>
+}
+
+// -----
+// CHECK-LABEL: cast
+func.func @test_cast1(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [int16, fp8e4m3, fp8e5m2, bf16] ]
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: rescale
+func.func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, %multiplier : tensor<1xi32>, %shift : tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
+  // CHECK: profiles: [ [pro_int] ]
+  // CHECK: extensions: [ [int16] ]
+  %0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+  return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+}
+
+// -----
+// CHECK-LABEL: const
+func.func @test_const(%arg0 : index) -> tensor<4xi32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [int4, int16, fp8e4m3, fp8e5m2, bf16] ]
+    %0 = "tosa.const"() {value = dense<[3, 0, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+    return %0 : tensor<4xi32>
+}
+
+// -----
+// CHECK-LABEL: identity
+func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [int4, int16, fp8e4m3, fp8e5m2, bf16] ]
+  %0 = tosa.identity %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+// CHECK-LABEL: cond_if
+func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+    %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  } else {
+    %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+
+// -----
+// CHECK-LABEL: while_loop
+func.func @test_while_loop(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [bf16] ]
+  %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %0, %arg4 = %arg0) : (tensor<i32>, tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<10xi32>) {
+    %2 = tosa.greater_equal %arg3, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    %3 = tosa.logical_not %2 : (tensor<i1>) -> tensor<i1>
+    tosa.yield %3 : tensor<i1>
+  } do {
+  ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<10xi32>):
+    %2 = "tosa.const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = tosa.add %arg3, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    %7 = tosa.const_shape {value = dense<[1]> : tensor<1xindex>} : () -> !tosa.shape<1>
+    %4 = tosa.reshape %2, %7 : (tensor<i32>, !tosa.shape<1>) -> tensor<1xi32>
+    %5 = tosa.add %arg4, %4 : (tensor<10xi32>, tensor<1xi32>) -> tensor<10xi32>
+    %6 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %6, %3, %5 : tensor<i32>, tensor<i32>, tensor<10xi32>
+  }
+  return
+}
+
+// -----
+// CHECK-LABEL: custom
+func.func @test_custom(%arg0: tensor<10xi32>) -> tensor<10xi32> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ ]
+  %0 = tosa.custom %arg0 {operator_name="custom_test", domain_name="tosa.mlir_test", implementation_attrs="" } : (tensor<10xi32>) -> (tensor<10xi32>)
+  return %0 : tensor<10xi32>
+}
+
+// -----
+// CHECK-LABEL: const_shape
+func.func @test_const_shape() -> !tosa.shape<4> {
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ ]
+  %cst = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
+  return %cst : !tosa.shape<4>
+}
+
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index f35c37a1ef70f..342106040d62e 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -4,8 +4,7 @@
 // validation flow.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=bi,mi,mt strict-op-spec-alignment"
-
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
 
 func.func @test_const() -> tensor<1xf32> {
   // expected-error at +1{{'tosa.const' op expected same attr/result element types}}
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
new file mode 100644
index 0000000000000..046b9d5615074
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -0,0 +1,38 @@
+//--------------------------------------------------------------------------------------------------
+// Enable all supported profiles to focus the verification of expected extension requirement errors.
+//--------------------------------------------------------------------------------------------------
+
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp,mt strict-op-spec-alignment"
+
+// -----
+func.func @test_fft2d(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) {
+  // expected-error at +1 {{'tosa.fft2d' op illegal: requires [fft] but not enabled in target}}
+  %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>)
+  return %0, %1 : tensor<1x4x8xf32>, tensor<1x4x8xf32>
+}
+
+// -----
+func.func @test_variable_read_type(%arg0: tensor<2x4x8xi32>) -> () {
+  // expected-error at +1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}}
+  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+  // expected-error at +1 {{'tosa.variable.read' op illegal: requires [variable]}}
+  %0 = tosa.variable.read @stored_var : tensor<2x4x8xi16>
+  return
+}
+
+// -----
+func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
+  // expected-error at +1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}}
+  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+  // expected-error at +1 {{'tosa.variable.write' op illegal: requires [variable]}}
+  tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16>
+  return
+}
+
+// -----
+func.func @test_cast_bf16_i32(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xi32> {
+  // expected-error at +1 {{'tosa.cast' op illegal: requires [bf16] but not enabled in target}}
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xbf16>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index a7f76f2d0fa64..f0b287200f5c5 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1,9 +1,8 @@
 //--------------------------------------------------------------------------------------------------
-// Enable all supported profiles to focus the verification of expected level errors.
+// Enable all supported profiles and extensions to focus the verification of expected level errors.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=bi,mi,mt"
-
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp,mt extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable"
 
 func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {
   // expected-error at +1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}}
diff --git a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
new file mode 100644
index 0000000000000..6dddcf329d110
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
@@ -0,0 +1,83 @@
+//--------------------------------------------------------------------------------------------------
+// Enable all supported extensions to focus the verification of expected profile requirement errors.
+//--------------------------------------------------------------------------------------------------
+
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
+
+// -----
+func.func @test_table(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
+  // expected-error at +1 {{'tosa.table' op illegal: requires [pro_int] but not enabled in target}}
+  %0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<513xi8>) -> tensor<?x?xi8>
+  return
+}
+
+// -----
+func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
+  // expected-error at +1 {{'tosa.conv2d' op illegal: requires [pro_fp] but not enabled in target}}
+  %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
+  return %0 : tensor<1x4x4x8xf32>
+}
+
+// -----
+func.func @test_avg_pool2d_f32(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> {
+  // expected-error at +1 {{'tosa.avg_pool2d' op illegal: requires [pro_fp] but not enabled in target}}
+  %0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32>
+  return %0 : tensor<1x7x7x9xf32>
+}
+
+// -----
+func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
+  // expected-error at +1 {{'tosa.matmul' op illegal: requires [pro_fp] but not enabled in target}}
+  %0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>) -> tensor<1x14x28xf32>
+  return %0 : tensor<1x14x28xf32>
+}
+
+// -----
+func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  // expected-error at +1 {{'tosa.sigmoid' op illegal: requires [pro_fp] but not enabled in target}}
+  %0 = tosa.sigmoid %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.transpose_conv2d' op illegal: requires [pro_fp] but not enabled in target}}
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+  return %0 : tensor<1x32x32x16xf32>
+}
+
+// -----
+func.func @test_add(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  // expected-error at +1 {{'tosa.add' op illegal: requires [pro_fp] but not enabled in target}}
+  %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+func.func @test_reduce_all(%arg0: tensor<13x21x3xi1>) -> tensor<1x21x3xi1> {
+  // expected-error at +1 {{'tosa.reduce_all' op illegal: requires any of [pro_int, pro_fp] but not enabled in target}}
+  %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<13x21x3xi1>) -> tensor<1x21x3xi1>
+  return %0 : tensor<1x21x3xi1>
+}
+
+// -----
+func.func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<26x21x3xf32> {
+  // expected-error at +1 {{'tosa.concat' op illegal: requires [pro_fp] but not enabled in target}}
+  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<26x21x3xf32>
+  return %0 : tensor<26x21x3xf32>
+}
+
+// -----
+func.func @test_cast_i32_f32(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
+  // expected-error at +1 {{'tosa.cast' op illegal: requires [pro_fp] but not enabled in target}}
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+func.func @test_custom(%arg0: tensor<10xi32>) -> tensor<10xi32> {
+  // expected-error at +1 {{'tosa.custom' op illegal: requires any of [pro_int, pro_fp] but not enabled in target}}
+  %0 = tosa.custom %arg0 {operator_name="custom_test", domain_name="tosa.mlir_test", implementation_attrs="" } : (tensor<10xi32>) -> (tensor<10xi32>)
+  return %0 : tensor<10xi32>
+}
+
diff --git a/mlir/test/Dialect/Tosa/profile_bi_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_bi_unsupported.mlir
new file mode 100644
index 0000000000000..479b7569f54ae
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/profile_bi_unsupported.mlir
@@ -0,0 +1,26 @@
+//--------------------------------------------------------------------------------------------------
+// Enable all supported extensions to focus the verification of expected profile requirement errors.
+//--------------------------------------------------------------------------------------------------
+
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
+
+// -----
+func.func @test_table(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
+  // expected-error at +1 {{'tosa.table' op illegal: requires [pro_int] but not enabled in target}}
+  %0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<513xi8>) -> tensor<?x?xi8>
+  return
+}
+
+// -----
+func.func @test_reduce_max(%arg0: tensor<13x21x3xi16>) -> tensor<1x21x3xi16> {
+  // expected-error at +1 {{'tosa.reduce_max' op illegal: requires [pro_int] but not enabled in target}}
+  %0 = tosa.reduce_max %arg0 {axis = 0 : i32} : (tensor<13x21x3xi16>) -> tensor<1x21x3xi16>
+  return %0 : tensor<1x21x3xi16>
+}
+
+// -----
+func.func @test_cast_i8_i32(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi8> {
+ // expected-error at +1 {{'tosa.cast' op illegal: requires [pro_int] but not enabled in target}}
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xi8>
+  return %0 : tensor<13x21x3xi8>
+}
diff --git a/mlir/test/Dialect/Tosa/profile_mi_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_mi_unsupported.mlir
new file mode 100644
index 0000000000000..c46b2543fbed5
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/profile_mi_unsupported.mlir
@@ -0,0 +1,62 @@
+//--------------------------------------------------------------------------------------------------
+// Enable all supported extensions to focus the verification of expected profile requirement errors.
+//--------------------------------------------------------------------------------------------------
+
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
+
+// -----
+func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
+  // expected-error at +1 {{'tosa.conv2d' op illegal: requires [pro_fp] but not enabled in target}}
+  %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
+  return %0 : tensor<1x4x4x8xf32>
+}
+
+// -----
+func.func @test_avg_pool2d_f32(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> {
+  // expected-error at +1 {{'tosa.avg_pool2d' op illegal: requires [pro_fp] but not enabled in target}}
+  %0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32>
+  return %0 : tensor<1x7x7x9xf32>
+}
+
+// -----
+func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
+  // expected-error at +1 {{'tosa.matmul' op illegal: requires [pro_fp] but not enabled in target}}
+  %0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>) -> tensor<1x14x28xf32>
+  return %0 : tensor<1x14x28xf32>
+}
+
+// -----
+func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  // expected-error at +1 {{'tosa.sigmoid' op illegal: requires [pro_fp] but not enabled in target}}
+  %0 = tosa.sigmoid %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+  // expected-error at +1 {{'tosa.transpose_conv2d' op illegal: requires [pro_fp] but not enabled in target}}
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+  return %0 : tensor<1x32x32x16xf32>
+}
+
+// -----
+func.func @test_add(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  // expected-error at +1 {{'tosa.add' op illegal: requires [pro_fp] but not enabled in target}}
+  %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+func.func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<26x21x3xf32> {
+  // expected-error at +1 {{'tosa.concat' op illegal: requires [pro_fp] but not enabled in target}}
+  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<26x21x3xf32>
+  return %0 : tensor<26x21x3xf32>
+}
+
+// -----
+func.func @test_cast_i32_f32(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
+  // expected-error at +1 {{'tosa.cast' op illegal: requires [pro_fp] but not enabled in target}}
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
diff --git a/mlir/test/lib/Dialect/Tosa/CMakeLists.txt b/mlir/test/lib/Dialect/Tosa/CMakeLists.txt
index 7d40881ee6ee4..43f0d0d21c1c0 100644
--- a/mlir/test/lib/Dialect/Tosa/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Tosa/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRTosaTestPasses
   TosaTestPasses.cpp
+  TestAvailability.cpp
 
   EXCLUDE_FROM_LIBMLIR
 
diff --git a/mlir/test/lib/Dialect/Tosa/TestAvailability.cpp b/mlir/test/lib/Dialect/Tosa/TestAvailability.cpp
new file mode 100644
index 0000000000000..bec563d1ec747
--- /dev/null
+++ b/mlir/test/lib/Dialect/Tosa/TestAvailability.cpp
@@ -0,0 +1,78 @@
+//===- TestAvailability.cpp - Pass to test Tosa op availability ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Printing op availability pass
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A pass for testing Tosa op availability.
+struct PrintOpAvailability
+    : public PassWrapper<PrintOpAvailability, OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrintOpAvailability)
+
+  void runOnOperation() override;
+  StringRef getArgument() const final { return "test-tosa-op-availability"; }
+  StringRef getDescription() const final { return "Test Tosa op availability"; }
+};
+} // namespace
+
+void PrintOpAvailability::runOnOperation() {
+  auto f = getOperation();
+  llvm::outs() << f.getName() << "\n";
+
+  Dialect *tosaDialect = getContext().getLoadedDialect("tosa");
+
+  f->walk([&](Operation *op) {
+    if (op->getDialect() != tosaDialect)
+      return WalkResult::advance();
+
+    auto opName = op->getName();
+    auto &os = llvm::outs();
+
+    if (auto profile = dyn_cast<tosa::QueryProfileInterface>(op)) {
+      os << opName << " profiles: [";
+      for (const auto &profs : profile.getProfiles()) {
+        os << " [";
+        llvm::interleaveComma(profs, os, [&](tosa::Profile prof) {
+          os << tosa::stringifyProfile(prof);
+        });
+        os << "]";
+      }
+      os << " ]\n";
+    }
+
+    if (auto extension = dyn_cast<tosa::QueryExtensionInterface>(op)) {
+      os << opName << " extensions: [";
+      for (const auto &exts : extension.getExtensions()) {
+        os << " [";
+        llvm::interleaveComma(exts, os, [&](tosa::Extension ext) {
+          os << tosa::stringifyExtension(ext);
+        });
+        os << "]";
+      }
+      os << " ]\n";
+    }
+
+    os.flush();
+
+    return WalkResult::advance();
+  });
+}
+
+namespace mlir {
+void registerPrintTosaAvailabilityPass() {
+  PassRegistration<PrintOpAvailability>();
+}
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 74007d01347ae..f18ad45dfb708 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -39,6 +39,7 @@ void registerLoopLikeInterfaceTestPasses();
 void registerPassManagerTestPass();
 void registerPrintSpirvAvailabilityPass();
 void registerRegionTestPasses();
+void registerPrintTosaAvailabilityPass();
 void registerShapeFunctionTestPasses();
 void registerSideEffectTestPasses();
 void registerSliceAnalysisTestPass();
@@ -175,6 +176,7 @@ void registerTestTransformDialectExtension(DialectRegistry &);
 void registerTestPasses() {
   registerCloneTestPasses();
   registerConvertToTargetEnvPass();
+  registerPrintTosaAvailabilityPass();
   registerLazyLoadingTestPasses();
   registerLoopLikeInterfaceTestPasses();
   registerPassManagerTestPass();
diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt
index fb507dc7f8c3c..9431c59860522 100644
--- a/mlir/tools/mlir-tblgen/CMakeLists.txt
+++ b/mlir/tools/mlir-tblgen/CMakeLists.txt
@@ -32,6 +32,7 @@ add_tablegen(mlir-tblgen MLIR
   PassGen.cpp
   RewriterGen.cpp
   SPIRVUtilsGen.cpp
+  TosaUtilsGen.cpp
   )
 
 target_link_libraries(mlir-tblgen
diff --git a/mlir/tools/mlir-tblgen/TosaUtilsGen.cpp b/mlir/tools/mlir-tblgen/TosaUtilsGen.cpp
new file mode 100644
index 0000000000000..491f9143edb02
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/TosaUtilsGen.cpp
@@ -0,0 +1,226 @@
+//===- TosaUtilsGen.cpp - Tosa utility generator -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// TosaUtilsGen generates common utility functions for Tosa validation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/CodeGenHelpers.h"
+#include "mlir/TableGen/Format.h"
+#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Operator.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringMap.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+#include "llvm/TableGen/TableGenBackend.h"
+
+#include <list>
+#include <optional>
+
+using llvm::ArrayRef;
+using llvm::formatv;
+using llvm::raw_ostream;
+using llvm::raw_string_ostream;
+using llvm::Record;
+using llvm::RecordKeeper;
+using llvm::SmallVector;
+using llvm::SMLoc;
+using llvm::StringMap;
+using llvm::StringRef;
+using mlir::tblgen::Attribute;
+using mlir::tblgen::EnumAttr;
+using mlir::tblgen::EnumAttrCase;
+using mlir::tblgen::NamedAttribute;
+using mlir::tblgen::NamedTypeConstraint;
+using mlir::tblgen::NamespaceEmitter;
+using mlir::tblgen::Operator;
+
+//===----------------------------------------------------------------------===//
+// Availability Wrapper Class
+//===----------------------------------------------------------------------===//
+
+namespace {
+// Wrapper class with helper methods for accessing availability defined in
+// TableGen.
+class Availability {
+public:
+  explicit Availability(const Record *def);
+
+  // Returns the name of the direct TableGen class for this availability
+  // instance.
+  StringRef getClass() const;
+
+  // Returns the name of the query function insided the generated C++ interface.
+  StringRef getQueryFnName() const;
+
+  // Returns the return type of the query function insided the generated C++
+  // interface.
+  StringRef getQueryFnRetType() const;
+
+  // Returns the code for merging availability requirements.
+  StringRef getMergeActionCode() const;
+
+  // Returns the initializer expression for initializing the final availability
+  // requirements.
+  StringRef getMergeInitializer() const;
+
+  // Returns the C++ statements for preparing availability instance.
+  StringRef getMergeInstancePreparation() const;
+
+  // Returns the concrete availability instance carried in this case.
+  StringRef getMergeInstance() const;
+
+  // Returns the underlying LLVM TableGen Record.
+  const llvm::Record *getDef() const { return def; }
+
+private:
+  // The TableGen definition of this availability.
+  const llvm::Record *def;
+};
+} // namespace
+
+Availability::Availability(const llvm::Record *def) : def(def) {
+  assert(def->isSubClassOf("Availability") &&
+         "must be subclass of TableGen 'Availability' class");
+}
+
+StringRef Availability::getClass() const {
+  SmallVector<const Record *, 1> parentClass;
+  def->getDirectSuperClasses(parentClass);
+  if (parentClass.size() != 1) {
+    PrintFatalError(def->getLoc(),
+                    "expected to only have one direct superclass");
+  }
+  return parentClass.front()->getName();
+}
+
+StringRef Availability::getQueryFnRetType() const {
+  return def->getValueAsString("queryFnRetType");
+}
+
+StringRef Availability::getQueryFnName() const {
+  return def->getValueAsString("queryFnName");
+}
+
+StringRef Availability::getMergeActionCode() const {
+  return def->getValueAsString("mergeAction");
+}
+
+StringRef Availability::getMergeInitializer() const {
+  return def->getValueAsString("initializer");
+}
+
+StringRef Availability::getMergeInstancePreparation() const {
+  return def->getValueAsString("instancePreparation");
+}
+
+StringRef Availability::getMergeInstance() const {
+  return def->getValueAsString("instance");
+}
+
+// Returns the availability spec of the given `def`.
+std::vector<Availability> getAvailabilities(const Record &def) {
+  std::vector<Availability> availabilities;
+
+  if (def.getValue("availability")) {
+    std::vector<const Record *> availDefs =
+        def.getValueAsListOfDefs("availability");
+    availabilities.reserve(availDefs.size());
+    for (const Record *avail : availDefs)
+      availabilities.emplace_back(avail);
+  }
+
+  return availabilities;
+}
+
+//===----------------------------------------------------------------------===//
+// Tosa Availability Impl AutoGen
+//===----------------------------------------------------------------------===//
+
+static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
+  mlir::tblgen::FmtContext fctx;
+  fctx.addSubst("overall", "tblgen_overall");
+
+  std::vector<Availability> opAvailabilities =
+      getAvailabilities(srcOp.getDef());
+
+  // First collect all availability classes this op should implement.
+  // All availability instances keep information for the generated interface and
+  // the instance's specific requirement. Here we remember a random instance so
+  // we can get the information regarding the generated interface.
+  llvm::StringMap<Availability> availClasses;
+  for (const Availability &avail : opAvailabilities)
+    availClasses.try_emplace(avail.getClass(), avail);
+
+  // Then generate implementation for each availability class.
+  for (const auto &availClass : availClasses) {
+    StringRef availClassName = availClass.getKey();
+    Availability avail = availClass.getValue();
+
+    // Generate the implementation method signature.
+    os << formatv("{0} {1}::{2}() {{\n", avail.getQueryFnRetType(),
+                  srcOp.getCppClassName(), avail.getQueryFnName());
+
+    // Create the variable for the final requirement and initialize it.
+    os << formatv("  {0} tblgen_overall = {1};\n", avail.getQueryFnRetType(),
+                  avail.getMergeInitializer());
+
+    // Update with the op's specific availability spec.
+    for (const Availability &avail : opAvailabilities)
+      if (avail.getClass() == availClassName &&
+          (!avail.getMergeInstancePreparation().empty() ||
+           !avail.getMergeActionCode().empty())) {
+        os << "  {\n    "
+           // Prepare this instance.
+           << avail.getMergeInstancePreparation()
+           << "\n    "
+           // Merge this instance.
+           << std::string(
+                  tgfmt(avail.getMergeActionCode(),
+                        &fctx.addSubst("instance", avail.getMergeInstance())))
+           << ";\n  }\n";
+      }
+
+    os << "  return tblgen_overall;\n";
+    os << "}\n";
+  }
+}
+
+static bool emitAvailabilityImpl(const RecordKeeper &recordKeeper,
+                                 raw_ostream &os) {
+  llvm::emitSourceFileHeader("Tosa Op Availability Implementations", os,
+                             recordKeeper);
+
+  auto defs = recordKeeper.getAllDerivedDefinitions("Tosa_Op");
+  for (const auto *def : defs) {
+    Operator op(def);
+    if (def->getValueAsBit("autogenAvailability"))
+      emitAvailabilityImpl(op, os);
+  }
+  return false;
+}
+
+//===----------------------------------------------------------------------===//
+// Op Availability Implementation Hook Registration
+//===----------------------------------------------------------------------===//
+
+static mlir::GenRegistration
+    genOpAvailabilityImpl("gen-tosa-avail-impls",
+                          "Generate Tosa operation utility definitions",
+                          [](const RecordKeeper &records, raw_ostream &os) {
+                            return emitAvailabilityImpl(records, os);
+                          });
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 92aedac837197..e4ca002acd9cf 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -12196,6 +12196,42 @@ gentbl_cc_library(
     deps = [":TosaDialectTdFiles"],
 )
 
+gentbl_cc_library(
+    name = "MLIRTosaEnumsIncGen",
+    tbl_outs = [
+        (
+            ["-gen-enum-decls"],
+            "include/mlir/Dialect/Tosa/IR/TosaEnums.h.inc",
+        ),
+        (
+            ["-gen-enum-defs"],
+            "include/mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/Tosa/IR/TosaOpBase.td",
+)
+
+gentbl_cc_library(
+    name = "MLIRTosaAvailabilityIncGen",
+    tbl_outs = [
+        (
+            ["-gen-avail-interface-decls"],
+            "include/mlir/Dialect/Tosa/IR/TosaAvailability.h.inc",
+        ),
+        (
+            ["-gen-avail-interface-defs"],
+            "include/mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc",
+        ),
+        (
+            ["-gen-tosa-avail-impls"],
+            "include/mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/Tosa/IR/TosaOps.td",
+)
+
 gentbl_cc_library(
     name = "TosaDialectBytecodeGen",
     strip_include_prefix = "include",



More information about the Mlir-commits mailing list