[Mlir-commits] [mlir] 9a5ae34 - [mlir][tosa] Add support for matmul_t_block_scaled (#163433)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 23 07:35:23 PDT 2025


Author: Luke Hutton
Date: 2025-10-23T15:35:18+01:00
New Revision: 9a5ae34eb6e90c51e2231ceb1a8cf933341f3222

URL: https://github.com/llvm/llvm-project/commit/9a5ae34eb6e90c51e2231ceb1a8cf933341f3222
DIFF: https://github.com/llvm/llvm-project/commit/9a5ae34eb6e90c51e2231ceb1a8cf933341f3222.diff

LOG: [mlir][tosa] Add support for matmul_t_block_scaled (#163433)

This commit adds support for the MATMUL_T_BLOCK_SCALED operation from
the EXT_MXFP extension. This includes:
- Operation definition in TosaOps.td
- Micro-scaling supported types definition
- Shape inference and verifiers
- Validation pass checks to ensure usage is only valid when the target
environment includes ext-mxfp and at least v1.1.draft of the
specification.

As part of this commit, a notion of EXT_MXFP is also added. The
extension can be specified as part of the target environment and can
only be used if the specification version is at least 1.1.

Note: currently it excludes support for mxint8. This will be added in a
later commit.

Note: this commit adds support as defined in the spec in
https://github.com/arm/tosa-specification/commit/063846a75b9687ab01e58cb3538472bffb3a03b0.
EXT_MXFP extension is considered experimental and subject to breaking
change.

Co-authored-by: Tat Wai Chong <tatwai.chong at arm.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
    mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
    mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
    mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
    mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
    mlir/test/Dialect/Tosa/invalid_extension.mlir
    mlir/test/Dialect/Tosa/level_check.mlir
    mlir/test/Dialect/Tosa/ops.mlir
    mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
    mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
    mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
    mlir/test/Dialect/Tosa/verifier.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
index 4ecf03c34c1a5..e088eb31338dc 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
@@ -54,6 +54,8 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op);
 /// and provide utilities around the TOSA specification version.
 class TosaSpecificationVersion {
 public:
+  TosaSpecificationVersion() = default;
+
   TosaSpecificationVersion(uint32_t major, uint32_t minor)
       : majorVersion(major), minorVersion(minor) {}
   TosaSpecificationVersion(SpecificationVersion version)
@@ -83,6 +85,10 @@ class TosaSpecificationVersion {
   }
 };
 
+TosaSpecificationVersion getMinVersion(const Profile &profile);
+TosaSpecificationVersion getMinVersion(const Extension &extension);
+TosaSpecificationVersion getMinVersion(const Level &level);
+
 llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version);
 
 /// This class represents the capability enabled in the target implementation
@@ -91,22 +97,19 @@ llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version);
 class TargetEnv {
 public:
   TargetEnv() {}
-  explicit TargetEnv(SpecificationVersion specificationVersion, Level level,
-                     const ArrayRef<Profile> &profiles,
-                     const ArrayRef<Extension> &extensions)
-      : specificationVersion(specificationVersion), level(level) {
-    enabledProfiles.insert_range(profiles);
-    enabledExtensions.insert_range(extensions);
-  }
 
-  explicit TargetEnv(TargetEnvAttr targetAttr)
-      : TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(),
-                  targetAttr.getProfiles(), targetAttr.getExtensions()) {}
+  static FailureOr<TargetEnv>
+  createTargetEnvFromAttr(TargetEnvAttr targetAttr, Location targetEnvAttrLoc);
+
+  static LogicalResult verifyTargetInformation(TargetEnvAttr targetAttr,
+                                               Location targetAttrLoc);
 
   void addProfile(Profile p) { enabledProfiles.insert(p); }
   void addExtension(Extension e) { enabledExtensions.insert(e); }
 
-  SpecificationVersion getSpecVersion() const { return specificationVersion; }
+  TosaSpecificationVersion getSpecVersion() const {
+    return specificationVersion;
+  }
 
   TosaLevel getLevel() const {
     if (level == Level::eightK)
@@ -140,7 +143,17 @@ class TargetEnv {
   }
 
 private:
-  SpecificationVersion specificationVersion;
+  // Require target information is verified before constructing, via the use of
+  // `createTargetEnvFromAttr`.
+  explicit TargetEnv(SpecificationVersion specificationVersion, Level level,
+                     const ArrayRef<Profile> &profiles,
+                     const ArrayRef<Extension> &extensions)
+      : specificationVersion(specificationVersion), level(level) {
+    enabledProfiles.insert_range(profiles);
+    enabledExtensions.insert_range(extensions);
+  }
+
+  TosaSpecificationVersion specificationVersion;
   Level level;
   llvm::SmallSet<Profile, 3> enabledProfiles;
   llvm::SmallSet<Extension, 13> enabledExtensions;

diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index c1b5e785bd739..294fb9d99fdb6 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -554,6 +554,18 @@ extensionComplianceMap = {
        allOf},
       {{Extension::bf16},
        {{{bf16T, bf16T, bf16T, bf16T, fp32T}, SpecificationVersion::V_1_0}}}}},
