[Mlir-commits] [mlir] [mlir][tosa] Add several level checks (PR #128074)

Luke Hutton llvmlistbot at llvm.org
Fri Feb 28 04:17:36 PST 2025


================
@@ -111,133 +116,212 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     constCheckers.emplace_back(checkConstantOperandPad);
   }
 
-  bool levelCheckKernel(Operation *op, int32_t v,
-                        const std::string &checkDesc) {
+  bool levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) {
     if (v > tosaLevel.MAX_KERNEL) {
       op->emitOpError() << "failed level check: " << checkDesc;
       return false;
     }
     return true;
   }
 
-  bool levelCheckStride(Operation *op, int32_t v,
-                        const std::string &checkDesc) {
+  bool levelCheckStride(Operation *op, int32_t v, const StringRef checkDesc) {
     if (v > tosaLevel.MAX_STRIDE) {
       op->emitOpError() << "failed level check: " << checkDesc;
       return false;
     }
     return true;
   }
 
-  bool levelCheckScale(Operation *op, int32_t v, const std::string &checkDesc) {
+  bool levelCheckScale(Operation *op, int32_t v, const StringRef checkDesc) {
     if (v > tosaLevel.MAX_SCALE) {
       op->emitOpError() << "failed level check: " << checkDesc;
       return false;
     }
     return true;
   }
 
-  bool levelCheckRank(Operation *op, const Value &v,
-                      const std::string &checkDesc) {
+  bool levelCheckListSize(Operation *op, int32_t v, const StringRef checkDesc) {
+    if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) {
+      op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: "
+                        << checkDesc;
+      return false;
+    }
+    return true;
+  }
+
+  bool levelCheckRankAndSizes(Operation *op, const Value &v,
+                              const StringRef operandOrResult,
+                              int32_t highest_rank) {
     if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
       if (!type.hasRank()) {
         op->emitOpError() << "failed level check: unranked tensor";
         return false;
       }
-      if (type.getRank() > tosaLevel.MAX_RANK) {
-        op->emitOpError() << "failed level check: " << checkDesc;
+      if (type.getRank() > highest_rank) {
+        op->emitOpError() << "failed level check: " << operandOrResult
+                          << " rank(shape) <= MAX_RANK";
+        return false;
+      }
+
+      auto shape = type.getShape();
+      for (auto dim : shape) {
+        if (mlir::ShapedType::isDynamic(dim)) {
+          op->emitOpError() << "failed level check: " << operandOrResult
+                            << " shape dimension cannot be dynamic";
+          return false;
+        }
+      }
+
+      int64_t element_bits = type.getElementTypeBitWidth();
+      int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
+      int64_t size = element_bytes * type.getNumElements();
+
+      // According to 1.11. Tensor Definitions of Tosa spec, the value of
+      // tensor_size_t is 1 << MAX_LOG2_SIZE) - 1 where MAX_LOG2_SIZE is
+      // defined in 1.7. Levels.
+      // For each tensor, the number of tensor elements multiplied by the
+      // element size in bytes must be representable as a tensor_size_t.
+      const int64_t max_size = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1;
+      if (size > max_size) {
+        op->emitOpError()
+            << "failed level check: " << operandOrResult
+            << " tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
         return false;
       }
     }
     return true;
   }
 
   template <typename T>
-  bool levelCheckRanksFor(Operation *op) {
-    if (dyn_cast<T>(op)) {
-      // level check ranks of all operands and results
-      for (auto v : op->getOperands()) {
-        if (!levelCheckRank(op, v, "operand rank(shape) <= MAX_RANK"))
-          return false;
-      }
-      for (auto v : op->getResults()) {
-        if (!levelCheckRank(op, v, "result rank(shape) <= MAX_RANK"))
-          return false;
-      }
+  bool levelCheckRanksAndSizesFor(T tosaOp) {
+    // level check ranks of all operands and results
+    auto op = tosaOp.getOperation();
+    for (auto v : op->getOperands()) {
+      if (!levelCheckRankAndSizes(op, v, "operand", tosaLevel.MAX_RANK))
+        return false;
+    }
+
+    for (auto v : op->getResults()) {
+      if (!levelCheckRankAndSizes(op, v, "result", tosaLevel.MAX_RANK))
+        return false;
     }
     return true;
   }
 
-  bool levelCheckRanks(Operation *op) {
-#define CHECK_RANKS_FOR(tosaOp)                                                \
-  if (!levelCheckRanksFor<tosaOp##Op>(op))                                     \
-    return false;
+  template <>
+  bool levelCheckRanksAndSizesFor(tosa::ArgMaxOp tosaOp) {
+    auto op = tosaOp.getOperation();
+    if (!levelCheckRankAndSizes(op, tosaOp.getInput(), "operand",
+                                tosaLevel.MAX_RANK))
+      return false;
+
+    // rank(output) = rank(input) - 1
+    if (!levelCheckRankAndSizes(op, tosaOp.getOutput(), "result",
+                                tosaLevel.MAX_RANK - 1))
+      return false;
+
+    return true;
+  }
+
+  template <>
+  bool levelCheckRanksAndSizesFor(tosa::IfOp tosaOp) {
+    auto op = tosaOp.getOperation();
+
+    // Only the condition input has rank limitation.
+    if (!levelCheckRankAndSizes(op, tosaOp.getCond(), "operand",
+                                tosaLevel.MAX_RANK))
+      return false;
+
+    return true;
+  }
+
+  bool levelCheckRanksAndSizes(Operation *op) {
+#define CHECK_RANKS_AND_SIZES_FOR(tosaOp)                                      \
+  if (isa<tosa::tosaOp##Op>(op))                                               \
+    if (!levelCheckRanksAndSizesFor(cast<tosa::tosaOp##Op>(op)))               \
+      return false;
+
+#define CHECK_RANKS_AND_SIZES_SKIP(tosaOp)                                     \
----------------
lhutton1 wrote:

I think we can just not specify the operators that don't require MAX_RANK checks, WDYT?

https://github.com/llvm/llvm-project/pull/128074


More information about the Mlir-commits mailing list