[Mlir-commits] [mlir] [mlir][tosa] Add table size check for Table Op (PR #135262)
Tai Ly
llvmlistbot at llvm.org
Fri Apr 11 14:20:00 PDT 2025
https://github.com/Tai78641 updated https://github.com/llvm/llvm-project/pull/135262
>From 50fde323a4599229ee64cb4dcb1baee9a3677b6e Mon Sep 17 00:00:00 2001
From: Tai Ly <tai.ly at arm.com>
Date: Mon, 7 Apr 2025 22:37:22 +0000
Subject: [PATCH] [mlir][tosa] Add table size check for Table Op
Add table size check for Table Op
and add lit tests to error_if_check.mlir
also corrected some existing tests that violated the
table size checks
Signed-off-by: Tai Ly <tai.ly at arm.com>
Change-Id: I34b3dd95d90c473622ae5f18320b688fe4da0b0a
---
.../Tosa/Transforms/TosaValidation.cpp | 24 ++++++++++++++++++-
mlir/test/Dialect/Tosa/dynamic_extension.mlir | 4 ++--
mlir/test/Dialect/Tosa/error_if_check.mlir | 16 +++++++++++++
mlir/test/Dialect/Tosa/invalid_extension.mlir | 4 ++--
4 files changed, 43 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 11eb0d969d78b..ef9d27f8df0ad 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1012,8 +1012,30 @@ bool checkErrorIfMul(Operation *op) {
return true;
}
+bool checkErrorIfTable(Operation *op) {
+ auto table = dyn_cast<tosa::TableOp>(op);
+ if (!table)
+ return true;
+
+ // REQUIRE(length(table) == TABLE_SIZE) where TABLE_SIZE is 256 or 513
+ const auto inputElemType = getElementTypeOrSelf(table.getInput1().getType());
+ const int tableSize = inputElemType.isInteger(8) ? 256 : 513;
+
+ const ShapeAdaptor tableShape(table.getTable().getType());
+ if (tableShape.hasStaticShape()) {
+ const auto numElements = tableShape.getNumElements();
+ if (numElements != tableSize) {
+ op->emitOpError() << "requires table size of " << tableSize << ", got "
+ << numElements;
+ return false;
+ }
+ }
+
+ return true;
+}
+
LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
- if (!checkErrorIfResize(op) || !checkErrorIfMul(op))
+ if (!checkErrorIfResize(op) || !checkErrorIfMul(op) || !checkErrorIfTable(op))
return failure();
return success();
}
diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
index 0ec46022157d7..25e1aa195c3a0 100644
--- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir
+++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
@@ -13,8 +13,8 @@ func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8
// -----
-func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
- %0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<513xi8>) -> tensor<4x5xi8>
+func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<256xi8>) -> () {
+ %0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<256xi8>) -> tensor<4x5xi8>
return
}
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index f7ca0faa8bc9e..65a69be91e0c8 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -113,3 +113,19 @@ func.func @test_mul_non_zero_shift(%arg0: tensor<1x8x8x8xi16>, %arg1: tensor<1x8
%mul = tosa.mul %arg0, %arg1, %shift : (tensor<1x8x8x8xi16>, tensor<1x8x8x8xi16>, tensor<1xi8>) -> tensor<1x8x8x8xi32>
return %mul : tensor<1x8x8x8xi32>
}
+
+// -----
+// CHECK-LABEL: test_i16_table_size
+func.func @test_i16_table_size(%arg0: tensor<2x64xi16>, %arg1: tensor<256xi16>) -> tensor<2x64xi32> {
+ // expected-error at +1 {{'tosa.table' op requires table size of 513, got 256}}
+ %0 = tosa.table %arg0, %arg1 : (tensor<2x64xi16>, tensor<256xi16>) -> tensor<2x64xi32>
+ return %0 : tensor<2x64xi32>
+}
+
+// -----
+// CHECK-LABEL: test_i8_table_size
+func.func @test_i8_table_size(%arg0: tensor<2x64xi8>, %arg1: tensor<513xi8>) -> tensor<2x64xi8> {
+ // expected-error at +1 {{'tosa.table' op requires table size of 256, got 513}}
+ %0 = tosa.table %arg0, %arg1 : (tensor<2x64xi8>, tensor<513xi8>) -> tensor<2x64xi8>
+ return %0 : tensor<2x64xi8>
+}
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 241e603e91c61..7386b1ba9df99 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -497,9 +497,9 @@ func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8
// -----
-func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
+func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<256xi8>) -> () {
// expected-error at +1 {{'tosa.table' op expected compile time resolvable constant, but got variable value for operand #1}}
- %0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<513xi8>) -> tensor<4x5xi8>
+ %0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<256xi8>) -> tensor<4x5xi8>
return
}
More information about the Mlir-commits
mailing list