[Mlir-commits] [mlir] [mlir][tosa] Add support for BF16 in `tosa.resize` legalization (PR #158616)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 15 04:52:00 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Georgios Pinitas (GeorgeARM)

<details>
<summary>Changes</summary>

Extend the resize linalg legalization functionality with BF16 support and in accordance to the TOSA specification.

---
Full diff: https://github.com/llvm/llvm-project/pull/158616.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+2-2) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir (+44) 


``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 0a6f2477560a1..1955eec9964eb 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1827,8 +1827,8 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
     auto resultTy = cast<ShapedType>(op.getType());
     auto resultETy = resultTy.getElementType();
 
-    bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
-    auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
+    bool floatingPointMode = isa<FloatType>(resultETy);
+    auto floatTy = resultETy;
 
     auto imageH = inputTy.getShape()[1];
     auto imageW = inputTy.getShape()[2];
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
index ff2cbbc0b3938..6998aee45b887 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
@@ -12,6 +12,18 @@ func.func @unary_resize_nearest_fp32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x
 
 // -----
 
+// CHECK-LABEL: @unary_resize_nearest_bf16
+func.func @unary_resize_nearest_bf16(%arg0 : tensor<3x1x1x7xbf16>) -> tensor<3x1x1x7xbf16> {
+  %scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
+  %offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %resize = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<3x1x1x7xbf16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x1x7xbf16>
+  // CHECK: return %arg0
+  return %resize : tensor<3x1x1x7xbf16>
+}
+
+// -----
+
 // CHECK-LABEL: @unary_resize_nearest_fp16
 func.func @unary_resize_nearest_fp16(%arg0 : tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16> {
   %scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
@@ -36,6 +48,18 @@ func.func @unary_resize_bilinear_fp32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1
 
 // -----
 
+// CHECK-LABEL: @unary_resize_bilinear_bf16
+func.func @unary_resize_bilinear_bf16(%arg0 : tensor<3x1x1x7xbf16>) -> tensor<3x1x1x7xbf16> {
+  %scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
+  %offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %resize = tosa.resize %arg0, %scale, %offset, %border {mode = BILINEAR} : (tensor<3x1x1x7xbf16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x1x7xbf16>
+  // CHECK: return %arg0
+  return %resize : tensor<3x1x1x7xbf16>
+}
+
+// -----
+
 // CHECK-LABEL: @unary_resize_bilinear_fp16
 func.func @unary_resize_bilinear_fp16(%arg0 : tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16> {
   %scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
@@ -60,6 +84,26 @@ func.func @unary_resize_nearest_i8(%arg0 : tensor<3x1x1x7xi8>) -> tensor<3x1x1x7
 
 // -----
 
+// CHECK-LABEL: @broadcast_resize_nearest_bf16
+func.func @broadcast_resize_nearest_bf16(%arg0 : tensor<3x1x1x7xbf16>) -> tensor<3x1x5x7xbf16> {
+  // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0
+  // CHECK-NEXT{literal}: [[0], [1, 2, 3]] : tensor<3x1x1x7xbf16> into tensor<3x7xbf16>
+  // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x1x5x7xbf16>
+  // CHECK: %[[GENERIC:.+]] = linalg.generic
+  // CHECK-SAME: indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+  // CHECK-SAME: ins(%[[COLLAPSE]] : tensor<3x7xbf16>) outs(%[[EMPTY]] : tensor<3x1x5x7xbf16>)
+  // CHECK: ^bb0(%[[IN:.+]]: bf16, %[[OUT:.+]]: bf16):
+  // CHECK:   linalg.yield %[[IN]] : bf16
+  %scale = tosa.const_shape { values = dense<[2, 1, 3, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
+  %offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %resize = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<3x1x1x7xbf16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x5x7xbf16>
+
+  return %resize : tensor<3x1x5x7xbf16>
+}
+
+// -----
+
 // CHECK-LABEL: @broadcast_resize_nearest_f32
 func.func @broadcast_resize_nearest_f32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x5x7xf32> {
   // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0

``````````

</details>


https://github.com/llvm/llvm-project/pull/158616


More information about the Mlir-commits mailing list