[Mlir-commits] [mlir] f9cefc7 - [mlir][tosa] Add tosa.max_pool2d as no-op canonicalization
Rob Suderman
llvmlistbot at llvm.org
Thu Dec 16 15:30:25 PST 2021
Author: not-jenni
Date: 2021-12-16T15:27:26-08:00
New Revision: f9cefc7b9089bc915121ef5890c641b95cc55819
URL: https://github.com/llvm/llvm-project/commit/f9cefc7b9089bc915121ef5890c641b95cc55819
DIFF: https://github.com/llvm/llvm-project/commit/f9cefc7b9089bc915121ef5890c641b95cc55819.diff
LOG: [mlir][tosa] Add tosa.max_pool2d as no-op canonicalization
When the input and output of a pool2d op are both 1x1, it can be canonicalized to a no-op
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D115908
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 173f26db6c934..982880e027e04 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
//
// This file defines the operation set for the TOSA dialect as defined in
-// the TOSA specfication (https://developer.mlplatform.org/w/tosa/).
+// the TOSA specfication (https://developer.mlplatform.org/w/tosa/).
//
//===----------------------------------------------------------------------===//
@@ -58,7 +58,7 @@ def Tosa_ArgMaxOp : Tosa_Op<"argmax", [
//===----------------------------------------------------------------------===//
def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
+ ["inferReturnTypeComponents"]>,
NoSideEffect]> {
let summary = "Performs max pooling on the input.";
@@ -275,6 +275,8 @@ def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [
let results = (outs
Tosa_Tensor4D:$output
);
+
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
@@ -326,9 +328,9 @@ def Tosa_ClampOp : Tosa_Op<"clamp", [
let description = [{
Clamp to an arbitrary minimum and maximum value.
- Maximum and minimum values are specified as values in the range of the
+ Maximum and minimum values are specified as values in the range of the
input type.
- No zero point subtraction is done to the values, thus to clamp to the zero
+ No zero point subtraction is done to the values, thus to clamp to the zero
point value, the zero point itself should be supplied as the minimum value.
}];
@@ -488,7 +490,7 @@ def Tosa_BitwiseAndOp : Tosa_Op<"bitwise_and", [
let description = [{
Elementwise bitwise AND of input1 and input2. Axis of size 1
- will be broadcast as necessary.
+ will be broadcast as necessary.
}];
let arguments = (ins
@@ -1379,7 +1381,7 @@ def Tosa_ConcatOp : Tosa_Op<"concat", [
let summary = "Concatenates tensors along one dimension.";
let description = [{
- Concatenate a variadic amount of tensors along a given axis. No data
+ Concatenate a variadic amount of tensors along a given axis. No data
conversion happens during a concat operation.
}];
@@ -1405,7 +1407,7 @@ def Tosa_PadOp : Tosa_Op<"pad", [
let summary = "Pads a tensor with value specified.";
let description = [{
- Pads a tensor along borders of each dimension with pad_value.
+ Pads a tensor along borders of each dimension with pad_value.
}];
let arguments = (ins
@@ -1510,7 +1512,7 @@ def Tosa_SliceOp: Tosa_Op<"slice", [
//===----------------------------------------------------------------------===//
def Tosa_TileOp: Tosa_Op<"tile", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
+ ["inferReturnTypeComponents"]>,
NoSideEffect]> {
let summary = "Tile operator";
@@ -1534,7 +1536,7 @@ def Tosa_TileOp: Tosa_Op<"tile", [
//===----------------------------------------------------------------------===//
def Tosa_TransposeOp : Tosa_Op<"transpose", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
+ ["inferReturnTypeComponents"]>,
NoSideEffect]> {
let summary = "Transpose operator";
@@ -1565,7 +1567,7 @@ def Tosa_TransposeOp : Tosa_Op<"transpose", [
//===----------------------------------------------------------------------===//
def Tosa_GatherOp : Tosa_Op<"gather", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
+ ["inferReturnTypeComponents"]>,
NoSideEffect]> {
let summary = "Gather operation,";
@@ -1697,7 +1699,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [NoSideEffect,
//===----------------------------------------------------------------------===//
// Operator: rescale
//===----------------------------------------------------------------------===//
-def Tosa_RescaleOp: Tosa_Op<"rescale", [NoSideEffect,
+def Tosa_RescaleOp: Tosa_Op<"rescale", [NoSideEffect,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>]> {
let summary = "Tosa rescale operator";
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 68c0d015c3a60..9809e57e3a9a0 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -614,6 +614,41 @@ void DepthwiseConv2DOp::getCanonicalizationPatterns(
results.insert<DepthwiseConv2DMulOptimization>(context);
}
+struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
+ PatternRewriter &rewriter) const override {
+ Value input = op.input();
+ Value output = op.output();
+ ShapedType inputType = input.getType().cast<ShapedType>();
+ ShapedType outputType = output.getType().cast<ShapedType>();
+
+ if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
+ return failure();
+ }
+
+ // If the output and input shapes are 1x1, then this is a no op.
+ ArrayRef<int64_t> outputShape = outputType.getShape();
+ if (outputShape[1] != 1 || outputShape[2] != 1) {
+ return failure();
+ }
+
+ ArrayRef<int64_t> inputShape = inputType.getShape();
+ if (inputShape[1] != 1 || inputShape[2] != 1) {
+ return failure();
+ }
+
+ rewriter.replaceOp(op, input);
+ return success();
+ }
+};
+
+void MaxPool2dOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<MaxPool2dIsNoOp>(context);
+}
+
//===----------------------------------------------------------------------===//
// Operator Folders.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 91c2e3ce7feb6..fa5304bcfb042 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -181,7 +181,17 @@ func @depthwise_conv2d_weight_2x2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<2x
return %0 : tensor<4x10x10x6xf32>
}
-// ----
+// -----
+
+// CHECK-LABEL: @max_pool2d_is_noop
+func @max_pool2d_is_noop(%arg0: tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf32> {
+ // CHECK-NOT: "tosa.max_pool2d"
+ // CHECK: return %arg0
+ %0 = "tosa.max_pool2d"(%arg0) {kernel = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf32>
+ return %0 : tensor<10x1x1x3xf32>
+}
+
+// -----
// CHECK-LABEL: @pad_noop
func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
@@ -191,7 +201,7 @@ func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
return %1 : tensor<?x?xf32>
}
-// ----
+// -----
// CHECK-LABEL: @pad_determine_val_i32
func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
@@ -202,7 +212,7 @@ func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) ->
return %1 : tensor<?x?xi32>
}
-// ----
+// -----
// CHECK-LABEL: @pad_determine_val_f32
func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xf32> {
@@ -213,7 +223,7 @@ func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>) ->
return %1 : tensor<?x?xf32>
}
-// ----
+// -----
// CHECK-LABEL: @pad_determine_val_quant
func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
More information about the Mlir-commits
mailing list