+    {"tosa.matmul_t_block_scaled",
+     {{{Extension::mxfp},
+       {{{fp4e2m1T, fp8ue8m0T, fp4e2m1T, fp8ue8m0T, fp32T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e2m3T, fp8ue8m0T, fp6e2m3T, fp8ue8m0T, fp32T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e3m2T, fp8ue8m0T, fp6e3m2T, fp8ue8m0T, fp32T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e4m3T, fp8ue8m0T, fp8e4m3T, fp8ue8m0T, fp32T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e5m2T, fp8ue8m0T, fp8e5m2T, fp8ue8m0T, fp32T},
+         SpecificationVersion::V_1_1_DRAFT}}}}},
     {"tosa.max_pool2d",
      {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}},
       {{Extension::fp8e4m3},

diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 8376a4c87dbf2..48e0073c76ab6 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -270,13 +270,14 @@ def Tosa_EXT_CONTROLFLOW  : I32EnumAttrCase<"controlflow", 8>;
 def Tosa_EXT_DOUBLEROUND  : I32EnumAttrCase<"doubleround", 9>;
 def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>;
 def Tosa_EXT_DYNAMIC      : I32EnumAttrCase<"dynamic", 11>;
+def Tosa_EXT_MXFP         : I32EnumAttrCase<"mxfp", 12>;
 
 def Tosa_ExtensionAttr
     : Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [
       Tosa_EXT_NONE, Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16,
       Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE,
       Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND,
-      Tosa_EXT_DYNAMIC
+      Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP
     ]> {
   let extraClassDeclaration = [{
     static llvm::SmallVector<Extension, 11> getAllValues() {
@@ -284,7 +285,7 @@ def Tosa_ExtensionAttr
         Extension::int16, Extension::int4, Extension::bf16,
         Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft,
         Extension::variable, Extension::controlflow, Extension::doubleround,
-        Extension::inexactround, Extension::dynamic
+        Extension::inexactround, Extension::dynamic, Extension::mxfp
       };
     }
   }];
@@ -437,7 +438,7 @@ def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> {
 }
 
 //===----------------------------------------------------------------------===//
-// Iterable attributes.
+// Enum attributes.
 //===----------------------------------------------------------------------===//
 // Defined in `section 3. Enumerations` of the TOSA specification.
 
@@ -463,6 +464,18 @@ def Tosa_RoundingModeAttr
     : Tosa_I32EnumAttr<"RoundingMode", "Supported rounding modes", "rounding_mode",
                     [Tosa_ROUNDING_SINGLE_ROUND, Tosa_ROUNDING_INEXACT_ROUND, Tosa_ROUNDING_DOUBLE_ROUND]>;
 
+def Tosa_BLOCK_SIZE_32 : I32EnumAttrCase<"BLOCK_SIZE_32", 32>;
+
+def Tosa_BlockSizeAttr
+    : Tosa_I32EnumAttr<"BlockSize", "Block size for the block_scaled formats", "block_size",
+                    [Tosa_BLOCK_SIZE_32]> {
+  let extraClassDeclaration = [{
+    static uint32_t getBlockSizeValue(BlockSize blockSize) {
+      return static_cast<uint32_t>(blockSize);
+    }
+  }];
+}
+
 
 //===----------------------------------------------------------------------===//
 // TOSA Interfaces.

diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 137554f49460d..6f07247b478c8 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -347,6 +347,40 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
       "operands attr-dict `:` functional-type(operands, results)";
 }
 
