[Mlir-commits] [mlir] dc741f2 - [mlir][tosa] Add constant folding for tosa.add_shape operation (#173112)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 23 04:50:22 PST 2026


Author: Luke Hutton
Date: 2026-01-23T12:50:17Z
New Revision: dc741f2f7e88edef573cc95882934ea8b44812cd

URL: https://github.com/llvm/llvm-project/commit/dc741f2f7e88edef573cc95882934ea8b44812cd
DIFF: https://github.com/llvm/llvm-project/commit/dc741f2f7e88edef573cc95882934ea8b44812cd.diff

LOG: [mlir][tosa] Add constant folding for tosa.add_shape operation (#173112)

This commit introduces constant folding for the tosa.add_shape
operation. When both operands of the add_shape operation are constant
shapes, the operation is evaluated at compile-time.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/test/Dialect/Tosa/constant_folding.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index 8fd176d3ea390..1783a5ef7c961 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -67,6 +67,8 @@ def Tosa_AddShapeOp : Tosa_ElementwiseShapeOp<"add_shape", [Pure]> {
   );
 
   let results = (outs Tosa_Shape:$output);
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index b15a3a4279064..14639ef5925ae 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -890,16 +890,28 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
 //===----------------------------------------------------------------------===//
 
 template <typename Folder>
-static DenseElementsAttr binaryFolder(DenseElementsAttr lhs,
-                                      DenseElementsAttr rhs,
-                                      RankedTensorType returnTy) {
-  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
-    const auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
-    const auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
-    if (lETy != rETy)
-      return {};
+static DenseElementsAttr
+binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy,
+             bool foldDenseValues = false) {
+  if (!lhs || !rhs)
+    return {};
 
-    if (const auto lIntTy = dyn_cast<IntegerType>(lETy)) {
+  const auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
+  const auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
+  if (lETy != rETy)
+    return {};
+
+  if (lhs.isSplat() && rhs.isSplat()) {
+    if (isa<FloatType>(lETy)) {
+      const APFloat l = lhs.getSplatValue<APFloat>();
+      const APFloat r = rhs.getSplatValue<APFloat>();
+      const auto maybeResult = Folder::fold(l, r);
+      if (failed(maybeResult))
+        return {};
+      return DenseElementsAttr::get(returnTy, maybeResult.value());
+    }
+
+    if (const auto lIntTy = llvm::dyn_cast<IntegerType>(lETy)) {
       const APInt l = lhs.getSplatValue<APInt>();
       const APInt r = rhs.getSplatValue<APInt>();
       const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned());
@@ -907,15 +919,20 @@ static DenseElementsAttr binaryFolder(DenseElementsAttr lhs,
         return {};
       return DenseElementsAttr::get(returnTy, maybeResult.value());
     }
+  }
 
-    if (llvm::isa<FloatType>(lETy)) {
-      const APFloat l = lhs.getSplatValue<APFloat>();
-      const APFloat r = rhs.getSplatValue<APFloat>();
-      const auto maybeResult = Folder::fold(l, r);
+  if (foldDenseValues) {
+    assert(lETy.isIntOrIndex() &&
+           "Only integer types are currently supported.");
+    SmallVector<APInt> resultValues;
+    for (auto [l, r] :
+         llvm::zip(lhs.getValues<APInt>(), rhs.getValues<APInt>())) {
+      const auto maybeResult = Folder::fold(l, r, false);
       if (failed(maybeResult))
         return {};
-      return DenseElementsAttr::get(returnTy, maybeResult.value());
+      resultValues.push_back(maybeResult.value());
     }
+    return DenseElementsAttr::get(returnTy, resultValues);
   }
 
   return {};
@@ -1683,3 +1700,18 @@ OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
 
   return {};
 }
+
+OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) {
+  auto input1ConstShape =
+      dyn_cast<tosa::ConstShapeOp>(getInput1().getDefiningOp());
+  auto input2ConstShape =
+      dyn_cast<tosa::ConstShapeOp>(getInput2().getDefiningOp());
+  if (!input1ConstShape || !input2ConstShape)
+    return {};
+
+  const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
+  const auto input2Attr = cast<DenseElementsAttr>(input2ConstShape.getValues());
+
+  return binaryFolder<FoldAddAdaptor>(
+      input1Attr, input2Attr, input1Attr.getType(), /*foldDenseValues=*/true);
+}

diff  --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 8c375b6c528ef..1007af6c8bd82 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -650,3 +650,36 @@ func.func @no_shift_op_reorder (%arg0 : tensor<44x1xi16>, %arg1 : tensor<1xi8>)
   %1 = tosa.mul %arg0, %0, %arg1 : (tensor<44x1xi16>, tensor<44x57xi16>, tensor<1xi8>) -> tensor<44x57xi32>
   return %1 : tensor<44x57xi32>
 }
+
+// -----
+
+// CHECK-LABEL: @test_fold_add_shape
+// CHECK: tosa.const_shape  {values = dense<[2, 4, 6, 8, 10, 12]> : tensor<6xindex>} : () -> !tosa.shape<6>
+func.func @test_fold_add_shape() -> !tosa.shape<6> {
+  %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %c = tosa.add_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
+  return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_no_fold_add_shape_positive_overflow
+// CHECK: tosa.add_shape
+func.func @test_no_fold_add_shape_positive_overflow() -> !tosa.shape<6> {
+  %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 9223372036854775807]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %c = tosa.add_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
+  return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_no_fold_add_shape_negative_overflow
+// CHECK: tosa.add_shape
+func.func @test_no_fold_add_shape_negative_overflow() -> !tosa.shape<6> {
+  %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, -9223372036854775808]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, -1]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %c = tosa.add_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
+  return %c : !tosa.shape<6>
+}


        


More information about the Mlir-commits mailing list