[Mlir-commits] [mlir] [TOSA] Fix avgpool2d accum in wider type (PR #80849)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 6 07:35:33 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Thomas Preud'homme (RoboTux)
<details>
<summary>Changes</summary>
Truncate result of avgpool when accumulation is done in a wider type
than the result element type, such as when doing a f16 avgpool2d with a
f32 accumulator type.
---
Full diff: https://github.com/llvm/llvm-project/pull/80849.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+4)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+91)
``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 8dc2d27bd545ff..607a603cca810f 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -890,6 +890,10 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
auto countF = rewriter.create<arith::SIToFPOp>(loc, accETy, count);
poolVal = rewriter.create<arith::DivFOp>(loc, poolVal, countF)
->getResult(0);
+ if (accETy.getIntOrFloatBitWidth() >
+ resultETy.getIntOrFloatBitWidth())
+ poolVal =
+ rewriter.create<arith::TruncFOp>(loc, resultETy, poolVal);
} else {
// If we have quantization information we need to apply an offset
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 6616ea7cf699fa..51ebcad0797807 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -306,6 +306,97 @@ func.func @avg_pool_f32(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>)
// -----
+// CHECK-LABEL: @avg_pool_f16_f32acc
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @avg_pool_f16_f32acc(%arg0: tensor<1x6x34x62xf16>) -> (tensor<1x5x33x62xf16>) {
+ // Apply padding to the input:
+ // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
+ // CHECK: %[[PAD:.+]] = tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
+ // CHECK: tensor.yield %[[F0]] : f16
+
+ // Fill the pooling target:
+ // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x5x33x62xf32>
+ // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[F0]] : f32) outs(%[[EMPTY]] : tensor<1x5x33x62xf32>)
+
+ // Compute the sum padding:
+ // CHECK: %[[KERNEL:.+]] = tensor.empty() : tensor<4x4xf32>
+ // CHECK: %[[POOL:.+]] = linalg.pooling_nhwc_sum
+ // CHECK-SAME: dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
+ // CHECK-SAME: ins(%[[PAD]], %[[KERNEL]] : tensor<1x8x36x62xf16>, tensor<4x4xf32>)
+ // CHECK-SAME: outs(%[[FILL]] : tensor<1x5x33x62xf32>)
+
+ // Compute dimension based constants:
+ // CHECK: %[[I1:.+]] = arith.constant 1 : index
+ // CHECK: %[[DIM1:.+]] = tensor.dim %[[POOL]], %[[I1]]
+ // CHECK: %[[I2:.+]] = arith.constant 2 : index
+ // CHECK: %[[DIM2:.+]] = tensor.dim %[[POOL]], %[[I2]]
+ // CHECK: %[[ONE:.+]] = arith.constant 1 : index
+ // CHECK: %[[HEIGHT:.+]] = arith.subi %[[DIM1]], %[[ONE]] : index
+ // CHECK: %[[WIDTH:.+]] = arith.subi %[[DIM2]], %[[ONE]] : index
+
+ // Divide the sum pooling by the number of summed values.
+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x5x33x62xf16>
+ // CHECK: %[[GENERIC:.+]] = linalg.generic
+ // CHECK-SAME: indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ // CHECK-SAME: ins(%[[POOL]] : tensor<1x5x33x62xf32>)
+ // CHECK-SAME: outs(%[[EMPTY]] : tensor<1x5x33x62xf16>)
+ // CHECK: ^bb0(%[[IN:.+]]: f32, %{{.+}}: f16)
+ // CHECK: %[[ZERO:.+]] = arith.constant 0
+
+ // Compute how much of the height does not include padding:
+ // CHECK: %[[STRIDE:.+]] = arith.constant 1
+ // CHECK: %[[KSIZE:.+]] = arith.constant 4
+ // CHECK: %[[START:.+]] = linalg.index 1
+ // CHECK: %[[END:.+]] = arith.subi %[[HEIGHT]], %[[START]]
+ // CHECK: %[[SRC_START:.+]] = arith.muli %[[START]], %[[STRIDE]]
+ // CHECK: %[[SRC_END:.+]] = arith.muli %[[END]], %[[STRIDE]]
+ // CHECK: %[[PAD_START:.+]] = arith.constant 1
+ // CHECK: %[[START_SUB:.+]] = arith.subi %[[SRC_START]], %[[PAD_START]]
+ // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[START_SUB]], %[[ZERO]]
+ // CHECK: %[[OFFSET:.+]] = arith.select %[[CMP]], %[[START_SUB]], %[[ZERO]]
+ // CHECK: %[[START_OFFSET:.+]] = arith.addi %[[KSIZE]], %[[OFFSET]]
+ // CHECK: %[[PAD_END:.+]] = arith.constant 1
+ // CHECK: %[[END_SUB:.+]] = arith.subi %[[SRC_END]], %[[PAD_END]]
+ // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[END_SUB]], %[[ZERO]]
+ // CHECK: %[[OFFSET:.+]] = arith.select %[[CMP]], %[[END_SUB]], %[[ZERO]]
+ // CHECK: %[[END_OFFSET:.+]] = arith.addi %[[START_OFFSET]], %[[OFFSET]]
+ // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[END_OFFSET]], %[[ONE]]
+ // CHECK: %[[KHEIGHT:.+]] = arith.select %[[CMP]], %[[ONE]], %[[END_OFFSET]]
+
+ // Compute how much of the width does not include padding:
+ // CHECK: %[[STRIDE:.+]] = arith.constant 1
+ // CHECK: %[[KSIZE:.+]] = arith.constant 4
+ // CHECK: %[[START:.+]] = linalg.index 2
+ // CHECK: %[[END:.+]] = arith.subi %[[WIDTH]], %[[START]]
+ // CHECK: %[[SRC_START:.+]] = arith.muli %[[START]], %[[STRIDE]]
+ // CHECK: %[[SRC_END:.+]] = arith.muli %[[END]], %[[STRIDE]]
+ // CHECK: %[[PAD_START:.+]] = arith.constant 1
+ // CHECK: %[[START_SUB:.+]] = arith.subi %[[SRC_START]], %[[PAD_START]]
+ // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[START_SUB]], %[[ZERO]]
+ // CHECK: %[[OFFSET:.+]] = arith.select %[[CMP]], %[[START_SUB]], %[[ZERO]]
+ // CHECK: %[[START_OFFSET:.+]] = arith.addi %[[KSIZE]], %[[OFFSET]]
+ // CHECK: %[[PAD_END:.+]] = arith.constant 1
+ // CHECK: %[[END_SUB:.+]] = arith.subi %[[SRC_END]], %[[PAD_END]]
+ // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[END_SUB]], %[[ZERO]]
+ // CHECK: %[[OFFSET:.+]] = arith.select %[[CMP]], %[[END_SUB]], %[[ZERO]]
+ // CHECK: %[[END_OFFSET:.+]] = arith.addi %[[START_OFFSET]], %[[OFFSET]]
+ // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[END_OFFSET]], %[[ONE]]
+ // CHECK: %[[KWIDTH:.+]] = arith.select %[[CMP]], %[[ONE]], %[[END_OFFSET]]
+
+ // Divide the summed value by the number of values summed.
+ // CHECK: %[[COUNT:.+]] = arith.muli %[[KHEIGHT]], %[[KWIDTH]]
+ // CHECK: %[[CAST:.+]] = arith.index_cast %[[COUNT]]
+ // CHECK: %[[FLT:.+]] = arith.sitofp %[[CAST]]
+ // CHECK: %[[DIV:.+]] = arith.divf %[[IN]], %[[FLT]]
+ // CHECK: %[[TRUNC:.+]] = arith.truncf %[[DIV]]
+ // CHECK: linalg.yield %[[TRUNC]]
+ %0 = tosa.avg_pool2d %arg0 {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf16>) -> tensor<1x5x33x62xf16>
+ return %0 : tensor<1x5x33x62xf16>
+}
+
+// -----
+
// CHECK-LABEL: @avg_pool_i8
func.func @avg_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) {
// CHECK: %[[GENERIC:.+]] = linalg.generic
``````````
</details>
https://github.com/llvm/llvm-project/pull/80849
More information about the Mlir-commits
mailing list