[Mlir-commits] [mlir] [mlir][tosa] Add row_gather_block_scaled op (PR #192272)

Peng Sun llvmlistbot at llvm.org
Wed Apr 15 08:10:14 PDT 2026


https://github.com/psunn created https://github.com/llvm/llvm-project/pull/192272

Add `tosa.row_gather_block_scaled` to the MLIR TOSA dialect, aligned with the current TOSA 1.1 draft spec and the implementation in `tosa-tools`.

  This includes:
  - op definition
  - verifier and shape inference support
  - validation / profile compliance wiring
  - availability and extension handling
  - lit tests for parsing, verification, shape inference, and version / extension gating

  #### Notes

  The op supports both spec-defined forms:
  - non-block-scaled: 1 input value tensor, `BLOCK_SIZE_1`, 1 output
  - block-scaled: data + scale tensor list, non-`BLOCK_SIZE_1`, 2 outputs

  This also tightens existing block-scaled-only ops to reject `BLOCK_SIZE_1` now that it is part of the shared enum.

>From b5cc923d882de2809fc8163c46438c96c2299612 Mon Sep 17 00:00:00 2001
From: Peng Sun <peng.sun at arm.com>
Date: Mon, 13 Apr 2026 13:59:30 +0100
Subject: [PATCH] [mlir][tosa] Add row_gather_block_scaled op

Signed-off-by: Peng Sun <peng.sun at arm.com>
---
 .../Dialect/Tosa/IR/TosaComplianceData.h.inc  |  70 +++++++
 .../mlir/Dialect/Tosa/IR/TosaOpBase.td        |   6 +-
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  |  38 ++++
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 193 ++++++++++++++++++
 .../Tosa/Transforms/TosaProfileCompliance.cpp |  22 +-
 .../Tosa/Transforms/TosaValidation.cpp        |  12 ++
 mlir/test/Dialect/Tosa/availability.mlir      |  11 +-
 mlir/test/Dialect/Tosa/invalid_extension.mlir |   8 +
 mlir/test/Dialect/Tosa/ops.mlir               |  16 ++
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir |  20 ++
 .../tosa-validation-version-1p0-invalid.mlir  |  11 +-
 ...a-validation-version-1p1-pro-fp-valid.mlir |   9 +
 .../tosa-validation-version-1p1-valid.mlir    |  11 +-
 mlir/test/Dialect/Tosa/verifier.mlir          |  36 ++++
 14 files changed, 453 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index d3e2cd129028e..50bb9f69c6242 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -396,6 +396,25 @@ profileComplianceMap = {
         {{i16T, i64T, i16T}, SpecificationVersion::V_1_1_DRAFT},
         {{i32T, i64T, i32T}, SpecificationVersion::V_1_1_DRAFT}},
        anyOf}}},
