[Mlir-commits] [mlir] [mlir][tosa] Add log2_ceil/log2_floor/exp2_shape ops (PR #175057)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 8 11:34:57 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Yuvaraj Venkatesh (Yuvaraj-Venkatesh)

<details>
<summary>Changes</summary>

This commit introduces new ext-shape operations,
- LOG2_CEIL_SHAPE
- LOG2_FLOOR_SHAPE
- EXP2_SHAPE

These additions include the operator definitions, same-rank verification, and level checks during validation.


---
Full diff: https://github.com/llvm/llvm-project/pull/175057.diff


8 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td (+72) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+32) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+3) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+25-1) 
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+36-1) 
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+24) 
- (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir (+16) 
- (modified) mlir/test/Dialect/Tosa/verifier.mlir (+27) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index 79967b7c9585e..927d68a436116 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -175,6 +175,78 @@ def Tosa_DivFloorShapeOp : Tosa_ElementwiseShapeOp<"div_floor_shape", [Pure]> {
   let results = (outs Tosa_Shape:$output);
 }
 
+//===----------------------------------------------------------------------===//
+// Operator: Exp2Shape
+//===----------------------------------------------------------------------===//
+def Tosa_Exp2ShapeOp : Tosa_ElementwiseShapeOp<"exp2_shape", [Pure]> {
+  let summary = "Elementwise base-2 exponential of shapes.";
+
+  let description = [{
+      Computation of raising two to the power of each element in input.
+  }];
+
+  let arguments = (ins
+    Tosa_Shape:$input
+  );
+
+  let results = (outs Tosa_Shape:$output);
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_SHAPE]>
+  ];
+
+  let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: Log2CeilShape
+//===----------------------------------------------------------------------===//
+def Tosa_Log2CeilShapeOp : Tosa_ElementwiseShapeOp<"log2_ceil_shape", [Pure]> {
+  let summary = "Elementwise ceil base-2 logarithm of shapes.";
+
+  let description = [{
+      Computation of the base two logarithm of each element in input. Result is rounded up.
+  }];
+
+  let arguments = (ins
+    Tosa_Shape:$input
+  );
+
+  let results = (outs Tosa_Shape:$output);
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_SHAPE]>
+  ];
+
+  let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: Log2FloorShape
+//===----------------------------------------------------------------------===//
+def Tosa_Log2FloorShapeOp : Tosa_ElementwiseShapeOp<"log2_floor_shape", [Pure]> {
+  let summary = "Elementwise floor base-2 logarithm of shapes.";
+
+  let description = [{
+      Computation of the base two logarithm of each element in input. Result is rounded down.
+  }];
+
+  let arguments = (ins
+    Tosa_Shape:$input
+  );
+
+  let results = (outs Tosa_Shape:$output);
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_SHAPE]>
+  ];
+
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Operator: MulShape
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 5656f3de698c5..f3b9fc4a90965 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -4685,6 +4685,27 @@ LogicalResult tosa::ConcatShapeOp::verify() {
   return success();
 }
 
