[Mlir-commits] [mlir] 1023f6c - [mlir][tosa] Support boolean types for clamp folder (#151653)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 1 04:46:43 PDT 2025


Author: Longsheng Mou
Date: 2025-08-01T12:46:39+01:00
New Revision: 1023f6c9db6c722fd0f8d631b114419cd522594f

URL: https://github.com/llvm/llvm-project/commit/1023f6c9db6c722fd0f8d631b114419cd522594f
DIFF: https://github.com/llvm/llvm-project/commit/1023f6c9db6c722fd0f8d631b114419cd522594f.diff

LOG: [mlir][tosa] Support boolean types for clamp folder (#151653)

This PR fixes several bugs in `ClampIsNoOp` pattern.
- static shape check is no need.
- ensures i1 values are zero extended to support fold boolean types
clamp.

Fixes #130016.

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 6d2cbb5539e14..e3cba38871909 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -452,18 +452,14 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
     auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
     auto inputElementType = inputType.getElementType();
 
-    if (!inputType.hasStaticShape()) {
-      return failure();
-    }
-
     if (isa<FloatType>(inputElementType)) {
       // Unlike integer types, floating point types can represent infinity.
-      auto minClamp =
+      const auto minClamp =
           llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
-      auto maxClamp =
+      const auto maxClamp =
           llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
-      bool isMin = minClamp.isNegInfinity();
-      bool isMax = maxClamp.isInfinity();
+      const bool isMin = minClamp.isNegInfinity();
+      const bool isMax = maxClamp.isInfinity();
 
       if (isMin && isMax) {
         rewriter.replaceOp(op, input);
@@ -472,18 +468,19 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
       return failure();
     }
 
-    if (inputElementType.isUnsignedInteger()) {
-      int64_t minClamp =
-          llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getUInt();
-      int64_t maxClamp =
-          llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getUInt();
+    // i1 types are boolean in TOSA
+    const bool isBoolean = inputElementType.isInteger(1);
+    if (inputElementType.isUnsignedInteger() || isBoolean) {
+      const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
+                                   .getValue()
+                                   .getZExtValue();
+      const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
+                                   .getValue()
+                                   .getZExtValue();
 
-      int64_t intMin =
-          APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
-              .getZExtValue();
-      int64_t intMax =
-          APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
-              .getZExtValue();
+      const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
+      const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
+      const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
 
       if (minClamp <= intMin && maxClamp >= intMax) {
         rewriter.replaceOp(op, input);
@@ -493,17 +490,14 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
     }
 
     if (llvm::isa<IntegerType>(inputElementType)) {
-      int64_t minClamp =
+      const int64_t minClamp =
           llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
-      int64_t maxClamp =
+      const int64_t maxClamp =
           llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
 
-      int64_t intMin =
-          APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
-              .getSExtValue();
-      int64_t intMax =
-          APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
-              .getSExtValue();
+      const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
+      const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
+      const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
 
       if (minClamp <= intMin && maxClamp >= intMax) {
         rewriter.replaceOp(op, input);

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 6b55442a82a0a..5150ee36e9e5e 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -241,6 +241,26 @@ func.func @clamp_f32_is_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> {
 
 // -----
 
+// CHECK-LABEL: @clamp_boolean_is_noop
+func.func @clamp_boolean_is_noop(%arg0: tensor<4xi1>) -> tensor<4xi1> {
+  // CHECK: return %arg0
+  // CHECK-NOT: tosa.clamp
+  %0 = tosa.clamp %arg0 {min_val = false, max_val = true} : (tensor<4xi1>) -> tensor<4xi1>
+  return %0 : tensor<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @clamp_boolean_dynamic_is_noop
+func.func @clamp_boolean_dynamic_is_noop(%arg0: tensor<?xi1>) -> tensor<?xi1> {
+  // CHECK: return %arg0
+  // CHECK-NOT: tosa.clamp
+  %0 = tosa.clamp %arg0 {min_val = false, max_val = true} : (tensor<?xi1>) -> tensor<?xi1>
+  return %0 : tensor<?xi1>
+}
+
+// -----
+
 // CHECK-LABEL: @clamp_int8_is_noop
 func.func @clamp_int8_is_noop(%arg0: tensor<4xi8>) -> tensor<4xi8> {
   // CHECK: return %arg0


        


More information about the Mlir-commits mailing list