[Mlir-commits] [mlir] [mlir][tosa] Add support for cast_from/to_block_scaled (PR #163436)
Luke Hutton
llvmlistbot at llvm.org
Thu Oct 23 08:19:49 PDT 2025
https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/163436
>From 04ed197ad515de3bcf9077a6dd19038c63a013fc Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Mon, 15 Sep 2025 21:52:24 +0000
Subject: [PATCH] [mlir][tosa] Add support for cast_from/to_block_scaled
This commit adds support for the cast_from/to_block_scaled
operations from the ext-mxfp extension. This includes:
- Operation definition in TosaOps.td
- Micro-scaling supported types definition
- Shape inference and verifiers
- Validation pass checks to ensure usage is only valid when
the target environment includes ext-mxfp and at least
v1.1.draft of the specification.
Note: currently it excludes support for mxint8. This will be
added in a later commit.
Note: this commit adds support as defined in the spec in
https://review.mlplatform.org/c/tosa/specification/+/15362. EXT_MXFP
extension is considered experimental and subject to breaking change.
Co-authored-by: Tat Wai Chong <tatwai.chong at arm.com>
Change-Id: I490645ce99b7ccd7021ed06acaf1530b4fbf6dfd
---
.../Dialect/Tosa/IR/TosaComplianceData.h.inc | 26 +++
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 63 +++++++
.../Dialect/Tosa/IR/TosaProfileCompliance.h | 2 +-
.../mlir/Dialect/Tosa/IR/TosaTypesBase.td | 10 ++
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 159 +++++++++++++++++-
.../Tosa/Transforms/TosaProfileCompliance.cpp | 32 +---
.../Tosa/Transforms/TosaValidation.cpp | 2 +
mlir/test/Dialect/Tosa/availability.mlir | 18 ++
mlir/test/Dialect/Tosa/invalid_extension.mlir | 16 ++
mlir/test/Dialect/Tosa/level_check.mlir | 33 +++-
mlir/test/Dialect/Tosa/ops.mlir | 28 +++
.../Tosa/profile_pro_fp_unsupported.mlir | 14 ++
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 45 +++++
.../tosa-validation-version-1p1-valid.mlir | 24 +++
mlir/test/Dialect/Tosa/verifier.mlir | 88 +++++++++-
15 files changed, 526 insertions(+), 34 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index 294fb9d99fdb6..fa058661bf5a2 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -753,6 +753,32 @@ extensionComplianceMap = {
{{fp8e5m2T, fp32T}, SpecificationVersion::V_1_0},
{{fp16T, fp8e5m2T}, SpecificationVersion::V_1_0},
{{fp32T, fp8e5m2T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.cast_from_block_scaled",
+ {{{Extension::bf16, Extension::mxfp},
+ {{{fp4e2m1T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e2m3T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e3m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e4m3T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e5m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}}, allOf},
+ {{Extension::mxfp},
+ {{{fp4e2m1T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e2m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e3m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e4m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e5m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}},
+ {"tosa.cast_to_block_scaled",
+ {{{Extension::mxfp},
+ {{{bf16T, fp4e2m1T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp32T, fp4e2m1T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp32T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp32T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp32T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp32T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}},
+ {{Extension::bf16, Extension::mxfp},
+ {{{bf16T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+ {{bf16T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+ {{bf16T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+ {{bf16T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}, allOf}}},
{"tosa.rescale",
{{{Extension::int16},
{{{i48T, i48T, i8T, i8T}, SpecificationVersion::V_1_0},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 6f07247b478c8..ba65baa0e9aaa 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2470,6 +2470,69 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure, SameOperandsAndResultShape,
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// Operator: cast_from_block_scaled
+//===----------------------------------------------------------------------===//
+def Tosa_CastFromBlockScaledOp: Tosa_InferShapedTypeOp<"cast_from_block_scaled"> {
+ let summary = "Apply scales from a scale tensor to the values in a value tensor";
+
+ let description = [{
+ Apply the scales from a scale tensor to the values in a value tensor, casting
+ the result to the output type. The block dimension must be the last dimension
+ of the tensor.
+ }];
+
+ let arguments = (ins
+ Tosa_MXFPDataTensorAtLeast1D:$input_data,
+ Tosa_MXFPScaleTensorAtLeast1D:$input_scale,
+ Tosa_BlockSizeAttr:$block_size
+ );
+
+ let results = (outs
+ Tosa_TensorAtLeast1D: $output_data
+ );
+
+ list<Availability> availability = [
+ Profile<[Tosa_PRO_FP]>,
+ Extension<[Tosa_EXT_BF16, Tosa_EXT_MXFP]>,
+ ];
+
+ let hasVerifier = 1;
+ let hasCustomAssemblyFormat = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: cast_to_block_scaled
+//===----------------------------------------------------------------------===//
+def Tosa_CastToBlockScaledOp : Tosa_InferShapedTypeOp<"cast_to_block_scaled"> {
+ let summary = "Calculate scale tensor values per block, output to separate scale and data tensors.";
+
+ let description = [{
+ Calculate a scale value per block of input values and use that to calculate
+ scaled data values from an input tensor. The output tensors are cast to the
+ specified scale and value types. The block dimension will be the last dimension
+ of the tensor.
+ }];
+
+ let arguments = (ins
+ Tosa_TensorAtLeast1D:$input_data,
+ Tosa_BlockSizeAttr:$block_size
+ );
+
+ let results = (outs
+ Tosa_MXFPDataTensorAtLeast1D:$output_data,
+ Tosa_MXFPScaleTensorAtLeast1D:$output_scale
+ );
+
+ list<Availability> availability = [
+ Profile<[Tosa_PRO_FP]>,
+ Extension<[Tosa_EXT_BF16, Tosa_EXT_MXFP]>
+ ];
+
+ let hasVerifier = 1;
+ let hasCustomAssemblyFormat = 1;
+}
+
//===----------------------------------------------------------------------===//
// Operator: rescale
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 79df1b888b40e..4a899e3c787e6 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -79,7 +79,7 @@ class ProfileInfoDepot {
LogicalResult populatationDispatch(Operation *op);
- LogicalResult populateProfileInfo(ValueRange operands, Value output);
+ LogicalResult populateProfileInfo(ValueRange operands, ValueRange output);
// Base
template <typename T>
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 20bb961482ad8..93843e86fd378 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -199,6 +199,16 @@ def Tosa_MXFPScaleTensor3D : AnyTypeOf<[
TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>,
TosaTensorRankOf<[Tosa_MXFPScaleNumber], [3]>
]>;
+def Tosa_MXFPDataTensorAtLeast1D : AnyTypeOf<[
+ TosaUnrankedTensorOf<[Tosa_MXFPNumber]>,
+ TosaRankedTensorOf<[Tosa_MXFPNumber], [AtLeastRankOne]>],
+ "tosa-conformant tensor of at least rank 1", "::mlir::TensorType"
+>;
+def Tosa_MXFPScaleTensorAtLeast1D : AnyTypeOf<[
+ TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>,
+ TosaRankedTensorOf<[Tosa_MXFPScaleNumber], [AtLeastRankOne]>],
+ "tosa-conformant tensor of at least rank 1", "::mlir::TensorType"
+>;
//===----------------------------------------------------------------------===//
// Generic scalar, vector, or tensor of a particular type.
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 6cd0eaea3ce6c..0aff67f0b5eba 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -370,7 +370,7 @@ ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) {
result.operands)))
return failure();
- result.addTypes(fnTy.getResult(0));
+ result.addTypes(fnTy.getResults());
result.addAttributes(attrs);
return success();
@@ -532,6 +532,24 @@ void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) {
printWithEnumHandling(parser, *this);
}
+ParseResult CastFromBlockScaledOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseWithEnumHandling<tosa::BlockSize>(parser, result);
+}
+
+void CastFromBlockScaledOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
+ParseResult CastToBlockScaledOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseWithEnumHandling<tosa::BlockSize>(parser, result);
+}
+
+void CastToBlockScaledOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
//===----------------------------------------------------------------------===//
// Tosa utilities.
//===----------------------------------------------------------------------===//
@@ -3944,6 +3962,145 @@ LogicalResult RescaleOp::inferReturnTypeComponents(
return success();
}
+LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ CastFromBlockScaledOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ const ShapeAdaptor inputShape(adaptor.getInputData().getType());
+ inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
+ return success();
+}
+
+LogicalResult CastFromBlockScaledOp::verify() {
+ const Type inputDataType = getInputData().getType();
+ const Type outputDataType = getResult().getType();
+ if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
+ return emitOpError() << "require compatible shapes for input_data ("
+ << inputDataType << ") and "
+ << "output_data (" << outputDataType << ")";
+
+ const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
+
+ if (inputDataShape.hasRank()) {
+ const unsigned int blockSize =
+ BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ const int64_t inputDataLastDim =
+ inputDataShape.getDimSize(inputDataShape.getRank() - 1);
+ if (inputDataLastDim % blockSize != 0)
+ return emitOpError() << "expect last dimension of input_data ("
+ << inputDataLastDim
+ << ") to be divisible by block_size (" << blockSize
+ << ")";
+
+ const Type inputScaleType = getInputScale().getType();
+ const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType);
+
+ if (inputScaleShape.hasRank()) {
+ SmallVector<int64_t> inputDataDims, inputScaleDims;
+ inputDataShape.getDims(inputDataDims);
+ inputScaleShape.getDims(inputScaleDims);
+
+ if (inputDataDims.size() != inputScaleDims.size() ||
+ failed(verifyCompatibleShape(
+ ArrayRef<int64_t>(inputDataDims).drop_back(1),
+ ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
+ return emitOpError() << "require compatible shapes for input_data ("
+ << inputDataType << ") and "
+ << "input_scale (" << inputScaleType
+ << ") except for the last dimension";
+
+ const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
+ inputScaleDims.back()};
+ if (ShapedType::isStatic(inputDataLastDim) &&
+ failed(verifyCompatibleDims(dimsToCheck)))
+ return emitOpError()
+ << "expect last dimension of input_scale ("
+ << inputScaleDims.back()
+ << ") to be equal to last dimension of input_data / block_size ("
+ << inputDataDims.back() / blockSize << ")";
+ }
+ }
+
+ return success();
+}
+
+LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ CastToBlockScaledOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ const ShapeAdaptor inputShape(adaptor.getInputData().getType());
+ inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
+ if (!inputShape.hasRank())
+ return success();
+
+ // Calculate output_scale shape if ranked input provided
+ SmallVector<int64_t> outputScaleShape;
+ inputShape.getDims(outputScaleShape);
+ const int64_t lastDimLoc = inputShape.getRank() - 1;
+ const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
+ if (ShapedType::isStatic(lastDimSize)) {
+ const unsigned int blockSize =
+ BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
+ outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
+ }
+ inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape));
+ return success();
+}
+
+LogicalResult CastToBlockScaledOp::verify() {
+ const Type inputDataType = getInputData().getType();
+ const Type outputDataType = getResult(0).getType();
+ if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
+ return emitOpError() << "require compatible shapes for input_data ("
+ << inputDataType << ") and "
+ << "output_data (" << outputDataType << ")";
+
+ const unsigned int blockSize =
+ BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
+ if (inputDataShape.hasRank()) {
+ const int64_t inputDataLastDim =
+ inputDataShape.getDimSize(inputDataShape.getRank() - 1);
+ if (ShapedType::isStatic(inputDataLastDim) &&
+ inputDataLastDim % blockSize != 0)
+ return emitOpError() << "expect last dimension of input_data ("
+ << inputDataLastDim
+ << ") to be divisible by block_size (" << blockSize
+ << ")";
+ }
+
+ const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType);
+ const Type outputScaleType = getResult(1).getType();
+ const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType);
+ if (outputDataShape.hasRank() && outputScaleShape.hasRank()) {
+ SmallVector<int64_t> outputDataDims, outputScaleDims;
+ outputDataShape.getDims(outputDataDims);
+ outputScaleShape.getDims(outputScaleDims);
+
+ if (outputDataDims.size() != outputScaleDims.size() ||
+ failed(verifyCompatibleShape(
+ ArrayRef<int64_t>(outputDataDims).drop_back(1),
+ ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
+ return emitOpError() << "require compatible shapes for output_data ("
+ << outputDataType << ") and "
+ << "output_scale (" << outputScaleType
+ << ") except for the last dimension";
+
+ const int64_t outputDataLastDim = outputDataDims.back();
+ const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize,
+ outputScaleDims.back()};
+ if (ShapedType::isStatic(outputDataLastDim) &&
+ failed(verifyCompatibleDims(dimsToCheck)))
+ return emitOpError()
+ << "expect last dimension of output_scale ("
+ << outputScaleDims.back()
+ << ") to be equal to last dimension of output_data / block_size ("
+ << outputDataDims.back() / blockSize << ")";
+ }
+
+ return success();
+}
+
LogicalResult IfOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
IfOp::Adaptor adaptor,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index e965ae0cf9888..92d5bac9c2653 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -50,10 +50,11 @@ TosaProfileCompliance::getProfileComplianceMap() {
// Base populating function
LogicalResult ProfileInfoDepot::populateProfileInfo(ValueRange operands,
- Value output) {
- for (auto operand : operands)
+ ValueRange outputs) {
+ for (const auto &operand : operands)
addValue(operand);
- addValue(output);
+ for (const auto &output : outputs)
+ addValue(output);
return success();
}
@@ -175,23 +176,6 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
return success();
}
-template <>
-LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) {
- addValue(op.getInputReal());
- addValue(op.getInputImag());
- addValue(op.getOutputReal());
- addValue(op.getOutputImag());
- return success();
-}
-
-template <>
-LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {
- addValue(op.getInputReal());
- addValue(op.getOutputReal());
- addValue(op.getOutputImag());
- return success();
-}
-
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
addValue(op.getOnTrue());
@@ -245,7 +229,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// This helper function populates the info for all operands.
#define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
if (isa<tosa::tosaOp##Op>(op)) { \
- return populateProfileInfo(op->getOperands(), op->getResult(0)); \
+ return populateProfileInfo(op->getOperands(), op->getResults()); \
}
// Skip irrelevant operands when they are independent and not tied to any
@@ -256,8 +240,6 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
POPULATE_PROFILE_INFO_CUSTOM(Conv3D)
POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D)
POPULATE_PROFILE_INFO_CUSTOM(Mul)
- POPULATE_PROFILE_INFO_CUSTOM(FFT2d)
- POPULATE_PROFILE_INFO_CUSTOM(RFFT2d)
POPULATE_PROFILE_INFO_CUSTOM(Concat)
POPULATE_PROFILE_INFO_CUSTOM(Pad)
POPULATE_PROFILE_INFO_CUSTOM(Reshape)
@@ -276,7 +258,11 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// For the most of tosa operators, all operands are profile/extension related
// and hence are all considered in this profile-based compilance check.
POPULATE_PROFILE_INFO_COMMON(MatmulTBlockScaled)
+ POPULATE_PROFILE_INFO_COMMON(FFT2d)
+ POPULATE_PROFILE_INFO_COMMON(RFFT2d)
POPULATE_PROFILE_INFO_COMMON(Cast)
+ POPULATE_PROFILE_INFO_COMMON(CastFromBlockScaled)
+ POPULATE_PROFILE_INFO_COMMON(CastToBlockScaled)
POPULATE_PROFILE_INFO_COMMON(Const)
POPULATE_PROFILE_INFO_COMMON(ArgMax)
POPULATE_PROFILE_INFO_COMMON(Sub)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 3f874d94ab9be..a142926bf87e2 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -635,6 +635,8 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
CHECK_RANKS_AND_SIZES(Transpose);
// Type Conversion
CHECK_RANKS_AND_SIZES(Cast);
+ CHECK_RANKS_AND_SIZES(CastFromBlockScaled);
+ CHECK_RANKS_AND_SIZES(CastToBlockScaled);
CHECK_RANKS_AND_SIZES(Rescale);
// Control Flow Operators
CHECK_RANKS_AND_SIZES(If);
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 600c4c717922a..d92d433a7d185 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -696,3 +696,21 @@ func.func @test_const_shape() -> !tosa.shape<4> {
return %cst : !tosa.shape<4>
}
+// -----
+// CHECK-LABEL: test_cast_from_block_scaled
+func.func @test_cast_from_block_scaled(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+ // CHECK: profiles: [ [pro_fp] ]
+ // CHECK: extensions: [ [bf16, mxfp] ]
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32>
+ return %0 : tensor<4x32xf32>
+}
+
+// -----
+// CHECK-LABEL: test_cast_to_block_scaled
+func.func @test_cast_to_block_scaled(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) {
+ // CHECK: profiles: [ [pro_fp] ]
+ // CHECK: extensions: [ [bf16, mxfp] ]
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = BLOCK_SIZE_32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>
+}
+
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 005601d4017b8..fff31c294a3f7 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -546,3 +546,19 @@ func.func @test_matmul_t_block_scaled(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: ten
%0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E4M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
return %0 : tensor<4x8x16xf32>
}
+
+// -----
+
+func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+ // expected-error at +1 {{'tosa.cast_from_block_scaled' op illegal: requires [mxfp] but not enabled in target}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf8E5M2>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32>
+ return %0 : tensor<4x32xf32>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>) {
+ // expected-error at +1 {{'tosa.cast_to_block_scaled' op illegal: requires [mxfp] but not enabled in target}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 8771e6e2476e4..cd392fcc20ea1 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1625,9 +1625,40 @@ func.func @test_unranked_weight_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor
// -----
-// CHECK-LABEL: test_matmul_t_block_scaled_invalid_size
func.func @test_matmul_t_block_scaled_invalid_size(%arg0: tensor<4x8x536870912xf4E2M1FN>, %arg1: tensor<4x8x16777216xf8E8M0FNU>, %arg2: tensor<4x16x536870912xf4E2M1FN>, %arg3: tensor<4x16x16777216xf8E8M0FNU>) -> tensor<*xf32> {
// expected-error at +1 {{'tosa.matmul_t_block_scaled' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
%0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x536870912xf4E2M1FN>, tensor<4x8x16777216xf8E8M0FNU>, tensor<4x16x536870912xf4E2M1FN>, tensor<4x16x16777216xf8E8M0FNU>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
+
+// -----
+
+func.func @test_cast_from_block_scaled_invalid_size(%arg0: tensor<536870912x32xf6E2M3FN>, %arg1: tensor<536870912x1xf8E8M0FNU>) -> tensor<536870912x32xf32> {
+ // expected-error at +1 {{'tosa.cast_from_block_scaled' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<536870912x32xf6E2M3FN>, tensor<536870912x1xf8E8M0FNU>) -> tensor<536870912x32xf32>
+ return %0 : tensor<536870912x32xf32>
+}
+
+// -----
+
+func.func @test_cast_from_block_scaled_invalid_rank(%arg0: tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, %arg1: tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>) -> tensor<1x2x3x4x5x6x7x32xf32> {
+ // expected-error at +1 {{'tosa.cast_from_block_scaled' op failed level check: operand rank(shape) <= MAX_RANK}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>) -> tensor<1x2x3x4x5x6x7x32xf32>
+ return %0 : tensor<1x2x3x4x5x6x7x32xf32>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_invalid_size(%arg0: tensor<536870912x32xf32>) -> (tensor<536870912x32xf6E2M3FN>, tensor<536870912x1xf8E8M0FNU>) {
+ // expected-error at +1 {{'tosa.cast_to_block_scaled' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<536870912x32xf32>) -> (tensor<536870912x32xf6E2M3FN>, tensor<536870912x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<536870912x32xf6E2M3FN>, tensor<536870912x1xf8E8M0FNU>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_invalid_rank(%arg0: tensor<1x2x3x4x5x6x7x32xf32>) -> (tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>) {
+ // expected-error at +1 {{'tosa.cast_to_block_scaled' op failed level check: operand rank(shape) <= MAX_RANK}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x2x3x4x5x6x7x32xf32>) -> (tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 9bf36b5fd4c7d..865f712ce1a5a 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -1268,3 +1268,31 @@ func.func @test_matmul_t_block_scaled_broadcast(%arg0: tensor<?x8x32xf8E4M3FN>,
%0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<?x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<?x16x32xf8E4M3FN>, tensor<1x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
return %0 : tensor<4x8x16xf32>
}
+
+// -----
+// CHECK-LABEL: test_cast_from_block_scaled_static
+func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32>
+ return %0 : tensor<4x32xf32>
+}
+
+// -----
+// CHECK-LABEL: test_cast_from_block_scaled_unranked
+func.func @test_cast_from_block_scaled_unranked(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>) -> tensor<*xf32> {
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+// CHECK-LABEL: test_cast_to_block_scaled_static
+func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) {
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>
+}
+
+// -----
+// CHECK-LABEL: test_cast_to_block_scaled_unranked
+func.func @test_cast_to_block_scaled_unranked(%arg0: tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) {
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>
+}
diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
index 0271d71561a52..7de7b85bcaedf 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
@@ -332,3 +332,17 @@ func.func @test_matmul_t_block_scaled(%arg0: tensor<4x8x32xf6E3M2FN>, %arg1: ten
%0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf6E3M2FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E3M2FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
return %0 : tensor<4x8x16xf32>
}
+
+// -----
+func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+ // expected-error at +1 {{'tosa.cast_from_block_scaled' op illegal: requires [pro_fp] but not enabled in target}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32>
+ return %0 : tensor<4x32xf32>
+}
+
+// -----
+func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) {
+ // expected-error at +1 {{'tosa.cast_to_block_scaled' op illegal: requires [pro_fp] but not enabled in target}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 72479fe21ade8..54556a0eb08e0 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1628,3 +1628,48 @@ func.func @test_matmul_t_block_scaled_broadcast_b_scale(%arg0: tensor<*xf8E4M3FN
%0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<1x4x1xf8E8M0FNU>) -> tensor<?x?x?xf32>
return %0 : tensor<?x?x?xf32>
}
+
+// -----
+
+// CHECK-LABEL: test_cast_from_block_scaled_static
+func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<*xf32> {
+ // CHECK: -> tensor<4x32xf32>
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_from_block_scaled_unranked_input_scale
+func.func @test_cast_from_block_scaled_unranked_input_scale(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>) -> tensor<*xf32> {
+ // CHECK: -> tensor<4x32xf32>
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<*xf8E8M0FNU>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_to_block_scaled_static
+func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) {
+ // CHECK: -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_to_block_scaled_unranked
+func.func @test_cast_to_block_scaled_unranked(%arg0: tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) {
+ // CHECK: -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_to_block_scaled_dynamic_scales
+func.func @test_cast_to_block_scaled_dynamic_scales(%arg0: tensor<4x?xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) {
+ // CHECK: -> (tensor<4x?xf4E2M1FN>, tensor<4x?xf8E8M0FNU>)
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x?xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index 2040a4bc7e6af..8b6cdc07925f0 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -26,3 +26,27 @@ func.func @test_matmul_t_block_scaled_fp6e2m3(%arg0: tensor<4x8x32xf6E2M3FN>, %a
%0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = BLOCK_SIZE_32} : (tensor<4x8x32xf6E2M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E2M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
return %0 : tensor<4x8x16xf32>
}
+
+// -----
+
+// CHECK-LABEL: test_cast_from_block_scaled_fp8e5m2_fp32
+func.func @test_cast_from_block_scaled_fp8e5m2_fp32(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf8E5M2>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32>
+ return %0 : tensor<4x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_from_block_scaled_fp8e5m2_bf16
+func.func @test_cast_from_block_scaled_fp8e5m2_bf16(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xbf16> {
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf8E5M2>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xbf16>
+ return %0 : tensor<4x32xbf16>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_to_block_scaled_static
+func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>) {
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>
+}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 4be5d725ad612..6cf76cdc7ad8e 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1033,7 +1033,6 @@ module {
// -----
-// CHECK-LABEL: @scatter_invalid_indices_N
func.func @scatter_invalid_indices_N(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<3x2xi32>, %arg2 : tensor<2x2x5xi32>) {
// expected-error at +1 {{'tosa.scatter' op requires indices dimension 0 to have size 2, got 3}}
%1 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<3x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x5xi32>
@@ -1042,7 +1041,6 @@ func.func @scatter_invalid_indices_N(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<3
// -----
-// CHECK-LABEL: @scatter_invalid_input_N
func.func @scatter_invalid_input_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<2x2xi32>, %arg2 : tensor<3x2x5xi32>) {
// expected-error at +1 {{'tosa.scatter' op requires input dimension 0 to have size 2, got 3}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<2x2xi32>, tensor<3x2x5xi32>) -> tensor<2x4x5xi32>
@@ -1051,7 +1049,6 @@ func.func @scatter_invalid_input_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<2x2
// -----
-// CHECK-LABEL: @scatter_invalid_out_N
func.func @scatter_invalid_out_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
// expected-error at +1 {{'tosa.scatter' op requires values_out dimension 0 to have size 2, got 3}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<3x4x5xi32>
@@ -1060,7 +1057,6 @@ func.func @scatter_invalid_out_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi
// -----
-// CHECK-LABEL: @scatter_invalid_out_K
func.func @scatter_invalid_out_K(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
// expected-error at +1 {{'tosa.scatter' op requires values_out dimension 1 to have size 4, got 3}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x3x5xi32>
@@ -1069,7 +1065,6 @@ func.func @scatter_invalid_out_K(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi
// -----
-// CHECK-LABEL: @scatter_invalid_input_W
func.func @scatter_invalid_input_W(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x3x5xi32>) {
// expected-error at +1 {{'tosa.scatter' op requires input dimension 1 to have size 2, got 3}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x3x5xi32>) -> tensor<2x4x5xi32>
@@ -1078,7 +1073,6 @@ func.func @scatter_invalid_input_W(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2
// -----
-// CHECK-LABEL: @scatter_invalid_input_C
func.func @scatter_invalid_input_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x6xi32>) {
// expected-error at +1 {{'tosa.scatter' op requires input dimension 2 to have size 5, got 6}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x6xi32>) -> tensor<2x4x5xi32>
@@ -1087,7 +1081,6 @@ func.func @scatter_invalid_input_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2
// -----
-// CHECK-LABEL: @scatter_invalid_out_C
func.func @scatter_invalid_out_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
// expected-error at +1 {{'tosa.scatter' op requires values_out dimension 2 to have size 5, got 6}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x6xi32>
@@ -1096,7 +1089,6 @@ func.func @scatter_invalid_out_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi
// -----
-// CHECK-LABEL: @scatter_invalid_K_W
func.func @scatter_invalid_K_W(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<2x6xi32>, %arg2 : tensor<2x6x5xi32>) {
// expected-error at +1 {{'tosa.scatter' op requires dimensions K >= W, got K=4 and W=6}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<2x6xi32>, tensor<2x6x5xi32>) -> tensor<2x4x5xi32>
@@ -1150,3 +1142,83 @@ func.func @test_matmul_t_block_scaled_batch_mismatch(%arg0: tensor<4x8x32xf8E4M3
%0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<2x16x32xf8E4M3FN>, tensor<2x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
return %0 : tensor<4x8x16xf32>
}
+
+// -----
+
+func.func @cast_from_block_scaled_incompatible_input_output_shape(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<5x32xf32> {
+ // expected-error at +1 {{'tosa.cast_from_block_scaled' op require compatible shapes for input_data ('tensor<4x32xf4E2M1FN>') and output_data ('tensor<5x32xf32>')}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<5x32xf32>
+ return %0 : tensor<5x32xf32>
+}
+
+// -----
+
+func.func @cast_from_block_scaled_not_scalar(%arg0: tensor<f4E2M1FN>, %arg1: tensor<f8E8M0FNU>) -> tensor<f32> {
+ // expected-error at +1 {{'tosa.cast_from_block_scaled' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f4E2M1FN>'}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<f4E2M1FN>, tensor<f8E8M0FNU>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @cast_from_block_scaled_not_divisible_by_block_size(%arg0: tensor<4x33xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x33xf32> {
+ // expected-error at +1 {{'tosa.cast_from_block_scaled' op expect last dimension of input_data (33) to be divisible by block_size (32)}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x33xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x33xf32>
+ return %0 : tensor<4x33xf32>
+}
+
+// -----
+
+func.func @cast_from_block_scaled_data_scale_mismatch(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<5x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+ // expected-error at +1 {{'tosa.cast_from_block_scaled' op require compatible shapes for input_data ('tensor<4x32xf4E2M1FN>') and input_scale ('tensor<5x1xf8E8M0FNU>') except for the last dimension}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<5x1xf8E8M0FNU>) -> tensor<4x32xf32>
+ return %0 : tensor<4x32xf32>
+}
+
+// -----
+
+func.func @cast_from_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x2xf8E8M0FNU>) -> tensor<4x32xf32> {
+ // expected-error at +1 {{'tosa.cast_from_block_scaled' op expect last dimension of input_scale (2) to be equal to last dimension of input_data / block_size (1)}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>) -> tensor<4x32xf32>
+ return %0 : tensor<4x32xf32>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_incompatible_input_output_shape(%arg0: tensor<4x32xf32>) -> (tensor<5x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) {
+ // expected-error at +1 {{'tosa.cast_to_block_scaled' op require compatible shapes for input_data ('tensor<4x32xf32>') and output_data ('tensor<5x32xf4E2M1FN>')}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<5x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<5x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_not_scalar(%arg0: tensor<f32>) -> (tensor<f4E2M1FN>, tensor<f8E8M0FNU>) {
+ // expected-error at +1 {{'tosa.cast_to_block_scaled' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<f32>) -> (tensor<f4E2M1FN>, tensor<f8E8M0FNU>)
+ return %0#0, %0#1 : tensor<f4E2M1FN>, tensor<f8E8M0FNU>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_not_divisible_by_block_size(%arg0: tensor<4x33xf32>) -> (tensor<4x33xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) {
+ // expected-error at +1 {{'tosa.cast_to_block_scaled' op expect last dimension of input_data (33) to be divisible by block_size (32)}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x33xf32>) -> (tensor<4x33xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<4x33xf4E2M1FN>, tensor<4x1xf8E8M0FNU>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_data_scale_mismatch(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<5x1xf8E8M0FNU>) {
+ // expected-error at +1 {{'tosa.cast_to_block_scaled' op require compatible shapes for output_data ('tensor<4x32xf4E2M1FN>') and output_scale ('tensor<5x1xf8E8M0FNU>') except for the last dimension}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<5x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<5x1xf8E8M0FNU>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>) {
+ // expected-error at +1 {{'tosa.cast_to_block_scaled' op expect last dimension of output_scale (2) to be equal to last dimension of output_data / block_size (1)}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>
+}
More information about the Mlir-commits
mailing list