[Mlir-commits] [mlir] 08068dd - [mlir][tosa] Fix tosa.avg_pool2d lowering to normalize correctly

Rob Suderman llvmlistbot at llvm.org
Mon May 17 10:06:30 PDT 2021


Author: Rob Suderman
Date: 2021-05-17T10:00:43-07:00
New Revision: 08068ddba7f52255fa39968207309f3d1ad98223

URL: https://github.com/llvm/llvm-project/commit/08068ddba7f52255fa39968207309f3d1ad98223
DIFF: https://github.com/llvm/llvm-project/commit/08068ddba7f52255fa39968207309f3d1ad98223.diff

LOG: [mlir][tosa] Fix tosa.avg_pool2d lowering to normalize correctly

Initial version of pooling assumed normalization was accross all elements
equally. TOSA actually requires the noramalization is perform by how
many elements were summed (edges are not artifically dimmer). Updated
the lowering to reflect this change with corresponding tests.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D102540

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 06b23cc9dd7ac..a68c8bca8a74a 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -2263,7 +2263,7 @@ class Pool2dConverter : public OpRewritePattern<SrcOp> {
     pad.resize(2, 0);
     getValuesFromIntArrayAttribute(op.pad(), pad);
     pad.resize(pad.size() + 2, 0);
-    input = applyPad(loc, input, pad, initialAttr, rewriter);
+    Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
 
     Value initialValue = rewriter.create<ConstantOp>(loc, initialAttr);
 
@@ -2273,7 +2273,6 @@ class Pool2dConverter : public OpRewritePattern<SrcOp> {
 
     Attribute strideAttr = rewriter.getI64VectorAttr(stride);
     Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
-    int64_t kernelSize = kernel[0] * kernel[1];
 
     // Create the linalg op that performs pooling.
     Value initTensor = rewriter.create<linalg::InitTensorOp>(
@@ -2290,7 +2289,7 @@ class Pool2dConverter : public OpRewritePattern<SrcOp> {
           rewriter
               .create<std::remove_pointer_t<decltype(typePtr)>>(
                   loc, ArrayRef<Type>{resultTy},
-                  ValueRange{input, fakeWindowDims}, filledInitTensor,
+                  ValueRange{paddedInput, fakeWindowDims}, filledInitTensor,
                   dilationAttr, strideAttr)
               .getOperation());
     };
@@ -2324,14 +2323,76 @@ class Pool2dConverter : public OpRewritePattern<SrcOp> {
     }
 
     if (isa<tosa::AvgPool2dOp>(op) && inElementTy.isF32()) {
-      linalg::LinalgOp poolingOp =
-          createOp(static_cast<linalg::PoolingNHWCSumFOp *>(nullptr));
-      auto constAttr = DenseElementsAttr::get(
-          resultTy, static_cast<float>(1.0 / kernelSize));
-      auto constant = rewriter.create<ConstantOp>(loc, constAttr);
-      auto mul = rewriter.create<tosa::MulOp>(
-          loc, resultTy, poolingOp->getResult(0), constant, 0);
-      rewriter.replaceOp(op, mul.output());
+      Value poolingOp =
+          createOp(static_cast<linalg::PoolingNHWCSumFOp *>(nullptr))
+              ->getResult(0);
+      auto poolingOpTy = poolingOp.getType().cast<ShapedType>();
+      auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
+      auto genericOp = rewriter.create<linalg::IndexedGenericOp>(
+          loc, ArrayRef<Type>({resultTy}), ValueRange{}, ValueRange{poolingOp},
+          ArrayRef<AffineMap>({affineMap}),
+          getNParallelLoopsAttrs(resultTy.getRank()),
+          [&](OpBuilder &b, Location loc, ValueRange indices, ValueRange args) {
+            auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
+            auto one = rewriter.create<ConstantIndexOp>(loc, 1);
+            auto iH = rewriter.create<ConstantIndexOp>(
+                loc, poolingOpTy.getDimSize(1) - 1);
+            auto iW = rewriter.create<ConstantIndexOp>(
+                loc, poolingOpTy.getDimSize(2) - 1);
+
+            // Compute the indices from either end.
+            auto y0 = indices[1];
+            auto x0 = indices[2];
+            auto y1 = rewriter.create<SubIOp>(loc, iH, y0);
+            auto x1 = rewriter.create<SubIOp>(loc, iW, x0);
+
+            // Determines what the portion of valid input is covered by the
+            // kernel.
+            auto padFn = [&](Value v, Value x, int64_t pad) -> Value {
+              if (pad == 0)
+                return v;
+
+              auto padVal = rewriter.create<ConstantIndexOp>(loc, pad);
+              Value dx = rewriter.create<SubIOp>(loc, x, padVal);
+
+              Value cmp = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
+                                                        dx, zero);
+              Value offset =
+                  rewriter.create<mlir::SelectOp>(loc, cmp, dx, zero);
+              return rewriter.create<mlir::AddIOp>(loc, v, offset)
+                  ->getResult(0);
+            };
+
+            // Compute the vertical component of coverage.
+            auto kH0 = rewriter.create<ConstantIndexOp>(loc, kernel[0]);
+            auto kH1 = padFn(kH0, y0, pad[2]);
+            auto kH2 = padFn(kH1, y1, pad[3]);
+            auto kHCmp =
+                rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, kH2, one);
+            auto kH3 = rewriter.create<SelectOp>(loc, kHCmp, one, kH2);
+
+            // compute teh horizontal component of coverage.
+            auto kW0 = rewriter.create<ConstantIndexOp>(loc, kernel[1]);
+            auto kW1 = padFn(kW0, x0, pad[4]);
+            auto kW2 = padFn(kW1, x1, pad[5]);
+            auto kWCmp =
+                rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, kW2, one);
+            auto kW3 = rewriter.create<SelectOp>(loc, kWCmp, one, kW2);
+
+            // Compute the total number of elements and normalize.
+            Value count = rewriter.create<MulIOp>(loc, kH3, kW3);
+            auto countI = rewriter.create<mlir::IndexCastOp>(
+                loc, rewriter.getI32Type(), count);
+            auto countF =
+                rewriter.create<mlir::SIToFPOp>(loc, inElementTy, countI);
+
+            auto div =
+                rewriter.create<DivFOp>(loc, args[0], countF)->getResult(0);
+
+            rewriter.create<linalg::YieldOp>(loc, div);
+          });
+
+      rewriter.replaceOp(op, genericOp.getResult(0));
       return success();
     }
 

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index c2e6b07dce46b..b789072874e6f 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1087,17 +1087,59 @@ func @max_pool_i32(%arg0: tensor<1x6x34x62xi32>) -> () {
 // -----
 
 // CHECK-LABEL: @avg_pool
-func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> () {
-  // CHECK-DAG: [[CONST:%.+]] = constant 0
-  // CHECK-DAG: [[INIT:%.+]] = linalg.init_tensor [1, 3, 31, 62]
-  // CHECK-DAG: [[FILL:%.+]] = linalg.fill([[INIT]], [[CONST]])
-  // CHECK-DAG: [[KERNEL:%.+]] = linalg.init_tensor [4, 4]
-  // CHECK: linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, [[KERNEL]] : tensor<1x6x34x62xf32>, tensor<4x4xf32>) outs([[FILL]] : tensor<1x3x31x62xf32>)
-  // CHECK: constant dense<6.250000e-02>
-  // CHECK: linalg.generic
-  // CHECK: mulf
-  %0 = "tosa.avg_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [4, 4], stride = [1, 1]} : (tensor<1x6x34x62xf32>)  -> (tensor<1x3x31x62xf32>)
-  return
+func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) {
+  // Initial piece computes the sum of the pooling region, with appropriate padding.
+  // CHECK: [[CONST:%.+]] = constant 0
+  // CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] 
+  // CHECK: [[CONST:%.+]] = constant 0
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 33, 62]
+  // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CONST]])
+  // CHECK: [[KERNEL:%.+]] = linalg.init_tensor [4, 4]
+  // CHECK: [[POOL:%.+]] = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins([[PAD]], [[KERNEL]] : tensor<1x8x36x62xf32>, tensor<4x4xf32>) outs([[FILL]] : tensor<1x5x33x62xf32>)
+  // CHECK: [[GENERIC:%.+]] = linalg.indexed_generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs([[POOL]] : tensor<1x5x33x62xf32>)
+  // CHECK:   [[ZERO:%.0]] = constant 0
+  // CHECK:   [[ONE:%.+]] = constant 1
+  // CHECK:   [[HEIGHT:%.+]] = constant 4
+  // CHECK:   [[WIDTH:%.+]] = constant 32
+
+  // The large block below computes what portion of the kernel is within non-padded input.
+  // CHECK:   [[NY:%.+]] = subi [[HEIGHT]], %arg2
+  // CHECK:   [[NX:%.+]] = subi [[WIDTH]], %arg3
+  // CHECK:   [[KH:%.+]] = constant 4
+  // CHECK:   [[PAD0:%.+]] = constant 1
+  // CHECK:   [[SUBP0:%.+]] = subi %arg2, [[PAD0]]
+  // CHECK:   [[P0CMP:%.+]] = cmpi slt, [[SUBP0]], [[ZERO]]
+  // CHECK:   [[SELP0:%.+]] = select [[P0CMP]], [[SUBP0]], [[ZERO]]
+  // CHECK:   [[ADDP0:%.+]] = addi [[KH]], [[SELP0]]
+  // CHECK:   [[PAD1:%.+]] = constant 1
+  // CHECK:   [[SUBP1:%.+]] = subi [[NY]], [[PAD1]]
+  // CHECK:   [[P1CMP:%.+]] = cmpi slt, [[SUBP1]], [[ZERO]]
+  // CHECK:   [[SELP1:%.+]] = select [[P1CMP]], [[SUBP1]], [[ZERO]]
+  // CHECK:   [[ADDP1:%.+]] = addi [[ADDP0]], [[SELP1]]
+  // CHECK:   [[YCMP:%.+]] = cmpi slt, [[ADDP1]], [[ONE]]
+  // CHECK:   [[YSEL:%.+]] = select [[YCMP]], [[ONE]], [[ADDP1]]
+  // CHECK:   [[KW:%.+]] = constant 4 : index
+  // CHECK:   [[PAD2:%.+]] = constant 1 : index
+  // CHECK:   [[SUBP2:%.+]] = subi %arg3, [[PAD2]]
+  // CHECK:   [[P2CMP:%.+]] = cmpi slt, [[SUBP2]], [[ZERO]]
+  // CHECK:   [[SELP2:%.+]] = select [[P2CMP]], [[SUBP2]], [[ZERO]]
+  // CHECK:   [[ADDP2:%.+]] = addi [[KW]], [[SELP2]]
+  // CHECK:   [[PAD3:%.+]] = constant 1 : index
+  // CHECK:   [[SUBP3:%.+]] = subi [[NX]], [[PAD3]]
+  // CHECK:   [[P3CMP:%.+]] = cmpi slt, [[SUBP3]], [[ZERO]]
+  // CHECK:   [[SELP3:%.+]] = select [[P3CMP]], [[SUBP3]], [[ZERO]]
+  // CHECK:   [[ADDP3:%.+]] = addi [[ADDP2]], [[SELP3]]
+  // CHECK:   [[XCMP:%.+]] = cmpi slt, [[ADDP3]], [[ONE]]
+  // CHECK:   [[XSEL:%.+]] = select [[XCMP]], [[ONE]], [[ADDP3]]
+
+  // Given the valid coverage of the pooling region, normalize the summation.
+  // CHECK:   [[C:%.+]] = muli [[YSEL]], [[XSEL]]
+  // CHECK:   [[CI:%.+]] = index_cast [[C]]
+  // CHECK:   [[CF:%.+]] = sitofp [[CI]]
+  // CHECK:   [[RESULT:%.+]] = divf %arg5, [[CF]]
+  // CHECK:   linalg.yield [[RESULT]]
+  %0 = "tosa.avg_pool2d"(%arg0) {pad = [1, 1, 1, 1], kernel = [4, 4], stride = [1, 1]} : (tensor<1x6x34x62xf32>)  -> (tensor<1x5x33x62xf32>)
+  return %0 : tensor<1x5x33x62xf32>
 }
 
 // -----


        


More information about the Mlir-commits mailing list