[Mlir-commits] [mlir] [mlir][tosa] Use `LogicalResult` in validation functions (PR #160052)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 22 01:34:04 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
Author: Luke Hutton (lhutton1)
<details>
<summary>Changes</summary>
This commit replaces functions that previously returned `bool` to indicate validation success or failure with `LogicalResult`.
Note: this PR also contains the contents of https://github.com/llvm/llvm-project/pull/159754, so shouldn't be merged before https://github.com/llvm/llvm-project/pull/159754.
---
Patch is 51.21 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160052.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+335-388)
- (modified) mlir/test/Dialect/Tosa/error_if_check.mlir (-33)
- (added) mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir (+34)
``````````diff
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 790bbf77877bc..6ea4e7736f78c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -205,148 +205,142 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
constCheckers.emplace_back(checkConstantOperandNegate);
}
- bool levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) {
- if (v > tosaLevel.MAX_KERNEL) {
- op->emitOpError() << "failed level check: " << checkDesc;
- return false;
- }
- return true;
+ LogicalResult levelCheckKernel(Operation *op, int32_t v,
+ const StringRef checkDesc) {
+ if (v > tosaLevel.MAX_KERNEL)
+ return op->emitOpError() << "failed level check: " << checkDesc;
+ return success();
}
- bool levelCheckStride(Operation *op, int32_t v, const StringRef checkDesc) {
- if (v > tosaLevel.MAX_STRIDE) {
- op->emitOpError() << "failed level check: " << checkDesc;
- return false;
- }
- return true;
+ LogicalResult levelCheckStride(Operation *op, int32_t v,
+ const StringRef checkDesc) {
+ if (v > tosaLevel.MAX_STRIDE)
+ return op->emitOpError() << "failed level check: " << checkDesc;
+ return success();
}
- bool levelCheckScale(Operation *op, int32_t v, const StringRef checkDesc) {
- if (v > tosaLevel.MAX_SCALE) {
- op->emitOpError() << "failed level check: " << checkDesc;
- return false;
- }
- return true;
+ LogicalResult levelCheckScale(Operation *op, int32_t v,
+ const StringRef checkDesc) {
+ if (v > tosaLevel.MAX_SCALE)
+ return op->emitOpError() << "failed level check: " << checkDesc;
+ return success();
}
- bool levelCheckListSize(Operation *op, int32_t v, const StringRef checkDesc) {
- if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) {
- op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: "
- << checkDesc;
- return false;
- }
- return true;
+ LogicalResult levelCheckListSize(Operation *op, int32_t v,
+ const StringRef checkDesc) {
+ if (v > tosaLevel.MAX_TENSOR_LIST_SIZE)
+ return op->emitOpError()
+ << "failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc;
+ return success();
}
// Perform the Level Rank check on the tensor type.
- bool levelCheckRank(Operation *op, const Type typeToCheck,
- const StringRef operandOrResult, int32_t highest_rank) {
+ LogicalResult levelCheckRank(Operation *op, const Type typeToCheck,
+ const StringRef operandOrResult,
+ int32_t highest_rank) {
if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
- if (!type.hasRank()) {
- op->emitOpError() << "failed level check: unranked tensor";
- return false;
- }
- if (type.getRank() > highest_rank) {
- op->emitOpError() << "failed level check: " << operandOrResult
- << " rank(shape) <= MAX_RANK";
- return false;
- }
+ if (!type.hasRank())
+ return op->emitOpError() << "failed level check: unranked tensor";
+ if (type.getRank() > highest_rank)
+ return op->emitOpError() << "failed level check: " << operandOrResult
+ << " rank(shape) <= MAX_RANK";
}
- return true;
+ return success();
}
// Perform the Level Rank check on the tensor value.
- bool levelCheckRank(Operation *op, const Value &v,
- const StringRef operandOrResult, int32_t highest_rank) {
+ LogicalResult levelCheckRank(Operation *op, const Value &v,
+ const StringRef operandOrResult,
+ int32_t highest_rank) {
return levelCheckRank(op, v.getType(), operandOrResult, highest_rank);
}
// Perform the Level tensor size check on the tensor type.
- bool levelCheckSize(Operation *op, const Type &typeToCheck,
- const StringRef operandOrResult);
+ LogicalResult levelCheckSize(Operation *op, const Type &typeToCheck,
+ const StringRef operandOrResult);
// Perform the Level tensor size check on the tensor value.
- bool levelCheckSize(Operation *op, const Value &v,
- const StringRef operandOrResult) {
+ LogicalResult levelCheckSize(Operation *op, const Value &v,
+ const StringRef operandOrResult) {
return levelCheckSize(op, v.getType(), operandOrResult);
}
// Level check sizes of all operands and results of the operation.
template <typename T>
- bool levelCheckSizes(T tosaOp) {
+ LogicalResult levelCheckSizes(T tosaOp) {
auto op = tosaOp.getOperation();
for (auto v : op->getOperands()) {
- if (!levelCheckSize(op, v, "operand"))
- return false;
+ if (failed(levelCheckSize(op, v, "operand")))
+ return failure();
}
for (auto v : op->getResults()) {
- if (!levelCheckSize(op, v, "result"))
- return false;
+ if (failed(levelCheckSize(op, v, "result")))
+ return failure();
}
- return true;
+ return success();
}
// Level check ranks of all operands, attribute and results of the operation.
template <typename T>
- bool levelCheckRanks(T tosaOp) {
+ LogicalResult levelCheckRanks(T tosaOp) {
auto op = tosaOp.getOperation();
for (auto v : op->getOperands()) {
- if (!levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK))
- return false;
+ if (failed(levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK)))
+ return failure();
}
for (auto v : op->getResults()) {
- if (!levelCheckRank(op, v, "result", tosaLevel.MAX_RANK))
- return false;
+ if (failed(levelCheckRank(op, v, "result", tosaLevel.MAX_RANK)))
+ return failure();
}
- return true;
+ return success();
}
// Level check ranks and sizes.
- bool levelCheckRanksAndSizes(Operation *op);
+ LogicalResult levelCheckRanksAndSizes(Operation *op);
// Pool Op: level check kernel/stride/pad values
template <typename T>
- bool levelCheckPool(Operation *op) {
+ LogicalResult levelCheckPool(Operation *op) {
if (auto poolOp = dyn_cast<T>(op)) {
for (auto k : poolOp.getKernel()) {
- if (!levelCheckKernel(op, k, "kernel <= MAX_KERNEL")) {
- return false;
+ if (failed(levelCheckKernel(op, k, "kernel <= MAX_KERNEL"))) {
+ return failure();
}
}
for (auto s : poolOp.getStride()) {
- if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
- return false;
+ if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
+ return failure();
}
}
for (auto p : poolOp.getPad()) {
- if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
- return false;
+ if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
+ return failure();
}
}
}
- return true;
+ return success();
}
// Conv Op: level check dilation/stride/pad values
template <typename T>
- bool levelCheckConv(Operation *op) {
+ LogicalResult levelCheckConv(Operation *op) {
if (auto convOp = dyn_cast<T>(op)) {
for (auto k : convOp.getDilation()) {
- if (!levelCheckKernel(op, k, "dilation <= MAX_KERNEL")) {
- return false;
+ if (failed(levelCheckKernel(op, k, "dilation <= MAX_KERNEL"))) {
+ return failure();
}
}
for (auto p : convOp.getPad()) {
- if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
- return false;
+ if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
+ return failure();
}
}
for (auto s : convOp.getStride()) {
- if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
- return false;
+ if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
+ return failure();
}
}
auto dilation = convOp.getDilation();
@@ -356,100 +350,100 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
if (isa<tosa::Conv2DOp>(op)) {
assert(shape.size() == 4);
assert(dilation.size() == 2);
- if (!levelCheckKernel(op, dilation[0] * shape[1],
- "dilation_y * KH <= MAX_KERNEL)") ||
- !levelCheckKernel(op, dilation[1] * shape[2],
- "dilation_x * KW <= MAX_KERNEL)"))
- return false;
+ if (failed(levelCheckKernel(op, dilation[0] * shape[1],
+ "dilation_y * KH <= MAX_KERNEL)")) ||
+ failed(levelCheckKernel(op, dilation[1] * shape[2],
+ "dilation_x * KW <= MAX_KERNEL)")))
+ return failure();
} else if (isa<tosa::Conv3DOp>(op)) {
assert(shape.size() == 5);
assert(dilation.size() == 3);
- if (!levelCheckKernel(op, dilation[0] * shape[1],
- "dilation_d * KD <= MAX_KERNEL)") ||
- !levelCheckKernel(op, dilation[1] * shape[2],
- "dilation_y * KH <= MAX_KERNEL)") ||
- !levelCheckKernel(op, dilation[2] * shape[3],
- "dilation_x * KW <= MAX_KERNEL)"))
- return false;
+ if (failed(levelCheckKernel(op, dilation[0] * shape[1],
+ "dilation_d * KD <= MAX_KERNEL)")) ||
+ failed(levelCheckKernel(op, dilation[1] * shape[2],
+ "dilation_y * KH <= MAX_KERNEL)")) ||
+ failed(levelCheckKernel(op, dilation[2] * shape[3],
+ "dilation_x * KW <= MAX_KERNEL)")))
+ return failure();
} else if (isa<tosa::DepthwiseConv2DOp>(op)) {
assert(shape.size() == 4);
assert(dilation.size() == 2);
- if (!levelCheckKernel(op, dilation[0] * shape[0],
- "dilation_y * KH <= MAX_KERNEL)") ||
- !levelCheckKernel(op, dilation[1] * shape[1],
- "dilation_x * KW <= MAX_KERNEL)"))
- return false;
+ if (failed(levelCheckKernel(op, dilation[0] * shape[0],
+ "dilation_y * KH <= MAX_KERNEL)")) ||
+ failed(levelCheckKernel(op, dilation[1] * shape[1],
+ "dilation_x * KW <= MAX_KERNEL)")))
+ return failure();
}
}
}
- return true;
+ return success();
}
// FFT op: level check H, W in input shape [N,H,W]
template <typename T>
- bool levelCheckFFT(Operation *op) {
+ LogicalResult levelCheckFFT(Operation *op) {
if (isa<T>(op)) {
for (auto v : op->getOperands()) {
if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
auto shape = type.getShape();
assert(shape.size() == 3);
- if (!levelCheckKernel(op, shape[1], "H <= MAX_KERNEL") ||
- !levelCheckKernel(op, shape[2], "W <= MAX_KERNEL")) {
- return false;
+ if (failed(levelCheckKernel(op, shape[1], "H <= MAX_KERNEL")) ||
+ failed(levelCheckKernel(op, shape[2], "W <= MAX_KERNEL"))) {
+ return failure();
}
}
}
}
- return true;
+ return success();
}
// TransposeConv2d op: level check kH/kW, outpad, and stride
- bool levelCheckTransposeConv2d(Operation *op) {
+ LogicalResult levelCheckTransposeConv2d(Operation *op) {
if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
if (ShapedType filterType =
dyn_cast<ShapedType>(transpose.getWeight().getType())) {
auto shape = filterType.getShape();
assert(shape.size() == 4);
// level check kernel sizes for kH and KW
- if (!levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL") ||
- !levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL")) {
- return false;
+ if (failed(levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL")) ||
+ failed(levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL"))) {
+ return failure();
}
}
for (auto p : transpose.getOutPad()) {
- if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
- return false;
+ if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
+ return failure();
}
}
for (auto s : transpose.getStride()) {
- if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
- return false;
+ if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
+ return failure();
}
}
}
- return true;
+ return success();
}
// Resize op: level check max scales
- bool levelCheckResize(Operation *op) {
+ LogicalResult levelCheckResize(Operation *op) {
if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
SmallVector<int64_t> scale;
if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(),
scale)) {
- return false;
+ return failure();
}
const int64_t scaleYN = scale[0];
const int64_t scaleYD = scale[1];
const int64_t scaleXN = scale[2];
const int64_t scaleXD = scale[3];
- if (!levelCheckScale(op, scaleYN / scaleYD,
- "scale_y_n/scale_y_d <= MAX_SCALE") ||
- !levelCheckScale(op, scaleXN / scaleXD,
- "scale_x_n/scale_x_d <= MAX_SCALE")) {
- return false;
+ if (failed(levelCheckScale(op, scaleYN / scaleYD,
+ "scale_y_n/scale_y_d <= MAX_SCALE")) ||
+ failed(levelCheckScale(op, scaleXN / scaleXD,
+ "scale_x_n/scale_x_d <= MAX_SCALE"))) {
+ return failure();
}
}
- return true;
+ return success();
}
// Recursively perform a bottom-up search to determine the maximum nesting
@@ -468,62 +462,65 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
getMaxNestedDepth(op, depth);
}
- bool levelCheckMaxNesting(Operation *op) {
+ LogicalResult levelCheckMaxNesting(Operation *op) {
int32_t maxNestedDepth = 0;
getMaxNestedDepth(op, maxNestedDepth);
if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
op->emitOpError() << "failed level check: " << maxNestedDepth
<< " >= MAX_NESTING";
- return false;
+ return failure();
}
- return true;
+ return success();
}
- bool levelCheckListSize(Operation *op) {
+ LogicalResult levelCheckListSize(Operation *op) {
if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
return levelCheckListSize(op, concat.getInput1().size(), "input1");
}
if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
- if (!levelCheckListSize(op, custom.getInputList().size(), "input_list") ||
- !levelCheckListSize(op, custom.getOutputList().size(),
- "output_list")) {
- return false;
+ if (failed(levelCheckListSize(op, custom.getInputList().size(),
+ "input_list")) ||
+ failed(levelCheckListSize(op, custom.getOutputList().size(),
+ "output_list"))) {
+ return failure();
}
}
if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
- if (!levelCheckListSize(op, condIf.getInputList().size(), "inputs") ||
- !levelCheckListSize(op, condIf.getOutputList().size(), "outputs")) {
- return false;
+ if (failed(
+ levelCheckListSize(op, condIf.getInputList().size(), "inputs")) ||
+ failed(levelCheckListSize(op, condIf.getOutputList().size(),
+ "outputs"))) {
+ return failure();
}
}
if (auto w = dyn_cast<tosa::WhileOp>(op)) {
- if (!levelCheckListSize(op, w.getInputList().size(), "inputs") ||
- !levelCheckListSize(op, w.getOutputList().size(), "outputs")) {
- return false;
+ if (failed(levelCheckListSize(op, w.getInputList().size(), "inputs")) ||
+ failed(levelCheckListSize(op, w.getOutputList().size(), "outputs"))) {
+ return failure();
}
}
- return true;
+ return success();
}
- bool attributeCheckRescale(Operation *op) {
+ LogicalResult attributeCheckRescale(Operation *op) {
if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
!targetEnv.allows(Extension::doubleround)) {
op->emitOpError()
<< "failed attribute check: rounding_mode = DOUBLE_ROUND "
<< "requires extension [doubleround]";
- return false;
+ return failure();
}
if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
!targetEnv.allows(Extension::inexactround)) {
op->emitOpError()
<< "failed attribute check: rounding_mode = INEXACT_ROUND "
<< "requires extension [inexactround]";
- return false;
+ return failure();
}
}
- return true;
+ return success();
}
// configure profile and level values from pass options profileName and
@@ -563,8 +560,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
}
}
- bool CheckVariable(Operation *op);
- bool CheckVariableReadOrWrite(Operation *op);
+ LogicalResult CheckVariable(Operation *op);
+ LogicalResult CheckVariableReadOrWrite(Operation *op);
bool isValidElementType(Type type, const bool allowUnsigned = false);
SmallVector<
@@ -577,62 +574,66 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
};
template <>
-bool TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
+LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
auto op = tosaOp.getOperation();
- if (!levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK))
- return false;
+ if (failed(
+ levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK)))
+ return failure();
// rank(output) = rank(input) - 1
- if (!levelCheckRank(op, tosaOp.getOutput(), "result", tosaLevel.MAX_RANK - 1))
- return false;
+ if (failed(levelCheckRank(op, tosaOp.getOutput(), "result",
+ tosaLevel.MAX_RANK - 1)))
+ return failure();
- return true;
+ return success();
}
template <>
-bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
+LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
auto op = tosaOp.getOperation();
// Only the condition input has rank limitation.
- if (!levelCheckRank(op, tosaOp.getCondition(), "operand", tosaLevel.MAX_RANK))
- return false;
+ if (failed(levelCheckRank(op, tosaOp.getCondition(), "operand",
+ tosaLevel.MAX_RANK)))
+ return failure();
- return true;
+ return success();
}
template <>
-bool TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
+LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
auto op = tosaOp.getOperation();
auto variableType = getVariableType(tosaOp);
- if (!levelCheckRank(op, variableType, "variable type", tosaLevel.MAX_RANK))
- return false;
+ if (failed(levelCheckRank(op, variableType, "variable type",
+ tosaLevel.MAX_RANK)))
+ return failure();
- return true;
+ return success();
}
template <>
-bool TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
+LogicalResult TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
auto op = tosaOp.getOperation();
auto variableType = getVariableType(tosaOp);
- if (!levelCheckSize(op, variableType, "variable type"))
- return false;
+ if (failed(levelCheckSize(op, variableType, "variable type")))
+ ret...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/160052
More information about the Mlir-commits
mailing list