[Mlir-commits] [mlir] [mlir][tosa] Add more verifiers for the following operators (PR #127923)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 27 13:24:24 PST 2025
================
@@ -850,6 +850,71 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
return success();
}
+LogicalResult tosa::ConcatOp::verify() {
+ // check that each input has same element type as output
+ auto outType = getOutput().getType();
+ const Operation::operand_range inputList = getInput1();
+
+ if (!llvm::all_of(inputList, [&](auto input) {
+ return succeeded(verifySameElementTypes(
+ *this, /* inType = */ input.getType(), outType));
+ })) {
+ return failure();
+ }
+
+ // Check there is at least one input
+ if (inputList.empty())
+ return emitOpError("expect at least one input");
+
+ const Type firstInputType = inputList.front().getType();
+ const ShapeAdaptor firstInputShape(firstInputType);
+ const int32_t axis = getAxis();
+
+ if (firstInputShape.hasRank()) {
+ // Check axis is in expected range
+ if (axis < 0 || axis >= firstInputShape.getRank())
+ return emitOpError("expect axis to be within range 0 < axis < "
+ "rank(input1[0]), got ")
+ << axis;
+ }
+
+ const auto allOperandsHasRank = [](const Value input) {
+ return ShapeAdaptor(input.getType()).hasRank();
+ };
+ if (llvm::all_of(inputList, allOperandsHasRank)) {
+ const int64_t firstInputRank = firstInputShape.getRank();
+
+ for (const auto [index, input] : llvm::enumerate(inputList.drop_front())) {
+ const ShapeAdaptor inputShape(input.getType());
+ const int64_t inputRank = inputShape.getRank();
+ const size_t operandNum = index + 1;
+
+ // Check that each operand has the same rank
+ if (inputRank != firstInputRank)
+ return emitOpError(
+ "expect all operands to have the same rank, but got ")
+ << firstInputRank << " vs " << inputRank << " on operands 0 and "
+ << operandNum;
+
+ // Check non-axis dims match
+ for (int i = 0; i < inputRank; i++) {
----------------
Jerry-Ge wrote:
we can but i think it's more code and less clear than the for loop.
if we want to use llvm::all_of,
- we first have to run the all_of first, find the incorrect index, then emit the error which ends up with more code.
I personally find it's way more clear and straightforward to use the for loop: once we hit the wrong index, simply emit the error.
https://github.com/llvm/llvm-project/pull/127923
More information about the Mlir-commits
mailing list