[Mlir-commits] [mlir] [mlir][tosa]: Add MIN_SHAPE, MAX_SHAPE Ops folders (PR #179488)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 3 08:03:05 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Udaya Ranga (udaya-ranga)

<details>
<summary>Changes</summary>

Change-Id: Ifc7d27cad875c22931351178c276f142a12e4bde

---
Full diff: https://github.com/llvm/llvm-project/pull/179488.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td (+4) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+42-7) 
- (modified) mlir/test/Dialect/Tosa/constant_folding.mlir (+22) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index 104f2741e5678..cbcc2a017ac3a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -250,6 +250,8 @@ def Tosa_MaxShapeOp : Tosa_ElementwiseShapeOp<"max_shape", [Pure]> {
   );
 
   let results = (outs Tosa_Shape:$output);
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -268,6 +270,8 @@ def Tosa_MinShapeOp : Tosa_ElementwiseShapeOp<"min_shape", [Pure]> {
   );
 
   let results = (outs Tosa_Shape:$output);
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index ed1584f93a367..42033ce8a3b02 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1116,7 +1116,33 @@ struct ModFoldAdaptor {
   }
 };
 
-struct FoldGreaterAdaptor {
+struct MaxFoldAdaptor {
+  static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+                               bool isUnsigned) {
+    if (lhs.getBitWidth() != rhs.getBitWidth())
+      return failure();
+    return lhs.getSExtValue() >= rhs.getSExtValue() ? lhs : rhs;
+  }
+
+  static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
+    return lhs >= rhs ? lhs : rhs;
+  }
+};
+
+struct MinFoldAdaptor {
+  static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+                               bool isUnsigned) {
+    if (lhs.getBitWidth() != rhs.getBitWidth())
+      return failure();
+    return lhs.getSExtValue() <= rhs.getSExtValue() ? lhs : rhs;
+  }
+
+  static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
+    return lhs <= rhs ? lhs : rhs;
+  }
+};
+
+struct GreaterFoldAdaptor {
   static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
                                const bool isUnsigned) {
     return isUnsigned ? APInt(1, lhs.ugt(rhs)) : APInt(1, lhs.sgt(rhs));
@@ -1127,7 +1153,7 @@ struct FoldGreaterAdaptor {
   }
 };
 
-struct FoldGreaterEqualAdaptor {
+struct GreaterEqualFoldAdaptor {
   static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
                                const bool isUnsigned) {
     return isUnsigned ? APInt(1, lhs.uge(rhs)) : APInt(1, lhs.sge(rhs));
@@ -1138,7 +1164,7 @@ struct FoldGreaterEqualAdaptor {
   }
 };
 
-struct FoldEqualAdaptor {
+struct EqualFoldAdaptor {
   static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
                                const bool isUnsigned) {
     return APInt(1, lhs == rhs);
@@ -1247,8 +1273,9 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
     APInt l = lhsAttr.getSplatValue<APInt>();
     APInt r = rhsAttr.getSplatValue<APInt>();
     if (!r.isZero()) {
+      auto intTy = dyn_cast<mlir::IntegerType>(resultETy);
       auto const result =
-          DivFoldAdaptor</*Ceil*/ false>::fold(l, r, /*isUnsigned*/ false);
+          DivFoldAdaptor</*Ceil*/ false>::fold(l, r, intTy.isUnsigned());
       if (failed(result))
         return {};
       return DenseElementsAttr::get(resultTy, result.value());
@@ -1396,7 +1423,7 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
   if (!lhsAttr || !rhsAttr)
     return {};
 
-  return binaryFolder<FoldGreaterAdaptor>(lhsAttr, rhsAttr, resultTy);
+  return binaryFolder<GreaterFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
 }
 
 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
@@ -1409,7 +1436,7 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
   if (!lhsAttr || !rhsAttr)
     return {};
 
-  return binaryFolder<FoldGreaterEqualAdaptor>(lhsAttr, rhsAttr, resultTy);
+  return binaryFolder<GreaterEqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
 }
 
 OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
@@ -1432,7 +1459,7 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
   if (!lhsAttr || !rhsAttr)
     return {};
 
-  return binaryFolder<FoldEqualAdaptor>(lhsAttr, rhsAttr, resultTy);
+  return binaryFolder<EqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
 }
 
 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
@@ -1909,3 +1936,11 @@ OpFoldResult tosa::DivFloorShapeOp::fold(FoldAdaptor adaptor) {
 OpFoldResult tosa::ModShapeOp::fold(FoldAdaptor adaptor) {
   return binaryFold<ModShapeOp, ModFoldAdaptor>(this);
 }
+
+OpFoldResult tosa::MaxShapeOp::fold(FoldAdaptor adaptor) {
+  return binaryFold<MaxShapeOp, MaxFoldAdaptor>(this);
+}
+
+OpFoldResult tosa::MinShapeOp::fold(FoldAdaptor adaptor) {
+  return binaryFold<MinShapeOp, MinFoldAdaptor>(this);
+}
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index a8bcb9d52f000..c3186279a30ae 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -875,3 +875,25 @@ func.func @test_no_fold_mod_shape_negative_overflow() -> !tosa.shape<6> {
   %c = tosa.mod_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
   return %c : !tosa.shape<6>
 }
+
+// -----
+
+// CHECK-LABEL: @test_max_shape
+// CHECK: tosa.const_shape  {values = dense<[24, 7, 65, 33, 39, 5]> : tensor<6xindex>} : () -> !tosa.shape<6>
+func.func @test_max_shape() -> !tosa.shape<6> {
+  %a = tosa.const_shape {values = dense<[24, 7, 65, 33, 39, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %b = tosa.const_shape {values = dense<[11, 2, 12, 13, 15, 5]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %c = tosa.max_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
+  return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_min_shape
+// CHECK: tosa.const_shape  {values = dense<[11, 2, 12, 13, 15, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
+func.func @test_min_shape() -> !tosa.shape<6> {
+  %a = tosa.const_shape {values = dense<[24, 7, 65, 33, 39, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %b = tosa.const_shape {values = dense<[11, 2, 12, 13, 15, 5]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %c = tosa.min_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
+  return %c : !tosa.shape<6>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/179488


More information about the Mlir-commits mailing list