[Mlir-commits] [mlir] 1023f6c - [mlir][tosa] Support boolean types for clamp folder (#151653)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Aug 1 04:46:43 PDT 2025
Author: Longsheng Mou
Date: 2025-08-01T12:46:39+01:00
New Revision: 1023f6c9db6c722fd0f8d631b114419cd522594f
URL: https://github.com/llvm/llvm-project/commit/1023f6c9db6c722fd0f8d631b114419cd522594f
DIFF: https://github.com/llvm/llvm-project/commit/1023f6c9db6c722fd0f8d631b114419cd522594f.diff
LOG: [mlir][tosa] Support boolean types for clamp folder (#151653)
This PR fixes several bugs in `ClampIsNoOp` pattern.
- static shape check is no need.
- ensures i1 values are zero extended to support fold boolean types
clamp.
Fixes #130016.
Added:
Modified:
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
mlir/test/Dialect/Tosa/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 6d2cbb5539e14..e3cba38871909 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -452,18 +452,14 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
auto inputElementType = inputType.getElementType();
- if (!inputType.hasStaticShape()) {
- return failure();
- }
-
if (isa<FloatType>(inputElementType)) {
// Unlike integer types, floating point types can represent infinity.
- auto minClamp =
+ const auto minClamp =
llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
- auto maxClamp =
+ const auto maxClamp =
llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
- bool isMin = minClamp.isNegInfinity();
- bool isMax = maxClamp.isInfinity();
+ const bool isMin = minClamp.isNegInfinity();
+ const bool isMax = maxClamp.isInfinity();
if (isMin && isMax) {
rewriter.replaceOp(op, input);
@@ -472,18 +468,19 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
return failure();
}
- if (inputElementType.isUnsignedInteger()) {
- int64_t minClamp =
- llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getUInt();
- int64_t maxClamp =
- llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getUInt();
+ // i1 types are boolean in TOSA
+ const bool isBoolean = inputElementType.isInteger(1);
+ if (inputElementType.isUnsignedInteger() || isBoolean) {
+ const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
+ .getValue()
+ .getZExtValue();
+ const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
+ .getValue()
+ .getZExtValue();
- int64_t intMin =
- APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
- .getZExtValue();
- int64_t intMax =
- APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
- .getZExtValue();
+ const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
+ const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
+ const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
if (minClamp <= intMin && maxClamp >= intMax) {
rewriter.replaceOp(op, input);
@@ -493,17 +490,14 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
}
if (llvm::isa<IntegerType>(inputElementType)) {
- int64_t minClamp =
+ const int64_t minClamp =
llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
- int64_t maxClamp =
+ const int64_t maxClamp =
llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
- int64_t intMin =
- APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
- .getSExtValue();
- int64_t intMax =
- APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
- .getSExtValue();
+ const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
+ const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
+ const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
if (minClamp <= intMin && maxClamp >= intMax) {
rewriter.replaceOp(op, input);
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 6b55442a82a0a..5150ee36e9e5e 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -241,6 +241,26 @@ func.func @clamp_f32_is_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> {
// -----
+// CHECK-LABEL: @clamp_boolean_is_noop
+func.func @clamp_boolean_is_noop(%arg0: tensor<4xi1>) -> tensor<4xi1> {
+ // CHECK: return %arg0
+ // CHECK-NOT: tosa.clamp
+ %0 = tosa.clamp %arg0 {min_val = false, max_val = true} : (tensor<4xi1>) -> tensor<4xi1>
+ return %0 : tensor<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @clamp_boolean_dynamic_is_noop
+func.func @clamp_boolean_dynamic_is_noop(%arg0: tensor<?xi1>) -> tensor<?xi1> {
+ // CHECK: return %arg0
+ // CHECK-NOT: tosa.clamp
+ %0 = tosa.clamp %arg0 {min_val = false, max_val = true} : (tensor<?xi1>) -> tensor<?xi1>
+ return %0 : tensor<?xi1>
+}
+
+// -----
+
// CHECK-LABEL: @clamp_int8_is_noop
func.func @clamp_int8_is_noop(%arg0: tensor<4xi8>) -> tensor<4xi8> {
// CHECK: return %arg0
More information about the Mlir-commits
mailing list