[Mlir-commits] [mlir] 728aa16 - [mlir][tosa]: Add Unary Shape Ops folders (#180762)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Feb 11 02:07:01 PST 2026
Author: Udaya Ranga
Date: 2026-02-11T10:06:56Z
New Revision: 728aa1666536d3d7197825aa4b2feba11341f528
URL: https://github.com/llvm/llvm-project/commit/728aa1666536d3d7197825aa4b2feba11341f528
DIFF: https://github.com/llvm/llvm-project/commit/728aa1666536d3d7197825aa4b2feba11341f528.diff
LOG: [mlir][tosa]: Add Unary Shape Ops folders (#180762)
* EXP2_SHAPE
* LOG2_CEIL_SHAPE
* LOG2_FLOOR_SHAPE
Signed-off-by: Udaya Ranga <udaya.ranga at arm.com>
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
mlir/test/Dialect/Tosa/constant_folding.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index cbcc2a017ac3a..45e03b6579970 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -198,6 +198,8 @@ def Tosa_Exp2ShapeOp : Tosa_ElementwiseShapeOp<"exp2_shape", [Pure]> {
);
let results = (outs Tosa_Shape:$output);
+
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
@@ -215,6 +217,8 @@ def Tosa_Log2CeilShapeOp : Tosa_ElementwiseShapeOp<"log2_ceil_shape", [Pure]> {
);
let results = (outs Tosa_Shape:$output);
+
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
@@ -232,6 +236,8 @@ def Tosa_Log2FloorShapeOp : Tosa_ElementwiseShapeOp<"log2_floor_shape", [Pure]>
);
let results = (outs Tosa_Shape:$output);
+
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 42033ce8a3b02..bb715a90b2ea2 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -986,6 +986,43 @@ binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy,
return {};
}
+
+template <typename Folder>
+static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy,
+ bool foldDenseValues = false) {
+ if (!val)
+ return {};
+
+ const auto vETy = llvm::cast<ShapedType>(val.getType()).getElementType();
+
+ if (val.isSplat()) {
+ if (const auto vIntTy = llvm::dyn_cast<IntegerType>(vETy)) {
+ const APInt v = val.getSplatValue<APInt>();
+ const auto maybeResult = Folder::fold(v, vIntTy.isUnsigned());
+ if (failed(maybeResult))
+ return {};
+ return DenseElementsAttr::get(returnTy, maybeResult.value());
+ }
+ }
+
+ if (foldDenseValues) {
+ mlir::Type elemTy = val.getElementType();
+ if (elemTy.isIntOrIndex()) {
+ SmallVector<APInt> resultValues;
+ for (auto const &v : val.getValues<APInt>()) {
+ const auto maybeResult = Folder::fold(v, false);
+ if (failed(maybeResult))
+ return {};
+ resultValues.push_back(maybeResult.value());
+ }
+ return DenseElementsAttr::get(returnTy, resultValues);
+ }
+ }
+
+ // Folding arbitrarily sized tensor operations is not supported
+ return {};
+}
+
struct AddFoldAdaptor {
static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
const bool isUnsigned) {
@@ -1142,6 +1179,38 @@ struct MinFoldAdaptor {
}
};
+struct Exp2FoldAdaptor {
+ static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
+ auto const numBits = value.getBitWidth();
+ if (isUnsigned) {
+ auto const zextv = value.getZExtValue();
+ if (zextv >= numBits)
+ return failure();
+ return APInt::getOneBitSet(numBits, zextv);
+ }
+ auto const sextv = value.getSExtValue();
+ if (sextv < 0 || sextv >= numBits || (value.isNegative()))
+ return failure();
+ return APInt::getOneBitSet(numBits, sextv);
+ }
+};
+
+struct Log2CeilFoldAdaptor {
+ static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
+ if (!value.isStrictlyPositive())
+ return failure();
+ return APInt(/*numBits=*/value.getBitWidth(), value.ceilLogBase2());
+ }
+};
+
+struct Log2FloorFoldAdaptor {
+ static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
+ if (!value.isStrictlyPositive())
+ return failure();
+ return APInt(/*numBits=*/value.getBitWidth(), value.logBase2());
+ }
+};
+
struct GreaterFoldAdaptor {
static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
const bool isUnsigned) {
@@ -1250,7 +1319,8 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
if (lhsTy != rhsTy)
return {};
- // IntDivOp inputs must be integer type, no need to check for quantized type
+ // IntDivOp inputs must be integer type, no need to check for quantized
+ // type
auto resultETy = resultTy.getElementType();
auto lhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
@@ -1300,7 +1370,8 @@ std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
auto round = APInt(64, 1) << (shift - 1);
result += round;
result.ashrInPlace(shift);
- // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
+ // REQUIRE(product >= minimum_s<i32_t>() && product <=
+ // maximum_s<i32_t>())
if (!(result.getSExtValue() >= INT32_MIN &&
result.getSExtValue() <= INT32_MAX)) {
// REQUIRE failed
@@ -1356,8 +1427,8 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
auto rhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
- // Result right shift on i32_t data type only. For simplification, synthesize
- // a zero shift for other data type.
+ // Result right shift on i32_t data type only. For simplification,
+ // synthesize a zero shift for other data type.
int32_t shift = 0;
if (resultETy.isInteger(32)) {
ElementsAttr shift_elem;
@@ -1449,8 +1520,8 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
Value rhs = getInput2();
auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
- // If we are comparing an integer value to itself it is always true. We can
- // not do this with float due to float values.
+ // If we are comparing an integer value to itself it is always true. We
+ // can not do this with float due to float values.
if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
resultTy.hasStaticShape() && lhs == rhs) {
return DenseElementsAttr::get(resultTy, true);
@@ -1561,9 +1632,9 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
if (!inputTy || !outputTy)
return {};
- // Fold when the input and output types are the same. This is only safe when
- // there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
- // there may still be a productive reshape.
+ // Fold when the input and output types are the same. This is only safe
+ // when there is at most 1 dynamic dimension. For 2 or more dynamic
+ // dimensions, there may still be a productive reshape.
if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
return getInput1();
@@ -1882,6 +1953,19 @@ OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
return {};
}
+template <typename Op, typename OpFoldAdaptor>
+OpFoldResult unaryShapeFold(Op *op) {
+ auto input1ConstShape =
+ dyn_cast<tosa::ConstShapeOp>(op->getInput().getDefiningOp());
+ if (!input1ConstShape)
+ return {};
+
+ const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
+
+ return unaryFolder<OpFoldAdaptor>(input1Attr, input1Attr.getType(),
+ /*foldDenseValues=*/true);
+}
+
template <typename Op, typename OpFoldAdaptor>
OpFoldResult binaryFold(Op *op) {
auto input1ConstShape =
@@ -1894,8 +1978,9 @@ OpFoldResult binaryFold(Op *op) {
const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
const auto input2Attr = cast<DenseElementsAttr>(input2ConstShape.getValues());
- return binaryFolder<OpFoldAdaptor>(
- input1Attr, input2Attr, input1Attr.getType(), /*foldDenseValues=*/true);
+ return binaryFolder<OpFoldAdaptor>(input1Attr, input2Attr,
+ input1Attr.getType(),
+ /*foldDenseValues=*/true);
}
OpFoldResult tosa::DimOp::fold(FoldAdaptor adaptor) {
@@ -1944,3 +2029,15 @@ OpFoldResult tosa::MaxShapeOp::fold(FoldAdaptor adaptor) {
OpFoldResult tosa::MinShapeOp::fold(FoldAdaptor adaptor) {
return binaryFold<MinShapeOp, MinFoldAdaptor>(this);
}
+
+OpFoldResult tosa::Exp2ShapeOp::fold(FoldAdaptor adaptor) {
+ return unaryShapeFold<Exp2ShapeOp, Exp2FoldAdaptor>(this);
+}
+
+OpFoldResult tosa::Log2CeilShapeOp::fold(FoldAdaptor adaptor) {
+ return unaryShapeFold<Log2CeilShapeOp, Log2CeilFoldAdaptor>(this);
+}
+
+OpFoldResult tosa::Log2FloorShapeOp::fold(FoldAdaptor adaptor) {
+ return unaryShapeFold<Log2FloorShapeOp, Log2FloorFoldAdaptor>(this);
+}
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index c3186279a30ae..9afa668b78d9f 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -897,3 +897,93 @@ func.func @test_min_shape() -> !tosa.shape<6> {
%c = tosa.min_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
return %c : !tosa.shape<6>
}
+
+// -----
+
+// CHECK-LABEL: @test_exp2_shape
+// CHECK: tosa.const_shape {values = dense<[4, 8, 2, 16, 64, 32]> : tensor<6xindex>} : () -> !tosa.shape<6>
+func.func @test_exp2_shape() -> !tosa.shape<6> {
+ %a = tosa.const_shape {values = dense<[2, 3, 1, 4, 6, 5]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %c = tosa.exp2_shape %a : (!tosa.shape<6>) -> !tosa.shape<6>
+ return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_neg_exp2_shape
+// CHECK: tosa.exp2_shape
+func.func @test_neg_exp2_shape() -> !tosa.shape<6> {
+ %a = tosa.const_shape {values = dense<[-10, 3, 1, 4, 6, 5]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %c = tosa.exp2_shape %a : (!tosa.shape<6>) -> !tosa.shape<6>
+ return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_high_exp2_shape
+// CHECK: tosa.exp2_shape
+func.func @test_high_exp2_shape() -> !tosa.shape<6> {
+ %a = tosa.const_shape {values = dense<[32, 64, 1, 4, 6, 5]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %c = tosa.exp2_shape %a : (!tosa.shape<6>) -> !tosa.shape<6>
+ return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_log2_ceil_shape
+// CHECK: tosa.const_shape {values = dense<[2, 4, 4, 0, 3, 5]> : tensor<6xindex>} : () -> !tosa.shape<6>
+func.func @test_log2_ceil_shape() -> !tosa.shape<6> {
+ %a = tosa.const_shape {values = dense<[4, 9, 14, 1, 7, 30]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %c = tosa.log2_ceil_shape %a : (!tosa.shape<6>) -> !tosa.shape<6>
+ return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_log2_ceil_shape_zero
+// CHECK: tosa.log2_ceil_shape
+func.func @test_log2_ceil_shape_zero() -> !tosa.shape<6> {
+ %a = tosa.const_shape {values = dense<[4, 9, 0, 1, 7, 30]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %c = tosa.log2_ceil_shape %a : (!tosa.shape<6>) -> !tosa.shape<6>
+ return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_log2_ceil_shape_neg
+// CHECK: tosa.log2_ceil_shape
+func.func @test_log2_ceil_shape_neg() -> !tosa.shape<6> {
+ %a = tosa.const_shape {values = dense<[4, 9, -123, 1, 7, 30]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %c = tosa.log2_ceil_shape %a : (!tosa.shape<6>) -> !tosa.shape<6>
+ return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_log2_floor_shape
+// CHECK: tosa.const_shape {values = dense<[2, 3, 3, 0, 2, 4]> : tensor<6xindex>} : () -> !tosa.shape<6>
+func.func @test_log2_floor_shape() -> !tosa.shape<6> {
+ %a = tosa.const_shape {values = dense<[4, 9, 14, 1, 7, 30]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %c = tosa.log2_floor_shape %a : (!tosa.shape<6>) -> !tosa.shape<6>
+ return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_log2_floor_shape_zero
+// CHECK: tosa.log2_floor_shape
+func.func @test_log2_floor_shape_zero() -> !tosa.shape<6> {
+ %a = tosa.const_shape {values = dense<[4, 9, 0, 1, 7, 30]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %c = tosa.log2_floor_shape %a : (!tosa.shape<6>) -> !tosa.shape<6>
+ return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_log2_floor_shape_neg
+// CHECK: tosa.log2_floor_shape
+func.func @test_log2_floor_shape_neg() -> !tosa.shape<6> {
+ %a = tosa.const_shape {values = dense<[4, 9, -123, 1, 7, 30]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %c = tosa.log2_floor_shape %a : (!tosa.shape<6>) -> !tosa.shape<6>
+ return %c : !tosa.shape<6>
+}
More information about the Mlir-commits
mailing list