[Mlir-commits] [mlir] [mlir][tosa] Add support for cast_from/to_block_scaled (PR #163436)

Luke Hutton llvmlistbot at llvm.org
Thu Oct 23 08:19:49 PDT 2025


https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/163436

>From 04ed197ad515de3bcf9077a6dd19038c63a013fc Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Mon, 15 Sep 2025 21:52:24 +0000
Subject: [PATCH] [mlir][tosa] Add support for cast_from/to_block_scaled

This commit adds support for the cast_from/to_block_scaled
operations from the ext-mxfp extension. This includes:
- Operation definition in TosaOps.td
- Micro-scaling supported types definition
- Shape inference and verifiers
- Validation pass checks to ensure usage is only valid when
  the target environment includes ext-mxfp and at least
  v1.1.draft of the specification.

Note: currently it excludes support for mxint8. This will be
added in a later commit.

Note: this commit adds support as defined in the spec in
https://review.mlplatform.org/c/tosa/specification/+/15362. EXT_MXFP
extension is considered experimental and subject to breaking change.

Co-authored-by: Tat Wai Chong <tatwai.chong at arm.com>
Change-Id: I490645ce99b7ccd7021ed06acaf1530b4fbf6dfd
---
 .../Dialect/Tosa/IR/TosaComplianceData.h.inc  |  26 +++
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  |  63 +++++++
 .../Dialect/Tosa/IR/TosaProfileCompliance.h   |   2 +-
 .../mlir/Dialect/Tosa/IR/TosaTypesBase.td     |  10 ++
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 159 +++++++++++++++++-
 .../Tosa/Transforms/TosaProfileCompliance.cpp |  32 +---
 .../Tosa/Transforms/TosaValidation.cpp        |   2 +
 mlir/test/Dialect/Tosa/availability.mlir      |  18 ++
 mlir/test/Dialect/Tosa/invalid_extension.mlir |  16 ++
 mlir/test/Dialect/Tosa/level_check.mlir       |  33 +++-
 mlir/test/Dialect/Tosa/ops.mlir               |  28 +++
 .../Tosa/profile_pro_fp_unsupported.mlir      |  14 ++
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir |  45 +++++
 .../tosa-validation-version-1p1-valid.mlir    |  24 +++
 mlir/test/Dialect/Tosa/verifier.mlir          |  88 +++++++++-
 15 files changed, 526 insertions(+), 34 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index 294fb9d99fdb6..fa058661bf5a2 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -753,6 +753,32 @@ extensionComplianceMap = {
         {{fp8e5m2T, fp32T}, SpecificationVersion::V_1_0},
         {{fp16T, fp8e5m2T}, SpecificationVersion::V_1_0},
         {{fp32T, fp8e5m2T}, SpecificationVersion::V_1_0}}}}},
