[Mlir-commits] [mlir] 2c9ddfc - [mlir][Tosa] fix fp16/bf16 support for AvgPool2d (#68718)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 13 08:42:44 PDT 2023


Author: fabrizio-indirli
Date: 2023-10-13T08:42:39-07:00
New Revision: 2c9ddfc7852ed88dd88bb38e9518404a623c70b5

URL: https://github.com/llvm/llvm-project/commit/2c9ddfc7852ed88dd88bb38e9518404a623c70b5
DIFF: https://github.com/llvm/llvm-project/commit/2c9ddfc7852ed88dd88bb38e9518404a623c70b5.diff

LOG: [mlir][Tosa] fix fp16/bf16 support for AvgPool2d (#68718)

Currently, the AvgPool2d operation in the TOSA MLIR dialect does not
accept half-precision Fp16 and Bf16 tensors, conversely to what stated
in the [TOSA
specification](https://www.mlplatform.org/tosa/tosa_spec.html#_avg_pool2d).
This issue was previously raised: #63424 here on Github and it is due to
a bug in the AvgPool2d verifier.

This patch fixes the AvgPool2d verifier to accept fp16 & bf16 datatype
for input/output tensors and accumulator, and it adds related LIT test
cases in Tosa/ops.mlir.

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index a719171b2b359d2..6db04fe38bcd356 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -247,18 +247,20 @@ LogicalResult tosa::AvgPool2dOp::verify() {
   if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
     return emitOpError("accumulator type for integer tensor is not i32");
 
-  if ((inputETy.isBF16() || inputETy.isF16()) &&
-      !(accType.isF16() || accType.isF32()))
-    return emitOpError("accumulator type for f16/bf16 tensor is not f16/f32");
+  if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
+    return emitOpError("accumulator type for f16 tensor is not f16/f32");
+
+  if (inputETy.isBF16() && !accType.isF32())
+    return emitOpError("accumulator type for bf16 tensor is not f32");
 
   if (inputETy.isF32() && !accType.isF32())
     return emitOpError("accumulator type for f32 tensor is not f32");
 
-  if (inputETy.isF32() && resultETy.isF32())
-    return success();
-  if (inputETy.isInteger(8) && resultETy.isInteger(8))
-    return success();
-  if (inputETy.isInteger(16) && resultETy.isInteger(16))
+  if ((inputETy.isF32() && resultETy.isF32()) ||
+      (inputETy.isF16() && resultETy.isF16()) ||
+      (inputETy.isBF16() && resultETy.isBF16()) ||
+      (inputETy.isInteger(8) && resultETy.isInteger(8)) ||
+      (inputETy.isInteger(16) && resultETy.isInteger(16)))
     return success();
 
   return emitOpError("input/output element types are incompatible.");

diff  --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 7d7f2d31a4244cd..e62bea515d06baa 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -16,6 +16,20 @@ func.func @test_avg_pool2d_f32(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32
   return %0 : tensor<1x7x7x9xf32>
 }
 
+// -----
+// CHECK-LABEL: avg_pool2d_f16
+func.func @test_avg_pool2d_f16(%arg0: tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16> {
+  %0 = tosa.avg_pool2d %arg0 {acc_type = f16, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16>
+  return %0 : tensor<1x7x7x9xf16>
+}
+
+// -----
+// CHECK-LABEL: avg_pool2d_f16_accumf32
+func.func @test_avg_pool2d_f16_accumf32(%arg0: tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16> {
+  %0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16>
+  return %0 : tensor<1x7x7x9xf16>
+}
+
 // -----
 // CHECK-LABEL: avg_pool2d_i8
 func.func @test_avg_pool2d_i8(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> {


        


More information about the Mlir-commits mailing list