[Mlir-commits] [mlir] [mlir][tosa] Support boolean types for clamp folder (PR #151653)
Longsheng Mou
llvmlistbot at llvm.org
Fri Aug 1 04:34:29 PDT 2025
https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/151653
>From bd470953daf5f6560d8fdc675a35c8f9205e45c8 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Fri, 1 Aug 2025 14:12:30 +0800
Subject: [PATCH 1/3] [mlir][tosa] Support boolean types for clamp folder
This PR fixes several bugs in `ClampIsNoOp` pattern.
- static shape check is no need.
- ensures i1 values are zero extended to support fold boolean types clamp.
---
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 36 ++++++++-----------
1 file changed, 15 insertions(+), 21 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 6d2cbb5539e14..c83aab199e067 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -452,10 +452,6 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
auto inputElementType = inputType.getElementType();
- if (!inputType.hasStaticShape()) {
- return failure();
- }
-
if (isa<FloatType>(inputElementType)) {
// Unlike integer types, floating point types can represent infinity.
auto minClamp =
@@ -472,18 +468,19 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
return failure();
}
- if (inputElementType.isUnsignedInteger()) {
- int64_t minClamp =
- llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getUInt();
- int64_t maxClamp =
- llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getUInt();
+ // i1 types are boolean in TOSA
+ const bool isBoolean = inputElementType.isInteger(1);
+ if (inputElementType.isUnsignedInteger() || isBoolean) {
+ int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
+ .getValue()
+ .getZExtValue();
+ int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
+ .getValue()
+ .getZExtValue();
- int64_t intMin =
- APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
- .getZExtValue();
- int64_t intMax =
- APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
- .getZExtValue();
+ unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
+ int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
+ int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
if (minClamp <= intMin && maxClamp >= intMax) {
rewriter.replaceOp(op, input);
@@ -498,12 +495,9 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
int64_t maxClamp =
llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
- int64_t intMin =
- APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
- .getSExtValue();
- int64_t intMax =
- APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
- .getSExtValue();
+ unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
+ int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
+ int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
if (minClamp <= intMin && maxClamp >= intMax) {
rewriter.replaceOp(op, input);
>From cf72618b471706144e9ab44d2c6a86e2d9aee55e Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Fri, 1 Aug 2025 14:15:33 +0800
Subject: [PATCH 2/3] add test
---
mlir/test/Dialect/Tosa/canonicalize.mlir | 20 ++++++++++++++++++++
1 file changed, 20 insertions(+)
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 6b55442a82a0a..5150ee36e9e5e 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -241,6 +241,26 @@ func.func @clamp_f32_is_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> {
// -----
+// CHECK-LABEL: @clamp_boolean_is_noop
+func.func @clamp_boolean_is_noop(%arg0: tensor<4xi1>) -> tensor<4xi1> {
+ // CHECK: return %arg0
+ // CHECK-NOT: tosa.clamp
+ %0 = tosa.clamp %arg0 {min_val = false, max_val = true} : (tensor<4xi1>) -> tensor<4xi1>
+ return %0 : tensor<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @clamp_boolean_dynamic_is_noop
+func.func @clamp_boolean_dynamic_is_noop(%arg0: tensor<?xi1>) -> tensor<?xi1> {
+ // CHECK: return %arg0
+ // CHECK-NOT: tosa.clamp
+ %0 = tosa.clamp %arg0 {min_val = false, max_val = true} : (tensor<?xi1>) -> tensor<?xi1>
+ return %0 : tensor<?xi1>
+}
+
+// -----
+
// CHECK-LABEL: @clamp_int8_is_noop
func.func @clamp_int8_is_noop(%arg0: tensor<4xi8>) -> tensor<4xi8> {
// CHECK: return %arg0
>From 73ca9707e15a3abe8ebaa170ded272d4e9d499d8 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Fri, 1 Aug 2025 19:33:57 +0800
Subject: [PATCH 3/3] add const
---
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 36 +++++++++----------
1 file changed, 18 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index c83aab199e067..e3cba38871909 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -454,12 +454,12 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
if (isa<FloatType>(inputElementType)) {
// Unlike integer types, floating point types can represent infinity.
- auto minClamp =
+ const auto minClamp =
llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
- auto maxClamp =
+ const auto maxClamp =
llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
- bool isMin = minClamp.isNegInfinity();
- bool isMax = maxClamp.isInfinity();
+ const bool isMin = minClamp.isNegInfinity();
+ const bool isMax = maxClamp.isInfinity();
if (isMin && isMax) {
rewriter.replaceOp(op, input);
@@ -471,16 +471,16 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
// i1 types are boolean in TOSA
const bool isBoolean = inputElementType.isInteger(1);
if (inputElementType.isUnsignedInteger() || isBoolean) {
- int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
- .getValue()
- .getZExtValue();
- int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
- .getValue()
- .getZExtValue();
+ const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
+ .getValue()
+ .getZExtValue();
+ const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
+ .getValue()
+ .getZExtValue();
- unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
- int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
- int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
+ const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
+ const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
+ const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
if (minClamp <= intMin && maxClamp >= intMax) {
rewriter.replaceOp(op, input);
@@ -490,14 +490,14 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
}
if (llvm::isa<IntegerType>(inputElementType)) {
- int64_t minClamp =
+ const int64_t minClamp =
llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
- int64_t maxClamp =
+ const int64_t maxClamp =
llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
- unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
- int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
- int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
+ const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
+ const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
+ const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
if (minClamp <= intMin && maxClamp >= intMax) {
rewriter.replaceOp(op, input);
More information about the Mlir-commits
mailing list