[Mlir-commits] [mlir] 728aa16 - [mlir][tosa]: Add Unary Shape Ops folders (#180762)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 11 02:07:01 PST 2026


Author: Udaya Ranga
Date: 2026-02-11T10:06:56Z
New Revision: 728aa1666536d3d7197825aa4b2feba11341f528

URL: https://github.com/llvm/llvm-project/commit/728aa1666536d3d7197825aa4b2feba11341f528
DIFF: https://github.com/llvm/llvm-project/commit/728aa1666536d3d7197825aa4b2feba11341f528.diff

LOG: [mlir][tosa]: Add Unary Shape Ops folders (#180762)

* EXP2_SHAPE
* LOG2_CEIL_SHAPE
* LOG2_FLOOR_SHAPE

Signed-off-by: Udaya Ranga <udaya.ranga at arm.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/test/Dialect/Tosa/constant_folding.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index cbcc2a017ac3a..45e03b6579970 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -198,6 +198,8 @@ def Tosa_Exp2ShapeOp : Tosa_ElementwiseShapeOp<"exp2_shape", [Pure]> {
   );
 
   let results = (outs Tosa_Shape:$output);
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -215,6 +217,8 @@ def Tosa_Log2CeilShapeOp : Tosa_ElementwiseShapeOp<"log2_ceil_shape", [Pure]> {
   );
 
   let results = (outs Tosa_Shape:$output);
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -232,6 +236,8 @@ def Tosa_Log2FloorShapeOp : Tosa_ElementwiseShapeOp<"log2_floor_shape", [Pure]>
   );
 
   let results = (outs Tosa_Shape:$output);
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 42033ce8a3b02..bb715a90b2ea2 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -986,6 +986,43 @@ binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy,
 
   return {};
 }
+
+template <typename Folder>
+static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy,
+                                     bool foldDenseValues = false) {
+  if (!val)
+    return {};
+
+  const auto vETy = llvm::cast<ShapedType>(val.getType()).getElementType();
+
+  if (val.isSplat()) {
+    if (const auto vIntTy = llvm::dyn_cast<IntegerType>(vETy)) {
+      const APInt v = val.getSplatValue<APInt>();
+      const auto maybeResult = Folder::fold(v, vIntTy.isUnsigned());
+      if (failed(maybeResult))
+        return {};
+      return DenseElementsAttr::get(returnTy, maybeResult.value());
+    }
+  }
+
+  if (foldDenseValues) {
+    mlir::Type elemTy = val.getElementType();
+    if (elemTy.isIntOrIndex()) {
+      SmallVector<APInt> resultValues;
+      for (auto const &v : val.getValues<APInt>()) {
+        const auto maybeResult = Folder::fold(v, false);
+        if (failed(maybeResult))
+          return {};
+        resultValues.push_back(maybeResult.value());
+      }
+      return DenseElementsAttr::get(returnTy, resultValues);
+    }
+  }
+
+  // Folding arbitrarily sized tensor operations is not supported
+  return {};
+}
+
 struct AddFoldAdaptor {
   static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
                                const bool isUnsigned) {
@@ -1142,6 +1179,38 @@ struct MinFoldAdaptor {
   }
 };
 
