[Mlir-commits] [mlir] 41d05e2 - [mlir][tosa] Add tosa.clamp as no-op canonicalization

Rob Suderman llvmlistbot at llvm.org
Tue Jan 18 23:17:54 PST 2022


Author: not-jenni
Date: 2022-01-18T23:15:40-08:00
New Revision: 41d05e29c04fb86a240169c9e10bad629f5897e7

URL: https://github.com/llvm/llvm-project/commit/41d05e29c04fb86a240169c9e10bad629f5897e7
DIFF: https://github.com/llvm/llvm-project/commit/41d05e29c04fb86a240169c9e10bad629f5897e7.diff

LOG: [mlir][tosa] Add tosa.clamp as no-op canonicalization

When the min/max are the total range of the value, it is a no-op as the values
are already restricted to that range.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D117625

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 597813481d684..094d136999376 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -341,6 +341,8 @@ def Tosa_ClampOp : Tosa_Op<"clamp", [
   let results = (outs
     Tosa_Tensor:$output
   );
+
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 303e433e1b16b..f93e5b2052b9d 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -458,6 +458,79 @@ void MaxPool2dOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
   results.insert<MaxPool2dIsNoOp>(context);
 }
 
+struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::ClampOp op,
+                                PatternRewriter &rewriter) const override {
+    Value input = op.input();
+    auto inputType = op.input().getType().template dyn_cast<RankedTensorType>();
+    auto inputElementType = inputType.getElementType();
+
+    if (!inputType.hasStaticShape()) {
+      return failure();
+    }
+
+    if (inputElementType.isF32()) {
+      auto minClamp = op.min_fp();
+      auto maxClamp = op.max_fp();
+      bool isMin = (minClamp.isLargest() || minClamp.isInfinity()) &&
+                   minClamp.isNegative();
+      bool isMax = (maxClamp.isLargest() || maxClamp.isInfinity()) &&
+                   !maxClamp.isNegative();
+
+      if (isMin && isMax) {
+        rewriter.replaceOp(op, input);
+        return success();
+      }
+      return failure();
+    }
+
+    if (inputElementType.isUnsignedInteger()) {
+      int64_t minClamp = op.min_int();
+      int64_t maxClamp = op.max_int();
+
+      int64_t intMin =
+          APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
+              .getZExtValue();
+      int64_t intMax =
+          APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
+              .getZExtValue();
+
+      if (minClamp <= intMin && maxClamp >= intMax) {
+        rewriter.replaceOp(op, input);
+        return success();
+      }
+      return failure();
+    }
+
+    if (inputElementType.isa<IntegerType>()) {
+      int64_t minClamp = op.min_int();
+      int64_t maxClamp = op.max_int();
+
+      int64_t intMin =
+          APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
+              .getSExtValue();
+      int64_t intMax =
+          APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
+              .getSExtValue();
+
+      if (minClamp <= intMin && maxClamp >= intMax) {
+        rewriter.replaceOp(op, input);
+        return success();
+      }
+      return failure();
+    }
+
+    return failure();
+  }
+};
+
+void ClampOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                          MLIRContext *context) {
+  results.insert<ClampIsNoOp>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // Operator Folders.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 1f2f9a1c66a45..c1e828cc7fca4 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -49,6 +49,55 @@ func @cast_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
 
 // -----
 
+// CHECK-LABEL: @clamp_not_noop
+func @clamp_not_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
+  // CHECK: "tosa.clamp"
+  %0 = "tosa.clamp"(%arg0) {min_int = 1 : i64, max_int = 4 : i64, min_fp = 1.0 : f32, max_fp = 4.0 : f32} : (tensor<4xi32>) -> tensor<4xi32>
+  return %0 : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @clamp_float_is_noop
+func @clamp_float_is_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> {
+  // CHECK: return %arg0
+  // CHECK-NOT: "tosa.clamp"
+  %0 = "tosa.clamp"(%arg0) {min_int = -128 : i64, max_int = 127 : i64, min_fp = -3.40282347E+38 : f32, max_fp = 3.40282347E+38 : f32} :  (tensor<4xf32>) -> tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @clamp_int8_is_noop
+func @clamp_int8_is_noop(%arg0: tensor<4xi8>) -> tensor<4xi8> {
+  // CHECK: return %arg0
+  // CHECK-NOT: "tosa.clamp"
+  %0 = "tosa.clamp"(%arg0) {min_int = -128 : i64, max_int = 127 : i64, min_fp = -3.40282347E+38 : f32, max_fp = 3.40282347E+38 : f32} :  (tensor<4xi8>) -> tensor<4xi8>
+  return %0 : tensor<4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @clamp_int16_is_noop
+func @clamp_int16_is_noop(%arg0: tensor<4xi16>) -> tensor<4xi16> {
+  // CHECK: return %arg0
+  // CHECK-NOT: "tosa.clamp"
+  %0 = "tosa.clamp"(%arg0) {min_int = -32768 : i64, max_int = 32767 : i64, min_fp = -3.40282347E+38 : f32, max_fp = 3.40282347E+38 : f32} :  (tensor<4xi16>) -> tensor<4xi16>
+  return %0 : tensor<4xi16>
+}
+
+// -----
+
+// CHECK-LABEL: @clamp_uint8_is_noop
+func @clamp_uint8_is_noop(%arg0: tensor<4xui8>) -> tensor<4xui8> {
+  // CHECK: return %arg0
+  // CHECK-NOT: "tosa.clamp"
+  %0 = "tosa.clamp"(%arg0) {min_int = 0 : i64, max_int = 255 : i64, min_fp = -3.40282347E+38 : f32, max_fp = 3.40282347E+38 : f32} :  (tensor<4xui8>) -> tensor<4xui8>
+  return %0 : tensor<4xui8>
+}
+
+// -----
+
 // CHECK-LABEL: @concat_fold
 func @concat_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0


        


More information about the Mlir-commits mailing list