[Mlir-commits] [mlir] b67b024 - [mlir][tosa] Update tosa.avg_pool2d for bit-exact TOSA behavior
Robert Suderman
llvmlistbot at llvm.org
Wed Jan 25 10:52:25 PST 2023
Author: Rob Suderman
Date: 2023-01-25T18:51:12Z
New Revision: b67b024d583d328420d7a46f0897e02cdd4ebd68
URL: https://github.com/llvm/llvm-project/commit/b67b024d583d328420d7a46f0897e02cdd4ebd68
DIFF: https://github.com/llvm/llvm-project/commit/b67b024d583d328420d7a46f0897e02cdd4ebd68.diff
LOG: [mlir][tosa] Update tosa.avg_pool2d for bit-exact TOSA behavior
The normalization component of average pool has a very specific
rounding behavior for compouting the division for floating
point values. Updated so that the bit-exact version is implemented.
Also includes a fix for computing the stride part of the average pool
operation.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D141339
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index f9732efaa27eb..fac8d65eb172e 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -817,12 +817,17 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
// Normalize the summed value by the number of elements grouped in each
// pool.
- auto poolingOpTy = poolingOp.getType().cast<ShapedType>();
- auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
+ Value iH = rewriter.create<tensor::DimOp>(loc, poolingOp, 1);
+ Value iW = rewriter.create<tensor::DimOp>(loc, poolingOp, 2);
+
+ auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ iH = rewriter.create<arith::SubIOp>(loc, iH, one);
+ iW = rewriter.create<arith::SubIOp>(loc, iW, one);
Value genericEmptyTensor = rewriter.create<tensor::EmptyOp>(
loc, resultTy.getShape(), resultETy, dynamicDims);
+ auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
auto genericOp = rewriter.create<linalg::GenericOp>(
loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
ValueRange{genericEmptyTensor},
@@ -830,60 +835,59 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
getNParallelLoopsAttrs(resultTy.getRank()),
[&](OpBuilder &b, Location loc, ValueRange args) {
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto iH = rewriter.create<arith::ConstantIndexOp>(
- loc, poolingOpTy.getDimSize(1) - 1);
- auto iW = rewriter.create<arith::ConstantIndexOp>(
- loc, poolingOpTy.getDimSize(2) - 1);
-
- // Compute the indices from either end.
- auto y0 = rewriter.create<linalg::IndexOp>(loc, 1);
- auto x0 = rewriter.create<linalg::IndexOp>(loc, 2);
- auto y1 = rewriter.create<arith::SubIOp>(loc, iH, y0);
- auto x1 = rewriter.create<arith::SubIOp>(loc, iW, x0);
// Determines what the portion of valid input is covered by the
// kernel.
- auto padFn = [&](Value v, Value x, int64_t pad) -> Value {
+ auto padFn = [&](Value valid, Value pos, int64_t pad) -> Value {
if (pad == 0)
- return v;
+ return valid;
auto padVal = rewriter.create<arith::ConstantIndexOp>(loc, pad);
- Value dx = rewriter.create<arith::SubIOp>(loc, x, padVal);
+ Value dpos = rewriter.create<arith::SubIOp>(loc, pos, padVal);
+
+ Value cmp = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::slt, dpos, zero);
+ Value offset =
+ rewriter.create<arith::SelectOp>(loc, cmp, dpos, zero);
+ return rewriter.create<arith::AddIOp>(loc, valid, offset)
+ ->getResult(0);
+ };
+ auto coverageFn = [&](int64_t i, Value isize) -> Value {
+ Value strideVal =
+ rewriter.create<arith::ConstantIndexOp>(loc, stride[i - 1]);
+ Value val =
+ rewriter.create<arith::ConstantIndexOp>(loc, kernel[i - 1]);
+
+ // Find the position relative to the input tensor's ends.
+ Value left = rewriter.create<linalg::IndexOp>(loc, i);
+ Value right = rewriter.create<arith::SubIOp>(loc, isize, left);
+ left = rewriter.create<arith::MulIOp>(loc, left, strideVal);
+ right = rewriter.create<arith::MulIOp>(loc, right, strideVal);
+
+ // Determine how much padding was included.
+ val = padFn(val, left, pad[i * 2]);
+ val = padFn(val, right, pad[i * 2 + 1]);
Value cmp = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, dx, zero);
- Value offset = rewriter.create<arith::SelectOp>(loc, cmp, dx, zero);
- return rewriter.create<arith::AddIOp>(loc, v, offset)->getResult(0);
+ loc, arith::CmpIPredicate::slt, val, one);
+ return rewriter.create<arith::SelectOp>(loc, cmp, one, val);
};
- // Compute the vertical component of coverage.
- auto kH0 = rewriter.create<arith::ConstantIndexOp>(loc, kernel[0]);
- auto kH1 = padFn(kH0, y0, pad[2]);
- auto kH2 = padFn(kH1, y1, pad[3]);
- auto kHCmp = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, kH2, one);
- auto kH3 = rewriter.create<arith::SelectOp>(loc, kHCmp, one, kH2);
-
- // compute the horizontal component of coverage.
- auto kW0 = rewriter.create<arith::ConstantIndexOp>(loc, kernel[1]);
- auto kW1 = padFn(kW0, x0, pad[4]);
- auto kW2 = padFn(kW1, x1, pad[5]);
- auto kWCmp = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, kW2, one);
- auto kW3 = rewriter.create<arith::SelectOp>(loc, kWCmp, one, kW2);
+ // Compute the indices from either end.
+ Value kH3 = coverageFn(1, iH);
+ Value kW3 = coverageFn(2, iW);
// Compute the total number of elements and normalize.
- Value count = rewriter.create<arith::MulIOp>(loc, kH3, kW3);
- auto countI = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI32Type(), count);
+ auto count = rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getI32Type(),
+ rewriter.create<arith::MulIOp>(loc, kH3, kW3));
// Divide by the number of summed values. For floats this is just
// a div however for quantized values input normalization had
// to be applied.
Value poolVal = args[0];
if (accETy.isa<FloatType>()) {
- auto countF = rewriter.create<arith::SIToFPOp>(loc, accETy, countI);
+ auto countF = rewriter.create<arith::SIToFPOp>(loc, accETy, count);
poolVal = rewriter.create<arith::DivFOp>(loc, poolVal, countF)
->getResult(0);
} else {
@@ -895,33 +899,52 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
auto inputZp = rewriter.create<arith::ConstantOp>(
loc, b.getIntegerAttr(accETy, quantizationInfo.getInputZp()));
Value offset =
- rewriter.create<arith::MulIOp>(loc, accETy, countI, inputZp);
+ rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp);
poolVal =
rewriter.create<arith::SubIOp>(loc, accETy, poolVal, offset);
}
- // Compute the multiplier and shift values for the quantization
- // normalization. Preferably we would want to compute more bits
- // however 32-bits should be enough for compute. Honestly we
- // should probably straight divide.
- int64_t numerator = ((1 << 30) + 1);
- int64_t shift = 30;
-
- Value numeratorVal = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(numerator));
- Value multiplierVal =
- rewriter
- .create<arith::DivUIOp>(loc, rewriter.getI32Type(),
- numeratorVal, countI)
- .getResult();
- Value shiftVal = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI8IntegerAttr(shift));
+ // Compute: k = 32 - count_leading_zeros(value - 1)
+ Value one32 = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(1));
+ Value thirtyTwo32 = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(32));
+
+ Value countSubOne =
+ rewriter.create<arith::SubIOp>(loc, count, one32);
+ Value leadingZeros =
+ rewriter.create<math::CountLeadingZerosOp>(loc, countSubOne);
+ Value k =
+ rewriter.create<arith::SubIOp>(loc, thirtyTwo32, leadingZeros);
+
+ // Compute: numerator = ((1 << 30) + 1) << k
+ Value k64 =
+ rewriter.create<arith::ExtUIOp>(loc, rewriter.getI64Type(), k);
+ Value thirtyShiftPlusOne = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI64IntegerAttr((1 << 30) + 1));
+ Value numerator =
+ rewriter.create<arith::ShLIOp>(loc, thirtyShiftPlusOne, k64);
+
+ // Compute: scale.multiplier = numerator / value;
+ Value count64 = rewriter.create<arith::ExtUIOp>(
+ loc, rewriter.getI64Type(), count);
+ Value multiplier =
+ rewriter.create<arith::DivUIOp>(loc, numerator, count64);
+ multiplier = rewriter.create<arith::TruncIOp>(
+ loc, rewriter.getI32Type(), multiplier);
+
+ // Compute: scale.shift = 30 + k
+ Value k8 =
+ rewriter.create<arith::TruncIOp>(loc, rewriter.getI8Type(), k);
+ Value thirty8 = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI8IntegerAttr(30));
+ Value shift = rewriter.create<arith::AddIOp>(loc, k8, thirty8);
auto scaled =
rewriter
- .create<tosa::ApplyScaleOp>(
- loc, rewriter.getI32Type(), poolVal, multiplierVal,
- shiftVal, rewriter.getBoolAttr(false))
+ .create<tosa::ApplyScaleOp>(loc, rewriter.getI32Type(),
+ poolVal, multiplier, shift,
+ rewriter.getBoolAttr(false))
.getResult();
// If we have quantization information we need to apply output
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 5a28597052c58..06d548f347da8 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -200,144 +200,160 @@ func.func @max_pool_i32(%arg0: tensor<1x6x34x62xi32>) -> () {
%0 = "tosa.max_pool2d"(%arg0) {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xi32>) -> (tensor<1x4x32x62xi32>)
return
}
-// -----
-
-// CHECK-LABEL: @avg_pool
-func.func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) {
- // Initial piece computes the sum of the pooling region, with appropriate padding.
- // CHECK: [[CONST:%.+]] = arith.constant 0
- // CHECK: [[PAD:%.+]] = tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
- // CHECK: [[CONST:%.+]] = arith.constant 0
- // CHECK: [[POOLINIT:%.+]] = tensor.empty()
- // CHECK: [[FILL:%.+]] = linalg.fill ins([[CONST]]{{.*}}outs([[POOLINIT]]
- // CHECK: [[KERNEL:%.+]] = tensor.empty()
- // CHECK: [[POOL:%.+]] = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins([[PAD]], [[KERNEL]] : tensor<1x8x36x62xf32>, tensor<4x4xf32>) outs([[FILL]] : tensor<1x5x33x62xf32>)
- // CHECK: [[INIT:%.+]] = tensor.empty()
- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins([[POOL]] : tensor<1x5x33x62xf32>) outs([[INIT]] : tensor<1x5x33x62xf32>)
- // CHECK: ^bb0(%[[BBARG1:[a-zA-Z0-9_]+]]: f32,
- // CHECK: [[ZERO:%.0]] = arith.constant 0
- // CHECK: [[ONE:%.+]] = arith.constant 1
- // CHECK: [[HEIGHT:%.+]] = arith.constant 4
- // CHECK: [[WIDTH:%.+]] = arith.constant 32
- // CHECK: [[IDX1:%.+]] = linalg.index 1
- // CHECK: [[IDX2:%.+]] = linalg.index 2
-
- // The large block below computes what portion of the kernel is within non-padded input.
- // CHECK: [[NY:%.+]] = arith.subi [[HEIGHT]], [[IDX1]]
- // CHECK: [[NX:%.+]] = arith.subi [[WIDTH]], [[IDX2]]
- // CHECK: [[KH:%.+]] = arith.constant 4
- // CHECK: [[PAD0:%.+]] = arith.constant 1
- // CHECK: [[SUBP0:%.+]] = arith.subi [[IDX1]], [[PAD0]]
- // CHECK: [[P0CMP:%.+]] = arith.cmpi slt, [[SUBP0]], [[ZERO]]
- // CHECK: [[SELP0:%.+]] = arith.select [[P0CMP]], [[SUBP0]], [[ZERO]]
- // CHECK: [[ADDP0:%.+]] = arith.addi [[KH]], [[SELP0]]
- // CHECK: [[PAD1:%.+]] = arith.constant 1
- // CHECK: [[SUBP1:%.+]] = arith.subi [[NY]], [[PAD1]]
- // CHECK: [[P1CMP:%.+]] = arith.cmpi slt, [[SUBP1]], [[ZERO]]
- // CHECK: [[SELP1:%.+]] = arith.select [[P1CMP]], [[SUBP1]], [[ZERO]]
- // CHECK: [[ADDP1:%.+]] = arith.addi [[ADDP0]], [[SELP1]]
- // CHECK: [[YCMP:%.+]] = arith.cmpi slt, [[ADDP1]], [[ONE]]
- // CHECK: [[YSEL:%.+]] = arith.select [[YCMP]], [[ONE]], [[ADDP1]]
- // CHECK: [[KW:%.+]] = arith.constant 4 : index
- // CHECK: [[PAD2:%.+]] = arith.constant 1 : index
- // CHECK: [[SUBP2:%.+]] = arith.subi [[IDX2]], [[PAD2]]
- // CHECK: [[P2CMP:%.+]] = arith.cmpi slt, [[SUBP2]], [[ZERO]]
- // CHECK: [[SELP2:%.+]] = arith.select [[P2CMP]], [[SUBP2]], [[ZERO]]
- // CHECK: [[ADDP2:%.+]] = arith.addi [[KW]], [[SELP2]]
- // CHECK: [[PAD3:%.+]] = arith.constant 1 : index
- // CHECK: [[SUBP3:%.+]] = arith.subi [[NX]], [[PAD3]]
- // CHECK: [[P3CMP:%.+]] = arith.cmpi slt, [[SUBP3]], [[ZERO]]
- // CHECK: [[SELP3:%.+]] = arith.select [[P3CMP]], [[SUBP3]], [[ZERO]]
- // CHECK: [[ADDP3:%.+]] = arith.addi [[ADDP2]], [[SELP3]]
- // CHECK: [[XCMP:%.+]] = arith.cmpi slt, [[ADDP3]], [[ONE]]
- // CHECK: [[XSEL:%.+]] = arith.select [[XCMP]], [[ONE]], [[ADDP3]]
-
- // Given the valid coverage of the pooling region, normalize the summation.
- // CHECK: [[C:%.+]] = arith.muli [[YSEL]], [[XSEL]]
- // CHECK: [[CI:%.+]] = arith.index_cast [[C]]
- // CHECK: [[CF:%.+]] = arith.sitofp [[CI]]
- // CHECK: [[RESULT:%.+]] = arith.divf %[[BBARG1]], [[CF]]
- // CHECK: linalg.yield [[RESULT]]
- %0 = "tosa.avg_pool2d"(%arg0) {pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>)
- return %0 : tensor<1x5x33x62xf32>
-}
// -----
-// CHECK-LABEL: @avg_pool_dyn
-func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>) {
- // The calculations remain the same as above, only testing for dyn behavior
- // CHECK: %[[C0:.+]] = arith.constant 0
- // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
+// CHECK-LABEL: @avg_pool_f32
+func.func @avg_pool_f32(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) {
+ // Apply padding to the input:
+ // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[PAD:.+]] = tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
- // CHECK: %[[POOLINIT:.+]] = tensor.empty(%[[BATCH]])
- // CHECK: %[[FILL:.+]] = linalg.fill
- // CHECK: %[[KERNEL:.+]] = tensor.empty()
- // CHECK: %[[POOL:.+]] = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[PAD]], %[[KERNEL]] : tensor<?x8x36x62xf32>, tensor<4x4xf32>) outs(%[[FILL]] : tensor<?x5x33x62xf32>)
- // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]])
- // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[POOL]] : tensor<?x5x33x62xf32>) outs(%[[INIT]] : tensor<?x5x33x62xf32>)
- %0 = "tosa.avg_pool2d"(%arg0) {pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
- return %0 : tensor<?x5x33x62xf32>
+ // CHECK: tensor.yield %[[F0]] : f32
+
+ // Fill the pooling target:
+ // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x5x33x62xf32>
+ // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[F0]] : f32) outs(%[[EMPTY]] : tensor<1x5x33x62xf32>)
+
+ // Compute the sum padding:
+ // CHECK: %[[KERNEL:.+]] = tensor.empty() : tensor<4x4xf32>
+ // CHECK: %[[POOL:.+]] = linalg.pooling_nhwc_sum
+ // CHECK-SAME: dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
+ // CHECK-SAME: ins(%[[PAD]], %[[KERNEL]] : tensor<1x8x36x62xf32>, tensor<4x4xf32>)
+ // CHECK-SAME: outs(%[[FILL]] : tensor<1x5x33x62xf32>)
+
+ // Compute dimension based constants:
+ // CHECK: %[[I1:.+]] = arith.constant 1 : index
+ // CHECK: %[[DIM1:.+]] = tensor.dim %[[POOL]], %[[I1]]
+ // CHECK: %[[I2:.+]] = arith.constant 2 : index
+ // CHECK: %[[DIM2:.+]] = tensor.dim %[[POOL]], %[[I2]]
+ // CHECK: %[[ONE:.+]] = arith.constant 1 : index
+ // CHECK: %[[HEIGHT:.+]] = arith.subi %[[DIM1]], %[[ONE]] : index
+ // CHECK: %[[WIDTH:.+]] = arith.subi %[[DIM2]], %[[ONE]] : index
+
+ // Divide the sum pooling by the number of summed values.
+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x5x33x62xf32>
+ // CHECK: %[[GENERIC:.+]] = linalg.generic
+ // CHECK-SAME: indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ // CHECK-SAME: ins(%[[POOL]] : tensor<1x5x33x62xf32>)
+ // CHECK-SAME: outs(%[[EMPTY]] : tensor<1x5x33x62xf32>)
+ // CHECK: ^bb0(%[[IN:.+]]: f32, %{{.+}}: f32)
+ // CHECK: %[[ZERO:.+]] = arith.constant 0
+
+ // Compute how much of the height does not include padding:
+ // CHECK: %[[STRIDE:.+]] = arith.constant 1
+ // CHECK: %[[KSIZE:.+]] = arith.constant 4
+ // CHECK: %[[START:.+]] = linalg.index 1
+ // CHECK: %[[END:.+]] = arith.subi %[[HEIGHT]], %[[START]]
+ // CHECK: %[[SRC_START:.+]] = arith.muli %[[START]], %[[STRIDE]]
+ // CHECK: %[[SRC_END:.+]] = arith.muli %[[END]], %[[STRIDE]]
+ // CHECK: %[[PAD_START:.+]] = arith.constant 1
+ // CHECK: %[[START_SUB:.+]] = arith.subi %[[SRC_START]], %[[PAD_START]]
+ // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[START_SUB]], %[[ZERO]]
+ // CHECK: %[[OFFSET:.+]] = arith.select %[[CMP]], %[[START_SUB]], %[[ZERO]]
+ // CHECK: %[[START_OFFSET:.+]] = arith.addi %[[KSIZE]], %[[OFFSET]]
+ // CHECK: %[[PAD_END:.+]] = arith.constant 1
+ // CHECK: %[[END_SUB:.+]] = arith.subi %[[SRC_END]], %[[PAD_END]]
+ // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[END_SUB]], %[[ZERO]]
+ // CHECK: %[[OFFSET:.+]] = arith.select %[[CMP]], %[[END_SUB]], %[[ZERO]]
+ // CHECK: %[[END_OFFSET:.+]] = arith.addi %[[START_OFFSET]], %[[OFFSET]]
+ // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[END_OFFSET]], %[[ONE]]
+ // CHECK: %[[KHEIGHT:.+]] = arith.select %[[CMP]], %[[ONE]], %[[END_OFFSET]]
+
+ // Compute how much of the width does not include padding:
+ // CHECK: %[[STRIDE:.+]] = arith.constant 1
+ // CHECK: %[[KSIZE:.+]] = arith.constant 4
+ // CHECK: %[[START:.+]] = linalg.index 2
+ // CHECK: %[[END:.+]] = arith.subi %[[WIDTH]], %[[START]]
+ // CHECK: %[[SRC_START:.+]] = arith.muli %[[START]], %[[STRIDE]]
+ // CHECK: %[[SRC_END:.+]] = arith.muli %[[END]], %[[STRIDE]]
+ // CHECK: %[[PAD_START:.+]] = arith.constant 1
+ // CHECK: %[[START_SUB:.+]] = arith.subi %[[SRC_START]], %[[PAD_START]]
+ // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[START_SUB]], %[[ZERO]]
+ // CHECK: %[[OFFSET:.+]] = arith.select %[[CMP]], %[[START_SUB]], %[[ZERO]]
+ // CHECK: %[[START_OFFSET:.+]] = arith.addi %[[KSIZE]], %[[OFFSET]]
+ // CHECK: %[[PAD_END:.+]] = arith.constant 1
+ // CHECK: %[[END_SUB:.+]] = arith.subi %[[SRC_END]], %[[PAD_END]]
+ // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[END_SUB]], %[[ZERO]]
+ // CHECK: %[[OFFSET:.+]] = arith.select %[[CMP]], %[[END_SUB]], %[[ZERO]]
+ // CHECK: %[[END_OFFSET:.+]] = arith.addi %[[START_OFFSET]], %[[OFFSET]]
+ // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[END_OFFSET]], %[[ONE]]
+ // CHECK: %[[KWIDTH:.+]] = arith.select %[[CMP]], %[[ONE]], %[[END_OFFSET]]
+
+ // Divide the summed value by the number of values summed.
+ // CHECK: %[[COUNT:.+]] = arith.muli %[[KHEIGHT]], %[[KWIDTH]]
+ // CHECK: %[[CAST:.+]] = arith.index_cast %[[COUNT]]
+ // CHECK: %[[FLT:.+]] = arith.sitofp %[[CAST]]
+ // CHECK: %[[DIV:.+]] = arith.divf %[[IN]], %[[FLT]]
+ // CHECK: linalg.yield %[[DIV]]
+ %0 = "tosa.avg_pool2d"(%arg0) {pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>)
+ return %0 : tensor<1x5x33x62xf32>
}
// -----
-// CHECK-LABEL: @avg_pool_i8
-func.func @avg_pool_i8(%arg0 : tensor<1x128x128x2xi8>) -> () {
-
- // CHECK: linalg.pooling_nhwc_sum
- // CHECK: linalg.generic
- // CHECK: ^bb0(%[[BBARG1:[a-zA-Z0-9_]+]]: i32,
-
- // CHECK: %[[INZP:.+]] = arith.constant -128
- // CHECK: %[[INZP_OFF:.+]] = arith.muli %{{.+}}, %[[INZP]]
- // CHECK: %[[OFFSETED:.+]] = arith.subi %[[BBARG1]], %[[INZP_OFF]]
- // CHECK: %[[NUMERATOR:.+]] = arith.constant 1073741825
- // CHECK: %[[MULTIPLIER:.+]] = arith.divui %[[NUMERATOR]], %{{.+}}
- // CHECK: %[[SHIFT:.+]] = arith.constant 30
- // CHECK: %[[SCALE:.+]] = "tosa.apply_scale"(%{{.+}}, %[[MULTIPLIER]], %[[SHIFT]]) {double_round = false}
- // CHECK: %[[OUTZP:.+]] = arith.constant -128
- // CHECK: %[[OUT:.+]] = arith.addi %[[SCALE]], %[[OUTZP]]
- // CHECK: %[[MIN:.+]] = arith.constant -128
- // CHECK: %[[MAX:.+]] = arith.constant 127
- // CHECK: %[[CMP_MIN:.+]] = arith.cmpi slt, %[[OUT]], %[[MIN]]
- // CHECK: %[[CLMP_MIN:.+]] = arith.select %[[CMP_MIN]], %[[MIN]], %[[OUT]]
- // CHECK: %[[CMP_MAX:.+]] = arith.cmpi slt, %[[MAX]], %[[OUT]]
- // CHECK: %[[CLMP_MAX:.+]] = arith.select %[[CMP_MAX]], %[[MAX]], %[[CLMP_MIN]]
- // CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLMP_MAX]]
+// CHECK-LABLE: @avg_pool_i8
+func.func @avg_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) {
+ // CHECK: %[[GENERIC:.+]] = linalg.generic
+ // CHECK-SAME: indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ // CHECK-SAME: ins(%[[POOL]] : tensor<1x5x33x62xi32>)
+ // CHECK-SAME: outs(%[[EMPTY]] : tensor<1x5x33x62xi8>)
+ // CHECK: ^bb0(%[[IN:.+]]: i32, %{{.+}}: i8)
+
+ // Only
diff erent behavior is how the division is performed.
+ // First we compute the mul and shift values for average pool:
+ // CHECK: %[[COUNT:.+]] = arith.muli %21, %35
+ // CHECK: %[[ICAST:.+]] = arith.index_cast %[[COUNT]]
+ // CHECK: %[[C1:.+]] = arith.constant 1
+ // CHECK: %[[C32:.+]] = arith.constant 32
+ // CHECK: %[[ISUB:.+]] = arith.subi %[[ICAST]], %[[C1]]
+ // CHECK: %[[CTLZ:.+]] = math.ctlz %[[ISUB]]
+ // CHECK: %[[SUB:.+]] = arith.subi %[[C32]], %[[CTLZ]]
+ // CHECK: %[[EXT:.+]] = arith.extui %[[SUB]]
+ // CHECK: %[[CBIG:.+]] = arith.constant 1073741825
+ // CHECK: %[[SHL:.+]] = arith.shli %[[CBIG]], %[[EXT]]
+ // CHECK: %[[IEXT:.+]] = arith.extui %[[ICAST]]
+ // CHECK: %[[DIV:.+]] = arith.divui %[[SHL]], %[[IEXT]]
+ // CHECK: %[[TRUNC_MUL:.+]] = arith.trunci %[[DIV]]
+ // CHECK: %[[TRUNC_SHIFT:.+]] = arith.trunci %[[SUB]]
+ // CHECK: %[[C30:.+]] = arith.constant 30
+ // CHECK: %[[SHIFT:.+]] = arith.addi %[[TRUNC_SHIFT]], %[[C30]] : i8
+ // CHECK: %[[SCALED:.+]] = "tosa.apply_scale"(%[[IN]], %[[TRUNC_MUL]], %[[SHIFT]]) {double_round = false}
+
+ // Perform the normalization.
+ // CHECK: %[[CMIN:.+]] = arith.constant -128
+ // CHECK: %[[CMAX:.+]] = arith.constant 127
+ // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[SCALED]], %[[CMIN]]
+ // CHECK: %[[SEL:.+]] = arith.select %[[CMP]], %[[CMIN]], %[[SCALED]]
+ // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[CMAX]], %[[SCALED]]
+ // CHECK: %[[CLAMP:.+]] = arith.select %[[CMP]], %[[CMAX]], %[[SEL]]
+ // CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLAMP]]
// CHECK: linalg.yield %[[TRUNC]]
- %0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 4, 4>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.unary_quant<input_zp = -128, output_zp = -128>, stride = array<i64: 4, 4>} : (tensor<1x128x128x2xi8>) -> tensor<1x32x32x2xi8>
- return
+ %0 = "tosa.avg_pool2d"(%arg0) {pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>)
+ return %0 : tensor<1x5x33x62xi8>
}
// -----
-// CHECK-LABEL: @avg_pool_i16
-func.func @avg_pool_i16(%arg0 : tensor<1x128x128x2xi16>) -> () {
-
- // CHECK: linalg.pooling_nhwc_sum
- // CHECK: linalg.generic
- // CHECK: ^bb0(%[[BBARG1:[a-zA-Z0-9_]+]]: i32,
-
- // CHECK: %[[INZP:.+]] = arith.constant -128
- // CHECK: %[[INZP_OFF:.+]] = arith.muli %{{.+}}, %[[INZP]]
- // CHECK: %[[OFFSETED:.+]] = arith.subi %[[BBARG1]], %[[INZP_OFF]]
- // CHECK: %[[NUMERATOR:.+]] = arith.constant 1073741825
- // CHECK: %[[MULTIPLIER:.+]] = arith.divui %[[NUMERATOR]], %{{.+}}
- // CHECK: %[[SHIFT:.+]] = arith.constant 30
- // CHECK: %[[SCALE:.+]] = "tosa.apply_scale"(%{{.+}}, %[[MULTIPLIER]], %[[SHIFT]]) {double_round = false}
- // CHECK: %[[OUTZP:.+]] = arith.constant -128
- // CHECK: %[[OUT:.+]] = arith.addi %[[SCALE]], %[[OUTZP]]
- // CHECK: %[[MIN:.+]] = arith.constant -32768
- // CHECK: %[[MAX:.+]] = arith.constant 32767
- // CHECK: %[[CMP_MIN:.+]] = arith.cmpi slt, %[[OUT]], %[[MIN]]
- // CHECK: %[[CLMP_MIN:.+]] = arith.select %[[CMP_MIN]], %[[MIN]], %[[OUT]]
- // CHECK: %[[CMP_MAX:.+]] = arith.cmpi slt, %[[MAX]], %[[OUT]]
- // CHECK: %[[CLMP_MAX:.+]] = arith.select %[[CMP_MAX]], %[[MAX]], %[[CLMP_MIN]]
- // CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLMP_MAX]]
- // CHECK: linalg.yield %[[TRUNC]]
- %0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 4, 4>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.unary_quant<input_zp = -128, output_zp = -128>, stride = array<i64: 4, 4>} : (tensor<1x128x128x2xi16>) -> tensor<1x32x32x2xi16>
- return
+// CHECK-LABEL: @avg_pool_dyn
+func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>) {
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
+ // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[PADDED:.+]] = tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
+ // CHECK: tensor.yield %[[F0]]
+ // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x5x33x62xf32>
+ // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[F0]] : f32) outs(%[[EMPTY]] : tensor<?x5x33x62xf32>)
+ // CHECK: %[[KERNEL:.+]] = tensor.empty() : tensor<4x4xf32>
+ // CHECK: %[[POOL:.+]] = linalg.pooling_nhwc_sum
+ // CHECK-SAME: dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>
+ // CHECK-SAME: ins(%[[PADDED]], %[[KERNEL]] : tensor<?x8x36x62xf32>, tensor<4x4xf32>)
+ // CHECK-SAME: outs(%[[FILL]] : tensor<?x5x33x62xf32>) -> tensor<?x5x33x62xf32>
+ // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x5x33x62xf32>
+ // CHECK: %[[GENERIC:.+]] = linalg.generic
+ %0 = "tosa.avg_pool2d"(%arg0) {pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
+ return %0 : tensor<?x5x33x62xf32>
}
// -----
More information about the Mlir-commits
mailing list