+struct Exp2FoldAdaptor {
+  static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
+    auto const numBits = value.getBitWidth();
+    if (isUnsigned) {
+      auto const zextv = value.getZExtValue();
+      if (zextv >= numBits)
+        return failure();
+      return APInt::getOneBitSet(numBits, zextv);
+    }
+    auto const sextv = value.getSExtValue();
+    if (sextv < 0 || sextv >= numBits || (value.isNegative()))
+      return failure();
+    return APInt::getOneBitSet(numBits, sextv);
+  }
+};
+
+struct Log2CeilFoldAdaptor {
+  static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
+    if (!value.isStrictlyPositive())
+      return failure();
+    return APInt(/*numBits=*/value.getBitWidth(), value.ceilLogBase2());
+  }
+};
+
+struct Log2FloorFoldAdaptor {
+  static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
+    if (!value.isStrictlyPositive())
+      return failure();
+    return APInt(/*numBits=*/value.getBitWidth(), value.logBase2());
+  }
+};
+
 struct GreaterFoldAdaptor {
   static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
                                const bool isUnsigned) {
@@ -1250,7 +1319,8 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
   if (lhsTy != rhsTy)
     return {};
 
-  // IntDivOp inputs must be integer type, no need to check for quantized type
+  // IntDivOp inputs must be integer type, no need to check for quantized
+  // type
   auto resultETy = resultTy.getElementType();
   auto lhsAttr =
       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
@@ -1300,7 +1370,8 @@ std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
     auto round = APInt(64, 1) << (shift - 1);
     result += round;
     result.ashrInPlace(shift);
-    // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
+    // REQUIRE(product >= minimum_s<i32_t>() && product <=
+    // maximum_s<i32_t>())
     if (!(result.getSExtValue() >= INT32_MIN &&
           result.getSExtValue() <= INT32_MAX)) {
       // REQUIRE failed
@@ -1356,8 +1427,8 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
   auto rhsAttr =
       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
-  // Result right shift on i32_t data type only. For simplification, synthesize
-  // a zero shift for other data type.
+  // Result right shift on i32_t data type only. For simplification,
+  // synthesize a zero shift for other data type.
   int32_t shift = 0;
   if (resultETy.isInteger(32)) {
     ElementsAttr shift_elem;
@@ -1449,8 +1520,8 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
   Value rhs = getInput2();
   auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
 
-  // If we are comparing an integer value to itself it is always true. We can
-  // not do this with float due to float values.
+  // If we are comparing an integer value to itself it is always true. We
+  // can not do this with float due to float values.
   if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
       resultTy.hasStaticShape() && lhs == rhs) {
     return DenseElementsAttr::get(resultTy, true);
@@ -1561,9 +1632,9 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
   if (!inputTy || !outputTy)
     return {};
 
-  // Fold when the input and output types are the same. This is only safe when
-  // there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
-  // there may still be a productive reshape.
+  // Fold when the input and output types are the same. This is only safe
+  // when there is at most 1 dynamic dimension. For 2 or more dynamic
+  // dimensions, there may still be a productive reshape.
   if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
     return getInput1();
 
@@ -1882,6 +1953,19 @@ OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
+template <typename Op, typename OpFoldAdaptor>
+OpFoldResult unaryShapeFold(Op *op) {
+  auto input1ConstShape =
+      dyn_cast<tosa::ConstShapeOp>(op->getInput().getDefiningOp());
+  if (!input1ConstShape)
+    return {};
+
+  const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
+
+  return unaryFolder<OpFoldAdaptor>(input1Attr, input1Attr.getType(),
+                                    /*foldDenseValues=*/true);
+}
+
 template <typename Op, typename OpFoldAdaptor>
 OpFoldResult binaryFold(Op *op) {
   auto input1ConstShape =
@@ -1894,8 +1978,9 @@ OpFoldResult binaryFold(Op *op) {
   const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
   const auto input2Attr = cast<DenseElementsAttr>(input2ConstShape.getValues());
 
-  return binaryFolder<OpFoldAdaptor>(
-      input1Attr, input2Attr, input1Attr.getType(), /*foldDenseValues=*/true);
+  return binaryFolder<OpFoldAdaptor>(input1Attr, input2Attr,
+                                     input1Attr.getType(),
+                                     /*foldDenseValues=*/true);
 }
 
 OpFoldResult tosa::DimOp::fold(FoldAdaptor adaptor) {
@@ -1944,3 +2029,15 @@ OpFoldResult tosa::MaxShapeOp::fold(FoldAdaptor adaptor) {
 OpFoldResult tosa::MinShapeOp::fold(FoldAdaptor adaptor) {
   return binaryFold<MinShapeOp, MinFoldAdaptor>(this);
 }
+
+OpFoldResult tosa::Exp2ShapeOp::fold(FoldAdaptor adaptor) {
+  return unaryShapeFold<Exp2ShapeOp, Exp2FoldAdaptor>(this);
+}
+
+OpFoldResult tosa::Log2CeilShapeOp::fold(FoldAdaptor adaptor) {
+  return unaryShapeFold<Log2CeilShapeOp, Log2CeilFoldAdaptor>(this);
+}
+
+OpFoldResult tosa::Log2FloorShapeOp::fold(FoldAdaptor adaptor) {
+  return unaryShapeFold<Log2FloorShapeOp, Log2FloorFoldAdaptor>(this);
+}

diff  --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index c3186279a30ae..9afa668b78d9f 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -897,3 +897,93 @@ func.func @test_min_shape() -> !tosa.shape<6> {
   %c = tosa.min_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
   return %c : !tosa.shape<6>
 }
+
+// -----
+
+// CHECK-LABEL: @test_exp2_shape
+// CHECK: tosa.const_shape  {values = dense<[4, 8, 2, 16, 64, 32]> : tensor<6xindex>} : () -> !tosa.shape<6>
+func.func @test_exp2_shape() -> !tosa.shape<6> {
+  %a = tosa.const_shape {values = dense<[2, 3, 1, 4, 6, 5]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %c = tosa.exp2_shape %a : (!tosa.shape<6>) -> !tosa.shape<6>
+  return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_neg_exp2_shape
+// CHECK: tosa.exp2_shape
+func.func @test_neg_exp2_shape() -> !tosa.shape<6> {
+  %a = tosa.const_shape {values = dense<[-10, 3, 1, 4, 6, 5]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %c = tosa.exp2_shape %a : (!tosa.shape<6>) -> !tosa.shape<6>
+  return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_high_exp2_shape
+// CHECK: tosa.exp2_shape
+func.func @test_high_exp2_shape() -> !tosa.shape<6> {
+  %a = tosa.const_shape {values = dense<[32, 64, 1, 4, 6, 5]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %c = tosa.exp2_shape %a : (!tosa.shape<6>) -> !tosa.shape<6>
+  return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_log2_ceil_shape
+// CHECK: tosa.const_shape  {values = dense<[2, 4, 4, 0, 3, 5]> : tensor<6xindex>} : () -> !tosa.shape<6>
+func.func @test_log2_ceil_shape() -> !tosa.shape<6> {
+  %a = tosa.const_shape {values = dense<[4, 9, 14, 1, 7, 30]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %c = tosa.log2_ceil_shape %a : (!tosa.shape<6>) -> !tosa.shape<6>
+  return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_log2_ceil_shape_zero
+// CHECK: tosa.log2_ceil_shape
+func.func @test_log2_ceil_shape_zero() -> !tosa.shape<6> {
+  %a = tosa.const_shape {values = dense<[4, 9, 0, 1, 7, 30]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %c = tosa.log2_ceil_shape %a : (!tosa.shape<6>) -> !tosa.shape<6>
+  return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_log2_ceil_shape_neg
+// CHECK: tosa.log2_ceil_shape
+func.func @test_log2_ceil_shape_neg() -> !tosa.shape<6> {
+  %a = tosa.const_shape {values = dense<[4, 9, -123, 1, 7, 30]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %c = tosa.log2_ceil_shape %a : (!tosa.shape<6>) -> !tosa.shape<6>
+  return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_log2_floor_shape
+// CHECK: tosa.const_shape  {values = dense<[2, 3, 3, 0, 2, 4]> : tensor<6xindex>} : () -> !tosa.shape<6>
+func.func @test_log2_floor_shape() -> !tosa.shape<6> {
+  %a = tosa.const_shape {values = dense<[4, 9, 14, 1, 7, 30]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %c = tosa.log2_floor_shape %a : (!tosa.shape<6>) -> !tosa.shape<6>
+  return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_log2_floor_shape_zero
+// CHECK: tosa.log2_floor_shape
+func.func @test_log2_floor_shape_zero() -> !tosa.shape<6> {
+  %a = tosa.const_shape {values = dense<[4, 9, 0, 1, 7, 30]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %c = tosa.log2_floor_shape %a : (!tosa.shape<6>) -> !tosa.shape<6>
+  return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_log2_floor_shape_neg
+// CHECK: tosa.log2_floor_shape
+func.func @test_log2_floor_shape_neg() -> !tosa.shape<6> {
+  %a = tosa.const_shape {values = dense<[4, 9, -123, 1, 7, 30]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %c = tosa.log2_floor_shape %a : (!tosa.shape<6>) -> !tosa.shape<6>
+  return %c : !tosa.shape<6>
+}


        


More information about the Mlir-commits mailing list