[Mlir-commits] [mlir] b67b024 - [mlir][tosa] Update tosa.avg_pool2d for bit-exact TOSA behavior

Robert Suderman llvmlistbot at llvm.org
Wed Jan 25 10:52:25 PST 2023


Author: Rob Suderman
Date: 2023-01-25T18:51:12Z
New Revision: b67b024d583d328420d7a46f0897e02cdd4ebd68

URL: https://github.com/llvm/llvm-project/commit/b67b024d583d328420d7a46f0897e02cdd4ebd68
DIFF: https://github.com/llvm/llvm-project/commit/b67b024d583d328420d7a46f0897e02cdd4ebd68.diff

LOG: [mlir][tosa] Update tosa.avg_pool2d for bit-exact TOSA behavior

The normalization component of average pool has a very specific
rounding behavior for compouting the division for floating
point values. Updated so that the bit-exact version is implemented.

Also includes a fix for computing the stride part of the average pool
operation.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D141339

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index f9732efaa27eb..fac8d65eb172e 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -817,12 +817,17 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
 
     // Normalize the summed value by the number of elements grouped in each
     // pool.
-    auto poolingOpTy = poolingOp.getType().cast<ShapedType>();
-    auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
+    Value iH = rewriter.create<tensor::DimOp>(loc, poolingOp, 1);
+    Value iW = rewriter.create<tensor::DimOp>(loc, poolingOp, 2);
+
+    auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    iH = rewriter.create<arith::SubIOp>(loc, iH, one);
+    iW = rewriter.create<arith::SubIOp>(loc, iW, one);
 
     Value genericEmptyTensor = rewriter.create<tensor::EmptyOp>(
         loc, resultTy.getShape(), resultETy, dynamicDims);
 
