[Mlir-commits] [mlir] [mlir][tosa] Add TOSA RESHAPE_BLOCK_SCALED support (PR #191149)

Jeremy Johnson llvmlistbot at llvm.org
Wed Apr 15 05:50:52 PDT 2026


https://github.com/jjohnson-arm updated https://github.com/llvm/llvm-project/pull/191149

>From 4c895610fd3d86c32a1fe6d1776e8b9cd4c784de Mon Sep 17 00:00:00 2001
From: Jeremy Johnson <jeremy.johnson at arm.com>
Date: Tue, 24 Mar 2026 17:04:03 +0000
Subject: [PATCH 1/4] Add TOSA RESHAPE_BLOCK_SCALED support

Experimental operator support, with no validation.

Signed-off-by: Jeremy Johnson <jeremy.johnson at arm.com>
Change-Id: I3d1f10a1b7765d849be0f71a8634f9a6d0077d69
---
 .../mlir/Dialect/Tosa/IR/TosaOpBase.td        |   6 +-
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  |  38 +++
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 223 ++++++++++++++++++
 .../Tosa/Transforms/TosaValidation.cpp        |   1 +
 mlir/test/Dialect/Tosa/level_check.mlir       |  37 +++
 mlir/test/Dialect/Tosa/ops.mlir               |  48 ++++
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir |  97 ++++++++
 mlir/test/Dialect/Tosa/verifier.mlir          | 218 +++++++++++++++++
 8 files changed, 665 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 1f05aee3e5eec..591073e9985ae 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -484,11 +484,12 @@ def Tosa_RoundingModeAttr
     : Tosa_I32EnumAttr<"RoundingMode", "Supported rounding modes", "rounding_mode",
                     [Tosa_ROUNDING_SINGLE_ROUND, Tosa_ROUNDING_INEXACT_ROUND, Tosa_ROUNDING_DOUBLE_ROUND]>;
 
+def Tosa_BLOCK_SIZE_1 : I32EnumAttrCase<"BLOCK_SIZE_1", 1>;
 def Tosa_BLOCK_SIZE_32 : I32EnumAttrCase<"BLOCK_SIZE_32", 32>;
 
 def Tosa_BlockSizeAttr
-    : Tosa_I32EnumAttr<"BlockSize", "Block size for the block_scaled formats", "block_size",
-                    [Tosa_BLOCK_SIZE_32]> {
+    : Tosa_I32EnumAttr<"BlockSize", "Block size for the block_scaled formats",
+                       "block_size", [Tosa_BLOCK_SIZE_1, Tosa_BLOCK_SIZE_32]> {
   let extraClassDeclaration = [{
     static uint32_t getBlockSizeValue(BlockSize blockSize) {
       return static_cast<uint32_t>(blockSize);
@@ -496,7 +497,6 @@ def Tosa_BlockSizeAttr
   }];
 }
 
-
 //===----------------------------------------------------------------------===//
 // TOSA Interfaces.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 5ac91e6b65457..951969d3f34f5 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2270,6 +2270,44 @@ def Tosa_ReshapeOp : Tosa_InferTensorTypeOp<"reshape", [Pure]> {
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 }
 
+//===----------------------------------------------------------------------===//
+// Operator: reshape_block_scaled
+//===----------------------------------------------------------------------===//
+def Tosa_ReshapeBlockScaledOp
+    : Tosa_InferTensorTypeOp<"reshape_block_scaled", [Pure]> {
+  let summary = "Reshape with support for block scaled tensors.";
+
+  let description = [{
+    Returns a tensor-list with the same type/values as the input, with a new
+    shape specified by the shape argument. Reshape may operate on block-scaled
+    or non-block-scaled tensors of any rank. No data conversion happens during
+    a reshape operation. Reshape must retain the relationship between values
+    and their scale in a block for block-scaled content.
+  }];
+
+  let hasVerifier = 1;
+
+  let arguments = (ins Variadic<Tosa_Tensor>:$input,
+      Tosa_Shape:$new_value_shape, Tosa_BlockSizeAttr:$block_size);
+
+  let results = (outs Variadic<Tosa_Tensor>:$output);
+
+  list<Availability> availability = [Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+                                     // NOTE: Validation of extensions is
+                                     // disabled for this op
+                                     Extension<[]>,
+  ];
+
+  let extraClassDeclaration = [{
+    /// Returns true when two result types are compatible for this op;
+    /// Method used by InferTypeOpInterface.
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
+
+  let assemblyFormat =
+      "operands attr-dict `:` functional-type(operands, results)";
+}
+
 //===----------------------------------------------------------------------===//
 // Operator: reverse
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 29318023092a1..76eb59ecf3cad 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2784,6 +2784,229 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
   return mlir::success();
 }
 
+bool tosa::ReshapeBlockScaledOp::isCompatibleReturnTypes(TypeRange l,
+                                                         TypeRange r) {
+  if (l.size() != r.size() || l.size() < 1 || l.size() > 2)
+    return false;
+  bool ok = (getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]));
+  if (l.size() == 2)
+    ok = ok && (getElementTypeOrSelf(l[1]) == getElementTypeOrSelf(r[1]));
+  return ok;
+}
+
+LogicalResult tosa::ReshapeBlockScaledOp::inferReturnTypeComponents(
+    MLIRContext *context, ::std::optional<Location> location,
+    ReshapeBlockScaledOp::Adaptor adaptor,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+
+  const auto numInputs = adaptor.getInput().size();
+  ShapeAdaptor inputShape(adaptor.getInput()[0].getType());
+  Type inputType = getElementTypeOrSelf(adaptor.getInput()[0].getType());
+  llvm::SmallVector<int64_t> newShapeValue;
+  const auto newShape = adaptor.getNewValueShape();
+  if (!tosa::getConstShapeValues(newShape.getDefiningOp(), newShapeValue)) {
+    auto rank = cast<tosa::shapeType>(newShape.getType()).getRank();
+    SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
+    inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
+    if (numInputs == 2)
+      inferredReturnShapes.push_back(ShapedTypeComponents(
+          fallback, getElementTypeOrSelf(adaptor.getInput()[1].getType())));
+    return success();
+  }
+
+  const uint32_t blockSize =
+      BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
+
+  llvm::SmallVector<int64_t> newScaleShapeValue;
+  if (numInputs == 2) {
+    newScaleShapeValue.assign(newShapeValue.begin(), newShapeValue.end());
+    if (ShapedType::isStatic(newScaleShapeValue.back()))
+      newScaleShapeValue.back() /= blockSize;
+  }
+
+  inferredReturnShapes.push_back(
+      ShapedTypeComponents(newShapeValue, inputType));
+  if (numInputs == 2) {
+    // Fix up scale shape - with special case for last dimension
+    for (size_t idx = 0; idx < newShapeValue.size(); idx++) {
+      if (ShapedType::isDynamic(newScaleShapeValue[idx])) {
+        newScaleShapeValue[idx] = newShapeValue[idx];
+        if (idx == (newShapeValue.size() - 1))
+          newScaleShapeValue[idx] /= blockSize;
+      }
+    }
+
+    inferredReturnShapes.push_back(ShapedTypeComponents(
+        newScaleShapeValue,
+        getElementTypeOrSelf(adaptor.getInput()[1].getType())));
+  }
+  return success();
+}
+
+llvm::LogicalResult tosa::ReshapeBlockScaledOp::verify() {
+  const Operation::operand_range inputList = getInput();
+  const Operation::result_range outputList = getResults();
+
+  if (inputList.size() == 0)
+    return emitOpError("requires at least one input");
+
+  if (inputList.size() > 2)
+    return emitOpError("requires at most two inputs");
+
+  if (inputList.size() != outputList.size())
+    return emitOpError("requires number of results to match inputs");
+
+  if (verifySameElementTypes(*this, /* inType = */ inputList[0].getType(),
+                             /* outType = */ outputList[0].getType())
+          .failed()) {
+    return failure();
+  }
+
+  const auto inputType = llvm::cast<ShapedType>(inputList[0].getType());
+  if (!inputType.hasRank())
+    return success();
+  const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+
+  if (inputList.size() == 2) {
+    if (llvm::any_of(inputList, [](Value v) {
+          const auto input = cast<ShapedType>(v.getType());
+          return input.hasRank() && input.getRank() == 0;
+        }))
+      return emitOpError(
+          "requires all input shapes have a rank greater than 0");
+    if (llvm::any_of(outputList, [](Value v) {
+          const auto output = cast<ShapedType>(v.getType());
+          return output.hasRank() && output.getRank() == 0;
+        }))
+      return emitOpError(
+          "requires all result shapes have a rank greater than 0");
+
+    if (verifySameElementTypes(*this, /* inType = */ inputList[1].getType(),
+                               /* outType = */ outputList[1].getType())
+            .failed()) {
+      return failure();
+    }
+
+    const auto inputScaleType = llvm::cast<ShapedType>(inputList[1].getType());
+    if (inputScaleType.hasRank()) {
+      if (inputType.getRank() != inputScaleType.getRank())
+        return emitOpError("input shapes do not have same rank");
+
+      // Check all but the last dimension that the input shape dimensions match
+      for (auto dimIdx = 0; dimIdx < inputType.getRank() - 1; dimIdx++) {
+        const int64_t inputValueDim = inputType.getDimSize(dimIdx);
+        const int64_t inputScaleDim = inputScaleType.getShape()[dimIdx];
+        if (ShapedType::isStatic(inputValueDim) &&
+            ShapedType::isStatic(inputScaleDim) &&
+            inputValueDim != inputScaleDim)
+          return emitOpError("input shapes for data and scale do not match on "
+                             "dimension ")
+                 << dimIdx;
+      }
+
+      // Verify last dimension of input is a multiple of block size
+      const int64_t lastValueDim =
+          inputType.getDimSize(inputType.getRank() - 1);
+      if (ShapedType::isStatic(lastValueDim)) {
+        if (lastValueDim % blockSize != 0)
+          return emitOpError("expect last dimension of input_data (")
+                 << lastValueDim << ") to be divisible by block_size ("
+                 << blockSize << ")";
+
+        const int64_t lastScaleDim =
+            inputScaleType.getDimSize(inputScaleType.getRank() - 1);
+        // Verify last dimension of scale is lastValueDim / block size
+        if (ShapedType::isStatic(lastScaleDim) &&
+            lastScaleDim != lastValueDim / blockSize)
+          return emitOpError("expect last dimension of scale_data (")
+                 << lastScaleDim << ") to be " << lastValueDim << "/"
+                 << blockSize;
+      }
+    }
+  }
+
+  // Get the new value shape dimension values
+  SmallVector<int64_t> shapeValues;
+  if (!tosa::getConstShapeValues(getNewValueShape().getDefiningOp(),
+                                 shapeValues)) {
+    // skip following checks if shape is not constant
+    return mlir::success();
+  }
+
+  if (inputList.size() == 2) {
+    if (static_cast<int64_t>(shapeValues.size()) == 0)
+      return emitOpError("requires new shape to have a rank greater than 0");
+
+    const int64_t lastShapeDim = shapeValues.back();
+    if (ShapedType::isStatic(lastShapeDim) && lastShapeDim % blockSize != 0)
+      return emitOpError("expect last dimension of new shape (")
+             << lastShapeDim << ") to be divisible by block_size (" << blockSize
+             << ")";
+  }
+
+  const auto outputType = llvm::cast<ShapedType>(outputList[0].getType());
+  if (!outputType.hasRank())
+    return success();
+
+  if (static_cast<int64_t>(shapeValues.size()) != outputType.getRank())
+    return emitOpError() << "result does not match new shape rank";
+
+  for (auto [newShapeDim, outputShapeDim] :
+       zip(shapeValues, outputType.getShape())) {
+    if (ShapedType::isStatic(newShapeDim) &&
+        ShapedType::isStatic(outputShapeDim) && newShapeDim != outputShapeDim)
+      return emitOpError() << "result shape is inconsistent with new shape";
+  }
+
+  if (outputList.size() == 2) {
+    // Set up scale shape from new shape given
+    SmallVector<int64_t> scaleShapeValues(shapeValues.begin(),
+                                          shapeValues.end());
+    scaleShapeValues.back() /= blockSize;
+
+    const auto outputScaleType =
+        llvm::cast<ShapedType>(outputList[1].getType());
+    if (outputScaleType.hasRank()) {
+      if ((int64_t)scaleShapeValues.size() != outputScaleType.getRank())
+        return emitOpError() << "result scale does not match new shape rank";
+
+      for (auto [newScaleShapeDim, outputScaleShapeDim] :
+           zip(scaleShapeValues, outputScaleType.getShape())) {
+        if (ShapedType::isStatic(newScaleShapeDim) &&
+            ShapedType::isStatic(outputScaleShapeDim) &&
+            newScaleShapeDim != outputScaleShapeDim)
+          return emitOpError()
+                 << "result scale shape is inconsistent with new shape";
+      }
+    }
+  }
+
+  if (inputType.hasStaticShape()) {
+    int64_t inputElementsNum = inputType.getNumElements();
+    if (outputType.hasStaticShape()) {
+      int64_t outputElementsNum = outputType.getNumElements();
+      if (inputElementsNum != outputElementsNum) {
+        return emitOpError() << "cannot reshape " << inputElementsNum
+                             << " elements into " << outputElementsNum;
+      }
+    }
+
+    int64_t newShapeElementsNum =
+        llvm::accumulate(shapeValues, int64_t(1), [](int64_t acc, int64_t dim) {
+          return (dim > 0) ? acc * dim : acc;
+        });
+    bool isStaticNewShape =
+        llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
+    if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
+        (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
+      return emitOpError() << "cannot reshape " << inputElementsNum
+                           << " elements into " << newShapeElementsNum;
+    }
+  }
+
+  return mlir::success();
+}
+
 // return failure if val is not a constant
 // set zp to -1 if val is non-zero float or val is not integer nor float
 // otherwise set zp to val's constant value
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 6169003881487..bdfde330d73f3 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -781,6 +781,7 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
   CHECK_RANKS_AND_SIZES(Concat);
   CHECK_RANKS_AND_SIZES(Pad);
   CHECK_RANKS_AND_SIZES(Reshape);
+  CHECK_RANKS_AND_SIZES(ReshapeBlockScaled);
   CHECK_RANKS_AND_SIZES(Reverse);
   CHECK_RANKS_AND_SIZES(Slice);
   CHECK_RANKS_AND_SIZES(Tile);
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index b3bdb02c20103..8887c8b5ecc70 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -395,6 +395,43 @@ func.func @test_reshape_rank_invalid(%arg0: tensor<13x21x3xf32>) -> tensor<1x1x1
   return %0 : tensor<1x1x1x1x1x1x819xf32>
 }
 
