[Mlir-commits] [mlir] [mlir][tosa] Add support for CONV2D_BLOCK_SCALED operator (PR #172294)

Luke Hutton llvmlistbot at llvm.org
Thu Jan 15 11:01:32 PST 2026


https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/172294

>From 3def7a84fc24e071be48045d3c0a283d468429f6 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Thu, 27 Nov 2025 10:03:58 +0000
Subject: [PATCH 1/3] [mlir][tosa] Add support for CONV2D_BLOCK_SCALED operator

This commit adds support for an MXFP CONV2D operation,
CONV2D_BLOCK_SCALED, added to the specification in
https://github.com/arm/tosa-specification/commit/408a5e53f5a7357adef7121ba3cc88e2225d4231.

This includes:
- Operator definition
- Addition of the EXT_MXFP_CONV extension
- Verification logic for the operator
- Output shape inference for the operator
- Validation checks to ensure compliance with the TOSA specification.

Change-Id: I7553f7796d2d156f43310108e9a69a593cdece33
---
 .../Dialect/Tosa/IR/TosaComplianceData.h.inc  |  12 +
 .../mlir/Dialect/Tosa/IR/TosaOpBase.td        |  12 +-
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  |  37 ++
 .../Dialect/Tosa/IR/TosaProfileCompliance.h   |   1 +
 .../mlir/Dialect/Tosa/IR/TosaTypesBase.td     |  10 +
 mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp        |   1 +
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 392 ++++++++++++++----
 .../Tosa/Transforms/TosaProfileCompliance.cpp |  13 +
 .../Tosa/Transforms/TosaValidation.cpp        |  54 ++-
 mlir/test/Dialect/Tosa/invalid.mlir           |  13 +-
 mlir/test/Dialect/Tosa/invalid_extension.mlir |  15 +-
 mlir/test/Dialect/Tosa/level_check.mlir       |  71 ++++
 mlir/test/Dialect/Tosa/ops.mlir               |  20 +
 .../Tosa/profile_pro_fp_unsupported.mlir      |  16 +-
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir |  48 +++
 .../tosa-validation-version-1p1-valid.mlir    |  13 +-
 mlir/test/Dialect/Tosa/verifier.mlir          | 132 ++++++
 17 files changed, 767 insertions(+), 93 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index 9e6471aa7d04e..009775293a987 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -516,6 +516,18 @@ extensionComplianceMap = {
       {{Extension::bf16},
        {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
          SpecificationVersion::V_1_0}}}}},