+    auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
     auto genericOp = rewriter.create<linalg::GenericOp>(
         loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
         ValueRange{genericEmptyTensor},
@@ -830,60 +835,59 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
         getNParallelLoopsAttrs(resultTy.getRank()),
         [&](OpBuilder &b, Location loc, ValueRange args) {
           auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-          auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-          auto iH = rewriter.create<arith::ConstantIndexOp>(
-              loc, poolingOpTy.getDimSize(1) - 1);
-          auto iW = rewriter.create<arith::ConstantIndexOp>(
-              loc, poolingOpTy.getDimSize(2) - 1);
-
-          // Compute the indices from either end.
-          auto y0 = rewriter.create<linalg::IndexOp>(loc, 1);
-          auto x0 = rewriter.create<linalg::IndexOp>(loc, 2);
-          auto y1 = rewriter.create<arith::SubIOp>(loc, iH, y0);
-          auto x1 = rewriter.create<arith::SubIOp>(loc, iW, x0);
 
           // Determines what the portion of valid input is covered by the
           // kernel.
-          auto padFn = [&](Value v, Value x, int64_t pad) -> Value {
+          auto padFn = [&](Value valid, Value pos, int64_t pad) -> Value {
             if (pad == 0)
-              return v;
+              return valid;
 
             auto padVal = rewriter.create<arith::ConstantIndexOp>(loc, pad);
-            Value dx = rewriter.create<arith::SubIOp>(loc, x, padVal);
+            Value dpos = rewriter.create<arith::SubIOp>(loc, pos, padVal);
+
+            Value cmp = rewriter.create<arith::CmpIOp>(
+                loc, arith::CmpIPredicate::slt, dpos, zero);
+            Value offset =
+                rewriter.create<arith::SelectOp>(loc, cmp, dpos, zero);
+            return rewriter.create<arith::AddIOp>(loc, valid, offset)
+                ->getResult(0);
+          };
 
+          auto coverageFn = [&](int64_t i, Value isize) -> Value {
+            Value strideVal =
+                rewriter.create<arith::ConstantIndexOp>(loc, stride[i - 1]);
+            Value val =
+                rewriter.create<arith::ConstantIndexOp>(loc, kernel[i - 1]);
+
+            // Find the position relative to the input tensor's ends.
+            Value left = rewriter.create<linalg::IndexOp>(loc, i);
+            Value right = rewriter.create<arith::SubIOp>(loc, isize, left);
+            left = rewriter.create<arith::MulIOp>(loc, left, strideVal);
+            right = rewriter.create<arith::MulIOp>(loc, right, strideVal);
+
+            // Determine how much padding was included.
+            val = padFn(val, left, pad[i * 2]);
+            val = padFn(val, right, pad[i * 2 + 1]);
             Value cmp = rewriter.create<arith::CmpIOp>(
-                loc, arith::CmpIPredicate::slt, dx, zero);
-            Value offset = rewriter.create<arith::SelectOp>(loc, cmp, dx, zero);
-            return rewriter.create<arith::AddIOp>(loc, v, offset)->getResult(0);
+                loc, arith::CmpIPredicate::slt, val, one);
+            return rewriter.create<arith::SelectOp>(loc, cmp, one, val);
           };
 
-          // Compute the vertical component of coverage.
-          auto kH0 = rewriter.create<arith::ConstantIndexOp>(loc, kernel[0]);
-          auto kH1 = padFn(kH0, y0, pad[2]);
-          auto kH2 = padFn(kH1, y1, pad[3]);
-          auto kHCmp = rewriter.create<arith::CmpIOp>(
-              loc, arith::CmpIPredicate::slt, kH2, one);
-          auto kH3 = rewriter.create<arith::SelectOp>(loc, kHCmp, one, kH2);
-
-          // compute the horizontal component of coverage.
-          auto kW0 = rewriter.create<arith::ConstantIndexOp>(loc, kernel[1]);
-          auto kW1 = padFn(kW0, x0, pad[4]);
-          auto kW2 = padFn(kW1, x1, pad[5]);
-          auto kWCmp = rewriter.create<arith::CmpIOp>(
-              loc, arith::CmpIPredicate::slt, kW2, one);
-          auto kW3 = rewriter.create<arith::SelectOp>(loc, kWCmp, one, kW2);
+          // Compute the indices from either end.
+          Value kH3 = coverageFn(1, iH);
+          Value kW3 = coverageFn(2, iW);
 
           // Compute the total number of elements and normalize.
-          Value count = rewriter.create<arith::MulIOp>(loc, kH3, kW3);
-          auto countI = rewriter.create<arith::IndexCastOp>(
-              loc, rewriter.getI32Type(), count);
+          auto count = rewriter.create<arith::IndexCastOp>(
+              loc, rewriter.getI32Type(),
+              rewriter.create<arith::MulIOp>(loc, kH3, kW3));
 
           // Divide by the number of summed values. For floats this is just
           // a div however for quantized values input normalization had
           // to be applied.
           Value poolVal = args[0];
           if (accETy.isa<FloatType>()) {
-            auto countF = rewriter.create<arith::SIToFPOp>(loc, accETy, countI);
+            auto countF = rewriter.create<arith::SIToFPOp>(loc, accETy, count);
             poolVal = rewriter.create<arith::DivFOp>(loc, poolVal, countF)
                           ->getResult(0);
           } else {
@@ -895,33 +899,52 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
               auto inputZp = rewriter.create<arith::ConstantOp>(
                   loc, b.getIntegerAttr(accETy, quantizationInfo.getInputZp()));
               Value offset =
-                  rewriter.create<arith::MulIOp>(loc, accETy, countI, inputZp);
+                  rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp);
               poolVal =
                   rewriter.create<arith::SubIOp>(loc, accETy, poolVal, offset);
             }
 
-            // Compute the multiplier and shift values for the quantization
-            // normalization. Preferably we would want to compute more bits
-            // however 32-bits should be enough for compute. Honestly we
-            // should probably straight divide.
-            int64_t numerator = ((1 << 30) + 1);
-            int64_t shift = 30;
-
-            Value numeratorVal = rewriter.create<arith::ConstantOp>(
-                loc, rewriter.getI32IntegerAttr(numerator));
-            Value multiplierVal =
-                rewriter
-                    .create<arith::DivUIOp>(loc, rewriter.getI32Type(),
-                                            numeratorVal, countI)
-                    .getResult();
-            Value shiftVal = rewriter.create<arith::ConstantOp>(
-                loc, rewriter.getI8IntegerAttr(shift));
+            // Compute: k = 32 - count_leading_zeros(value - 1)
+            Value one32 = rewriter.create<arith::ConstantOp>(
+                loc, rewriter.getI32IntegerAttr(1));
+            Value thirtyTwo32 = rewriter.create<arith::ConstantOp>(
+                loc, rewriter.getI32IntegerAttr(32));
+
+            Value countSubOne =
+                rewriter.create<arith::SubIOp>(loc, count, one32);
+            Value leadingZeros =
+                rewriter.create<math::CountLeadingZerosOp>(loc, countSubOne);
+            Value k =
+                rewriter.create<arith::SubIOp>(loc, thirtyTwo32, leadingZeros);
+
+            // Compute: numerator = ((1 << 30) + 1) << k
+            Value k64 =
+                rewriter.create<arith::ExtUIOp>(loc, rewriter.getI64Type(), k);
+            Value thirtyShiftPlusOne = rewriter.create<arith::ConstantOp>(
+                loc, rewriter.getI64IntegerAttr((1 << 30) + 1));
+            Value numerator =
+                rewriter.create<arith::ShLIOp>(loc, thirtyShiftPlusOne, k64);
+
+            // Compute: scale.multiplier = numerator / value;
+            Value count64 = rewriter.create<arith::ExtUIOp>(
+                loc, rewriter.getI64Type(), count);
+            Value multiplier =
+                rewriter.create<arith::DivUIOp>(loc, numerator, count64);
+            multiplier = rewriter.create<arith::TruncIOp>(
+                loc, rewriter.getI32Type(), multiplier);
+
+            // Compute: scale.shift = 30 + k
+            Value k8 =
+                rewriter.create<arith::TruncIOp>(loc, rewriter.getI8Type(), k);
+            Value thirty8 = rewriter.create<arith::ConstantOp>(
+                loc, rewriter.getI8IntegerAttr(30));
+            Value shift = rewriter.create<arith::AddIOp>(loc, k8, thirty8);
 
             auto scaled =
                 rewriter
-                    .create<tosa::ApplyScaleOp>(
-                        loc, rewriter.getI32Type(), poolVal, multiplierVal,
-                        shiftVal, rewriter.getBoolAttr(false))
+                    .create<tosa::ApplyScaleOp>(loc, rewriter.getI32Type(),
+                                                poolVal, multiplier, shift,
+                                                rewriter.getBoolAttr(false))
                     .getResult();
 
             // If we have quantization information we need to apply output

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 5a28597052c58..06d548f347da8 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -200,144 +200,160 @@ func.func @max_pool_i32(%arg0: tensor<1x6x34x62xi32>) -> () {
   %0 = "tosa.max_pool2d"(%arg0) {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xi32>)  -> (tensor<1x4x32x62xi32>)
   return
 }
-// -----
-
-// CHECK-LABEL: @avg_pool
-func.func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) {
-  // Initial piece computes the sum of the pooling region, with appropriate padding.
-  // CHECK: [[CONST:%.+]] = arith.constant 0
-  // CHECK: [[PAD:%.+]] = tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
-  // CHECK: [[CONST:%.+]] = arith.constant 0
-  // CHECK: [[POOLINIT:%.+]] = tensor.empty()
-  // CHECK: [[FILL:%.+]] = linalg.fill ins([[CONST]]{{.*}}outs([[POOLINIT]]
-  // CHECK: [[KERNEL:%.+]] = tensor.empty()
-  // CHECK: [[POOL:%.+]] = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins([[PAD]], [[KERNEL]] : tensor<1x8x36x62xf32>, tensor<4x4xf32>) outs([[FILL]] : tensor<1x5x33x62xf32>)
-  // CHECK: [[INIT:%.+]] = tensor.empty()
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins([[POOL]] : tensor<1x5x33x62xf32>) outs([[INIT]] : tensor<1x5x33x62xf32>)
-  // CHECK: ^bb0(%[[BBARG1:[a-zA-Z0-9_]+]]: f32,
-  // CHECK:   [[ZERO:%.0]] = arith.constant 0
-  // CHECK:   [[ONE:%.+]] = arith.constant 1
-  // CHECK:   [[HEIGHT:%.+]] = arith.constant 4
-  // CHECK:   [[WIDTH:%.+]] = arith.constant 32
-  // CHECK:   [[IDX1:%.+]] = linalg.index 1
-  // CHECK:   [[IDX2:%.+]] = linalg.index 2
-
-  // The large block below computes what portion of the kernel is within non-padded input.
-  // CHECK:   [[NY:%.+]] = arith.subi [[HEIGHT]], [[IDX1]]
-  // CHECK:   [[NX:%.+]] = arith.subi [[WIDTH]], [[IDX2]]
-  // CHECK:   [[KH:%.+]] = arith.constant 4
-  // CHECK:   [[PAD0:%.+]] = arith.constant 1
-  // CHECK:   [[SUBP0:%.+]] = arith.subi [[IDX1]], [[PAD0]]
-  // CHECK:   [[P0CMP:%.+]] = arith.cmpi slt, [[SUBP0]], [[ZERO]]
-  // CHECK:   [[SELP0:%.+]] = arith.select [[P0CMP]], [[SUBP0]], [[ZERO]]
-  // CHECK:   [[ADDP0:%.+]] = arith.addi [[KH]], [[SELP0]]
-  // CHECK:   [[PAD1:%.+]] = arith.constant 1
-  // CHECK:   [[SUBP1:%.+]] = arith.subi [[NY]], [[PAD1]]
-  // CHECK:   [[P1CMP:%.+]] = arith.cmpi slt, [[SUBP1]], [[ZERO]]
-  // CHECK:   [[SELP1:%.+]] = arith.select [[P1CMP]], [[SUBP1]], [[ZERO]]
-  // CHECK:   [[ADDP1:%.+]] = arith.addi [[ADDP0]], [[SELP1]]
-  // CHECK:   [[YCMP:%.+]] = arith.cmpi slt, [[ADDP1]], [[ONE]]
-  // CHECK:   [[YSEL:%.+]] = arith.select [[YCMP]], [[ONE]], [[ADDP1]]
-  // CHECK:   [[KW:%.+]] = arith.constant 4 : index
-  // CHECK:   [[PAD2:%.+]] = arith.constant 1 : index
-  // CHECK:   [[SUBP2:%.+]] = arith.subi [[IDX2]], [[PAD2]]
-  // CHECK:   [[P2CMP:%.+]] = arith.cmpi slt, [[SUBP2]], [[ZERO]]
-  // CHECK:   [[SELP2:%.+]] = arith.select [[P2CMP]], [[SUBP2]], [[ZERO]]
-  // CHECK:   [[ADDP2:%.+]] = arith.addi [[KW]], [[SELP2]]
-  // CHECK:   [[PAD3:%.+]] = arith.constant 1 : index
-  // CHECK:   [[SUBP3:%.+]] = arith.subi [[NX]], [[PAD3]]
-  // CHECK:   [[P3CMP:%.+]] = arith.cmpi slt, [[SUBP3]], [[ZERO]]
-  // CHECK:   [[SELP3:%.+]] = arith.select [[P3CMP]], [[SUBP3]], [[ZERO]]
-  // CHECK:   [[ADDP3:%.+]] = arith.addi [[ADDP2]], [[SELP3]]
-  // CHECK:   [[XCMP:%.+]] = arith.cmpi slt, [[ADDP3]], [[ONE]]
-  // CHECK:   [[XSEL:%.+]] = arith.select [[XCMP]], [[ONE]], [[ADDP3]]
-
-  // Given the valid coverage of the pooling region, normalize the summation.
-  // CHECK:   [[C:%.+]] = arith.muli [[YSEL]], [[XSEL]]
-  // CHECK:   [[CI:%.+]] = arith.index_cast [[C]]
-  // CHECK:   [[CF:%.+]] = arith.sitofp [[CI]]
-  // CHECK:   [[RESULT:%.+]] = arith.divf %[[BBARG1]], [[CF]]
-  // CHECK:   linalg.yield [[RESULT]]
-  %0 = "tosa.avg_pool2d"(%arg0) {pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf32>)  -> (tensor<1x5x33x62xf32>)
-  return %0 : tensor<1x5x33x62xf32>
-}
 
 // -----
 
