[Mlir-commits] [mlir] [mlir][tosa] Add row_gather_block_scaled op (PR #192272)
Peng Sun
llvmlistbot at llvm.org
Wed Apr 15 08:10:14 PDT 2026
https://github.com/psunn created https://github.com/llvm/llvm-project/pull/192272
Add `tosa.row_gather_block_scaled` to the MLIR TOSA dialect, aligned with the current TOSA 1.1 draft spec and the implementation in `tosa-tools`.
This includes:
- op definition
- verifier and shape inference support
- validation / profile compliance wiring
- availability and extension handling
- lit tests for parsing, verification, shape inference, and version / extension gating
#### Notes
The op supports both spec-defined forms:
- non-block-scaled: 1 input value tensor, `BLOCK_SIZE_1`, 1 output
- block-scaled: data + scale tensor list, non-`BLOCK_SIZE_1`, 2 outputs
This also tightens existing block-scaled-only ops to reject `BLOCK_SIZE_1` now that it is part of the shared enum.
>From b5cc923d882de2809fc8163c46438c96c2299612 Mon Sep 17 00:00:00 2001
From: Peng Sun <peng.sun at arm.com>
Date: Mon, 13 Apr 2026 13:59:30 +0100
Subject: [PATCH] [mlir][tosa] Add row_gather_block_scaled op
Signed-off-by: Peng Sun <peng.sun at arm.com>
---
.../Dialect/Tosa/IR/TosaComplianceData.h.inc | 70 +++++++
.../mlir/Dialect/Tosa/IR/TosaOpBase.td | 6 +-
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 38 ++++
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 193 ++++++++++++++++++
.../Tosa/Transforms/TosaProfileCompliance.cpp | 22 +-
.../Tosa/Transforms/TosaValidation.cpp | 12 ++
mlir/test/Dialect/Tosa/availability.mlir | 11 +-
mlir/test/Dialect/Tosa/invalid_extension.mlir | 8 +
mlir/test/Dialect/Tosa/ops.mlir | 16 ++
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 20 ++
.../tosa-validation-version-1p0-invalid.mlir | 11 +-
...a-validation-version-1p1-pro-fp-valid.mlir | 9 +
.../tosa-validation-version-1p1-valid.mlir | 11 +-
mlir/test/Dialect/Tosa/verifier.mlir | 36 ++++
14 files changed, 453 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index d3e2cd129028e..50bb9f69c6242 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -396,6 +396,25 @@ profileComplianceMap = {
{{i16T, i64T, i16T}, SpecificationVersion::V_1_1_DRAFT},
{{i32T, i64T, i32T}, SpecificationVersion::V_1_1_DRAFT}},
anyOf}}},
+ {"tosa.row_gather_block_scaled",
+ {{{Profile::pro_int},
+ {{{i8T, i32T, i32T, i8T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i16T, i32T, i32T, i16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_1_DRAFT}}},
+ {{Profile::pro_fp},
+ {{{i8T, i32T, i32T, i8T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i16T, i32T, i32T, i16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp16T, i32T, i32T, fp16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp32T, i32T, i32T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}},
+ {{Profile::pro_fp, Profile::pro_int},
+ {{{boolT, i32T, i32T, boolT}, SpecificationVersion::V_1_1_DRAFT},
+ {{i8T, i64T, i32T, i8T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i16T, i64T, i32T, i16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i32T, i64T, i32T, i32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i64T, i64T, i32T, i64T}, SpecificationVersion::V_1_1_DRAFT},
+ {{boolT, i64T, i32T, boolT}, SpecificationVersion::V_1_1_DRAFT}},
+ anyOf}}},
{"tosa.scatter",
{{{Profile::pro_int},
{{{i8T, i32T, i8T, i8T}, SpecificationVersion::V_1_0},
@@ -890,6 +909,57 @@ extensionComplianceMap = {
{{Extension::bf16, Extension::int64},
{{{bf16T, i64T, bf16T}, SpecificationVersion::V_1_1_DRAFT}},
allOf}}},
+ {"tosa.row_gather_block_scaled",
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, i32T, i32T, fp8e4m3T}, SpecificationVersion::V_1_1_DRAFT}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, i32T, i32T, fp8e5m2T}, SpecificationVersion::V_1_1_DRAFT}}},
+ {{Extension::bf16},
+ {{{bf16T, i32T, i32T, bf16T}, SpecificationVersion::V_1_1_DRAFT}}},
+ {{Extension::int64},
+ {{{i8T, i64T, i32T, i8T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i16T, i64T, i32T, i16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i32T, i64T, i32T, i32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i64T, i64T, i32T, i64T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp16T, i64T, i32T, fp16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp32T, i64T, i32T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{boolT, i64T, i32T, boolT}, SpecificationVersion::V_1_1_DRAFT}}},
+ {{Extension::fp8e4m3, Extension::int64},
+ {{{fp8e4m3T, i64T, i32T, fp8e4m3T}, SpecificationVersion::V_1_1_DRAFT}},
+ allOf},
+ {{Extension::fp8e5m2, Extension::int64},
+ {{{fp8e5m2T, i64T, i32T, fp8e5m2T}, SpecificationVersion::V_1_1_DRAFT}},
+ allOf},
+ {{Extension::bf16, Extension::int64},
+ {{{bf16T, i64T, i32T, bf16T}, SpecificationVersion::V_1_1_DRAFT}},
+ allOf},
+ {{Extension::mxfp},
+ {{{fp4e2m1T, fp8ue8m0T, i32T, i32T, fp4e2m1T, fp8ue8m0T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e2m3T, fp8ue8m0T, i32T, i32T, fp6e2m3T, fp8ue8m0T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e3m2T, fp8ue8m0T, i32T, i32T, fp6e3m2T, fp8ue8m0T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e4m3T, fp8ue8m0T, i32T, i32T, fp8e4m3T, fp8ue8m0T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e5m2T, fp8ue8m0T, i32T, i32T, fp8e5m2T, fp8ue8m0T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{mxint8T, fp8ue8m0T, i32T, i32T, mxint8T, fp8ue8m0T},
+ SpecificationVersion::V_1_1_DRAFT}}},
+ {{Extension::mxfp, Extension::int64},
+ {{{fp4e2m1T, fp8ue8m0T, i64T, i32T, fp4e2m1T, fp8ue8m0T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e2m3T, fp8ue8m0T, i64T, i32T, fp6e2m3T, fp8ue8m0T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e3m2T, fp8ue8m0T, i64T, i32T, fp6e3m2T, fp8ue8m0T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e4m3T, fp8ue8m0T, i64T, i32T, fp8e4m3T, fp8ue8m0T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e5m2T, fp8ue8m0T, i64T, i32T, fp8e5m2T, fp8ue8m0T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{mxint8T, fp8ue8m0T, i64T, i32T, mxint8T, fp8ue8m0T},
+ SpecificationVersion::V_1_1_DRAFT}},
+ allOf}}},
{"tosa.scatter",
{{{Extension::fp8e4m3},
{{{fp8e4m3T, i32T, fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 1f05aee3e5eec..591073e9985ae 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -484,11 +484,12 @@ def Tosa_RoundingModeAttr
: Tosa_I32EnumAttr<"RoundingMode", "Supported rounding modes", "rounding_mode",
[Tosa_ROUNDING_SINGLE_ROUND, Tosa_ROUNDING_INEXACT_ROUND, Tosa_ROUNDING_DOUBLE_ROUND]>;
+def Tosa_BLOCK_SIZE_1 : I32EnumAttrCase<"BLOCK_SIZE_1", 1>;
def Tosa_BLOCK_SIZE_32 : I32EnumAttrCase<"BLOCK_SIZE_32", 32>;
def Tosa_BlockSizeAttr
- : Tosa_I32EnumAttr<"BlockSize", "Block size for the block_scaled formats", "block_size",
- [Tosa_BLOCK_SIZE_32]> {
+ : Tosa_I32EnumAttr<"BlockSize", "Block size for the block_scaled formats",
+ "block_size", [Tosa_BLOCK_SIZE_1, Tosa_BLOCK_SIZE_32]> {
let extraClassDeclaration = [{
static uint32_t getBlockSizeValue(BlockSize blockSize) {
return static_cast<uint32_t>(blockSize);
@@ -496,7 +497,6 @@ def Tosa_BlockSizeAttr
}];
}
-
//===----------------------------------------------------------------------===//
// TOSA Interfaces.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 45d1388a28749..ba750ef4438db 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2487,6 +2487,44 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather", [NoMemoryEffect]> {
"operands attr-dict `:` functional-type(operands, results)";
}
+//===----------------------------------------------------------------------===//
+// Operator: row_gather_block_scaled
+//===----------------------------------------------------------------------===//
+def Tosa_RowGatherBlockScaledOp
+ : Tosa_InferShapedTypeOp<"row_gather_block_scaled", [NoMemoryEffect]> {
+ let summary =
+ "Row gather operation for block-scaled and non-block-scaled data.";
+
+ let description = [{
+ Generate a tensor-list which contains a data tensor and an optional scale
+ tensor based on the input indices and row_count. The number of consecutive
+ rows gathered for each index is specified in row_count.
+
+ This operation follows the TOSA 1.1 draft specification and may evolve as
+ the specification is updated.
+
+ This operation is not pure. Undefined behaviour may occur if the specified
+ indices are out of range.
+ }];
+
+ let arguments = (ins Variadic<Tosa_Tensor3D>:$values,
+ Tosa_IndexTensor2D:$indices, Tosa_ScalarInt32Tensor:$row_count,
+ Tosa_BlockSizeAttr:$block_size);
+
+ let results = (outs Variadic<Tosa_Tensor3D>:$output);
+
+ list<Availability> availability = [Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+ Extension<[Tosa_EXT_FP8E4M3,
+ Tosa_EXT_FP8E5M2, Tosa_EXT_BF16,
+ Tosa_EXT_MXFP, Tosa_EXT_INT64]>,
+ ];
+
+ let hasVerifier = 1;
+
+ let assemblyFormat =
+ "operands attr-dict `:` functional-type(operands, results)";
+}
+
//===----------------------------------------------------------------------===//
// Operator: scatter
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 2754d3b21d4a6..360ab25875bd7 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2201,6 +2201,8 @@ LogicalResult MatmulTBlockScaledOp::verify() {
// Verify C is a multiple of block size
const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ if (blockSize == 1)
+ return emitOpError("requires block_size to not be BLOCK_SIZE_1");
if (ShapedType::isStatic(C) && C % blockSize != 0)
return emitOpError("expect C to be a multiple of block size, got C=")
<< C << ", block_size=" << blockSize;
@@ -2848,6 +2850,18 @@ static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
return -1;
}
+static FailureOr<int64_t> getConstantScalarIntValue(Value val) {
+ ElementsAttr attr;
+ if (!matchPattern(val, m_Constant(&attr)))
+ return failure();
+
+ if (!llvm::isa<IntegerType>(attr.getElementType()) ||
+ attr.getNumElements() != 1)
+ return failure();
+
+ return attr.getValues<APInt>()[0].getSExtValue();
+}
+
template <typename T>
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
const std::string &operand) {
@@ -3096,6 +3110,48 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
return success();
}
+LogicalResult tosa::RowGatherBlockScaledOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ RowGatherBlockScaledOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ const auto values = adaptor.getValues();
+ if (values.empty() || values.size() > 2)
+ return failure();
+
+ SmallVector<int64_t> dataShape(3, ShapedType::kDynamic);
+ const ShapeAdaptor valuesShape(values.front().getType());
+ if (valuesShape.hasRank()) {
+ dataShape[0] = valuesShape.getDimSize(0);
+ dataShape[2] = valuesShape.getDimSize(2);
+ }
+
+ const ShapeAdaptor indicesShape(adaptor.getIndices().getType());
+ if (indicesShape.hasRank()) {
+ if (dataShape[0] == ShapedType::kDynamic)
+ dataShape[0] = indicesShape.getDimSize(0);
+
+ if (auto rowCount = getConstantScalarIntValue(adaptor.getRowCount());
+ succeeded(rowCount) && rowCount.value() > 0) {
+ const int64_t indicesW = indicesShape.getDimSize(1);
+ if (ShapedType::isStatic(indicesW))
+ dataShape[1] = indicesW * rowCount.value();
+ }
+ }
+
+ inferredReturnShapes.push_back(ShapedTypeComponents(dataShape));
+ if (values.size() == 1)
+ return success();
+
+ SmallVector<int64_t> scaleShape = dataShape;
+ const uint32_t blockSize =
+ BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
+ if (ShapedType::isStatic(dataShape[2]))
+ scaleShape[2] = dataShape[2] / blockSize;
+
+ inferredReturnShapes.push_back(ShapedTypeComponents(scaleShape));
+ return success();
+}
+
LogicalResult tosa::GatherOp::verify() {
if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
/* outType = */ getOutput().getType())
@@ -3145,6 +3201,137 @@ LogicalResult tosa::GatherOp::verify() {
return success();
}
+LogicalResult tosa::RowGatherBlockScaledOp::verify() {
+ const OperandRange values = getValues();
+ const ResultRange output = getOutput();
+ if (values.empty() || values.size() > 2)
+ return emitOpError()
+ << "expects values tensor list length to be 1 or 2, got "
+ << values.size();
+ if (output.size() != values.size())
+ return emitOpError()
+ << "expects output tensor list length to match values tensor list "
+ "length, got "
+ << output.size() << " results for " << values.size()
+ << " input tensors";
+
+ const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ if (values.size() == 1 && blockSize != 1)
+ return emitOpError()
+ << "requires block_size to be BLOCK_SIZE_1 when values tensor list "
+ "length is 1";
+ if (values.size() == 2 && blockSize == 1)
+ return emitOpError()
+ << "requires block_size to not be BLOCK_SIZE_1 when values tensor "
+ "list length is 2";
+
+ if (failed(verifySameElementTypes(*this, values[0].getType(),
+ output[0].getType(), "values[0]",
+ "output[0]")))
+ return failure();
+ if (values.size() == 2 && failed(verifySameElementTypes(
+ *this, values[1].getType(), output[1].getType(),
+ "values[1]", "output[1]")))
+ return failure();
+
+ if (auto rowCount = getConstantScalarIntValue(getRowCount());
+ succeeded(rowCount) && rowCount.value() <= 0)
+ return emitOpError() << "requires row_count to be > 0, got "
+ << rowCount.value();
+
+ int64_t n = ShapedType::kDynamic;
+ int64_t k = ShapedType::kDynamic;
+ int64_t c = ShapedType::kDynamic;
+ int64_t w = ShapedType::kDynamic;
+ int64_t multiplesOfC = ShapedType::kDynamic;
+
+ const ShapeAdaptor valuesDataShape(values[0].getType());
+ if (valuesDataShape.hasRank()) {
+ n = valuesDataShape.getDimSize(0);
+ k = valuesDataShape.getDimSize(1);
+ c = valuesDataShape.getDimSize(2);
+ }
+
+ if (ShapedType::isStatic(c) && c % blockSize != 0)
+ return emitOpError() << "expects channels of values[0] (" << c
+ << ") to be divisible by block_size (" << blockSize
+ << ")";
+
+ const ShapeAdaptor indicesShape(getIndices().getType());
+ if (indicesShape.hasRank()) {
+ if (failed(tryUpdateDimOrFailure(*this, n, indicesShape.getDimSize(0),
+ "indices", "batch")))
+ return failure();
+ w = indicesShape.getDimSize(1);
+ }
+
+ const ShapeAdaptor outputDataShape(output[0].getType());
+ if (outputDataShape.hasRank()) {
+ if (failed(tryUpdateDimOrFailure(*this, n, outputDataShape.getDimSize(0),
+ "output[0]", "batch")) ||
+ failed(tryUpdateDimOrFailure(*this, c, outputDataShape.getDimSize(2),
+ "output[0]", "channels")))
+ return failure();
+
+ if (auto rowCount = getConstantScalarIntValue(getRowCount());
+ succeeded(rowCount) && rowCount.value() > 0 &&
+ ShapedType::isStatic(w)) {
+ const int64_t expectedOutputRows = w * rowCount.value();
+ if (ShapedType::isStatic(outputDataShape.getDimSize(1)) &&
+ outputDataShape.getDimSize(1) != expectedOutputRows)
+ return emitOpError() << "requires output[0] dimension 1 to have size "
+ << expectedOutputRows << ", got "
+ << outputDataShape.getDimSize(1);
+ }
+ }
+
+ if (values.size() == 2) {
+ const ShapeAdaptor valuesScaleShape(values[1].getType());
+ if (valuesScaleShape.hasRank()) {
+ if (failed(tryUpdateDimOrFailure(*this, n, valuesScaleShape.getDimSize(0),
+ "values[1]", "batch")) ||
+ failed(tryUpdateDimOrFailure(*this, k, valuesScaleShape.getDimSize(1),
+ "values[1]", "rows")))
+ return failure();
+ multiplesOfC = valuesScaleShape.getDimSize(2);
+ }
+
+ const ShapeAdaptor outputScaleShape(output[1].getType());
+ if (outputScaleShape.hasRank()) {
+ if (failed(tryUpdateDimOrFailure(*this, n, outputScaleShape.getDimSize(0),
+ "output[1]", "batch")))
+ return failure();
+
+ if (auto rowCount = getConstantScalarIntValue(getRowCount());
+ succeeded(rowCount) && rowCount.value() > 0 &&
+ ShapedType::isStatic(w)) {
+ const int64_t expectedOutputRows = w * rowCount.value();
+ if (ShapedType::isStatic(outputScaleShape.getDimSize(1)) &&
+ outputScaleShape.getDimSize(1) != expectedOutputRows)
+ return emitOpError() << "requires output[1] dimension 1 to have size "
+ << expectedOutputRows << ", got "
+ << outputScaleShape.getDimSize(1);
+ }
+
+ if (ShapedType::isDynamic(multiplesOfC))
+ multiplesOfC = outputScaleShape.getDimSize(2);
+ else if (ShapedType::isStatic(outputScaleShape.getDimSize(2)) &&
+ multiplesOfC != outputScaleShape.getDimSize(2))
+ return emitOpError()
+ << "expected channels of output[1] to match size "
+ << multiplesOfC << ", got " << outputScaleShape.getDimSize(2);
+ }
+
+ if (ShapedType::isStatic(c) && ShapedType::isStatic(multiplesOfC) &&
+ multiplesOfC != c / blockSize)
+ return emitOpError()
+ << "expects channels of scale tensors to equal C/block_size (" << c
+ << "/" << blockSize << "), got " << multiplesOfC;
+ }
+
+ return success();
+}
+
LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ResizeOp::Adaptor adaptor,
@@ -3987,6 +4174,8 @@ LogicalResult Conv2DBlockScaledOp::verify() {
// Verify IC is a multiple of block size
const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ if (blockSize == 1)
+ return emitOpError("requires block_size to not be BLOCK_SIZE_1");
if (ShapedType::isStatic(IC) && IC % blockSize != 0)
return emitOpError("expect IC to be a multiple of block size, got IC=")
<< IC << ", block_size=" << blockSize;
@@ -4577,6 +4766,8 @@ LogicalResult CastFromBlockScaledOp::verify() {
if (inputDataShape.hasRank()) {
const unsigned int blockSize =
BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ if (blockSize == 1)
+ return emitOpError("requires block_size to not be BLOCK_SIZE_1");
const int64_t inputDataLastDim =
inputDataShape.getDimSize(inputDataShape.getRank() - 1);
if (inputDataLastDim % blockSize != 0)
@@ -4650,6 +4841,8 @@ LogicalResult CastToBlockScaledOp::verify() {
const unsigned int blockSize =
BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ if (blockSize == 1)
+ return emitOpError("requires block_size to not be BLOCK_SIZE_1");
const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
if (inputDataShape.hasRank()) {
const int64_t inputDataLastDim =
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 01c85be4f704f..4ea225b860f6c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -185,6 +185,18 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
return success();
}
+template <>
+LogicalResult
+ProfileInfoDepot::populateProfileInfo(tosa::RowGatherBlockScaledOp op) {
+ for (Value value : op.getValues())
+ addValue(value);
+ addValue(op.getIndices());
+ addValue(op.getRowCount());
+ for (Value result : op.getOutput())
+ addValue(result);
+ return success();
+}
+
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
addValue(op.getValuesIn());
@@ -288,6 +300,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
POPULATE_PROFILE_INFO_CUSTOM(Tile)
POPULATE_PROFILE_INFO_CUSTOM(Transpose)
POPULATE_PROFILE_INFO_CUSTOM(Gather)
+ POPULATE_PROFILE_INFO_CUSTOM(RowGatherBlockScaled)
POPULATE_PROFILE_INFO_CUSTOM(Scatter)
POPULATE_PROFILE_INFO_CUSTOM(Resize)
POPULATE_PROFILE_INFO_CUSTOM(Select)
@@ -598,10 +611,11 @@ SmallVector<OpComplianceInfo<T>> TosaProfileCompliance::findMatchedEntries(
SmallVector<VersionedTypeInfo> sets = compInfo[i].operandTypeInfoSet;
for (const auto &set : sets) {
SmallVector<TypeInfo> expected = set.first;
- assert(present.size() == expected.size() &&
- "the entries for profile-based compliance do not match between "
- "the generated metadata and the type definition retrieved from "
- " the operation");
+ // Tensor-list operators can legitimately have multiple valid signatures
+ // with different operand/result counts, e.g. data-only and data+scale
+ // forms. Treat those as non-matches instead of asserting.
+ if (present.size() != expected.size())
+ continue;
bool isFound = true;
// Compare the type signature between the given operation and the
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 3e2cda9d37666..177878b98cbaa 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -114,6 +114,17 @@ static LogicalResult checkConstantOperandMatMul(Operation *op,
return success();
}
+static LogicalResult
+checkConstantOperandRowGatherBlockScaled(Operation *op, const TargetEnv &env) {
+ if (!env.allows(Extension::dynamic) &&
+ isa<tosa::RowGatherBlockScaledOp>(op)) {
+ auto rowGatherOp = cast<tosa::RowGatherBlockScaledOp>(op);
+ const unsigned rowCountIndex = rowGatherOp.getValues().size() + 1;
+ return checkConstantOperands(op, {rowCountIndex});
+ }
+ return success();
+}
+
static LogicalResult checkConstantOperandAvgPool2d(Operation *op,
const TargetEnv &env) {
if (!env.allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
@@ -199,6 +210,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
constCheckers.emplace_back(
checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
constCheckers.emplace_back(checkConstantOperandMatMul);
+ constCheckers.emplace_back(checkConstantOperandRowGatherBlockScaled);
constCheckers.emplace_back(checkConstantOperandAvgPool2d);
constCheckers.emplace_back(checkConstantOperandAvgPool2dAdaptive);
constCheckers.emplace_back(checkConstantOperandNegate);
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 34de532639994..81294f2c0c308 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -595,6 +595,16 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
return %0 : tensor<13x26x3xf32>
}
+// -----
+// CHECK-LABEL: row_gather_block_scaled
+func.func @test_row_gather_block_scaled(%arg0: tensor<13x21x32xf4E2M1FN>, %arg1: tensor<13x21x1xf8E8M0FNU>, %arg2: tensor<13x26xi32>) -> (tensor<13x52x32xf4E2M1FN>, tensor<13x52x1xf8E8M0FNU>) {
+ %row_count = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+ // CHECK: profiles: [ [pro_int, pro_fp] ]
+ // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16, mxfp, int64] ]
+ %0:2 = tosa.row_gather_block_scaled %arg0, %arg1, %arg2, %row_count {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<13x21x32xf4E2M1FN>, tensor<13x21x1xf8E8M0FNU>, tensor<13x26xi32>, tensor<1xi32>) -> (tensor<13x52x32xf4E2M1FN>, tensor<13x52x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<13x52x32xf4E2M1FN>, tensor<13x52x1xf8E8M0FNU>
+}
+
// -----
// CHECK-LABEL: scatter
func.func @test_scatter(%arg0: tensor<13x28x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x28x3xf32> {
@@ -727,4 +737,3 @@ func.func @test_cast_to_block_scaled(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = BLOCK_SIZE_32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>
}
-
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 8a9b014864c74..295d7172bc2c4 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -249,6 +249,14 @@ func.func @test_gather(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x26xi32>) ->
return %0 : tensor<13x26x3xbf16>
}
+// -----
+func.func @test_row_gather_block_scaled(%arg0: tensor<13x21x32xf4E2M1FN>, %arg1: tensor<13x21x1xf8E8M0FNU>, %arg2: tensor<13x26xi32>) -> (tensor<13x52x32xf4E2M1FN>, tensor<13x52x1xf8E8M0FNU>) {
+ %row_count = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+ // expected-error at +1 {{'tosa.row_gather_block_scaled' op illegal: requires any of [mxfp] profiles/extensions to be specified in the target environment}}
+ %0:2 = tosa.row_gather_block_scaled %arg0, %arg1, %arg2, %row_count {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<13x21x32xf4E2M1FN>, tensor<13x21x1xf8E8M0FNU>, tensor<13x26xi32>, tensor<1xi32>) -> (tensor<13x52x32xf4E2M1FN>, tensor<13x52x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<13x52x32xf4E2M1FN>, tensor<13x52x1xf8E8M0FNU>
+}
+
// -----
func.func @test_scatter(%arg0: tensor<13x26x3xbf16>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xbf16>) -> tensor<13x26x3xbf16> {
// expected-error at +1 {{'tosa.scatter' op illegal: requires any of [bf16] profiles/extensions to be specified in the target environment}}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index b30e92c4a9621..fbdaf2e42d022 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -906,6 +906,22 @@ func.func @test_gather_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tenso
return %0 : tensor<13x26x3xf32>
}
+// -----
+// CHECK-LABEL: test_row_gather_block_scaled
+func.func @test_row_gather_block_scaled(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x52x3xf32> {
+ %row_count = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+ %0 = tosa.row_gather_block_scaled %arg0, %arg1, %row_count {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<1xi32>) -> (tensor<13x52x3xf32>)
+ return %0 : tensor<13x52x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_row_gather_block_scaled_mxfp
+func.func @test_row_gather_block_scaled_mxfp(%arg0: tensor<13x21x32xf4E2M1FN>, %arg1: tensor<13x21x1xf8E8M0FNU>, %arg2: tensor<13x26xi32>) -> (tensor<13x52x32xf4E2M1FN>, tensor<13x52x1xf8E8M0FNU>) {
+ %row_count = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+ %0:2 = tosa.row_gather_block_scaled %arg0, %arg1, %arg2, %row_count {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<13x21x32xf4E2M1FN>, tensor<13x21x1xf8E8M0FNU>, tensor<13x26xi32>, tensor<1xi32>) -> (tensor<13x52x32xf4E2M1FN>, tensor<13x52x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<13x52x32xf4E2M1FN>, tensor<13x52x1xf8E8M0FNU>
+}
+
// -----
// CHECK-LABEL: scatter
func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 408300fa7034b..3bdac624a42a9 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -674,6 +674,26 @@ func.func @gather_minimum_info(%arg0 : tensor<3x?x5xi32>, %arg1 : tensor<?x6xi32
// -----
+// CHECK-LABEL: @row_gather_block_scaled_static
+func.func @row_gather_block_scaled_static(%arg0 : tensor<3x4x5xi32>, %arg1 : tensor<3x6xi32>) {
+ %row_count = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+ // CHECK: tosa.row_gather_block_scaled %arg0, %arg1, %[[ROW_COUNT:.+]] {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<1xi32>) -> tensor<3x12x5xi32>
+ %0 = tosa.row_gather_block_scaled %arg0, %arg1, %row_count {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<1xi32>) -> (tensor<?x?x?xi32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @row_gather_block_scaled_mxfp_static
+func.func @row_gather_block_scaled_mxfp_static(%arg0 : tensor<3x4x32xf4E2M1FN>, %arg1 : tensor<3x4x1xf8E8M0FNU>, %arg2 : tensor<3x6xi32>) {
+ %row_count = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+ // CHECK: tosa.row_gather_block_scaled %arg0, %arg1, %arg2, %[[ROW_COUNT:.+]] {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<3x4x32xf4E2M1FN>, tensor<3x4x1xf8E8M0FNU>, tensor<3x6xi32>, tensor<1xi32>) -> (tensor<3x12x32xf4E2M1FN>, tensor<3x12x1xf8E8M0FNU>)
+ %0:2 = tosa.row_gather_block_scaled %arg0, %arg1, %arg2, %row_count {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<3x4x32xf4E2M1FN>, tensor<3x4x1xf8E8M0FNU>, tensor<3x6xi32>, tensor<1xi32>) -> (tensor<?x?x?xf4E2M1FN>, tensor<?x?x?xf8E8M0FNU>)
+ return
+}
+
+// -----
+
// CHECK-LABEL: @scatter_static
func.func @scatter_static(%arg0 : tensor<3x8x5xi32>, %arg1 : tensor<3x6xi32>, %arg2 : tensor<3x6x5xi32>) {
// CHECK: tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x8x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<3x8x5xi32>
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
index eed9621edc7b2..693fd21e9cf4b 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
@@ -78,6 +78,15 @@ func.func @test_gather_bool_i32(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x26xi
// -----
+func.func @test_row_gather_block_scaled_i8_i32(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x26xi32>) -> tensor<13x52x3xi8> {
+ %row_count = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+ // expected-error at +1 {{'tosa.row_gather_block_scaled' op illegal: requires specification version compatible with 1.1 (got 1.0) OR requires specification version compatible with 1.1 (got 1.0) to be specified in the target environment}}
+ %0 = tosa.row_gather_block_scaled %arg0, %arg1, %row_count {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<13x21x3xi8>, tensor<13x26xi32>, tensor<1xi32>) -> (tensor<13x52x3xi8>)
+ return %0 : tensor<13x52x3xi8>
+}
+
+// -----
+
func.func @test_scatter_bool_i64(%arg0: tensor<13x52x3xi1>, %arg1: tensor<13x26xi64>, %arg2: tensor<13x26x3xi1>) -> tensor<13x52x3xi1> {
// expected-error at +1 {{'tosa.scatter' op illegal: requires specification version compatible with 1.1 (got 1.0) and requires any of [int64] profiles/extensions to be specified in the target environment}}
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xi1>, tensor<13x26xi64>, tensor<13x26x3xi1>) -> tensor<13x52x3xi1>
@@ -191,4 +200,4 @@ func.func @test_maxpool2d_adaptive(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x
%0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
(tensor<1x32x32x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x32x8xf32>
return %0 : tensor<1x32x32x8xf32>
-}
\ No newline at end of file
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-pro-fp-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-pro-fp-valid.mlir
index 0f2bbb71ee4b7..57802fd147f23 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-pro-fp-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-pro-fp-valid.mlir
@@ -13,3 +13,12 @@ func.func @test_scatter_i8_i32(%input: tensor<13x27x3xi8>, %indices: tensor<13x2
%scatter = tosa.scatter %input, %indices, %updates : (tensor<13x27x3xi8>, tensor<13x26xi32>, tensor<13x26x3xi8>) -> tensor<13x27x3xi8>
return %scatter : tensor<13x27x3xi8>
}
+
+// -----
+
+// CHECK-LABEL: test_row_gather_block_scaled_i8_i32
+func.func @test_row_gather_block_scaled_i8_i32(%input: tensor<13x27x3xi8>, %indices: tensor<13x26xi32>) -> tensor<13x52x3xi8> {
+ %row_count = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+ %gather = tosa.row_gather_block_scaled %input, %indices, %row_count {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<13x27x3xi8>, tensor<13x26xi32>, tensor<1xi32>) -> (tensor<13x52x3xi8>)
+ return %gather : tensor<13x52x3xi8>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index 49ac4904002b3..1bac8bacbaf40 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -223,6 +223,15 @@ func.func @test_gather_i8_i32_indices(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1
// -----
+// CHECK-LABEL: test_row_gather_block_scaled_i8_i32_indices
+func.func @test_row_gather_block_scaled_i8_i32_indices(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x26xi32>) -> tensor<13x52x3xi8> {
+ %row_count = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+ %0 = tosa.row_gather_block_scaled %arg0, %arg1, %row_count {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<13x21x3xi8>, tensor<13x26xi32>, tensor<1xi32>) -> (tensor<13x52x3xi8>)
+ return %0 : tensor<13x52x3xi8>
+}
+
+// -----
+
// CHECK-LABEL: test_scatter_i8_i32_indices
func.func @test_scatter_i8_i32_indices(%arg0: tensor<13x27x3xi8>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi8>) -> tensor<13x27x3xi8> {
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x27x3xi8>, tensor<13x26xi32>, tensor<13x26x3xi8>) -> tensor<13x27x3xi8>
@@ -430,4 +439,4 @@ func.func @test_maxpool2d_adaptive(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x
%0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
(tensor<1x32x32x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x32x8xf32>
return %0 : tensor<1x32x32x8xf32>
-}
\ No newline at end of file
+}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 5f3aa8764664d..1c92c9ae335f5 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -454,6 +454,42 @@ func.func @test_gather_invalid_out_C(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1
return %0 : tensor<13x26x8xf32>
}
+// -----
+// CHECK-LABEL: test_row_gather_block_scaled_output_count_mismatch
+func.func @test_row_gather_block_scaled_output_count_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> (tensor<13x52x3xf32>, tensor<13x52x3xf32>) {
+ %row_count = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+ // expected-error at +1 {{'tosa.row_gather_block_scaled' op expects output tensor list length to match values tensor list length, got 2 results for 1 input tensors}}
+ %0:2 = tosa.row_gather_block_scaled %arg0, %arg1, %row_count {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<1xi32>) -> (tensor<13x52x3xf32>, tensor<13x52x3xf32>)
+ return %0#0, %0#1 : tensor<13x52x3xf32>, tensor<13x52x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_row_gather_block_scaled_block_size_one_required
+func.func @test_row_gather_block_scaled_block_size_one_required(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x52x3xf32> {
+ %row_count = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+ // expected-error at +1 {{'tosa.row_gather_block_scaled' op requires block_size to be BLOCK_SIZE_1 when values tensor list length is 1}}
+ %0 = tosa.row_gather_block_scaled %arg0, %arg1, %row_count {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<1xi32>) -> (tensor<13x52x3xf32>)
+ return %0 : tensor<13x52x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_row_gather_block_scaled_output_rows_mismatch
+func.func @test_row_gather_block_scaled_output_rows_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x53x3xf32> {
+ %row_count = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+ // expected-error at +1 {{'tosa.row_gather_block_scaled' op requires output[0] dimension 1 to have size 52, got 53}}
+ %0 = tosa.row_gather_block_scaled %arg0, %arg1, %row_count {block_size = #tosa.block_size<BLOCK_SIZE_1>} : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<1xi32>) -> (tensor<13x53x3xf32>)
+ return %0 : tensor<13x53x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_row_gather_block_scaled_scale_channel_mismatch
+func.func @test_row_gather_block_scaled_scale_channel_mismatch(%arg0: tensor<13x21x32xf4E2M1FN>, %arg1: tensor<13x21x2xf8E8M0FNU>, %arg2: tensor<13x26xi32>) -> (tensor<13x52x32xf4E2M1FN>, tensor<13x52x2xf8E8M0FNU>) {
+ %row_count = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+ // expected-error at +1 {{'tosa.row_gather_block_scaled' op expects channels of scale tensors to equal C/block_size (32/32), got 2}}
+ %0:2 = tosa.row_gather_block_scaled %arg0, %arg1, %arg2, %row_count {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<13x21x32xf4E2M1FN>, tensor<13x21x2xf8E8M0FNU>, tensor<13x26xi32>, tensor<1xi32>) -> (tensor<13x52x32xf4E2M1FN>, tensor<13x52x2xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<13x52x32xf4E2M1FN>, tensor<13x52x2xf8E8M0FNU>
+}
+
// -----
func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
More information about the Mlir-commits
mailing list