+//===----------------------------------------------------------------------===//
+// Operator: matmul_t_block_scaled
+//===----------------------------------------------------------------------===//
+def Tosa_MatmulTBlockScaledOp : Tosa_InferShapedTypeOp<"matmul_t_block_scaled"> {
+  let summary = "Performs two dimensional matrix multiplications using block scaled tensors.";
+
+  let description = [{
+    Performs two dimensional matrix multiplications using block scaled tensors. The block
+    dimension is always the the last dimension of the tensor, so the result is effectively
+    a matrix multiply of A by the transposed B matrix. If the N dimension of input B is of
+    size 1, the B matrix will be broadcast.
+  }];
+
+  let arguments = (ins
+    Tosa_MXFPDataTensor3D:$a_data,
+    Tosa_MXFPScaleTensor3D:$a_scale,
+    Tosa_MXFPDataTensor3D:$b_data,
+    Tosa_MXFPScaleTensor3D:$b_scale,
+    Tosa_BlockSizeAttr:$block_size
+  );
+
+  let results = (outs
+    Tosa_Tensor3D:$output_data
+  );
+
+  let hasVerifier = 1;
+  let hasCustomAssemblyFormat = 1;
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_MXFP]>
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // Operator: max_pool2d
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 7b946ad6c6a89..79df1b888b40e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -147,6 +147,7 @@ class TosaProfileCompliance {
     case Extension::fp8e4m3:
     case Extension::fp8e5m2:
     case Extension::fft:
+    case Extension::mxfp:
       return {Profile::pro_fp};
     case Extension::variable:
     case Extension::controlflow:

diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 93ab120339d55..20bb961482ad8 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -84,6 +84,10 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
 def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
                                 "number">;
 
+def Tosa_MXFPNumber : AnyTypeOf<[F8E4M3FN, F8E5M2, F4E2M1FN, F6E2M3FN, F6E3M2FN],
+                                "micro-scaling format number">;
+def Tosa_MXFPScaleNumber : AnyTypeOf<[F8E8M0FNU], "micro-scaling format scale number">;
+
 //===----------------------------------------------------------------------===//
 // TOSA Tensor Conformance
 //===----------------------------------------------------------------------===//
@@ -187,6 +191,15 @@ def Tosa_Int32Tensor2D : AnyTypeOf<[
 def Tosa_TensorAtLeast1D : AnyTypeOf<[
   Tosa_UnrankedTensor, TosaRankedTensorOf<[Tosa_AnyNumber], [AtLeastRankOne]>], "tosa-conformant tensor of at least rank 1", "::mlir::TensorType">;
 
+def Tosa_MXFPDataTensor3D : AnyTypeOf<[
+  TosaUnrankedTensorOf<[Tosa_MXFPNumber]>,
+  TosaTensorRankOf<[Tosa_MXFPNumber], [3]>
+]>;
+def Tosa_MXFPScaleTensor3D : AnyTypeOf<[
+  TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>,
+  TosaTensorRankOf<[Tosa_MXFPScaleNumber], [3]>
+]>;
+
 //===----------------------------------------------------------------------===//
 // Generic scalar, vector, or tensor of a particular type.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
index 1cba1bb540c02..32eb286531d28 100644
--- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -12,6 +12,96 @@
 namespace mlir {
 namespace tosa {
 
+llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) {
+  return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor());
+}
+
+TosaSpecificationVersion getMinVersion(const Profile &profile) {
+  switch (profile) {
+  case Profile::pro_int:
+  case Profile::pro_fp:
+    return TosaSpecificationVersion(1, 0);
+  case Profile::none:
+    return TosaSpecificationVersion(0, 0);
+  }
+  llvm_unreachable("Unknown TOSA profile");
+}
+
+TosaSpecificationVersion getMinVersion(const Extension &extension) {
+  switch (extension) {
+  case Extension::int16:
+  case Extension::int4:
+  case Extension::bf16:
+  case Extension::fp8e4m3:
+  case Extension::fp8e5m2:
+  case Extension::fft:
+  case Extension::variable:
+  case Extension::controlflow:
+  case Extension::doubleround:
+  case Extension::inexactround:
+  case Extension::dynamic:
+    return TosaSpecificationVersion(1, 0);
+  case Extension::mxfp:
+    return TosaSpecificationVersion(1, 1);
+  case Extension::none:
+    return TosaSpecificationVersion(0, 0);
+  }
+  llvm_unreachable("Unknown TOSA extension");
+}
+
+TosaSpecificationVersion getMinVersion(const Level &level) {
+  switch (level) {
+  case Level::eightK:
+  case Level::none:
+    return TosaSpecificationVersion(1, 0);
+  }
+  llvm_unreachable("Unknown TOSA level");
+}
+
+FailureOr<TargetEnv>
+TargetEnv::createTargetEnvFromAttr(TargetEnvAttr targetAttr,
+                                   Location targetEnvAttrLoc) {
+  if (failed(verifyTargetInformation(targetAttr, targetEnvAttrLoc)))
+    return failure();
+
+  return TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(),
+                   targetAttr.getProfiles(), targetAttr.getExtensions());
+}
+
+LogicalResult TargetEnv::verifyTargetInformation(TargetEnvAttr targetAttr,
+                                                 Location targetAttrLoc) {
+  TosaSpecificationVersion targetVersion(targetAttr.getSpecificationVersion());
+
+  const auto isCompatibleWithTargetVersion =
+      [&](const auto &targetEnum, Location targetAttrLoc,
+          StringRef enumName) -> LogicalResult {
+    const TosaSpecificationVersion minRequiredVersion =
+        getMinVersion(targetEnum);
+    if (!targetVersion.isBackwardsCompatibleWith(minRequiredVersion))
+      return emitError(targetAttrLoc, enumName)
+             << " '" << stringifyEnum(targetEnum)
+             << "' is not compatible with the target version "
+             << stringifyVersion(targetVersion)
+             << ", minimum required version is "
+             << stringifyVersion(minRequiredVersion);
+    return success();
+  };
+
+  for (const auto &profile : targetAttr.getProfiles())
+    if (failed(
+            isCompatibleWithTargetVersion(profile, targetAttrLoc, "profile")))
+      return failure();
+  for (const auto &extension : targetAttr.getExtensions())
+    if (failed(isCompatibleWithTargetVersion(extension, targetAttrLoc,
+                                             "extension")))
+      return failure();
+  if (failed(isCompatibleWithTargetVersion(targetAttr.getLevel(), targetAttrLoc,
+                                           "level")))
+    return failure();
+
+  return success();
+}
+
 TargetEnvAttr lookupTargetEnv(Operation *op) {
   while (op) {
     op = SymbolTable::getNearestSymbolTable(op);
@@ -39,9 +129,5 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) {
   return getDefaultTargetEnv(op->getContext());
 }
 
-llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) {
-  return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor());
-}
-
 } // namespace tosa
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 00f84bc43f444..6cd0eaea3ce6c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -321,6 +321,19 @@ ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser,
     }
   }
 