+LogicalResult tosa::Exp2ShapeOp::verify() {
+  SmallVector<int64_t> shapes;
+  // Verification is done only for inputs with a const_shape.
+  if (tosa::getConstShapeValues(getInput().getDefiningOp(), shapes)) {
+    if (llvm::any_of(shapes, [](int64_t s) { return s < 0; }))
+      return emitOpError("input elements must be >= 0 for exp2");
+  }
+  return success();
+}
+
+LogicalResult tosa::Log2FloorShapeOp::verify() {
+  SmallVector<int64_t> shapes;
+  // Verification is done only for inputs with a const_shape.
+  if (tosa::getConstShapeValues(getInput().getDefiningOp(), shapes)) {
+    if (llvm::any_of(shapes, [](int64_t s) { return s < 1; }))
+      return emitOpError("input elements must be >= 1 for log2");
+  }
+
+  return success();
+}
+
 LogicalResult tosa::SliceShapeOp::verify() {
   std::optional<int32_t> start;
   DenseIntElementsAttr startAttr;
@@ -4727,6 +4748,17 @@ LogicalResult tosa::SliceShapeOp::verify() {
   return success();
 }
 
+LogicalResult tosa::Log2CeilShapeOp::verify() {
+  SmallVector<int64_t> shapes;
+  // Verification is done only for inputs with a const_shape.
+  if (tosa::getConstShapeValues(getInput().getDefiningOp(), shapes)) {
+    if (llvm::any_of(shapes, [](int64_t s) { return s < 1; }))
+      return emitOpError("input elements must be >= 1 for log2");
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Attribute Definitions.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index e8a24057a96ac..9ef3af615ed75 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -329,6 +329,9 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
   POPULATE_PROFILE_INFO_SKIP(ConstShape)
   POPULATE_PROFILE_INFO_SKIP(DivCeilShape)
   POPULATE_PROFILE_INFO_SKIP(DivFloorShape)
+  POPULATE_PROFILE_INFO_SKIP(Exp2Shape)
+  POPULATE_PROFILE_INFO_SKIP(Log2CeilShape)
+  POPULATE_PROFILE_INFO_SKIP(Log2FloorShape)
   POPULATE_PROFILE_INFO_SKIP(MulShape)
   POPULATE_PROFILE_INFO_SKIP(SliceShape)
   POPULATE_PROFILE_INFO_SKIP(SubShape)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 4b9c8b2030a49..4ae484f2d761c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -437,6 +437,22 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     return success();
   }
 
+  // Exp2Shape op: level check max shape
+  LogicalResult levelCheckExp2Shape(Operation *op) {
+    if (auto exp2Shape = dyn_cast<tosa::Exp2ShapeOp>(op)) {
+      SmallVector<int64_t> shapes;
+      // Level check is done only for inputs with a const_shape.
+      if (tosa::getConstShapeValues(exp2Shape.getInput().getDefiningOp(),
+                                    shapes)) {
+        for (int64_t shape : shapes) {
+          if (shape >= targetEnv.getLevel().MAX_LOG2_SIZE)
+            return op->emitOpError("failed level check: shape < MAX_LOG2_SIZE");
+        }
+      }
+    }
+    return success();
+  }
+
   // Recursively perform a bottom-up search to determine the maximum nesting
   // depth, starting from a specific operation and continuing up to the function
   // or module scope. Tosa nesting_depth starts at 0 and increments by one each
@@ -673,6 +689,13 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
   CHECK_RANKS_AND_SIZES(VariableRead);
   // Shape Operators
   CHECK_RANKS_AND_SIZES(Dim);
+  CHECK_RANKS_AND_SIZES(DivCeilShape);
+  CHECK_RANKS_AND_SIZES(DivFloorShape);
+  CHECK_RANKS_AND_SIZES(Exp2Shape);
+  CHECK_RANKS_AND_SIZES(Log2CeilShape);
+  CHECK_RANKS_AND_SIZES(Log2FloorShape);
+  CHECK_RANKS_AND_SIZES(MulShape);
+  CHECK_RANKS_AND_SIZES(SubShape);
 
   // For the following operators, check whether the size of each tensor
   // operand is valid in a given Level.
@@ -782,7 +805,8 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
       failed(levelCheckFFT<tosa::FFT2dOp>(op)) ||
       failed(levelCheckPool<tosa::MaxPool2dOp>(op)) ||
       failed(levelCheckFFT<tosa::RFFT2dOp>(op)) ||
-      failed(levelCheckTransposeConv2d(op)) || failed(levelCheckResize(op))) {
+      failed(levelCheckTransposeConv2d(op)) || failed(levelCheckResize(op)) ||
+      failed(levelCheckExp2Shape(op))) {
     return failure();
   }
 
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index ec540bda5e57d..a3069671f3ddb 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1691,7 +1691,6 @@ func.func @test_dim(%arg0: tensor<1x2x3x4x5x6x7x8xi32>) -> !tosa.shape<1> {
   return %0 : !tosa.shape<1>
 }
 
-
 // -----
 
 func.func @test_concat_shape_invalid_list_size() {
@@ -1719,3 +1718,39 @@ func.func @test_concat_shape_invalid_list_size() {
                          ) -> !tosa.shape<0>
   return
 }
+
+// -----
+
+func.func @test_exp2_shape_invalid_rank() -> !tosa.shape<7> {
+  %0 = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7>
+  // expected-error at +1 {{'tosa.exp2_shape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}}
+  %1 = tosa.exp2_shape %0 : (!tosa.shape<7>) -> !tosa.shape<7>
+  return %1 : !tosa.shape<7>
+}
+
+// -----
+
+func.func @test_exp2_shape_invalid_shape() -> !tosa.shape<4> {
+  %0 = tosa.const_shape {values = dense<[1, 2, 3, 32]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error at +1 {{'tosa.exp2_shape' op failed level check: shape < MAX_LOG2_SIZE}}
+  %1 = tosa.exp2_shape %0 : (!tosa.shape<4>) -> !tosa.shape<4>
+  return %1 : !tosa.shape<4>
+}
+
+// -----
+
+func.func @test_log2_floor_shape_invalid_rank() -> !tosa.shape<7> {
+  %0 = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7>
+  // expected-error at +1 {{'tosa.log2_floor_shape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}}
+  %1 = tosa.log2_floor_shape %0 : (!tosa.shape<7>) -> !tosa.shape<7>
+  return %1 : !tosa.shape<7>
+}
+
+// -----
+
+func.func @test_log2_ceil_shape_invalid_rank() -> !tosa.shape<7> {
+  %0 = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7>
+  // expected-error at +1 {{'tosa.log2_ceil_shape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}}
+  %1 = tosa.log2_ceil_shape %0 : (!tosa.shape<7>) -> !tosa.shape<7>
+  return %1 : !tosa.shape<7>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 276eac4d6166d..12b5aff796d52 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -1471,3 +1471,27 @@ func.func @test_slice_shape_dynamic(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>)
   %3 = tosa.slice_shape %0, %arg0, %arg1 : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<3>
   return %3 : !tosa.shape<3>
 }
+
+// -----
+// CHECK-LABEL: test_exp2_shape
+func.func @test_exp2_shape() -> !tosa.shape<4> {
+  %a = tosa.const_shape {values = dense<[5, 7, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %b = tosa.exp2_shape %a : (!tosa.shape<4>) -> !tosa.shape<4>
+  return %b : !tosa.shape<4>
+}
+
+// -----
+// CHECK-LABEL: test_log2_ceil_shape
+func.func @test_log2_ceil_shape() -> !tosa.shape<4> {
+  %a = tosa.const_shape {values = dense<[5, 7, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %b = tosa.log2_ceil_shape %a : (!tosa.shape<4>) -> !tosa.shape<4>
+  return %b : !tosa.shape<4>
+}
+
+// -----
+// CHECK-LABEL: test_log2_floor_shape
+func.func @test_log2_floor_shape() -> !tosa.shape<4> {
+  %a = tosa.const_shape {values = dense<[5, 7, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %b = tosa.log2_floor_shape %a : (!tosa.shape<4>) -> !tosa.shape<4>
+  return %b : !tosa.shape<4>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index 63379ed8d8a4d..a4f835e4a4c46 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -166,3 +166,19 @@ func.func @test_dim(%arg0: tensor<1x2x3x4xi32>) -> !tosa.shape<1> {
   %0 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x2x3x4xi32>) -> !tosa.shape<1>
   return %0 : !tosa.shape<1>
 }
+
+// -----
+// CHECK-LABEL: test_exp2_shape
+func.func @test_exp2_shape() -> !tosa.shape<4> {
+  %a = tosa.const_shape {values = dense<[5, 7, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %b = tosa.exp2_shape %a : (!tosa.shape<4>) -> !tosa.shape<4>
+  return %b : !tosa.shape<4>
+}
+
+// -----
+// CHECK-LABEL: test_log2_ceil_shape
+func.func @test_log2_ceil_shape() -> !tosa.shape<4> {
+  %a = tosa.const_shape {values = dense<[5, 7, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %b = tosa.log2_ceil_shape %a : (!tosa.shape<4>) -> !tosa.shape<4>
+  return %b : !tosa.shape<4>
+}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index a51ed4f09400f..d3e883cdbf9b8 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1284,6 +1284,33 @@ func.func @test_concat_shape_rank_mismatch() -> !tosa.shape<4> {
 
 // -----
 
+func.func @test_exp2_negative_shape() -> !tosa.shape<4> {
+  %0 = tosa.const_shape {values = dense<[-1, 0, 10, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error @+1 {{'tosa.exp2_shape' op input elements must be >= 0 for exp2}}
+  %1 = tosa.exp2_shape %0 : (!tosa.shape<4>) -> !tosa.shape<4>
+  return %1 : !tosa.shape<4>
+}
+
+// -----
+
+func.func @test_log2_floor_zero_shape() -> !tosa.shape<4> {
+  %0 = tosa.const_shape {values = dense<[5, 0, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error @+1 {{'tosa.log2_floor_shape' op input elements must be >= 1 for log2}}
+  %1 = tosa.log2_floor_shape %0 : (!tosa.shape<4>) -> !tosa.shape<4>
+  return %1 : !tosa.shape<4>
+}
+
+// -----
+
+func.func @test_log2_ceil_zero_shape() -> !tosa.shape<4> {
+  %0 = tosa.const_shape {values = dense<[5, 0, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error @+1 {{'tosa.log2_ceil_shape' op input elements must be >= 1 for log2}}
+  %1 = tosa.log2_ceil_shape %0 : (!tosa.shape<4>) -> !tosa.shape<4>
+  return %1 : !tosa.shape<4>
+}
+
+// -----
+
 func.func @test_slice_shape_negative_start() -> !tosa.shape<3> {
   %0 = tosa.const_shape {values = dense<[4, 5, 6, 7, 8, 9]> : tensor<6xindex>} : () -> !tosa.shape<6>
   %1 = "tosa.const"() {values = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/175057


More information about the Mlir-commits mailing list