[Mlir-commits] [mlir] f2832c2 - [mlir][tosa] Added shape propagation for TOSA pool operations.

Rob Suderman llvmlistbot at llvm.org
Mon Jul 12 15:41:49 PDT 2021


Author: Rob Suderman
Date: 2021-07-12T15:40:49-07:00
New Revision: f2832c2295c6076b51a35d0d7b304c08e1b41c29

URL: https://github.com/llvm/llvm-project/commit/f2832c2295c6076b51a35d0d7b304c08e1b41c29
DIFF: https://github.com/llvm/llvm-project/commit/f2832c2295c6076b51a35d0d7b304c08e1b41c29.diff

LOG: [mlir][tosa] Added shape propagation for TOSA pool operations.

Pool operations perform the same shape propagation. Included the shape
propagation and tests for these avg_pool2d and max_pool2d.

Differential Revision: https://reviews.llvm.org/D105665

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 76cd66aac064e..eafce2c378433 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -56,7 +56,10 @@ def Tosa_ArgMaxOp : Tosa_Op<"argmax", [
 //===----------------------------------------------------------------------===//
 // Operator: avg_pool2d
 //===----------------------------------------------------------------------===//
-def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [NoSideEffect]> {
+def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>, 
+    NoSideEffect]> {
   let summary = "Performs max pooling on the input.";
 
   let description = [{
@@ -233,7 +236,10 @@ def Tosa_MatMulOp : Tosa_Op<"matmul", [
 //===----------------------------------------------------------------------===//
 // Operator: max_pool2d
 //===----------------------------------------------------------------------===//
-def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [NoSideEffect]> {
+def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect]> {
   let summary = "Performs max pooling on the input.";
 
   let description = [{

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 9126f1776ca21..75f26f6f23cb5 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -845,6 +845,62 @@ NARY_SHAPE_INFER(tosa::TanhOp)
 NARY_SHAPE_INFER(tosa::SigmoidOp)
 #undef PRED_SHAPE_INFER
 
+static LogicalResult poolingInferReturnTypes(
+    ValueRange operands, DictionaryAttr attributes,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  RankedTensorType inputTy = operands[0].getType().dyn_cast<RankedTensorType>();
+  llvm::SmallVector<int64_t> outputShape;
+  outputShape.resize(4, -1);
+
+  // We only know the rank if the input type is unranked.
+  if (!inputTy) {
+    inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+    return success();
+  }
+
+  // Batch and number of channels are identical for pooling layer.
+  outputShape[0] = inputTy.getDimSize(0);
+  outputShape[3] = inputTy.getDimSize(3);
+
+  int32_t height = inputTy.getDimSize(1);
+  int32_t width = inputTy.getDimSize(2);
+
+  llvm::SmallVector<int64_t> kernel;
+  llvm::SmallVector<int64_t> stride;
+  llvm::SmallVector<int64_t> pad;
+
+  getI64Values(attributes.get("kernel").cast<ArrayAttr>(), kernel);
+  getI64Values(attributes.get("stride").cast<ArrayAttr>(), stride);
+  getI64Values(attributes.get("pad").cast<ArrayAttr>(), pad);
+
+  if (height != -1) {
+    int32_t padded = height + pad[0] + pad[1] - kernel[0];
+    outputShape[1] = padded / stride[0] + 1;
+  }
+
+  if (width != -1) {
+    int32_t padded = width + pad[2] + pad[3] - kernel[1];
+    outputShape[2] = padded / stride[1] + 1;
+  }
+
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  return success();
+}
+
+LogicalResult AvgPool2dOp::inferReturnTypeComponents(
+    MLIRContext *context, ::llvm::Optional<Location> location,
+    ValueRange operands, DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  return poolingInferReturnTypes(operands, attributes, inferredReturnShapes);
+}
+
+LogicalResult MaxPool2dOp::inferReturnTypeComponents(
+    MLIRContext *context, ::llvm::Optional<Location> location,
+    ValueRange operands, DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  return poolingInferReturnTypes(operands, attributes, inferredReturnShapes);
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Operator Definitions.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index bfbbe07d42fde..a5134aca388ea 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -660,3 +660,51 @@ func @scatter_minimum_static(%arg0 : tensor<?x4x?xi32>, %arg1 : tensor<3x?xi32>,
   %0 = "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor<?x4x?xi32>, tensor<3x?xi32>, tensor<?x?x5xi32>)  -> (tensor<?x?x?xi32>)
   return
 }
+
+// -----
+
+// CHECK-LABEL: @test_pool_static
+func @test_pool_static(%arg0: tensor<3x5x6x7xf32>) {
+  // CHECK: -> tensor<3x2x4x7xf32>
+  %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 3], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<3x5x6x7xf32>) -> tensor<?x?x?x?xf32>
+
+  // CHECK: -> tensor<3x2x4x7xf32>
+  %1 = "tosa.max_pool2d"(%arg0) {kernel = [4, 3], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<3x5x6x7xf32>) -> tensor<?x?x?x?xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_pool_dynamic_input
+func @test_pool_dynamic_input(%arg0: tensor<?x?x?x?xf32>) {
+  // CHECK: -> tensor<?x?x?x?xf32>
+  %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 3], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+
+  // CHECK: -> tensor<?x?x?x?xf32>
+  %1 = "tosa.max_pool2d"(%arg0) {kernel = [4, 3], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_pool_padded
+func @test_pool_padded(%arg0: tensor<3x5x6x7xf32>) {
+  // CHECK: -> tensor<3x5x11x7xf32>
+  %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 3], pad = [1, 2, 3, 4], stride = [1, 1]} : (tensor<3x5x6x7xf32>) -> tensor<?x?x?x?xf32>
+
+  // CHECK: -> tensor<3x5x11x7xf32>
+  %1 = "tosa.max_pool2d"(%arg0) {kernel = [4, 3], pad = [1, 2, 3, 4], stride = [1, 1]} : (tensor<3x5x6x7xf32>) -> tensor<?x?x?x?xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_pool_stride
+func @test_pool_stride(%arg0: tensor<3x11x12x7xf32>) {
+  // CHECK: -> tensor<3x4x4x7xf32>
+  %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 3], pad = [0, 0, 0, 0], stride = [2, 3]} : (tensor<3x11x12x7xf32>) -> tensor<?x?x?x?xf32>
+
+  // CHECK: -> tensor<3x4x4x7xf32>
+  %1 = "tosa.max_pool2d"(%arg0) {kernel = [4, 3], pad = [0, 0, 0, 0], stride = [2, 3]} : (tensor<3x11x12x7xf32>) -> tensor<?x?x?x?xf32>
+  return
+}


        


More information about the Mlir-commits mailing list