+  // special handling: block_size accepts a *bare* BlockSizeMode enum
+  if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) {
+    if (name == "block_size" && succeeded(parser.parseOptionalKeyword(&kw))) {
+      auto sym = symbolizeBlockSize(kw);
+      if (!sym)
+        return parser.emitError(parser.getCurrentLocation())
+               << "invalid block_size value: " << kw;
+      auto attr = BlockSizeAttr::get(parser.getContext(), sym.value());
+      outAttrs.push_back(NamedAttribute(name, attr));
+      return success();
+    }
+  }
+
   // Default path: parse any normal attribute literal, including fully qualified
   // enum keyword
   Attribute attr;
@@ -373,6 +386,8 @@ void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) {
   } else if (auto nanPropagationModeAttr =
                  dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
     parser << nanPropagationModeAttr.getValue();
+  } else if (auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) {
+    parser << blockSizeAttr.getValue();
   } else {
     parser.printAttribute(attr);
   }
@@ -508,6 +523,15 @@ void ReduceMinOp::print(OpAsmPrinter &parser) {
   printWithNanPropagationHandling(parser, *this);
 }
 
+ParseResult MatmulTBlockScaledOp::parse(OpAsmParser &parser,
+                                        OperationState &result) {
+  return parseWithEnumHandling<tosa::BlockSize>(parser, result);
+}
+
+void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) {
+  printWithEnumHandling(parser, *this);
+}
+
 //===----------------------------------------------------------------------===//
 // Tosa utilities.
 //===----------------------------------------------------------------------===//
@@ -933,32 +957,35 @@ static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
 
 // verify that inType and outType have same element types
 template <typename T>
