[Mlir-commits] [mlir] [mlir][tosa] Add fp16 support to `tosa.resize` (PR #73019)
Georgios Pinitas
llvmlistbot at llvm.org
Tue Nov 21 09:17:31 PST 2023
https://github.com/GeorgeARM created https://github.com/llvm/llvm-project/pull/73019
None
>From 940946969672d151eada200b840a7212f499c802 Mon Sep 17 00:00:00 2001
From: Georgios Pinitas <georgios.pinitas at arm.com>
Date: Tue, 21 Nov 2023 16:37:35 +0000
Subject: [PATCH] [mlir][tosa] Add fp16 support to `tosa.resize`
---
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 22 +++++++--------
.../TosaToLinalg/tosa-to-linalg-resize.mlir | 28 +++++++++++++++----
2 files changed, 34 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 3bf7bf12b5e96ff..c07fd47d7d8d710 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1503,6 +1503,9 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
auto resultTy = cast<ShapedType>(op.getType());
auto resultETy = resultTy.getElementType();
+ bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
+ auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
+
auto imageH = inputTy.getShape()[1];
auto imageW = inputTy.getShape()[2];
@@ -1536,16 +1539,13 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
Value zeroI32 =
b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
- Value zeroFp32 =
- b.create<arith::ConstantOp>(b.getZeroAttr(b.getF32Type()));
+ Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy));
Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
- bool floatingPointMode = resultETy.isF32();
-
ArrayRef<int64_t> offset = op.getOffset();
ArrayRef<int64_t> border = op.getBorder();
ArrayRef<int64_t> scale = op.getScale();
@@ -1568,16 +1568,16 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
int size, ImplicitLocOpBuilder &b) {
if (size == 1) {
index = zeroI32;
- delta = zeroFp32;
+ delta = zeroFp;
return;
}
// x = x * scale_d + offset;
// ix = floor(x / scale_n)
// dx = x / scale_n - ix
- Value val = b.create<arith::UIToFPOp>(b.getF32Type(), in);
- scaleN = b.create<arith::UIToFPOp>(b.getF32Type(), scaleN);
- scaleD = b.create<arith::UIToFPOp>(b.getF32Type(), scaleD);
- offset = b.create<arith::SIToFPOp>(b.getF32Type(), offset);
+ Value val = b.create<arith::UIToFPOp>(floatTy, in);
+ scaleN = b.create<arith::UIToFPOp>(floatTy, scaleN);
+ scaleD = b.create<arith::UIToFPOp>(floatTy, scaleD);
+ offset = b.create<arith::SIToFPOp>(floatTy, offset);
val = b.create<arith::MulFOp>(val, scaleD);
val = b.create<arith::AddFOp>(val, offset);
val = b.create<arith::DivFOp>(val, scaleN);
@@ -1626,7 +1626,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
Value pred;
if (floatingPointMode) {
- auto h = b.create<arith::ConstantOp>(b.getF32FloatAttr(0.5f));
+ auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f));
pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
} else {
Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
@@ -1682,7 +1682,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
input, ValueRange{batch, y1, x1, channel});
if (floatingPointMode) {
- auto oneVal = b.create<arith::ConstantOp>(b.getF32FloatAttr(1.0f));
+ auto oneVal = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f));
auto interpolate = [&](Value val0, Value val1, Value delta,
int inputSize,
ImplicitLocOpBuilder &b) -> Value {
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
index e7db61a7888397d..361161db863ce6d 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg))" %s -o -| FileCheck %s
-// CHECK-LABEL: @unary_resize_nearest_fp
-func.func @unary_resize_nearest_fp(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32> {
+// CHECK-LABEL: @unary_resize_nearest_fp32
+func.func @unary_resize_nearest_fp32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32> {
%resize = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = array<i64: 2, 2, 1, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32>
// CHECK: return %arg0
return %resize : tensor<3x1x1x7xf32>
@@ -9,8 +9,17 @@ func.func @unary_resize_nearest_fp(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1x
// -----
-// CHECK-LABEL: @unary_resize_bilinear_fp
-func.func @unary_resize_bilinear_fp(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32> {
+// CHECK-LABEL: @unary_resize_nearest_fp16
+func.func @unary_resize_nearest_fp16(%arg0 : tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16> {
+ %resize = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = array<i64: 2, 2, 1, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16>
+ // CHECK: return %arg0
+ return %resize : tensor<3x1x1x7xf16>
+}
+
+// -----
+
+// CHECK-LABEL: @unary_resize_bilinear_fp32
+func.func @unary_resize_bilinear_fp32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32> {
%resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = array<i64: 2, 2, 1, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32>
// CHECK: return %arg0
return %resize : tensor<3x1x1x7xf32>
@@ -18,6 +27,15 @@ func.func @unary_resize_bilinear_fp(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1
// -----
+// CHECK-LABEL: @unary_resize_bilinear_fp16
+func.func @unary_resize_bilinear_fp16(%arg0 : tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16> {
+ %resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = array<i64: 2, 2, 1, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16>
+ // CHECK: return %arg0
+ return %resize : tensor<3x1x1x7xf16>
+}
+
+// -----
+
// CHECK-LABEL: @unary_resize_nearest_i8
func.func @unary_resize_nearest_i8(%arg0 : tensor<3x1x1x7xi8>) -> tensor<3x1x1x7xi8> {
%resize = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = array<i64: 2, 1, 3, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<3x1x1x7xi8>) -> tensor<3x1x1x7xi8>
@@ -286,7 +304,7 @@ func.func @resize_bilinear_int(%arg0: tensor<1x19x20x1xi8>) {
// -----
// CHECK-LABEL: @resize_nearest_fp
-func.func @resize_nearest_fp(%input: tensor<1x50x48x1xf32>) -> () {
+func.func @resize_nearest_fp32(%input: tensor<1x50x48x1xf32>) -> () {
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x1600x1536x1xf32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK: %[[IDX0:.+]] = linalg.index 0
More information about the Mlir-commits
mailing list