[Mlir-commits] [mlir] f9cefc7 - [mlir][tosa] Add tosa.max_pool2d as no-op canonicalization

Rob Suderman llvmlistbot at llvm.org
Thu Dec 16 15:30:25 PST 2021


Author: not-jenni
Date: 2021-12-16T15:27:26-08:00
New Revision: f9cefc7b9089bc915121ef5890c641b95cc55819

URL: https://github.com/llvm/llvm-project/commit/f9cefc7b9089bc915121ef5890c641b95cc55819
DIFF: https://github.com/llvm/llvm-project/commit/f9cefc7b9089bc915121ef5890c641b95cc55819.diff

LOG: [mlir][tosa] Add tosa.max_pool2d as no-op canonicalization

When the input and output of a pool2d op are both 1x1, it can be canonicalized to a no-op

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D115908

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 173f26db6c934..982880e027e04 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 //
 // This file defines the operation set for the TOSA dialect as defined in
-// the TOSA specfication (https://developer.mlplatform.org/w/tosa/). 
+// the TOSA specfication (https://developer.mlplatform.org/w/tosa/).
 //
 //===----------------------------------------------------------------------===//
 
@@ -58,7 +58,7 @@ def Tosa_ArgMaxOp : Tosa_Op<"argmax", [
 //===----------------------------------------------------------------------===//
 def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [
     DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>, 
+                              ["inferReturnTypeComponents"]>,
     NoSideEffect]> {
   let summary = "Performs max pooling on the input.";
 
@@ -275,6 +275,8 @@ def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [
   let results = (outs
     Tosa_Tensor4D:$output
   );
+
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -326,9 +328,9 @@ def Tosa_ClampOp : Tosa_Op<"clamp", [
 
   let description = [{
     Clamp to an arbitrary minimum and maximum value.
-    Maximum and minimum values are specified as values in the range of the 
+    Maximum and minimum values are specified as values in the range of the
     input type.
-    No zero point subtraction is done to the values, thus to clamp to the zero 
+    No zero point subtraction is done to the values, thus to clamp to the zero
     point value, the zero point itself should be supplied as the minimum value.
   }];
 
@@ -488,7 +490,7 @@ def Tosa_BitwiseAndOp : Tosa_Op<"bitwise_and", [
 
   let description = [{
     Elementwise bitwise AND of input1 and input2. Axis of size 1
-    will be broadcast as necessary. 
+    will be broadcast as necessary.
   }];
 
   let arguments = (ins
@@ -1379,7 +1381,7 @@ def Tosa_ConcatOp : Tosa_Op<"concat", [
   let summary = "Concatenates tensors along one dimension.";
 
   let description = [{
-    Concatenate a variadic amount of tensors along a given axis. No data 
+    Concatenate a variadic amount of tensors along a given axis. No data
     conversion happens during a concat operation.
   }];
 
@@ -1405,7 +1407,7 @@ def Tosa_PadOp : Tosa_Op<"pad", [
   let summary = "Pads a tensor with value specified.";
 
   let description = [{
-    Pads a tensor along borders of each dimension with pad_value. 
+    Pads a tensor along borders of each dimension with pad_value.
   }];
 
   let arguments = (ins
@@ -1510,7 +1512,7 @@ def Tosa_SliceOp: Tosa_Op<"slice", [
 //===----------------------------------------------------------------------===//
 def Tosa_TileOp: Tosa_Op<"tile", [
       DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>, 
+                              ["inferReturnTypeComponents"]>,
       NoSideEffect]> {
   let summary = "Tile operator";
 
@@ -1534,7 +1536,7 @@ def Tosa_TileOp: Tosa_Op<"tile", [
 //===----------------------------------------------------------------------===//
 def Tosa_TransposeOp : Tosa_Op<"transpose", [
       DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>, 
+                              ["inferReturnTypeComponents"]>,
       NoSideEffect]> {
   let summary = "Transpose operator";
 
@@ -1565,7 +1567,7 @@ def Tosa_TransposeOp : Tosa_Op<"transpose", [
 //===----------------------------------------------------------------------===//
 def Tosa_GatherOp : Tosa_Op<"gather", [
       DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>, 
+                              ["inferReturnTypeComponents"]>,
       NoSideEffect]> {
   let summary = "Gather operation,";
 
@@ -1697,7 +1699,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [NoSideEffect,
 //===----------------------------------------------------------------------===//
 // Operator: rescale
 //===----------------------------------------------------------------------===//
-def Tosa_RescaleOp: Tosa_Op<"rescale", [NoSideEffect, 
+def Tosa_RescaleOp: Tosa_Op<"rescale", [NoSideEffect,
       DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
                               ["inferReturnTypeComponents"]>]> {
   let summary = "Tosa rescale operator";

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 68c0d015c3a60..9809e57e3a9a0 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -614,6 +614,41 @@ void DepthwiseConv2DOp::getCanonicalizationPatterns(
   results.insert<DepthwiseConv2DMulOptimization>(context);
 }
 
+struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
+                                PatternRewriter &rewriter) const override {
+    Value input = op.input();
+    Value output = op.output();
+    ShapedType inputType = input.getType().cast<ShapedType>();
+    ShapedType outputType = output.getType().cast<ShapedType>();
+
+    if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
+      return failure();
+    }
+
+    // If the output and input shapes are 1x1, then this is a no op.
+    ArrayRef<int64_t> outputShape = outputType.getShape();
+    if (outputShape[1] != 1 || outputShape[2] != 1) {
+      return failure();
+    }
+
+    ArrayRef<int64_t> inputShape = inputType.getShape();
+    if (inputShape[1] != 1 || inputShape[2] != 1) {
+      return failure();
+    }
+
+    rewriter.replaceOp(op, input);
+    return success();
+  }
+};
+
+void MaxPool2dOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                              MLIRContext *context) {
+  results.insert<MaxPool2dIsNoOp>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // Operator Folders.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 91c2e3ce7feb6..fa5304bcfb042 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -181,7 +181,17 @@ func @depthwise_conv2d_weight_2x2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<2x
   return %0 : tensor<4x10x10x6xf32>
 }
 
-// ----
+// -----
+
+// CHECK-LABEL: @max_pool2d_is_noop
+func @max_pool2d_is_noop(%arg0: tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf32> {
+  // CHECK-NOT: "tosa.max_pool2d"
+  // CHECK: return %arg0
+  %0 = "tosa.max_pool2d"(%arg0) {kernel = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf32>
+  return %0 : tensor<10x1x1x3xf32>
+}
+
+// -----
 
 // CHECK-LABEL: @pad_noop
 func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
@@ -191,7 +201,7 @@ func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
   return %1 : tensor<?x?xf32>
 }
 
-// ----
+// -----
 
 // CHECK-LABEL: @pad_determine_val_i32
 func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
@@ -202,7 +212,7 @@ func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) ->
   return %1 : tensor<?x?xi32>
 }
 
-// ----
+// -----
 
 // CHECK-LABEL: @pad_determine_val_f32
 func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xf32> {
@@ -213,7 +223,7 @@ func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>) ->
   return %1 : tensor<?x?xf32>
 }
 
-// ----
+// -----
 
 // CHECK-LABEL: @pad_determine_val_quant
 func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {


        


More information about the Mlir-commits mailing list