-static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
-  auto inputType = llvm::dyn_cast<TensorType>(inType);
-  auto outputType = llvm::dyn_cast<TensorType>(outType);
-  if (!inputType) {
-    op.emitOpError("expect shaped tensor for input, got ") << inType;
+static LogicalResult verifySameElementTypes(T op, Type aType, Type bType,
+                                            StringRef aName = "input",
+                                            StringRef bName = "output") {
+  auto aTType = llvm::dyn_cast<TensorType>(aType);
+  auto bTType = llvm::dyn_cast<TensorType>(bType);
+  if (!aTType) {
+    op.emitOpError("expect shaped tensor for") << aName << ", got " << aType;
     return failure();
   }
-  if (!outputType) {
-    op.emitOpError("expect shaped tensor for output, got ") << outType;
+  if (!bTType) {
+    op.emitOpError("expect shaped tensor for") << bName << ", got" << bType;
     return failure();
   }
-  auto inputElementType = inputType.getElementType();
-  auto outputElementType = outputType.getElementType();
-  auto inputQuantType =
-      llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
-  auto outputQuantType =
-      llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
-  if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
-      (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
-      inputElementType != outputElementType) {
+  auto aElementType = aTType.getElementType();
+  auto bElementType = bTType.getElementType();
+  auto aQuantType =
+      llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType);
+  auto bQuantType =
+      llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType);
+  if ((aElementType.isIntOrIndexOrFloat() || aQuantType) &&
+      (bElementType.isIntOrIndexOrFloat() || bQuantType) &&
+      aElementType != bElementType) {
     // only check if both element types are int/index/float/UniformQuantized
     // eg, not sure how to check quant::QuantizedType
     // this happens in test_conv2d_q_grouped_convolution in
     // tfl-to-tosa-pipeline.mlir
-    op.emitOpError("expect input and output to have same element type, got ")
-        << inputElementType << " and " << outputElementType;
+    op.emitOpError("expect ")
+        << aName << " and " << bName << " to have same element type, got "
+        << aElementType << " and " << bElementType;
     return failure();
   }
   return success();
@@ -1846,6 +1873,161 @@ LogicalResult MatMulOp::verify() {
   return success();
 }
 
+LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents(
+    MLIRContext *context, ::std::optional<Location> location,
+    MatmulTBlockScaledOp::Adaptor adaptor,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  SmallVector<int64_t, 3> outShape(3, ShapedType::kDynamic);
+
+  const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType());
+  if (aDataShape.hasRank()) {
+    outShape[0] = aDataShape.getDimSize(0);
+    outShape[1] = aDataShape.getDimSize(1);
+  }
+
+  const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType());
+  if (aScaleShape.hasRank()) {
+    outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0)
+                                                     : outShape[0];
+    outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1)
+                                                     : outShape[1];
+  }
+
+  // If B batch size is 1, it is broadcast across A's batch size
+  const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType());
+  if (bDataShape.hasRank()) {
+    const int64_t bDataBatchSize = bDataShape.getDimSize(0);
+    if (bDataBatchSize != 1)
+      outShape[0] =
+          ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0];
+    outShape[2] = bDataShape.getDimSize(1);
+  }
+
+  const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType());
+  if (bScaleShape.hasRank()) {
+    const int64_t bScaleBatchSize = bScaleShape.getDimSize(0);
+    if (bScaleBatchSize != 1)
+      outShape[0] =
+          ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0];
+    outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1)
+                                                     : outShape[2];
+  }
+
+  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
+  return success();
+}
+
+LogicalResult MatmulTBlockScaledOp::verify() {
+  // Verify same input data types
+  const Type aDataType = getAData().getType();
+  const Type bDataType = getBData().getType();
+  if (failed(verifySameElementTypes(*this, aDataType, bDataType, "A_data",
+                                    "B_data")))
+    return failure();
+
+  auto tryUpdateDimOrFailure = [&](int64_t &currDim, const int64_t newDim,
+                                   const StringRef operandName,
+                                   const StringRef dimName) -> LogicalResult {
+    if (ShapedType::isDynamic(currDim)) {
+      currDim = newDim;
+      return success();
+    } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
+      return emitOpError("expected ")
+             << dimName << " of " << operandName << " to match size " << currDim
+             << ", got " << newDim;
+    }
+    return success();
+  };
+
+  // Verify input shape compatibility
+  int64_t N = ShapedType::kDynamic;
+  int64_t D = ShapedType::kDynamic;
+  int64_t H = ShapedType::kDynamic;
+  int64_t W = ShapedType::kDynamic;
+  int64_t C = ShapedType::kDynamic;
+  int64_t multiplesOfC = ShapedType::kDynamic;
+
+  const ShapeAdaptor aDataShape = ShapeAdaptor(aDataType);
+  if (aDataShape.hasRank()) {
+    N = aDataShape.getDimSize(0);
+    H = aDataShape.getDimSize(1);
+    C = aDataShape.getDimSize(2);
+  }
+
+  const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType());
+  if (aScaleShape.hasRank()) {
+    if (failed(tryUpdateDimOrFailure(N, aScaleShape.getDimSize(0), "a_scale",
+                                     "batch")) ||
+        failed(tryUpdateDimOrFailure(H, aScaleShape.getDimSize(1), "a_scale",
+                                     "height")))
+      return failure();
+    multiplesOfC = aScaleShape.getDimSize(2);
+  }
+
+  const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType);
+  if (bDataShape.hasRank()) {
+    if (failed(tryUpdateDimOrFailure(D, bDataShape.getDimSize(0), "b_data",
+                                     "batch")) ||
+        failed(tryUpdateDimOrFailure(C, bDataShape.getDimSize(2), "b_data",
+                                     "channels")))
+      return failure();
+    W = bDataShape.getDimSize(1);
+  }
+
+  const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType());
+  if (bScaleShape.hasRank()) {
+    if (failed(tryUpdateDimOrFailure(D, bScaleShape.getDimSize(0), "b_scale",
+                                     "batch")) ||
+        failed(tryUpdateDimOrFailure(W, bScaleShape.getDimSize(1), "b_scale",
+                                     "width")) ||
+        failed(tryUpdateDimOrFailure(multiplesOfC, bScaleShape.getDimSize(2),
+                                     "b_scale", "C/block_size")))
+      return failure();
+  }
+
+  // Verify batch size is broadcast compatible
+  if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1)
+    return emitOpError("expect B matrix batch size to be broadcast compatible "
+                       "with A, got D=")
+           << D << " vs N=" << N;
+
+  // Verify C is a multiple of block size
+  const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+  if (ShapedType::isStatic(C) && C % blockSize != 0)
+    return emitOpError("expect C to be a multiple of block size, got C=")
+           << C << ", block_size=" << blockSize;
+
+  // Verify multiplesOfC is C / block size
+  if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) &&
+      multiplesOfC != C / blockSize)
+    return emitOpError(
+               "expect scale operands dimension 2 to equal C/block_size (")
+           << C << "/" << blockSize << ")"
+           << ", got " << multiplesOfC;
+
+  // Verify output shape
+  N = ShapedType::isDynamic(N) ? D : N;
+  const SmallVector<int64_t, 3> expectedOutputShape = {N, H, W};
+  const auto outputType = cast<ShapedType>(getResult().getType());
+  if (outputType.hasRank() &&
+      failed(
+          verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) {
+    InFlightDiagnostic opError = emitOpError("expected output shape ");
+    auto stringifyDim = [&](int64_t d) {
+      if (ShapedType::isDynamic(d))
+        opError << "?";
+      else
+        opError << d;
+    };
+    llvm::interleaveComma(outputType.getShape(), opError, stringifyDim);
+    opError << " to be compatible with expected output shape ";
+    llvm::interleaveComma(expectedOutputShape, opError, stringifyDim);
+    return opError;
+  }
+
+  return success();
+}
+
 LogicalResult tosa::PadOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     PadOp::Adaptor adaptor,

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index f072e3eff1975..e965ae0cf9888 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -25,6 +25,12 @@ TosaProfileCompliance::TosaProfileCompliance() {
   const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
   const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};
 
