[Mlir-commits] [mlir] 648dfdf - [mlir][tosa] Add tosa.avg_pool2d lowering

Rob Suderman llvmlistbot at llvm.org
Wed Apr 21 19:14:18 PDT 2021


Author: Rob Suderman
Date: 2021-04-21T19:07:52-07:00
New Revision: 648dfdfc2481bf0205181991f6eb9be13a3d9174

URL: https://github.com/llvm/llvm-project/commit/648dfdfc2481bf0205181991f6eb9be13a3d9174
DIFF: https://github.com/llvm/llvm-project/commit/648dfdfc2481bf0205181991f6eb9be13a3d9174.diff

LOG: [mlir][tosa] Add tosa.avg_pool2d lowering

Added the float lowerings for avg pool with corresponding tests.

Differential Revision: https://reviews.llvm.org/D100793

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index de27feca3e103..8ef186e257ec2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1626,18 +1626,19 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
   }
 };
 
-class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
+template <typename SrcOp>
+class Pool2dConverter : public OpRewritePattern<SrcOp> {
 public:
-  using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;
+  using OpRewritePattern<SrcOp>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
+  LogicalResult matchAndRewrite(SrcOp op,
                                 PatternRewriter &rewriter) const final {
     Location loc = op.getLoc();
     Value input = op.input();
     ShapedType inputTy = input.getType().cast<ShapedType>();
     Type inElementTy = inputTy.getElementType();
 
-    ShapedType resultTy = op.getType().cast<ShapedType>();
+    ShapedType resultTy = op.getType().template cast<ShapedType>();
     Type outElementTy = inputTy.getElementType();
     int64_t rank = inputTy.getRank();
 
@@ -1646,17 +1647,20 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
 
     // Determine what the initial value needs to be for the max pool op.
     Attribute initialAttr;
-    if (outElementTy.isF32())
+    if (isa<tosa::MaxPool2dOp>(op) && outElementTy.isF32())
       initialAttr = rewriter.getFloatAttr(
           outElementTy,
           APFloat::getLargest(
               outElementTy.cast<FloatType>().getFloatSemantics(), true));
 
-    if (outElementTy.isa<IntegerType>())
+    if (isa<tosa::MaxPool2dOp>(op) && outElementTy.isa<IntegerType>())
       initialAttr = rewriter.getIntegerAttr(
           outElementTy,
           APInt::getSignedMinValue(outElementTy.getIntOrFloatBitWidth()));
 
+    if (isa<tosa::AvgPool2dOp>(op) && outElementTy.isa<FloatType>())
+      initialAttr = rewriter.getZeroAttr(outElementTy);
+
     if (!initialAttr)
       return rewriter.notifyMatchFailure(
           op, "Unsupported initial value for tosa.maxpool_2d op");
@@ -1670,6 +1674,7 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
 
     Attribute strideAttr = rewriter.getI64VectorAttr(stride);
     Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
+    int64_t kernelSize = kernel[0] * kernel[1];
 
     // If non-zero padding we need to pad the input
     if (llvm::any_of(pad, [](int64_t v) { return v != 0; })) {
@@ -1716,34 +1721,46 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
               .getOperation());
     };
 
-    if (inElementTy.isF32()) {
+    if (isa<tosa::MaxPool2dOp>(op) && inElementTy.isF32()) {
       linalg::LinalgOp poolingOp =
           createOp(static_cast<linalg::PoolingNHWCMaxFOp *>(nullptr));
       rewriter.replaceOp(op, poolingOp->getResult(0));
       return success();
     }
 
-    if (inElementTy.isInteger(8)) {
+    if (isa<tosa::MaxPool2dOp>(op) && inElementTy.isInteger(8)) {
       linalg::LinalgOp poolingOp =
           createOp(static_cast<linalg::PoolingNHWCMaxI8Op *>(nullptr));
       rewriter.replaceOp(op, poolingOp->getResult(0));
       return success();
     }
 
-    if (inElementTy.isInteger(16)) {
+    if (isa<tosa::MaxPool2dOp>(op) && inElementTy.isInteger(16)) {
       linalg::LinalgOp poolingOp =
           createOp(static_cast<linalg::PoolingNHWCMaxI16Op *>(nullptr));
       rewriter.replaceOp(op, poolingOp->getResult(0));
       return success();
     }
 
-    if (inElementTy.isInteger(32)) {
+    if (isa<tosa::MaxPool2dOp>(op) && inElementTy.isInteger(32)) {
       linalg::LinalgOp poolingOp =
           createOp(static_cast<linalg::PoolingNHWCMaxI32Op *>(nullptr));
       rewriter.replaceOp(op, poolingOp->getResult(0));
       return success();
     }
 
+    if (isa<tosa::AvgPool2dOp>(op) && inElementTy.isF32()) {
+      linalg::LinalgOp poolingOp =
+          createOp(static_cast<linalg::PoolingNHWCSumFOp *>(nullptr));
+      auto constAttr = DenseElementsAttr::get(
+          resultTy, static_cast<float>(1.0 / kernelSize));
+      auto constant = rewriter.create<ConstantOp>(loc, constAttr);
+      auto mul = rewriter.create<tosa::MulOp>(
+          loc, resultTy, poolingOp->getResult(0), constant, 0);
+      rewriter.replaceOp(op, mul.output());
+      return success();
+    }
+
     return failure();
   }
 };
@@ -1805,7 +1822,8 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
       TileConverter,
       TransposeConverter,
       MatMulConverter,
-      MaxPool2dConverter,
+      Pool2dConverter<tosa::AvgPool2dOp>,
+      Pool2dConverter<tosa::MaxPool2dOp>,
       FullyConnectedConverter>(patterns->getContext());
-      // clang-format on
+  // clang-format on
 }

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 119ceea14afe3..09392e6cef5db 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -923,6 +923,21 @@ func @max_pool_i32(%arg0: tensor<1x6x34x62xi32>) -> () {
   %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xi32>)  -> (tensor<1x4x32x62xi32>)
   return
 }
+// -----
+
+// CHECK-LABEL: @avg_pool
+func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> () {
+  // CHECK-DAG: [[CONST:%.+]] = constant 0
+  // CHECK-DAG: [[INIT:%.+]] = linalg.init_tensor [1, 3, 31, 62]
+  // CHECK-DAG: [[FILL:%.+]] = linalg.fill([[INIT]], [[CONST]])
+  // CHECK-DAG: [[KERNEL:%.+]] = linalg.init_tensor [4, 4]
+  // CHECK: linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, [[KERNEL]] : tensor<1x6x34x62xf32>, tensor<4x4xf32>) outs([[FILL]] : tensor<1x3x31x62xf32>)
+  // CHECK: constant dense<6.250000e-02>
+  // CHECK: linalg.generic
+  // CHECK: mulf
+  %0 = "tosa.avg_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [4, 4], stride = [1, 1]} : (tensor<1x6x34x62xf32>)  -> (tensor<1x3x31x62xf32>)
+  return
+}
 
 // -----
 


        


More information about the Mlir-commits mailing list