[Mlir-commits] [mlir] [MLIR][TOSA] Fix f16/bf16 support for MaxPool2D (PR #69332)

Dhruv Chauhan llvmlistbot at llvm.org
Tue Oct 17 06:42:26 PDT 2023


https://github.com/dchauhan-arm created https://github.com/llvm/llvm-project/pull/69332

Currently, the MaxPool2D operation in the TOSA MLIR dialect does not accept half-precision Fp16 and Bf16 tensors, converse to what is stated in the [TOSA Specification](https://www.mlplatform.org/tosa/tosa_spec.html#_max_pool2d).

This patch fixes the verifier to accept the two datatypes for input/output tensors, and adds related LIT test cases in Tosa/ops.mlir

>From 73805761acd6ba6fe79e28582630a5ddd5d1b112 Mon Sep 17 00:00:00 2001
From: Dhruv Chauhan <dhruv.chauhan at arm.com>
Date: Tue, 17 Oct 2023 13:24:51 +0100
Subject: [PATCH] [MLIR][TOSA] Fix f16/bf16 support for MaxPool2D

Currently, the MaxPool2D operation in the TOSA MLIR dialect does not
accept half-precision Fp16 and Bf16 tensors, converse to what is stated
in the [TOSA Specification](https://www.mlplatform.org/tosa/tosa_spec.html#_max_pool2d).

This patch fixes the verifier to accept the two datatypes for
input/output tensors, and adds related LIT test cases in Tosa/ops.mlir
---
 .../TosaToLinalg/TosaToLinalgNamed.cpp         |  2 +-
 mlir/test/Dialect/Tosa/ops.mlir                | 18 ++++++++++++++++--
 2 files changed, 17 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 4214bb57563285c..ee8f52deadbd152 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -691,7 +691,7 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
 
     // Determine what the initial value needs to be for the max pool op.
     TypedAttr initialAttr;
-    if (resultETy.isF32())
+    if (resultETy.isF32() || resultETy.isBF16() || resultETy.isF16())
       initialAttr = rewriter.getFloatAttr(
           resultETy, APFloat::getLargest(
                          cast<FloatType>(resultETy).getFloatSemantics(), true));
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index e62bea515d06baa..8ce8fb73f29a504 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -97,12 +97,26 @@ func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -
 }
 
 // -----
-// CHECK-LABEL: max_pool2d
-func.func @test_max_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+// CHECK-LABEL: max_pool2d_f32
+func.func @test_max_pool2d_f32(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
   %0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
   return %0 : tensor<1x32x32x8xf32>
 }
 
+// -----
+// CHECK-LABEL: max_pool2d_bf16
+func.func @test_max_pool2d_bf16(%arg0: tensor<1x32x32x8xbf16>) -> tensor<1x32x32x8xbf16> {
+  %0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xbf16>) -> tensor<1x32x32x8xbf16>
+  return %0 : tensor<1x32x32x8xbf16>
+}
+
+// -----
+// CHECK-LABEL: max_pool2d_f16
+func.func @test_max_pool2d_f16(%arg0: tensor<1x32x32x8xf16>) -> tensor<1x32x32x8xf16> {
+  %0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf16>) -> tensor<1x32x32x8xf16>
+  return %0 : tensor<1x32x32x8xf16>
+}
+
 // -----
 // CHECK-LABEL: rfft2d
 func.func @test_rfft2d(%arg0: tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {



More information about the Mlir-commits mailing list