+  // micro-scaling formats
+  const TypeInfo fp6e2m3T = {mlir::Float6E2M3FNType::getTypeID(), 6};
+  const TypeInfo fp6e3m2T = {mlir::Float6E3M2FNType::getTypeID(), 6};
+  const TypeInfo fp4e2m1T = {mlir::Float4E2M1FNType::getTypeID(), 4};
+  const TypeInfo fp8ue8m0T = {mlir::Float8E8M0FNUType::getTypeID(), 8};
+
 // The profile-based compliance content below is auto-generated by a script
 // in https://git.mlplatform.org/tosa/specification.git
 #include "mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc"
@@ -269,6 +275,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
 
   // For the most of tosa operators, all operands are profile/extension related
   // and hence are all considered in this profile-based compilance check.
+  POPULATE_PROFILE_INFO_COMMON(MatmulTBlockScaled)
   POPULATE_PROFILE_INFO_COMMON(Cast)
   POPULATE_PROFILE_INFO_COMMON(Const)
   POPULATE_PROFILE_INFO_COMMON(ArgMax)
@@ -623,6 +630,14 @@ TosaProfileCompliance::stringifyTypeInfo(const TypeInfo &typeInfo) {
     return {"fp8e4m3"};
   } else if (typeInfo.typeID == mlir::Float8E5M2Type::getTypeID()) {
     return {"fp8e5m2"};
+  } else if (typeInfo.typeID == mlir::Float6E2M3FNType::getTypeID()) {
+    return {"fp6e2m3"};
+  } else if (typeInfo.typeID == mlir::Float6E3M2FNType::getTypeID()) {
+    return {"fp6e3m2"};
+  } else if (typeInfo.typeID == mlir::Float4E2M1FNType::getTypeID()) {
+    return {"fp4e2m1"};
+  } else if (typeInfo.typeID == mlir::Float8E8M0FNUType::getTypeID()) {
+    return {"fp8e8m0"};
   }
   llvm_unreachable("unknown type");
 }

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 82f2f7eb17af4..3f874d94ab9be 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -657,6 +657,7 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
   CHECK_SIZES(TransposeConv2D);
   CHECK_SIZES(FFT2d);
   CHECK_SIZES(MatMul);
+  CHECK_SIZES(MatmulTBlockScaled);
   CHECK_SIZES(MaxPool2d);
   CHECK_SIZES(RFFT2d);
   // Scatter/Gather Operators