-// CHECK-LABEL: @avg_pool_dyn
-func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>) {
-  // The calculations remain the same as above, only testing for dyn behavior
-  // CHECK: %[[C0:.+]] = arith.constant 0
-  // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
+// CHECK-LABEL: @avg_pool_f32
+func.func @avg_pool_f32(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) {
+  // Apply padding to the input:
+  // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
   // CHECK: %[[PAD:.+]] = tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
-  // CHECK: %[[POOLINIT:.+]] = tensor.empty(%[[BATCH]])
-  // CHECK: %[[FILL:.+]] = linalg.fill
-  // CHECK: %[[KERNEL:.+]] = tensor.empty()
-  // CHECK: %[[POOL:.+]] = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[PAD]], %[[KERNEL]] : tensor<?x8x36x62xf32>, tensor<4x4xf32>) outs(%[[FILL]] : tensor<?x5x33x62xf32>)
-  // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]])
-  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[POOL]] : tensor<?x5x33x62xf32>) outs(%[[INIT]] : tensor<?x5x33x62xf32>)
-  %0 = "tosa.avg_pool2d"(%arg0) {pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<?x6x34x62xf32>)  -> (tensor<?x5x33x62xf32>)
-  return %0 : tensor<?x5x33x62xf32>
+  // CHECK:   tensor.yield %[[F0]] : f32
+
+  // Fill the pooling target:
+  // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+  // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x5x33x62xf32>
+  // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[F0]] : f32) outs(%[[EMPTY]] : tensor<1x5x33x62xf32>)
+
+  // Compute the sum padding:
+  // CHECK: %[[KERNEL:.+]] = tensor.empty() : tensor<4x4xf32>
+  // CHECK: %[[POOL:.+]] = linalg.pooling_nhwc_sum 
+  // CHECK-SAME: dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
+  // CHECK-SAME: ins(%[[PAD]], %[[KERNEL]] : tensor<1x8x36x62xf32>, tensor<4x4xf32>)
+  // CHECK-SAME: outs(%[[FILL]] : tensor<1x5x33x62xf32>)
+
+  // Compute dimension based constants:
+  // CHECK: %[[I1:.+]] = arith.constant 1 : index
+  // CHECK: %[[DIM1:.+]] = tensor.dim %[[POOL]], %[[I1]]
+  // CHECK: %[[I2:.+]] = arith.constant 2 : index
+  // CHECK: %[[DIM2:.+]] = tensor.dim %[[POOL]], %[[I2]]
+  // CHECK: %[[ONE:.+]] = arith.constant 1 : index
+  // CHECK: %[[HEIGHT:.+]] = arith.subi %[[DIM1]], %[[ONE]] : index
+  // CHECK: %[[WIDTH:.+]] = arith.subi %[[DIM2]], %[[ONE]] : index
+
+  // Divide the sum pooling by the number of summed values.
+  // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x5x33x62xf32>
+  // CHECK: %[[GENERIC:.+]] = linalg.generic 
+  // CHECK-SAME: indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+  // CHECK-SAME: ins(%[[POOL]] : tensor<1x5x33x62xf32>)
+  // CHECK-SAME: outs(%[[EMPTY]] : tensor<1x5x33x62xf32>)
+  // CHECK: ^bb0(%[[IN:.+]]: f32, %{{.+}}: f32)
+  // CHECK:   %[[ZERO:.+]] = arith.constant 0
+
+  // Compute how much of the height does not include padding:
+  // CHECK:   %[[STRIDE:.+]] = arith.constant 1
+  // CHECK:   %[[KSIZE:.+]] = arith.constant 4
+  // CHECK:   %[[START:.+]] = linalg.index 1
+  // CHECK:   %[[END:.+]] = arith.subi %[[HEIGHT]], %[[START]]
+  // CHECK:   %[[SRC_START:.+]] = arith.muli %[[START]], %[[STRIDE]]
+  // CHECK:   %[[SRC_END:.+]] = arith.muli %[[END]], %[[STRIDE]]
+  // CHECK:   %[[PAD_START:.+]] = arith.constant 1
+  // CHECK:   %[[START_SUB:.+]] = arith.subi %[[SRC_START]], %[[PAD_START]]
+  // CHECK:   %[[CMP:.+]] = arith.cmpi slt, %[[START_SUB]], %[[ZERO]]
+  // CHECK:   %[[OFFSET:.+]] = arith.select %[[CMP]], %[[START_SUB]], %[[ZERO]]
+  // CHECK:   %[[START_OFFSET:.+]] = arith.addi %[[KSIZE]], %[[OFFSET]]
+  // CHECK:   %[[PAD_END:.+]] = arith.constant 1
+  // CHECK:   %[[END_SUB:.+]] = arith.subi %[[SRC_END]], %[[PAD_END]]
+  // CHECK:   %[[CMP:.+]] = arith.cmpi slt, %[[END_SUB]], %[[ZERO]]
+  // CHECK:   %[[OFFSET:.+]] = arith.select %[[CMP]], %[[END_SUB]], %[[ZERO]]
+  // CHECK:   %[[END_OFFSET:.+]] = arith.addi %[[START_OFFSET]], %[[OFFSET]]
+  // CHECK:   %[[CMP:.+]] = arith.cmpi slt, %[[END_OFFSET]], %[[ONE]]
+  // CHECK:   %[[KHEIGHT:.+]] = arith.select %[[CMP]], %[[ONE]], %[[END_OFFSET]]
+
+  // Compute how much of the width does not include padding:
+  // CHECK:   %[[STRIDE:.+]] = arith.constant 1
+  // CHECK:   %[[KSIZE:.+]] = arith.constant 4
+  // CHECK:   %[[START:.+]] = linalg.index 2
+  // CHECK:   %[[END:.+]] = arith.subi %[[WIDTH]], %[[START]]
+  // CHECK:   %[[SRC_START:.+]] = arith.muli %[[START]], %[[STRIDE]]
+  // CHECK:   %[[SRC_END:.+]] = arith.muli %[[END]], %[[STRIDE]]
+  // CHECK:   %[[PAD_START:.+]] = arith.constant 1
+  // CHECK:   %[[START_SUB:.+]] = arith.subi %[[SRC_START]], %[[PAD_START]]
+  // CHECK:   %[[CMP:.+]] = arith.cmpi slt, %[[START_SUB]], %[[ZERO]]
+  // CHECK:   %[[OFFSET:.+]] = arith.select %[[CMP]], %[[START_SUB]], %[[ZERO]]
+  // CHECK:   %[[START_OFFSET:.+]] = arith.addi %[[KSIZE]], %[[OFFSET]]
+  // CHECK:   %[[PAD_END:.+]] = arith.constant 1
+  // CHECK:   %[[END_SUB:.+]] = arith.subi %[[SRC_END]], %[[PAD_END]]
+  // CHECK:   %[[CMP:.+]] = arith.cmpi slt, %[[END_SUB]], %[[ZERO]]
+  // CHECK:   %[[OFFSET:.+]] = arith.select %[[CMP]], %[[END_SUB]], %[[ZERO]]
+  // CHECK:   %[[END_OFFSET:.+]] = arith.addi %[[START_OFFSET]], %[[OFFSET]]
+  // CHECK:   %[[CMP:.+]] = arith.cmpi slt, %[[END_OFFSET]], %[[ONE]]
+  // CHECK:   %[[KWIDTH:.+]] = arith.select %[[CMP]], %[[ONE]], %[[END_OFFSET]]
+
+  // Divide the summed value by the number of values summed.
+  // CHECK:   %[[COUNT:.+]] = arith.muli %[[KHEIGHT]], %[[KWIDTH]]
+  // CHECK:   %[[CAST:.+]] = arith.index_cast %[[COUNT]]
+  // CHECK:   %[[FLT:.+]] = arith.sitofp %[[CAST]]
+  // CHECK:   %[[DIV:.+]] = arith.divf %[[IN]], %[[FLT]]
+  // CHECK:   linalg.yield %[[DIV]]
+  %0 = "tosa.avg_pool2d"(%arg0) {pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf32>)  -> (tensor<1x5x33x62xf32>)
+  return %0 : tensor<1x5x33x62xf32>
 }
 
 // -----
 
