[Mlir-commits] [mlir] [mlir][tosa] Add FP8 lit tests (PR #127730)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 7 09:48:52 PST 2025


https://github.com/Jerry-Ge updated https://github.com/llvm/llvm-project/pull/127730

>From dae7298894de9a096ac6627fbb6c6561b22248ad Mon Sep 17 00:00:00 2001
From: Tai Ly <tai.ly at arm.com>
Date: Thu, 1 Feb 2024 23:44:37 +0000
Subject: [PATCH] [mlir][tosa] Add FP8 lit tests

Add FP8 lit tests to the following operators:

ARGMAX
AVGPOOL
CONV2D
CONV3D
DEPTHWISE_CONV2D
MATMUL
MAX_POOL2D
TRANSPOSE_CONV2D
CONST
CAST
CONCAT
PAD
RESHAPE
REVERSE
SLICE
TILE
TRANSPOSE
GATHER
SCATTER

Signed-off-by: Tai Ly <tai.ly at arm.com>
Signed-off-by: Jerry Ge <jerry.ge at arm.com>
Change-Id: I56adfabb2396e38b7ed3479e4fd680b740bdb4e4
---
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp |   9 +-
 mlir/test/Dialect/Tosa/ops.mlir      | 288 +++++++++++++++++++++++++++
 2 files changed, 289 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 5fdf5f4b5cb6a..35f6ffb845c22 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -522,14 +522,7 @@ LogicalResult tosa::AvgPool2dOp::verify() {
   if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
     return failure();
 
-  if ((inputETy.isF32() && resultETy.isF32()) ||
-      (inputETy.isF16() && resultETy.isF16()) ||
-      (inputETy.isBF16() && resultETy.isBF16()) ||
-      (inputETy.isInteger(8) && resultETy.isInteger(8)) ||
-      (inputETy.isInteger(16) && resultETy.isInteger(16)))
-    return success();
-
-  return emitOpError("input/output element types are incompatible.");
+  return success();
 }
 
 LogicalResult tosa::ClampOp::verify() {
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index f9cd342ccf1d4..424bbe76ed6e2 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -783,3 +783,291 @@ func.func @test_const_shape() -> !tosa.shape<4> {
   %cst = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
   return %cst : !tosa.shape<4>
 }
+
+// F8 support tests
+
+// -----
+// CHECK-LABEL: argmax_f8E5M2
+func.func @test_argmax_f8E5M2(%arg0: tensor<12x8x16xf8E5M2>) -> tensor<12x16xi32> {
+  %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xf8E5M2>) -> tensor<12x16xi32>
+  return %0 : tensor<12x16xi32>
+}
+
+// -----
+// CHECK-LABEL: avg_pool2d_f8E5M2
+func.func @test_avg_pool2d_f8E5M2(%arg0: tensor<1x7x7x9xf8E5M2>) -> tensor<1x7x7x9xf8E5M2> {
+  %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+  %output_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+  %0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f16, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf8E5M2>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x7x7x9xf8E5M2>
+  return %0 : tensor<1x7x7x9xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: conv2d_f8E5M2
+func.func @test_conv2d_f8E5M2(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<8x1x1x4xf8E5M2>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
+  %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+  %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf8E5M2>, tensor<8x1x1x4xf8E5M2>, tensor<8xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x4x8xf16>
+  return %0 : tensor<1x4x4x8xf16>
+}
+
+// -----
+// CHECK-LABEL: conv3d_f8E5M2
+func.func @test_conv3d_f8E5M2(%arg0: tensor<1x4x8x21x17xf8E5M2>, %arg1: tensor<34x1x1x1x17xf8E5M2>, %arg2: tensor<34xf16>, %arg3: tensor<1xf8E5M2>, %arg4: tensor<1xf8E5M2>) -> tensor<1x4x8x21x34xf16> {
+  %0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf8E5M2>, tensor<34x1x1x1x17xf8E5M2>, tensor<34xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x8x21x34xf16>
+  return %0 : tensor<1x4x8x21x34xf16>
+}
+
+// -----
+// CHECK-LABEL: depthwise_conv2d_f8E5M2
+func.func @test_depthwise_conv2d_f8E5M2(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<1x1x4x2xf8E5M2>, %arg2: tensor<8xf16>, %arg3: tensor<1xf8E5M2>, %arg4: tensor<1xf8E5M2>) -> tensor<1x4x4x8xf16> {
+  %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E5M2>, tensor<1x1x4x2xf8E5M2>, tensor<8xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x4x8xf16>
+  return %0 : tensor<1x4x4x8xf16>
+}
+
+// -----
+// CHECK-LABEL: test_matmul_f8E5M2
+func.func @test_matmul_f8E5M2(%arg0: tensor<1x14x19xf8E5M2>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> {
+  %0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf8E5M2>, tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16>
+  return %0 : tensor<1x14x28xf16>
+}
+
+// -----
+// CHECK-LABEL: max_pool2d_f8E5M2
+func.func @test_max_pool2d_f8E5M2(%arg0: tensor<1x32x32x8xf8E5M2>) -> tensor<1x32x32x8xf8E5M2> {
+  %0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E5M2>) -> tensor<1x32x32x8xf8E5M2>
+  return %0 : tensor<1x32x32x8xf8E5M2>
+}
+
+// -----
+
+// CHECK-LABEL: transpose_conv2d_f8E5M2
+func.func @test_transpose_conv2d_f8E5M2(%arg0: tensor<1x32x32x8xf8E5M2>, %arg1: tensor<16x1x1x8xf8E5M2>, %arg2: tensor<16xf16>, %arg3: tensor<1xf8E5M2>, %arg4: tensor<1xf8E5M2>) -> tensor<1x32x32x16xf16> {
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E5M2>, tensor<16x1x1x8xf8E5M2>, tensor<16xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x32x32x16xf16>
+  return %0 : tensor<1x32x32x16xf16>
+}
+
+// -----
+// CHECK-LABEL: const_f8E5M2
+func.func @test_const_f8E5M2(%arg0 : index) -> tensor<4xf8E5M2> {
+    %0 = "tosa.const"() {values = dense<[3.0, -0.0, -1.0, 2.0]> : tensor<4xf8E5M2>} : () -> tensor<4xf8E5M2>
+    return %0 : tensor<4xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: cast_f8E5M2
+func.func @test_cast_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf16> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf16>
+  return %0 : tensor<13x21x3xf16>
+}
+
+// -----
+// CHECK-LABEL: concat_f8E5M2
+func.func @test_concat_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x21x3xf8E5M2>) -> tensor<26x21x3xf8E5M2> {
+  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xf8E5M2>, tensor<13x21x3xf8E5M2>) -> tensor<26x21x3xf8E5M2>
+  return %0 : tensor<26x21x3xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: pad_f8E5M2
+func.func @test_pad_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> {
+  %padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %cst = "tosa.const"() { values = dense<-0.0> : tensor<1xf8E5M2> } : () -> tensor<1xf8E5M2>
+  %0 = tosa.pad %arg0, %padding, %cst : (tensor<13x21x3xf8E5M2>, !tosa.shape<6>, tensor<1xf8E5M2>) -> tensor<13x21x3xf8E5M2>
+  return %0 : tensor<13x21x3xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: reshape_f8E5M2
+func.func @test_reshape_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<1x819xf8E5M2> {
+  %1 = tosa.const_shape {values = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf8E5M2>, !tosa.shape<2>) -> tensor<1x819xf8E5M2>
+  return %0 : tensor<1x819xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: reverse_f8E5M2
+func.func @test_reverse_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> {
+  %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf8E5M2>
+  return %0 : tensor<13x21x3xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: slice_f8E5M2
+func.func @test_slice_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<4x11x1xf8E5M2> {
+  %0 = tosa.const_shape {values = dense<[4, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %1 = tosa.const_shape {values = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %2 = tosa.slice %arg0, %0, %1 : (tensor<13x21x3xf8E5M2>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xf8E5M2>
+  return %2 : tensor<4x11x1xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: tile_f8E5M2
+func.func @test_tile_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<39x21x6xf8E5M2> {
+  %cst = tosa.const_shape { values = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3>
+  %0 = tosa.tile %arg0, %cst: (tensor<13x21x3xf8E5M2>, !tosa.shape<3>) -> tensor<39x21x6xf8E5M2>
+  return %0 : tensor<39x21x6xf8E5M2>
+}
+
+// -----
+func.func @test_transpose_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<3x13x21xf8E5M2> {
+  %1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<13x21x3xf8E5M2>) -> tensor<3x13x21xf8E5M2>
+  return %1 : tensor<3x13x21xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: gather_f8E5M2
+func.func @test_gather_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26xi32>) -> tensor<13x26x3xf8E5M2> {
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf8E5M2>, tensor<13x26xi32>) -> tensor<13x26x3xf8E5M2>
+  return %0 : tensor<13x26x3xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: scatter_f8E5M2
+func.func @test_scatter_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> {
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E5M2>, tensor<13x26xi32>, tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2>
+  return %0 : tensor<13x21x3xf8E5M2>
+}
+
+// -----
+// CHECK-LABEL: argmax_f8E4M3FN
+func.func @test_argmax_f8E4M3FN(%arg0: tensor<12x8x16xf8E4M3FN>) -> tensor<12x16xi32> {
+  %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xf8E4M3FN>) -> tensor<12x16xi32>
+  return %0 : tensor<12x16xi32>
+}
+
+// -----
+// CHECK-LABEL: avg_pool2d_f8E4M3FN
+func.func @test_avg_pool2d_f8E4M3FN(%arg0: tensor<1x7x7x9xf8E4M3FN>) -> tensor<1x7x7x9xf8E4M3FN> {
+  %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+  %output_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+  %0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f16, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf8E4M3FN>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x7x7x9xf8E4M3FN>
+  return %0 : tensor<1x7x7x9xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: conv2d_f8E4M3FN
+func.func @test_conv2d_f8E4M3FN(%arg0: tensor<1x4x4x4xf8E4M3FN>, %arg1: tensor<8x1x1x4xf8E4M3FN>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
+  %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+  %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf8E4M3FN>, tensor<8x1x1x4xf8E4M3FN>, tensor<8xf16>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x4x4x8xf16>
+  return %0 : tensor<1x4x4x8xf16>
+}
+
+// -----
+// CHECK-LABEL: conv3d_f8E4M3FN
+func.func @test_conv3d_f8E4M3FN(%arg0: tensor<1x4x8x21x17xf8E4M3FN>, %arg1: tensor<34x1x1x1x17xf8E4M3FN>, %arg2: tensor<34xf16>, %arg3: tensor<1xf8E4M3FN>, %arg4: tensor<1xf8E4M3FN>) -> tensor<1x4x8x21x34xf16> {
+  %0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf8E4M3FN>, tensor<34x1x1x1x17xf8E4M3FN>, tensor<34xf16>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x4x8x21x34xf16>
+  return %0 : tensor<1x4x8x21x34xf16>
+}
+
+// -----
+// CHECK-LABEL: depthwise_conv2d_f8E4M3FN
+func.func @test_depthwise_conv2d_f8E4M3FN(%arg0: tensor<1x4x4x4xf8E4M3FN>, %arg1: tensor<1x1x4x2xf8E4M3FN>, %arg2: tensor<8xf16>, %arg3: tensor<1xf8E4M3FN>, %arg4: tensor<1xf8E4M3FN>) -> tensor<1x4x4x8xf16> {
+  %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E4M3FN>, tensor<1x1x4x2xf8E4M3FN>, tensor<8xf16>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x4x4x8xf16>
+  return %0 : tensor<1x4x4x8xf16>
+}
+
+// -----
+// CHECK-LABEL: matmul_f8E4M3FN
+func.func @test_matmul_f8E4M3FN(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf16> {
+  %0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf16>
+  return %0 : tensor<1x14x28xf16>
+}
+
+// -----
+// CHECK-LABEL: max_pool2d_f8E4M3FN
+func.func @test_max_pool2d_f8E4M3FN(%arg0: tensor<1x32x32x8xf8E4M3FN>) -> tensor<1x32x32x8xf8E4M3FN> {
+  %0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E4M3FN>) -> tensor<1x32x32x8xf8E4M3FN>
+  return %0 : tensor<1x32x32x8xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: transpose_conv2d_f8E4M3FN
+func.func @test_transpose_conv2d_f8E4M3FN(%arg0: tensor<1x32x32x8xf8E4M3FN>, %arg1: tensor<16x1x1x8xf8E4M3FN>, %arg2: tensor<16xf16>, %arg3: tensor<1xf8E4M3FN>, %arg4: tensor<1xf8E4M3FN>) -> tensor<1x32x32x16xf16> {
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E4M3FN>, tensor<16x1x1x8xf8E4M3FN>, tensor<16xf16>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x32x32x16xf16>
+  return %0 : tensor<1x32x32x16xf16>
+}
+
+// -----
+// CHECK-LABEL: const_f8E4M3FN
+func.func @test_const_f8E4M3FN(%arg0 : index) -> tensor<4xf8E4M3FN> {
+    %0 = "tosa.const"() {values = dense<[3.0, -0.0, -1.0, 2.0]> : tensor<4xf8E4M3FN>} : () -> tensor<4xf8E4M3FN>
+    return %0 : tensor<4xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: cast_f8E4M3FN
+func.func @test_cast_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf16> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf16>
+  return %0 : tensor<13x21x3xf16>
+}
+
+// -----
+// CHECK-LABEL: concat_f8E4M3FN
+func.func @test_concat_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<13x21x3xf8E4M3FN>) -> tensor<26x21x3xf8E4M3FN> {
+  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xf8E4M3FN>, tensor<13x21x3xf8E4M3FN>) -> tensor<26x21x3xf8E4M3FN>
+  return %0 : tensor<26x21x3xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: pad_f8E4M3FN
+func.func @test_pad_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> {
+  %padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %cst = "tosa.const"() { values = dense<-0.0> : tensor<1xf8E4M3FN> } : () -> tensor<1xf8E4M3FN>
+  %0 = tosa.pad %arg0, %padding, %cst : (tensor<13x21x3xf8E4M3FN>, !tosa.shape<6>, tensor<1xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN>
+  return %0 : tensor<13x21x3xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: reshape_f8E4M3FN
+func.func @test_reshape_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<1x819xf8E4M3FN> {
+  %1 = tosa.const_shape {values = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf8E4M3FN>, !tosa.shape<2>) -> tensor<1x819xf8E4M3FN>
+  return %0 : tensor<1x819xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: reverse_f8E4M3FN
+func.func @test_reverse_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> {
+  %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN>
+  return %0 : tensor<13x21x3xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: slice_f8E4M3FN
+func.func @test_slice_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<4x11x1xf8E4M3FN> {
+  %0 = tosa.const_shape {values = dense<[4, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %1 = tosa.const_shape {values = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %2 = tosa.slice %arg0, %0, %1 : (tensor<13x21x3xf8E4M3FN>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xf8E4M3FN>
+  return %2 : tensor<4x11x1xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: tile_f8E4M3FN
+func.func @test_tile_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<39x21x6xf8E4M3FN> {
+  %cst = tosa.const_shape { values = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3>
+  %0 = tosa.tile %arg0, %cst: (tensor<13x21x3xf8E4M3FN>, !tosa.shape<3>) -> tensor<39x21x6xf8E4M3FN>
+  return %0 : tensor<39x21x6xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: transpose_f8E4M3FN
+func.func @test_transpose_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<3x13x21xf8E4M3FN> {
+  %1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<13x21x3xf8E4M3FN>) -> tensor<3x13x21xf8E4M3FN>
+  return %1 : tensor<3x13x21xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: gather_f8E4M3FN
+func.func @test_gather_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<13x26xi32>) -> tensor<13x26x3xf8E4M3FN> {
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf8E4M3FN>, tensor<13x26xi32>) -> tensor<13x26x3xf8E4M3FN>
+  return %0 : tensor<13x26x3xf8E4M3FN>
+}
+
+// -----
+// CHECK-LABEL: scatter_f8E4M3FN
+func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> {
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN>
+  return %0 : tensor<13x21x3xf8E4M3FN>
+}



More information about the Mlir-commits mailing list