+    {"tosa.row_gather_block_scaled",
+     {{{Profile::pro_int},
+       {{{i8T, i32T, i32T, i8T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i16T, i32T, i32T, i16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Profile::pro_fp},
+       {{{i8T, i32T, i32T, i8T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i16T, i32T, i32T, i16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp16T, i32T, i32T, fp16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp32T, i32T, i32T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Profile::pro_fp, Profile::pro_int},
+       {{{boolT, i32T, i32T, boolT}, SpecificationVersion::V_1_1_DRAFT},
+        {{i8T, i64T, i32T, i8T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i16T, i64T, i32T, i16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i32T, i64T, i32T, i32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i64T, i64T, i32T, i64T}, SpecificationVersion::V_1_1_DRAFT},
+        {{boolT, i64T, i32T, boolT}, SpecificationVersion::V_1_1_DRAFT}},
+       anyOf}}},
     {"tosa.scatter",
      {{{Profile::pro_int},
        {{{i8T, i32T, i8T, i8T}, SpecificationVersion::V_1_0},
@@ -890,6 +909,57 @@ extensionComplianceMap = {
       {{Extension::bf16, Extension::int64},
        {{{bf16T, i64T, bf16T}, SpecificationVersion::V_1_1_DRAFT}},
        allOf}}},
+    {"tosa.row_gather_block_scaled",
+     {{{Extension::fp8e4m3},
+       {{{fp8e4m3T, i32T, i32T, fp8e4m3T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::fp8e5m2},
+       {{{fp8e5m2T, i32T, i32T, fp8e5m2T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::bf16},
+       {{{bf16T, i32T, i32T, bf16T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::int64},
+       {{{i8T, i64T, i32T, i8T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i16T, i64T, i32T, i16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i32T, i64T, i32T, i32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i64T, i64T, i32T, i64T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp16T, i64T, i32T, fp16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp32T, i64T, i32T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{boolT, i64T, i32T, boolT}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::fp8e4m3, Extension::int64},
+       {{{fp8e4m3T, i64T, i32T, fp8e4m3T}, SpecificationVersion::V_1_1_DRAFT}},
+       allOf},
+      {{Extension::fp8e5m2, Extension::int64},
+       {{{fp8e5m2T, i64T, i32T, fp8e5m2T}, SpecificationVersion::V_1_1_DRAFT}},
+       allOf},
+      {{Extension::bf16, Extension::int64},
+       {{{bf16T, i64T, i32T, bf16T}, SpecificationVersion::V_1_1_DRAFT}},
+       allOf},
+      {{Extension::mxfp},
+       {{{fp4e2m1T, fp8ue8m0T, i32T, i32T, fp4e2m1T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e2m3T, fp8ue8m0T, i32T, i32T, fp6e2m3T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e3m2T, fp8ue8m0T, i32T, i32T, fp6e3m2T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e4m3T, fp8ue8m0T, i32T, i32T, fp8e4m3T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e5m2T, fp8ue8m0T, i32T, i32T, fp8e5m2T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{mxint8T, fp8ue8m0T, i32T, i32T, mxint8T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::mxfp, Extension::int64},
+       {{{fp4e2m1T, fp8ue8m0T, i64T, i32T, fp4e2m1T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e2m3T, fp8ue8m0T, i64T, i32T, fp6e2m3T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e3m2T, fp8ue8m0T, i64T, i32T, fp6e3m2T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e4m3T, fp8ue8m0T, i64T, i32T, fp8e4m3T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e5m2T, fp8ue8m0T, i64T, i32T, fp8e5m2T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{mxint8T, fp8ue8m0T, i64T, i32T, mxint8T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT}},
+       allOf}}},
     {"tosa.scatter",
      {{{Extension::fp8e4m3},
        {{{fp8e4m3T, i32T, fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 1f05aee3e5eec..591073e9985ae 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -484,11 +484,12 @@ def Tosa_RoundingModeAttr
     : Tosa_I32EnumAttr<"RoundingMode", "Supported rounding modes", "rounding_mode",
                     [Tosa_ROUNDING_SINGLE_ROUND, Tosa_ROUNDING_INEXACT_ROUND, Tosa_ROUNDING_DOUBLE_ROUND]>;
 
+def Tosa_BLOCK_SIZE_1 : I32EnumAttrCase<"BLOCK_SIZE_1", 1>;
 def Tosa_BLOCK_SIZE_32 : I32EnumAttrCase<"BLOCK_SIZE_32", 32>;
 
 def Tosa_BlockSizeAttr
-    : Tosa_I32EnumAttr<"BlockSize", "Block size for the block_scaled formats", "block_size",
-                    [Tosa_BLOCK_SIZE_32]> {
+    : Tosa_I32EnumAttr<"BlockSize", "Block size for the block_scaled formats",
+                       "block_size", [Tosa_BLOCK_SIZE_1, Tosa_BLOCK_SIZE_32]> {
   let extraClassDeclaration = [{
     static uint32_t getBlockSizeValue(BlockSize blockSize) {
       return static_cast<uint32_t>(blockSize);
@@ -496,7 +497,6 @@ def Tosa_BlockSizeAttr
   }];
 }
 
-
 //===----------------------------------------------------------------------===//
 // TOSA Interfaces.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 45d1388a28749..ba750ef4438db 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2487,6 +2487,44 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather", [NoMemoryEffect]> {
       "operands attr-dict `:` functional-type(operands, results)";
 }
 
+//===----------------------------------------------------------------------===//
+// Operator: row_gather_block_scaled
+//===----------------------------------------------------------------------===//
+def Tosa_RowGatherBlockScaledOp
+    : Tosa_InferShapedTypeOp<"row_gather_block_scaled", [NoMemoryEffect]> {
+  let summary =
+      "Row gather operation for block-scaled and non-block-scaled data.";
+
+  let description = [{
+    Generate a tensor-list which contains a data tensor and an optional scale
+    tensor based on the input indices and row_count. The number of consecutive
+    rows gathered for each index is specified in row_count.
+
+    This operation follows the TOSA 1.1 draft specification and may evolve as
+    the specification is updated.
+
+    This operation is not pure. Undefined behaviour may occur if the specified
+    indices are out of range.
+  }];
+
+  let arguments = (ins Variadic<Tosa_Tensor3D>:$values,
+      Tosa_IndexTensor2D:$indices, Tosa_ScalarInt32Tensor:$row_count,
+      Tosa_BlockSizeAttr:$block_size);
+
+  let results = (outs Variadic<Tosa_Tensor3D>:$output);
+
+  list<Availability> availability = [Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+                                     Extension<[Tosa_EXT_FP8E4M3,
+                                                Tosa_EXT_FP8E5M2, Tosa_EXT_BF16,
+                                                Tosa_EXT_MXFP, Tosa_EXT_INT64]>,
+  ];
+
+  let hasVerifier = 1;
+
+  let assemblyFormat =
+      "operands attr-dict `:` functional-type(operands, results)";
+}
+
 //===----------------------------------------------------------------------===//
 // Operator: scatter
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 2754d3b21d4a6..360ab25875bd7 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2201,6 +2201,8 @@ LogicalResult MatmulTBlockScaledOp::verify() {
 
   // Verify C is a multiple of block size
   const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+  if (blockSize == 1)
+    return emitOpError("requires block_size to not be BLOCK_SIZE_1");
   if (ShapedType::isStatic(C) && C % blockSize != 0)
     return emitOpError("expect C to be a multiple of block size, got C=")
            << C << ", block_size=" << blockSize;
@@ -2848,6 +2850,18 @@ static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
   return -1;
 }
 
+static FailureOr<int64_t> getConstantScalarIntValue(Value val) {
+  ElementsAttr attr;
+  if (!matchPattern(val, m_Constant(&attr)))
+    return failure();
+
+  if (!llvm::isa<IntegerType>(attr.getElementType()) ||
+      attr.getNumElements() != 1)
+    return failure();
+
+  return attr.getValues<APInt>()[0].getSExtValue();
+}
+
 template <typename T>
 static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
                                      const std::string &operand) {
@@ -3096,6 +3110,48 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::RowGatherBlockScaledOp::inferReturnTypeComponents(
+    MLIRContext *context, ::std::optional<Location> location,
+    RowGatherBlockScaledOp::Adaptor adaptor,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  const auto values = adaptor.getValues();
+  if (values.empty() || values.size() > 2)
+    return failure();
+
+  SmallVector<int64_t> dataShape(3, ShapedType::kDynamic);
+  const ShapeAdaptor valuesShape(values.front().getType());
+  if (valuesShape.hasRank()) {
+    dataShape[0] = valuesShape.getDimSize(0);
+    dataShape[2] = valuesShape.getDimSize(2);
+  }
+
+  const ShapeAdaptor indicesShape(adaptor.getIndices().getType());
+  if (indicesShape.hasRank()) {
+    if (dataShape[0] == ShapedType::kDynamic)
+      dataShape[0] = indicesShape.getDimSize(0);
+
+    if (auto rowCount = getConstantScalarIntValue(adaptor.getRowCount());
+        succeeded(rowCount) && rowCount.value() > 0) {
+      const int64_t indicesW = indicesShape.getDimSize(1);
+      if (ShapedType::isStatic(indicesW))
+        dataShape[1] = indicesW * rowCount.value();
+    }
+  }
+
+  inferredReturnShapes.push_back(ShapedTypeComponents(dataShape));
+  if (values.size() == 1)
+    return success();
+
+  SmallVector<int64_t> scaleShape = dataShape;
+  const uint32_t blockSize =
+      BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
+  if (ShapedType::isStatic(dataShape[2]))
+    scaleShape[2] = dataShape[2] / blockSize;
+
+  inferredReturnShapes.push_back(ShapedTypeComponents(scaleShape));
+  return success();
+}
+
 LogicalResult tosa::GatherOp::verify() {
   if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
                              /* outType = */ getOutput().getType())
@@ -3145,6 +3201,137 @@ LogicalResult tosa::GatherOp::verify() {
   return success();
 }
 
+LogicalResult tosa::RowGatherBlockScaledOp::verify() {
+  const OperandRange values = getValues();
+  const ResultRange output = getOutput();
+  if (values.empty() || values.size() > 2)
+    return emitOpError()
+           << "expects values tensor list length to be 1 or 2, got "
+           << values.size();
+  if (output.size() != values.size())
+    return emitOpError()
+           << "expects output tensor list length to match values tensor list "
+              "length, got "
+           << output.size() << " results for " << values.size()
+           << " input tensors";
+
+  const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+  if (values.size() == 1 && blockSize != 1)
+    return emitOpError()
+           << "requires block_size to be BLOCK_SIZE_1 when values tensor list "
+              "length is 1";
+  if (values.size() == 2 && blockSize == 1)
+    return emitOpError()
+           << "requires block_size to not be BLOCK_SIZE_1 when values tensor "
+              "list length is 2";
+
+  if (failed(verifySameElementTypes(*this, values[0].getType(),
+                                    output[0].getType(), "values[0]",
+                                    "output[0]")))
+    return failure();
+  if (values.size() == 2 && failed(verifySameElementTypes(
+                                *this, values[1].getType(), output[1].getType(),
+                                "values[1]", "output[1]")))
+    return failure();
+
+  if (auto rowCount = getConstantScalarIntValue(getRowCount());
+      succeeded(rowCount) && rowCount.value() <= 0)
+    return emitOpError() << "requires row_count to be > 0, got "
+                         << rowCount.value();
+
+  int64_t n = ShapedType::kDynamic;
+  int64_t k = ShapedType::kDynamic;
+  int64_t c = ShapedType::kDynamic;
+  int64_t w = ShapedType::kDynamic;
+  int64_t multiplesOfC = ShapedType::kDynamic;
+
+  const ShapeAdaptor valuesDataShape(values[0].getType());
+  if (valuesDataShape.hasRank()) {
+    n = valuesDataShape.getDimSize(0);
+    k = valuesDataShape.getDimSize(1);
+    c = valuesDataShape.getDimSize(2);
+  }
+
+  if (ShapedType::isStatic(c) && c % blockSize != 0)
+    return emitOpError() << "expects channels of values[0] (" << c
+                         << ") to be divisible by block_size (" << blockSize
+                         << ")";
+
+  const ShapeAdaptor indicesShape(getIndices().getType());
+  if (indicesShape.hasRank()) {
+    if (failed(tryUpdateDimOrFailure(*this, n, indicesShape.getDimSize(0),
+                                     "indices", "batch")))
+      return failure();
+    w = indicesShape.getDimSize(1);
+  }
+
+  const ShapeAdaptor outputDataShape(output[0].getType());
+  if (outputDataShape.hasRank()) {
+    if (failed(tryUpdateDimOrFailure(*this, n, outputDataShape.getDimSize(0),
+                                     "output[0]", "batch")) ||
+        failed(tryUpdateDimOrFailure(*this, c, outputDataShape.getDimSize(2),
+                                     "output[0]", "channels")))
+      return failure();
+
+    if (auto rowCount = getConstantScalarIntValue(getRowCount());
+        succeeded(rowCount) && rowCount.value() > 0 &&
+        ShapedType::isStatic(w)) {
+      const int64_t expectedOutputRows = w * rowCount.value();
+      if (ShapedType::isStatic(outputDataShape.getDimSize(1)) &&
+          outputDataShape.getDimSize(1) != expectedOutputRows)
+        return emitOpError() << "requires output[0] dimension 1 to have size "
+                             << expectedOutputRows << ", got "
+                             << outputDataShape.getDimSize(1);
+    }
+  }
+
+  if (values.size() == 2) {
+    const ShapeAdaptor valuesScaleShape(values[1].getType());
+    if (valuesScaleShape.hasRank()) {
+      if (failed(tryUpdateDimOrFailure(*this, n, valuesScaleShape.getDimSize(0),
+                                       "values[1]", "batch")) ||
+          failed(tryUpdateDimOrFailure(*this, k, valuesScaleShape.getDimSize(1),
+                                       "values[1]", "rows")))
+        return failure();
+      multiplesOfC = valuesScaleShape.getDimSize(2);
+    }
+
+    const ShapeAdaptor outputScaleShape(output[1].getType());
+    if (outputScaleShape.hasRank()) {
+      if (failed(tryUpdateDimOrFailure(*this, n, outputScaleShape.getDimSize(0),
+                                       "output[1]", "batch")))
+        return failure();
+
+      if (auto rowCount = getConstantScalarIntValue(getRowCount());
+          succeeded(rowCount) && rowCount.value() > 0 &&
+          ShapedType::isStatic(w)) {
+        const int64_t expectedOutputRows = w * rowCount.value();
+        if (ShapedType::isStatic(outputScaleShape.getDimSize(1)) &&
+            outputScaleShape.getDimSize(1) != expectedOutputRows)
+          return emitOpError() << "requires output[1] dimension 1 to have size "
+                               << expectedOutputRows << ", got "
+                               << outputScaleShape.getDimSize(1);
+      }
+
+      if (ShapedType::isDynamic(multiplesOfC))
+        multiplesOfC = outputScaleShape.getDimSize(2);
+      else if (ShapedType::isStatic(outputScaleShape.getDimSize(2)) &&
+               multiplesOfC != outputScaleShape.getDimSize(2))
+        return emitOpError()
+               << "expected channels of output[1] to match size "
+               << multiplesOfC << ", got " << outputScaleShape.getDimSize(2);
+    }
+
+    if (ShapedType::isStatic(c) && ShapedType::isStatic(multiplesOfC) &&
+        multiplesOfC != c / blockSize)
+      return emitOpError()
+             << "expects channels of scale tensors to equal C/block_size (" << c
+             << "/" << blockSize << "), got " << multiplesOfC;
+  }
+
+  return success();
+}
+
 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     ResizeOp::Adaptor adaptor,
@@ -3987,6 +4174,8 @@ LogicalResult Conv2DBlockScaledOp::verify() {
 
   // Verify IC is a multiple of block size
   const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+  if (blockSize == 1)
+    return emitOpError("requires block_size to not be BLOCK_SIZE_1");
   if (ShapedType::isStatic(IC) && IC % blockSize != 0)
     return emitOpError("expect IC to be a multiple of block size, got IC=")
            << IC << ", block_size=" << blockSize;
@@ -4577,6 +4766,8 @@ LogicalResult CastFromBlockScaledOp::verify() {
   if (inputDataShape.hasRank()) {
     const unsigned int blockSize =
         BlockSizeAttr::getBlockSizeValue(getBlockSize());
+    if (blockSize == 1)
+      return emitOpError("requires block_size to not be BLOCK_SIZE_1");
     const int64_t inputDataLastDim =
         inputDataShape.getDimSize(inputDataShape.getRank() - 1);
     if (inputDataLastDim % blockSize != 0)
@@ -4650,6 +4841,8 @@ LogicalResult CastToBlockScaledOp::verify() {
 
   const unsigned int blockSize =
       BlockSizeAttr::getBlockSizeValue(getBlockSize());
+  if (blockSize == 1)
+    return emitOpError("requires block_size to not be BLOCK_SIZE_1");
   const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
   if (inputDataShape.hasRank()) {
     const int64_t inputDataLastDim =
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 01c85be4f704f..4ea225b860f6c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -185,6 +185,18 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
   return success();
 }
 
+template <>
+LogicalResult
+ProfileInfoDepot::populateProfileInfo(tosa::RowGatherBlockScaledOp op) {
+  for (Value value : op.getValues())
+    addValue(value);
+  addValue(op.getIndices());
+  addValue(op.getRowCount());
+  for (Value result : op.getOutput())
+    addValue(result);
+  return success();
+}
+
 template <>
 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
   addValue(op.getValuesIn());
@@ -288,6 +300,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
   POPULATE_PROFILE_INFO_CUSTOM(Tile)
   POPULATE_PROFILE_INFO_CUSTOM(Transpose)
   POPULATE_PROFILE_INFO_CUSTOM(Gather)
+  POPULATE_PROFILE_INFO_CUSTOM(RowGatherBlockScaled)
   POPULATE_PROFILE_INFO_CUSTOM(Scatter)
   POPULATE_PROFILE_INFO_CUSTOM(Resize)
   POPULATE_PROFILE_INFO_CUSTOM(Select)
@@ -598,10 +611,11 @@ SmallVector<OpComplianceInfo<T>> TosaProfileCompliance::findMatchedEntries(
     SmallVector<VersionedTypeInfo> sets = compInfo[i].operandTypeInfoSet;
     for (const auto &set : sets) {
       SmallVector<TypeInfo> expected = set.first;
-      assert(present.size() == expected.size() &&
-             "the entries for profile-based compliance do not match between "
-             "the generated metadata and the type definition retrieved from "
-             " the operation");
+      // Tensor-list operators can legitimately have multiple valid signatures
+      // with different operand/result counts, e.g. data-only and data+scale
+      // forms. Treat those as non-matches instead of asserting.
+      if (present.size() != expected.size())
+        continue;
 
       bool isFound = true;
       // Compare the type signature between the given operation and the
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 3e2cda9d37666..177878b98cbaa 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -114,6 +114,17 @@ static LogicalResult checkConstantOperandMatMul(Operation *op,
   return success();
 }
 
+static LogicalResult
+checkConstantOperandRowGatherBlockScaled(Operation *op, const TargetEnv &env) {
+  if (!env.allows(Extension::dynamic) &&
+      isa<tosa::RowGatherBlockScaledOp>(op)) {
+    auto rowGatherOp = cast<tosa::RowGatherBlockScaledOp>(op);
+    const unsigned rowCountIndex = rowGatherOp.getValues().size() + 1;
+    return checkConstantOperands(op, {rowCountIndex});
+  }
+  return success();
+}
+
 static LogicalResult checkConstantOperandAvgPool2d(Operation *op,
                                                    const TargetEnv &env) {
   if (!env.allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
@@ -199,6 +210,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     constCheckers.emplace_back(
         checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
     constCheckers.emplace_back(checkConstantOperandMatMul);
+    constCheckers.emplace_back(checkConstantOperandRowGatherBlockScaled);
     constCheckers.emplace_back(checkConstantOperandAvgPool2d);
     constCheckers.emplace_back(checkConstantOperandAvgPool2dAdaptive);
     constCheckers.emplace_back(checkConstantOperandNegate);
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 34de532639994..81294f2c0c308 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -595,6 +595,16 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
   return %0 : tensor<13x26x3xf32>
 }
 
+// -----
+// CHECK-LABEL: row_gather_block_scaled
+func.func @test_row_gather_block_scaled(%arg0: tensor<13x21x32xf4E2M1FN>, %arg1: tensor<13x21x1xf8E8M0FNU>, %arg2: tensor<13x26xi32>) -> (tensor<13x52x32xf4E2M1FN>, tensor<13x52x1xf8E8M0FNU>) {
+  %row_count = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK: profiles: [ [pro_int, pro_fp] ]
+  // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16, mxfp, int64] ]
+  %0:2 = tosa.row_gather_block_scaled %arg0, %arg1, %arg2, %row_count {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<13x21x32xf4E2M1FN>, tensor<13x21x1xf8E8M0FNU>, tensor<13x26xi32>, tensor<1xi32>) -> (tensor<13x52x32xf4E2M1FN>, tensor<13x52x1xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<13x52x32xf4E2M1FN>, tensor<13x52x1xf8E8M0FNU>
+}
+
 // -----
 // CHECK-LABEL: scatter
 func.func @test_scatter(%arg0: tensor<13x28x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x28x3xf32> {
@@ -727,4 +737,3 @@ func.func @test_cast_to_block_scaled(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4
   %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = BLOCK_SIZE_32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
   return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>
 }
-
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 8a9b014864c74..295d7172bc2c4 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -249,6 +249,14 @@ func.func @test_gather(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x26xi32>) ->
   return %0 : tensor<13x26x3xbf16>
 }
 
+// -----
+func.func @test_row_gather_block_scaled(%arg0: tensor<13x21x32xf4E2M1FN>, %arg1: tensor<13x21x1xf8E8M0FNU>, %arg2: tensor<13x26xi32>) -> (tensor<13x52x32xf4E2M1FN>, tensor<13x52x1xf8E8M0FNU>) {
+  %row_count = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error at +1 {{'tosa.row_gather_block_scaled' op illegal: requires any of [mxfp] profiles/extensions to be specified in the target environment}}
+  %0:2 = tosa.row_gather_block_scaled %arg0, %arg1, %arg2, %row_count {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<13x21x32xf4E2M1FN>, tensor<13x21x1xf8E8M0FNU>, tensor<13x26xi32>, tensor<1xi32>) -> (tensor<13x52x32xf4E2M1FN>, tensor<13x52x1xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<13x52x32xf4E2M1FN>, tensor<13x52x1xf8E8M0FNU>
+}
+
 // -----
 func.func @test_scatter(%arg0: tensor<13x26x3xbf16>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xbf16>) -> tensor<13x26x3xbf16> {
   // expected-error at +1 {{'tosa.scatter' op illegal: requires any of [bf16] profiles/extensions to be specified in the target environment}}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index b30e92c4a9621..fbdaf2e42d022 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -906,6 +906,22 @@ func.func @test_gather_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tenso
   return %0 : tensor<13x26x3xf32>
 }
 
+// -----
+// CHECK-LABEL: test_row_gather_block_scaled
+func.func @test_row_gather_block_scaled(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x52x3xf32> {
+  %row_count = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+  %0 = tosa.row_gather_block_scaled %arg0, %arg1, %row_count {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<1xi32>) -> (tensor<13x52x3xf32>)
+  return %0 : tensor<13x52x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_row_gather_block_scaled_mxfp
+func.func @test_row_gather_block_scaled_mxfp(%arg0: tensor<13x21x32xf4E2M1FN>, %arg1: tensor<13x21x1xf8E8M0FNU>, %arg2: tensor<13x26xi32>) -> (tensor<13x52x32xf4E2M1FN>, tensor<13x52x1xf8E8M0FNU>) {
+  %row_count = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+  %0:2 = tosa.row_gather_block_scaled %arg0, %arg1, %arg2, %row_count {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<13x21x32xf4E2M1FN>, tensor<13x21x1xf8E8M0FNU>, tensor<13x26xi32>, tensor<1xi32>) -> (tensor<13x52x32xf4E2M1FN>, tensor<13x52x1xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<13x52x32xf4E2M1FN>, tensor<13x52x1xf8E8M0FNU>
+}
+
 // -----
 // CHECK-LABEL: scatter
 func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 408300fa7034b..3bdac624a42a9 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -674,6 +674,26 @@ func.func @gather_minimum_info(%arg0 : tensor<3x?x5xi32>, %arg1 : tensor<?x6xi32
 
 // -----
 
+// CHECK-LABEL: @row_gather_block_scaled_static
+func.func @row_gather_block_scaled_static(%arg0 : tensor<3x4x5xi32>, %arg1 : tensor<3x6xi32>) {
+  %row_count = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK: tosa.row_gather_block_scaled %arg0, %arg1, %[[ROW_COUNT:.+]] {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<1xi32>) -> tensor<3x12x5xi32>
+  %0 = tosa.row_gather_block_scaled %arg0, %arg1, %row_count {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<1xi32>) -> (tensor<?x?x?xi32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @row_gather_block_scaled_mxfp_static
+func.func @row_gather_block_scaled_mxfp_static(%arg0 : tensor<3x4x32xf4E2M1FN>, %arg1 : tensor<3x4x1xf8E8M0FNU>, %arg2 : tensor<3x6xi32>) {
+  %row_count = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK: tosa.row_gather_block_scaled %arg0, %arg1, %arg2, %[[ROW_COUNT:.+]] {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<3x4x32xf4E2M1FN>, tensor<3x4x1xf8E8M0FNU>, tensor<3x6xi32>, tensor<1xi32>) -> (tensor<3x12x32xf4E2M1FN>, tensor<3x12x1xf8E8M0FNU>)
+  %0:2 = tosa.row_gather_block_scaled %arg0, %arg1, %arg2, %row_count {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<3x4x32xf4E2M1FN>, tensor<3x4x1xf8E8M0FNU>, tensor<3x6xi32>, tensor<1xi32>) -> (tensor<?x?x?xf4E2M1FN>, tensor<?x?x?xf8E8M0FNU>)
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @scatter_static
 func.func @scatter_static(%arg0 : tensor<3x8x5xi32>, %arg1 : tensor<3x6xi32>, %arg2 : tensor<3x6x5xi32>) {
   // CHECK: tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x8x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<3x8x5xi32>
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
index eed9621edc7b2..693fd21e9cf4b 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
@@ -78,6 +78,15 @@ func.func @test_gather_bool_i32(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x26xi
 
 // -----
 
+func.func @test_row_gather_block_scaled_i8_i32(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x26xi32>) -> tensor<13x52x3xi8> {
+  %row_count = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error at +1 {{'tosa.row_gather_block_scaled' op illegal: requires specification version compatible with 1.1 (got 1.0) OR requires specification version compatible with 1.1 (got 1.0) to be specified in the target environment}}
+  %0 = tosa.row_gather_block_scaled %arg0, %arg1, %row_count {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<13x21x3xi8>, tensor<13x26xi32>, tensor<1xi32>) -> (tensor<13x52x3xi8>)
+  return %0 : tensor<13x52x3xi8>
+}
+
+// -----
+
 func.func @test_scatter_bool_i64(%arg0: tensor<13x52x3xi1>, %arg1: tensor<13x26xi64>, %arg2: tensor<13x26x3xi1>) -> tensor<13x52x3xi1> {
   // expected-error at +1 {{'tosa.scatter' op illegal: requires specification version compatible with 1.1 (got 1.0) and requires any of [int64] profiles/extensions to be specified in the target environment}}
   %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xi1>, tensor<13x26xi64>, tensor<13x26x3xi1>) -> tensor<13x52x3xi1>
@@ -191,4 +200,4 @@ func.func @test_maxpool2d_adaptive(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x
   %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
          (tensor<1x32x32x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x32x8xf32>
   return %0 : tensor<1x32x32x8xf32>
-}
\ No newline at end of file
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-pro-fp-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-pro-fp-valid.mlir
index 0f2bbb71ee4b7..57802fd147f23 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-pro-fp-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-pro-fp-valid.mlir
@@ -13,3 +13,12 @@ func.func @test_scatter_i8_i32(%input: tensor<13x27x3xi8>, %indices: tensor<13x2
   %scatter = tosa.scatter %input, %indices, %updates : (tensor<13x27x3xi8>, tensor<13x26xi32>, tensor<13x26x3xi8>) -> tensor<13x27x3xi8>
   return %scatter : tensor<13x27x3xi8>
 }
+
+// -----
+
+// CHECK-LABEL: test_row_gather_block_scaled_i8_i32
+func.func @test_row_gather_block_scaled_i8_i32(%input: tensor<13x27x3xi8>, %indices: tensor<13x26xi32>) -> tensor<13x52x3xi8> {
+  %row_count = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+  %gather = tosa.row_gather_block_scaled %input, %indices, %row_count {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<13x27x3xi8>, tensor<13x26xi32>, tensor<1xi32>) -> (tensor<13x52x3xi8>)
+  return %gather : tensor<13x52x3xi8>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index 49ac4904002b3..1bac8bacbaf40 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -223,6 +223,15 @@ func.func @test_gather_i8_i32_indices(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1
 
 // -----
 
+// CHECK-LABEL: test_row_gather_block_scaled_i8_i32_indices
+func.func @test_row_gather_block_scaled_i8_i32_indices(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x26xi32>) -> tensor<13x52x3xi8> {
+  %row_count = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+  %0 = tosa.row_gather_block_scaled %arg0, %arg1, %row_count {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<13x21x3xi8>, tensor<13x26xi32>, tensor<1xi32>) -> (tensor<13x52x3xi8>)
+  return %0 : tensor<13x52x3xi8>
+}
+
+// -----
+
 // CHECK-LABEL: test_scatter_i8_i32_indices
 func.func @test_scatter_i8_i32_indices(%arg0: tensor<13x27x3xi8>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi8>) -> tensor<13x27x3xi8> {
   %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x27x3xi8>, tensor<13x26xi32>, tensor<13x26x3xi8>) -> tensor<13x27x3xi8>
@@ -430,4 +439,4 @@ func.func @test_maxpool2d_adaptive(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x
   %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
          (tensor<1x32x32x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x32x8xf32>
   return %0 : tensor<1x32x32x8xf32>
-}
\ No newline at end of file
+}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 5f3aa8764664d..1c92c9ae335f5 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -454,6 +454,42 @@ func.func @test_gather_invalid_out_C(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1
   return %0 : tensor<13x26x8xf32>
 }
 
+// -----
+// CHECK-LABEL: test_row_gather_block_scaled_output_count_mismatch
+func.func @test_row_gather_block_scaled_output_count_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> (tensor<13x52x3xf32>, tensor<13x52x3xf32>) {
+  %row_count = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error at +1 {{'tosa.row_gather_block_scaled' op expects output tensor list length to match values tensor list length, got 2 results for 1 input tensors}}
+  %0:2 = tosa.row_gather_block_scaled %arg0, %arg1, %row_count {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<1xi32>) -> (tensor<13x52x3xf32>, tensor<13x52x3xf32>)
+  return %0#0, %0#1 : tensor<13x52x3xf32>, tensor<13x52x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_row_gather_block_scaled_block_size_one_required
+func.func @test_row_gather_block_scaled_block_size_one_required(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x52x3xf32> {
+  %row_count = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error at +1 {{'tosa.row_gather_block_scaled' op requires block_size to be BLOCK_SIZE_1 when values tensor list length is 1}}
+  %0 = tosa.row_gather_block_scaled %arg0, %arg1, %row_count {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<1xi32>) -> (tensor<13x52x3xf32>)
+  return %0 : tensor<13x52x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_row_gather_block_scaled_output_rows_mismatch
+func.func @test_row_gather_block_scaled_output_rows_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x53x3xf32> {
+  %row_count = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error at +1 {{'tosa.row_gather_block_scaled' op requires output[0] dimension 1 to have size 52, got 53}}
+  %0 = tosa.row_gather_block_scaled %arg0, %arg1, %row_count {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<1xi32>) -> (tensor<13x53x3xf32>)
+  return %0 : tensor<13x53x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_row_gather_block_scaled_scale_channel_mismatch
+func.func @test_row_gather_block_scaled_scale_channel_mismatch(%arg0: tensor<13x21x32xf4E2M1FN>, %arg1: tensor<13x21x2xf8E8M0FNU>, %arg2: tensor<13x26xi32>) -> (tensor<13x52x32xf4E2M1FN>, tensor<13x52x2xf8E8M0FNU>) {
+  %row_count = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error at +1 {{'tosa.row_gather_block_scaled' op expects channels of scale tensors to equal C/block_size (32/32), got 2}}
+  %0:2 = tosa.row_gather_block_scaled %arg0, %arg1, %arg2, %row_count {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<13x21x32xf4E2M1FN>, tensor<13x21x2xf8E8M0FNU>, tensor<13x26xi32>, tensor<1xi32>) -> (tensor<13x52x32xf4E2M1FN>, tensor<13x52x2xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<13x52x32xf4E2M1FN>, tensor<13x52x2xf8E8M0FNU>
+}
+
 // -----
 func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
   %0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>



More information about the Mlir-commits mailing list