[Mlir-commits] [mlir] [mlir][tosa] Add row_gather_block_scaled op (PR #192272)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 15 08:10:51 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Peng Sun (psunn)

<details>
<summary>Changes</summary>

Add `tosa.row_gather_block_scaled` to the MLIR TOSA dialect, aligned with the current TOSA 1.1 draft spec and the implementation in `tosa-tools`.

  This includes:
  - op definition
  - verifier and shape inference support
  - validation / profile compliance wiring
  - availability and extension handling
  - lit tests for parsing, verification, shape inference, and version / extension gating

  #### Notes

  The op supports both spec-defined forms:
  - non-block-scaled: 1 input value tensor, `BLOCK_SIZE_1`, 1 output
  - block-scaled: data + scale tensor list, non-`BLOCK_SIZE_1`, 2 outputs

  This also tightens existing block-scaled-only ops to reject `BLOCK_SIZE_1` now that it is part of the shared enum.

---

Patch is 36.87 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/192272.diff


14 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc (+70) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+3-3) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+38) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+193) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+18-4) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+12) 
- (modified) mlir/test/Dialect/Tosa/availability.mlir (+10-1) 
- (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+8) 
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+16) 
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+20) 
- (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir (+10-1) 
- (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-pro-fp-valid.mlir (+9) 
- (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir (+10-1) 
- (modified) mlir/test/Dialect/Tosa/verifier.mlir (+36) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index d3e2cd129028e..50bb9f69c6242 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -396,6 +396,25 @@ profileComplianceMap = {
         {{i16T, i64T, i16T}, SpecificationVersion::V_1_1_DRAFT},
         {{i32T, i64T, i32T}, SpecificationVersion::V_1_1_DRAFT}},
        anyOf}}},
