[Mlir-commits] [mlir] [mlir][tosa] Add several level checks (PR #128074)

TatWai Chong llvmlistbot at llvm.org
Wed Feb 26 11:16:38 PST 2025


================
@@ -147,107 +152,149 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     return true;
   }
 
-  bool levelCheckRank(Operation *op, const Value &v,
-                      const std::string &checkDesc) {
+  bool levelCheckListSize(Operation *op, int32_t v,
+                          const std::string &checkDesc) {
+    if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) {
+      op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: "
+                        << checkDesc;
+      return false;
+    }
+    return true;
+  }
+
+  bool levelCheckRankAndSizes(Operation *op, const Value &v,
+                              const std::string &operandOrResult) {
     if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
       if (!type.hasRank()) {
         op->emitOpError() << "failed level check: unranked tensor";
         return false;
       }
       if (type.getRank() > tosaLevel.MAX_RANK) {
-        op->emitOpError() << "failed level check: " << checkDesc;
+        op->emitOpError() << "failed level check: " << operandOrResult
+                          << " rank(shape) <= MAX_RANK";
+        return false;
+      }
+
+      const int64_t max_dim = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1;
+      const int64_t max_size =
+          (INT64_C(1) << (tosaLevel.MAX_LOG2_SIZE + 1)) - 1;
+
+      auto shape = type.getShape();
+      for (auto dim : shape) {
+        if (mlir::ShapedType::isDynamic(dim)) {
+          op->emitOpError() << "failed level check: " << operandOrResult
+                            << " shape dimension cannot be dynamic";
+          return false;
+        }
+        if (dim > max_dim) {
+          op->emitOpError() << "failed level check: " << operandOrResult
+                            << " shape dimension <= (1<<MAX_LOG2_SIZE) - 1";
+          return false;
+        }
+      }
+
+      int64_t element_bits = type.getElementTypeBitWidth();
+      int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
+      int64_t size = element_bytes * type.getNumElements();
+      if (size > max_size) {
+        op->emitOpError()
+            << "failed level check: " << operandOrResult
+            << " tensor size (in bytes) <= (1<<MAX_LOG2_SIZE+1) - 1";
         return false;
       }
     }
     return true;
   }
 
   template <typename T>
