[Mlir-commits] [mlir] [mlir][tosa] Add constant folding for tosa.add_shape operation (PR #173112)
Luke Hutton
llvmlistbot at llvm.org
Tue Jan 20 04:47:20 PST 2026
https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/173112
>From 56a03ff7d28e62c5fbc80eec902863d34d5bbe83 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 17 Dec 2025 13:28:15 +0000
Subject: [PATCH 1/2] [mlir][tosa] Check for overflow in integer folders
For these folders to be TOSA compliant, they need to check
for overflow. This commit adds those checks, subsequently
preventing folding if an overflow is detected.
This commit also fixes the greater/greater_equal folders
to account for unsigned types.
Change-Id: I2b5a5b92fb840d6c34a1f2faa18ae68a20d0ecdf
---
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 177 +++++++++++++-----
mlir/test/Dialect/Tosa/constant_folding.mlir | 121 ++++++++++++
2 files changed, 246 insertions(+), 52 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index c420a4c9596ff..3e9d803a916a9 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -889,33 +889,141 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
// Operator Folders.
//===----------------------------------------------------------------------===//
-template <typename IntFolder, typename FloatFolder>
+template <typename Folder>
static DenseElementsAttr binaryFolder(DenseElementsAttr lhs,
DenseElementsAttr rhs,
RankedTensorType returnTy) {
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
- auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
- auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
+ const auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
+ const auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
if (lETy != rETy)
return {};
- if (llvm::isa<IntegerType>(lETy)) {
- APInt l = lhs.getSplatValue<APInt>();
- APInt r = rhs.getSplatValue<APInt>();
- auto result = IntFolder()(l, r);
- return DenseElementsAttr::get(returnTy, result);
+ if (const auto lIntTy = dyn_cast<IntegerType>(lETy)) {
+ const APInt l = lhs.getSplatValue<APInt>();
+ const APInt r = rhs.getSplatValue<APInt>();
+ const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned());
+ if (failed(maybeResult))
+ return {};
+ return DenseElementsAttr::get(returnTy, maybeResult.value());
}
if (llvm::isa<FloatType>(lETy)) {
- APFloat l = lhs.getSplatValue<APFloat>();
- APFloat r = rhs.getSplatValue<APFloat>();
- auto result = FloatFolder()(l, r);
- return DenseElementsAttr::get(returnTy, result);
+ const APFloat l = lhs.getSplatValue<APFloat>();
+ const APFloat r = rhs.getSplatValue<APFloat>();
+ const auto maybeResult = Folder::fold(l, r);
+ if (failed(maybeResult))
+ return {};
+ return DenseElementsAttr::get(returnTy, maybeResult.value());
}
}
return {};
}
+struct AddFoldAdaptor {
+ static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+ const bool isUnsigned) {
+ const unsigned originalWidth = lhs.getBitWidth();
+
+ APInt lhs64, rhs64;
+ if (isUnsigned) {
+ lhs64 = lhs.zext(64);
+ rhs64 = rhs.zext(64);
+
+ // Check for overflow
+ const APInt max = APInt::getMaxValue(originalWidth).zext(64);
+ if (lhs64.ugt(max - rhs64))
+ return failure();
+ } else {
+ lhs64 = lhs.sext(64);
+ rhs64 = rhs.sext(64);
+
+ // Check for overflow
+ const APInt zero = APInt::getZero(64);
+ const APInt max = APInt::getSignedMaxValue(originalWidth).sext(64);
+ const APInt min = APInt::getSignedMinValue(originalWidth).sext(64);
+ if ((rhs64.sgt(zero) && lhs64.sgt(max - rhs64)) ||
+ (rhs64.slt(zero) && lhs64.slt(min - rhs64)))
+ return failure();
+ }
+
+ const APInt result64 = lhs64 + rhs64;
+ return result64.trunc(originalWidth);
+ }
+
+ static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
+ return lhs + rhs;
+ }
+};
+
+struct SubFoldAdaptor {
+ static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+ const bool isUnsigned) {
+ const unsigned originalWidth = lhs.getBitWidth();
+
+ APInt lhs64, rhs64;
+ if (isUnsigned) {
+ lhs64 = lhs.zext(64);
+ rhs64 = rhs.zext(64);
+
+ // Check for overflow
+ const APInt max = APInt::getMaxValue(originalWidth).zext(64);
+ if (lhs64.ult(rhs64))
+ return failure();
+ } else {
+ lhs64 = lhs.sext(64);
+ rhs64 = rhs.sext(64);
+
+ // Check for overflow
+ const APInt zero = APInt::getZero(64);
+ const APInt max = APInt::getSignedMaxValue(originalWidth).sext(64);
+ const APInt min = APInt::getSignedMinValue(originalWidth).sext(64);
+ if ((rhs64.sgt(zero) && lhs64.slt(min + rhs64)) ||
+ (rhs64.slt(zero) && lhs64.sgt(max + rhs64)))
+ return failure();
+ }
+
+ const APInt result64 = lhs64 - rhs64;
+ return result64.trunc(originalWidth);
+ }
+
+ static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
+ return lhs - rhs;
+ }
+};
+
+struct FoldGreaterAdaptor {
+ static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+ const bool isUnsigned) {
+ return isUnsigned ? APInt(1, lhs.ugt(rhs)) : APInt(1, lhs.sgt(rhs));
+ }
+
+ static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
+ return APInt(1, lhs > rhs);
+ }
+};
+
+struct FoldGreaterEqualAdaptor {
+ static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+ const bool isUnsigned) {
+ return isUnsigned ? APInt(1, lhs.uge(rhs)) : APInt(1, lhs.sge(rhs));
+ }
+
+ static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
+ return APInt(1, lhs >= rhs);
+ }
+};
+
+struct FoldEqualAdaptor {
+ static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+ const bool isUnsigned) {
+ return APInt(1, lhs == rhs);
+ }
+
+ static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
+ return APInt(1, lhs == rhs);
+ }
+};
static bool isSplatZero(Type elemType, DenseElementsAttr val) {
if (llvm::isa<FloatType>(elemType))
@@ -963,8 +1071,7 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
- resultTy);
+ return binaryFolder<AddFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
}
OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
@@ -1145,38 +1252,9 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
- resultTy);
+ return binaryFolder<SubFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
}
-namespace {
-template <typename Cmp>
-struct ComparisonFold {
- ComparisonFold() = default;
- APInt operator()(const APInt &l, const APInt &r) {
- return APInt(1, Cmp()(l, r));
- }
-
- APInt operator()(const APFloat &l, const APFloat &r) {
- return APInt(1, Cmp()(l, r));
- }
-};
-
-struct APIntFoldGreater {
- APIntFoldGreater() = default;
- APInt operator()(const APInt &l, const APInt &r) {
- return APInt(1, l.sgt(r));
- }
-};
-
-struct APIntFoldGreaterEqual {
- APIntFoldGreaterEqual() = default;
- APInt operator()(const APInt &l, const APInt &r) {
- return APInt(1, l.sge(r));
- }
-};
-} // namespace
-
OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
auto lhsAttr =
@@ -1187,8 +1265,7 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
- lhsAttr, rhsAttr, resultTy);
+ return binaryFolder<FoldGreaterAdaptor>(lhsAttr, rhsAttr, resultTy);
}
OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
@@ -1201,9 +1278,7 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<APIntFoldGreaterEqual,
- ComparisonFold<std::greater_equal<APFloat>>>(
- lhsAttr, rhsAttr, resultTy);
+ return binaryFolder<FoldGreaterEqualAdaptor>(lhsAttr, rhsAttr, resultTy);
}
OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
@@ -1226,9 +1301,7 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
- ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
- resultTy);
+ return binaryFolder<FoldEqualAdaptor>(lhsAttr, rhsAttr, resultTy);
}
OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 68bfd4c7e4980..8c375b6c528ef 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -92,6 +92,50 @@ func.func @fold_add_splat_f32() -> tensor<10xf32> {
// -----
+// CHECK-LABEL: @fold_add_splat_i32_positive_overflow
+func.func @fold_add_splat_i32_positive_overflow() -> tensor<10xi32> {
+ %one = "tosa.const"() {values = dense<2147483647> : tensor<10xi32>} : () -> tensor<10xi32>
+ %two = "tosa.const"() {values = dense<1> : tensor<10xi32>} : () -> tensor<10xi32>
+ // CHECK: tosa.add
+ %add = tosa.add %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
+ return %add : tensor<10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_add_splat_i32_negative_overflow
+func.func @fold_add_splat_i32_negative_overflow() -> tensor<10xi32> {
+ %one = "tosa.const"() {values = dense<-1> : tensor<10xi32>} : () -> tensor<10xi32>
+ %two = "tosa.const"() {values = dense<-2147483648> : tensor<10xi32>} : () -> tensor<10xi32>
+ // CHECK: tosa.add
+ %add = tosa.add %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
+ return %add : tensor<10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_add_splat_ui8
+func.func @fold_add_splat_ui8() -> tensor<10xui8> {
+ %one = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8>
+ %two = "tosa.const"() {values = dense<254> : tensor<10xui8>} : () -> tensor<10xui8>
+ // CHECK: "tosa.const"() <{values = dense<255> : tensor<10xui8>}> : () -> tensor<10xui8>
+ %add = tosa.add %one, %two : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xui8>
+ return %add : tensor<10xui8>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_add_splat_ui8_overflow
+func.func @fold_add_splat_ui8_overflow() -> tensor<10xui8> {
+ %one = "tosa.const"() {values = dense<2> : tensor<10xui8>} : () -> tensor<10xui8>
+ %two = "tosa.const"() {values = dense<254> : tensor<10xui8>} : () -> tensor<10xui8>
+ // CHECK: tosa.add
+ %add = tosa.add %one, %two : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xui8>
+ return %add : tensor<10xui8>
+}
+
+// -----
+
// CHECK-LABEL: @fold_div_zero_lhs_i32
func.func @fold_div_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
%zero = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
@@ -288,6 +332,50 @@ func.func @fold_sub_splat_f32() -> tensor<10xf32> {
// -----
+// CHECK-LABEL: @fold_sub_splat_i32_positive_overflow
+func.func @fold_sub_splat_i32_positive_overflow() -> tensor<10xi32> {
+ %one = "tosa.const"() {values = dense<2147483647> : tensor<10xi32>} : () -> tensor<10xi32>
+ %two = "tosa.const"() {values = dense<-1> : tensor<10xi32>} : () -> tensor<10xi32>
+ // CHECK: tosa.sub
+ %sub = tosa.sub %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
+ return %sub : tensor<10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_sub_splat_i32_negative_overflow
+func.func @fold_sub_splat_i32_negative_overflow() -> tensor<10xi32> {
+ %one = "tosa.const"() {values = dense<-2147483648> : tensor<10xi32>} : () -> tensor<10xi32>
+ %two = "tosa.const"() {values = dense<1> : tensor<10xi32>} : () -> tensor<10xi32>
+ // CHECK: tosa.sub
+ %sub = tosa.sub %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
+ return %sub : tensor<10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_sub_splat_ui8
+func.func @fold_sub_splat_ui8() -> tensor<10xui8> {
+ %one = "tosa.const"() {values = dense<255> : tensor<10xui8>} : () -> tensor<10xui8>
+ %two = "tosa.const"() {values = dense<253> : tensor<10xui8>} : () -> tensor<10xui8>
+ // CHECK: "tosa.const"() <{values = dense<2> : tensor<10xui8>}> : () -> tensor<10xui8>
+ %sub = tosa.sub %one, %two : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xui8>
+ return %sub : tensor<10xui8>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_sub_splat_ui8_overflow
+func.func @fold_sub_splat_ui8_overflow() -> tensor<10xui8> {
+ %one = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8>
+ %two = "tosa.const"() {values = dense<253> : tensor<10xui8>} : () -> tensor<10xui8>
+ // CHECK: tosa.sub
+ %sub = tosa.sub %one, %two : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xui8>
+ return %sub : tensor<10xui8>
+}
+
+// -----
+
// CHECK-LABEL: @fold_greater_splat_f32
func.func @fold_greater_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
%0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
@@ -320,6 +408,23 @@ func.func @fold_greater_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
// -----
+// CHECK-LABEL: @fold_greater_splat_ui8
+func.func @fold_greater_splat_ui8() -> (tensor<10xi1>, tensor<10xi1>, tensor<10xi1>) {
+ %0 = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8>
+ %1 = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8>
+ %2 = "tosa.const"() {values = dense<246> : tensor<10xui8>} : () -> tensor<10xui8>
+ %3 = "tosa.const"() {values = dense<245> : tensor<10xui8>} : () -> tensor<10xui8>
+ %true = tosa.greater %2, %3 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1>
+ %false = tosa.greater %0, %1 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1>
+ %false2 = tosa.greater %0, %2 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1>
+ // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense<true> : tensor<10xi1>}
+ // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense<false> : tensor<10xi1>}
+ // CHECK: return %[[TRUE]], %[[FALSE]], %[[FALSE]]
+ return %true, %false, %false2 : tensor<10xi1>, tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
// CHECK-LABEL: @fold_greater_eq_splat_f32
func.func @fold_greater_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
%0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
@@ -352,6 +457,22 @@ func.func @fold_greater_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
// -----
+// CHECK-LABEL: @fold_greater_eq_splat_ui8
+func.func @fold_greater_eq_splat_ui8() -> (tensor<10xi1>, tensor<10xi1>) {
+ %0 = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8>
+ %1 = "tosa.const"() {values = dense<255> : tensor<10xui8>} : () -> tensor<10xui8>
+ %2 = "tosa.const"() {values = dense<245> : tensor<10xui8>} : () -> tensor<10xui8>
+ %3 = "tosa.const"() {values = dense<245> : tensor<10xui8>} : () -> tensor<10xui8>
+ %true = tosa.greater_equal %2, %3 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1>
+ %false = tosa.greater_equal %0, %1 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1>
+ // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense<true> : tensor<10xi1>}
+ // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense<false> : tensor<10xi1>}
+ // CHECK: return %[[TRUE]], %[[FALSE]]
+ return %true, %false : tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
// CHECK-LABEL: @fold_eq_splat_f32
func.func @fold_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
%0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
>From b60d3753157442da5d62a8bd63d081e779b49800 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 17 Dec 2025 16:22:28 +0000
Subject: [PATCH 2/2] [mlir][tosa] Add constant folding for tosa.add_shape
operation
This commit introduces constant folding for the tosa.add_shape
operation. When both operands of the add_shape operation are
constant shapes, the operation is evaluated at compile-time.
Change-Id: I5567fae8290bf238f809088573d40666fe3bdf51
---
.../mlir/Dialect/Tosa/IR/TosaShapeOps.td | 2 +
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 58 ++++++++++++++-----
mlir/test/Dialect/Tosa/constant_folding.mlir | 33 +++++++++++
3 files changed, 79 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index d8597151714c3..6b2e1045cd0dd 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -67,6 +67,8 @@ def Tosa_AddShapeOp : Tosa_ElementwiseShapeOp<"add_shape", [Pure]> {
);
let results = (outs Tosa_Shape:$output);
+
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 3e9d803a916a9..f0a02fea5863a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -890,16 +890,28 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
//===----------------------------------------------------------------------===//
template <typename Folder>
-static DenseElementsAttr binaryFolder(DenseElementsAttr lhs,
- DenseElementsAttr rhs,
- RankedTensorType returnTy) {
- if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
- const auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
- const auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
- if (lETy != rETy)
- return {};
+static DenseElementsAttr
+binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy,
+ bool foldDenseValues = false) {
+ if (!lhs || !rhs)
+ return {};
- if (const auto lIntTy = dyn_cast<IntegerType>(lETy)) {
+ const auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
+ const auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
+ if (lETy != rETy)
+ return {};
+
+ if (lhs.isSplat() && rhs.isSplat()) {
+ if (isa<FloatType>(lETy)) {
+ const APFloat l = lhs.getSplatValue<APFloat>();
+ const APFloat r = rhs.getSplatValue<APFloat>();
+ const auto maybeResult = Folder::fold(l, r);
+ if (failed(maybeResult))
+ return {};
+ return DenseElementsAttr::get(returnTy, maybeResult.value());
+ }
+
+ if (const auto lIntTy = llvm::dyn_cast<IntegerType>(lETy)) {
const APInt l = lhs.getSplatValue<APInt>();
const APInt r = rhs.getSplatValue<APInt>();
const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned());
@@ -907,15 +919,18 @@ static DenseElementsAttr binaryFolder(DenseElementsAttr lhs,
return {};
return DenseElementsAttr::get(returnTy, maybeResult.value());
}
+ }
- if (llvm::isa<FloatType>(lETy)) {
- const APFloat l = lhs.getSplatValue<APFloat>();
- const APFloat r = rhs.getSplatValue<APFloat>();
- const auto maybeResult = Folder::fold(l, r);
+ if (foldDenseValues) {
+ SmallVector<APInt> resultValues;
+ for (auto [l, r] :
+ llvm::zip(lhs.getValues<APInt>(), rhs.getValues<APInt>())) {
+ const auto maybeResult = Folder::fold(l, r, false);
if (failed(maybeResult))
return {};
- return DenseElementsAttr::get(returnTy, maybeResult.value());
+ resultValues.push_back(maybeResult.value());
}
+ return DenseElementsAttr::get(returnTy, resultValues);
}
return {};
@@ -1723,3 +1738,18 @@ OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
return {};
}
+
+OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) {
+ auto input1ConstShape =
+ dyn_cast<tosa::ConstShapeOp>(getInput1().getDefiningOp());
+ auto input2ConstShape =
+ dyn_cast<tosa::ConstShapeOp>(getInput2().getDefiningOp());
+ if (!input1ConstShape || !input2ConstShape)
+ return {};
+
+ const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
+ const auto input2Attr = cast<DenseElementsAttr>(input2ConstShape.getValues());
+
+ return binaryFolder<AddFoldAdaptor>(
+ input1Attr, input2Attr, input1Attr.getType(), /*foldDenseValues=*/true);
+}
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 8c375b6c528ef..1007af6c8bd82 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -650,3 +650,36 @@ func.func @no_shift_op_reorder (%arg0 : tensor<44x1xi16>, %arg1 : tensor<1xi8>)
%1 = tosa.mul %arg0, %0, %arg1 : (tensor<44x1xi16>, tensor<44x57xi16>, tensor<1xi8>) -> tensor<44x57xi32>
return %1 : tensor<44x57xi32>
}
+
+// -----
+
+// CHECK-LABEL: @test_fold_add_shape
+// CHECK: tosa.const_shape {values = dense<[2, 4, 6, 8, 10, 12]> : tensor<6xindex>} : () -> !tosa.shape<6>
+func.func @test_fold_add_shape() -> !tosa.shape<6> {
+ %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %c = tosa.add_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
+ return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_no_fold_add_shape_positive_overflow
+// CHECK: tosa.add_shape
+func.func @test_no_fold_add_shape_positive_overflow() -> !tosa.shape<6> {
+ %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 9223372036854775807]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %c = tosa.add_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
+ return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_no_fold_add_shape_negative_overflow
+// CHECK: tosa.add_shape
+func.func @test_no_fold_add_shape_negative_overflow() -> !tosa.shape<6> {
+ %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, -9223372036854775808]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, -1]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %c = tosa.add_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
+ return %c : !tosa.shape<6>
+}
More information about the Mlir-commits
mailing list