-// CHECK-LABEL: @avg_pool_i8
-func.func @avg_pool_i8(%arg0 : tensor<1x128x128x2xi8>) -> () {
-
-  // CHECK: linalg.pooling_nhwc_sum
-  // CHECK: linalg.generic
-  // CHECK: ^bb0(%[[BBARG1:[a-zA-Z0-9_]+]]: i32,
-
-  // CHECK: %[[INZP:.+]] = arith.constant -128
-  // CHECK: %[[INZP_OFF:.+]] = arith.muli %{{.+}}, %[[INZP]]
-  // CHECK: %[[OFFSETED:.+]] = arith.subi %[[BBARG1]], %[[INZP_OFF]]
-  // CHECK: %[[NUMERATOR:.+]] = arith.constant 1073741825
-  // CHECK: %[[MULTIPLIER:.+]] = arith.divui %[[NUMERATOR]], %{{.+}}
-  // CHECK: %[[SHIFT:.+]] = arith.constant 30
-  // CHECK: %[[SCALE:.+]] = "tosa.apply_scale"(%{{.+}}, %[[MULTIPLIER]], %[[SHIFT]]) {double_round = false}
-  // CHECK: %[[OUTZP:.+]] = arith.constant -128
-  // CHECK: %[[OUT:.+]] = arith.addi %[[SCALE]], %[[OUTZP]]
-  // CHECK: %[[MIN:.+]] = arith.constant -128
-  // CHECK: %[[MAX:.+]] = arith.constant 127
-  // CHECK: %[[CMP_MIN:.+]] = arith.cmpi slt, %[[OUT]], %[[MIN]]
-  // CHECK: %[[CLMP_MIN:.+]] = arith.select %[[CMP_MIN]], %[[MIN]], %[[OUT]]
-  // CHECK: %[[CMP_MAX:.+]] = arith.cmpi slt, %[[MAX]], %[[OUT]]
-  // CHECK: %[[CLMP_MAX:.+]] = arith.select %[[CMP_MAX]], %[[MAX]], %[[CLMP_MIN]]
-  // CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLMP_MAX]]
+// CHECK-LABLE: @avg_pool_i8
+func.func @avg_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) {
+  // CHECK: %[[GENERIC:.+]] = linalg.generic 
+  // CHECK-SAME: indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+  // CHECK-SAME: ins(%[[POOL]] : tensor<1x5x33x62xi32>)
+  // CHECK-SAME: outs(%[[EMPTY]] : tensor<1x5x33x62xi8>)
+  // CHECK: ^bb0(%[[IN:.+]]: i32, %{{.+}}: i8)
+
+  // Only 
diff erent behavior is how the division is performed.
+  // First we compute the mul and shift values for average pool:
+  // CHECK: %[[COUNT:.+]] = arith.muli %21, %35
+  // CHECK: %[[ICAST:.+]] = arith.index_cast %[[COUNT]]
+  // CHECK: %[[C1:.+]] = arith.constant 1
+  // CHECK: %[[C32:.+]] = arith.constant 32
+  // CHECK: %[[ISUB:.+]] = arith.subi %[[ICAST]], %[[C1]]
+  // CHECK: %[[CTLZ:.+]] = math.ctlz %[[ISUB]]
+  // CHECK: %[[SUB:.+]] = arith.subi %[[C32]], %[[CTLZ]]
+  // CHECK: %[[EXT:.+]] = arith.extui %[[SUB]]
+  // CHECK: %[[CBIG:.+]] = arith.constant 1073741825
+  // CHECK: %[[SHL:.+]] = arith.shli %[[CBIG]], %[[EXT]]
+  // CHECK: %[[IEXT:.+]] = arith.extui %[[ICAST]]
+  // CHECK: %[[DIV:.+]] = arith.divui %[[SHL]], %[[IEXT]]
+  // CHECK: %[[TRUNC_MUL:.+]] = arith.trunci %[[DIV]]
+  // CHECK: %[[TRUNC_SHIFT:.+]] = arith.trunci %[[SUB]]
+  // CHECK: %[[C30:.+]] = arith.constant 30
+  // CHECK: %[[SHIFT:.+]] = arith.addi %[[TRUNC_SHIFT]], %[[C30]] : i8
+  // CHECK: %[[SCALED:.+]] = "tosa.apply_scale"(%[[IN]], %[[TRUNC_MUL]], %[[SHIFT]]) {double_round = false}
+
+  // Perform the normalization.
+  // CHECK: %[[CMIN:.+]] = arith.constant -128
+  // CHECK: %[[CMAX:.+]] = arith.constant 127
+  // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[SCALED]], %[[CMIN]]
+  // CHECK: %[[SEL:.+]] = arith.select %[[CMP]], %[[CMIN]], %[[SCALED]]
+  // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[CMAX]], %[[SCALED]]
+  // CHECK: %[[CLAMP:.+]] = arith.select %[[CMP]], %[[CMAX]], %[[SEL]]
+  // CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLAMP]]
   // CHECK: linalg.yield %[[TRUNC]]
-  %0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 4, 4>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.unary_quant<input_zp = -128, output_zp = -128>, stride = array<i64: 4, 4>} : (tensor<1x128x128x2xi8>) -> tensor<1x32x32x2xi8>
-  return
+  %0 = "tosa.avg_pool2d"(%arg0) {pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xi8>)  -> (tensor<1x5x33x62xi8>)
+  return %0 : tensor<1x5x33x62xi8>
 }
 
 // -----
 