-  bool levelCheckRanksFor(Operation *op) {
+  bool levelCheckRanksAndSizesFor(Operation *op) {
     if (dyn_cast<T>(op)) {
       // level check ranks of all operands and results
       for (auto v : op->getOperands()) {
-        if (!levelCheckRank(op, v, "operand rank(shape) <= MAX_RANK"))
+        if (!levelCheckRankAndSizes(op, v, "operand"))
           return false;
       }
       for (auto v : op->getResults()) {
-        if (!levelCheckRank(op, v, "result rank(shape) <= MAX_RANK"))
+        if (!levelCheckRankAndSizes(op, v, "result"))
           return false;
       }
     }
     return true;
   }
 
-  bool levelCheckRanks(Operation *op) {
-#define CHECK_RANKS_FOR(tosaOp)                                                \
-  if (!levelCheckRanksFor<tosaOp##Op>(op))                                     \
+  bool levelCheckRanksAndSizes(Operation *op) {
+#define CHECK_RANKS_AND_SIZES_FOR(tosaOp)                                      \
+  if (!levelCheckRanksAndSizesFor<tosaOp##Op>(op))                             \
     return false;
 
     // tensor operators:
-    CHECK_RANKS_FOR(ArgMax);
+    CHECK_RANKS_AND_SIZES_FOR(ArgMax);
     // all activation functions:
-    CHECK_RANKS_FOR(Clamp);
-    CHECK_RANKS_FOR(Sigmoid);
-    CHECK_RANKS_FOR(Tanh);
+    CHECK_RANKS_AND_SIZES_FOR(Clamp);
+    CHECK_RANKS_AND_SIZES_FOR(Erf);
+    CHECK_RANKS_AND_SIZES_FOR(Sigmoid);
+    CHECK_RANKS_AND_SIZES_FOR(Tanh);
     // all elementwise binary operators:
-    CHECK_RANKS_FOR(Add);
-    CHECK_RANKS_FOR(ArithmeticRightShift);
-    CHECK_RANKS_FOR(BitwiseAnd);
-    CHECK_RANKS_FOR(BitwiseOr);
-    CHECK_RANKS_FOR(BitwiseXor);
-    CHECK_RANKS_FOR(IntDiv);
-    CHECK_RANKS_FOR(LogicalAnd);
-    CHECK_RANKS_FOR(LogicalLeftShift);
-    CHECK_RANKS_FOR(LogicalRightShift);
-    CHECK_RANKS_FOR(LogicalOr);
-    CHECK_RANKS_FOR(LogicalXor);
-    CHECK_RANKS_FOR(Maximum);
-    CHECK_RANKS_FOR(Minimum);
-    CHECK_RANKS_FOR(Mul);
-    CHECK_RANKS_FOR(Pow);
-    CHECK_RANKS_FOR(Sub);
-    CHECK_RANKS_FOR(Table);
+    CHECK_RANKS_AND_SIZES_FOR(Add);
+    CHECK_RANKS_AND_SIZES_FOR(ArithmeticRightShift);
+    CHECK_RANKS_AND_SIZES_FOR(BitwiseAnd);
+    CHECK_RANKS_AND_SIZES_FOR(BitwiseOr);
+    CHECK_RANKS_AND_SIZES_FOR(BitwiseXor);
+    CHECK_RANKS_AND_SIZES_FOR(IntDiv);
+    CHECK_RANKS_AND_SIZES_FOR(LogicalAnd);
+    CHECK_RANKS_AND_SIZES_FOR(LogicalLeftShift);
+    CHECK_RANKS_AND_SIZES_FOR(LogicalRightShift);
+    CHECK_RANKS_AND_SIZES_FOR(LogicalOr);
+    CHECK_RANKS_AND_SIZES_FOR(LogicalXor);
+    CHECK_RANKS_AND_SIZES_FOR(Maximum);
+    CHECK_RANKS_AND_SIZES_FOR(Minimum);
+    CHECK_RANKS_AND_SIZES_FOR(Mul);
+    CHECK_RANKS_AND_SIZES_FOR(Pow);
+    CHECK_RANKS_AND_SIZES_FOR(Sub);
+    CHECK_RANKS_AND_SIZES_FOR(Table);
     // all elementwise unary operators:
-    CHECK_RANKS_FOR(Abs);
-    CHECK_RANKS_FOR(BitwiseNot);
-    CHECK_RANKS_FOR(Ceil);
-    CHECK_RANKS_FOR(Clz);
-    CHECK_RANKS_FOR(Exp);
-    CHECK_RANKS_FOR(Floor);
-    CHECK_RANKS_FOR(Log);
-    CHECK_RANKS_FOR(LogicalNot);
-    CHECK_RANKS_FOR(Negate);
-    CHECK_RANKS_FOR(Reciprocal);
-    CHECK_RANKS_FOR(Rsqrt);
+    CHECK_RANKS_AND_SIZES_FOR(Abs);
+    CHECK_RANKS_AND_SIZES_FOR(BitwiseNot);
+    CHECK_RANKS_AND_SIZES_FOR(Ceil);
+    CHECK_RANKS_AND_SIZES_FOR(Clz);
+    CHECK_RANKS_AND_SIZES_FOR(Cos);
+    CHECK_RANKS_AND_SIZES_FOR(Exp);
+    CHECK_RANKS_AND_SIZES_FOR(Floor);
+    CHECK_RANKS_AND_SIZES_FOR(Log);
+    CHECK_RANKS_AND_SIZES_FOR(LogicalNot);
+    CHECK_RANKS_AND_SIZES_FOR(Negate);
+    CHECK_RANKS_AND_SIZES_FOR(Reciprocal);
+    CHECK_RANKS_AND_SIZES_FOR(Rsqrt);
+    CHECK_RANKS_AND_SIZES_FOR(Sin);
     // all elementwise ternary operators:
-    CHECK_RANKS_FOR(Select);
+    CHECK_RANKS_AND_SIZES_FOR(Select);
     // all comparison operators:
-    CHECK_RANKS_FOR(Equal);
-    CHECK_RANKS_FOR(Greater);
-    CHECK_RANKS_FOR(GreaterEqual);
+    CHECK_RANKS_AND_SIZES_FOR(Equal);
+    CHECK_RANKS_AND_SIZES_FOR(Greater);
+    CHECK_RANKS_AND_SIZES_FOR(GreaterEqual);
     // all reduction operators:
-    CHECK_RANKS_FOR(ReduceAll);
-    CHECK_RANKS_FOR(ReduceAny);
-    CHECK_RANKS_FOR(ReduceMax);
-    CHECK_RANKS_FOR(ReduceMin);
-    CHECK_RANKS_FOR(ReduceProd);
-    CHECK_RANKS_FOR(ReduceSum);
+    CHECK_RANKS_AND_SIZES_FOR(ReduceAll);
+    CHECK_RANKS_AND_SIZES_FOR(ReduceAny);
+    CHECK_RANKS_AND_SIZES_FOR(ReduceMax);
+    CHECK_RANKS_AND_SIZES_FOR(ReduceMin);
+    CHECK_RANKS_AND_SIZES_FOR(ReduceProd);
+    CHECK_RANKS_AND_SIZES_FOR(ReduceSum);
     // all data layout operators:
-    CHECK_RANKS_FOR(Concat);
-    CHECK_RANKS_FOR(Pad);
-    CHECK_RANKS_FOR(Reshape);
-    CHECK_RANKS_FOR(Reverse);
-    CHECK_RANKS_FOR(Slice);
-    CHECK_RANKS_FOR(Tile);
-    CHECK_RANKS_FOR(Transpose);
+    CHECK_RANKS_AND_SIZES_FOR(Concat);
+    CHECK_RANKS_AND_SIZES_FOR(Pad);
+    CHECK_RANKS_AND_SIZES_FOR(Reshape);
+    CHECK_RANKS_AND_SIZES_FOR(Reverse);
+    CHECK_RANKS_AND_SIZES_FOR(Slice);
+    CHECK_RANKS_AND_SIZES_FOR(Tile);
+    CHECK_RANKS_AND_SIZES_FOR(Transpose);
     // all type conversion operators:
-    CHECK_RANKS_FOR(Cast);
-    CHECK_RANKS_FOR(Rescale);
+    CHECK_RANKS_AND_SIZES_FOR(Cast);
+    CHECK_RANKS_AND_SIZES_FOR(Rescale);
     // all data nodes operators:
-    CHECK_RANKS_FOR(Const);
-    CHECK_RANKS_FOR(Identity);
+    CHECK_RANKS_AND_SIZES_FOR(Const);
+    CHECK_RANKS_AND_SIZES_FOR(Identity);
----------------
tatwaichong wrote:

Yes. I add a dedicated level check for `condition` operand of `COND_IF`.

https://github.com/llvm/llvm-project/pull/128074


More information about the Mlir-commits mailing list