[Mlir-commits] [mlir] [mlir][tosa] Add row_gather_block_scaled op (PR #192272)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 15 08:10:51 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
Author: Peng Sun (psunn)
<details>
<summary>Changes</summary>
Add `tosa.row_gather_block_scaled` to the MLIR TOSA dialect, aligned with the current TOSA 1.1 draft spec and the implementation in `tosa-tools`.
This includes:
- op definition
- verifier and shape inference support
- validation / profile compliance wiring
- availability and extension handling
- lit tests for parsing, verification, shape inference, and version / extension gating
#### Notes
The op supports both spec-defined forms:
- non-block-scaled: 1 input value tensor, `BLOCK_SIZE_1`, 1 output
- block-scaled: data + scale tensor list, non-`BLOCK_SIZE_1`, 2 outputs
This also tightens existing block-scaled-only ops to reject `BLOCK_SIZE_1` now that it is part of the shared enum.
---
Patch is 36.87 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/192272.diff
14 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc (+70)
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+3-3)
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+38)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+193)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+18-4)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+12)
- (modified) mlir/test/Dialect/Tosa/availability.mlir (+10-1)
- (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+8)
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+16)
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+20)
- (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir (+10-1)
- (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-pro-fp-valid.mlir (+9)
- (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir (+10-1)
- (modified) mlir/test/Dialect/Tosa/verifier.mlir (+36)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index d3e2cd129028e..50bb9f69c6242 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -396,6 +396,25 @@ profileComplianceMap = {
{{i16T, i64T, i16T}, SpecificationVersion::V_1_1_DRAFT},
{{i32T, i64T, i32T}, SpecificationVersion::V_1_1_DRAFT}},
anyOf}}},
+ {"tosa.row_gather_block_scaled",
+ {{{Profile::pro_int},
+ {{{i8T, i32T, i32T, i8T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i16T, i32T, i32T, i16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_1_DRAFT}}},
+ {{Profile::pro_fp},
+ {{{i8T, i32T, i32T, i8T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i16T, i32T, i32T, i16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp16T, i32T, i32T, fp16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp32T, i32T, i32T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}},
+ {{Profile::pro_fp, Profile::pro_int},
+ {{{boolT, i32T, i32T, boolT}, SpecificationVersion::V_1_1_DRAFT},
+ {{i8T, i64T, i32T, i8T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i16T, i64T, i32T, i16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i32T, i64T, i32T, i32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i64T, i64T, i32T, i64T}, SpecificationVersion::V_1_1_DRAFT},
+ {{boolT, i64T, i32T, boolT}, SpecificationVersion::V_1_1_DRAFT}},
+ anyOf}}},
{"tosa.scatter",
{{{Profile::pro_int},
{{{i8T, i32T, i8T, i8T}, SpecificationVersion::V_1_0},
@@ -890,6 +909,57 @@ extensionComplianceMap = {
{{Extension::bf16, Extension::int64},
{{{bf16T, i64T, bf16T}, SpecificationVersion::V_1_1_DRAFT}},
allOf}}},
+ {"tosa.row_gather_block_scaled",
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, i32T, i32T, fp8e4m3T}, SpecificationVersion::V_1_1_DRAFT}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, i32T, i32T, fp8e5m2T}, SpecificationVersion::V_1_1_DRAFT}}},
+ {{Extension::bf16},
+ {{{bf16T, i32T, i32T, bf16T}, SpecificationVersion::V_1_1_DRAFT}}},
+ {{Extension::int64},
+ {{{i8T, i64T, i32T, i8T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i16T, i64T, i32T, i16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i32T, i64T, i32T, i32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i64T, i64T, i32T, i64T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp16T, i64T, i32T, fp16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp32T, i64T, i32T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{boolT, i64T, i32T, boolT}, SpecificationVersion::V_1_1_DRAFT}}},
+ {{Extension::fp8e4m3, Extension::int64},
+ {{{fp8e4m3T, i64T, i32T, fp8e4m3T}, SpecificationVersion::V_1_1_DRAFT}},
+ allOf},
+ {{Extension::fp8e5m2, Extension::int64},
+ {{{fp8e5m2T, i64T, i32T, fp8e5m2T}, SpecificationVersion::V_1_1_DRAFT}},
+ allOf},
+ {{Extension::bf16, Extension::int64},
+ {{{bf16T, i64T, i32T, bf16T}, SpecificationVersion::V_1_1_DRAFT}},
+ allOf},
+ {{Extension::mxfp},
+ {{{fp4e2m1T, fp8ue8m0T, i32T, i32T, fp4e2m1T, fp8ue8m0T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e2m3T, fp8ue8m0T, i32T, i32T, fp6e2m3T, fp8ue8m0T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e3m2T, fp8ue8m0T, i32T, i32T, fp6e3m2T, fp8ue8m0T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e4m3T, fp8ue8m0T, i32T, i32T, fp8e4m3T, fp8ue8m0T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e5m2T, fp8ue8m0T, i32T, i32T, fp8e5m2T, fp8ue8m0T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{mxint8T, fp8ue8m0T, i32T, i32T, mxint8T, fp8ue8m0T},
+ SpecificationVersion::V_1_1_DRAFT}}},
+ {{Extension::mxfp, Extension::int64},
+ {{{fp4e2m1T, fp8ue8m0T, i64T, i32T, fp4e2m1T, fp8ue8m0T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e2m3T, fp8ue8m0T, i64T, i32T, fp6e2m3T, fp8ue8m0T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e3m2T, fp8ue8m0T, i64T, i32T, fp6e3m2T, fp8ue8m0T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e4m3T, fp8ue8m0T, i64T, i32T, fp8e4m3T, fp8ue8m0T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e5m2T, fp8ue8m0T, i64T, i32T, fp8e5m2T, fp8ue8m0T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{mxint8T, fp8ue8m0T, i64T, i32T, mxint8T, fp8ue8m0T},
+ SpecificationVersion::V_1_1_DRAFT}},
+ allOf}}},
{"tosa.scatter",
{{{Extension::fp8e4m3},
{{{fp8e4m3T, i32T, fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 1f05aee3e5eec..591073e9985ae 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -484,11 +484,12 @@ def Tosa_RoundingModeAttr
: Tosa_I32EnumAttr<"RoundingMode", "Supported rounding modes", "rounding_mode",
[Tosa_ROUNDING_SINGLE_ROUND, Tosa_ROUNDING_INEXACT_ROUND, Tosa_ROUNDING_DOUBLE_ROUND]>;
+def Tosa_BLOCK_SIZE_1 : I32EnumAttrCase<"BLOCK_SIZE_1", 1>;
def Tosa_BLOCK_SIZE_32 : I32EnumAttrCase<"BLOCK_SIZE_32", 32>;
def Tosa_BlockSizeAttr
- : Tosa_I32EnumAttr<"BlockSize", "Block size for the block_scaled formats", "block_size",
- [Tosa_BLOCK_SIZE_32]> {
+ : Tosa_I32EnumAttr<"BlockSize", "Block size for the block_scaled formats",
+ "block_size", [Tosa_BLOCK_SIZE_1, Tosa_BLOCK_SIZE_32]> {
let extraClassDeclaration = [{
static uint32_t getBlockSizeValue(BlockSize blockSize) {
return static_cast<uint32_t>(blockSize);
@@ -496,7 +497,6 @@ def Tosa_BlockSizeAttr
}];
}
-
//===----------------------------------------------------------------------===//
// TOSA Interfaces.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 45d1388a28749..ba750ef4438db 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2487,6 +2487,44 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather", [NoMemoryEffect]> {
"operands attr-dict `:` functional-type(operands, results)";
}
+//===----------------------------------------------------------------------===//
+// Operator: row_gather_block_scaled
+//===----------------------------------------------------------------------===//
+def Tosa_RowGatherBlockScaledOp
+ : Tosa_InferShapedTypeOp<"row_gather_block_scaled", [NoMemoryEffect]> {
+ let summary =
+ "Row gather operation for block-scaled and non-block-scaled data.";
+
+ let description = [{
+ Generate a tensor-list which contains a data tensor and an optional scale
+ tensor based on the input indices and row_count. The number of consecutive
+ rows gathered for each index is specified in row_count.
+
+ This operation follows the TOSA 1.1 draft specification and may evolve as
+ the specification is updated.
+
+ This operation is not pure. Undefined behaviour may occur if the specified
+ indices are out of range.
+ }];
+
+ let arguments = (ins Variadic<Tosa_Tensor3D>:$values,
+ Tosa_IndexTensor2D:$indices, Tosa_ScalarInt32Tensor:$row_count,
+ Tosa_BlockSizeAttr:$block_size);
+
+ let results = (outs Variadic<Tosa_Tensor3D>:$output);
+
+ list<Availability> availability = [Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+ Extension<[Tosa_EXT_FP8E4M3,
+ Tosa_EXT_FP8E5M2, Tosa_EXT_BF16,
+ Tosa_EXT_MXFP, Tosa_EXT_INT64]>,
+ ];
+
+ let hasVerifier = 1;
+
+ let assemblyFormat =
+ "operands attr-dict `:` functional-type(operands, results)";
+}
+
//===----------------------------------------------------------------------===//
// Operator: scatter
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 2754d3b21d4a6..360ab25875bd7 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2201,6 +2201,8 @@ LogicalResult MatmulTBlockScaledOp::verify() {
// Verify C is a multiple of block size
const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ if (blockSize == 1)
+ return emitOpError("requires block_size to not be BLOCK_SIZE_1");
if (ShapedType::isStatic(C) && C % blockSize != 0)
return emitOpError("expect C to be a multiple of block size, got C=")
<< C << ", block_size=" << blockSize;
@@ -2848,6 +2850,18 @@ static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
return -1;
}
+static FailureOr<int64_t> getConstantScalarIntValue(Value val) {
+ ElementsAttr attr;
+ if (!matchPattern(val, m_Constant(&attr)))
+ return failure();
+
+ if (!llvm::isa<IntegerType>(attr.getElementType()) ||
+ attr.getNumElements() != 1)
+ return failure();
+
+ return attr.getValues<APInt>()[0].getSExtValue();
+}
+
template <typename T>
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
const std::string &operand) {
@@ -3096,6 +3110,48 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
return success();
}
+LogicalResult tosa::RowGatherBlockScaledOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ RowGatherBlockScaledOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ const auto values = adaptor.getValues();
+ if (values.empty() || values.size() > 2)
+ return failure();
+
+ SmallVector<int64_t> dataShape(3, ShapedType::kDynamic);
+ const ShapeAdaptor valuesShape(values.front().getType());
+ if (valuesShape.hasRank()) {
+ dataShape[0] = valuesShape.getDimSize(0);
+ dataShape[2] = valuesShape.getDimSize(2);
+ }
+
+ const ShapeAdaptor indicesShape(adaptor.getIndices().getType());
+ if (indicesShape.hasRank()) {
+ if (dataShape[0] == ShapedType::kDynamic)
+ dataShape[0] = indicesShape.getDimSize(0);
+
+ if (auto rowCount = getConstantScalarIntValue(adaptor.getRowCount());
+ succeeded(rowCount) && rowCount.value() > 0) {
+ const int64_t indicesW = indicesShape.getDimSize(1);
+ if (ShapedType::isStatic(indicesW))
+ dataShape[1] = indicesW * rowCount.value();
+ }
+ }
+
+ inferredReturnShapes.push_back(ShapedTypeComponents(dataShape));
+ if (values.size() == 1)
+ return success();
+
+ SmallVector<int64_t> scaleShape = dataShape;
+ const uint32_t blockSize =
+ BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
+ if (ShapedType::isStatic(dataShape[2]))
+ scaleShape[2] = dataShape[2] / blockSize;
+
+ inferredReturnShapes.push_back(ShapedTypeComponents(scaleShape));
+ return success();
+}
+
LogicalResult tosa::GatherOp::verify() {
if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
/* outType = */ getOutput().getType())
@@ -3145,6 +3201,137 @@ LogicalResult tosa::GatherOp::verify() {
return success();
}
+LogicalResult tosa::RowGatherBlockScaledOp::verify() {
+ const OperandRange values = getValues();
+ const ResultRange output = getOutput();
+ if (values.empty() || values.size() > 2)
+ return emitOpError()
+ << "expects values tensor list length to be 1 or 2, got "
+ << values.size();
+ if (output.size() != values.size())
+ return emitOpError()
+ << "expects output tensor list length to match values tensor list "
+ "length, got "
+ << output.size() << " results for " << values.size()
+ << " input tensors";
+
+ const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ if (values.size() == 1 && blockSize != 1)
+ return emitOpError()
+ << "requires block_size to be BLOCK_SIZE_1 when values tensor list "
+ "length is 1";
+ if (values.size() == 2 && blockSize == 1)
+ return emitOpError()
+ << "requires block_size to not be BLOCK_SIZE_1 when values tensor "
+ "list length is 2";
+
+ if (failed(verifySameElementTypes(*this, values[0].getType(),
+ output[0].getType(), "values[0]",
+ "output[0]")))
+ return failure();
+ if (values.size() == 2 && failed(verifySameElementTypes(
+ *this, values[1].getType(), output[1].getType(),
+ "values[1]", "output[1]")))
+ return failure();
+
+ if (auto rowCount = getConstantScalarIntValue(getRowCount());
+ succeeded(rowCount) && rowCount.value() <= 0)
+ return emitOpError() << "requires row_count to be > 0, got "
+ << rowCount.value();
+
+ int64_t n = ShapedType::kDynamic;
+ int64_t k = ShapedType::kDynamic;
+ int64_t c = ShapedType::kDynamic;
+ int64_t w = ShapedType::kDynamic;
+ int64_t multiplesOfC = ShapedType::kDynamic;
+
+ const ShapeAdaptor valuesDataShape(values[0].getType());
+ if (valuesDataShape.hasRank()) {
+ n = valuesDataShape.getDimSize(0);
+ k = valuesDataShape.getDimSize(1);
+ c = valuesDataShape.getDimSize(2);
+ }
+
+ if (ShapedType::isStatic(c) && c % blockSize != 0)
+ return emitOpError() << "expects channels of values[0] (" << c
+ << ") to be divisible by block_size (" << blockSize
+ << ")";
+
+ const ShapeAdaptor indicesShape(getIndices().getType());
+ if (indicesShape.hasRank()) {
+ if (failed(tryUpdateDimOrFailure(*this, n, indicesShape.getDimSize(0),
+ "indices", "batch")))
+ return failure();
+ w = indicesShape.getDimSize(1);
+ }
+
+ const ShapeAdaptor outputDataShape(output[0].getType());
+ if (outputDataShape.hasRank()) {
+ if (failed(tryUpdateDimOrFailure(*this, n, outputDataShape.getDimSize(0),
+ "output[0]", "batch")) ||
+ failed(tryUpdateDimOrFailure(*this, c, outputDataShape.getDimSize(2),
+ "output[0]", "channels")))
+ return failure();
+
+ if (auto rowCount = getConstantScalarIntValue(getRowCount());
+ succeeded(rowCount) && rowCount.value() > 0 &&
+ ShapedType::isStatic(w)) {
+ const int64_t expectedOutputRows = w * rowCount.value();
+ if (ShapedType::isStatic(outputDataShape.getDimSize(1)) &&
+ outputDataShape.getDimSize(1) != expectedOutputRows)
+ return emitOpError() << "requires output[0] dimension 1 to have size "
+ << expectedOutputRows << ", got "
+ << outputDataShape.getDimSize(1);
+ }
+ }
+
+ if (values.size() == 2) {
+ const ShapeAdaptor valuesScaleShape(values[1].getType());
+ if (valuesScaleShape.hasRank()) {
+ if (failed(tryUpdateDimOrFailure(*this, n, valuesScaleShape.getDimSize(0),
+ "values[1]", "batch")) ||
+ failed(tryUpdateDimOrFailure(*this, k, valuesScaleShape.getDimSize(1),
+ "values[1]", "rows")))
+ return failure();
+ multiplesOfC = valuesScaleShape.getDimSize(2);
+ }
+
+ const ShapeAdaptor outputScaleShape(output[1].getType());
+ if (outputScaleShape.hasRank()) {
+ if (failed(tryUpdateDimOrFailure(*this, n, outputScaleShape.getDimSize(0),
+ "output[1]", "batch")))
+ return failure();
+
+ if (auto rowCount = getConstantScalarIntValue(getRowCount());
+ succeeded(rowCount) && rowCount.value() > 0 &&
+ ShapedType::isStatic(w)) {
+ const int64_t expectedOutputRows = w * rowCount.value();
+ if (ShapedType::isStatic(outputScaleShape.getDimSize(1)) &&
+ outputScaleShape.getDimSize(1) != expectedOutputRows)
+ return emitOpError() << "requires output[1] dimension 1 to have size "
+ << expectedOutputRows << ", got "
+ << outputScaleShape.getDimSize(1);
+ }
+
+ if (ShapedType::isDynamic(multiplesOfC))
+ multiplesOfC = outputScaleShape.getDimSize(2);
+ else if (ShapedType::isStatic(outputScaleShape.getDimSize(2)) &&
+ multiplesOfC != outputScaleShape.getDimSize(2))
+ return emitOpError()
+ << "expected channels of output[1] to match size "
+ << multiplesOfC << ", got " << outputScaleShape.getDimSize(2);
+ }
+
+ if (ShapedType::isStatic(c) && ShapedType::isStatic(multiplesOfC) &&
+ multiplesOfC != c / blockSize)
+ return emitOpError()
+ << "expects channels of scale tensors to equal C/block_size (" << c
+ << "/" << blockSize << "), got " << multiplesOfC;
+ }
+
+ return success();
+}
+
LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ResizeOp::Adaptor adaptor,
@@ -3987,6 +4174,8 @@ LogicalResult Conv2DBlockScaledOp::verify() {
// Verify IC is a multiple of block size
const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ if (blockSize == 1)
+ return emitOpError("requires block_size to not be BLOCK_SIZE_1");
if (ShapedType::isStatic(IC) && IC % blockSize != 0)
return emitOpError("expect IC to be a multiple of block size, got IC=")
<< IC << ", block_size=" << blockSize;
@@ -4577,6 +4766,8 @@ LogicalResult CastFromBlockScaledOp::verify() {
if (inputDataShape.hasRank()) {
const unsigned int blockSize =
BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ if (blockSize == 1)
+ return emitOpError("requires block_size to not be BLOCK_SIZE_1");
const int64_t inputDataLastDim =
inputDataShape.getDimSize(inputDataShape.getRank() - 1);
if (inputDataLastDim % blockSize != 0)
@@ -4650,6 +4841,8 @@ LogicalResult CastToBlockScaledOp::verify() {
const unsigned int blockSize =
BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ if (blockSize == 1)
+ return emitOpError("requires block_size to not be BLOCK_SIZE_1");
const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
if (inputDataShape.hasRank()) {
const int64_t inputDataLastDim =
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 01c85be4f704f..4ea225b860f6c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -185,6 +185,18 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
return success();
}
+template <>
+LogicalResult
+ProfileInfoDepot::populateProfileInfo(tosa::RowGatherBlockScaledOp op) {
+ for (Value value : op.getValues())
+ addValue(value);
+ addValue(op.getIndices());
+ addValue(op.getRowCount());
+ for (Value result : op.getOutput())
+ addValue(result);
+ return success();
+}
+
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
addValue(op.getValuesIn());
@@ -288,6 +300,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
POPULATE_PROFILE_INFO_CUSTOM(Tile)
POPULATE_PROFILE_INFO_CUSTOM(Transpose)
POPULATE_PROFILE_INFO_CUSTOM(Gather)
+ POPULATE_PROFILE_INFO_CUSTOM(RowGatherBlockScaled)
POPULATE_PROFILE_INFO_CUSTOM(Scatter)
POPULATE_PROFILE_INFO_CUSTOM(Resize)
POPULATE_PROFILE_INFO_CUSTOM(Select)
@@ -598,10 +611,11 @@ SmallVector<OpComplianceInfo<T>> TosaProfileCompliance::findMatchedEntries(
SmallVector<VersionedTypeInfo> sets = compInfo[i].operandTypeInfoSet;
for (const auto &set : sets) {
SmallVector<TypeInfo> expected = set.first;
- assert(present.size() == expected.size() &&
- "the entries for profile-based compliance do not match between...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/192272
More information about the Mlir-commits
mailing list