[Mlir-commits] [mlir] [mlir][tosa] Add constant folding for tosa.add_shape operation (PR #173112)

Luke Hutton llvmlistbot at llvm.org
Tue Jan 20 08:59:59 PST 2026


https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/173112

>From eb9fa252890bba92c0507052934d09db929c9440 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 17 Dec 2025 13:28:15 +0000
Subject: [PATCH 1/2] [mlir][tosa] Check for overflow in integer folders

For these folders to be TOSA compliant, they need to check
for overflow. This commit adds those checks, subsequently
preventing folding if an overflow is detected.

This commit also fixes the greater/greater_equal folders
to account for unsigned types.

Change-Id: I2b5a5b92fb840d6c34a1f2faa18ae68a20d0ecdf
---
 mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index b15a3a4279064..2c8611bd7e542 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -920,6 +920,7 @@ static DenseElementsAttr binaryFolder(DenseElementsAttr lhs,
 
   return {};
 }
+
 struct FoldAddAdaptor {
   static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
                                const bool isUnsigned) {

>From 3f88bd0c2345c15d9f59b640ecc8d7f571572f2c Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 17 Dec 2025 16:22:28 +0000
Subject: [PATCH 2/2] [mlir][tosa] Add constant folding for tosa.add_shape
 operation

This commit introduces constant folding for the tosa.add_shape
operation. When both operands of the add_shape operation are
constant shapes, the operation is evaluated at compile-time.

Change-Id: I5567fae8290bf238f809088573d40666fe3bdf51
---
 .../mlir/Dialect/Tosa/IR/TosaShapeOps.td      |  2 +
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 58 ++++++++++++++-----
 mlir/test/Dialect/Tosa/constant_folding.mlir  | 33 +++++++++++
 3 files changed, 79 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index d8597151714c3..6b2e1045cd0dd 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -67,6 +67,8 @@ def Tosa_AddShapeOp : Tosa_ElementwiseShapeOp<"add_shape", [Pure]> {
   );
 
   let results = (outs Tosa_Shape:$output);
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 2c8611bd7e542..e86fa3fee4e4c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -890,16 +890,28 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
 //===----------------------------------------------------------------------===//
 
 template <typename Folder>
-static DenseElementsAttr binaryFolder(DenseElementsAttr lhs,
-                                      DenseElementsAttr rhs,
-                                      RankedTensorType returnTy) {
-  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
-    const auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
-    const auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
-    if (lETy != rETy)
-      return {};
+static DenseElementsAttr
+binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy,
+             bool foldDenseValues = false) {
+  if (!lhs || !rhs)
+    return {};
 
-    if (const auto lIntTy = dyn_cast<IntegerType>(lETy)) {
+  const auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
+  const auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
+  if (lETy != rETy)
+    return {};
+
+  if (lhs.isSplat() && rhs.isSplat()) {
+    if (isa<FloatType>(lETy)) {
+      const APFloat l = lhs.getSplatValue<APFloat>();
+      const APFloat r = rhs.getSplatValue<APFloat>();
+      const auto maybeResult = Folder::fold(l, r);
+      if (failed(maybeResult))
+        return {};
+      return DenseElementsAttr::get(returnTy, maybeResult.value());
+    }
+
+    if (const auto lIntTy = llvm::dyn_cast<IntegerType>(lETy)) {
       const APInt l = lhs.getSplatValue<APInt>();
       const APInt r = rhs.getSplatValue<APInt>();
       const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned());
@@ -907,15 +919,18 @@ static DenseElementsAttr binaryFolder(DenseElementsAttr lhs,
         return {};
       return DenseElementsAttr::get(returnTy, maybeResult.value());
     }
+  }
 
-    if (llvm::isa<FloatType>(lETy)) {
-      const APFloat l = lhs.getSplatValue<APFloat>();
-      const APFloat r = rhs.getSplatValue<APFloat>();
-      const auto maybeResult = Folder::fold(l, r);
+  if (foldDenseValues) {
+    SmallVector<APInt> resultValues;
+    for (auto [l, r] :
+         llvm::zip(lhs.getValues<APInt>(), rhs.getValues<APInt>())) {
+      const auto maybeResult = Folder::fold(l, r, false);
       if (failed(maybeResult))
         return {};
-      return DenseElementsAttr::get(returnTy, maybeResult.value());
+      resultValues.push_back(maybeResult.value());
     }
+    return DenseElementsAttr::get(returnTy, resultValues);
   }
 
   return {};
@@ -1684,3 +1699,18 @@ OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
 
   return {};
 }
+
+OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) {
+  auto input1ConstShape =
+      dyn_cast<tosa::ConstShapeOp>(getInput1().getDefiningOp());
+  auto input2ConstShape =
+      dyn_cast<tosa::ConstShapeOp>(getInput2().getDefiningOp());
+  if (!input1ConstShape || !input2ConstShape)
+    return {};
+
+  const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
+  const auto input2Attr = cast<DenseElementsAttr>(input2ConstShape.getValues());
+
+  return binaryFolder<AddFoldAdaptor>(
+      input1Attr, input2Attr, input1Attr.getType(), /*foldDenseValues=*/true);
+}
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 8c375b6c528ef..1007af6c8bd82 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -650,3 +650,36 @@ func.func @no_shift_op_reorder (%arg0 : tensor<44x1xi16>, %arg1 : tensor<1xi8>)
   %1 = tosa.mul %arg0, %0, %arg1 : (tensor<44x1xi16>, tensor<44x57xi16>, tensor<1xi8>) -> tensor<44x57xi32>
   return %1 : tensor<44x57xi32>
 }
+
+// -----
+
+// CHECK-LABEL: @test_fold_add_shape
+// CHECK: tosa.const_shape  {values = dense<[2, 4, 6, 8, 10, 12]> : tensor<6xindex>} : () -> !tosa.shape<6>
+func.func @test_fold_add_shape() -> !tosa.shape<6> {
+  %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %c = tosa.add_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
+  return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_no_fold_add_shape_positive_overflow
+// CHECK: tosa.add_shape
+func.func @test_no_fold_add_shape_positive_overflow() -> !tosa.shape<6> {
+  %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 9223372036854775807]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %c = tosa.add_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
+  return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_no_fold_add_shape_negative_overflow
+// CHECK: tosa.add_shape
+func.func @test_no_fold_add_shape_negative_overflow() -> !tosa.shape<6> {
+  %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, -9223372036854775808]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, -1]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %c = tosa.add_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
+  return %c : !tosa.shape<6>
+}



More information about the Mlir-commits mailing list