+    {"tosa.row_gather_block_scaled",
+     {{{Profile::pro_int},
+       {{{i8T, i32T, i32T, i8T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i16T, i32T, i32T, i16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Profile::pro_fp},
+       {{{i8T, i32T, i32T, i8T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i16T, i32T, i32T, i16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp16T, i32T, i32T, fp16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp32T, i32T, i32T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Profile::pro_fp, Profile::pro_int},
+       {{{boolT, i32T, i32T, boolT}, SpecificationVersion::V_1_1_DRAFT},
+        {{i8T, i64T, i32T, i8T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i16T, i64T, i32T, i16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i32T, i64T, i32T, i32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i64T, i64T, i32T, i64T}, SpecificationVersion::V_1_1_DRAFT},
+        {{boolT, i64T, i32T, boolT}, SpecificationVersion::V_1_1_DRAFT}},
+       anyOf}}},
     {"tosa.scatter",
      {{{Profile::pro_int},
        {{{i8T, i32T, i8T, i8T}, SpecificationVersion::V_1_0},
@@ -890,6 +909,57 @@ extensionComplianceMap = {
       {{Extension::bf16, Extension::int64},
        {{{bf16T, i64T, bf16T}, SpecificationVersion::V_1_1_DRAFT}},
        allOf}}},
+    {"tosa.row_gather_block_scaled",
+     {{{Extension::fp8e4m3},
+       {{{fp8e4m3T, i32T, i32T, fp8e4m3T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::fp8e5m2},
+       {{{fp8e5m2T, i32T, i32T, fp8e5m2T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::bf16},
+       {{{bf16T, i32T, i32T, bf16T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::int64},
+       {{{i8T, i64T, i32T, i8T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i16T, i64T, i32T, i16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i32T, i64T, i32T, i32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i64T, i64T, i32T, i64T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp16T, i64T, i32T, fp16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp32T, i64T, i32T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{boolT, i64T, i32T, boolT}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::fp8e4m3, Extension::int64},
+       {{{fp8e4m3T, i64T, i32T, fp8e4m3T}, SpecificationVersion::V_1_1_DRAFT}},
+       allOf},
+      {{Extension::fp8e5m2, Extension::int64},
+       {{{fp8e5m2T, i64T, i32T, fp8e5m2T}, SpecificationVersion::V_1_1_DRAFT}},
+       allOf},
+      {{Extension::bf16, Extension::int64},
+       {{{bf16T, i64T, i32T, bf16T}, SpecificationVersion::V_1_1_DRAFT}},
+       allOf},
+      {{Extension::mxfp},
+       {{{fp4e2m1T, fp8ue8m0T, i32T, i32T, fp4e2m1T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e2m3T, fp8ue8m0T, i32T, i32T, fp6e2m3T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e3m2T, fp8ue8m0T, i32T, i32T, fp6e3m2T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e4m3T, fp8ue8m0T, i32T, i32T, fp8e4m3T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e5m2T, fp8ue8m0T, i32T, i32T, fp8e5m2T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{mxint8T, fp8ue8m0T, i32T, i32T, mxint8T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::mxfp, Extension::int64},
+       {{{fp4e2m1T, fp8ue8m0T, i64T, i32T, fp4e2m1T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e2m3T, fp8ue8m0T, i64T, i32T, fp6e2m3T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e3m2T, fp8ue8m0T, i64T, i32T, fp6e3m2T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e4m3T, fp8ue8m0T, i64T, i32T, fp8e4m3T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e5m2T, fp8ue8m0T, i64T, i32T, fp8e5m2T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{mxint8T, fp8ue8m0T, i64T, i32T, mxint8T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT}},
+       allOf}}},
     {"tosa.scatter",
      {{{Extension::fp8e4m3},
        {{{fp8e4m3T, i32T, fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 1f05aee3e5eec..591073e9985ae 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -484,11 +484,12 @@ def Tosa_RoundingModeAttr
     : Tosa_I32EnumAttr<"RoundingMode", "Supported rounding modes", "rounding_mode",
                     [Tosa_ROUNDING_SINGLE_ROUND, Tosa_ROUNDING_INEXACT_ROUND, Tosa_ROUNDING_DOUBLE_ROUND]>;
 
+def Tosa_BLOCK_SIZE_1 : I32EnumAttrCase<"BLOCK_SIZE_1", 1>;
 def Tosa_BLOCK_SIZE_32 : I32EnumAttrCase<"BLOCK_SIZE_32", 32>;
 
 def Tosa_BlockSizeAttr
-    : Tosa_I32EnumAttr<"BlockSize", "Block size for the block_scaled formats", "block_size",
-                    [Tosa_BLOCK_SIZE_32]> {
+    : Tosa_I32EnumAttr<"BlockSize", "Block size for the block_scaled formats",
+                       "block_size", [Tosa_BLOCK_SIZE_1, Tosa_BLOCK_SIZE_32]> {
   let extraClassDeclaration = [{
     static uint32_t getBlockSizeValue(BlockSize blockSize) {
       return static_cast<uint32_t>(blockSize);
@@ -496,7 +497,6 @@ def Tosa_BlockSizeAttr
   }];
 }
 
-
 //===----------------------------------------------------------------------===//
 // TOSA Interfaces.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 45d1388a28749..ba750ef4438db 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2487,6 +2487,44 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather", [NoMemoryEffect]> {
       "operands attr-dict `:` functional-type(operands, results)";
 }
 
+//===----------------------------------------------------------------------===//
+// Operator: row_gather_block_scaled
+//===----------------------------------------------------------------------===//
+def Tosa_RowGatherBlockScaledOp
+    : Tosa_InferShapedTypeOp<"row_gather_block_scaled", [NoMemoryEffect]> {
+  let summary =
+      "Row gather operation for block-scaled and non-block-scaled data.";
+
+  let description = [{
+    Generate a tensor-list which contains a data tensor and an optional scale
+    tensor based on the input indices and row_count. The number of consecutive
+    rows gathered for each index is specified in row_count.
+
+    This operation follows the TOSA 1.1 draft specification and may evolve as
+    the specification is updated.
+
+    This operation is not pure. Undefined behaviour may occur if the specified
+    indices are out of range.
+  }];
+
+  let arguments = (ins Variadic<Tosa_Tensor3D>:$values,
+      Tosa_IndexTensor2D:$indices, Tosa_ScalarInt32Tensor:$row_count,
+      Tosa_BlockSizeAttr:$block_size);
+
+  let results = (outs Variadic<Tosa_Tensor3D>:$output);
+
+  list<Availability> availability = [Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+                                     Extension<[Tosa_EXT_FP8E4M3,
+                                                Tosa_EXT_FP8E5M2, Tosa_EXT_BF16,
+                                                Tosa_EXT_MXFP, Tosa_EXT_INT64]>,
+  ];
+
+  let hasVerifier = 1;
+
+  let assemblyFormat =
+      "operands attr-dict `:` functional-type(operands, results)";
+}
+
 //===----------------------------------------------------------------------===//
 // Operator: scatter
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 2754d3b21d4a6..360ab25875bd7 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2201,6 +2201,8 @@ LogicalResult MatmulTBlockScaledOp::verify() {
 
   // Verify C is a multiple of block size
   const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+  if (blockSize == 1)
+    return emitOpError("requires block_size to not be BLOCK_SIZE_1");
   if (ShapedType::isStatic(C) && C % blockSize != 0)
     return emitOpError("expect C to be a multiple of block size, got C=")
            << C << ", block_size=" << blockSize;
@@ -2848,6 +2850,18 @@ static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
   return -1;
 }
 
+static FailureOr<int64_t> getConstantScalarIntValue(Value val) {
+  ElementsAttr attr;
+  if (!matchPattern(val, m_Constant(&attr)))
+    return failure();
+
+  if (!llvm::isa<IntegerType>(attr.getElementType()) ||
+      attr.getNumElements() != 1)
+    return failure();
+
+  return attr.getValues<APInt>()[0].getSExtValue();
+}
+
 template <typename T>
 static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
                                      const std::string &operand) {
@@ -3096,6 +3110,48 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::RowGatherBlockScaledOp::inferReturnTypeComponents(
+    MLIRContext *context, ::std::optional<Location> location,
+    RowGatherBlockScaledOp::Adaptor adaptor,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  const auto values = adaptor.getValues();
+  if (values.empty() || values.size() > 2)
+    return failure();
+
+  SmallVector<int64_t> dataShape(3, ShapedType::kDynamic);
+  const ShapeAdaptor valuesShape(values.front().getType());
+  if (valuesShape.hasRank()) {
+    dataShape[0] = valuesShape.getDimSize(0);
+    dataShape[2] = valuesShape.getDimSize(2);
+  }
+
+  const ShapeAdaptor indicesShape(adaptor.getIndices().getType());
+  if (indicesShape.hasRank()) {
+    if (dataShape[0] == ShapedType::kDynamic)
+      dataShape[0] = indicesShape.getDimSize(0);
+
+    if (auto rowCount = getConstantScalarIntValue(adaptor.getRowCount());
+        succeeded(rowCount) && rowCount.value() > 0) {
+      const int64_t indicesW = indicesShape.getDimSize(1);
+      if (ShapedType::isStatic(indicesW))
+        dataShape[1] = indicesW * rowCount.value();
+    }
+  }
+
+  inferredReturnShapes.push_back(ShapedTypeComponents(dataShape));
+  if (values.size() == 1)
+    return success();
+
+  SmallVector<int64_t> scaleShape = dataShape;
+  const uint32_t blockSize =
+      BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
+  if (ShapedType::isStatic(dataShape[2]))
+    scaleShape[2] = dataShape[2] / blockSize;
+
+  inferredReturnShapes.push_back(ShapedTypeComponents(scaleShape));
+  return success();
+}
+
 LogicalResult tosa::GatherOp::verify() {
   if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
                              /* outType = */ getOutput().getType())
@@ -3145,6 +3201,137 @@ LogicalResult tosa::GatherOp::verify() {
   return success();
 }
 
+LogicalResult tosa::RowGatherBlockScaledOp::verify() {
+  const OperandRange values = getValues();
+  const ResultRange output = getOutput();
+  if (values.empty() || values.size() > 2)
+    return emitOpError()
+           << "expects values tensor list length to be 1 or 2, got "
+           << values.size();
+  if (output.size() != values.size())
+    return emitOpError()
+           << "expects output tensor list length to match values tensor list "
+              "length, got "
+           << output.size() << " results for " << values.size()
+           << " input tensors";
+
+  const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+  if (values.size() == 1 && blockSize != 1)
+    return emitOpError()
+           << "requires block_size to be BLOCK_SIZE_1 when values tensor list "
+              "length is 1";
+  if (values.size() == 2 && blockSize == 1)
+    return emitOpError()
+           << "requires block_size to not be BLOCK_SIZE_1 when values tensor "
+              "list length is 2";
+
+  if (failed(verifySameElementTypes(*this, values[0].getType(),
+                                    output[0].getType(), "values[0]",
+                                    "output[0]")))
+    return failure();
+  if (values.size() == 2 && failed(verifySameElementTypes(
+                                *this, values[1].getType(), output[1].getType(),
+                                "values[1]", "output[1]")))
+    return failure();
+
+  if (auto rowCount = getConstantScalarIntValue(getRowCount());
+      succeeded(rowCount) && rowCount.value() <= 0)
+    return emitOpError() << "requires row_count to be > 0, got "
+                         << rowCount.value();
+
+  int64_t n = ShapedType::kDynamic;
+  int64_t k = ShapedType::kDynamic;
+  int64_t c = ShapedType::kDynamic;
+  int64_t w = ShapedType::kDynamic;
+  int64_t multiplesOfC = ShapedType::kDynamic;
+
+  const ShapeAdaptor valuesDataShape(values[0].getType());
+  if (valuesDataShape.hasRank()) {
+    n = valuesDataShape.getDimSize(0);
+    k = valuesDataShape.getDimSize(1);
+    c = valuesDataShape.getDimSize(2);
+  }
+
+  if (ShapedType::isStatic(c) && c % blockSize != 0)
+    return emitOpError() << "expects channels of values[0] (" << c
+                         << ") to be divisible by block_size (" << blockSize
+                         << ")";
+
+  const ShapeAdaptor indicesShape(getIndices().getType());
+  if (indicesShape.hasRank()) {
+    if (failed(tryUpdateDimOrFailure(*this, n, indicesShape.getDimSize(0),
+                                     "indices", "batch")))
+      return failure();
+    w = indicesShape.getDimSize(1);
+  }
+
+  const ShapeAdaptor outputDataShape(output[0].getType());
+  if (outputDataShape.hasRank()) {
+    if (failed(tryUpdateDimOrFailure(*this, n, outputDataShape.getDimSize(0),
+                                     "output[0]", "batch")) ||
+        failed(tryUpdateDimOrFailure(*this, c, outputDataShape.getDimSize(2),
+                                     "output[0]", "channels")))
+      return failure();
+
+    if (auto rowCount = getConstantScalarIntValue(getRowCount());
+        succeeded(rowCount) && rowCount.value() > 0 &&
+        ShapedType::isStatic(w)) {
+      const int64_t expectedOutputRows = w * rowCount.value();
+      if (ShapedType::isStatic(outputDataShape.getDimSize(1)) &&
+          outputDataShape.getDimSize(1) != expectedOutputRows)
+        return emitOpError() << "requires output[0] dimension 1 to have size "
+                             << expectedOutputRows << ", got "
+                             << outputDataShape.getDimSize(1);
+    }
+  }
+
+  if (values.size() == 2) {
+    const ShapeAdaptor valuesScaleShape(values[1].getType());
+    if (valuesScaleShape.hasRank()) {
+      if (failed(tryUpdateDimOrFailure(*this, n, valuesScaleShape.getDimSize(0),
+                                       "values[1]", "batch")) ||
+          failed(tryUpdateDimOrFailure(*this, k, valuesScaleShape.getDimSize(1),
+                                       "values[1]", "rows")))
+        return failure();
+      multiplesOfC = valuesScaleShape.getDimSize(2);
+    }
+
+    const ShapeAdaptor outputScaleShape(output[1].getType());
+    if (outputScaleShape.hasRank()) {
+      if (failed(tryUpdateDimOrFailure(*this, n, outputScaleShape.getDimSize(0),
+                                       "output[1]", "batch")))
+        return failure();
+
+      if (auto rowCount = getConstantScalarIntValue(getRowCount());
+          succeeded(rowCount) && rowCount.value() > 0 &&
+          ShapedType::isStatic(w)) {
+        const int64_t expectedOutputRows = w * rowCount.value();
+        if (ShapedType::isStatic(outputScaleShape.getDimSize(1)) &&
+            outputScaleShape.getDimSize(1) != expectedOutputRows)
+          return emitOpError() << "requires output[1] dimension 1 to have size "
+                               << expectedOutputRows << ", got "
+                               << outputScaleShape.getDimSize(1);
+      }
+
+      if (ShapedType::isDynamic(multiplesOfC))
+        multiplesOfC = outputScaleShape.getDimSize(2);
+      else if (ShapedType::isStatic(outputScaleShape.getDimSize(2)) &&
+               multiplesOfC != outputScaleShape.getDimSize(2))
+        return emitOpError()
+               << "expected channels of output[1] to match size "
+               << multiplesOfC << ", got " << outputScaleShape.getDimSize(2);
+    }
+
+    if (ShapedType::isStatic(c) && ShapedType::isStatic(multiplesOfC) &&
+        multiplesOfC != c / blockSize)
+      return emitOpError()
+             << "expects channels of scale tensors to equal C/block_size (" << c
+             << "/" << blockSize << "), got " << multiplesOfC;
+  }
+
+  return success();
+}
+
 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     ResizeOp::Adaptor adaptor,
@@ -3987,6 +4174,8 @@ LogicalResult Conv2DBlockScaledOp::verify() {
 
   // Verify IC is a multiple of block size
   const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+  if (blockSize == 1)
+    return emitOpError("requires block_size to not be BLOCK_SIZE_1");
   if (ShapedType::isStatic(IC) && IC % blockSize != 0)
     return emitOpError("expect IC to be a multiple of block size, got IC=")
            << IC << ", block_size=" << blockSize;
@@ -4577,6 +4766,8 @@ LogicalResult CastFromBlockScaledOp::verify() {
   if (inputDataShape.hasRank()) {
     const unsigned int blockSize =
         BlockSizeAttr::getBlockSizeValue(getBlockSize());
+    if (blockSize == 1)
+      return emitOpError("requires block_size to not be BLOCK_SIZE_1");
     const int64_t inputDataLastDim =
         inputDataShape.getDimSize(inputDataShape.getRank() - 1);
     if (inputDataLastDim % blockSize != 0)
@@ -4650,6 +4841,8 @@ LogicalResult CastToBlockScaledOp::verify() {
 
   const unsigned int blockSize =
       BlockSizeAttr::getBlockSizeValue(getBlockSize());
+  if (blockSize == 1)
+    return emitOpError("requires block_size to not be BLOCK_SIZE_1");
   const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
   if (inputDataShape.hasRank()) {
     const int64_t inputDataLastDim =
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 01c85be4f704f..4ea225b860f6c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -185,6 +185,18 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
   return success();
 }
 
+template <>
+LogicalResult
+ProfileInfoDepot::populateProfileInfo(tosa::RowGatherBlockScaledOp op) {
+  for (Value value : op.getValues())
+    addValue(value);
+  addValue(op.getIndices());
+  addValue(op.getRowCount());
+  for (Value result : op.getOutput())
+    addValue(result);
+  return success();
+}
+
 template <>
 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
   addValue(op.getValuesIn());
@@ -288,6 +300,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
   POPULATE_PROFILE_INFO_CUSTOM(Tile)
   POPULATE_PROFILE_INFO_CUSTOM(Transpose)
   POPULATE_PROFILE_INFO_CUSTOM(Gather)
+  POPULATE_PROFILE_INFO_CUSTOM(RowGatherBlockScaled)
   POPULATE_PROFILE_INFO_CUSTOM(Scatter)
   POPULATE_PROFILE_INFO_CUSTOM(Resize)
   POPULATE_PROFILE_INFO_CUSTOM(Select)
@@ -598,10 +611,11 @@ SmallVector<OpComplianceInfo<T>> TosaProfileCompliance::findMatchedEntries(
     SmallVector<VersionedTypeInfo> sets = compInfo[i].operandTypeInfoSet;
     for (const auto &set : sets) {
       SmallVector<TypeInfo> expected = set.first;
-      assert(present.size() == expected.size() &&
-             "the entries for profile-based compliance do not match between...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/192272


More information about the Mlir-commits mailing list