+
+// -----
+
+func.func @test_reshape_non_block_scaled_output_rank_invalid(%arg0: tensor<13x21x3xf32>) -> tensor<1x1x1x1x1x1x819xf32> {
+  %1 = tosa.const_shape {values = dense<[1, 1, 1, 1, 1, 1, 819]> : tensor<7xindex>} : () -> !tosa.shape<7>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op failed level check: result rank(shape) <= MAX_RANK}}
+  %0 = tosa.reshape_block_scaled %arg0, %1 {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<13x21x3xf32>, !tosa.shape<7>) -> tensor<1x1x1x1x1x1x819xf32>
+  return %0 : tensor<1x1x1x1x1x1x819xf32>
+}
+
+// -----
+
+func.func @test_reshape_non_block_scaled_input_rank_invalid(%arg0: tensor<1x1x1x1x1x1x819xf32>) -> tensor<13x21x3xf32> {
+  %1 = tosa.const_shape {values = dense<[13, 21, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.reshape_block_scaled %arg0, %1 {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<1x1x1x1x1x1x819xf32>, !tosa.shape<3>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_output_rank_invalid(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> (tensor<1x1x1x1x1x2x64xf4E2M1FN>, tensor<1x1x1x1x1x2x2xf8E8M0FNU>) {
+  %1 = tosa.const_shape {values = dense<[1, 1, 1, 1, 1, 2, 64]> : tensor<7xindex>} : () -> !tosa.shape<7>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op failed level check: result rank(shape) <= MAX_RANK}}
+  %0:2 = tosa.reshape_block_scaled %arg0, %arg1, %1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>, !tosa.shape<7>) -> (tensor<1x1x1x1x1x2x64xf4E2M1FN>, tensor<1x1x1x1x1x2x2xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<1x1x1x1x1x2x64xf4E2M1FN>, tensor<1x1x1x1x1x2x2xf8E8M0FNU>
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_input_rank_invalid(%arg0: tensor<1x1x1x1x1x4x32xf4E2M1FN>, %arg1: tensor<1x1x1x1x1x4x1xf8E8M0FNU>) -> (tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>) {
+  %1 = tosa.const_shape {values = dense<[2, 64]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0:2 = tosa.reshape_block_scaled %arg0, %arg1, %1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<1x1x1x1x1x4x32xf4E2M1FN>, tensor<1x1x1x1x1x4x1xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>
+}
+
 // -----
 
 func.func @test_reverse_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index e80d3d84a8105..507c77c54d0da 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -784,6 +784,54 @@ func.func @test_reshape_unranked_output(%arg0: tensor<13x21x3xf32>) -> tensor<*x
   return %0 : tensor<*xf32>
 }
 
+// -----
+// CHECK-LABEL: reshape_non_block_scaled
+func.func @test_reshape_non_block_scaled(%arg0: tensor<13x21x3xf32>) -> tensor<1x819xf32> {
+  %1 = tosa.const_shape {values = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %0 = tosa.reshape_block_scaled %arg0, %1 {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<13x21x3xf32>, !tosa.shape<2>) -> tensor<1x819xf32>
+  return %0 : tensor<1x819xf32>
+}
+
+// -----
+// CHECK-LABEL: reshape_non_block_scaled_unranked_input
+func.func @test_reshape_non_block_scaled_unranked_input(%arg0: tensor<*xf32>) -> tensor<1x819xf32> {
+  %1 = tosa.const_shape {values = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %0 = tosa.reshape_block_scaled %arg0, %1 {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<*xf32>, !tosa.shape<2>) -> tensor<1x819xf32>
+  return %0 : tensor<1x819xf32>
+}
+
+// -----
+// CHECK-LABEL: reshape_non_block_scaled_unranked_output
+func.func @test_reshape_non_block_scaled_unranked_output(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> {
+  %1 = tosa.const_shape {values = dense<[21, 13, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %0 = tosa.reshape_block_scaled %arg0, %1 {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<13x21x3xf32>, !tosa.shape<3>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+// CHECK-LABEL: reshape_block_scaled
+func.func @test_reshape_block_scaled(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> (tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>) {
+  %1 = tosa.const_shape {values = dense<[2, 64]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %0:2 = tosa.reshape_block_scaled %arg0, %arg1, %1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>
+}
+
+// -----
+// CHECK-LABEL: reshape_block_scaled_unranked_input
+func.func @test_reshape_block_scaled_unranked_input(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>) -> (tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>) {
+  %1 = tosa.const_shape {values = dense<[2, 64]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %0:2 = tosa.reshape_block_scaled %arg0, %arg1, %1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>
+}
+
+// -----
+// CHECK-LABEL: reshape_block_scaled_unranked_output
+func.func @test_reshape_block_scaled_unranked_output(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) {
+  %1 = tosa.const_shape {values = dense<[2, 64]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %0:2 = tosa.reshape_block_scaled %arg0, %arg1, %1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>
+}
+
 // -----
 // CHECK-LABEL: reverse
 func.func @test_reverse(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 408300fa7034b..18d117c4692fe 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -450,6 +450,103 @@ func.func @test_dynamic_reshape(%arg0 : tensor<4x?xi32>) -> () {
 
 // -----
 
+// CHECK-LABEL: @test_static_reshape_non_block_scaled
+func.func @test_static_reshape_non_block_scaled(%arg0 : tensor<4x4xi32>) -> () {
+  // CHECK-DAG: %[[CONSTSHAPE1:.+]] = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK-DAG: %[[CONSTSHAPE2:.+]] = tosa.const_shape {values = dense<[2, 8]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK-DAG: tosa.reshape_block_scaled %arg0, %[[CONSTSHAPE1]] {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
+  // CHECK-DAG: tosa.reshape_block_scaled %arg0, %[[CONSTSHAPE2]] {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<4x4xi32>, !tosa.shape<2>) -> tensor<2x8xi32>
+  %0 = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %1 = tosa.reshape_block_scaled %arg0, %0 {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
+  %2 = tosa.const_shape {values = dense<[2, 8]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %3 = tosa.reshape_block_scaled %arg0, %2 {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<4x4xi32>, !tosa.shape<2>) -> tensor<2x8xi32>
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_dynamic_reshape_non_block_scaled
+func.func @test_dynamic_reshape_non_block_scaled(%arg0 : tensor<4x?xi32>) -> () {
+  // CHECK-DAG: %[[CONSTSHAPE1:.+]] = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK-DAG: %[[CONSTSHAPE2:.+]] = tosa.const_shape {values = dense<[2, 8]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK-DAG: tosa.reshape_block_scaled %arg0, %[[CONSTSHAPE1]] {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<16xi32>
+  // CHECK-DAG: tosa.reshape_block_scaled %arg0, %[[CONSTSHAPE2]] {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<4x?xi32>, !tosa.shape<2>) -> tensor<2x8xi32>
+  %0 = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %1 = tosa.reshape_block_scaled %arg0, %0 {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<?xi32>
+  %2 = tosa.const_shape {values = dense<[2, 8]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %3 = tosa.reshape_block_scaled %arg0, %2 {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<4x?xi32>, !tosa.shape<2>) -> tensor<?x?xi32>
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_unranked_reshape_non_block_scaled
+func.func @test_unranked_reshape_non_block_scaled(%arg0 : tensor<4x4xi32>) -> () {
+  // CHECK-DAG: %[[CONSTSHAPE1:.+]] = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK-DAG: %[[CONSTSHAPE2:.+]] = tosa.const_shape {values = dense<[2, 8]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK-DAG: tosa.reshape_block_scaled %arg0, %[[CONSTSHAPE1]] {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
+  // CHECK-DAG: tosa.reshape_block_scaled %arg0, %[[CONSTSHAPE2]] {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<4x4xi32>, !tosa.shape<2>) -> tensor<2x8xi32>
+  %0 = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %1 = tosa.reshape_block_scaled %arg0, %0 {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<*xi32>
+  %2 = tosa.const_shape {values = dense<[2, 8]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %3 = tosa.reshape_block_scaled %arg0, %2 {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<4x4xi32>, !tosa.shape<2>) -> tensor<*xi32>
+
+  return
+}
+
+
+// -----
+
+// CHECK-LABEL: @test_static_reshape_block_scaled
+func.func @test_static_reshape_block_scaled(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> () {
+  // CHECK-DAG: %[[CONSTSHAPE1:.+]] = tosa.const_shape {values = dense<128> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK-DAG: %[[CONSTSHAPE2:.+]] = tosa.const_shape {values = dense<[2, 64]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK-DAG: tosa.reshape_block_scaled %arg0, %arg1, %[[CONSTSHAPE1]] {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>, !tosa.shape<1>) -> (tensor<128xf4E2M1FN>, tensor<4xf8E8M0FNU>)
+  // CHECK-DAG: tosa.reshape_block_scaled %arg0, %arg1, %[[CONSTSHAPE2]] {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>)
+  %0 = tosa.const_shape {values = dense<128> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %1:2 = tosa.reshape_block_scaled %arg0, %arg1, %0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>, !tosa.shape<1>) -> (tensor<128xf4E2M1FN>, tensor<4xf8E8M0FNU>)
+  %2 = tosa.const_shape {values = dense<[2, 64]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %3:2 = tosa.reshape_block_scaled %arg0, %arg1, %2 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>)
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_dynamic_reshape_block_scaled
+func.func @test_dynamic_reshape_block_scaled(%arg0: tensor<4x?xf4E2M1FN>, %arg1: tensor<?x1xf8E8M0FNU>) -> () {
+  // CHECK-DAG: %[[CONSTSHAPE1:.+]] = tosa.const_shape {values = dense<128> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK-DAG: %[[CONSTSHAPE2:.+]] = tosa.const_shape {values = dense<[2, 64]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK-DAG: tosa.reshape_block_scaled %arg0, %arg1, %[[CONSTSHAPE1]] {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x?xf4E2M1FN>, tensor<?x1xf8E8M0FNU>, !tosa.shape<1>) -> (tensor<128xf4E2M1FN>, tensor<4xf8E8M0FNU>)
+  // CHECK-DAG: tosa.reshape_block_scaled %arg0, %arg1, %[[CONSTSHAPE2]] {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x?xf4E2M1FN>, tensor<?x1xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>)
+  %0 = tosa.const_shape {values = dense<128> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %1:2 = tosa.reshape_block_scaled %arg0, %arg1, %0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x?xf4E2M1FN>, tensor<?x1xf8E8M0FNU>, !tosa.shape<1>) -> (tensor<?xf4E2M1FN>, tensor<?xf8E8M0FNU>)
+  %2 = tosa.const_shape {values = dense<[2, 64]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %3:2 = tosa.reshape_block_scaled %arg0, %arg1, %2 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x?xf4E2M1FN>, tensor<?x1xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x?xf4E2M1FN>, tensor<?x2xf8E8M0FNU>)
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_unranked_reshape_block_scaled
+func.func @test_unranked_reshape_block_scaled(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> () {
+  // CHECK-DAG: %[[CONSTSHAPE1:.+]] = tosa.const_shape {values = dense<128> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK-DAG: %[[CONSTSHAPE2:.+]] = tosa.const_shape {values = dense<[2, 64]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK-DAG: tosa.reshape_block_scaled %arg0, %arg1, %[[CONSTSHAPE1]] {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>, !tosa.shape<1>) -> (tensor<128xf4E2M1FN>, tensor<4xf8E8M0FNU>)
+  // CHECK-DAG: tosa.reshape_block_scaled %arg0, %arg1, %[[CONSTSHAPE2]] {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>)
+  %0 = tosa.const_shape {values = dense<128> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %1:2 = tosa.reshape_block_scaled %arg0, %arg1, %0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>, !tosa.shape<1>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
+  %2 = tosa.const_shape {values = dense<[2, 64]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %3:2 = tosa.reshape_block_scaled %arg0, %arg1, %2 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
+
+  return
+}
+
+// -----
+
 // CHECK: @test_reduce_binary
 func.func @test_reduce_binary(%arg0 : tensor<2x3x?x?xi1>) -> () {
   // CHECK: tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<2x3x?x?xi1>) -> tensor<1x3x?x?xi1>
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 80d5bca039909..d62aa09734416 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1537,6 +1537,8 @@ func.func @test_conv2d_block_scaled_invalid_bias_size(%arg0: tensor<1x4x4x64xf4E
   return %0 : tensor<1x4x4x8xf32>
 }
 
+// -----
+
 func.func @test_missmatched_ranks() {
   %0 = tosa.const_shape {values = dense<[10]> : tensor<1xindex>} : () -> !tosa.shape<1>
   %1 = tosa.const_shape {values = dense<[10, 15]> : tensor<2xindex>} : () -> !tosa.shape<2>
@@ -1544,3 +1546,219 @@ func.func @test_missmatched_ranks() {
   tosa.assert_equal_shape %0, %1 {allow_broadcast = true} : (!tosa.shape<1>, !tosa.shape<2>) -> ()
   return
 }
+
+// -----
+
+func.func @test_reshape_block_scaled_static_zero_dim_input(%arg0 : tensor<13x0x3xf32>) -> () {
+  %s = tosa.const_shape {values = dense<[13, 21, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op operand #0 must be variadic of tosa-conformant tensor of number values, but got 'tensor<13x0x3xf32>'}}
+  %0 = "tosa.reshape_block_scaled"(%arg0, %s) {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<13x0x3xf32>, !tosa.shape<3>) -> tensor<13x0x3xf32>
+  return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_zero_dim_input(%arg0 : tensor<?x0x3xf32>) -> () {
+  %s = tosa.const_shape {values = dense<[13, 21, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op operand #0 must be variadic of tosa-conformant tensor of number values, but got 'tensor<?x0x3xf32>'}}
+  %0 = "tosa.reshape_block_scaled"(%arg0, %s) {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<?x0x3xf32>, !tosa.shape<3>) -> tensor<13x0x3xf32>
+  return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_zero_dim_input(%arg0 : tensor<?x0x3xf32>) -> () {
+  %1 = tosa.const_shape {values = dense<[13, 21, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op operand #0 must be variadic of tosa-conformant tensor of number values, but got 'tensor<?x0x3xf32>'}}
+  %0 = "tosa.reshape_block_scaled"(%arg0, %1) {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<?x0x3xf32>, !tosa.shape<3>) -> tensor<13x0x3xf32>
+  return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_invalid_tensor_dim(%arg0 : tensor<4x?xf32>) -> () {
+  %s = tosa.const_shape {values = dense<[1, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op result shape is inconsistent with new shape}}
+  %0 = "tosa.reshape_block_scaled" (%arg0, %s) {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<4x?xf32>, !tosa.shape<2>) -> tensor<?x4xf32>
+  return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_invalid_size(%arg0 : tensor<2x4xf32>) -> () {
+  %s = tosa.const_shape {values = dense<[3, 5]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op cannot reshape 8 elements into 15}}
+  %0 = "tosa.reshape_block_scaled"(%arg0, %s) {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<2x4xf32>, !tosa.shape<2>) -> tensor<3x5xf32>
+  return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_invalid_newshape(%arg0 : tensor<1xf32>) -> () {
+  %s = tosa.const_shape {values = dense<[1, 4]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op cannot reshape 1 elements into 4}}
+  %0 = "tosa.reshape_block_scaled"(%arg0, %s) {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<1xf32>, !tosa.shape<2>) -> tensor<?x4xf32>
+  return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_inconsistent_dynamic_result(%arg0 : tensor<?xf32>) -> () {
+  %s = tosa.const_shape {values = dense<[2, 4, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op result shape is inconsistent with new shape}}
+  %0 = "tosa.reshape_block_scaled"(%arg0, %s) {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<?xf32>, !tosa.shape<3>) -> tensor<?x3x5xf32>
+  return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_rank_mismatch(%arg0 : tensor<?xf32>) -> () {
+  %s = tosa.const_shape {values = dense<[2, 4]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op result does not match new shape rank}}
+  %0 = "tosa.reshape_block_scaled"(%arg0, %s) {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<?xf32>, !tosa.shape<2>) -> tensor<?xf32>
+  return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_types_mismatch(%arg0 : tensor<2x4xf32>) -> () {
+  %s = tosa.const_shape {values = dense<[8]> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op expect input and output to have same element type, got 'f32' and 'i32'}}
+  %0 = "tosa.reshape_block_scaled"(%arg0, %s) {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<2x4xf32>, !tosa.shape<1>) -> tensor<8xi32>
+  return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_outputs_mismatch_inputs(%arg0 : tensor<64xf8E4M3FN>) -> () {
+  %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op requires number of results to match inputs}}
+  %0:2 = "tosa.reshape_block_scaled"(%arg0, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>, tensor<2x1xf8E8M0FNU>)
+  return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_inputs_mismatch_outputs(%arg0 : tensor<64xf8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>) -> () {
+  %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op requires number of results to match inputs}}
+  %0 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, tensor<2xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>)
+  return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_no_inputs() -> () {
+  %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op requires at least one input}}
+  %0 = "tosa.reshape_block_scaled"(%s) {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (!tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>)
+  return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_too_many_inputs(%arg0 : tensor<64xf8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>, %arg2 : tensor<2xf8E8M0FNU>) -> () {
+  %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op requires at most two inputs}}
+  %0 = "tosa.reshape_block_scaled"(%arg0, %arg1, %arg2, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, tensor<2xf8E8M0FNU>, tensor<2xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>)
+  return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_scale_types_mismatch(%arg0 : tensor<64xf8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>) -> () {
+  %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op expect input and output to have same element type, got 'f8E8M0FNU' and 'f8E4M3FN'}}
+  %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, tensor<2xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>, tensor<2x1xf8E4M3FN>)
+  return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_input_ranks_mismatch(%arg0 : tensor<64xf8E4M3FN>, %arg1 : tensor<2x1xf8E8M0FNU>) -> () {
+  %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op input shapes do not have same rank}}
+  %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, tensor<2x1xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>, tensor<2x1xf8E8M0FNU>)
+  return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_scale_dims_mismatch(%arg0 : tensor<1x64xf8E4M3FN>, %arg1 : tensor<2x2xf8E8M0FNU>) -> () {
+  %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op input shapes for data and scale do not match on dimension 0}}
+  %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<1x64xf8E4M3FN>, tensor<2x2xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>, tensor<2x1xf8E8M0FNU>)
+  return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_block_size_mismatch(%arg0 : tensor<60xf8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>) -> () {
+  %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op expect last dimension of input_data (60) to be divisible by block_size (32)}}
+  %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<60xf8E4M3FN>, tensor<2xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>, tensor<2x1xf8E8M0FNU>)
+  return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_scale_blocks_mismatch(%arg0 : tensor<64xf8E4M3FN>, %arg1 : tensor<3xf8E8M0FNU>) -> () {
+  %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op expect last dimension of scale_data (3) to be 64/32}}
+  %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, tensor<3xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>, tensor<2x1xf8E8M0FNU>)
+  return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_new_shape_block_size_mismatch(%arg0 : tensor<64xf8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>) -> () {
+  %s = tosa.const_shape {values = dense<[2, 30]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op expect last dimension of new shape (30) to be divisible by block_size (32)}}
+  %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, tensor<2xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x30xf8E4M3FN>, tensor<2x1xf8E8M0FNU>)
+  return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_inconsistent_scale_output_rank(%arg0 : tensor<64xf8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>) -> () {
+  %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op result scale does not match new shape rank}}
+  %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, tensor<2xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>, tensor<2xf8E8M0FNU>)
+  return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_inconsistent_scale(%arg0 : tensor<64xf8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>) -> () {
+  %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op result scale shape is inconsistent with new shape}}
+  %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, tensor<2xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>, tensor<3x?xf8E8M0FNU>)
+  return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_rank0_input(%arg0 : tensor<f8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>) -> () {
+  %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op requires all input shapes have a rank greater than 0}}
+  %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<f8E4M3FN>, tensor<2xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>, tensor<2x2xf8E8M0FNU>)
+  return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_rank0_scale_output(%arg0 : tensor<64xf8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>) -> () {
+  %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op requires all result shapes have a rank greater than 0}}
+  %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, tensor<2xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>, tensor<f8E8M0FNU>)
+  return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_rank0_scale_input(%arg0 : tensor<64xf8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>) -> () {
+  %s = tosa.const_shape {values = dense<> : tensor<0xindex>} : () -> !tosa.shape<0>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op requires new shape to have a rank greater than 0}}
+  %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, tensor<2xf8E8M0FNU>, !tosa.shape<0>) -> (tensor<2x32xf8E4M3FN>, tensor<2x2xf8E8M0FNU>)
+  return
+}

>From 809ed65aa97fb31f2ef0a69635ced83b478b862e Mon Sep 17 00:00:00 2001
From: Jeremy Johnson <jeremy.johnson at arm.com>
Date: Mon, 13 Apr 2026 16:58:25 +0100
Subject: [PATCH 2/4] Fix up RESHAPE_BLOCK_SCALED

Add better decription on tensor usage
Add checking for BLOCK_SIZE_1 and BLOCK_SIZE_32 usage

Change-Id: I4a625ffdd637bd7b391b631ff8a6e1e1453f13c7
Signed-off-by: Jeremy Johnson <jeremy.johnson at arm.com>
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td |  3 +-
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp         | 15 ++++-
 mlir/test/Dialect/Tosa/verifier.mlir         | 59 +++++++++++++++++++-
 3 files changed, 72 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 951969d3f34f5..1f14c43136529 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2280,7 +2280,8 @@ def Tosa_ReshapeBlockScaledOp
   let description = [{
     Returns a tensor-list with the same type/values as the input, with a new
     shape specified by the shape argument. Reshape may operate on block-scaled
-    or non-block-scaled tensors of any rank. No data conversion happens during
+    tensors (values tensor followed by scale tensor) of rank 1 or higher; or a
+    single non-block-scaled tensor of any rank. No data conversion happens during
     a reshape operation. Reshape must retain the relationship between values
     and their scale in a block for block-scaled content.
   }];
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 76eb59ecf3cad..9d85f65e65da3 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2160,6 +2160,8 @@ LogicalResult MatmulTBlockScaledOp::verify() {
 
   // Verify C is a multiple of block size
   const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+  if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
+    return emitOpError("expect block size to be 32, got ") << blockSize;
   if (ShapedType::isStatic(C) && C % blockSize != 0)
     return emitOpError("expect C to be a multiple of block size, got C=")
            << C << ", block_size=" << blockSize;
@@ -2868,6 +2870,8 @@ llvm::LogicalResult tosa::ReshapeBlockScaledOp::verify() {
   const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
 
   if (inputList.size() == 2) {
+    if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
+        return emitOpError("expect block size to be 32, got ") << blockSize;
     if (llvm::any_of(inputList, [](Value v) {
           const auto input = cast<ShapedType>(v.getType());
           return input.hasRank() && input.getRank() == 0;
@@ -2923,6 +2927,9 @@ llvm::LogicalResult tosa::ReshapeBlockScaledOp::verify() {
                  << blockSize;
       }
     }
+  } else {
+    if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_1))
+      return emitOpError("expect block size to be 1, got ") << blockSize;
   }
 
   // Get the new value shape dimension values
@@ -4173,8 +4180,10 @@ LogicalResult Conv2DBlockScaledOp::verify() {
       return failure();
   }
 
-  // Verify IC is a multiple of block size
   const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+  if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
+      return emitOpError("expect block size to be 32, got ") << blockSize;
+  // Verify IC is a multiple of block size
   if (ShapedType::isStatic(IC) && IC % blockSize != 0)
     return emitOpError("expect IC to be a multiple of block size, got IC=")
            << IC << ", block_size=" << blockSize;
@@ -4723,6 +4732,8 @@ LogicalResult CastFromBlockScaledOp::verify() {
   if (inputDataShape.hasRank()) {
     const unsigned int blockSize =
         BlockSizeAttr::getBlockSizeValue(getBlockSize());
+    if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
+      return emitOpError("expect block size to be 32, got ") << blockSize;
     const int64_t inputDataLastDim =
         inputDataShape.getDimSize(inputDataShape.getRank() - 1);
     if (inputDataLastDim % blockSize != 0)
@@ -4796,6 +4807,8 @@ LogicalResult CastToBlockScaledOp::verify() {
 
   const unsigned int blockSize =
       BlockSizeAttr::getBlockSizeValue(getBlockSize());
+  if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
+    return emitOpError("expect block size to be 32, got ") << blockSize;
   const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
   if (inputDataShape.hasRank()) {
     const int64_t inputDataLastDim =
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index d62aa09734416..0666feb5b367e 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1195,6 +1195,14 @@ func.func @test_matmul_t_block_scaled_batch_mismatch(%arg0: tensor<4x8x32xf8E4M3
 
 // -----
 
+func.func @test_matmul_t_block_scaled_block_size_mismatch(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E4M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+  // expected-error at +1 {{'tosa.matmul_t_block_scaled' op expect block size to be 32, got 1}}
+  %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E4M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+  return %0 : tensor<4x8x16xf32>
+}
+
+// -----
+
 func.func @cast_from_block_scaled_incompatible_input_output_shape(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<5x32xf32> {
   // expected-error at +1 {{'tosa.cast_from_block_scaled' op require compatible shapes for input_data ('tensor<4x32xf4E2M1FN>') and output_data ('tensor<5x32xf32>')}}
   %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<5x32xf32>
@@ -1235,6 +1243,14 @@ func.func @cast_from_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4x32
 
 // -----
 
+func.func @test_cast_from_block_scaled_block_size_mismatch(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+  // expected-error at +1 {{'tosa.cast_from_block_scaled' op expect block size to be 32, got 1}}
+  %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32>
+  return %0 : tensor<4x32xf32>
+}
+
+// -----
+
 func.func @test_cast_to_block_scaled_incompatible_input_output_shape(%arg0: tensor<4x32xf32>) -> (tensor<5x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) {
   // expected-error at +1 {{'tosa.cast_to_block_scaled' op require compatible shapes for input_data ('tensor<4x32xf32>') and output_data ('tensor<5x32xf4E2M1FN>')}}
   %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<5x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
@@ -1275,6 +1291,14 @@ func.func @test_cast_to_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4
 
 // -----
 
+func.func @test_cast_to_block_scaled_block_size_mismatch(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) {
+  // expected-error at +1 {{'tosa.cast_to_block_scaled' op expect block size to be 32, got 1}}
+  %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>
+}
+
+// -----
+
 func.func @test_clamp_quantized(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) {
     // expected-error at +1 {{'tosa.clamp' op min/max attributes types are incompatible with input/output element types.}}
     %0 = tosa.clamp %arg0 {max_val = 127 : i8, min_val = -128 : i8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
@@ -1539,6 +1563,17 @@ func.func @test_conv2d_block_scaled_invalid_bias_size(%arg0: tensor<1x4x4x64xf4E
 
 // -----
 
+func.func @test_conv2d_block_scaled_block_size_mismatch(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>) -> tensor<*xf32> {
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.conv2d_block_scaled' op expect block size to be 32, got 1}}
+  %3 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size =  #tosa.block_size<BLOCK_SIZE_1>} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+  return %3 : tensor<*xf32>
+}
+
+// -----
+
 func.func @test_missmatched_ranks() {
   %0 = tosa.const_shape {values = dense<[10]> : tensor<1xindex>} : () -> !tosa.shape<1>
   %1 = tosa.const_shape {values = dense<[10, 15]> : tensor<2xindex>} : () -> !tosa.shape<2>
@@ -1633,7 +1668,7 @@ func.func @test_reshape_block_scaled_types_mismatch(%arg0 : tensor<2x4xf32>) ->
 func.func @test_reshape_block_scaled_outputs_mismatch_inputs(%arg0 : tensor<64xf8E4M3FN>) -> () {
   %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
   // expected-error at +1 {{'tosa.reshape_block_scaled' op requires number of results to match inputs}}
-  %0:2 = "tosa.reshape_block_scaled"(%arg0, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>, tensor<2x1xf8E8M0FNU>)
+  %0:2 = "tosa.reshape_block_scaled"(%arg0, %s) {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<64xf8E4M3FN>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>, tensor<2x1xf8E8M0FNU>)
   return
 }
 
@@ -1693,7 +1728,7 @@ func.func @test_reshape_block_scaled_scale_dims_mismatch(%arg0 : tensor<1x64xf8E
 
 // -----
 
-func.func @test_reshape_block_scaled_block_size_mismatch(%arg0 : tensor<60xf8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>) -> () {
+func.func @test_reshape_block_scaled_block_size_dim_mismatch(%arg0 : tensor<60xf8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>) -> () {
   %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
   // expected-error at +1 {{'tosa.reshape_block_scaled' op expect last dimension of input_data (60) to be divisible by block_size (32)}}
   %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<60xf8E4M3FN>, tensor<2xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>, tensor<2x1xf8E8M0FNU>)
@@ -1711,7 +1746,7 @@ func.func @test_reshape_block_scaled_scale_blocks_mismatch(%arg0 : tensor<64xf8E
 
 // -----
 
-func.func @test_reshape_block_scaled_new_shape_block_size_mismatch(%arg0 : tensor<64xf8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>) -> () {
+func.func @test_reshape_block_scaled_new_shape_block_size_dim_mismatch(%arg0 : tensor<64xf8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>) -> () {
   %s = tosa.const_shape {values = dense<[2, 30]> : tensor<2xindex>} : () -> !tosa.shape<2>
   // expected-error at +1 {{'tosa.reshape_block_scaled' op expect last dimension of new shape (30) to be divisible by block_size (32)}}
   %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, tensor<2xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x30xf8E4M3FN>, tensor<2x1xf8E8M0FNU>)
@@ -1762,3 +1797,21 @@ func.func @test_reshape_block_scaled_rank0_scale_input(%arg0 : tensor<64xf8E4M3F
   %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, tensor<2xf8E8M0FNU>, !tosa.shape<0>) -> (tensor<2x32xf8E4M3FN>, tensor<2x2xf8E8M0FNU>)
   return
 }
+
+// -----
+
+func.func @test_reshape_non_block_scaled_block_size_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor<1x819xf32> {
+  %s = tosa.const_shape {values = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op expect block size to be 1, got 32}}
+  %0 = "tosa.reshape_block_scaled"(%arg0, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<13x21x3xf32>, !tosa.shape<2>) -> tensor<1x819xf32>
+  return %0 : tensor<1x819xf32>
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_block_size_mismatch(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> (tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>) {
+  %s = tosa.const_shape {values = dense<[2, 64]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.reshape_block_scaled' op expect block size to be 32, got 1}}
+  %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>
+}

>From 0f33e145bbe9561ace28570b490631ec7f33eed4 Mon Sep 17 00:00:00 2001
From: Jeremy Johnson <jeremy.johnson at arm.com>
Date: Wed, 15 Apr 2026 08:54:37 +0100
Subject: [PATCH 3/4] Missed formatting

Change-Id: I5959d4586a9b3a7cbc5e2bdc5794baf1dcbda0ee
---
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 621e6c2b71bc1..35ffe6c9eccf7 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2906,7 +2906,7 @@ llvm::LogicalResult tosa::ReshapeBlockScaledOp::verify() {
 
   if (inputList.size() == 2) {
     if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
-        return emitOpError("expect block size to be 32, got ") << blockSize;
+      return emitOpError("expect block size to be 32, got ") << blockSize;
     if (llvm::any_of(inputList, [](Value v) {
           const auto input = cast<ShapedType>(v.getType());
           return input.hasRank() && input.getRank() == 0;
@@ -4217,7 +4217,7 @@ LogicalResult Conv2DBlockScaledOp::verify() {
 
   const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
   if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
-      return emitOpError("expect block size to be 32, got ") << blockSize;
+    return emitOpError("expect block size to be 32, got ") << blockSize;
   // Verify IC is a multiple of block size
   if (ShapedType::isStatic(IC) && IC % blockSize != 0)
     return emitOpError("expect IC to be a multiple of block size, got IC=")

>From a71cf699253dd6ab061901f19ed11cd8abdc3f6c Mon Sep 17 00:00:00 2001
From: Jeremy Johnson <jeremy.johnson at arm.com>
Date: Wed, 15 Apr 2026 13:12:30 +0100
Subject: [PATCH 4/4] Fix up bad merge

Change-Id: I1ef6af7e52b0b925776a37a139f49bf180349cb7
---
 mlir/test/Dialect/Tosa/verifier.mlir | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 5c16f33a43b0a..4fd42aff6988a 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1951,6 +1951,10 @@ func.func @test_reshape_block_scaled_block_size_mismatch(%arg0: tensor<4x32xf4E2
   // expected-error at +1 {{'tosa.reshape_block_scaled' op expect block size to be 32, got 1}}
   %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>)
   return %0#0, %0#1 : tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>
+}
+
+// -----
+
 func.func @test_maxpool2d_adaptive_invalid_kernel(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x2x32x8xf32> {
   %kernel = tosa.const_shape {values = dense<[0, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
   %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>



More information about the Mlir-commits mailing list