[Mlir-commits] [mlir] [mlir][tosa] Add FP8 support (PR #127730)

Luke Hutton llvmlistbot at llvm.org
Thu Mar 6 09:54:35 PST 2025


================
@@ -776,3 +776,286 @@ func.func @test_const_shape() -> !tosa.shape<4> {
   %cst = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
   return %cst : !tosa.shape<4>
 }
+
+// F8 support tests
+
+// -----
+// CHECK-LABEL: argmax_f8E5M2
+func.func @test_argmax_f8E5M2(%arg0: tensor<12x8x16xf8E5M2>) -> tensor<12x16xi32> {
+  %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xf8E5M2>) -> tensor<12x16xi32>
+  return %0 : tensor<12x16xi32>
+}
+
+// -----
+// func.func @test_avg_pool2d_f8E5M2(%arg0: tensor<1x7x7x9xf8E5M2>) -> tensor<1x7x7x9xf8E5M2> {
+//   %0 = tosa.avg_pool2d %arg0 {acc_type = f16, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf8E5M2>) -> tensor<1x7x7x9xf8E5M2>
+//   return %0 : tensor<1x7x7x9xf8E5M2>
+// }
+
+// -----
+// CHECK-LABEL: conv2d_f8E5M2
+func.func @test_conv2d_f8E5M2(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<8x1x1x4xf8E5M2>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
+  %input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+  %weight_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf8E5M2>, tensor<8x1x1x4xf8E5M2>, tensor<8xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x4x8xf16>
+  return %0 : tensor<1x4x4x8xf16>
+}
+
+// -----
+// CHECK-LABEL: conv3d_f8E5M2
+func.func @test_conv3d_f8E5M2(%arg0: tensor<1x4x8x21x17xf8E5M2>, %arg1: tensor<34x1x1x1x17xf8E5M2>, %arg2: tensor<34xf16>, %arg3: tensor<1xf8E5M2>, %arg4: tensor<1xf8E5M2>) -> tensor<1x4x8x21x34xf16> {
+  %0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf8E5M2>, tensor<34x1x1x1x17xf8E5M2>, tensor<34xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x8x21x34xf16>
+  return %0 : tensor<1x4x8x21x34xf16>
+}
+
+// -----
+// CHECK-LABEL: depthwise_conv2d_f8E5M2
+func.func @test_depthwise_conv2d_f8E5M2(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<1x1x4x2xf8E5M2>, %arg2: tensor<8xf16>, %arg3: tensor<1xf8E5M2>, %arg4: tensor<1xf8E5M2>) -> tensor<1x4x4x8xf16> {
+  %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E5M2>, tensor<1x1x4x2xf8E5M2>, tensor<8xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x4x8xf16>
+  return %0 : tensor<1x4x4x8xf16>
+}
+
+// -----
+// CHECK-LABEL: test_matmul_f8E5M2
+func.func @test_matmul_f8E5M2(%arg0: tensor<1x14x19xf8E5M2>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> {
+  %0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf8E5M2>, tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16>
+  return %0 : tensor<1x14x28xf16>
+}
+
+// -----
+// CHECK-LABEL: max_pool2d_f8E5M2
+func.func @test_max_pool2d_f8E5M2(%arg0: tensor<1x32x32x8xf8E5M2>) -> tensor<1x32x32x8xf8E5M2> {
+  %0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E5M2>) -> tensor<1x32x32x8xf8E5M2>
+  return %0 : tensor<1x32x32x8xf8E5M2>
+}
+
+// -----
+
+// func.func @test_transpose_conv2d_f8E5M2(%arg0: tensor<1x32x32x8xf8E5M2>, %arg1: tensor<16x1x1x8xf8E5M2>, %arg2: tensor<16xf16>, %arg3: tensor<1xf8E5M2>, %arg4: tensor<1xf8E5M2>) -> tensor<1x32x32x16xf16> {
----------------
lhutton1 wrote:

nit: same here

https://github.com/llvm/llvm-project/pull/127730


More information about the Mlir-commits mailing list