+    {"tosa.conv2d_block_scaled",
+     {{{Extension::mxfp_conv},
+       {{{fp4e2m1T, fp8ue8m0T, fp4e2m1T, fp8ue8m0T, fp32T, fp32T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e2m3T, fp8ue8m0T, fp6e2m3T, fp8ue8m0T, fp32T, fp32T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e3m2T, fp8ue8m0T, fp6e3m2T, fp8ue8m0T, fp32T, fp32T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e4m3T, fp8ue8m0T, fp8e4m3T, fp8ue8m0T, fp32T, fp32T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e5m2T, fp8ue8m0T, fp8e5m2T, fp8ue8m0T, fp32T, fp32T},
+         SpecificationVersion::V_1_1_DRAFT}}}}},
     {"tosa.conv3d",
      {{{Extension::int4},
        {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 419340256fa59..1498fad2f08e0 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -241,6 +241,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
 // INEXACTROUND : Adds inexact rounding support to the RESCALE operator.
 // DYNAMIC      : Removes all Compile Time Constant state for CTC inputs.
 // MXFP         : Microscaling formats.
+// MXFP_CONV    : Microscaling format convolution.
 // SHAPE        : Shape calcuation operators.
 //===----------------------------------------------------------------------===//
 
@@ -275,24 +276,25 @@ def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>;
 def Tosa_EXT_DYNAMIC      : I32EnumAttrCase<"dynamic", 11>;
 def Tosa_EXT_MXFP         : I32EnumAttrCase<"mxfp", 12>;
 def Tosa_EXT_INT64        : I32EnumAttrCase<"int64", 13>;
-def Tosa_EXT_SHAPE        : I32EnumAttrCase<"shape", 14>;
-
+def Tosa_EXT_MXFP_CONV    : I32EnumAttrCase<"mxfp_conv", 14>;
+def Tosa_EXT_SHAPE        : I32EnumAttrCase<"shape", 15>;
 
 def Tosa_ExtensionAttr
     : Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [
       Tosa_EXT_NONE, Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16,
       Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE,
       Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND,
-      Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP, Tosa_EXT_INT64, Tosa_EXT_SHAPE,
+      Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP, Tosa_EXT_INT64, Tosa_EXT_MXFP_CONV,
+      Tosa_EXT_SHAPE,
     ]> {
   let extraClassDeclaration = [{
-    static llvm::SmallVector<Extension, 13> getAllValues() {
+    static llvm::SmallVector<Extension, 14> getAllValues() {
       return {
         Extension::int16, Extension::int4, Extension::bf16,
         Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft,
         Extension::variable, Extension::controlflow, Extension::doubleround,
         Extension::inexactround, Extension::dynamic, Extension::mxfp,
-        Extension::int64, Extension::shape
+        Extension::int64, Extension::mxfp_conv, Extension::shape
       };
     }
   }];
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 370ce8c161d0b..edd8f0fc266bb 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -163,6 +163,43 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Operator: conv2d_block_scaled
+//===----------------------------------------------------------------------===//
+def Tosa_Conv2DBlockScaledOp : Tosa_InferShapedTypeOp<"conv2d_block_scaled"> {
+  let summary = "Performs two dimensional convolution using block scaled tensors.";
+
+  let description = [{
+    Performs a 2D convolution over the given input data and scales, using
+    the weight data and scales. Implementations may choose to skip calculation
+    of multiplies in the padding area.
+  }];
+
+  let arguments = (ins
+    Tosa_MXFPDataTensor4D:$input_data,
+    Tosa_MXFPScaleTensor4D:$input_scale,
+    Tosa_MXFPDataTensor4D:$weight_data,
+    Tosa_MXFPScaleTensor4D:$weight_scale,
+    Tosa_Tensor1D:$bias,
+    Rank4TosaShape:$pad,
+    Rank2TosaShape:$stride,
+    Rank2TosaShape:$dilation,
+    Tosa_BlockSizeAttr:$block_size
+  );
+
+  let results = (outs
+    Tosa_Tensor4D:$output
+  );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_MXFP_CONV]>,
+  ];
+
+  let hasVerifier = 1;
+  let hasCustomAssemblyFormat = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Operator: conv3d
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index bee253689bab7..d9de79e415292 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -149,6 +149,7 @@ class TosaProfileCompliance {
     case Extension::fp8e5m2:
     case Extension::fft:
     case Extension::mxfp:
+    case Extension::mxfp_conv:
       return {Profile::pro_fp};
     case Extension::variable:
     case Extension::controlflow:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index a105b58e57e2c..5b5189c84f4a3 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -203,6 +203,8 @@ def Tosa_Tensor1Dto6D : AnyTypeOf<[
 def Tosa_TensorUpto4D : AnyTypeOf<[
   Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;
 
+def Tosa_IndexTensor1D : AnyTypeOf<[
+  Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32, Tosa_Int64], [1]>]>;
 def Tosa_IndexTensor2D : AnyTypeOf<[
   Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32, Tosa_Int64], [2]>]>;
 
@@ -217,6 +219,14 @@ def Tosa_MXFPScaleTensor3D : AnyTypeOf<[
   TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>,
   TosaTensorRankOf<[Tosa_MXFPScaleNumber], [3]>
 ]>;
+def Tosa_MXFPDataTensor4D : AnyTypeOf<[
+  TosaUnrankedTensorOf<[Tosa_MXFPNumber]>,
+  TosaTensorRankOf<[Tosa_MXFPNumber], [4]>
+]>;
+def Tosa_MXFPScaleTensor4D : AnyTypeOf<[
+  TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>,
+  TosaTensorRankOf<[Tosa_MXFPScaleNumber], [4]>
+]>;
 def Tosa_MXFPDataTensorAtLeast1D : AnyTypeOf<[
   TosaUnrankedTensorOf<[Tosa_MXFPNumber]>,
   TosaRankedTensorOf<[Tosa_MXFPNumber], [AtLeastRankOne]>],
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
index 01f78f86d427b..9f616b223bf92 100644
--- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -43,6 +43,7 @@ TosaSpecificationVersion getMinVersion(const Extension &extension) {
     return TosaSpecificationVersion(1, 0);
   case Extension::mxfp:
   case Extension::int64:
+  case Extension::mxfp_conv:
   case Extension::shape:
     return TosaSpecificationVersion(1, 1);
   case Extension::none:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 6e934526d9035..192fbbcac6c15 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -550,6 +550,15 @@ void CastToBlockScaledOp::print(OpAsmPrinter &parser) {
   printWithEnumHandling(parser, *this);
 }
 
+ParseResult Conv2DBlockScaledOp::parse(OpAsmParser &parser,
+                                       OperationState &result) {
+  return parseWithEnumHandling<tosa::BlockSize>(parser, result);
+}
+
+void Conv2DBlockScaledOp::print(OpAsmPrinter &parser) {
+  printWithEnumHandling(parser, *this);
+}
+
 //===----------------------------------------------------------------------===//
 // Tosa utilities.
 //===----------------------------------------------------------------------===//
@@ -612,6 +621,55 @@ unsigned mlir::tosa::getBitWidth(Type type) {
   return type.getIntOrFloatBitWidth();
 }
 
+// Update dim size if current dim is dynamic, otherwise raise an error if sizes
+// do not match
+LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim,
+                                    const int64_t newDim,
+                                    const StringRef operandName,
+                                    const StringRef dimName) {
+  if (ShapedType::isDynamic(currDim)) {
+    currDim = newDim;
+    return success();
+  } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
+    return op->emitOpError("expected ")
+           << dimName << " of " << operandName << " to match size " << currDim
+           << ", got " << newDim;
+  }
+  return success();
+}
+
+LogicalResult verifyConvOutputSize(
+    Operation *op, const int64_t inputSize, const int64_t kernelSize,
+    const int64_t outputSize, const int64_t padBefore, const int64_t padAfter,
+    const int64_t stride, const int64_t dilation, const llvm::StringRef dimName,
+    const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName,
+    const llvm::StringRef padAfterName) {
+  if (inputSize == ShapedType::kDynamic || kernelSize == ShapedType::kDynamic)
+    return success();
+
+  // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
+
+  const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
+      inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
+      stride);
+  if (!calculatedOutSizeMinusOne.has_value())
+    return op->emitOpError("expected input_")
+           << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
+           << padAfterName << " - (kernel_" << dimName << " - 1) * dilation_"
+           << dimAxis << " to be wholly divisible by stride_" << dimAxis
+           << ", got (" << inputSize << " - 1 + " << padBefore << " + "
+           << padAfter << " - (" << kernelSize << " - 1) * " << dilation
+           << ") / " << stride;
+
+  const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
+  if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
+    return op->emitOpError("calculated output ")
+           << dimName << " did not match expected: "
+           << "calculated=" << calculatedOutSize << ", expected=" << outputSize;
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Operator Verifiers.
 //===----------------------------------------------------------------------===//
@@ -799,53 +857,16 @@ static LogicalResult verifyConvOpErrorIf(T op) {
       llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
 
   if (inputType && weightType) {
-    const auto verifyOutputSize =
-        [&op](const int64_t inputSize, const int64_t kernelSize,
-              const int64_t outputSize, const int64_t padBefore,
-              const int64_t padAfter, const int64_t stride,
-              const int64_t dilation, const llvm::StringRef dimName,
-              const llvm::StringRef dimAxis,
-              const llvm::StringRef padBeforeName,
-              const llvm::StringRef padAfterName) -> LogicalResult {
-      if (inputSize == ShapedType::kDynamic ||
-          kernelSize == ShapedType::kDynamic)
-        return success();
-
-      // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
-
-      const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
-          inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
-          stride);
-      if (!calculatedOutSizeMinusOne.has_value())
-        return op.emitOpError("expected input_")
-               << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
-               << padAfterName << " - (kernel_" << dimName
-               << " - 1) * dilation_" << dimAxis
-               << " to be wholly divisible by stride_" << dimAxis << ", got ("
-               << inputSize << " - 1 + " << padBefore << " + " << padAfter
-               << " - (" << kernelSize << " - 1) * " << dilation << ") / "
-               << stride;
-
-      const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
-      if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
-        return op.emitOpError("calculated output ")
-               << dimName << " did not match expected: "
-               << "calculated=" << calculatedOutSize
-               << ", expected=" << outputSize;
-
-      return success();
-    };
-
     // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
     if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
-      if (failed(verifyOutputSize(
-              inputType.getDimSize(1), weightType.getDimSize(1),
+      if (failed(verifyConvOutputSize(
+              op, inputType.getDimSize(1), weightType.getDimSize(1),
               outputType.getDimSize(1), padding[0], padding[1], strides[0],
               dilations[0], "height", "y", "top", "bottom")))
         return failure();
 
-      if (failed(verifyOutputSize(
-              inputType.getDimSize(2), weightType.getDimSize(2),
+      if (failed(verifyConvOutputSize(
+              op, inputType.getDimSize(2), weightType.getDimSize(2),
               outputType.getDimSize(2), padding[2], padding[3], strides[1],
               dilations[1], "width", "x", "left", "right")))
         return failure();
@@ -853,14 +874,14 @@ static LogicalResult verifyConvOpErrorIf(T op) {
 
     // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
     if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
-      if (failed(verifyOutputSize(
-              inputType.getDimSize(1), weightType.getDimSize(0),
+      if (failed(verifyConvOutputSize(
+              op, inputType.getDimSize(1), weightType.getDimSize(0),
               outputType.getDimSize(1), padding[0], padding[1], strides[0],
               dilations[0], "height", "y", "top", "bottom")))
         return failure();
 
-      if (failed(verifyOutputSize(
-              inputType.getDimSize(2), weightType.getDimSize(1),
+      if (failed(verifyConvOutputSize(
+              op, inputType.getDimSize(2), weightType.getDimSize(1),
               outputType.getDimSize(2), padding[2], padding[3], strides[1],
               dilations[1], "width", "x", "left", "right")))
         return failure();
@@ -868,20 +889,20 @@ static LogicalResult verifyConvOpErrorIf(T op) {
 
     // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
     if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
-      if (failed(verifyOutputSize(
-              inputType.getDimSize(1), weightType.getDimSize(1),
+      if (failed(verifyConvOutputSize(
+              op, inputType.getDimSize(1), weightType.getDimSize(1),
               outputType.getDimSize(1), padding[0], padding[1], strides[0],
               dilations[0], "depth", "d", "front", "back")))
         return failure();
 
-      if (failed(verifyOutputSize(
-              inputType.getDimSize(2), weightType.getDimSize(2),
+      if (failed(verifyConvOutputSize(
+              op, inputType.getDimSize(2), weightType.getDimSize(2),
               outputType.getDimSize(2), padding[2], padding[3], strides[1],
               dilations[1], "height", "y", "top", "bottom")))
         return failure();
 
-      if (failed(verifyOutputSize(
-              inputType.getDimSize(3), weightType.getDimSize(3),
+      if (failed(verifyConvOutputSize(
+              op, inputType.getDimSize(3), weightType.getDimSize(3),
               outputType.getDimSize(3), padding[4], padding[5], strides[2],
               dilations[2], "width", "x", "left", "right")))
         return failure();
@@ -1962,20 +1983,6 @@ LogicalResult MatmulTBlockScaledOp::verify() {
                                     "B_data")))
     return failure();
 
-  auto tryUpdateDimOrFailure = [&](int64_t &currDim, const int64_t newDim,
-                                   const StringRef operandName,
-                                   const StringRef dimName) -> LogicalResult {
-    if (ShapedType::isDynamic(currDim)) {
-      currDim = newDim;
-      return success();
-    } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
-      return emitOpError("expected ")
-             << dimName << " of " << operandName << " to match size " << currDim
-             << ", got " << newDim;
-    }
-    return success();
-  };
-
   // Verify input shape compatibility
   int64_t N = ShapedType::kDynamic;
   int64_t D = ShapedType::kDynamic;
@@ -1993,32 +2000,33 @@ LogicalResult MatmulTBlockScaledOp::verify() {
 
   const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType());
   if (aScaleShape.hasRank()) {
-    if (failed(tryUpdateDimOrFailure(N, aScaleShape.getDimSize(0), "a_scale",
-                                     "batch")) ||
-        failed(tryUpdateDimOrFailure(H, aScaleShape.getDimSize(1), "a_scale",
-                                     "height")))
+    if (failed(tryUpdateDimOrFailure(*this, N, aScaleShape.getDimSize(0),
+                                     "a_scale", "batch")) ||
+        failed(tryUpdateDimOrFailure(*this, H, aScaleShape.getDimSize(1),
+                                     "a_scale", "height")))
       return failure();
     multiplesOfC = aScaleShape.getDimSize(2);
   }
 
   const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType);
   if (bDataShape.hasRank()) {
-    if (failed(tryUpdateDimOrFailure(D, bDataShape.getDimSize(0), "b_data",
-                                     "batch")) ||
-        failed(tryUpdateDimOrFailure(C, bDataShape.getDimSize(2), "b_data",
-                                     "channels")))
+    if (failed(tryUpdateDimOrFailure(*this, D, bDataShape.getDimSize(0),
+                                     "b_data", "batch")) ||
+        failed(tryUpdateDimOrFailure(*this, C, bDataShape.getDimSize(2),
+                                     "b_data", "channels")))
       return failure();
     W = bDataShape.getDimSize(1);
   }
 
   const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType());
   if (bScaleShape.hasRank()) {
-    if (failed(tryUpdateDimOrFailure(D, bScaleShape.getDimSize(0), "b_scale",
-                                     "batch")) ||
-        failed(tryUpdateDimOrFailure(W, bScaleShape.getDimSize(1), "b_scale",
-                                     "width")) ||
-        failed(tryUpdateDimOrFailure(multiplesOfC, bScaleShape.getDimSize(2),
-                                     "b_scale", "C/block_size")))
+    if (failed(tryUpdateDimOrFailure(*this, D, bScaleShape.getDimSize(0),
+                                     "b_scale", "batch")) ||
+        failed(tryUpdateDimOrFailure(*this, W, bScaleShape.getDimSize(1),
+                                     "b_scale", "width")) ||
+        failed(tryUpdateDimOrFailure(*this, multiplesOfC,
+                                     bScaleShape.getDimSize(2), "b_scale",
+                                     "C/block_size")))
       return failure();
   }
 
@@ -3493,6 +3501,232 @@ LogicalResult Conv2DOp::verify() {
   return success();
 }
 
+LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents(
+    MLIRContext *context, ::std::optional<Location> location,
+    Conv2DBlockScaledOp::Adaptor adaptor,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  SmallVector<int64_t, 4> outShape(4, ShapedType::kDynamic);
+
+  int64_t inputWidth = ShapedType::kDynamic;
+  int64_t inputHeight = ShapedType::kDynamic;
+  int64_t weightWidth = ShapedType::kDynamic;
+  int64_t weightHeight = ShapedType::kDynamic;
+
+  // Input shape describes input width/height and batch.
+  const ShapeAdaptor inputDataShape(adaptor.getInputData().getType());
+  if (inputDataShape.hasRank()) {
+    outShape[0] = inputDataShape.getDimSize(0);
+    inputHeight = inputDataShape.getDimSize(1);
+    inputWidth = inputDataShape.getDimSize(2);
+  }
+  const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType());
+  if (inputScaleShape.hasRank()) {
+    outShape[0] = ShapedType::isDynamic(outShape[0])
+                      ? inputScaleShape.getDimSize(0)
+                      : outShape[0];
+    inputHeight = ShapedType::isDynamic(inputHeight)
+                      ? inputScaleShape.getDimSize(1)
+                      : inputHeight;
+    inputWidth = ShapedType::isDynamic(inputWidth)
+                     ? inputScaleShape.getDimSize(2)
+                     : inputWidth;
+  }
+
+  // Weight shapes describes the filter width/height and the output channels.
+  const ShapeAdaptor weightDataShape(adaptor.getWeightData().getType());
+  if (weightDataShape.hasRank()) {
+    outShape[3] = weightDataShape.getDimSize(0);
+    weightHeight = weightDataShape.getDimSize(1);
+    weightWidth = weightDataShape.getDimSize(2);
+  }
+  const ShapeAdaptor weightScaleShape(adaptor.getWeightScale().getType());
+  if (weightScaleShape.hasRank()) {
+    outShape[3] = ShapedType::isDynamic(outShape[3])
+                      ? weightScaleShape.getDimSize(0)
+                      : outShape[3];
+    weightHeight = ShapedType::isDynamic(weightHeight)
+                       ? weightScaleShape.getDimSize(1)
+                       : weightHeight;
+    weightWidth = ShapedType::isDynamic(weightWidth)
+                      ? weightScaleShape.getDimSize(2)
+                      : weightWidth;
+  }
+
+  // Bias shape can describe the output channels.
+  const ShapeAdaptor biasShape(adaptor.getBias().getType());
+  if (biasShape.hasRank()) {
+    const int64_t biasSize = biasShape.getDimSize(0);
+    // Bias of size 1 may be broadcast
+    if (biasSize != 1) {
+      outShape[3] = ShapedType::isDynamic(outShape[3]) ? biasSize : outShape[3];
+    }
+  }
+
+  SmallVector<int64_t> padValues;
+  SmallVector<int64_t> strideValues;
+  SmallVector<int64_t> dilationValues;
+  if (!tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(), padValues) ||
+      !tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(),
+                                 strideValues) ||
+      !tosa::getConstShapeValues(adaptor.getDilation().getDefiningOp(),
+                                 dilationValues)) {
+    inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
+    return success();
+  }
+
+  if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
+    const int64_t inputSize = inputHeight + padValues[0] + padValues[1];
+    const int64_t filterSize = (weightHeight - 1) * dilationValues[0] + 1;
+    const int64_t unstridedResult = inputSize - filterSize + 1;
+    outShape[1] = (unstridedResult - 1) / strideValues[0] + 1;
+  }
+
+  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
+    const int64_t inputSize = inputWidth + padValues[2] + padValues[3];
+    const int64_t filterSize = (weightWidth - 1) * dilationValues[1] + 1;
+    const int64_t unstridedResult = inputSize - filterSize + 1;
+    outShape[2] = (unstridedResult - 1) / strideValues[1] + 1;
+  }
+
+  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
+  return success();
+}
+
+LogicalResult Conv2DBlockScaledOp::verify() {
+  if (failed(verifySameElementTypes(*this, getInputData().getType(),
+                                    getWeightData().getType())) ||
+      failed(verifySameElementTypes(*this, getInputScale().getType(),
+                                    getWeightScale().getType())) ||
+      failed(verifySameElementTypes(*this, getBias().getType(),
+                                    getOutput().getType())))
+    return failure();
+
+  // Verify input shape compatibility
+  int64_t N = ShapedType::kDynamic;
+  int64_t IH = ShapedType::kDynamic;
+  int64_t IW = ShapedType::kDynamic;
+  int64_t IC = ShapedType::kDynamic;
+  int64_t multiplesOfIC = ShapedType::kDynamic;
+  int64_t OC = ShapedType::kDynamic;
+  int64_t KH = ShapedType::kDynamic;
+  int64_t KW = ShapedType::kDynamic;
+
+  const ShapeAdaptor inputDataShape(getInputData().getType());
+  if (inputDataShape.hasRank()) {
+    N = inputDataShape.getDimSize(0);
+    IH = inputDataShape.getDimSize(1);
+    IW = inputDataShape.getDimSize(2);
+    IC = inputDataShape.getDimSize(3);
+  }
+
+  const ShapeAdaptor inputScaleShape(getInputScale().getType());
+  if (inputScaleShape.hasRank()) {
+    if (failed(tryUpdateDimOrFailure(*this, N, inputScaleShape.getDimSize(0),
+                                     "input_scale", "batch size")) ||
+        failed(tryUpdateDimOrFailure(*this, IH, inputScaleShape.getDimSize(1),
+                                     "input_scale", "input height")) ||
+        failed(tryUpdateDimOrFailure(*this, IW, inputScaleShape.getDimSize(2),
+                                     "input_scale", "input width")))
+      return failure();
+    multiplesOfIC = inputScaleShape.getDimSize(3);
+  }
+
+  const ShapeAdaptor weightDataShape(getWeightData().getType());
+  if (weightDataShape.hasRank()) {
+    OC = weightDataShape.getDimSize(0);
+    KH = weightDataShape.getDimSize(1);
+    KW = weightDataShape.getDimSize(2);
+    if (failed(tryUpdateDimOrFailure(*this, IC, weightDataShape.getDimSize(3),
+                                     "weight_data", "input channels")))
+      return failure();
+  }
+
+  const ShapeAdaptor weightScaleShape(getWeightScale().getType());
+  if (weightScaleShape.hasRank()) {
+    if (failed(tryUpdateDimOrFailure(*this, OC, weightScaleShape.getDimSize(0),
+                                     "weight_scale", "output channels")) ||
+        failed(tryUpdateDimOrFailure(*this, KH, weightScaleShape.getDimSize(1),
+                                     "weight_scale", "kernel height")) ||
+        failed(tryUpdateDimOrFailure(*this, KW, weightScaleShape.getDimSize(2),
+                                     "weight_scale", "kernel width")) ||
+        failed(tryUpdateDimOrFailure(*this, multiplesOfIC,
+                                     weightScaleShape.getDimSize(3),
+                                     "weight_scale", "input channel blocks")))
+      return failure();
+  }
+
+  // Verify IC is a multiple of block size
+  const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+  if (ShapedType::isStatic(IC) && IC % blockSize != 0)
+    return emitOpError("expect IC to be a multiple of block size, got IC=")
+           << IC << ", block_size=" << blockSize;
+
+  // Verify multiplesOfIC is IC / block size
+  if (ShapedType::isStatic(IC) && ShapedType::isStatic(multiplesOfIC) &&
+      multiplesOfIC != IC / blockSize)
+    return emitOpError(
+               "expect scale operands dimension 2 to equal IC/block_size (")
+           << IC << "/" << blockSize << ")"
+           << ", got " << multiplesOfIC;
+
+  // Verify pad/stride/dilation values
+  SmallVector<int64_t> padValues;
+  if (tosa::getConstShapeValues(getPad().getDefiningOp(), padValues)) {
+    if (llvm::any_of(padValues, [](int64_t p) { return p < 0; }))
+      return emitOpError("expect all padding values to be >= 0, got ")
+             << padValues;
+  }
+
+  SmallVector<int64_t> strideValues;
+  if (tosa::getConstShapeValues(getStride().getDefiningOp(), strideValues)) {
+    if (llvm::any_of(strideValues, [](int64_t s) { return s < 1; }))
+      return emitOpError("expect all stride values to be >= 1, got ")
+             << strideValues;
+  }
+
+  SmallVector<int64_t> dilationValues;
+  if (tosa::getConstShapeValues(getDilation().getDefiningOp(),
+                                dilationValues)) {
+    if (llvm::any_of(dilationValues, [](int64_t d) { return d < 1; }))
+      return emitOpError("expect all dilation values to be >= 1, got ")
+             << dilationValues;
+  }
+
+  // Verify output shape compatibility
+  const ShapeAdaptor outputShape(getOutput().getType());
+  if (!padValues.empty() && !strideValues.empty() && !dilationValues.empty() &&
+      outputShape.hasRank()) {
+    if (failed(verifyConvOutputSize(*this, IH, KH, outputShape.getDimSize(1),
+                                    padValues[0], padValues[1], strideValues[0],
+                                    dilationValues[0], "height", "y", "top",
+                                    "bottom")) ||
+        failed(verifyConvOutputSize(*this, IW, KW, outputShape.getDimSize(2),
+                                    padValues[2], padValues[3], strideValues[1],
+                                    dilationValues[1], "width", "x", "left",
+                                    "right")))
+      return failure();
+  }
+
+  // Verify bias
+  const ShapeAdaptor biasShape(getBias().getType());
+  if (biasShape.hasRank() && outputShape.hasRank()) {
+    const int64_t biasChannels = biasShape.getDimSize(0);
+    const int64_t outputChannels =
+        outputShape.getDimSize(outputShape.getRank() - 1);
+    if (biasChannels == ShapedType::kDynamic ||
+        outputChannels == ShapedType::kDynamic)
+      // Skip following checks if biasChannels or outputChannels is dynamic dim
+      return success();
+
+    if (biasChannels != outputChannels && biasChannels != 1)
+      return emitOpError(
+                 "bias channels expected to be equal to output channels (")
+             << outputChannels << ") or 1, got " << biasChannels;
+  }
+
+  return success();
+}
+
 LogicalResult Conv3DOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     Conv3DOp::Adaptor adaptor,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 08c702bd2f29f..f26554fb5768a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -111,6 +111,18 @@ ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) {
   return populateProfileInfoConv(op);
 }
 
+template <>
+LogicalResult
+ProfileInfoDepot::populateProfileInfo(tosa::Conv2DBlockScaledOp op) {
+  addValue(op.getInputData());
+  addValue(op.getInputScale());
+  addValue(op.getWeightData());
+  addValue(op.getWeightScale());
+  addValue(op.getBias());
+  addValue(op.getOutput());
+  return success();
+}
+
 template <>
 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) {
   addValue(op.getInput1());
@@ -245,6 +257,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
   POPULATE_PROFILE_INFO_CUSTOM(AvgPool2d)
   POPULATE_PROFILE_INFO_CUSTOM(TransposeConv2D)
   POPULATE_PROFILE_INFO_CUSTOM(Conv2D)
+  POPULATE_PROFILE_INFO_CUSTOM(Conv2DBlockScaled)
   POPULATE_PROFILE_INFO_CUSTOM(Conv3D)
   POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D)
   POPULATE_PROFILE_INFO_CUSTOM(Mul)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 387d38411f0fe..f1798af9198ca 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -390,6 +390,55 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     return success();
   }
 
+  LogicalResult levelCheckConv2DBlockScaled(Operation *op) {
+    auto convOp = dyn_cast<Conv2DBlockScaledOp>(op);
+    if (!convOp)
+      return success();
+
+    DenseIntElementsAttr padding;
+    if (matchPattern(convOp.getPad(), m_Constant(&padding))) {
+      const SmallVector<int64_t> padValues = convertFromIntAttr(padding, 4);
+      for (const auto p : padValues)
+        if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL")))
+          return failure();
+    }
+
+    DenseIntElementsAttr stride;
+    if (matchPattern(convOp.getStride(), m_Constant(&stride))) {
+      const SmallVector<int64_t> strideValues = convertFromIntAttr(stride, 4);
+      for (const auto s : strideValues)
+        if (failed(levelCheckKernel(op, s, "stride <= MAX_KERNEL")))
+          return failure();
+    }
+
+    DenseIntElementsAttr dilation;
+    if (matchPattern(convOp.getDilation(), m_Constant(&dilation))) {
+      const SmallVector<int64_t> dilationValues =
+          convertFromIntAttr(dilation, 4);
+
+      int64_t KH = ShapedType::kDynamic;
+      int64_t KW = ShapedType::kDynamic;
+      const ShapeAdaptor weightDataShape(convOp.getWeightData().getType());
+      KH = weightDataShape.getDimSize(1);
+      KW = weightDataShape.getDimSize(2);
+      const ShapeAdaptor weightScaleShape(convOp.getWeightScale().getType());
+      KH = ShapedType::isDynamic(KH) ? weightScaleShape.getDimSize(1) : KH;
+      KW = ShapedType::isDynamic(KW) ? weightScaleShape.getDimSize(2) : KW;
+
+      if (!ShapedType::isDynamic(KH) &&
+          failed(levelCheckKernel(op, dilationValues[0] * KH,
+                                  "dilation_y * KH <= MAX_KERNEL)")))
+        return failure();
+
+      if (!ShapedType::isDynamic(KW) &&
+          failed(levelCheckKernel(op, dilationValues[1] * KW,
+                                  "dilation_x * KW <= MAX_KERNEL)")))
+        return failure();
+    }
+
+    return success();
+  }
+
   // FFT op: level check H, W in input shape [N,H,W]
   template <typename T>
   LogicalResult levelCheckFFT(Operation *op) {
@@ -700,6 +749,7 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
   // Tensor Operators
   CHECK_SIZES(AvgPool2d);
   CHECK_SIZES(Conv2D);
+  CHECK_SIZES(Conv2DBlockScaled);
   CHECK_SIZES(Conv3D);
   CHECK_SIZES(DepthwiseConv2D);
   CHECK_SIZES(TransposeConv2D);
@@ -800,7 +850,6 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
   if (failed(levelCheckRanksAndSizes(op)))
     return failure();
 
-  // additional level checks from spec 0.70
   if (failed(levelCheckPool<tosa::AvgPool2dOp>(op)) ||
       failed(levelCheckConv<tosa::Conv2DOp>(op)) ||
       failed(levelCheckConv<tosa::Conv3DOp>(op)) ||
@@ -808,7 +857,8 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
       failed(levelCheckFFT<tosa::FFT2dOp>(op)) ||
       failed(levelCheckPool<tosa::MaxPool2dOp>(op)) ||
       failed(levelCheckFFT<tosa::RFFT2dOp>(op)) ||
-      failed(levelCheckTransposeConv2d(op)) || failed(levelCheckResize(op))) {
+      failed(levelCheckTransposeConv2d(op)) || failed(levelCheckResize(op)) ||
+      failed(levelCheckConv2DBlockScaled(op))) {
     return failure();
   }
 
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index fb79070b3a8a4..aafb688750433 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -4,7 +4,7 @@
 // validation flow.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,int64,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,shape" -tosa-validate="strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,int64,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,mxfp_conv,shape" -tosa-validate="strict-op-spec-alignment"
 
 
 func.func @test_cast(%arg0: tensor<i1>) -> tensor<5xi32> {
@@ -2087,3 +2087,14 @@ func.func @test_slice_shape_non_const_size(%arg0: tensor<1xi32>) -> !tosa.shape<
   %3 = tosa.slice_shape %0, %1, %arg0 : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<3>
   return %3 : !tosa.shape<3>
 }
+
+// -----
+
+func.func @test_conv2d_block_scaled(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf4E2M1FN>, %arg3: tensor<*xf8E8M0FNU>, %arg4: tensor<*xf16>) -> tensor<*xf16> {
+  %0 = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %1 = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %2 = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.conv2d_block_scaled' op illegal: operation operand/result data types did not align with any profile or extension, got (fp4e2m1,fp8e8m0,fp4e2m1,fp8e8m0,f16,f16), did you mean (fp4e2m1,fp8e8m0,fp4e2m1,fp8e8m0,f32,f32)? Otherwise, please refer to the 'supported data types' for 'tosa.conv2d_block_scaled' in the specification.}}
+  %3 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %0, %1, %2 {block_size = BLOCK_SIZE_32} : (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf16>
+  return %3 : tensor<*xf16>
+}
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 9162f6167f7ec..2fa1fe71148d1 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -564,7 +564,7 @@ func.func @test_const_fp6e3m2(%arg0 : index) -> tensor<4xf6E3M2FN> {
 
 // -----
 
-func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+func.func @test_cast_from_block_scaled(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
   // expected-error at +1 {{'tosa.cast_from_block_scaled' op illegal: requires [mxfp] but not enabled in target}}
   %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf8E5M2>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32>
   return %0 : tensor<4x32xf32>
@@ -572,7 +572,7 @@ func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf8E5M2>, %arg1:
 
 // -----
 
-func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>) {
+func.func @test_cast_to_block_scaled(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>) {
   // expected-error at +1 {{'tosa.cast_to_block_scaled' op illegal: requires [mxfp] but not enabled in target}}
   %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>)
   return %0#0, %0#1 : tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>
@@ -607,3 +607,14 @@ func.func @test_min_shape() -> !tosa.shape<4> {
   %c = tosa.min_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
   return %c : !tosa.shape<4>
 }
+
+// -----
+
+func.func @test_conv2d_block_scaled(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf4E2M1FN>, %arg3: tensor<*xf8E8M0FNU>, %arg4: tensor<*xf32>) -> tensor<*xf32> {
+  %0 = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %1 = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %2 = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.conv2d_block_scaled' op illegal: requires [mxfp_conv] but not enabled in target}}
+  %3 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %0, %1, %2 {block_size = BLOCK_SIZE_32} : (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+  return %3 : tensor<*xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index dd5ece417cf9e..3ebf0ff8a2f69 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1755,3 +1755,74 @@ func.func @test_mod_shape_invalid_rank() -> !tosa.shape<17> {
   %c = tosa.mod_shape %a, %b : (!tosa.shape<17>, !tosa.shape<17>) -> !tosa.shape<17>
   return %c : !tosa.shape<17>
 }
+
+// -----
+
+func.func @test_conv2d_block_scaled_invalid_size(%arg0: tensor<67108864x4x4x64xf4E2M1FN>, %arg1: tensor<67108864x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>) -> tensor<67108864x4x4x8xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.conv2d_block_scaled' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = BLOCK_SIZE_32} : (tensor<67108864x4x4x64xf4E2M1FN>, tensor<67108864x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<67108864x4x4x8xf32>
+  return %0 : tensor<67108864x4x4x8xf32>
+}
+
+// -----
+
+func.func @test_conv2d_block_scaled_dilation_y(%arg0: tensor<1x8191x8191x32xf8E4M3FN>, %arg1: tensor<1x8191x8191x1xf8E8M0FNU>, %arg2: tensor<16x1025x1024x32xf8E4M3FN>, %arg3: tensor<16x1025x1024x1xf8E8M0FNU>, %arg4: tensor<16xf32>) -> tensor<1x9x7178x16xf32> {
+  %pad = tosa.const_shape {values = dense<[10, 0, 10, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[8, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.conv2d_block_scaled' op failed level check: dilation_y * KH <= MAX_KERNEL}}
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = BLOCK_SIZE_32} :
+            (tensor<1x8191x8191x32xf8E4M3FN>, tensor<1x8191x8191x1xf8E8M0FNU>, tensor<16x1025x1024x32xf8E4M3FN>, tensor<16x1025x1024x1xf8E8M0FNU>, tensor<16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x9x7178x16xf32>
+  return %0 : tensor<1x9x7178x16xf32>
+}
+
+// -----
+
+func.func @test_conv2d_block_scaled_dilation_x(%arg0: tensor<1x8191x8191x32xf8E4M3FN>, %arg1: tensor<1x8191x8191x1xf8E8M0FNU>, %arg2: tensor<16x1024x1025x32xf8E4M3FN>, %arg3: tensor<16x1024x1025x1xf8E8M0FNU>, %arg4: tensor<16xf32>) -> tensor<1x7178x9x16xf32> {
+  %pad = tosa.const_shape {values = dense<[10, 0, 10, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 8]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.conv2d_block_scaled' op failed level check: dilation_x * KW <= MAX_KERNEL}}
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = BLOCK_SIZE_32} :
+            (tensor<1x8191x8191x32xf8E4M3FN>, tensor<1x8191x8191x1xf8E8M0FNU>, tensor<16x1024x1025x32xf8E4M3FN>, tensor<16x1024x1025x1xf8E8M0FNU>, tensor<16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x7178x9x16xf32>
+  return %0 : tensor<1x7178x9x16xf32>
+}
+
+// -----
+
+func.func @test_conv2d_block_scaled_pad_top(%arg0: tensor<1x32x32x32xf8E4M3FN>, %arg1: tensor<1x32x32x1xf8E8M0FNU>, %arg2: tensor<16x2x2x32xf8E4M3FN>, %arg3: tensor<16x2x2x1xf8E8M0FNU>, %arg4: tensor<16xf32>) -> tensor<1x8224x31x16xf32> {
+  %pad = tosa.const_shape {values = dense<[8193, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.conv2d_block_scaled' op failed level check: pad <= MAX_KERNEL}}
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} :
+            (tensor<1x32x32x32xf8E4M3FN>, tensor<1x32x32x1xf8E8M0FNU>, tensor<16x2x2x32xf8E4M3FN>, tensor<16x2x2x1xf8E8M0FNU>, tensor<16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x8224x31x16xf32>
+  return %0 : tensor<1x8224x31x16xf32>
+}
+
+// -----
+
+func.func @test_conv2d_block_scaled_pad_right(%arg0: tensor<1x32x32x32xf8E4M3FN>, %arg1: tensor<1x32x32x1xf8E8M0FNU>, %arg2: tensor<16x2x2x32xf8E4M3FN>, %arg3: tensor<16x2x2x1xf8E8M0FNU>, %arg4: tensor<16xf32>) -> tensor<1x31x8224x16xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 8193]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.conv2d_block_scaled' op failed level check: pad <= MAX_KERNEL}}
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} :
+            (tensor<1x32x32x32xf8E4M3FN>, tensor<1x32x32x1xf8E8M0FNU>, tensor<16x2x2x32xf8E4M3FN>, tensor<16x2x2x1xf8E8M0FNU>, tensor<16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x31x8224x16xf32>
+  return %0 : tensor<1x31x8224x16xf32>
+}
+
+// -----
+
+func.func @test_conv2d_block_scaled_stride_y(%arg0: tensor<1x8194x33x32xf8E4M3FN>, %arg1: tensor<1x8194x33x1xf8E8M0FNU>, %arg2: tensor<16x2x2x32xf8E4M3FN>, %arg3: tensor<16x2x2x1xf8E8M0FNU>, %arg4: tensor<16xf32>) -> tensor<1x2x32x16xf32> {
+  %pad = tosa.const_shape {values = dense<[1, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[8193, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.conv2d_block_scaled' op failed level check: stride <= MAX_KERNEL}}
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} :
+            (tensor<1x8194x33x32xf8E4M3FN>, tensor<1x8194x33x1xf8E8M0FNU>, tensor<16x2x2x32xf8E4M3FN>, tensor<16x2x2x1xf8E8M0FNU>, tensor<16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x2x32x16xf32>
+  return %0 : tensor<1x2x32x16xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 986bdf019a613..626d2b6caafd1 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -1558,3 +1558,23 @@ func.func @test_min_shape() -> !tosa.shape<4> {
   %c = tosa.min_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
   return %c : !tosa.shape<4>
 }
+
+// -----
+// CHECK-LABEL: test_conv2d_block_scaled_static
+func.func @test_conv2d_block_scaled_static(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>) -> tensor<*xf32> {
+  %0 = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %1 = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %2 = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %3 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %0, %1, %2 {block_size = BLOCK_SIZE_32} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+  return %3 : tensor<*xf32>
+}
+
+// -----
+// CHECK-LABEL: test_conv2d_block_scaled_dynamic
+func.func @test_conv2d_block_scaled_dynamic(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf4E2M1FN>, %arg3: tensor<*xf8E8M0FNU>, %arg4: tensor<*xf32>) -> tensor<*xf32> {
+  %0 = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %1 = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %2 = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %3 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %0, %1, %2 {block_size = BLOCK_SIZE_32} : (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+  return %3 : tensor<*xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
index 7de7b85bcaedf..15aad410c6f44 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
@@ -2,7 +2,7 @@
 // Enable all supported extensions to focus the verification of expected profile requirement errors.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround,mxfp" -tosa-validate="strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround,mxfp,mxfp_conv" -tosa-validate="strict-op-spec-alignment"
 
 // -----
 func.func @test_const_f16() -> tensor<3x11x11x3xf16> {
@@ -334,15 +334,25 @@ func.func @test_matmul_t_block_scaled(%arg0: tensor<4x8x32xf6E3M2FN>, %arg1: ten
 }
 
 // -----
-func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+func.func @test_cast_from_block_scaled(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
   // expected-error at +1 {{'tosa.cast_from_block_scaled' op illegal: requires [pro_fp] but not enabled in target}}
   %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32>
   return %0 : tensor<4x32xf32>
 }
 
 // -----
-func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) {
+func.func @test_cast_to_block_scaled(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) {
   // expected-error at +1 {{'tosa.cast_to_block_scaled' op illegal: requires [pro_fp] but not enabled in target}}
   %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
   return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>
 }
+
+// -----
+func.func @test_conv2d_block_scaled(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf4E2M1FN>, %arg3: tensor<*xf8E8M0FNU>, %arg4: tensor<*xf32>) -> tensor<*xf32> {
+  %0 = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %1 = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %2 = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.conv2d_block_scaled' op illegal: requires [pro_fp] but not enabled in target}}
+  %3 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %0, %1, %2 {block_size = BLOCK_SIZE_32} : (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+  return %3 : tensor<*xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 54556a0eb08e0..b74540f060cfe 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1673,3 +1673,51 @@ func.func @test_cast_to_block_scaled_dynamic_scales(%arg0: tensor<4x?xf32>) -> (
   %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x?xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
   return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>
 }
+
+// -----
+
+// CHECK-LABEL: test_conv2d_block_scaled_static
+func.func @test_conv2d_block_scaled_static(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK: -> tensor<1x4x4x8xf32>
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_conv2d_block_scaled_dynamic_scales
+func.func @test_conv2d_block_scaled_dynamic_scales(%arg0: tensor<?x4x4x64xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<?x1x1x64xf4E2M1FN>, %arg3: tensor<*xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK: -> tensor<?x4x4x?xf32>
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<?x4x4x64xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<?x1x1x64xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_conv2d_block_scaled_dynamic_data
+func.func @test_conv2d_block_scaled_dynamic_data(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<*xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK: -> tensor<1x4x4x8xf32>
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<*xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_conv2d_block_scaled_dynamic_unranked
+func.func @test_conv2d_block_scaled_dynamic_unranked(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf4E2M1FN>, %arg3: tensor<*xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK: -> tensor<?x?x?x?xf32>
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index 4edddfff49f24..ef52b90f194de 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,mxfp,int64,shape" -tosa-validate="strict-op-spec-alignment" | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,mxfp,int64,mxfp_conv,shape" -tosa-validate="strict-op-spec-alignment" | FileCheck %s
 
 // -----
 
@@ -232,3 +232,14 @@ func.func @test_mod_shape() -> !tosa.shape<3> {
   %c = tosa.mod_shape %a, %b : (!tosa.shape<3>, !tosa.shape<3>) -> !tosa.shape<3>
   return %c : !tosa.shape<3>
 }
+
+// -----
+
+// CHECK-LABEL: test_conv2d_block_scaled
+func.func @test_conv2d_block_scaled(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = BLOCK_SIZE_32} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x4x4x8xf32>
+  return %0 : tensor<1x4x4x8xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index e444664cf2b93..c45a699e3440b 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1335,3 +1335,135 @@ func.func @test_mod_shape_input1_input2_rank_mismatch() -> !tosa.shape<6> {
   %c = tosa.mod_shape %a, %b : (!tosa.shape<6>, !tosa.shape<5>) -> !tosa.shape<6>
   return %c : !tosa.shape<6>
 }
+
+// -----
+
+func.func @test_conv2d_block_scaled_data_type_mismatch(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf8E4M3FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>) -> tensor<*xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.conv2d_block_scaled' op expect input and output to have same element type, got 'f4E2M1FN' and 'f8E4M3FN'}}
+  %3 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = BLOCK_SIZE_32} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf8E4M3FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+  return %3 : tensor<*xf32>
+}
+
+// -----
+
+func.func @test_conv2d_block_scaled_bias_output_type_mismatch(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf16>) -> tensor<*xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.conv2d_block_scaled' op expect input and output to have same element type, got 'f16' and 'f32'}}
+  %3 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = BLOCK_SIZE_32} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+  return %3 : tensor<*xf32>
+}
+
+// -----
+
+func.func @test_conv2d_block_scaled_invalid_padding(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, -1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.conv2d_block_scaled' op expect all padding values to be >= 0, got 0, 0, 0, -1}}
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+func.func @test_conv2d_block_scaled_invalid_stride(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[0, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.conv2d_block_scaled' op expect all stride values to be >= 1, got 0, 1}}
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+func.func @test_conv2d_block_scaled_invalid_dilation(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 0]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.conv2d_block_scaled' op expect all dilation values to be >= 1, got 1, 0}}
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+func.func @test_conv2d_block_scaled_input_width_mismatch(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x5x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<1x4x4x8xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.conv2d_block_scaled' op expected input width of input_scale to match size 4, got 5}}
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x5x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x4x4x8xf32>
+  return %0 : tensor<1x4x4x8xf32>
+}
+
+// -----
+
+func.func @test_conv2d_block_scaled_kernel_height_mismatch(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x2x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<1x4x4x8xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.conv2d_block_scaled' op expected kernel height of weight_scale to match size 2, got 1}}
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x2x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x4x4x8xf32>
+  return %0 : tensor<1x4x4x8xf32>
+}
+
+// -----
+
+func.func @test_conv2d_block_scaled_output_shape_indivisible(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<1x4x5x8xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.conv2d_block_scaled' op expected input_width - 1 + pad_left + pad_right - (kernel_width - 1) * dilation_x to be wholly divisible by stride_x, got (4 - 1 + 0 + 0 - (1 - 1) * 1) / 2}}
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x4x5x8xf32>
+  return %0 : tensor<1x4x5x8xf32>
+}
+
+// -----
+
+func.func @test_conv2d_block_scaled_output_shape_mismatch(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<1x4x5x8xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.conv2d_block_scaled' op calculated output width did not match expected: calculated=4, expected=5}}
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x4x5x8xf32>
+  return %0 : tensor<1x4x5x8xf32>
+}
+
+// -----
+
+func.func @test_conv2d_block_scaled_invalid_ic(%arg0: tensor<1x4x4x63xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x63xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<1x4x5x8xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.conv2d_block_scaled' op expect IC to be a multiple of block size, got IC=63, block_size=32}}
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x4x4x63xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x63xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x4x5x8xf32>
+  return %0 : tensor<1x4x5x8xf32>
+}
+
+// -----
+
+func.func @test_conv2d_block_scaled_invalid_ic_mutiple(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x3xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x3xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<1x4x5x8xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.conv2d_block_scaled' op expect scale operands dimension 2 to equal IC/block_size (64/32), got 3}}
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x3xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x3xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x4x5x8xf32>
+  return %0 : tensor<1x4x5x8xf32>
+}
+
+// -----
+
+func.func @test_conv2d_block_scaled_invalid_bias_size(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<6xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<1x4x4x8xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.conv2d_block_scaled' op bias channels expected to be equal to output channels (8) or 1, got 6}}
+  %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<6xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x4x4x8xf32>
+  return %0 : tensor<1x4x4x8xf32>
+}

>From 4d40ab94d62a81d217dcd62fec0781b3f46622cb Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Fri, 19 Dec 2025 11:46:09 +0000
Subject: [PATCH 2/3] update api used to get const shape values in level check

Change-Id: I1de3504fdf75332f5cebf4a2bfaf8e28d8755267
---
 .../Tosa/Transforms/TosaValidation.cpp        | 19 ++++++++-----------
 1 file changed, 8 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index f1798af9198ca..5f26adabf409c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -395,27 +395,24 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     if (!convOp)
       return success();
 
-    DenseIntElementsAttr padding;
-    if (matchPattern(convOp.getPad(), m_Constant(&padding))) {
-      const SmallVector<int64_t> padValues = convertFromIntAttr(padding, 4);
+    SmallVector<int64_t> padValues;
+    if (tosa::getConstShapeValues(convOp.getPad().getDefiningOp(), padValues)) {
       for (const auto p : padValues)
         if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL")))
           return failure();
     }
 
-    DenseIntElementsAttr stride;
-    if (matchPattern(convOp.getStride(), m_Constant(&stride))) {
-      const SmallVector<int64_t> strideValues = convertFromIntAttr(stride, 4);
+    SmallVector<int64_t> strideValues;
+    if (tosa::getConstShapeValues(convOp.getStride().getDefiningOp(),
+                                  strideValues)) {
       for (const auto s : strideValues)
         if (failed(levelCheckKernel(op, s, "stride <= MAX_KERNEL")))
           return failure();
     }
 
-    DenseIntElementsAttr dilation;
-    if (matchPattern(convOp.getDilation(), m_Constant(&dilation))) {
-      const SmallVector<int64_t> dilationValues =
-          convertFromIntAttr(dilation, 4);
-
+    SmallVector<int64_t> dilationValues;
+    if (tosa::getConstShapeValues(convOp.getDilation().getDefiningOp(),
+                                  dilationValues)) {
       int64_t KH = ShapedType::kDynamic;
       int64_t KW = ShapedType::kDynamic;
       const ShapeAdaptor weightDataShape(convOp.getWeightData().getType());

>From 539ab87ab1b7928537cc5b037530855075601d4c Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Thu, 15 Jan 2026 18:27:09 +0000
Subject: [PATCH 3/3] Updated datatype mismatch error message

Change-Id: Ib2b8163423b239fd387d4529647ee0a85895e996
---
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 8 +++++---
 mlir/test/Dialect/Tosa/verifier.mlir | 4 ++--
 2 files changed, 7 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 192fbbcac6c15..4aaadf28d7a61 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -3594,11 +3594,13 @@ LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents(
 
 LogicalResult Conv2DBlockScaledOp::verify() {
   if (failed(verifySameElementTypes(*this, getInputData().getType(),
-                                    getWeightData().getType())) ||
+                                    getWeightData().getType(), "input_data",
+                                    "weight_data")) ||
       failed(verifySameElementTypes(*this, getInputScale().getType(),
-                                    getWeightScale().getType())) ||
+                                    getWeightScale().getType(), "input_scale",
+                                    "weight_scale")) ||
       failed(verifySameElementTypes(*this, getBias().getType(),
-                                    getOutput().getType())))
+                                    getOutput().getType(), "bias", "output")))
     return failure();
 
   // Verify input shape compatibility
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index c45a699e3440b..e16a12b94b923 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1342,7 +1342,7 @@ func.func @test_conv2d_block_scaled_data_type_mismatch(%arg0: tensor<1x4x4x64xf4
   %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
   %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
   %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
-  // expected-error at +1 {{'tosa.conv2d_block_scaled' op expect input and output to have same element type, got 'f4E2M1FN' and 'f8E4M3FN'}}
+  // expected-error at +1 {{'tosa.conv2d_block_scaled' op expect input_data and weight_data to have same element type, got 'f4E2M1FN' and 'f8E4M3FN'}}
   %3 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = BLOCK_SIZE_32} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf8E4M3FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
   return %3 : tensor<*xf32>
 }
@@ -1353,7 +1353,7 @@ func.func @test_conv2d_block_scaled_bias_output_type_mismatch(%arg0: tensor<1x4x
   %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
   %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
   %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
-  // expected-error at +1 {{'tosa.conv2d_block_scaled' op expect input and output to have same element type, got 'f16' and 'f32'}}
+  // expected-error at +1 {{'tosa.conv2d_block_scaled' op expect bias and output to have same element type, got 'f16' and 'f32'}}
   %3 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = BLOCK_SIZE_32} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
   return %3 : tensor<*xf32>
 }



More information about the Mlir-commits mailing list