@@ -1192,9 +1193,9 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
 bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
   if (isa<FloatType>(type)) {
     return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
-               Float8E5M2Type>(type);
-  }
-  if (auto intTy = dyn_cast<IntegerType>(type)) {
+               Float8E5M2Type, Float4E2M1FNType, Float6E2M3FNType,
+               Float6E3M2FNType, Float8E8M0FNUType>(type);
+  } else if (auto intTy = dyn_cast<IntegerType>(type)) {
     if (intTy.isSignless()) {
       switch (intTy.getWidth()) {
       case 1:
@@ -1220,13 +1221,19 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
 }
 
 void TosaValidation::runOnOperation() {
+  ModuleOp modOp = getOperation();
+  const TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(modOp);
+  const auto maybeTargetEnv =
+      tosa::TargetEnv::createTargetEnvFromAttr(targetEnvAttr, modOp.getLoc());
+  if (failed(maybeTargetEnv))
+    return signalPassFailure();
+  targetEnv = *maybeTargetEnv;
+
   TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>();
   if (!tosaDialect)
     return;
 
-  targetEnv = tosa::TargetEnv(lookupTargetEnvOrDefault(getOperation()));
-
-  getOperation().walk([&](Operation *op) {
+  modOp.walk([&](Operation *op) {
     if (op->getDialect() != tosaDialect)
       return;
 

diff  --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index e5c9402caaddc..005601d4017b8 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -538,3 +538,11 @@ func.func @test_avg_pool2d_non_const_output_zp(%arg0: tensor<1x32x32x8xf32>, %ou
          (tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32>
   return %0 : tensor<1x32x32x8xf32>
 }
+
+// -----
+
+func.func @test_matmul_t_block_scaled(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E4M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+  // expected-error at +1 {{'tosa.matmul_t_block_scaled' op illegal: requires [mxfp] but not enabled in target}}
+  %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E4M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+  return %0 : tensor<4x8x16xf32>
+}

diff  --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 8cc357efa0c77..8771e6e2476e4 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1622,3 +1622,12 @@ func.func @test_unranked_weight_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor
   %0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf32>, tensor<*xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<*xf32>
   return %0 : tensor<*xf32>
 }
+
+// -----
+
+// CHECK-LABEL: test_matmul_t_block_scaled_invalid_size
+func.func @test_matmul_t_block_scaled_invalid_size(%arg0: tensor<4x8x536870912xf4E2M1FN>, %arg1: tensor<4x8x16777216xf8E8M0FNU>, %arg2: tensor<4x16x536870912xf4E2M1FN>, %arg3: tensor<4x16x16777216xf8E8M0FNU>) -> tensor<*xf32> {
+  // expected-error at +1 {{'tosa.matmul_t_block_scaled' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
+  %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x536870912xf4E2M1FN>, tensor<4x8x16777216xf8E8M0FNU>, tensor<4x16x536870912xf4E2M1FN>, tensor<4x16x16777216xf8E8M0FNU>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}

diff  --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 868b7b7a93335..9bf36b5fd4c7d 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -1226,3 +1226,45 @@ func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x29x3xf8E4M3FN>, %arg1: tensor<
   %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x29x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x29x3xf8E4M3FN>
   return %0 : tensor<13x29x3xf8E4M3FN>
 }
+
+// -----
+// CHECK-LABEL: test_matmul_t_block_scaled_static
+func.func @test_matmul_t_block_scaled_static(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E4M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+  %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E4M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+  return %0 : tensor<4x8x16xf32>
+}
+
+// -----
+// CHECK-LABEL: test_matmul_t_block_scaled_unranked
+func.func @test_matmul_t_block_scaled_unranked(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<*xf8E8M0FNU>) -> tensor<*xf32> {
+  %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+// CHECK-LABEL: test_matmul_t_block_scaled_fp6e3m2
+func.func @test_matmul_t_block_scaled_fp6e3m2(%arg0: tensor<4x8x32xf6E3M2FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf6E3M2FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+  %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf6E3M2FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E3M2FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+  return %0 : tensor<4x8x16xf32>
+}
+
+// -----
+// CHECK-LABEL: test_matmul_t_block_scaled_fp6e2m3
+func.func @test_matmul_t_block_scaled_fp6e2m3(%arg0: tensor<4x8x32xf6E2M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf6E2M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+  %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf6E2M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E2M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+  return %0 : tensor<4x8x16xf32>
+}
+
+// -----
+// CHECK-LABEL: test_matmul_t_block_scaled_fp4e2m1
+func.func @test_matmul_t_block_scaled_fp4e2m1(%arg0: tensor<4x8x32xf4E2M1FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf4E2M1FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+  %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf4E2M1FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf4E2M1FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+  return %0 : tensor<4x8x16xf32>
+}
+
+// -----
+// CHECK-LABEL: test_matmul_t_block_scaled_broadcast
+func.func @test_matmul_t_block_scaled_broadcast(%arg0: tensor<?x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<?x16x32xf8E4M3FN>, %arg3: tensor<1x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+  %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<?x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<?x16x32xf8E4M3FN>, tensor<1x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+  return %0 : tensor<4x8x16xf32>
+}

diff  --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
index 7ff8065ee41fd..0271d71561a52 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
@@ -2,7 +2,7 @@
 // Enable all supported extensions to focus the verification of expected profile requirement errors.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround,mxfp" -tosa-validate="strict-op-spec-alignment"
 
 // -----
 func.func @test_const_f16() -> tensor<3x11x11x3xf16> {
@@ -325,3 +325,10 @@ func.func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> {
   %1 = tosa.resize %arg0, %scale, %offset, %border { mode = BILINEAR } : (tensor<1x32x32x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x64x64x8xf32>
   return %1 : tensor<1x64x64x8xf32>
 }
+
+// -----
+func.func @test_matmul_t_block_scaled(%arg0: tensor<4x8x32xf6E3M2FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf6E3M2FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+  // expected-error at +1 {{'tosa.matmul_t_block_scaled' op illegal: requires [pro_fp] but not enabled in target}}
+  %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf6E3M2FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E3M2FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+  return %0 : tensor<4x8x16xf32>
+}

diff  --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 80f06f11fe4ad..72479fe21ade8 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1574,3 +1574,57 @@ func.func @test_mul_scalar(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<*xf
   %0 = tosa.mul %arg0, %arg1, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<*xf32>
   return %0 : tensor<*xf32>
 }
+
+// -----
+
+// CHECK-LABEL: test_matmul_t_block_scaled_static
+func.func @test_matmul_t_block_scaled_static(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<1x16x32xf8E4M3FN>, %arg3: tensor<1x16x1xf8E8M0FNU>) -> tensor<?x?x?xf32> {
+  // CHECK: -> tensor<4x8x16xf32>
+  %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<1x16x32xf8E4M3FN>, tensor<1x16x1xf8E8M0FNU>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_t_block_scaled_unranked_a_data
+func.func @test_matmul_t_block_scaled_unranked_a_data(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E4M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<?x?x?xf32> {
+  // CHECK: -> tensor<4x8x16xf32>
+  %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E4M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_t_block_scaled_unranked_b_data_and_scale
+func.func @test_matmul_t_block_scaled_unranked_b_data_and_scale(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<*xf8E8M0FNU>) -> tensor<?x?x?xf32> {
+  // CHECK: -> tensor<4x8x?xf32>
+  %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_t_block_scaled_unranked_all
+func.func @test_matmul_t_block_scaled_unranked_all(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<*xf8E8M0FNU>) -> tensor<?x?x?xf32> {
+  // CHECK: -> tensor<?x?x?xf32>
+  %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_t_block_scaled_broadcast_b_data
+func.func @test_matmul_t_block_scaled_broadcast_b_data(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<1x4x32xf8E4M3FN>, %arg3: tensor<1x4x1xf8E8M0FNU>) -> tensor<?x?x?xf32> {
+  // CHECK: -> tensor<?x?x4xf32>
+  %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>, tensor<1x4x32xf8E4M3FN>, tensor<1x4x1xf8E8M0FNU>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_t_block_scaled_broadcast_b_scale
+func.func @test_matmul_t_block_scaled_broadcast_b_scale(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<1x4x1xf8E8M0FNU>) -> tensor<?x?x?xf32> {
+  // CHECK: -> tensor<?x?x4xf32>
+  %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<1x4x1xf8E8M0FNU>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}

diff  --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index 81645092bf195..2040a4bc7e6af 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,mxfp" -tosa-validate="strict-op-spec-alignment" | FileCheck %s
 
 // -----
 
@@ -18,3 +18,11 @@ func.func @test_matmul_fp8_input_fp32_acc_type(%arg0: tensor<1x14x19xf8E4M3FN>,
   %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>)  -> tensor<1x14x28xf32>
   return %0 : tensor<1x14x28xf32>
 }
+
+// -----
+
+// CHECK-LABEL: test_matmul_t_block_scaled_fp6e2m3
+func.func @test_matmul_t_block_scaled_fp6e2m3(%arg0: tensor<4x8x32xf6E2M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf6E2M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+  %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = BLOCK_SIZE_32} : (tensor<4x8x32xf6E2M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E2M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+  return %0 : tensor<4x8x16xf32>
+}

diff  --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 430b06ad16c39..4be5d725ad612 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1102,3 +1102,51 @@ func.func @scatter_invalid_K_W(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<2x6xi32
   %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<2x6xi32>, tensor<2x6x5xi32>) -> tensor<2x4x5xi32>
   return
 }
+
+// -----
+
+func.func @test_matmul_t_block_scaled_data_mismatch(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E5M2>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+  // expected-error at +1 {{'tosa.matmul_t_block_scaled' op expect A_data and B_data to have same element type, got 'f8E4M3FN' and 'f8E5M2'}}
+  %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E5M2>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+  return %0 : tensor<4x8x16xf32>
+}
+
+// -----
+
+func.func @test_matmul_t_block_scaled_output_batch_mismatch(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<?x8x1xf8E8M0FNU>, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<4x?x?xf8E8M0FNU>) -> tensor<5x?x?xf32> {
+  // expected-error at +1 {{'tosa.matmul_t_block_scaled' op expected output shape 5, ?, ? to be compatible with expected output shape 4, 8, ?}}
+  %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<*xf8E4M3FN>, tensor<?x8x1xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<4x?x?xf8E8M0FNU>) -> tensor<5x?x?xf32>
+  return %0 : tensor<5x?x?xf32>
+}
+
+// -----
+
+func.func @test_matmul_t_block_scaled_output_height_mismatch(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<?x9x1xf8E8M0FNU>, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<4x?x?xf8E8M0FNU>) -> tensor<4x8x?xf32> {
+  // expected-error at +1 {{'tosa.matmul_t_block_scaled' op expected output shape 4, 8, ? to be compatible with expected output shape 4, 9, ?}}
+  %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<*xf8E4M3FN>, tensor<?x9x1xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<4x?x?xf8E8M0FNU>) -> tensor<4x8x?xf32>
+  return %0 : tensor<4x8x?xf32>
+}
+
+// -----
+
+func.func @test_matmul_t_block_scaled_output_width_mismatch(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<?x?x1xf8E8M0FNU>, %arg2: tensor<?x1x?xf8E4M3FN>, %arg3: tensor<*xf8E8M0FNU>) -> tensor<?x?x10xf32> {
+  // expected-error at +1 {{'tosa.matmul_t_block_scaled' op expected output shape ?, ?, 10 to be compatible with expected output shape ?, ?, 1}}
+  %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<*xf8E4M3FN>, tensor<?x?x1xf8E8M0FNU>, tensor<?x1x?xf8E4M3FN>, tensor<*xf8E8M0FNU>) -> tensor<?x?x10xf32>
+  return %0 : tensor<?x?x10xf32>
+}
+
+// -----
+
+func.func @test_matmul_t_block_scaled_channel_not_multiple_of_block_size(%arg0: tensor<4x8x55xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E4M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+  // expected-error at +1 {{'tosa.matmul_t_block_scaled' op expected channels of b_data to match size 55, got 32}}
+  %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x8x55xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E4M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+  return %0 : tensor<4x8x16xf32>
+}
+
+// -----
+
+func.func @test_matmul_t_block_scaled_batch_mismatch(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<2x16x32xf8E4M3FN>, %arg3: tensor<2x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+  // expected-error at +1 {{'tosa.matmul_t_block_scaled' op expect B matrix batch size to be broadcast compatible with A, got D=2 vs N=4}}
+  %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<2x16x32xf8E4M3FN>, tensor<2x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+  return %0 : tensor<4x8x16xf32>
+}


        


More information about the Mlir-commits mailing list