[Mlir-commits] [mlir] [mlir][tosa] Add more verifiers for the following operators (PR #127923)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 27 13:24:24 PST 2025


================
@@ -850,6 +850,71 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::ConcatOp::verify() {
+  // check that each input has same element type as output
+  auto outType = getOutput().getType();
+  const Operation::operand_range inputList = getInput1();
+
+  if (!llvm::all_of(inputList, [&](auto input) {
+        return succeeded(verifySameElementTypes(
+            *this, /* inType = */ input.getType(), outType));
+      })) {
+    return failure();
+  }
+
+  // Check there is at least one input
+  if (inputList.empty())
+    return emitOpError("expect at least one input");
+
+  const Type firstInputType = inputList.front().getType();
+  const ShapeAdaptor firstInputShape(firstInputType);
+  const int32_t axis = getAxis();
+
+  if (firstInputShape.hasRank()) {
+    // Check axis is in expected range
+    if (axis < 0 || axis >= firstInputShape.getRank())
+      return emitOpError("expect axis to be within range 0 < axis < "
+                         "rank(input1[0]), got ")
+             << axis;
+  }
+
+  const auto allOperandsHasRank = [](const Value input) {
+    return ShapeAdaptor(input.getType()).hasRank();
+  };
+  if (llvm::all_of(inputList, allOperandsHasRank)) {
+    const int64_t firstInputRank = firstInputShape.getRank();
+
+    for (const auto [index, input] : llvm::enumerate(inputList.drop_front())) {
+      const ShapeAdaptor inputShape(input.getType());
+      const int64_t inputRank = inputShape.getRank();
+      const size_t operandNum = index + 1;
+
+      // Check that each operand has the same rank
+      if (inputRank != firstInputRank)
+        return emitOpError(
+                   "expect all operands to have the same rank, but got ")
+               << firstInputRank << " vs " << inputRank << " on operands 0 and "
+               << operandNum;
+
+      // Check non-axis dims match
+      for (int i = 0; i < inputRank; i++) {
----------------
Jerry-Ge wrote:

we can but i think it's more code and less clear than the for loop. 

if we want to use llvm::all_of, 
- we first have to run the all_of first, find the incorrect index, then emit the error which ends up with more code. 

I personally find it's way more clear and straightforward to use the for loop: once we hit the wrong index, simply emit the error. 




https://github.com/llvm/llvm-project/pull/127923


More information about the Mlir-commits mailing list