[Mlir-commits] [mlir] [mlir][tosa] Add TOSA RESHAPE_BLOCK_SCALED support (PR #191149)
Jeremy Johnson
llvmlistbot at llvm.org
Wed Apr 15 05:50:52 PDT 2026
https://github.com/jjohnson-arm updated https://github.com/llvm/llvm-project/pull/191149
>From 4c895610fd3d86c32a1fe6d1776e8b9cd4c784de Mon Sep 17 00:00:00 2001
From: Jeremy Johnson <jeremy.johnson at arm.com>
Date: Tue, 24 Mar 2026 17:04:03 +0000
Subject: [PATCH 1/4] Add TOSA RESHAPE_BLOCK_SCALED support
Experimental operator support, with no validation.
Signed-off-by: Jeremy Johnson <jeremy.johnson at arm.com>
Change-Id: I3d1f10a1b7765d849be0f71a8634f9a6d0077d69
---
.../mlir/Dialect/Tosa/IR/TosaOpBase.td | 6 +-
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 38 +++
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 223 ++++++++++++++++++
.../Tosa/Transforms/TosaValidation.cpp | 1 +
mlir/test/Dialect/Tosa/level_check.mlir | 37 +++
mlir/test/Dialect/Tosa/ops.mlir | 48 ++++
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 97 ++++++++
mlir/test/Dialect/Tosa/verifier.mlir | 218 +++++++++++++++++
8 files changed, 665 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 1f05aee3e5eec..591073e9985ae 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -484,11 +484,12 @@ def Tosa_RoundingModeAttr
: Tosa_I32EnumAttr<"RoundingMode", "Supported rounding modes", "rounding_mode",
[Tosa_ROUNDING_SINGLE_ROUND, Tosa_ROUNDING_INEXACT_ROUND, Tosa_ROUNDING_DOUBLE_ROUND]>;
+def Tosa_BLOCK_SIZE_1 : I32EnumAttrCase<"BLOCK_SIZE_1", 1>;
def Tosa_BLOCK_SIZE_32 : I32EnumAttrCase<"BLOCK_SIZE_32", 32>;
def Tosa_BlockSizeAttr
- : Tosa_I32EnumAttr<"BlockSize", "Block size for the block_scaled formats", "block_size",
- [Tosa_BLOCK_SIZE_32]> {
+ : Tosa_I32EnumAttr<"BlockSize", "Block size for the block_scaled formats",
+ "block_size", [Tosa_BLOCK_SIZE_1, Tosa_BLOCK_SIZE_32]> {
let extraClassDeclaration = [{
static uint32_t getBlockSizeValue(BlockSize blockSize) {
return static_cast<uint32_t>(blockSize);
@@ -496,7 +497,6 @@ def Tosa_BlockSizeAttr
}];
}
-
//===----------------------------------------------------------------------===//
// TOSA Interfaces.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 5ac91e6b65457..951969d3f34f5 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2270,6 +2270,44 @@ def Tosa_ReshapeOp : Tosa_InferTensorTypeOp<"reshape", [Pure]> {
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
+//===----------------------------------------------------------------------===//
+// Operator: reshape_block_scaled
+//===----------------------------------------------------------------------===//
+def Tosa_ReshapeBlockScaledOp
+ : Tosa_InferTensorTypeOp<"reshape_block_scaled", [Pure]> {
+ let summary = "Reshape with support for block scaled tensors.";
+
+ let description = [{
+ Returns a tensor-list with the same type/values as the input, with a new
+ shape specified by the shape argument. Reshape may operate on block-scaled
+ or non-block-scaled tensors of any rank. No data conversion happens during
+ a reshape operation. Reshape must retain the relationship between values
+ and their scale in a block for block-scaled content.
+ }];
+
+ let hasVerifier = 1;
+
+ let arguments = (ins Variadic<Tosa_Tensor>:$input,
+ Tosa_Shape:$new_value_shape, Tosa_BlockSizeAttr:$block_size);
+
+ let results = (outs Variadic<Tosa_Tensor>:$output);
+
+ list<Availability> availability = [Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+ // NOTE: Validation of extensions is
+ // disabled for this op
+ Extension<[]>,
+ ];
+
+ let extraClassDeclaration = [{
+ /// Returns true when two result types are compatible for this op;
+ /// Method used by InferTypeOpInterface.
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
+
+ let assemblyFormat =
+ "operands attr-dict `:` functional-type(operands, results)";
+}
+
//===----------------------------------------------------------------------===//
// Operator: reverse
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 29318023092a1..76eb59ecf3cad 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2784,6 +2784,229 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
return mlir::success();
}
+bool tosa::ReshapeBlockScaledOp::isCompatibleReturnTypes(TypeRange l,
+ TypeRange r) {
+ if (l.size() != r.size() || l.size() < 1 || l.size() > 2)
+ return false;
+ bool ok = (getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]));
+ if (l.size() == 2)
+ ok = ok && (getElementTypeOrSelf(l[1]) == getElementTypeOrSelf(r[1]));
+ return ok;
+}
+
+LogicalResult tosa::ReshapeBlockScaledOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ ReshapeBlockScaledOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+
+ const auto numInputs = adaptor.getInput().size();
+ ShapeAdaptor inputShape(adaptor.getInput()[0].getType());
+ Type inputType = getElementTypeOrSelf(adaptor.getInput()[0].getType());
+ llvm::SmallVector<int64_t> newShapeValue;
+ const auto newShape = adaptor.getNewValueShape();
+ if (!tosa::getConstShapeValues(newShape.getDefiningOp(), newShapeValue)) {
+ auto rank = cast<tosa::shapeType>(newShape.getType()).getRank();
+ SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
+ inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
+ if (numInputs == 2)
+ inferredReturnShapes.push_back(ShapedTypeComponents(
+ fallback, getElementTypeOrSelf(adaptor.getInput()[1].getType())));
+ return success();
+ }
+
+ const uint32_t blockSize =
+ BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
+
+ llvm::SmallVector<int64_t> newScaleShapeValue;
+ if (numInputs == 2) {
+ newScaleShapeValue.assign(newShapeValue.begin(), newShapeValue.end());
+ if (ShapedType::isStatic(newScaleShapeValue.back()))
+ newScaleShapeValue.back() /= blockSize;
+ }
+
+ inferredReturnShapes.push_back(
+ ShapedTypeComponents(newShapeValue, inputType));
+ if (numInputs == 2) {
+ // Fix up scale shape - with special case for last dimension
+ for (size_t idx = 0; idx < newShapeValue.size(); idx++) {
+ if (ShapedType::isDynamic(newScaleShapeValue[idx])) {
+ newScaleShapeValue[idx] = newShapeValue[idx];
+ if (idx == (newShapeValue.size() - 1))
+ newScaleShapeValue[idx] /= blockSize;
+ }
+ }
+
+ inferredReturnShapes.push_back(ShapedTypeComponents(
+ newScaleShapeValue,
+ getElementTypeOrSelf(adaptor.getInput()[1].getType())));
+ }
+ return success();
+}
+
+llvm::LogicalResult tosa::ReshapeBlockScaledOp::verify() {
+ const Operation::operand_range inputList = getInput();
+ const Operation::result_range outputList = getResults();
+
+ if (inputList.size() == 0)
+ return emitOpError("requires at least one input");
+
+ if (inputList.size() > 2)
+ return emitOpError("requires at most two inputs");
+
+ if (inputList.size() != outputList.size())
+ return emitOpError("requires number of results to match inputs");
+
+ if (verifySameElementTypes(*this, /* inType = */ inputList[0].getType(),
+ /* outType = */ outputList[0].getType())
+ .failed()) {
+ return failure();
+ }
+
+ const auto inputType = llvm::cast<ShapedType>(inputList[0].getType());
+ if (!inputType.hasRank())
+ return success();
+ const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+
+ if (inputList.size() == 2) {
+ if (llvm::any_of(inputList, [](Value v) {
+ const auto input = cast<ShapedType>(v.getType());
+ return input.hasRank() && input.getRank() == 0;
+ }))
+ return emitOpError(
+ "requires all input shapes have a rank greater than 0");
+ if (llvm::any_of(outputList, [](Value v) {
+ const auto output = cast<ShapedType>(v.getType());
+ return output.hasRank() && output.getRank() == 0;
+ }))
+ return emitOpError(
+ "requires all result shapes have a rank greater than 0");
+
+ if (verifySameElementTypes(*this, /* inType = */ inputList[1].getType(),
+ /* outType = */ outputList[1].getType())
+ .failed()) {
+ return failure();
+ }
+
+ const auto inputScaleType = llvm::cast<ShapedType>(inputList[1].getType());
+ if (inputScaleType.hasRank()) {
+ if (inputType.getRank() != inputScaleType.getRank())
+ return emitOpError("input shapes do not have same rank");
+
+ // Check all but the last dimension that the input shape dimensions match
+ for (auto dimIdx = 0; dimIdx < inputType.getRank() - 1; dimIdx++) {
+ const int64_t inputValueDim = inputType.getDimSize(dimIdx);
+ const int64_t inputScaleDim = inputScaleType.getShape()[dimIdx];
+ if (ShapedType::isStatic(inputValueDim) &&
+ ShapedType::isStatic(inputScaleDim) &&
+ inputValueDim != inputScaleDim)
+ return emitOpError("input shapes for data and scale do not match on "
+ "dimension ")
+ << dimIdx;
+ }
+
+ // Verify last dimension of input is a multiple of block size
+ const int64_t lastValueDim =
+ inputType.getDimSize(inputType.getRank() - 1);
+ if (ShapedType::isStatic(lastValueDim)) {
+ if (lastValueDim % blockSize != 0)
+ return emitOpError("expect last dimension of input_data (")
+ << lastValueDim << ") to be divisible by block_size ("
+ << blockSize << ")";
+
+ const int64_t lastScaleDim =
+ inputScaleType.getDimSize(inputScaleType.getRank() - 1);
+ // Verify last dimension of scale is lastValueDim / block size
+ if (ShapedType::isStatic(lastScaleDim) &&
+ lastScaleDim != lastValueDim / blockSize)
+ return emitOpError("expect last dimension of scale_data (")
+ << lastScaleDim << ") to be " << lastValueDim << "/"
+ << blockSize;
+ }
+ }
+ }
+
+ // Get the new value shape dimension values
+ SmallVector<int64_t> shapeValues;
+ if (!tosa::getConstShapeValues(getNewValueShape().getDefiningOp(),
+ shapeValues)) {
+ // skip following checks if shape is not constant
+ return mlir::success();
+ }
+
+ if (inputList.size() == 2) {
+ if (static_cast<int64_t>(shapeValues.size()) == 0)
+ return emitOpError("requires new shape to have a rank greater than 0");
+
+ const int64_t lastShapeDim = shapeValues.back();
+ if (ShapedType::isStatic(lastShapeDim) && lastShapeDim % blockSize != 0)
+ return emitOpError("expect last dimension of new shape (")
+ << lastShapeDim << ") to be divisible by block_size (" << blockSize
+ << ")";
+ }
+
+ const auto outputType = llvm::cast<ShapedType>(outputList[0].getType());
+ if (!outputType.hasRank())
+ return success();
+
+ if (static_cast<int64_t>(shapeValues.size()) != outputType.getRank())
+ return emitOpError() << "result does not match new shape rank";
+
+ for (auto [newShapeDim, outputShapeDim] :
+ zip(shapeValues, outputType.getShape())) {
+ if (ShapedType::isStatic(newShapeDim) &&
+ ShapedType::isStatic(outputShapeDim) && newShapeDim != outputShapeDim)
+ return emitOpError() << "result shape is inconsistent with new shape";
+ }
+
+ if (outputList.size() == 2) {
+ // Set up scale shape from new shape given
+ SmallVector<int64_t> scaleShapeValues(shapeValues.begin(),
+ shapeValues.end());
+ scaleShapeValues.back() /= blockSize;
+
+ const auto outputScaleType =
+ llvm::cast<ShapedType>(outputList[1].getType());
+ if (outputScaleType.hasRank()) {
+ if ((int64_t)scaleShapeValues.size() != outputScaleType.getRank())
+ return emitOpError() << "result scale does not match new shape rank";
+
+ for (auto [newScaleShapeDim, outputScaleShapeDim] :
+ zip(scaleShapeValues, outputScaleType.getShape())) {
+ if (ShapedType::isStatic(newScaleShapeDim) &&
+ ShapedType::isStatic(outputScaleShapeDim) &&
+ newScaleShapeDim != outputScaleShapeDim)
+ return emitOpError()
+ << "result scale shape is inconsistent with new shape";
+ }
+ }
+ }
+
+ if (inputType.hasStaticShape()) {
+ int64_t inputElementsNum = inputType.getNumElements();
+ if (outputType.hasStaticShape()) {
+ int64_t outputElementsNum = outputType.getNumElements();
+ if (inputElementsNum != outputElementsNum) {
+ return emitOpError() << "cannot reshape " << inputElementsNum
+ << " elements into " << outputElementsNum;
+ }
+ }
+
+ int64_t newShapeElementsNum =
+ llvm::accumulate(shapeValues, int64_t(1), [](int64_t acc, int64_t dim) {
+ return (dim > 0) ? acc * dim : acc;
+ });
+ bool isStaticNewShape =
+ llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
+ if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
+ (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
+ return emitOpError() << "cannot reshape " << inputElementsNum
+ << " elements into " << newShapeElementsNum;
+ }
+ }
+
+ return mlir::success();
+}
+
// return failure if val is not a constant
// set zp to -1 if val is non-zero float or val is not integer nor float
// otherwise set zp to val's constant value
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 6169003881487..bdfde330d73f3 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -781,6 +781,7 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
CHECK_RANKS_AND_SIZES(Concat);
CHECK_RANKS_AND_SIZES(Pad);
CHECK_RANKS_AND_SIZES(Reshape);
+ CHECK_RANKS_AND_SIZES(ReshapeBlockScaled);
CHECK_RANKS_AND_SIZES(Reverse);
CHECK_RANKS_AND_SIZES(Slice);
CHECK_RANKS_AND_SIZES(Tile);
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index b3bdb02c20103..8887c8b5ecc70 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -395,6 +395,43 @@ func.func @test_reshape_rank_invalid(%arg0: tensor<13x21x3xf32>) -> tensor<1x1x1
return %0 : tensor<1x1x1x1x1x1x819xf32>
}
+
+// -----
+
+func.func @test_reshape_non_block_scaled_output_rank_invalid(%arg0: tensor<13x21x3xf32>) -> tensor<1x1x1x1x1x1x819xf32> {
+ %1 = tosa.const_shape {values = dense<[1, 1, 1, 1, 1, 1, 819]> : tensor<7xindex>} : () -> !tosa.shape<7>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op failed level check: result rank(shape) <= MAX_RANK}}
+ %0 = tosa.reshape_block_scaled %arg0, %1 {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<13x21x3xf32>, !tosa.shape<7>) -> tensor<1x1x1x1x1x1x819xf32>
+ return %0 : tensor<1x1x1x1x1x1x819xf32>
+}
+
+// -----
+
+func.func @test_reshape_non_block_scaled_input_rank_invalid(%arg0: tensor<1x1x1x1x1x1x819xf32>) -> tensor<13x21x3xf32> {
+ %1 = tosa.const_shape {values = dense<[13, 21, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op failed level check: operand rank(shape) <= MAX_RANK}}
+ %0 = tosa.reshape_block_scaled %arg0, %1 {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<1x1x1x1x1x1x819xf32>, !tosa.shape<3>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_output_rank_invalid(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> (tensor<1x1x1x1x1x2x64xf4E2M1FN>, tensor<1x1x1x1x1x2x2xf8E8M0FNU>) {
+ %1 = tosa.const_shape {values = dense<[1, 1, 1, 1, 1, 2, 64]> : tensor<7xindex>} : () -> !tosa.shape<7>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op failed level check: result rank(shape) <= MAX_RANK}}
+ %0:2 = tosa.reshape_block_scaled %arg0, %arg1, %1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>, !tosa.shape<7>) -> (tensor<1x1x1x1x1x2x64xf4E2M1FN>, tensor<1x1x1x1x1x2x2xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<1x1x1x1x1x2x64xf4E2M1FN>, tensor<1x1x1x1x1x2x2xf8E8M0FNU>
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_input_rank_invalid(%arg0: tensor<1x1x1x1x1x4x32xf4E2M1FN>, %arg1: tensor<1x1x1x1x1x4x1xf8E8M0FNU>) -> (tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>) {
+ %1 = tosa.const_shape {values = dense<[2, 64]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op failed level check: operand rank(shape) <= MAX_RANK}}
+ %0:2 = tosa.reshape_block_scaled %arg0, %arg1, %1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<1x1x1x1x1x4x32xf4E2M1FN>, tensor<1x1x1x1x1x4x1xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>
+}
+
// -----
func.func @test_reverse_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index e80d3d84a8105..507c77c54d0da 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -784,6 +784,54 @@ func.func @test_reshape_unranked_output(%arg0: tensor<13x21x3xf32>) -> tensor<*x
return %0 : tensor<*xf32>
}
+// -----
+// CHECK-LABEL: reshape_non_block_scaled
+func.func @test_reshape_non_block_scaled(%arg0: tensor<13x21x3xf32>) -> tensor<1x819xf32> {
+ %1 = tosa.const_shape {values = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %0 = tosa.reshape_block_scaled %arg0, %1 {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<13x21x3xf32>, !tosa.shape<2>) -> tensor<1x819xf32>
+ return %0 : tensor<1x819xf32>
+}
+
+// -----
+// CHECK-LABEL: reshape_non_block_scaled_unranked_input
+func.func @test_reshape_non_block_scaled_unranked_input(%arg0: tensor<*xf32>) -> tensor<1x819xf32> {
+ %1 = tosa.const_shape {values = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %0 = tosa.reshape_block_scaled %arg0, %1 {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<*xf32>, !tosa.shape<2>) -> tensor<1x819xf32>
+ return %0 : tensor<1x819xf32>
+}
+
+// -----
+// CHECK-LABEL: reshape_non_block_scaled_unranked_output
+func.func @test_reshape_non_block_scaled_unranked_output(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> {
+ %1 = tosa.const_shape {values = dense<[21, 13, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %0 = tosa.reshape_block_scaled %arg0, %1 {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<13x21x3xf32>, !tosa.shape<3>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+// CHECK-LABEL: reshape_block_scaled
+func.func @test_reshape_block_scaled(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> (tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>) {
+ %1 = tosa.const_shape {values = dense<[2, 64]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %0:2 = tosa.reshape_block_scaled %arg0, %arg1, %1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>
+}
+
+// -----
+// CHECK-LABEL: reshape_block_scaled_unranked_input
+func.func @test_reshape_block_scaled_unranked_input(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>) -> (tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>) {
+ %1 = tosa.const_shape {values = dense<[2, 64]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %0:2 = tosa.reshape_block_scaled %arg0, %arg1, %1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>
+}
+
+// -----
+// CHECK-LABEL: reshape_block_scaled_unranked_output
+func.func @test_reshape_block_scaled_unranked_output(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) {
+ %1 = tosa.const_shape {values = dense<[2, 64]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %0:2 = tosa.reshape_block_scaled %arg0, %arg1, %1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>
+}
+
// -----
// CHECK-LABEL: reverse
func.func @test_reverse(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 408300fa7034b..18d117c4692fe 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -450,6 +450,103 @@ func.func @test_dynamic_reshape(%arg0 : tensor<4x?xi32>) -> () {
// -----
+// CHECK-LABEL: @test_static_reshape_non_block_scaled
+func.func @test_static_reshape_non_block_scaled(%arg0 : tensor<4x4xi32>) -> () {
+ // CHECK-DAG: %[[CONSTSHAPE1:.+]] = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
+ // CHECK-DAG: %[[CONSTSHAPE2:.+]] = tosa.const_shape {values = dense<[2, 8]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // CHECK-DAG: tosa.reshape_block_scaled %arg0, %[[CONSTSHAPE1]] {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
+ // CHECK-DAG: tosa.reshape_block_scaled %arg0, %[[CONSTSHAPE2]] {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<4x4xi32>, !tosa.shape<2>) -> tensor<2x8xi32>
+ %0 = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %1 = tosa.reshape_block_scaled %arg0, %0 {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
+ %2 = tosa.const_shape {values = dense<[2, 8]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %3 = tosa.reshape_block_scaled %arg0, %2 {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<4x4xi32>, !tosa.shape<2>) -> tensor<2x8xi32>
+
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @test_dynamic_reshape_non_block_scaled
+func.func @test_dynamic_reshape_non_block_scaled(%arg0 : tensor<4x?xi32>) -> () {
+ // CHECK-DAG: %[[CONSTSHAPE1:.+]] = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
+ // CHECK-DAG: %[[CONSTSHAPE2:.+]] = tosa.const_shape {values = dense<[2, 8]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // CHECK-DAG: tosa.reshape_block_scaled %arg0, %[[CONSTSHAPE1]] {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<16xi32>
+ // CHECK-DAG: tosa.reshape_block_scaled %arg0, %[[CONSTSHAPE2]] {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<4x?xi32>, !tosa.shape<2>) -> tensor<2x8xi32>
+ %0 = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %1 = tosa.reshape_block_scaled %arg0, %0 {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<?xi32>
+ %2 = tosa.const_shape {values = dense<[2, 8]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %3 = tosa.reshape_block_scaled %arg0, %2 {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<4x?xi32>, !tosa.shape<2>) -> tensor<?x?xi32>
+
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @test_unranked_reshape_non_block_scaled
+func.func @test_unranked_reshape_non_block_scaled(%arg0 : tensor<4x4xi32>) -> () {
+ // CHECK-DAG: %[[CONSTSHAPE1:.+]] = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
+ // CHECK-DAG: %[[CONSTSHAPE2:.+]] = tosa.const_shape {values = dense<[2, 8]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // CHECK-DAG: tosa.reshape_block_scaled %arg0, %[[CONSTSHAPE1]] {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
+ // CHECK-DAG: tosa.reshape_block_scaled %arg0, %[[CONSTSHAPE2]] {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<4x4xi32>, !tosa.shape<2>) -> tensor<2x8xi32>
+ %0 = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %1 = tosa.reshape_block_scaled %arg0, %0 {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<*xi32>
+ %2 = tosa.const_shape {values = dense<[2, 8]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %3 = tosa.reshape_block_scaled %arg0, %2 {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<4x4xi32>, !tosa.shape<2>) -> tensor<*xi32>
+
+ return
+}
+
+
+// -----
+
+// CHECK-LABEL: @test_static_reshape_block_scaled
+func.func @test_static_reshape_block_scaled(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> () {
+ // CHECK-DAG: %[[CONSTSHAPE1:.+]] = tosa.const_shape {values = dense<128> : tensor<1xindex>} : () -> !tosa.shape<1>
+ // CHECK-DAG: %[[CONSTSHAPE2:.+]] = tosa.const_shape {values = dense<[2, 64]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // CHECK-DAG: tosa.reshape_block_scaled %arg0, %arg1, %[[CONSTSHAPE1]] {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>, !tosa.shape<1>) -> (tensor<128xf4E2M1FN>, tensor<4xf8E8M0FNU>)
+ // CHECK-DAG: tosa.reshape_block_scaled %arg0, %arg1, %[[CONSTSHAPE2]] {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>)
+ %0 = tosa.const_shape {values = dense<128> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %1:2 = tosa.reshape_block_scaled %arg0, %arg1, %0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>, !tosa.shape<1>) -> (tensor<128xf4E2M1FN>, tensor<4xf8E8M0FNU>)
+ %2 = tosa.const_shape {values = dense<[2, 64]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %3:2 = tosa.reshape_block_scaled %arg0, %arg1, %2 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>)
+
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @test_dynamic_reshape_block_scaled
+func.func @test_dynamic_reshape_block_scaled(%arg0: tensor<4x?xf4E2M1FN>, %arg1: tensor<?x1xf8E8M0FNU>) -> () {
+ // CHECK-DAG: %[[CONSTSHAPE1:.+]] = tosa.const_shape {values = dense<128> : tensor<1xindex>} : () -> !tosa.shape<1>
+ // CHECK-DAG: %[[CONSTSHAPE2:.+]] = tosa.const_shape {values = dense<[2, 64]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // CHECK-DAG: tosa.reshape_block_scaled %arg0, %arg1, %[[CONSTSHAPE1]] {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x?xf4E2M1FN>, tensor<?x1xf8E8M0FNU>, !tosa.shape<1>) -> (tensor<128xf4E2M1FN>, tensor<4xf8E8M0FNU>)
+ // CHECK-DAG: tosa.reshape_block_scaled %arg0, %arg1, %[[CONSTSHAPE2]] {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x?xf4E2M1FN>, tensor<?x1xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>)
+ %0 = tosa.const_shape {values = dense<128> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %1:2 = tosa.reshape_block_scaled %arg0, %arg1, %0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x?xf4E2M1FN>, tensor<?x1xf8E8M0FNU>, !tosa.shape<1>) -> (tensor<?xf4E2M1FN>, tensor<?xf8E8M0FNU>)
+ %2 = tosa.const_shape {values = dense<[2, 64]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %3:2 = tosa.reshape_block_scaled %arg0, %arg1, %2 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x?xf4E2M1FN>, tensor<?x1xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x?xf4E2M1FN>, tensor<?x2xf8E8M0FNU>)
+
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @test_unranked_reshape_block_scaled
+func.func @test_unranked_reshape_block_scaled(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> () {
+ // CHECK-DAG: %[[CONSTSHAPE1:.+]] = tosa.const_shape {values = dense<128> : tensor<1xindex>} : () -> !tosa.shape<1>
+ // CHECK-DAG: %[[CONSTSHAPE2:.+]] = tosa.const_shape {values = dense<[2, 64]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // CHECK-DAG: tosa.reshape_block_scaled %arg0, %arg1, %[[CONSTSHAPE1]] {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>, !tosa.shape<1>) -> (tensor<128xf4E2M1FN>, tensor<4xf8E8M0FNU>)
+ // CHECK-DAG: tosa.reshape_block_scaled %arg0, %arg1, %[[CONSTSHAPE2]] {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>)
+ %0 = tosa.const_shape {values = dense<128> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %1:2 = tosa.reshape_block_scaled %arg0, %arg1, %0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>, !tosa.shape<1>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
+ %2 = tosa.const_shape {values = dense<[2, 64]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %3:2 = tosa.reshape_block_scaled %arg0, %arg1, %2 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
+
+ return
+}
+
+// -----
+
// CHECK: @test_reduce_binary
func.func @test_reduce_binary(%arg0 : tensor<2x3x?x?xi1>) -> () {
// CHECK: tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<2x3x?x?xi1>) -> tensor<1x3x?x?xi1>
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 80d5bca039909..d62aa09734416 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1537,6 +1537,8 @@ func.func @test_conv2d_block_scaled_invalid_bias_size(%arg0: tensor<1x4x4x64xf4E
return %0 : tensor<1x4x4x8xf32>
}
+// -----
+
func.func @test_missmatched_ranks() {
%0 = tosa.const_shape {values = dense<[10]> : tensor<1xindex>} : () -> !tosa.shape<1>
%1 = tosa.const_shape {values = dense<[10, 15]> : tensor<2xindex>} : () -> !tosa.shape<2>
@@ -1544,3 +1546,219 @@ func.func @test_missmatched_ranks() {
tosa.assert_equal_shape %0, %1 {allow_broadcast = true} : (!tosa.shape<1>, !tosa.shape<2>) -> ()
return
}
+
+// -----
+
+func.func @test_reshape_block_scaled_static_zero_dim_input(%arg0 : tensor<13x0x3xf32>) -> () {
+ %s = tosa.const_shape {values = dense<[13, 21, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op operand #0 must be variadic of tosa-conformant tensor of number values, but got 'tensor<13x0x3xf32>'}}
+ %0 = "tosa.reshape_block_scaled"(%arg0, %s) {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<13x0x3xf32>, !tosa.shape<3>) -> tensor<13x0x3xf32>
+ return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_zero_dim_input(%arg0 : tensor<?x0x3xf32>) -> () {
+ %s = tosa.const_shape {values = dense<[13, 21, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op operand #0 must be variadic of tosa-conformant tensor of number values, but got 'tensor<?x0x3xf32>'}}
+ %0 = "tosa.reshape_block_scaled"(%arg0, %s) {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<?x0x3xf32>, !tosa.shape<3>) -> tensor<13x0x3xf32>
+ return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_zero_dim_input(%arg0 : tensor<?x0x3xf32>) -> () {
+ %1 = tosa.const_shape {values = dense<[13, 21, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op operand #0 must be variadic of tosa-conformant tensor of number values, but got 'tensor<?x0x3xf32>'}}
+ %0 = "tosa.reshape_block_scaled"(%arg0, %1) {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<?x0x3xf32>, !tosa.shape<3>) -> tensor<13x0x3xf32>
+ return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_invalid_tensor_dim(%arg0 : tensor<4x?xf32>) -> () {
+ %s = tosa.const_shape {values = dense<[1, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op result shape is inconsistent with new shape}}
+ %0 = "tosa.reshape_block_scaled" (%arg0, %s) {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<4x?xf32>, !tosa.shape<2>) -> tensor<?x4xf32>
+ return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_invalid_size(%arg0 : tensor<2x4xf32>) -> () {
+ %s = tosa.const_shape {values = dense<[3, 5]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op cannot reshape 8 elements into 15}}
+ %0 = "tosa.reshape_block_scaled"(%arg0, %s) {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<2x4xf32>, !tosa.shape<2>) -> tensor<3x5xf32>
+ return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_invalid_newshape(%arg0 : tensor<1xf32>) -> () {
+ %s = tosa.const_shape {values = dense<[1, 4]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op cannot reshape 1 elements into 4}}
+ %0 = "tosa.reshape_block_scaled"(%arg0, %s) {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<1xf32>, !tosa.shape<2>) -> tensor<?x4xf32>
+ return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_inconsistent_dynamic_result(%arg0 : tensor<?xf32>) -> () {
+ %s = tosa.const_shape {values = dense<[2, 4, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op result shape is inconsistent with new shape}}
+ %0 = "tosa.reshape_block_scaled"(%arg0, %s) {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<?xf32>, !tosa.shape<3>) -> tensor<?x3x5xf32>
+ return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_rank_mismatch(%arg0 : tensor<?xf32>) -> () {
+ %s = tosa.const_shape {values = dense<[2, 4]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op result does not match new shape rank}}
+ %0 = "tosa.reshape_block_scaled"(%arg0, %s) {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<?xf32>, !tosa.shape<2>) -> tensor<?xf32>
+ return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_types_mismatch(%arg0 : tensor<2x4xf32>) -> () {
+ %s = tosa.const_shape {values = dense<[8]> : tensor<1xindex>} : () -> !tosa.shape<1>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op expect input and output to have same element type, got 'f32' and 'i32'}}
+ %0 = "tosa.reshape_block_scaled"(%arg0, %s) {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<2x4xf32>, !tosa.shape<1>) -> tensor<8xi32>
+ return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_outputs_mismatch_inputs(%arg0 : tensor<64xf8E4M3FN>) -> () {
+ %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op requires number of results to match inputs}}
+ %0:2 = "tosa.reshape_block_scaled"(%arg0, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>, tensor<2x1xf8E8M0FNU>)
+ return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_inputs_mismatch_outputs(%arg0 : tensor<64xf8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>) -> () {
+ %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op requires number of results to match inputs}}
+ %0 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, tensor<2xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>)
+ return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_no_inputs() -> () {
+ %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op requires at least one input}}
+ %0 = "tosa.reshape_block_scaled"(%s) {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (!tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>)
+ return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_too_many_inputs(%arg0 : tensor<64xf8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>, %arg2 : tensor<2xf8E8M0FNU>) -> () {
+ %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op requires at most two inputs}}
+ %0 = "tosa.reshape_block_scaled"(%arg0, %arg1, %arg2, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, tensor<2xf8E8M0FNU>, tensor<2xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>)
+ return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_scale_types_mismatch(%arg0 : tensor<64xf8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>) -> () {
+ %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op expect input and output to have same element type, got 'f8E8M0FNU' and 'f8E4M3FN'}}
+ %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, tensor<2xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>, tensor<2x1xf8E4M3FN>)
+ return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_input_ranks_mismatch(%arg0 : tensor<64xf8E4M3FN>, %arg1 : tensor<2x1xf8E8M0FNU>) -> () {
+ %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op input shapes do not have same rank}}
+ %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, tensor<2x1xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>, tensor<2x1xf8E8M0FNU>)
+ return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_scale_dims_mismatch(%arg0 : tensor<1x64xf8E4M3FN>, %arg1 : tensor<2x2xf8E8M0FNU>) -> () {
+ %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op input shapes for data and scale do not match on dimension 0}}
+ %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<1x64xf8E4M3FN>, tensor<2x2xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>, tensor<2x1xf8E8M0FNU>)
+ return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_block_size_mismatch(%arg0 : tensor<60xf8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>) -> () {
+ %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op expect last dimension of input_data (60) to be divisible by block_size (32)}}
+ %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<60xf8E4M3FN>, tensor<2xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>, tensor<2x1xf8E8M0FNU>)
+ return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_scale_blocks_mismatch(%arg0 : tensor<64xf8E4M3FN>, %arg1 : tensor<3xf8E8M0FNU>) -> () {
+ %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op expect last dimension of scale_data (3) to be 64/32}}
+ %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, tensor<3xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>, tensor<2x1xf8E8M0FNU>)
+ return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_new_shape_block_size_mismatch(%arg0 : tensor<64xf8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>) -> () {
+ %s = tosa.const_shape {values = dense<[2, 30]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op expect last dimension of new shape (30) to be divisible by block_size (32)}}
+ %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, tensor<2xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x30xf8E4M3FN>, tensor<2x1xf8E8M0FNU>)
+ return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_inconsistent_scale_output_rank(%arg0 : tensor<64xf8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>) -> () {
+ %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op result scale does not match new shape rank}}
+ %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, tensor<2xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>, tensor<2xf8E8M0FNU>)
+ return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_inconsistent_scale(%arg0 : tensor<64xf8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>) -> () {
+ %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op result scale shape is inconsistent with new shape}}
+ %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, tensor<2xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>, tensor<3x?xf8E8M0FNU>)
+ return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_rank0_input(%arg0 : tensor<f8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>) -> () {
+ %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op requires all input shapes have a rank greater than 0}}
+ %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<f8E4M3FN>, tensor<2xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>, tensor<2x2xf8E8M0FNU>)
+ return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_rank0_scale_output(%arg0 : tensor<64xf8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>) -> () {
+ %s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op requires all result shapes have a rank greater than 0}}
+ %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, tensor<2xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>, tensor<f8E8M0FNU>)
+ return
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_rank0_scale_input(%arg0 : tensor<64xf8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>) -> () {
+ %s = tosa.const_shape {values = dense<> : tensor<0xindex>} : () -> !tosa.shape<0>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op requires new shape to have a rank greater than 0}}
+ %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, tensor<2xf8E8M0FNU>, !tosa.shape<0>) -> (tensor<2x32xf8E4M3FN>, tensor<2x2xf8E8M0FNU>)
+ return
+}
>From 809ed65aa97fb31f2ef0a69635ced83b478b862e Mon Sep 17 00:00:00 2001
From: Jeremy Johnson <jeremy.johnson at arm.com>
Date: Mon, 13 Apr 2026 16:58:25 +0100
Subject: [PATCH 2/4] Fix up RESHAPE_BLOCK_SCALED
Add better decription on tensor usage
Add checking for BLOCK_SIZE_1 and BLOCK_SIZE_32 usage
Change-Id: I4a625ffdd637bd7b391b631ff8a6e1e1453f13c7
Signed-off-by: Jeremy Johnson <jeremy.johnson at arm.com>
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 3 +-
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 15 ++++-
mlir/test/Dialect/Tosa/verifier.mlir | 59 +++++++++++++++++++-
3 files changed, 72 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 951969d3f34f5..1f14c43136529 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2280,7 +2280,8 @@ def Tosa_ReshapeBlockScaledOp
let description = [{
Returns a tensor-list with the same type/values as the input, with a new
shape specified by the shape argument. Reshape may operate on block-scaled
- or non-block-scaled tensors of any rank. No data conversion happens during
+ tensors (values tensor followed by scale tensor) of rank 1 or higher; or a
+ single non-block-scaled tensor of any rank. No data conversion happens during
a reshape operation. Reshape must retain the relationship between values
and their scale in a block for block-scaled content.
}];
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 76eb59ecf3cad..9d85f65e65da3 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2160,6 +2160,8 @@ LogicalResult MatmulTBlockScaledOp::verify() {
// Verify C is a multiple of block size
const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
+ return emitOpError("expect block size to be 32, got ") << blockSize;
if (ShapedType::isStatic(C) && C % blockSize != 0)
return emitOpError("expect C to be a multiple of block size, got C=")
<< C << ", block_size=" << blockSize;
@@ -2868,6 +2870,8 @@ llvm::LogicalResult tosa::ReshapeBlockScaledOp::verify() {
const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
if (inputList.size() == 2) {
+ if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
+ return emitOpError("expect block size to be 32, got ") << blockSize;
if (llvm::any_of(inputList, [](Value v) {
const auto input = cast<ShapedType>(v.getType());
return input.hasRank() && input.getRank() == 0;
@@ -2923,6 +2927,9 @@ llvm::LogicalResult tosa::ReshapeBlockScaledOp::verify() {
<< blockSize;
}
}
+ } else {
+ if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_1))
+ return emitOpError("expect block size to be 1, got ") << blockSize;
}
// Get the new value shape dimension values
@@ -4173,8 +4180,10 @@ LogicalResult Conv2DBlockScaledOp::verify() {
return failure();
}
- // Verify IC is a multiple of block size
const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
+ return emitOpError("expect block size to be 32, got ") << blockSize;
+ // Verify IC is a multiple of block size
if (ShapedType::isStatic(IC) && IC % blockSize != 0)
return emitOpError("expect IC to be a multiple of block size, got IC=")
<< IC << ", block_size=" << blockSize;
@@ -4723,6 +4732,8 @@ LogicalResult CastFromBlockScaledOp::verify() {
if (inputDataShape.hasRank()) {
const unsigned int blockSize =
BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
+ return emitOpError("expect block size to be 32, got ") << blockSize;
const int64_t inputDataLastDim =
inputDataShape.getDimSize(inputDataShape.getRank() - 1);
if (inputDataLastDim % blockSize != 0)
@@ -4796,6 +4807,8 @@ LogicalResult CastToBlockScaledOp::verify() {
const unsigned int blockSize =
BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
+ return emitOpError("expect block size to be 32, got ") << blockSize;
const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
if (inputDataShape.hasRank()) {
const int64_t inputDataLastDim =
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index d62aa09734416..0666feb5b367e 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1195,6 +1195,14 @@ func.func @test_matmul_t_block_scaled_batch_mismatch(%arg0: tensor<4x8x32xf8E4M3
// -----
+func.func @test_matmul_t_block_scaled_block_size_mismatch(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E4M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+ // expected-error at +1 {{'tosa.matmul_t_block_scaled' op expect block size to be 32, got 1}}
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E4M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+ return %0 : tensor<4x8x16xf32>
+}
+
+// -----
+
func.func @cast_from_block_scaled_incompatible_input_output_shape(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<5x32xf32> {
// expected-error at +1 {{'tosa.cast_from_block_scaled' op require compatible shapes for input_data ('tensor<4x32xf4E2M1FN>') and output_data ('tensor<5x32xf32>')}}
%0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<5x32xf32>
@@ -1235,6 +1243,14 @@ func.func @cast_from_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4x32
// -----
+func.func @test_cast_from_block_scaled_block_size_mismatch(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+ // expected-error at +1 {{'tosa.cast_from_block_scaled' op expect block size to be 32, got 1}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32>
+ return %0 : tensor<4x32xf32>
+}
+
+// -----
+
func.func @test_cast_to_block_scaled_incompatible_input_output_shape(%arg0: tensor<4x32xf32>) -> (tensor<5x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) {
// expected-error at +1 {{'tosa.cast_to_block_scaled' op require compatible shapes for input_data ('tensor<4x32xf32>') and output_data ('tensor<5x32xf4E2M1FN>')}}
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<5x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
@@ -1275,6 +1291,14 @@ func.func @test_cast_to_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4
// -----
+func.func @test_cast_to_block_scaled_block_size_mismatch(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) {
+ // expected-error at +1 {{'tosa.cast_to_block_scaled' op expect block size to be 32, got 1}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>
+}
+
+// -----
+
func.func @test_clamp_quantized(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) {
// expected-error at +1 {{'tosa.clamp' op min/max attributes types are incompatible with input/output element types.}}
%0 = tosa.clamp %arg0 {max_val = 127 : i8, min_val = -128 : i8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
@@ -1539,6 +1563,17 @@ func.func @test_conv2d_block_scaled_invalid_bias_size(%arg0: tensor<1x4x4x64xf4E
// -----
+func.func @test_conv2d_block_scaled_block_size_mismatch(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>) -> tensor<*xf32> {
+ %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.conv2d_block_scaled' op expect block size to be 32, got 1}}
+ %3 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+ return %3 : tensor<*xf32>
+}
+
+// -----
+
func.func @test_missmatched_ranks() {
%0 = tosa.const_shape {values = dense<[10]> : tensor<1xindex>} : () -> !tosa.shape<1>
%1 = tosa.const_shape {values = dense<[10, 15]> : tensor<2xindex>} : () -> !tosa.shape<2>
@@ -1633,7 +1668,7 @@ func.func @test_reshape_block_scaled_types_mismatch(%arg0 : tensor<2x4xf32>) ->
func.func @test_reshape_block_scaled_outputs_mismatch_inputs(%arg0 : tensor<64xf8E4M3FN>) -> () {
%s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
// expected-error at +1 {{'tosa.reshape_block_scaled' op requires number of results to match inputs}}
- %0:2 = "tosa.reshape_block_scaled"(%arg0, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>, tensor<2x1xf8E8M0FNU>)
+ %0:2 = "tosa.reshape_block_scaled"(%arg0, %s) {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<64xf8E4M3FN>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>, tensor<2x1xf8E8M0FNU>)
return
}
@@ -1693,7 +1728,7 @@ func.func @test_reshape_block_scaled_scale_dims_mismatch(%arg0 : tensor<1x64xf8E
// -----
-func.func @test_reshape_block_scaled_block_size_mismatch(%arg0 : tensor<60xf8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>) -> () {
+func.func @test_reshape_block_scaled_block_size_dim_mismatch(%arg0 : tensor<60xf8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>) -> () {
%s = tosa.const_shape {values = dense<[2, 32]> : tensor<2xindex>} : () -> !tosa.shape<2>
// expected-error at +1 {{'tosa.reshape_block_scaled' op expect last dimension of input_data (60) to be divisible by block_size (32)}}
%0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<60xf8E4M3FN>, tensor<2xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x32xf8E4M3FN>, tensor<2x1xf8E8M0FNU>)
@@ -1711,7 +1746,7 @@ func.func @test_reshape_block_scaled_scale_blocks_mismatch(%arg0 : tensor<64xf8E
// -----
-func.func @test_reshape_block_scaled_new_shape_block_size_mismatch(%arg0 : tensor<64xf8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>) -> () {
+func.func @test_reshape_block_scaled_new_shape_block_size_dim_mismatch(%arg0 : tensor<64xf8E4M3FN>, %arg1 : tensor<2xf8E8M0FNU>) -> () {
%s = tosa.const_shape {values = dense<[2, 30]> : tensor<2xindex>} : () -> !tosa.shape<2>
// expected-error at +1 {{'tosa.reshape_block_scaled' op expect last dimension of new shape (30) to be divisible by block_size (32)}}
%0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, tensor<2xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x30xf8E4M3FN>, tensor<2x1xf8E8M0FNU>)
@@ -1762,3 +1797,21 @@ func.func @test_reshape_block_scaled_rank0_scale_input(%arg0 : tensor<64xf8E4M3F
%0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<64xf8E4M3FN>, tensor<2xf8E8M0FNU>, !tosa.shape<0>) -> (tensor<2x32xf8E4M3FN>, tensor<2x2xf8E8M0FNU>)
return
}
+
+// -----
+
+func.func @test_reshape_non_block_scaled_block_size_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor<1x819xf32> {
+ %s = tosa.const_shape {values = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op expect block size to be 1, got 32}}
+ %0 = "tosa.reshape_block_scaled"(%arg0, %s) {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<13x21x3xf32>, !tosa.shape<2>) -> tensor<1x819xf32>
+ return %0 : tensor<1x819xf32>
+}
+
+// -----
+
+func.func @test_reshape_block_scaled_block_size_mismatch(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> (tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>) {
+ %s = tosa.const_shape {values = dense<[2, 64]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.reshape_block_scaled' op expect block size to be 32, got 1}}
+ %0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>
+}
>From 0f33e145bbe9561ace28570b490631ec7f33eed4 Mon Sep 17 00:00:00 2001
From: Jeremy Johnson <jeremy.johnson at arm.com>
Date: Wed, 15 Apr 2026 08:54:37 +0100
Subject: [PATCH 3/4] Missed formatting
Change-Id: I5959d4586a9b3a7cbc5e2bdc5794baf1dcbda0ee
---
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 621e6c2b71bc1..35ffe6c9eccf7 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2906,7 +2906,7 @@ llvm::LogicalResult tosa::ReshapeBlockScaledOp::verify() {
if (inputList.size() == 2) {
if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
- return emitOpError("expect block size to be 32, got ") << blockSize;
+ return emitOpError("expect block size to be 32, got ") << blockSize;
if (llvm::any_of(inputList, [](Value v) {
const auto input = cast<ShapedType>(v.getType());
return input.hasRank() && input.getRank() == 0;
@@ -4217,7 +4217,7 @@ LogicalResult Conv2DBlockScaledOp::verify() {
const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
- return emitOpError("expect block size to be 32, got ") << blockSize;
+ return emitOpError("expect block size to be 32, got ") << blockSize;
// Verify IC is a multiple of block size
if (ShapedType::isStatic(IC) && IC % blockSize != 0)
return emitOpError("expect IC to be a multiple of block size, got IC=")
>From a71cf699253dd6ab061901f19ed11cd8abdc3f6c Mon Sep 17 00:00:00 2001
From: Jeremy Johnson <jeremy.johnson at arm.com>
Date: Wed, 15 Apr 2026 13:12:30 +0100
Subject: [PATCH 4/4] Fix up bad merge
Change-Id: I1ef6af7e52b0b925776a37a139f49bf180349cb7
---
mlir/test/Dialect/Tosa/verifier.mlir | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 5c16f33a43b0a..4fd42aff6988a 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1951,6 +1951,10 @@ func.func @test_reshape_block_scaled_block_size_mismatch(%arg0: tensor<4x32xf4E2
// expected-error at +1 {{'tosa.reshape_block_scaled' op expect block size to be 32, got 1}}
%0:2 = "tosa.reshape_block_scaled"(%arg0, %arg1, %s) {block_size = #tosa.block_size<BLOCK_SIZE_1> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>, !tosa.shape<2>) -> (tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>)
return %0#0, %0#1 : tensor<2x64xf4E2M1FN>, tensor<2x2xf8E8M0FNU>
+}
+
+// -----
+
func.func @test_maxpool2d_adaptive_invalid_kernel(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x2x32x8xf32> {
%kernel = tosa.const_shape {values = dense<[0, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
More information about the Mlir-commits
mailing list