+    {"tosa.cast_from_block_scaled",
+     {{{Extension::bf16, Extension::mxfp},
+       {{{fp4e2m1T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e2m3T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e3m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e4m3T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e5m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}}, allOf},
+      {{Extension::mxfp},
+       {{{fp4e2m1T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e2m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e3m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e4m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e5m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}},
+    {"tosa.cast_to_block_scaled",
+      {{{Extension::mxfp},
+        {{{bf16T, fp4e2m1T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+         {{fp32T, fp4e2m1T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+         {{fp32T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+         {{fp32T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+         {{fp32T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+         {{fp32T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}},
+       {{Extension::bf16, Extension::mxfp},
+        {{{bf16T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+         {{bf16T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+         {{bf16T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+         {{bf16T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}, allOf}}},
     {"tosa.rescale",
      {{{Extension::int16},
        {{{i48T, i48T, i8T, i8T}, SpecificationVersion::V_1_0},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 6f07247b478c8..ba65baa0e9aaa 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2470,6 +2470,69 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure, SameOperandsAndResultShape,
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Operator: cast_from_block_scaled
+//===----------------------------------------------------------------------===//
+def Tosa_CastFromBlockScaledOp: Tosa_InferShapedTypeOp<"cast_from_block_scaled"> {
+  let summary = "Apply scales from a scale tensor to the values in a value tensor";
+
+  let description = [{
+    Apply the scales from a scale tensor to the values in a value tensor, casting
+    the result to the output type. The block dimension must be the last dimension
+    of the tensor.
+  }];
+
+  let arguments = (ins
+    Tosa_MXFPDataTensorAtLeast1D:$input_data,
+    Tosa_MXFPScaleTensorAtLeast1D:$input_scale,
+    Tosa_BlockSizeAttr:$block_size
+  );
+
+  let results = (outs
+    Tosa_TensorAtLeast1D: $output_data
+  );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16, Tosa_EXT_MXFP]>,
+  ];
+
+  let hasVerifier = 1;
+  let hasCustomAssemblyFormat = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: cast_to_block_scaled
+//===----------------------------------------------------------------------===//
+def Tosa_CastToBlockScaledOp : Tosa_InferShapedTypeOp<"cast_to_block_scaled"> {
+  let summary = "Calculate scale tensor values per block, output to separate scale and data tensors.";
+
+  let description = [{
+    Calculate a scale value per block of input values and use that to calculate
+    scaled data values from an input tensor. The output tensors are cast to the
+    specified scale and value types. The block dimension will be the last dimension
+    of the tensor.
+  }];
+
+  let arguments = (ins
+    Tosa_TensorAtLeast1D:$input_data,
+    Tosa_BlockSizeAttr:$block_size
+  );
+
+  let results = (outs
+    Tosa_MXFPDataTensorAtLeast1D:$output_data,
+    Tosa_MXFPScaleTensorAtLeast1D:$output_scale
+  );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_BF16, Tosa_EXT_MXFP]>
+  ];
+
+  let hasVerifier = 1;
+  let hasCustomAssemblyFormat = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Operator: rescale
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 79df1b888b40e..4a899e3c787e6 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -79,7 +79,7 @@ class ProfileInfoDepot {
 
   LogicalResult populatationDispatch(Operation *op);
 
-  LogicalResult populateProfileInfo(ValueRange operands, Value output);
+  LogicalResult populateProfileInfo(ValueRange operands, ValueRange output);
 
   // Base
   template <typename T>
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 20bb961482ad8..93843e86fd378 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -199,6 +199,16 @@ def Tosa_MXFPScaleTensor3D : AnyTypeOf<[
   TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>,
   TosaTensorRankOf<[Tosa_MXFPScaleNumber], [3]>
 ]>;
+def Tosa_MXFPDataTensorAtLeast1D : AnyTypeOf<[
+  TosaUnrankedTensorOf<[Tosa_MXFPNumber]>,
+  TosaRankedTensorOf<[Tosa_MXFPNumber], [AtLeastRankOne]>],
+  "tosa-conformant tensor of at least rank 1", "::mlir::TensorType"
+>;
+def Tosa_MXFPScaleTensorAtLeast1D : AnyTypeOf<[
+  TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>,
+  TosaRankedTensorOf<[Tosa_MXFPScaleNumber], [AtLeastRankOne]>],
+  "tosa-conformant tensor of at least rank 1", "::mlir::TensorType"
+>;
 
 //===----------------------------------------------------------------------===//
 // Generic scalar, vector, or tensor of a particular type.
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 6cd0eaea3ce6c..0aff67f0b5eba 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -370,7 +370,7 @@ ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) {
                                     result.operands)))
     return failure();
 
-  result.addTypes(fnTy.getResult(0));
+  result.addTypes(fnTy.getResults());
   result.addAttributes(attrs);
 
   return success();
@@ -532,6 +532,24 @@ void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) {
   printWithEnumHandling(parser, *this);
 }
 
+ParseResult CastFromBlockScaledOp::parse(OpAsmParser &parser,
+                                         OperationState &result) {
+  return parseWithEnumHandling<tosa::BlockSize>(parser, result);
+}
+
+void CastFromBlockScaledOp::print(OpAsmPrinter &parser) {
+  printWithEnumHandling(parser, *this);
+}
+
+ParseResult CastToBlockScaledOp::parse(OpAsmParser &parser,
+                                       OperationState &result) {
+  return parseWithEnumHandling<tosa::BlockSize>(parser, result);
+}
+
+void CastToBlockScaledOp::print(OpAsmPrinter &parser) {
+  printWithEnumHandling(parser, *this);
+}
+
 //===----------------------------------------------------------------------===//
 // Tosa utilities.
 //===----------------------------------------------------------------------===//
@@ -3944,6 +3962,145 @@ LogicalResult RescaleOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
+    MLIRContext *context, ::std::optional<Location> location,
+    CastFromBlockScaledOp::Adaptor adaptor,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  const ShapeAdaptor inputShape(adaptor.getInputData().getType());
+  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
+  return success();
+}
+
+LogicalResult CastFromBlockScaledOp::verify() {
+  const Type inputDataType = getInputData().getType();
+  const Type outputDataType = getResult().getType();
+  if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
+    return emitOpError() << "require compatible shapes for input_data ("
+                         << inputDataType << ") and "
+                         << "output_data (" << outputDataType << ")";
+
+  const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
+
+  if (inputDataShape.hasRank()) {
+    const unsigned int blockSize =
+        BlockSizeAttr::getBlockSizeValue(getBlockSize());
+    const int64_t inputDataLastDim =
+        inputDataShape.getDimSize(inputDataShape.getRank() - 1);
+    if (inputDataLastDim % blockSize != 0)
+      return emitOpError() << "expect last dimension of input_data ("
+                           << inputDataLastDim
+                           << ") to be divisible by block_size (" << blockSize
+                           << ")";
+
+    const Type inputScaleType = getInputScale().getType();
+    const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType);
+
+    if (inputScaleShape.hasRank()) {
+      SmallVector<int64_t> inputDataDims, inputScaleDims;
+      inputDataShape.getDims(inputDataDims);
+      inputScaleShape.getDims(inputScaleDims);
+
+      if (inputDataDims.size() != inputScaleDims.size() ||
+          failed(verifyCompatibleShape(
+              ArrayRef<int64_t>(inputDataDims).drop_back(1),
+              ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
+        return emitOpError() << "require compatible shapes for input_data ("
+                             << inputDataType << ") and "
+                             << "input_scale (" << inputScaleType
+                             << ") except for the last dimension";
+
+      const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
+                                                inputScaleDims.back()};
+      if (ShapedType::isStatic(inputDataLastDim) &&
+          failed(verifyCompatibleDims(dimsToCheck)))
+        return emitOpError()
+               << "expect last dimension of input_scale ("
+               << inputScaleDims.back()
+               << ") to be equal to last dimension of input_data / block_size ("
+               << inputDataDims.back() / blockSize << ")";
+    }
+  }
+
+  return success();
+}
+
+LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
+    MLIRContext *context, ::std::optional<Location> location,
+    CastToBlockScaledOp::Adaptor adaptor,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  const ShapeAdaptor inputShape(adaptor.getInputData().getType());
+  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
+  if (!inputShape.hasRank())
+    return success();
+
+  // Calculate output_scale shape if ranked input provided
+  SmallVector<int64_t> outputScaleShape;
+  inputShape.getDims(outputScaleShape);
+  const int64_t lastDimLoc = inputShape.getRank() - 1;
+  const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
+  if (ShapedType::isStatic(lastDimSize)) {
+    const unsigned int blockSize =
+        BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
+    outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
+  }
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape));
+  return success();
+}
+
+LogicalResult CastToBlockScaledOp::verify() {
+  const Type inputDataType = getInputData().getType();
+  const Type outputDataType = getResult(0).getType();
+  if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
+    return emitOpError() << "require compatible shapes for input_data ("
+                         << inputDataType << ") and "
+                         << "output_data (" << outputDataType << ")";
+
+  const unsigned int blockSize =
+      BlockSizeAttr::getBlockSizeValue(getBlockSize());
+  const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
+  if (inputDataShape.hasRank()) {
+    const int64_t inputDataLastDim =
+        inputDataShape.getDimSize(inputDataShape.getRank() - 1);
+    if (ShapedType::isStatic(inputDataLastDim) &&
+        inputDataLastDim % blockSize != 0)
+      return emitOpError() << "expect last dimension of input_data ("
+                           << inputDataLastDim
+                           << ") to be divisible by block_size (" << blockSize
+                           << ")";
+  }
+
+  const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType);
+  const Type outputScaleType = getResult(1).getType();
+  const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType);
+  if (outputDataShape.hasRank() && outputScaleShape.hasRank()) {
+    SmallVector<int64_t> outputDataDims, outputScaleDims;
+    outputDataShape.getDims(outputDataDims);
+    outputScaleShape.getDims(outputScaleDims);
+
+    if (outputDataDims.size() != outputScaleDims.size() ||
+        failed(verifyCompatibleShape(
+            ArrayRef<int64_t>(outputDataDims).drop_back(1),
+            ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
+      return emitOpError() << "require compatible shapes for output_data ("
+                           << outputDataType << ") and "
+                           << "output_scale (" << outputScaleType
+                           << ") except for the last dimension";
+
+    const int64_t outputDataLastDim = outputDataDims.back();
+    const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize,
+                                              outputScaleDims.back()};
+    if (ShapedType::isStatic(outputDataLastDim) &&
+        failed(verifyCompatibleDims(dimsToCheck)))
+      return emitOpError()
+             << "expect last dimension of output_scale ("
+             << outputScaleDims.back()
+             << ") to be equal to last dimension of output_data / block_size ("
+             << outputDataDims.back() / blockSize << ")";
+  }
+
+  return success();
+}
+
 LogicalResult IfOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     IfOp::Adaptor adaptor,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index e965ae0cf9888..92d5bac9c2653 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -50,10 +50,11 @@ TosaProfileCompliance::getProfileComplianceMap() {
 
 // Base populating function
 LogicalResult ProfileInfoDepot::populateProfileInfo(ValueRange operands,
-                                                    Value output) {
-  for (auto operand : operands)
+                                                    ValueRange outputs) {
+  for (const auto &operand : operands)
     addValue(operand);
-  addValue(output);
+  for (const auto &output : outputs)
+    addValue(output);
   return success();
 }
 
@@ -175,23 +176,6 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
   return success();
 }
 
-template <>
-LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) {
-  addValue(op.getInputReal());
-  addValue(op.getInputImag());
-  addValue(op.getOutputReal());
-  addValue(op.getOutputImag());
-  return success();
-}
-
-template <>
-LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {
-  addValue(op.getInputReal());
-  addValue(op.getOutputReal());
-  addValue(op.getOutputImag());
-  return success();
-}
-
 template <>
 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
   addValue(op.getOnTrue());
@@ -245,7 +229,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
 // This helper function populates the info for all operands.
 #define POPULATE_PROFILE_INFO_COMMON(tosaOp)                                   \
   if (isa<tosa::tosaOp##Op>(op)) {                                             \
-    return populateProfileInfo(op->getOperands(), op->getResult(0));           \
+    return populateProfileInfo(op->getOperands(), op->getResults());           \
   }
 
   // Skip irrelevant operands when they are independent and not tied to any
@@ -256,8 +240,6 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
   POPULATE_PROFILE_INFO_CUSTOM(Conv3D)
   POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D)
   POPULATE_PROFILE_INFO_CUSTOM(Mul)
-  POPULATE_PROFILE_INFO_CUSTOM(FFT2d)
-  POPULATE_PROFILE_INFO_CUSTOM(RFFT2d)
   POPULATE_PROFILE_INFO_CUSTOM(Concat)
   POPULATE_PROFILE_INFO_CUSTOM(Pad)
   POPULATE_PROFILE_INFO_CUSTOM(Reshape)
@@ -276,7 +258,11 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
   // For the most of tosa operators, all operands are profile/extension related
   // and hence are all considered in this profile-based compilance check.
   POPULATE_PROFILE_INFO_COMMON(MatmulTBlockScaled)
+  POPULATE_PROFILE_INFO_COMMON(FFT2d)
+  POPULATE_PROFILE_INFO_COMMON(RFFT2d)
   POPULATE_PROFILE_INFO_COMMON(Cast)
+  POPULATE_PROFILE_INFO_COMMON(CastFromBlockScaled)
+  POPULATE_PROFILE_INFO_COMMON(CastToBlockScaled)
   POPULATE_PROFILE_INFO_COMMON(Const)
   POPULATE_PROFILE_INFO_COMMON(ArgMax)
   POPULATE_PROFILE_INFO_COMMON(Sub)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 3f874d94ab9be..a142926bf87e2 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -635,6 +635,8 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
   CHECK_RANKS_AND_SIZES(Transpose);
   // Type Conversion
   CHECK_RANKS_AND_SIZES(Cast);
+  CHECK_RANKS_AND_SIZES(CastFromBlockScaled);
+  CHECK_RANKS_AND_SIZES(CastToBlockScaled);
   CHECK_RANKS_AND_SIZES(Rescale);
   // Control Flow Operators
   CHECK_RANKS_AND_SIZES(If);
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 600c4c717922a..d92d433a7d185 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -696,3 +696,21 @@ func.func @test_const_shape() -> !tosa.shape<4> {
   return %cst : !tosa.shape<4>
 }
 
+// -----
+// CHECK-LABEL: test_cast_from_block_scaled
+func.func @test_cast_from_block_scaled(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+  // CHECK: profiles: [ [pro_fp] ]
+  // CHECK: extensions: [ [bf16, mxfp] ]
+  %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32>
+  return %0 : tensor<4x32xf32>
+}
+
+// -----
+// CHECK-LABEL: test_cast_to_block_scaled
+func.func @test_cast_to_block_scaled(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) {
+  // CHECK: profiles: [ [pro_fp] ]
+  // CHECK: extensions: [ [bf16, mxfp] ]
+  %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = BLOCK_SIZE_32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>
+}
+
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 005601d4017b8..fff31c294a3f7 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -546,3 +546,19 @@ func.func @test_matmul_t_block_scaled(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: ten
   %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E4M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
   return %0 : tensor<4x8x16xf32>
 }
+
+// -----
+
+func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+  // expected-error at +1 {{'tosa.cast_from_block_scaled' op illegal: requires [mxfp] but not enabled in target}}
+  %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf8E5M2>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32>
+  return %0 : tensor<4x32xf32>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>) {
+  // expected-error at +1 {{'tosa.cast_to_block_scaled' op illegal: requires [mxfp] but not enabled in target}}
+  %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 8771e6e2476e4..cd392fcc20ea1 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1625,9 +1625,40 @@ func.func @test_unranked_weight_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor
 
 // -----
 
-// CHECK-LABEL: test_matmul_t_block_scaled_invalid_size
 func.func @test_matmul_t_block_scaled_invalid_size(%arg0: tensor<4x8x536870912xf4E2M1FN>, %arg1: tensor<4x8x16777216xf8E8M0FNU>, %arg2: tensor<4x16x536870912xf4E2M1FN>, %arg3: tensor<4x16x16777216xf8E8M0FNU>) -> tensor<*xf32> {
   // expected-error at +1 {{'tosa.matmul_t_block_scaled' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
   %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x536870912xf4E2M1FN>, tensor<4x8x16777216xf8E8M0FNU>, tensor<4x16x536870912xf4E2M1FN>, tensor<4x16x16777216xf8E8M0FNU>) -> tensor<*xf32>
   return %0 : tensor<*xf32>
 }
+
+// -----
+
+func.func @test_cast_from_block_scaled_invalid_size(%arg0: tensor<536870912x32xf6E2M3FN>, %arg1: tensor<536870912x1xf8E8M0FNU>) -> tensor<536870912x32xf32> {
+  // expected-error at +1 {{'tosa.cast_from_block_scaled' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
+  %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<536870912x32xf6E2M3FN>, tensor<536870912x1xf8E8M0FNU>) -> tensor<536870912x32xf32>
+  return %0 : tensor<536870912x32xf32>
+}
+
+// -----
+
+func.func @test_cast_from_block_scaled_invalid_rank(%arg0: tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, %arg1: tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>) -> tensor<1x2x3x4x5x6x7x32xf32> {
+  // expected-error at +1 {{'tosa.cast_from_block_scaled' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>) -> tensor<1x2x3x4x5x6x7x32xf32>
+  return %0 : tensor<1x2x3x4x5x6x7x32xf32>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_invalid_size(%arg0: tensor<536870912x32xf32>) -> (tensor<536870912x32xf6E2M3FN>, tensor<536870912x1xf8E8M0FNU>) {
+  // expected-error at +1 {{'tosa.cast_to_block_scaled' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
+  %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<536870912x32xf32>) -> (tensor<536870912x32xf6E2M3FN>, tensor<536870912x1xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<536870912x32xf6E2M3FN>, tensor<536870912x1xf8E8M0FNU>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_invalid_rank(%arg0: tensor<1x2x3x4x5x6x7x32xf32>) -> (tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>) {
+  // expected-error at +1 {{'tosa.cast_to_block_scaled' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x2x3x4x5x6x7x32xf32>) -> (tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 9bf36b5fd4c7d..865f712ce1a5a 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -1268,3 +1268,31 @@ func.func @test_matmul_t_block_scaled_broadcast(%arg0: tensor<?x8x32xf8E4M3FN>,
   %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<?x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<?x16x32xf8E4M3FN>, tensor<1x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
   return %0 : tensor<4x8x16xf32>
 }
+
+// -----
+// CHECK-LABEL: test_cast_from_block_scaled_static
+func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+  %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32>
+  return %0 : tensor<4x32xf32>
+}
+
+// -----
+// CHECK-LABEL: test_cast_from_block_scaled_unranked
+func.func @test_cast_from_block_scaled_unranked(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>) -> tensor<*xf32> {
+  %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+// CHECK-LABEL: test_cast_to_block_scaled_static
+func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) {
+  %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>
+}
+
+// -----
+// CHECK-LABEL: test_cast_to_block_scaled_unranked
+func.func @test_cast_to_block_scaled_unranked(%arg0: tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) {
+  %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>
+}
diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
index 0271d71561a52..7de7b85bcaedf 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
@@ -332,3 +332,17 @@ func.func @test_matmul_t_block_scaled(%arg0: tensor<4x8x32xf6E3M2FN>, %arg1: ten
   %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf6E3M2FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E3M2FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
   return %0 : tensor<4x8x16xf32>
 }
+
+// -----
+func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+  // expected-error at +1 {{'tosa.cast_from_block_scaled' op illegal: requires [pro_fp] but not enabled in target}}
+  %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32>
+  return %0 : tensor<4x32xf32>
+}
+
+// -----
+func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) {
+  // expected-error at +1 {{'tosa.cast_to_block_scaled' op illegal: requires [pro_fp] but not enabled in target}}
+  %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 72479fe21ade8..54556a0eb08e0 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1628,3 +1628,48 @@ func.func @test_matmul_t_block_scaled_broadcast_b_scale(%arg0: tensor<*xf8E4M3FN
   %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<1x4x1xf8E8M0FNU>) -> tensor<?x?x?xf32>
   return %0 : tensor<?x?x?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: test_cast_from_block_scaled_static
+func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<*xf32> {
+  // CHECK: -> tensor<4x32xf32>
+  %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_from_block_scaled_unranked_input_scale
+func.func @test_cast_from_block_scaled_unranked_input_scale(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>) -> tensor<*xf32> {
+  // CHECK: -> tensor<4x32xf32>
+  %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<*xf8E8M0FNU>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_to_block_scaled_static
+func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) {
+  // CHECK: -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
+  %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_to_block_scaled_unranked
+func.func @test_cast_to_block_scaled_unranked(%arg0: tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) {
+  // CHECK: -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
+  %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_to_block_scaled_dynamic_scales
+func.func @test_cast_to_block_scaled_dynamic_scales(%arg0: tensor<4x?xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) {
+  // CHECK: -> (tensor<4x?xf4E2M1FN>, tensor<4x?xf8E8M0FNU>)
+  %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x?xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index 2040a4bc7e6af..8b6cdc07925f0 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -26,3 +26,27 @@ func.func @test_matmul_t_block_scaled_fp6e2m3(%arg0: tensor<4x8x32xf6E2M3FN>, %a
   %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = BLOCK_SIZE_32} : (tensor<4x8x32xf6E2M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E2M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
   return %0 : tensor<4x8x16xf32>
 }
+
+// -----
+
+// CHECK-LABEL: test_cast_from_block_scaled_fp8e5m2_fp32
+func.func @test_cast_from_block_scaled_fp8e5m2_fp32(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+  %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf8E5M2>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32>
+  return %0 : tensor<4x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_from_block_scaled_fp8e5m2_bf16
+func.func @test_cast_from_block_scaled_fp8e5m2_bf16(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xbf16> {
+  %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf8E5M2>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xbf16>
+  return %0 : tensor<4x32xbf16>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_to_block_scaled_static
+func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>) {
+  %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>
+}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 4be5d725ad612..6cf76cdc7ad8e 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1033,7 +1033,6 @@ module {
 
 // -----
 
-// CHECK-LABEL: @scatter_invalid_indices_N
 func.func @scatter_invalid_indices_N(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<3x2xi32>, %arg2 : tensor<2x2x5xi32>) {
   // expected-error at +1 {{'tosa.scatter' op requires indices dimension 0 to have size 2, got 3}}
   %1 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<3x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x5xi32>
@@ -1042,7 +1041,6 @@ func.func @scatter_invalid_indices_N(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<3
 
 // -----
 
-// CHECK-LABEL: @scatter_invalid_input_N
 func.func @scatter_invalid_input_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<2x2xi32>, %arg2 : tensor<3x2x5xi32>) {
   // expected-error at +1 {{'tosa.scatter' op requires input dimension 0 to have size 2, got 3}}
   %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<2x2xi32>, tensor<3x2x5xi32>) -> tensor<2x4x5xi32>
@@ -1051,7 +1049,6 @@ func.func @scatter_invalid_input_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<2x2
 
 // -----
 
-// CHECK-LABEL: @scatter_invalid_out_N
 func.func @scatter_invalid_out_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
   // expected-error at +1 {{'tosa.scatter' op requires values_out dimension 0 to have size 2, got 3}}
   %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<3x4x5xi32>
@@ -1060,7 +1057,6 @@ func.func @scatter_invalid_out_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi
 
 // -----
 
-// CHECK-LABEL: @scatter_invalid_out_K
 func.func @scatter_invalid_out_K(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
   // expected-error at +1 {{'tosa.scatter' op requires values_out dimension 1 to have size 4, got 3}}
   %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x3x5xi32>
@@ -1069,7 +1065,6 @@ func.func @scatter_invalid_out_K(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi
 
 // -----
 
-// CHECK-LABEL: @scatter_invalid_input_W
 func.func @scatter_invalid_input_W(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x3x5xi32>) {
   // expected-error at +1 {{'tosa.scatter' op requires input dimension 1 to have size 2, got 3}}
   %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x3x5xi32>) -> tensor<2x4x5xi32>
@@ -1078,7 +1073,6 @@ func.func @scatter_invalid_input_W(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2
 
 // -----
 
-// CHECK-LABEL: @scatter_invalid_input_C
 func.func @scatter_invalid_input_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x6xi32>) {
   // expected-error at +1 {{'tosa.scatter' op requires input dimension 2 to have size 5, got 6}}
   %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x6xi32>) -> tensor<2x4x5xi32>
@@ -1087,7 +1081,6 @@ func.func @scatter_invalid_input_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2
 
 // -----
 
-// CHECK-LABEL: @scatter_invalid_out_C
 func.func @scatter_invalid_out_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
   // expected-error at +1 {{'tosa.scatter' op requires values_out dimension 2 to have size 5, got 6}}
   %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x6xi32>
@@ -1096,7 +1089,6 @@ func.func @scatter_invalid_out_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi
 
 // -----
 
-// CHECK-LABEL: @scatter_invalid_K_W
 func.func @scatter_invalid_K_W(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<2x6xi32>, %arg2 : tensor<2x6x5xi32>) {
   // expected-error at +1 {{'tosa.scatter' op requires dimensions K >= W, got K=4 and W=6}}
   %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<2x6xi32>, tensor<2x6x5xi32>) -> tensor<2x4x5xi32>
@@ -1150,3 +1142,83 @@ func.func @test_matmul_t_block_scaled_batch_mismatch(%arg0: tensor<4x8x32xf8E4M3
   %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<2x16x32xf8E4M3FN>, tensor<2x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
   return %0 : tensor<4x8x16xf32>
 }
+
+// -----
+
+func.func @cast_from_block_scaled_incompatible_input_output_shape(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<5x32xf32> {
+  // expected-error at +1 {{'tosa.cast_from_block_scaled' op require compatible shapes for input_data ('tensor<4x32xf4E2M1FN>') and output_data ('tensor<5x32xf32>')}}
+  %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<5x32xf32>
+  return %0 : tensor<5x32xf32>
+}
+
+// -----
+
+func.func @cast_from_block_scaled_not_scalar(%arg0: tensor<f4E2M1FN>, %arg1: tensor<f8E8M0FNU>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.cast_from_block_scaled' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f4E2M1FN>'}}
+  %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<f4E2M1FN>, tensor<f8E8M0FNU>) -> tensor<f32>
+  return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @cast_from_block_scaled_not_divisible_by_block_size(%arg0: tensor<4x33xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x33xf32> {
+  // expected-error at +1 {{'tosa.cast_from_block_scaled' op expect last dimension of input_data (33) to be divisible by block_size (32)}}
+  %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x33xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x33xf32>
+  return %0 : tensor<4x33xf32>
+}
+
+// -----
+
+func.func @cast_from_block_scaled_data_scale_mismatch(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<5x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+  // expected-error at +1 {{'tosa.cast_from_block_scaled' op require compatible shapes for input_data ('tensor<4x32xf4E2M1FN>') and input_scale ('tensor<5x1xf8E8M0FNU>') except for the last dimension}}
+  %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<5x1xf8E8M0FNU>) -> tensor<4x32xf32>
+  return %0 : tensor<4x32xf32>
+}
+
+// -----
+
+func.func @cast_from_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x2xf8E8M0FNU>) -> tensor<4x32xf32> {
+  // expected-error at +1 {{'tosa.cast_from_block_scaled' op expect last dimension of input_scale (2) to be equal to last dimension of input_data / block_size (1)}}
+  %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>) -> tensor<4x32xf32>
+  return %0 : tensor<4x32xf32>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_incompatible_input_output_shape(%arg0: tensor<4x32xf32>) -> (tensor<5x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) {
+  // expected-error at +1 {{'tosa.cast_to_block_scaled' op require compatible shapes for input_data ('tensor<4x32xf32>') and output_data ('tensor<5x32xf4E2M1FN>')}}
+  %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<5x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<5x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_not_scalar(%arg0: tensor<f32>) -> (tensor<f4E2M1FN>, tensor<f8E8M0FNU>) {
+  // expected-error at +1 {{'tosa.cast_to_block_scaled' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+  %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<f32>) -> (tensor<f4E2M1FN>, tensor<f8E8M0FNU>)
+  return %0#0, %0#1 : tensor<f4E2M1FN>, tensor<f8E8M0FNU>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_not_divisible_by_block_size(%arg0: tensor<4x33xf32>) -> (tensor<4x33xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) {
+  // expected-error at +1 {{'tosa.cast_to_block_scaled' op expect last dimension of input_data (33) to be divisible by block_size (32)}}
+  %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x33xf32>) -> (tensor<4x33xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<4x33xf4E2M1FN>, tensor<4x1xf8E8M0FNU>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_data_scale_mismatch(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<5x1xf8E8M0FNU>) {
+  // expected-error at +1 {{'tosa.cast_to_block_scaled' op require compatible shapes for output_data ('tensor<4x32xf4E2M1FN>') and output_scale ('tensor<5x1xf8E8M0FNU>') except for the last dimension}}
+  %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<5x1xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<5x1xf8E8M0FNU>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>) {
+  // expected-error at +1 {{'tosa.cast_to_block_scaled' op expect last dimension of output_scale (2) to be equal to last dimension of output_data / block_size (1)}}
+  %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>
+}



More information about the Mlir-commits mailing list