[Mlir-commits] [mlir] [mlir][tosa] Add fp16 support to `tosa.resize` (PR #73019)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 21 09:18:03 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Georgios Pinitas (GeorgeARM)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/73019.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+11-11) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir (+23-5) 


``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 3bf7bf12b5e96ff..c07fd47d7d8d710 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1503,6 +1503,9 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
     auto resultTy = cast<ShapedType>(op.getType());
     auto resultETy = resultTy.getElementType();
 
+    bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
+    auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type(); 
+
     auto imageH = inputTy.getShape()[1];
     auto imageW = inputTy.getShape()[2];
 
@@ -1536,16 +1539,13 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
 
       Value zeroI32 =
           b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
-      Value zeroFp32 =
-          b.create<arith::ConstantOp>(b.getZeroAttr(b.getF32Type()));
+      Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy));
       Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
       Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
 
       Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
       Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
 
-      bool floatingPointMode = resultETy.isF32();
-
       ArrayRef<int64_t> offset = op.getOffset();
       ArrayRef<int64_t> border = op.getBorder();
       ArrayRef<int64_t> scale = op.getScale();
@@ -1568,16 +1568,16 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
                                     int size, ImplicitLocOpBuilder &b) {
         if (size == 1) {
           index = zeroI32;
-          delta = zeroFp32;
+          delta = zeroFp;
           return;
         }
         // x = x * scale_d + offset;
         // ix = floor(x / scale_n)
         // dx = x / scale_n - ix
-        Value val = b.create<arith::UIToFPOp>(b.getF32Type(), in);
-        scaleN = b.create<arith::UIToFPOp>(b.getF32Type(), scaleN);
-        scaleD = b.create<arith::UIToFPOp>(b.getF32Type(), scaleD);
-        offset = b.create<arith::SIToFPOp>(b.getF32Type(), offset);
+        Value val = b.create<arith::UIToFPOp>(floatTy, in);
+        scaleN = b.create<arith::UIToFPOp>(floatTy, scaleN);
+        scaleD = b.create<arith::UIToFPOp>(floatTy, scaleD);
+        offset = b.create<arith::SIToFPOp>(floatTy, offset);
         val = b.create<arith::MulFOp>(val, scaleD);
         val = b.create<arith::AddFOp>(val, offset);
         val = b.create<arith::DivFOp>(val, scaleN);
@@ -1626,7 +1626,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
 
           Value pred;
           if (floatingPointMode) {
-            auto h = b.create<arith::ConstantOp>(b.getF32FloatAttr(0.5f));
+            auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f));
             pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
           } else {
             Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
@@ -1682,7 +1682,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
             input, ValueRange{batch, y1, x1, channel});
 
         if (floatingPointMode) {
-          auto oneVal = b.create<arith::ConstantOp>(b.getF32FloatAttr(1.0f));
+          auto oneVal = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f));
           auto interpolate = [&](Value val0, Value val1, Value delta,
                                  int inputSize,
                                  ImplicitLocOpBuilder &b) -> Value {
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
index e7db61a7888397d..361161db863ce6d 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg))" %s -o -| FileCheck %s
 
-// CHECK-LABEL: @unary_resize_nearest_fp
-func.func @unary_resize_nearest_fp(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32> {
+// CHECK-LABEL: @unary_resize_nearest_fp32
+func.func @unary_resize_nearest_fp32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32> {
   %resize = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = array<i64: 2, 2, 1, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32>
   // CHECK: return %arg0
   return %resize : tensor<3x1x1x7xf32>
@@ -9,8 +9,17 @@ func.func @unary_resize_nearest_fp(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1x
 
 // -----
 
-// CHECK-LABEL: @unary_resize_bilinear_fp
-func.func @unary_resize_bilinear_fp(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32> {
+// CHECK-LABEL: @unary_resize_nearest_fp16
+func.func @unary_resize_nearest_fp16(%arg0 : tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16> {
+  %resize = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = array<i64: 2, 2, 1, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16>
+  // CHECK: return %arg0
+  return %resize : tensor<3x1x1x7xf16>
+}
+
+// -----
+
+// CHECK-LABEL: @unary_resize_bilinear_fp32
+func.func @unary_resize_bilinear_fp32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32> {
   %resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = array<i64: 2, 2, 1, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32>
   // CHECK: return %arg0
   return %resize : tensor<3x1x1x7xf32>
@@ -18,6 +27,15 @@ func.func @unary_resize_bilinear_fp(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1
 
 // -----
 
+// CHECK-LABEL: @unary_resize_bilinear_fp16
+func.func @unary_resize_bilinear_fp16(%arg0 : tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16> {
+  %resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = array<i64: 2, 2, 1, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16>
+  // CHECK: return %arg0
+  return %resize : tensor<3x1x1x7xf16>
+}
+
+// -----
+
 // CHECK-LABEL: @unary_resize_nearest_i8
 func.func @unary_resize_nearest_i8(%arg0 : tensor<3x1x1x7xi8>) -> tensor<3x1x1x7xi8> {
   %resize = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = array<i64: 2, 1, 3, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<3x1x1x7xi8>) -> tensor<3x1x1x7xi8>
@@ -286,7 +304,7 @@ func.func @resize_bilinear_int(%arg0: tensor<1x19x20x1xi8>) {
 // -----
 
 // CHECK-LABEL: @resize_nearest_fp
-func.func @resize_nearest_fp(%input: tensor<1x50x48x1xf32>) -> () {
+func.func @resize_nearest_fp32(%input: tensor<1x50x48x1xf32>) -> () {
   // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x1600x1536x1xf32>
   // CHECK: %[[GENERIC:.+]] = linalg.generic
   // CHECK: %[[IDX0:.+]] = linalg.index 0

``````````

</details>


https://github.com/llvm/llvm-project/pull/73019


More information about the Mlir-commits mailing list