[Mlir-commits] [mlir] [TOSA] Fix avgpool2d accum in wider type (PR #80849)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 6 07:35:33 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Thomas Preud'homme (RoboTux)

<details>
<summary>Changes</summary>

Truncate result of avgpool when accumulation is done in a wider type
than the result element type, such as when doing a f16 avgpool2d with a
f32 accumulator type.


---
Full diff: https://github.com/llvm/llvm-project/pull/80849.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+4) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+91) 


``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 8dc2d27bd545ff..607a603cca810f 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -890,6 +890,10 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
             auto countF = rewriter.create<arith::SIToFPOp>(loc, accETy, count);
             poolVal = rewriter.create<arith::DivFOp>(loc, poolVal, countF)
                           ->getResult(0);
+            if (accETy.getIntOrFloatBitWidth() >
+                resultETy.getIntOrFloatBitWidth())
+              poolVal =
+                  rewriter.create<arith::TruncFOp>(loc, resultETy, poolVal);
           } else {
 
             // If we have quantization information we need to apply an offset
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 6616ea7cf699fa..51ebcad0797807 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -306,6 +306,97 @@ func.func @avg_pool_f32(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>)
 
 // -----
 
+// CHECK-LABEL: @avg_pool_f16_f32acc
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @avg_pool_f16_f32acc(%arg0: tensor<1x6x34x62xf16>) -> (tensor<1x5x33x62xf16>) {
+  // Apply padding to the input:
+  // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
+  // CHECK: %[[PAD:.+]] = tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
+  // CHECK:   tensor.yield %[[F0]] : f16
+
+  // Fill the pooling target:
+  // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+  // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x5x33x62xf32>
+  // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[F0]] : f32) outs(%[[EMPTY]] : tensor<1x5x33x62xf32>)
+
+  // Compute the sum padding:
+  // CHECK: %[[KERNEL:.+]] = tensor.empty() : tensor<4x4xf32>
+  // CHECK: %[[POOL:.+]] = linalg.pooling_nhwc_sum
+  // CHECK-SAME: dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
+  // CHECK-SAME: ins(%[[PAD]], %[[KERNEL]] : tensor<1x8x36x62xf16>, tensor<4x4xf32>)
+  // CHECK-SAME: outs(%[[FILL]] : tensor<1x5x33x62xf32>)
+
+  // Compute dimension based constants:
+  // CHECK: %[[I1:.+]] = arith.constant 1 : index
+  // CHECK: %[[DIM1:.+]] = tensor.dim %[[POOL]], %[[I1]]
+  // CHECK: %[[I2:.+]] = arith.constant 2 : index
+  // CHECK: %[[DIM2:.+]] = tensor.dim %[[POOL]], %[[I2]]
+  // CHECK: %[[ONE:.+]] = arith.constant 1 : index
+  // CHECK: %[[HEIGHT:.+]] = arith.subi %[[DIM1]], %[[ONE]] : index
+  // CHECK: %[[WIDTH:.+]] = arith.subi %[[DIM2]], %[[ONE]] : index
+
+  // Divide the sum pooling by the number of summed values.
+  // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x5x33x62xf16>
+  // CHECK: %[[GENERIC:.+]] = linalg.generic
+  // CHECK-SAME: indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+  // CHECK-SAME: ins(%[[POOL]] : tensor<1x5x33x62xf32>)
+  // CHECK-SAME: outs(%[[EMPTY]] : tensor<1x5x33x62xf16>)
+  // CHECK: ^bb0(%[[IN:.+]]: f32, %{{.+}}: f16)
+  // CHECK:   %[[ZERO:.+]] = arith.constant 0
+
+  // Compute how much of the height does not include padding:
+  // CHECK:   %[[STRIDE:.+]] = arith.constant 1
+  // CHECK:   %[[KSIZE:.+]] = arith.constant 4
+  // CHECK:   %[[START:.+]] = linalg.index 1
+  // CHECK:   %[[END:.+]] = arith.subi %[[HEIGHT]], %[[START]]
+  // CHECK:   %[[SRC_START:.+]] = arith.muli %[[START]], %[[STRIDE]]
+  // CHECK:   %[[SRC_END:.+]] = arith.muli %[[END]], %[[STRIDE]]
+  // CHECK:   %[[PAD_START:.+]] = arith.constant 1
+  // CHECK:   %[[START_SUB:.+]] = arith.subi %[[SRC_START]], %[[PAD_START]]
+  // CHECK:   %[[CMP:.+]] = arith.cmpi slt, %[[START_SUB]], %[[ZERO]]
+  // CHECK:   %[[OFFSET:.+]] = arith.select %[[CMP]], %[[START_SUB]], %[[ZERO]]
+  // CHECK:   %[[START_OFFSET:.+]] = arith.addi %[[KSIZE]], %[[OFFSET]]
+  // CHECK:   %[[PAD_END:.+]] = arith.constant 1
+  // CHECK:   %[[END_SUB:.+]] = arith.subi %[[SRC_END]], %[[PAD_END]]
+  // CHECK:   %[[CMP:.+]] = arith.cmpi slt, %[[END_SUB]], %[[ZERO]]
+  // CHECK:   %[[OFFSET:.+]] = arith.select %[[CMP]], %[[END_SUB]], %[[ZERO]]
+  // CHECK:   %[[END_OFFSET:.+]] = arith.addi %[[START_OFFSET]], %[[OFFSET]]
+  // CHECK:   %[[CMP:.+]] = arith.cmpi slt, %[[END_OFFSET]], %[[ONE]]
+  // CHECK:   %[[KHEIGHT:.+]] = arith.select %[[CMP]], %[[ONE]], %[[END_OFFSET]]
+
+  // Compute how much of the width does not include padding:
+  // CHECK:   %[[STRIDE:.+]] = arith.constant 1
+  // CHECK:   %[[KSIZE:.+]] = arith.constant 4
+  // CHECK:   %[[START:.+]] = linalg.index 2
+  // CHECK:   %[[END:.+]] = arith.subi %[[WIDTH]], %[[START]]
+  // CHECK:   %[[SRC_START:.+]] = arith.muli %[[START]], %[[STRIDE]]
+  // CHECK:   %[[SRC_END:.+]] = arith.muli %[[END]], %[[STRIDE]]
+  // CHECK:   %[[PAD_START:.+]] = arith.constant 1
+  // CHECK:   %[[START_SUB:.+]] = arith.subi %[[SRC_START]], %[[PAD_START]]
+  // CHECK:   %[[CMP:.+]] = arith.cmpi slt, %[[START_SUB]], %[[ZERO]]
+  // CHECK:   %[[OFFSET:.+]] = arith.select %[[CMP]], %[[START_SUB]], %[[ZERO]]
+  // CHECK:   %[[START_OFFSET:.+]] = arith.addi %[[KSIZE]], %[[OFFSET]]
+  // CHECK:   %[[PAD_END:.+]] = arith.constant 1
+  // CHECK:   %[[END_SUB:.+]] = arith.subi %[[SRC_END]], %[[PAD_END]]
+  // CHECK:   %[[CMP:.+]] = arith.cmpi slt, %[[END_SUB]], %[[ZERO]]
+  // CHECK:   %[[OFFSET:.+]] = arith.select %[[CMP]], %[[END_SUB]], %[[ZERO]]
+  // CHECK:   %[[END_OFFSET:.+]] = arith.addi %[[START_OFFSET]], %[[OFFSET]]
+  // CHECK:   %[[CMP:.+]] = arith.cmpi slt, %[[END_OFFSET]], %[[ONE]]
+  // CHECK:   %[[KWIDTH:.+]] = arith.select %[[CMP]], %[[ONE]], %[[END_OFFSET]]
+
+  // Divide the summed value by the number of values summed.
+  // CHECK:   %[[COUNT:.+]] = arith.muli %[[KHEIGHT]], %[[KWIDTH]]
+  // CHECK:   %[[CAST:.+]] = arith.index_cast %[[COUNT]]
+  // CHECK:   %[[FLT:.+]] = arith.sitofp %[[CAST]]
+  // CHECK:   %[[DIV:.+]] = arith.divf %[[IN]], %[[FLT]]
+  // CHECK:   %[[TRUNC:.+]] = arith.truncf %[[DIV]]
+  // CHECK:   linalg.yield %[[TRUNC]]
+  %0 = tosa.avg_pool2d %arg0 {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf16>) -> tensor<1x5x33x62xf16>
+  return %0 : tensor<1x5x33x62xf16>
+}
+
+// -----
+
 // CHECK-LABEL: @avg_pool_i8
 func.func @avg_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) {
   // CHECK: %[[GENERIC:.+]] = linalg.generic

``````````

</details>


https://github.com/llvm/llvm-project/pull/80849


More information about the Mlir-commits mailing list