[Mlir-commits] [mlir] [mlir][tosa]: Add MIN_SHAPE, MAX_SHAPE Ops folders (PR #179488)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 3 08:03:05 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
Author: Udaya Ranga (udaya-ranga)
<details>
<summary>Changes</summary>
Change-Id: Ifc7d27cad875c22931351178c276f142a12e4bde
---
Full diff: https://github.com/llvm/llvm-project/pull/179488.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td (+4)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+42-7)
- (modified) mlir/test/Dialect/Tosa/constant_folding.mlir (+22)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index 104f2741e5678..cbcc2a017ac3a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -250,6 +250,8 @@ def Tosa_MaxShapeOp : Tosa_ElementwiseShapeOp<"max_shape", [Pure]> {
);
let results = (outs Tosa_Shape:$output);
+
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
@@ -268,6 +270,8 @@ def Tosa_MinShapeOp : Tosa_ElementwiseShapeOp<"min_shape", [Pure]> {
);
let results = (outs Tosa_Shape:$output);
+
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index ed1584f93a367..42033ce8a3b02 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1116,7 +1116,33 @@ struct ModFoldAdaptor {
}
};
-struct FoldGreaterAdaptor {
+struct MaxFoldAdaptor {
+ static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+ bool isUnsigned) {
+ if (lhs.getBitWidth() != rhs.getBitWidth())
+ return failure();
+ return lhs.getSExtValue() >= rhs.getSExtValue() ? lhs : rhs;
+ }
+
+ static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
+ return lhs >= rhs ? lhs : rhs;
+ }
+};
+
+struct MinFoldAdaptor {
+ static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+ bool isUnsigned) {
+ if (lhs.getBitWidth() != rhs.getBitWidth())
+ return failure();
+ return lhs.getSExtValue() <= rhs.getSExtValue() ? lhs : rhs;
+ }
+
+ static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
+ return lhs <= rhs ? lhs : rhs;
+ }
+};
+
+struct GreaterFoldAdaptor {
static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
const bool isUnsigned) {
return isUnsigned ? APInt(1, lhs.ugt(rhs)) : APInt(1, lhs.sgt(rhs));
@@ -1127,7 +1153,7 @@ struct FoldGreaterAdaptor {
}
};
-struct FoldGreaterEqualAdaptor {
+struct GreaterEqualFoldAdaptor {
static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
const bool isUnsigned) {
return isUnsigned ? APInt(1, lhs.uge(rhs)) : APInt(1, lhs.sge(rhs));
@@ -1138,7 +1164,7 @@ struct FoldGreaterEqualAdaptor {
}
};
-struct FoldEqualAdaptor {
+struct EqualFoldAdaptor {
static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
const bool isUnsigned) {
return APInt(1, lhs == rhs);
@@ -1247,8 +1273,9 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
APInt l = lhsAttr.getSplatValue<APInt>();
APInt r = rhsAttr.getSplatValue<APInt>();
if (!r.isZero()) {
+ auto intTy = dyn_cast<mlir::IntegerType>(resultETy);
auto const result =
- DivFoldAdaptor</*Ceil*/ false>::fold(l, r, /*isUnsigned*/ false);
+ DivFoldAdaptor</*Ceil*/ false>::fold(l, r, intTy.isUnsigned());
if (failed(result))
return {};
return DenseElementsAttr::get(resultTy, result.value());
@@ -1396,7 +1423,7 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<FoldGreaterAdaptor>(lhsAttr, rhsAttr, resultTy);
+ return binaryFolder<GreaterFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
}
OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
@@ -1409,7 +1436,7 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<FoldGreaterEqualAdaptor>(lhsAttr, rhsAttr, resultTy);
+ return binaryFolder<GreaterEqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
}
OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
@@ -1432,7 +1459,7 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<FoldEqualAdaptor>(lhsAttr, rhsAttr, resultTy);
+ return binaryFolder<EqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
}
OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
@@ -1909,3 +1936,11 @@ OpFoldResult tosa::DivFloorShapeOp::fold(FoldAdaptor adaptor) {
OpFoldResult tosa::ModShapeOp::fold(FoldAdaptor adaptor) {
return binaryFold<ModShapeOp, ModFoldAdaptor>(this);
}
+
+OpFoldResult tosa::MaxShapeOp::fold(FoldAdaptor adaptor) {
+ return binaryFold<MaxShapeOp, MaxFoldAdaptor>(this);
+}
+
+OpFoldResult tosa::MinShapeOp::fold(FoldAdaptor adaptor) {
+ return binaryFold<MinShapeOp, MinFoldAdaptor>(this);
+}
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index a8bcb9d52f000..c3186279a30ae 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -875,3 +875,25 @@ func.func @test_no_fold_mod_shape_negative_overflow() -> !tosa.shape<6> {
%c = tosa.mod_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
return %c : !tosa.shape<6>
}
+
+// -----
+
+// CHECK-LABEL: @test_max_shape
+// CHECK: tosa.const_shape {values = dense<[24, 7, 65, 33, 39, 5]> : tensor<6xindex>} : () -> !tosa.shape<6>
+func.func @test_max_shape() -> !tosa.shape<6> {
+ %a = tosa.const_shape {values = dense<[24, 7, 65, 33, 39, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %b = tosa.const_shape {values = dense<[11, 2, 12, 13, 15, 5]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %c = tosa.max_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
+ return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_min_shape
+// CHECK: tosa.const_shape {values = dense<[11, 2, 12, 13, 15, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
+func.func @test_min_shape() -> !tosa.shape<6> {
+ %a = tosa.const_shape {values = dense<[24, 7, 65, 33, 39, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %b = tosa.const_shape {values = dense<[11, 2, 12, 13, 15, 5]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %c = tosa.min_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
+ return %c : !tosa.shape<6>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/179488
More information about the Mlir-commits
mailing list