[Mlir-commits] [mlir] b0de4e6 - [mlir][tosa] Add support for BF16 in `tosa.resize` legalization (#158616)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 15 05:26:05 PDT 2025
Author: Georgios Pinitas
Date: 2025-09-15T13:25:59+01:00
New Revision: b0de4e67775869a9e0a7c95236335084165e90ce
URL: https://github.com/llvm/llvm-project/commit/b0de4e67775869a9e0a7c95236335084165e90ce
DIFF: https://github.com/llvm/llvm-project/commit/b0de4e67775869a9e0a7c95236335084165e90ce.diff
LOG: [mlir][tosa] Add support for BF16 in `tosa.resize` legalization (#158616)
Extend the resize linalg legalization functionality with BF16 support
and in accordance to the TOSA specification.
Signed-off-by: Georgios Pinitas <georgios.pinitas at arm.com>
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 0a6f2477560a1..1955eec9964eb 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1827,8 +1827,8 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
auto resultTy = cast<ShapedType>(op.getType());
auto resultETy = resultTy.getElementType();
- bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
- auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
+ bool floatingPointMode = isa<FloatType>(resultETy);
+ auto floatTy = resultETy;
auto imageH = inputTy.getShape()[1];
auto imageW = inputTy.getShape()[2];
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
index ff2cbbc0b3938..6998aee45b887 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
@@ -12,6 +12,18 @@ func.func @unary_resize_nearest_fp32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x
// -----
+// CHECK-LABEL: @unary_resize_nearest_bf16
+func.func @unary_resize_nearest_bf16(%arg0 : tensor<3x1x1x7xbf16>) -> tensor<3x1x1x7xbf16> {
+ %scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
+ %offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %resize = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<3x1x1x7xbf16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x1x7xbf16>
+ // CHECK: return %arg0
+ return %resize : tensor<3x1x1x7xbf16>
+}
+
+// -----
+
// CHECK-LABEL: @unary_resize_nearest_fp16
func.func @unary_resize_nearest_fp16(%arg0 : tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16> {
%scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
@@ -36,6 +48,18 @@ func.func @unary_resize_bilinear_fp32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1
// -----
+// CHECK-LABEL: @unary_resize_bilinear_bf16
+func.func @unary_resize_bilinear_bf16(%arg0 : tensor<3x1x1x7xbf16>) -> tensor<3x1x1x7xbf16> {
+ %scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
+ %offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %resize = tosa.resize %arg0, %scale, %offset, %border {mode = BILINEAR} : (tensor<3x1x1x7xbf16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x1x7xbf16>
+ // CHECK: return %arg0
+ return %resize : tensor<3x1x1x7xbf16>
+}
+
+// -----
+
// CHECK-LABEL: @unary_resize_bilinear_fp16
func.func @unary_resize_bilinear_fp16(%arg0 : tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16> {
%scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
@@ -60,6 +84,26 @@ func.func @unary_resize_nearest_i8(%arg0 : tensor<3x1x1x7xi8>) -> tensor<3x1x1x7
// -----
+// CHECK-LABEL: @broadcast_resize_nearest_bf16
+func.func @broadcast_resize_nearest_bf16(%arg0 : tensor<3x1x1x7xbf16>) -> tensor<3x1x5x7xbf16> {
+ // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0
+ // CHECK-NEXT{literal}: [[0], [1, 2, 3]] : tensor<3x1x1x7xbf16> into tensor<3x7xbf16>
+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x1x5x7xbf16>
+ // CHECK: %[[GENERIC:.+]] = linalg.generic
+ // CHECK-SAME: indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ // CHECK-SAME: ins(%[[COLLAPSE]] : tensor<3x7xbf16>) outs(%[[EMPTY]] : tensor<3x1x5x7xbf16>)
+ // CHECK: ^bb0(%[[IN:.+]]: bf16, %[[OUT:.+]]: bf16):
+ // CHECK: linalg.yield %[[IN]] : bf16
+ %scale = tosa.const_shape { values = dense<[2, 1, 3, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
+ %offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %resize = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<3x1x1x7xbf16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x5x7xbf16>
+
+ return %resize : tensor<3x1x5x7xbf16>
+}
+
+// -----
+
// CHECK-LABEL: @broadcast_resize_nearest_f32
func.func @broadcast_resize_nearest_f32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x5x7xf32> {
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0
More information about the Mlir-commits
mailing list