-// CHECK-LABEL: @avg_pool_i16
-func.func @avg_pool_i16(%arg0 : tensor<1x128x128x2xi16>) -> () {
-
-  // CHECK: linalg.pooling_nhwc_sum
-  // CHECK: linalg.generic
-  // CHECK: ^bb0(%[[BBARG1:[a-zA-Z0-9_]+]]: i32,
-
-  // CHECK: %[[INZP:.+]] = arith.constant -128
-  // CHECK: %[[INZP_OFF:.+]] = arith.muli %{{.+}}, %[[INZP]]
-  // CHECK: %[[OFFSETED:.+]] = arith.subi %[[BBARG1]], %[[INZP_OFF]]
-  // CHECK: %[[NUMERATOR:.+]] = arith.constant 1073741825
-  // CHECK: %[[MULTIPLIER:.+]] = arith.divui %[[NUMERATOR]], %{{.+}}
-  // CHECK: %[[SHIFT:.+]] = arith.constant 30
-  // CHECK: %[[SCALE:.+]] = "tosa.apply_scale"(%{{.+}}, %[[MULTIPLIER]], %[[SHIFT]]) {double_round = false}
-  // CHECK: %[[OUTZP:.+]] = arith.constant -128
-  // CHECK: %[[OUT:.+]] = arith.addi %[[SCALE]], %[[OUTZP]]
-  // CHECK: %[[MIN:.+]] = arith.constant -32768
-  // CHECK: %[[MAX:.+]] = arith.constant 32767
-  // CHECK: %[[CMP_MIN:.+]] = arith.cmpi slt, %[[OUT]], %[[MIN]]
-  // CHECK: %[[CLMP_MIN:.+]] = arith.select %[[CMP_MIN]], %[[MIN]], %[[OUT]]
-  // CHECK: %[[CMP_MAX:.+]] = arith.cmpi slt, %[[MAX]], %[[OUT]]
-  // CHECK: %[[CLMP_MAX:.+]] = arith.select %[[CMP_MAX]], %[[MAX]], %[[CLMP_MIN]]
-  // CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLMP_MAX]]
-  // CHECK: linalg.yield %[[TRUNC]]
-  %0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 4, 4>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.unary_quant<input_zp = -128, output_zp = -128>, stride = array<i64: 4, 4>} : (tensor<1x128x128x2xi16>) -> tensor<1x32x32x2xi16>
-  return
+// CHECK-LABEL: @avg_pool_dyn
+func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>) {
+  // CHECK: %[[C0:.+]] = arith.constant 0 : index
+  // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
+  // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+  // CHECK: %[[PADDED:.+]] = tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
+  // CHECK:   tensor.yield %[[F0]]
+  // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+  // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x5x33x62xf32>
+  // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[F0]] : f32) outs(%[[EMPTY]] : tensor<?x5x33x62xf32>)
+  // CHECK: %[[KERNEL:.+]] = tensor.empty() : tensor<4x4xf32>
+  // CHECK: %[[POOL:.+]] = linalg.pooling_nhwc_sum 
+  // CHECK-SAME: dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>
+  // CHECK-SAME: ins(%[[PADDED]], %[[KERNEL]] : tensor<?x8x36x62xf32>, tensor<4x4xf32>)
+  // CHECK-SAME: outs(%[[FILL]] : tensor<?x5x33x62xf32>) -> tensor<?x5x33x62xf32>
+  // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x5x33x62xf32>
+  // CHECK: %[[GENERIC:.+]] = linalg.generic
+  %0 = "tosa.avg_pool2d"(%arg0) {pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<?x6x34x62xf32>)  -> (tensor<?x5x33x62xf32>)
+  return %0 : tensor<?x5x33x62xf32>
 }
 
 // -----


        


More information about the Mlir-commits mailing list