[Mlir-commits] [mlir] [mlir][tosa] Add log2_ceil/log2_floor/exp2_shape ops (PR #175057)
Yuvaraj Venkatesh
llvmlistbot at llvm.org
Fri Jan 9 11:03:47 PST 2026
https://github.com/Yuvaraj-Venkatesh updated https://github.com/llvm/llvm-project/pull/175057
>From d42f3ea2a8e164da20e42fcae256ff692823eb3e Mon Sep 17 00:00:00 2001
From: Yuvaraj Venkatesh <yuvaraj.venkatesh at arm.com>
Date: Wed, 26 Nov 2025 16:39:15 +0000
Subject: [PATCH 1/2] [mlir][tosa] Add log2_ceil/log2_floor/exp2_shape ops
This commit introduces new ext-shape operations,
- LOG2_CEIL_SHAPE
- LOG2_FLOOR_SHAPE
- EXP2_SHAPE
These additions include the operator definitions, same-rank
verification, and level checks during validation.
Co-authored-by: Luke Hutton <luke.hutton at arm.com>
Change-Id: I0a3ba4113771ededbc9a60b92d1e39768aa394f3
---
.../mlir/Dialect/Tosa/IR/TosaShapeOps.td | 72 +++++++++++++++++++
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 32 +++++++++
.../Tosa/Transforms/TosaProfileCompliance.cpp | 3 +
.../Tosa/Transforms/TosaValidation.cpp | 26 ++++++-
mlir/test/Dialect/Tosa/level_check.mlir | 37 +++++++++-
mlir/test/Dialect/Tosa/ops.mlir | 24 +++++++
.../tosa-validation-version-1p1-valid.mlir | 16 +++++
mlir/test/Dialect/Tosa/verifier.mlir | 27 +++++++
8 files changed, 235 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index be5789c1be7bf..4b1631db0bc45 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -175,6 +175,78 @@ def Tosa_DivFloorShapeOp : Tosa_ElementwiseShapeOp<"div_floor_shape", [Pure]> {
let results = (outs Tosa_Shape:$output);
}
+//===----------------------------------------------------------------------===//
+// Operator: Exp2Shape
+//===----------------------------------------------------------------------===//
+def Tosa_Exp2ShapeOp : Tosa_ElementwiseShapeOp<"exp2_shape", [Pure]> {
+ let summary = "Elementwise base-2 exponential of shapes.";
+
+ let description = [{
+ Computation of raising two to the power of each element in input.
+ }];
+
+ let arguments = (ins
+ Tosa_Shape:$input
+ );
+
+ let results = (outs Tosa_Shape:$output);
+
+ list<Availability> availability = [
+ Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+ Extension<[Tosa_EXT_SHAPE]>
+ ];
+
+ let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: Log2CeilShape
+//===----------------------------------------------------------------------===//
+def Tosa_Log2CeilShapeOp : Tosa_ElementwiseShapeOp<"log2_ceil_shape", [Pure]> {
+ let summary = "Elementwise ceil base-2 logarithm of shapes.";
+
+ let description = [{
+ Computation of the base two logarithm of each element in input. Result is rounded up.
+ }];
+
+ let arguments = (ins
+ Tosa_Shape:$input
+ );
+
+ let results = (outs Tosa_Shape:$output);
+
+ list<Availability> availability = [
+ Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+ Extension<[Tosa_EXT_SHAPE]>
+ ];
+
+ let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: Log2FloorShape
+//===----------------------------------------------------------------------===//
+def Tosa_Log2FloorShapeOp : Tosa_ElementwiseShapeOp<"log2_floor_shape", [Pure]> {
+ let summary = "Elementwise floor base-2 logarithm of shapes.";
+
+ let description = [{
+ Computation of the base two logarithm of each element in input. Result is rounded down.
+ }];
+
+ let arguments = (ins
+ Tosa_Shape:$input
+ );
+
+ let results = (outs Tosa_Shape:$output);
+
+ list<Availability> availability = [
+ Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+ Extension<[Tosa_EXT_SHAPE]>
+ ];
+
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// Operator: MaxShape
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 5656f3de698c5..f3b9fc4a90965 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -4685,6 +4685,27 @@ LogicalResult tosa::ConcatShapeOp::verify() {
return success();
}
+LogicalResult tosa::Exp2ShapeOp::verify() {
+ SmallVector<int64_t> shapes;
+ // Verification is done only for inputs with a const_shape.
+ if (tosa::getConstShapeValues(getInput().getDefiningOp(), shapes)) {
+ if (llvm::any_of(shapes, [](int64_t s) { return s < 0; }))
+ return emitOpError("input elements must be >= 0 for exp2");
+ }
+ return success();
+}
+
+LogicalResult tosa::Log2FloorShapeOp::verify() {
+ SmallVector<int64_t> shapes;
+ // Verification is done only for inputs with a const_shape.
+ if (tosa::getConstShapeValues(getInput().getDefiningOp(), shapes)) {
+ if (llvm::any_of(shapes, [](int64_t s) { return s < 1; }))
+ return emitOpError("input elements must be >= 1 for log2");
+ }
+
+ return success();
+}
+
LogicalResult tosa::SliceShapeOp::verify() {
std::optional<int32_t> start;
DenseIntElementsAttr startAttr;
@@ -4727,6 +4748,17 @@ LogicalResult tosa::SliceShapeOp::verify() {
return success();
}
+LogicalResult tosa::Log2CeilShapeOp::verify() {
+ SmallVector<int64_t> shapes;
+ // Verification is done only for inputs with a const_shape.
+ if (tosa::getConstShapeValues(getInput().getDefiningOp(), shapes)) {
+ if (llvm::any_of(shapes, [](int64_t s) { return s < 1; }))
+ return emitOpError("input elements must be >= 1 for log2");
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TOSA Attribute Definitions.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 46c8c32324313..08c702bd2f29f 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -329,6 +329,9 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
POPULATE_PROFILE_INFO_SKIP(ConstShape)
POPULATE_PROFILE_INFO_SKIP(DivCeilShape)
POPULATE_PROFILE_INFO_SKIP(DivFloorShape)
+ POPULATE_PROFILE_INFO_SKIP(Exp2Shape)
+ POPULATE_PROFILE_INFO_SKIP(Log2CeilShape)
+ POPULATE_PROFILE_INFO_SKIP(Log2FloorShape)
POPULATE_PROFILE_INFO_SKIP(MaxShape)
POPULATE_PROFILE_INFO_SKIP(MinShape)
POPULATE_PROFILE_INFO_SKIP(ModShape)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index dc2c90bbf1199..309d2d163ebdd 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -437,6 +437,22 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
return success();
}
+ // Exp2Shape op: level check max shape
+ LogicalResult levelCheckExp2Shape(Operation *op) {
+ if (auto exp2Shape = dyn_cast<tosa::Exp2ShapeOp>(op)) {
+ SmallVector<int64_t> shapes;
+ // Level check is done only for inputs with a const_shape.
+ if (tosa::getConstShapeValues(exp2Shape.getInput().getDefiningOp(),
+ shapes)) {
+ for (int64_t shape : shapes) {
+ if (shape >= targetEnv.getLevel().MAX_LOG2_SIZE)
+ return op->emitOpError("failed level check: shape < MAX_LOG2_SIZE");
+ }
+ }
+ }
+ return success();
+ }
+
// Recursively perform a bottom-up search to determine the maximum nesting
// depth, starting from a specific operation and continuing up to the function
// or module scope. Tosa nesting_depth starts at 0 and increments by one each
@@ -673,6 +689,13 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
CHECK_RANKS_AND_SIZES(VariableRead);
// Shape Operators
CHECK_RANKS_AND_SIZES(Dim);
+ CHECK_RANKS_AND_SIZES(DivCeilShape);
+ CHECK_RANKS_AND_SIZES(DivFloorShape);
+ CHECK_RANKS_AND_SIZES(Exp2Shape);
+ CHECK_RANKS_AND_SIZES(Log2CeilShape);
+ CHECK_RANKS_AND_SIZES(Log2FloorShape);
+ CHECK_RANKS_AND_SIZES(MulShape);
+ CHECK_RANKS_AND_SIZES(SubShape);
// For the following operators, check whether the size of each tensor
// operand is valid in a given Level.
@@ -785,7 +808,8 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
failed(levelCheckFFT<tosa::FFT2dOp>(op)) ||
failed(levelCheckPool<tosa::MaxPool2dOp>(op)) ||
failed(levelCheckFFT<tosa::RFFT2dOp>(op)) ||
- failed(levelCheckTransposeConv2d(op)) || failed(levelCheckResize(op))) {
+ failed(levelCheckTransposeConv2d(op)) || failed(levelCheckResize(op)) ||
+ failed(levelCheckExp2Shape(op))) {
return failure();
}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 5e681c1ef75c8..d062ebdc2dd19 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1691,7 +1691,6 @@ func.func @test_dim(%arg0: tensor<1x2x3x4x5x6x7x8xi32>) -> !tosa.shape<1> {
return %0 : !tosa.shape<1>
}
-
// -----
func.func @test_concat_shape_invalid_list_size() {
@@ -1722,6 +1721,42 @@ func.func @test_concat_shape_invalid_list_size() {
// -----
+func.func @test_exp2_shape_invalid_rank() -> !tosa.shape<7> {
+ %0 = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7>
+ // expected-error at +1 {{'tosa.exp2_shape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}}
+ %1 = tosa.exp2_shape %0 : (!tosa.shape<7>) -> !tosa.shape<7>
+ return %1 : !tosa.shape<7>
+}
+
+// -----
+
+func.func @test_exp2_shape_invalid_shape() -> !tosa.shape<4> {
+ %0 = tosa.const_shape {values = dense<[1, 2, 3, 32]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ // expected-error at +1 {{'tosa.exp2_shape' op failed level check: shape < MAX_LOG2_SIZE}}
+ %1 = tosa.exp2_shape %0 : (!tosa.shape<4>) -> !tosa.shape<4>
+ return %1 : !tosa.shape<4>
+}
+
+// -----
+
+func.func @test_log2_floor_shape_invalid_rank() -> !tosa.shape<7> {
+ %0 = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7>
+ // expected-error at +1 {{'tosa.log2_floor_shape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}}
+ %1 = tosa.log2_floor_shape %0 : (!tosa.shape<7>) -> !tosa.shape<7>
+ return %1 : !tosa.shape<7>
+}
+
+// -----
+
+func.func @test_log2_ceil_shape_invalid_rank() -> !tosa.shape<7> {
+ %0 = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7>
+ // expected-error at +1 {{'tosa.log2_ceil_shape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}}
+ %1 = tosa.log2_ceil_shape %0 : (!tosa.shape<7>) -> !tosa.shape<7>
+ return %1 : !tosa.shape<7>
+}
+
+// -----
+
func.func @test_mod_shape_invalid_rank() -> !tosa.shape<9> {
%a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<9xindex>} : () -> !tosa.shape<9>
%b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<9xindex>} : () -> !tosa.shape<9>
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 8d7d703e2f450..5d6384ca34cac 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -1481,6 +1481,30 @@ func.func @test_slice_shape_dynamic(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>)
return %3 : !tosa.shape<3>
}
+// -----
+// CHECK-LABEL: test_exp2_shape
+func.func @test_exp2_shape() -> !tosa.shape<4> {
+ %a = tosa.const_shape {values = dense<[5, 7, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %b = tosa.exp2_shape %a : (!tosa.shape<4>) -> !tosa.shape<4>
+ return %b : !tosa.shape<4>
+}
+
+// -----
+// CHECK-LABEL: test_log2_ceil_shape
+func.func @test_log2_ceil_shape() -> !tosa.shape<4> {
+ %a = tosa.const_shape {values = dense<[5, 7, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %b = tosa.log2_ceil_shape %a : (!tosa.shape<4>) -> !tosa.shape<4>
+ return %b : !tosa.shape<4>
+}
+
+// -----
+// CHECK-LABEL: test_log2_floor_shape
+func.func @test_log2_floor_shape() -> !tosa.shape<4> {
+ %a = tosa.const_shape {values = dense<[5, 7, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %b = tosa.log2_floor_shape %a : (!tosa.shape<4>) -> !tosa.shape<4>
+ return %b : !tosa.shape<4>
+}
+
// -----
// CHECK-LABEL: test_max_shape
func.func @test_max_shape() -> !tosa.shape<4> {
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index 1b2403c60fbbd..d11c2b512b273 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -167,6 +167,22 @@ func.func @test_dim(%arg0: tensor<1x2x3x4xi32>) -> !tosa.shape<1> {
return %0 : !tosa.shape<1>
}
+// -----
+// CHECK-LABEL: test_exp2_shape
+func.func @test_exp2_shape() -> !tosa.shape<4> {
+ %a = tosa.const_shape {values = dense<[5, 7, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %b = tosa.exp2_shape %a : (!tosa.shape<4>) -> !tosa.shape<4>
+ return %b : !tosa.shape<4>
+}
+
+// -----
+// CHECK-LABEL: test_log2_ceil_shape
+func.func @test_log2_ceil_shape() -> !tosa.shape<4> {
+ %a = tosa.const_shape {values = dense<[5, 7, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %b = tosa.log2_ceil_shape %a : (!tosa.shape<4>) -> !tosa.shape<4>
+ return %b : !tosa.shape<4>
+}
+
// -----
// CHECK-LABEL: test_mod_shape
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index e444664cf2b93..d17d9f6287cc5 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1284,6 +1284,33 @@ func.func @test_concat_shape_rank_mismatch() -> !tosa.shape<4> {
// -----
+func.func @test_exp2_negative_shape() -> !tosa.shape<4> {
+ %0 = tosa.const_shape {values = dense<[-1, 0, 10, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ // expected-error @+1 {{'tosa.exp2_shape' op input elements must be >= 0 for exp2}}
+ %1 = tosa.exp2_shape %0 : (!tosa.shape<4>) -> !tosa.shape<4>
+ return %1 : !tosa.shape<4>
+}
+
+// -----
+
+func.func @test_log2_floor_zero_shape() -> !tosa.shape<4> {
+ %0 = tosa.const_shape {values = dense<[5, 0, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ // expected-error @+1 {{'tosa.log2_floor_shape' op input elements must be >= 1 for log2}}
+ %1 = tosa.log2_floor_shape %0 : (!tosa.shape<4>) -> !tosa.shape<4>
+ return %1 : !tosa.shape<4>
+}
+
+// -----
+
+func.func @test_log2_ceil_zero_shape() -> !tosa.shape<4> {
+ %0 = tosa.const_shape {values = dense<[5, 0, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ // expected-error @+1 {{'tosa.log2_ceil_shape' op input elements must be >= 1 for log2}}
+ %1 = tosa.log2_ceil_shape %0 : (!tosa.shape<4>) -> !tosa.shape<4>
+ return %1 : !tosa.shape<4>
+}
+
+// -----
+
func.func @test_slice_shape_negative_start() -> !tosa.shape<3> {
%0 = tosa.const_shape {values = dense<[4, 5, 6, 7, 8, 9]> : tensor<6xindex>} : () -> !tosa.shape<6>
%1 = "tosa.const"() {values = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
>From c78e00e3cddda5828cbf5fedf53fdc68aba27107 Mon Sep 17 00:00:00 2001
From: Yuvaraj Venkatesh <yuvaraj.venkatesh at arm.com>
Date: Fri, 9 Jan 2026 11:24:29 +0000
Subject: [PATCH 2/2] [mlir][tosa] Removed verifier and level check
Change-Id: Ib3b00a68d37fd41deb27abffc05f8f17ac230664
---
.../mlir/Dialect/Tosa/IR/TosaShapeOps.td | 21 ------------
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 32 -------------------
.../Tosa/Transforms/TosaValidation.cpp | 29 +++--------------
mlir/test/Dialect/Tosa/level_check.mlir | 9 ------
mlir/test/Dialect/Tosa/verifier.mlir | 27 ----------------
5 files changed, 4 insertions(+), 114 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index 4b1631db0bc45..d8597151714c3 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -190,13 +190,6 @@ def Tosa_Exp2ShapeOp : Tosa_ElementwiseShapeOp<"exp2_shape", [Pure]> {
);
let results = (outs Tosa_Shape:$output);
-
- list<Availability> availability = [
- Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
- Extension<[Tosa_EXT_SHAPE]>
- ];
-
- let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -214,13 +207,6 @@ def Tosa_Log2CeilShapeOp : Tosa_ElementwiseShapeOp<"log2_ceil_shape", [Pure]> {
);
let results = (outs Tosa_Shape:$output);
-
- list<Availability> availability = [
- Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
- Extension<[Tosa_EXT_SHAPE]>
- ];
-
- let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -238,13 +224,6 @@ def Tosa_Log2FloorShapeOp : Tosa_ElementwiseShapeOp<"log2_floor_shape", [Pure]>
);
let results = (outs Tosa_Shape:$output);
-
- list<Availability> availability = [
- Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
- Extension<[Tosa_EXT_SHAPE]>
- ];
-
- let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index f3b9fc4a90965..5656f3de698c5 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -4685,27 +4685,6 @@ LogicalResult tosa::ConcatShapeOp::verify() {
return success();
}
-LogicalResult tosa::Exp2ShapeOp::verify() {
- SmallVector<int64_t> shapes;
- // Verification is done only for inputs with a const_shape.
- if (tosa::getConstShapeValues(getInput().getDefiningOp(), shapes)) {
- if (llvm::any_of(shapes, [](int64_t s) { return s < 0; }))
- return emitOpError("input elements must be >= 0 for exp2");
- }
- return success();
-}
-
-LogicalResult tosa::Log2FloorShapeOp::verify() {
- SmallVector<int64_t> shapes;
- // Verification is done only for inputs with a const_shape.
- if (tosa::getConstShapeValues(getInput().getDefiningOp(), shapes)) {
- if (llvm::any_of(shapes, [](int64_t s) { return s < 1; }))
- return emitOpError("input elements must be >= 1 for log2");
- }
-
- return success();
-}
-
LogicalResult tosa::SliceShapeOp::verify() {
std::optional<int32_t> start;
DenseIntElementsAttr startAttr;
@@ -4748,17 +4727,6 @@ LogicalResult tosa::SliceShapeOp::verify() {
return success();
}
-LogicalResult tosa::Log2CeilShapeOp::verify() {
- SmallVector<int64_t> shapes;
- // Verification is done only for inputs with a const_shape.
- if (tosa::getConstShapeValues(getInput().getDefiningOp(), shapes)) {
- if (llvm::any_of(shapes, [](int64_t s) { return s < 1; }))
- return emitOpError("input elements must be >= 1 for log2");
- }
-
- return success();
-}
-
//===----------------------------------------------------------------------===//
// TOSA Attribute Definitions.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 309d2d163ebdd..a900aef04f753 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -437,22 +437,6 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
return success();
}
- // Exp2Shape op: level check max shape
- LogicalResult levelCheckExp2Shape(Operation *op) {
- if (auto exp2Shape = dyn_cast<tosa::Exp2ShapeOp>(op)) {
- SmallVector<int64_t> shapes;
- // Level check is done only for inputs with a const_shape.
- if (tosa::getConstShapeValues(exp2Shape.getInput().getDefiningOp(),
- shapes)) {
- for (int64_t shape : shapes) {
- if (shape >= targetEnv.getLevel().MAX_LOG2_SIZE)
- return op->emitOpError("failed level check: shape < MAX_LOG2_SIZE");
- }
- }
- }
- return success();
- }
-
// Recursively perform a bottom-up search to determine the maximum nesting
// depth, starting from a specific operation and continuing up to the function
// or module scope. Tosa nesting_depth starts at 0 and increments by one each
@@ -689,13 +673,6 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
CHECK_RANKS_AND_SIZES(VariableRead);
// Shape Operators
CHECK_RANKS_AND_SIZES(Dim);
- CHECK_RANKS_AND_SIZES(DivCeilShape);
- CHECK_RANKS_AND_SIZES(DivFloorShape);
- CHECK_RANKS_AND_SIZES(Exp2Shape);
- CHECK_RANKS_AND_SIZES(Log2CeilShape);
- CHECK_RANKS_AND_SIZES(Log2FloorShape);
- CHECK_RANKS_AND_SIZES(MulShape);
- CHECK_RANKS_AND_SIZES(SubShape);
// For the following operators, check whether the size of each tensor
// operand is valid in a given Level.
@@ -731,6 +708,9 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
CHECK_RANKS(ConcatShape);
CHECK_RANKS(DivCeilShape);
CHECK_RANKS(DivFloorShape);
+ CHECK_RANKS(Exp2Shape);
+ CHECK_RANKS(Log2CeilShape);
+ CHECK_RANKS(Log2FloorShape);
CHECK_RANKS(MaxShape);
CHECK_RANKS(MinShape);
CHECK_RANKS(ModShape);
@@ -808,8 +788,7 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
failed(levelCheckFFT<tosa::FFT2dOp>(op)) ||
failed(levelCheckPool<tosa::MaxPool2dOp>(op)) ||
failed(levelCheckFFT<tosa::RFFT2dOp>(op)) ||
- failed(levelCheckTransposeConv2d(op)) || failed(levelCheckResize(op)) ||
- failed(levelCheckExp2Shape(op))) {
+ failed(levelCheckTransposeConv2d(op)) || failed(levelCheckResize(op))) {
return failure();
}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index d062ebdc2dd19..d874ab6d23a50 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1730,15 +1730,6 @@ func.func @test_exp2_shape_invalid_rank() -> !tosa.shape<7> {
// -----
-func.func @test_exp2_shape_invalid_shape() -> !tosa.shape<4> {
- %0 = tosa.const_shape {values = dense<[1, 2, 3, 32]> : tensor<4xindex>} : () -> !tosa.shape<4>
- // expected-error at +1 {{'tosa.exp2_shape' op failed level check: shape < MAX_LOG2_SIZE}}
- %1 = tosa.exp2_shape %0 : (!tosa.shape<4>) -> !tosa.shape<4>
- return %1 : !tosa.shape<4>
-}
-
-// -----
-
func.func @test_log2_floor_shape_invalid_rank() -> !tosa.shape<7> {
%0 = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7>
// expected-error at +1 {{'tosa.log2_floor_shape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index d17d9f6287cc5..e444664cf2b93 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1284,33 +1284,6 @@ func.func @test_concat_shape_rank_mismatch() -> !tosa.shape<4> {
// -----
-func.func @test_exp2_negative_shape() -> !tosa.shape<4> {
- %0 = tosa.const_shape {values = dense<[-1, 0, 10, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
- // expected-error @+1 {{'tosa.exp2_shape' op input elements must be >= 0 for exp2}}
- %1 = tosa.exp2_shape %0 : (!tosa.shape<4>) -> !tosa.shape<4>
- return %1 : !tosa.shape<4>
-}
-
-// -----
-
-func.func @test_log2_floor_zero_shape() -> !tosa.shape<4> {
- %0 = tosa.const_shape {values = dense<[5, 0, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
- // expected-error @+1 {{'tosa.log2_floor_shape' op input elements must be >= 1 for log2}}
- %1 = tosa.log2_floor_shape %0 : (!tosa.shape<4>) -> !tosa.shape<4>
- return %1 : !tosa.shape<4>
-}
-
-// -----
-
-func.func @test_log2_ceil_zero_shape() -> !tosa.shape<4> {
- %0 = tosa.const_shape {values = dense<[5, 0, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
- // expected-error @+1 {{'tosa.log2_ceil_shape' op input elements must be >= 1 for log2}}
- %1 = tosa.log2_ceil_shape %0 : (!tosa.shape<4>) -> !tosa.shape<4>
- return %1 : !tosa.shape<4>
-}
-
-// -----
-
func.func @test_slice_shape_negative_start() -> !tosa.shape<3> {
%0 = tosa.const_shape {values = dense<[4, 5, 6, 7, 8, 9]> : tensor<6xindex>} : () -> !tosa.shape<6>
%1 = "tosa.const"() {values = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
More information about the Mlir-commits
mailing list